import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from tqdm import tqdm
dataset_path = "./flowers"
image_size = (320, 240)
batch_size = 32
num_classes = 5
epochs=10
datagen = ImageDataGenerator(
rescale=1./255,
validation_split=0.2
)
train_generator = datagen.flow_from_directory(
dataset_path,
target_size=image_size,
batch_size=batch_size,
class_mode='categorical',
subset='training'
)
validation_generator = datagen.flow_from_directory(
dataset_path,
target_size=image_size,
batch_size=batch_size,
class_mode='categorical',
subset='validation'
)
sample_images, _ = next(train_generator)
plt.figure(figsize=(10, 10))
for i in range(9):
plt.subplot(3, 3, i + 1)
plt.imshow(sample_images[i])
plt.axis("off")
plt.show()
sample_images, _ = next(validation_generator)
plt.figure(figsize=(10, 10))
for i in range(9):
plt.subplot(3, 3, i + 1)
plt.imshow(sample_images[i])
plt.axis("off")
plt.show()
model = keras.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(image_size[0], image_size[1], 3)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(128, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(num_classes, activation='softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
train_steps_per_epoch = train_generator.samples // batch_size
val_steps_per_epoch = validation_generator.samples // batch_size
for epoch in range(epochs):
print("Epoch {}/{}".format(epoch + 1, epochs))
pbar = tqdm(total=train_steps_per_epoch, unit="batch")
for _ in range(train_steps_per_epoch):
x_batch, y_batch = next(train_generator)
loss, acc = model.train_on_batch(x_batch, y_batch)
pbar.update(1)
pbar.set_description("Train loss: {:.4f}, acc: {:.4f}".format(loss, acc))
pbar.close()
val_loss, val_acc = model.evaluate(validation_generator, steps=val_steps_per_epoch)
print("Validation loss: {:.4f}, acc: {:.4f}".format(val_loss, val_acc))
model.save("flower_classification_model.h5")
test_image_path = "./flowers/test3-t.jpg"
test_image = tf.keras.preprocessing.image.load_img(test_image_path, target_size=image_size)
test_image_array = tf.keras.preprocessing.image.img_to_array(test_image)
test_image_array = tf.expand_dims(test_image_array, 0)
saved_model = keras.models.load_model("flower_classification_model.h5")
predictions = saved_model.predict(test_image_array)
predicted_class_index = tf.argmax(predictions, axis=-1)
predicted_class = train_generator.class_indices
for class_name, class_index in predicted_class.items():
if class_index == predicted_class_index:
print("预测结果:", class_name)
break
相关知识
卷积神经网络(CNN)鲜花的识别
几个卷积神经网络(CNN)可视化的网站
满满干货!一文快速实现CNN(卷积神经网络)识别花朵
CNN卷积神经网络:花卉分类
花朵识别系统Python+卷积神经网络算法+人工智能+深度学习+计算机课设项目+TensorFlow+模型训练
【花卉识别系统】Python+卷积神经网络算法+人工智能+深度学习+图像识别+算法模型
Tensorflow基于卷积神经网络(CNN)的手写数字识别
基于卷积神经网络的农作物病虫害识别系统
基于卷积神经网络的花卉识别技术 Flower Recognition Based on Convolutional Neural Networks
基于卷积神经网络和集成学习的材质识别和分割方法研究
网址: CNN卷积神经网络花朵识别(5种) https://m.huajiangbk.com/newsview1101253.html
上一篇: 蔷薇的寓意 |
下一篇: 识别花app免费下载 |