首页 > 分享 > ResNet

ResNet

ResNet-V1-50卷积神经网络迁移学习进行不同品种的花的分类识别

0.0882018.12.12 06:40:56字数 510阅读 10,340

运行环境

python3.6.3、tensorflow1.10.0
Intel@AIDevCloud:Intel Xeon Gold 6128 processors集群

数据和模型来源

数据集:http://download.tensorflow.org/example_images/flower_photos.tgz
模型:http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz

思路

数据集分析及处理

数据集文件解压后共有五个文件夹,每个文件夹都包含一定数量的花的图片,一个文件夹对应一个品种,图片各种尺寸都有,均为jpg格式,均为彩色图片。这里利用tensorflow提供的图片处理工具将所有图片转为300×300×3的格式,然后将所有图片的80%当作训练集,10%当作验证集,10%当作测试集,并且将训练集进行随机打乱,将得到的数据存在一个numpy文件中,以待后续训练使用。

模型构建

这里采用了ResNet-V1-50卷积神经网络来进行训练,模型结构在slim中都提供好了,另外采用官方已经训练好的参数进行迁移学习,只是在模型的最后根据问题的实际需要再定义一层输出层,只训练最后的自定义的全连接输出层的参数,训练500次,每次batch样本数取32,学习率取0.0001。

源代码

load_data.py

# -*- coding: UTF-8 -*- #Author:Yinli import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platform import gfile #定义输入文件夹和数据存储文件名 INPUT_DATA = 'flower_photos' OUTPUT_FILE = 'flower_processed_data.npy' #设定验证集和测试集的百分比 VALIDATION_PERCENTAGE = 10 TEST_PERCENTAGE = 10 def create_image_list(sess, testing_percentage, validation_percentage): #列出输入文件夹下的所有子文件夹,此时sub_dirs里面除了有子文件夹还有它自身,在第一个 sub_dirs = [x[0] for x in os.walk(INPUT_DATA)] #设置一个bool值,指定第一次循环的时候跳过母文件夹 is_root_dir = True #print(sub_dirs) #初始化数据矩阵 training_images = [] training_labels = [] testing_images = [] testing_labels = [] validation_images = [] validation_labels= [] current_label = 0 #分别处理每个子文件夹 for sub_dir in sub_dirs: #跳过第一个值,即跳过母文件夹 if is_root_dir: is_root_dir = False continue #获取子目录中的所有图片文件 extensions = ['jpg', 'jpeg', 'JPG', 'JPEG'] #用列表记录所有图片文件 file_list = [] #获取此子目录的名字比如daisy dir_name = os.path.basename(sub_dir) #对此子目录中所有图片后缀的文件 for extension in extensions: #获取每种图片的所有正则表达式 file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension) print(file_glob) #将所有符合正则表达式的文件名加入文件列表 file_list.extend(glob.glob(file_glob)) print(file_list) #如果没有文件跳出循环 if not file_list: continue print("processing ", dir_name) i = 0 #对于每张图片 for file_name in file_list: i+=1 #打开图片文件 #print("process num : ",i," processing", file_name, file=f) image_raw_data = gfile.FastGFile(file_name,'rb').read() #解码 image = tf.image.decode_jpeg(image_raw_data) #如果图片格式不是float32则转为float32 if image.dtype != tf.float32: image = tf.image.convert_image_dtype(image, dtype=tf.float32) #将图片源数据转为299*299 image = tf.image.resize_images(image, [300,300]) #得到此图片的数据 image_value = sess.run(image) print(np.shape(image_value)) #生成一个100以内的数 chance = np.random.randint(100) #按概率随机分到三个数据集中 if chance < validation_percentage: validation_images.append(image_value) validation_labels.append(current_label) elif chance < (testing_percentage + validation_percentage): testing_images.append(image_value) testing_labels.append(current_label) else: training_images.append(image_value) training_labels.append(current_label) if i%200 == 0: print("processing...") #处理完此种品种就将标签+1 current_label += 1 #将训练数据和标签以同样的方式打乱 state = np.random.get_state() np.random.shuffle(training_images) np.random.set_state(state) np.random.shuffle(training_labels) #返回所有数据 return np.asarray([training_images, training_labels, validation_images, validation_labels, testing_images, testing_labels]) def main(): with tf.Session() as sess: processed_data = create_image_list(sess, TEST_PERCENTAGE, VALIDATION_PERCENTAGE) #将数据存到文件中 np.save(OUTPUT_FILE, processed_data) if __name__ == "__main__": main() resnet.py

# -*- coding: UTF-8 -*- # Author:Yinli import numpy as np import tensorflow as tf import tensorflow.contrib.slim as slim # 加载通过slim定义好的resnet_v1模型 import tensorflow.contrib.slim.python.slim.nets.resnet_v1 as resnet_v1 # 数据文件 INPUT_DATA = "./flower_processed_data.npy" # 保存训练好的模型 TRAIN_FILE = "./save_model/my_model" # 提供的已经训练好的模型 CKPT_FILE = "./resnet_v1_50.ckpt" # 定义训练所用参数 LEARNING_RATE = 0.0001 STEPS = 500 BATCH = 32 N_CLASSES = 5 # 这里指出了不需要从训练好的模型中加载的参数,就是最后的自定义的全连接层 CHECKPOINT_EXCLUDE_SCOPES = 'Logits' # 指定最后的全连接层为可训练的参数 TRAINABLE_SCOPES = 'Logits' # 加载所有需要从训练好的模型加载的参数 def get_tuned_variables(): ##不需要加载的范围 exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(",")] # 初始化需要加载的参数 variables_to_restore = [] # 遍历模型中的所有参数 for var in slim.get_model_variables(): # 先指定为不需要移除 excluded = False # 遍历exclusions,如果在exclusions中,就指定为需要移除 for exclusion in exclusions: if var.op.name.startswith(exclusion): excluded = True break # 如果遍历完后还是不需要移除,就把参数加到列表里 if not excluded: variables_to_restore.append(var) return variables_to_restore # 获取所有需要训练的参数 def get_trainable_variables(): # 同上 scopes = [scope.strip() for scope in TRAINABLE_SCOPES.split(",")] variables_to_train = [] # 枚举所有需要训练的参数的前缀,并找到这些前缀的所有参数 for scope in scopes: variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope) variables_to_train.extend(variables) return variables_to_train def main(): # 加载数据 processed_data = np.load(INPUT_DATA) training_images = processed_data[0] n_training_example = len(training_images) training_labels = processed_data[1] validation_images = processed_data[2] validation_labels = processed_data[3] testing_images = processed_data[4] testing_labels = processed_data[5] print("there is %d training examples, %d validation examples, %d testing examples" % (n_training_example, len(validation_labels), len(testing_labels))) # 定义数据格式 images = tf.placeholder(tf.float32, [None, 300, 300, 3], name='input_images') labels = tf.placeholder(tf.int64, [None], name='labels') # 定义模型,因为给出的只有参数,并没有模型,这里需要指定模型的具体结构 with slim.arg_scope(resnet_v1.resnet_arg_scope()): # logits就是最后预测值,images就是输入数据,指定num_classes=None是为了使resnet模型最后的输出层禁用 logits, _ = resnet_v1.resnet_v1_50(images, num_classes=None) #自定义的输出层 with tf.variable_scope("Logits"): #将原始模型的输出数据去掉维度为2和3的维度,最后只剩维度1的batch数和维度4的300*300*3 #也就是将原来的二三四维度全部压缩到第四维度 net = tf.squeeze(logits, axis=[1,2]) #加入一层dropout层 net = slim.dropout(net, keep_prob=0.5,scope='dropout_scope') #加入一层全连接层,指定最后输出大小 logits = slim.fully_connected(net, num_outputs=N_CLASSES, scope='fc') # 获取需要训练的变量 trainable_variables = get_trainable_variables() # 定义损失,模型定义的时候已经考虑了正则化了 tf.losses.softmax_cross_entropy(tf.one_hot(labels, N_CLASSES), logits, weights=1.0) # 定义训练过程 train_step = tf.train.RMSPropOptimizer(LEARNING_RATE).minimize(tf.losses.get_total_loss()) # 定义测试和验证过程 with tf.name_scope('evaluation'): correct_prediction = tf.equal(tf.argmax(logits, 1), labels) evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # 定义加载模型的函数,就是重新定义load_fn函数,从文件中获取参数,获取指定的变量,忽略缺省值 load_fn = slim.assign_from_checkpoint_fn(CKPT_FILE, get_tuned_variables(), ignore_missing_vars=True) # 定义保存新的训练好的模型的函数 saver = tf.train.Saver() with tf.Session() as sess: # 初始化没有加载进来的变量,一定要在模型加载之前,否则会将训练好的参数重新赋值 init = tf.global_variables_initializer() sess.run(init) # 加载训练好的模型 print("加载谷歌训练好的模型...") load_fn(sess) start = 0 end = BATCH for i in range(STEPS): # 训练... sess.run(train_step, feed_dict={images: training_images[start:end], labels: training_labels[start:end]}) # 间断地保存模型,并在验证集上验证 if i % 50 == 0 or i + 1 == STEPS: saver.save(sess, TRAIN_FILE, global_step=i) validation_accuracy = sess.run(evaluation_step, feed_dict={images: validation_images, labels: validation_labels}) print("经过%d次训练后,在验证集上的正确率为%.3f" % (i, validation_accuracy)) # 更新起始和末尾 start = end if start == n_training_example: start = 0 end = start + BATCH if end > n_training_example: end = n_training_example # 训练完了在测试集上测试正确率 testing_accuracy = sess.run(evaluation_step, feed_dict={images: testing_images, labels: testing_labels}) print("最后在测试集上的正确率为%.3f" % testing_accuracy) if __name__ == '__main__': main()

运行结果

result.png

结果分析

从结果中可以看到,利用已经训练好的复杂模型的参数,再根据问题加上一层自定义的输出层,可以在短时间内利用较少的资源将模型迁移到不同的问题上,在200次训练的时候就可以在这个问题上达到90%的正确率,经过500次训练后可以在测试集上达到接近95%的正确率,验证了目前的主流卷积神经网络具有很好的普适性和迁移性。

最后编辑于

:2019.03.28 04:47:30

更多精彩内容,就在简书APP

"小礼物走一走,来简书关注我"

还没有人赞赏,支持一下

序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...

沈念sama阅读 203,671评论 6赞 477

序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...

文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...

文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...

正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...

茶点故事阅读 63,642评论 5赞 365

文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...

那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...

文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...

序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...

正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...

茶点故事阅读 35,608评论 2赞 321

正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...

茶点故事阅读 37,698评论 1赞 329

序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...

正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...

茶点故事阅读 38,958评论 3赞 307

文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...

文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...

我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...

正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...

茶点故事阅读 42,425评论 2赞 342

被以下专题收入,发现更多相似内容

推荐阅读更多精彩内容

1、MobileNet GoogleMobileNets:用于移动视觉应用的高效卷积神经网络的张量流实现 在ten...

AI信仰者阅读 3,302评论 1赞 15

原文链接:https://yq.aliyun.com/articles/178374 0. 简介 在过去,我写的主...

tensorflow中文社区对官方文档进行了完整翻译。鉴于官方更新不少内容,而现有的翻译基本上都已过时。故本人对更...

周乘阅读 10,204评论 4赞 27

译者按: 祖师爷Hinton 带领的小组经典之作,深度学习开山祖师 Hinton率领的谷歌团队多次夺冠 ,主力成员...

晚上又重温了一遍《西西里的美丽传说》。一部很老的片子,也是我很喜欢的片子。真心喜欢莫妮卡*贝鲁奇的性感优雅,也喜欢...

窝窝透阅读 836评论 4赞 6

相关知识

基于pytorch搭建ResNet神经网络用于花类识别
7 Resnet深度残差网络实现102种花卉分类
ResNet残差网络在PyTorch中的实现——训练花卉分类器
基于ResNet对花朵分类研究
Pytorch resnet花朵识别(5种花)附完整代码
基于卷积神经网络的樱桃叶片病虫害识别与防治系统,vgg16,resnet,swintransformer,模型融合(pytorch框架,python代码)
度学习pytorch实战六:ResNet50网络图像分类篇自建花数据集图像分类(5类)超详细代码
基于python编程的五种鲜花识别
基于MSDB
Pytorch框架实战——102类花卉分类

网址: ResNet https://m.huajiangbk.com/newsview516089.html

所属分类:花卉
上一篇: 浅析寿山石的鉴别方法 – 根盆网
下一篇: 【经典卷积神经网络】之AlexN