这一部分包括数据处理,模型定义,模型训练。
1、第26行的名称为数据集文件夹每一类花的文件夹名字
2、第27行到44行的内容在运行一次后可以添上注释,这是数据处理的部分,处理后保存到了npy文件,后续直接读取就行。
3、选择VGG16作为基础模型,再次基础上进行训练,通过设计include_top=False,可以获得不含全连接层的基础网络。
import pandas as pd import numpy as np from tensorflow.keras.models import * from tensorflow.keras.applications import ResNet50,VGG16,MobileNet,InceptionV3,NASNetLarge import os from tensorflow.keras import layers, optimizers, models from tensorflow.keras.callbacks import ModelCheckpoint from tensorflow.keras.layers import * from tensorflow.keras.models import Model from sklearn.model_selection import train_test_split from tensorflow.keras.utils import to_categorical import tensorflow as tf from sklearn.metrics import confusion_matrix, classification_report from sklearn.tree import DecisionTreeClassifier import cv2 import glob import sklearn.metrics as metrics import matplotlib.pyplot as plt import warnings from tensorflow.keras.models import load_model warnings.filterwarnings("ignore") print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU'))) tf.debugging.set_log_device_placement(True) names = ['bee_balm','blackberry_lily','blanket_flower','bougainvillea','bromelia', 'foxglove'] X = [] Y = [] for i in names: # 遍历文件夹,读取图片 for f in os.listdir(r"./data/train_data/{}".format(i)): print(f) Images = cv2.imread(r"./data/train_data/{}/{}".format(i, f)) # images[i,j,k] # resize函数图像缩放 image = cv2.resize(Images, (256, 256), interpolation=cv2.INTER_CUBIC) # INTER_CUBIC-基于4x4像素邻域的3次插值法 X.append(image) Y.append(i) X = np.array(X) Y = np.array(Y) print(X) print(Y) print("结束了") np.save('x.npy',X) np.save('y.npy',Y) X_path = 'x.npy' Y_path = 'y.npy' X = np.load(X_path) Y = np.load(Y_path) labels= {'bee_balm':0,'blackberry_lily':1, 'blanket_flower':2,'bougainvillea':3,'bromelia':4, 'foxglove':5} Y = pd.DataFrame(Y) Y[0]=Y[0].map(labels) Y = Y.values.flatten() Y = to_categorical(Y, 6) X = X/255 # print(X.shape) # print(X) # print(Y) # print("结束") x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=1) def model(): conv_base = VGG16(weights='imagenet', include_top=False, input_shape=(256, 256, 3)) model = models.Sequential() model.add(conv_base) # model.add(GlobalAveragePooling2D()) model.add(Dropout(0.3)) model.add(layers.Flatten()) model.add(Dense(512, activation='relu')) model.add(layers.Dense(6, activation='softmax')) conv_base.trainable = True model.compile(loss='categorical_crossentropy', optimizer=optimizers.Adam(lr=0.0001), metrics=['categorical_accuracy']) model.summary() return model # model=model() # early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=6) # model_checkpoint = ModelCheckpoint('model2.hdf5', monitor='loss', verbose=1, save_best_only=True) # history = model.fit(x_train, y_train, epochs=20, batch_size=32,validation_data=(x_test,y_test),callbacks=[early_stop,model_checkpoint]) # model.save("model1.h5") model=load_model('model1.h5') pred = model.predict(x_test) y = np.argmax(pred, axis=-1) y_test = np.argmax(y_test, axis=-1) print(confusion_matrix(y_test, y)) print(classification_report(y_test, y)) cm = confusion_matrix(y_test, y) print(cm) plt.imshow(cm, cmap=plt.cm.BuPu) # ticks 坐标轴的坐标点 # label 坐标轴标签说明 indices = range(len(cm)) # 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表 #plt.xticks(indices, [0, 1, 2]) #plt.yticks(indices, [0, 1, 2]) label_name = ['bee_balm','blackberry_lily','blanket_flower','bougainvillea','bromelia', 'foxglove'] ax = plt.gca() plt.xticks(indices,label_name,fontsize=8) ax.xaxis.set_ticks_position("top") plt.yticks(indices, label_name,fontsize=8) plt.colorbar() plt.xlabel('预测值') plt.ylabel('真实值') plt.title('混淆矩阵') # plt.rcParams两行是用于解决标签不能显示汉字的问题 plt.rcParams['font.sans-serif']=['SimHei'] plt.rcParams['axes.unicode_minus'] = False # # 显示数据 for first_index in range(len(cm)): #第几行 for second_index in range(len(cm[first_index])): #第几列 plt.text(first_index, second_index, cm[first_index][second_index],fontdict={'size':6}) # 显示 plt.show() plt.savefig("混淆矩阵.png"); # ----------------------------------------------------------------------------------------------
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119先跑vgg.py进行训练,然后用main.py读取训练的模型进行预测
from tensorflow.keras.models import * import pandas as pd import cv2 import numpy as np def model(): model = load_model('model1.h5') return model def read(path): img = cv2.imread(path) img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_CUBIC) img = img / 255 img = img.reshape(1, 256, 256, 3) return img def pre(model,img): pred = model.predict(img) y = np.argmax(pred, axis=-1) labels= {0:'bee_balm',1:'blackberry_lily', 2:'blanket_flower',3:'bougainvillea',4:'bromelia', 5:'foxglove'} y = pd.DataFrame(y) y[0]=y[0].map(labels) y = y.values.flatten() print('此花为:',y) return y if __name__ =='__main__': path = r'./data/test/test6.jpg' img = read(path) model = model() pred = pre(model,img)
123456789101112131415161718192021222324252627282930进行UI界面的设计。
from PyQt5 import QtCore, QtGui, QtWidgets import sys from PyQt5 import QtCore,QtWidgets from PyQt5.QtWidgets import QApplication, QFileDialog from PyQt5.QtGui import QPixmap import main as sb from tensorflow.keras.models import load_model class Ui_Form(object): def setupUi(self, Form): Form.setObjectName("Form") Form.resize(765, 402) self.centralwidget = QtWidgets.QWidget(Form) self.label = QtWidgets.QLabel(Form) self.label.setGeometry(QtCore.QRect(70, 50, 256, 256)) self.label.setObjectName("label") self.pushButton = QtWidgets.QPushButton(Form) self.pushButton.setGeometry(QtCore.QRect(560, 300, 151, 61)) self.pushButton.setObjectName("pushButton") self.textBrowser = QtWidgets.QTextBrowser(Form) self.textBrowser.setGeometry(QtCore.QRect(420, 50, 256, 51)) self.textBrowser.setStyleSheet("border:0px;n""") self.textBrowser.setObjectName("textBrowser") self.pushButton_2 = QtWidgets.QPushButton(Form) self.pushButton_2.setGeometry(QtCore.QRect(380, 300, 151, 61)) self.pushButton_2.setObjectName("pushButton_2") self.textBrowser_1 = QtWidgets.QTextBrowser(Form) self.textBrowser_1.setGeometry(QtCore.QRect(420, 140, 261, 101)) self.textBrowser_1.setObjectName("textBrowser1") self.retranslateUi(Form) QtCore.QMetaObject.connectSlotsByName(Form) self.pushButton.clicked.connect(self.prediction) self.pushButton_2.clicked.connect(self.openimg) def openimg(self): self.img_file, _ = QFileDialog.getOpenFileName(self.centralwidget, 'Open file', r'xhsb', 'Image files (*.jpg)') print(self.img_file) self.img = QPixmap(self.img_file) self.label.setPixmap(self.img) self.label.setScaledContents(True) def prediction(self): str = self.img_file.split('/')[-1] str = './data/test/' + str self.image=sb.read(str) model = sb.model() pred = sb.pre(model,self.image) pred = str(pred) self.textBrowser_1.append("<font size="8" color="#000000">" + '此花为:' + pred + "</font>") QtWidgets.QApplication.processEvents() # 防止进程卡死 def retranslateUi(self, Form): _translate = QtCore.QCoreApplication.translate Form.setWindowTitle(_translate("Form", "花卉识别")) self.label.setText(_translate("Form", "请上传图片")) self.pushButton.setText(_translate("Form", "开始识别")) self.textBrowser.setHtml(_translate("Form", "<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.0//EN" "http://www.w3.org/TR/REC-html40/strict.dtd">n" "<html><head><meta name="qrichtext" content="1" /><style type="text/css">n" "p, li { white-space: pre-wrap; }n" "</style></head><body style=" font-family:'SimSun'; font-size:9pt; font-weight:400; font-style:normal;">n" "<p align="center" style=" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;"><span style=" font-size:20pt;">花卉识别系统</span></p></body></html>")) self.pushButton_2.setText(_translate("Form", "加载图片")) if __name__ == '__main__': import PyQt5 app = QApplication(sys.argv) ex = Ui_Form() window = PyQt5.QtWidgets.QMainWindow() ex.setupUi(window) window.show() sys.exit(app.exec_())
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071相关知识
利用VGG16做花数据集的识别
基于轻量化VGG的植物病虫害识别
基于残差网络迁移学习的花卉识别系统
基于卷积神经网络的樱桃叶片病虫害识别与防治系统,vgg16,resnet,swintransformer,模型融合(pytorch框架,python代码)
基于深度学习特征的植物病虫害检测
基于多视图特征融合的植物病害识别算法
基于深度学习和迁移学习的识花实践
农作物病虫害识别技术的发展综述
卷积神经网络训练花卉识别分类器
Keras复现VGG16及实现花卉分类
网址: 基于VGG16网络的花卉识别 https://m.huajiangbk.com/newsview375780.html
上一篇: 【实战】tensorflow 花 |
下一篇: sklearn |