首页 > 分享 > 基于VGG16网络的花卉识别

基于VGG16网络的花卉识别

环境简述python3.7 Tensorflow卷积神经网络对花卉图片进行识别

VGG.py

这一部分包括数据处理,模型定义,模型训练。
1、第26行的名称为数据集文件夹每一类花的文件夹名字
2、第27行到44行的内容在运行一次后可以添上注释,这是数据处理的部分,处理后保存到了npy文件,后续直接读取就行。
3、选择VGG16作为基础模型,再次基础上进行训练,通过设计include_top=False,可以获得不含全连接层的基础网络。

import pandas as pd import numpy as np from tensorflow.keras.models import * from tensorflow.keras.applications import ResNet50,VGG16,MobileNet,InceptionV3,NASNetLarge import os from tensorflow.keras import layers, optimizers, models from tensorflow.keras.callbacks import ModelCheckpoint from tensorflow.keras.layers import * from tensorflow.keras.models import Model from sklearn.model_selection import train_test_split from tensorflow.keras.utils import to_categorical import tensorflow as tf from sklearn.metrics import confusion_matrix, classification_report from sklearn.tree import DecisionTreeClassifier import cv2 import glob import sklearn.metrics as metrics import matplotlib.pyplot as plt import warnings from tensorflow.keras.models import load_model warnings.filterwarnings("ignore") print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU'))) tf.debugging.set_log_device_placement(True) names = ['bee_balm','blackberry_lily','blanket_flower','bougainvillea','bromelia', 'foxglove'] X = [] Y = [] for i in names: # 遍历文件夹,读取图片 for f in os.listdir(r"./data/train_data/{}".format(i)): print(f) Images = cv2.imread(r"./data/train_data/{}/{}".format(i, f)) # images[i,j,k] # resize函数图像缩放 image = cv2.resize(Images, (256, 256), interpolation=cv2.INTER_CUBIC) # INTER_CUBIC-基于4x4像素邻域的3次插值法 X.append(image) Y.append(i) X = np.array(X) Y = np.array(Y) print(X) print(Y) print("结束了") np.save('x.npy',X) np.save('y.npy',Y) X_path = 'x.npy' Y_path = 'y.npy' X = np.load(X_path) Y = np.load(Y_path) labels= {'bee_balm':0,'blackberry_lily':1, 'blanket_flower':2,'bougainvillea':3,'bromelia':4, 'foxglove':5} Y = pd.DataFrame(Y) Y[0]=Y[0].map(labels) Y = Y.values.flatten() Y = to_categorical(Y, 6) X = X/255 # print(X.shape) # print(X) # print(Y) # print("结束") x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=1) def model(): conv_base = VGG16(weights='imagenet', include_top=False, input_shape=(256, 256, 3)) model = models.Sequential() model.add(conv_base) # model.add(GlobalAveragePooling2D()) model.add(Dropout(0.3)) model.add(layers.Flatten()) model.add(Dense(512, activation='relu')) model.add(layers.Dense(6, activation='softmax')) conv_base.trainable = True model.compile(loss='categorical_crossentropy', optimizer=optimizers.Adam(lr=0.0001), metrics=['categorical_accuracy']) model.summary() return model # model=model() # early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=6) # model_checkpoint = ModelCheckpoint('model2.hdf5', monitor='loss', verbose=1, save_best_only=True) # history = model.fit(x_train, y_train, epochs=20, batch_size=32,validation_data=(x_test,y_test),callbacks=[early_stop,model_checkpoint]) # model.save("model1.h5") model=load_model('model1.h5') pred = model.predict(x_test) y = np.argmax(pred, axis=-1) y_test = np.argmax(y_test, axis=-1) print(confusion_matrix(y_test, y)) print(classification_report(y_test, y)) cm = confusion_matrix(y_test, y) print(cm) plt.imshow(cm, cmap=plt.cm.BuPu) # ticks 坐标轴的坐标点 # label 坐标轴标签说明 indices = range(len(cm)) # 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表 #plt.xticks(indices, [0, 1, 2]) #plt.yticks(indices, [0, 1, 2]) label_name = ['bee_balm','blackberry_lily','blanket_flower','bougainvillea','bromelia', 'foxglove'] ax = plt.gca() plt.xticks(indices,label_name,fontsize=8) ax.xaxis.set_ticks_position("top") plt.yticks(indices, label_name,fontsize=8) plt.colorbar() plt.xlabel('预测值') plt.ylabel('真实值') plt.title('混淆矩阵') # plt.rcParams两行是用于解决标签不能显示汉字的问题 plt.rcParams['font.sans-serif']=['SimHei'] plt.rcParams['axes.unicode_minus'] = False # # 显示数据 for first_index in range(len(cm)): #第几行 for second_index in range(len(cm[first_index])): #第几列 plt.text(first_index, second_index, cm[first_index][second_index],fontdict={'size':6}) # 显示 plt.show() plt.savefig("混淆矩阵.png"); # ----------------------------------------------------------------------------------------------

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119

main.py

先跑vgg.py进行训练,然后用main.py读取训练的模型进行预测

from tensorflow.keras.models import * import pandas as pd import cv2 import numpy as np def model(): model = load_model('model1.h5') return model def read(path): img = cv2.imread(path) img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_CUBIC) img = img / 255 img = img.reshape(1, 256, 256, 3) return img def pre(model,img): pred = model.predict(img) y = np.argmax(pred, axis=-1) labels= {0:'bee_balm',1:'blackberry_lily', 2:'blanket_flower',3:'bougainvillea',4:'bromelia', 5:'foxglove'} y = pd.DataFrame(y) y[0]=y[0].map(labels) y = y.values.flatten() print('此花为:',y) return y if __name__ =='__main__': path = r'./data/test/test6.jpg' img = read(path) model = model() pred = pre(model,img)

123456789101112131415161718192021222324252627282930

UI.py

进行UI界面的设计。

from PyQt5 import QtCore, QtGui, QtWidgets import sys from PyQt5 import QtCore,QtWidgets from PyQt5.QtWidgets import QApplication, QFileDialog from PyQt5.QtGui import QPixmap import main as sb from tensorflow.keras.models import load_model class Ui_Form(object): def setupUi(self, Form): Form.setObjectName("Form") Form.resize(765, 402) self.centralwidget = QtWidgets.QWidget(Form) self.label = QtWidgets.QLabel(Form) self.label.setGeometry(QtCore.QRect(70, 50, 256, 256)) self.label.setObjectName("label") self.pushButton = QtWidgets.QPushButton(Form) self.pushButton.setGeometry(QtCore.QRect(560, 300, 151, 61)) self.pushButton.setObjectName("pushButton") self.textBrowser = QtWidgets.QTextBrowser(Form) self.textBrowser.setGeometry(QtCore.QRect(420, 50, 256, 51)) self.textBrowser.setStyleSheet("border:0px;n""") self.textBrowser.setObjectName("textBrowser") self.pushButton_2 = QtWidgets.QPushButton(Form) self.pushButton_2.setGeometry(QtCore.QRect(380, 300, 151, 61)) self.pushButton_2.setObjectName("pushButton_2") self.textBrowser_1 = QtWidgets.QTextBrowser(Form) self.textBrowser_1.setGeometry(QtCore.QRect(420, 140, 261, 101)) self.textBrowser_1.setObjectName("textBrowser1") self.retranslateUi(Form) QtCore.QMetaObject.connectSlotsByName(Form) self.pushButton.clicked.connect(self.prediction) self.pushButton_2.clicked.connect(self.openimg) def openimg(self): self.img_file, _ = QFileDialog.getOpenFileName(self.centralwidget, 'Open file', r'xhsb', 'Image files (*.jpg)') print(self.img_file) self.img = QPixmap(self.img_file) self.label.setPixmap(self.img) self.label.setScaledContents(True) def prediction(self): str = self.img_file.split('/')[-1] str = './data/test/' + str self.image=sb.read(str) model = sb.model() pred = sb.pre(model,self.image) pred = str(pred) self.textBrowser_1.append("<font size="8" color="#000000">" + '此花为:' + pred + "</font>") QtWidgets.QApplication.processEvents() # 防止进程卡死 def retranslateUi(self, Form): _translate = QtCore.QCoreApplication.translate Form.setWindowTitle(_translate("Form", "花卉识别")) self.label.setText(_translate("Form", "请上传图片")) self.pushButton.setText(_translate("Form", "开始识别")) self.textBrowser.setHtml(_translate("Form", "<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.0//EN" "http://www.w3.org/TR/REC-html40/strict.dtd">n" "<html><head><meta name="qrichtext" content="1" /><style type="text/css">n" "p, li { white-space: pre-wrap; }n" "</style></head><body style=" font-family:'SimSun'; font-size:9pt; font-weight:400; font-style:normal;">n" "<p align="center" style=" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;"><span style=" font-size:20pt;">花卉识别系统</span></p></body></html>")) self.pushButton_2.setText(_translate("Form", "加载图片")) if __name__ == '__main__': import PyQt5 app = QApplication(sys.argv) ex = Ui_Form() window = PyQt5.QtWidgets.QMainWindow() ex.setupUi(window) window.show() sys.exit(app.exec_())

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071

相关知识

利用VGG16做花数据集的识别
基于轻量化VGG的植物病虫害识别
基于残差网络迁移学习的花卉识别系统
基于卷积神经网络的樱桃叶片病虫害识别与防治系统,vgg16,resnet,swintransformer,模型融合(pytorch框架,python代码)
基于深度学习特征的植物病虫害检测
基于多视图特征融合的植物病害识别算法
基于深度学习和迁移学习的识花实践
农作物病虫害识别技术的发展综述
卷积神经网络训练花卉识别分类器
Keras复现VGG16及实现花卉分类

网址: 基于VGG16网络的花卉识别 https://m.huajiangbk.com/newsview375780.html

所属分类:花卉
上一篇: 【实战】tensorflow 花
下一篇: sklearn