import tensorflow as tf from tensorflow.keras import datasets, layers, models import matplotlib.pyplot as plt from tensorflow.keras import Model from tensorflow.keras.layers import MaxPool2D, BatchNormalization,Dropout,Activation,Conv2D,MaxPool2D,Flatten
# 学习地方 1 使用网上压缩数据集,下载解压
tf.keras.utils.get_file( fname, origin, untar=False, md5_hash=None, file_hash=None,cache_subdir='datasets', hash_algorithm='auto', extract=False, archive_format='auto', cache_dir=None )
''' 参数说明-- fname:文件名,如果指定了绝对路径"/path/to/file.txt",则文件将会保存到该位置 选填 origin:文件的URL untar:boolean,文件是否需要解压缩 md5_hash:MD5哈希值,用于数据校验,支持sha256和md5哈希 cache_subdir:用于缓存数据的文件夹,若指定绝对路径"/path/to/folder"则将存放在该路径下 hash_algorithm:选择文件校验的哈希算法,可选项有'md5', 'sha256', 和'auto'. 默认'auto'自动检测使用的哈希算法 extract:若为True则试图提取文件,例如tar或zip archive_format:试图提取的文件格式,可选为'auto', 'tar', 'zip', 和None. 'tar' 包括tar, tar.gz, tar.bz文件. 默认'auto'是['tar', 'zip']. None或空列表将返回没有匹配 cache_dir:文件缓存后的地址,若为None,则默认存放在根目录的.keras文件夹中 '''
# 下载数据 dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz" data_dir = tf.keras.utils.get_file(fname = 'flower_photos', # 下载到本地后的文件名称 origin = dataset_url, # 数据集(Dataset)的URL路径; untar = True, # 是否解压文件 cache_dir= '/content/drive/MyDrive/DL/DL 100例/花朵数据')
data_dirs='/content/drive/MyDrive/DL/DL 100例/花朵数据/datasets/flower_photos'
import pathlib import PIL
data_dir = pathlib.Path(data_dirs) data_dir
PosixPath('/content/drive/MyDrive/DL/DL 100例/花朵数据/datasets/flower_photos')
#计算文件夹中全部jpg的数量
len(list(data_dir.glob('*/*.jpg')))
3670
roses = list(data_dir.glob('roses/*')) PIL.Image.open(str(roses[0]))
# 数据预处理 # 使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset中 batch_size = 32 img_height = 180 img_width = 180
train_ds=tf.keras.preprocessing.image_dataset_from_directory( data_dir, validation_split=0.2, subset='training', seed=123, image_size=(img_height, img_width), batch_size=batch_size )
Found 3670 files belonging to 5 classes. Using 2936 files for training.
val_ds = tf.keras.preprocessing.image_dataset_from_directory( data_dir, validation_split=0.2, subset="validation", seed=123, image_size=(img_height, img_width), batch_size=batch_size)
Found 3670 files belonging to 5 classes. Using 734 files for validation.
class_names=train_ds.class_names
plt.figure(figsize=(20, 10)) for images, labels in train_ds.take(1): for i in range(20): ax = plt.subplot(5, 10, i + 1) plt.imshow(images[i].numpy().astype("uint8")) plt.title(class_names[labels[i]]) plt.axis("off")
for image_batch, labels_batch in train_ds: print(image_batch.shape) print(labels_batch.shape) break
(32, 180, 180, 3) (32,)
# 配置数据集 # shuffle():打乱数据 # prefetch()预取数据,加速运行 # cache():将数据集缓存到内存当中,加速运行 AUTOTUNE = tf.data.AUTOTUNE train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE) val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
num_classes = 5 """ 关于卷积核的计算不懂的可以参考文章:https://blog.csdn.net/qq_38251616/article/details/114278995 """ model = models.Sequential([ layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)), layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), # 卷积层1,卷积核3*3 layers.MaxPooling2D((2, 2)), # 池化层1,2*2采样 layers.Conv2D(32, (3, 3), activation='relu'), # 卷积层2,卷积核3*3 layers.MaxPooling2D((2, 2)), # 池化层2,2*2采样 layers.Conv2D(64, (3, 3), activation='relu'), # 卷积层3,卷积核3*3 layers.Flatten(), # Flatten层,连接卷积层与全连接层 layers.Dense(128, activation='relu'), # 全连接层,特征进一步提取 layers.Dense(num_classes) # 输出层,输出预期结果 ]) model.summary() # 打印网络结构
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= rescaling (Rescaling) (None, 180, 180, 3) 0 _________________________________________________________________ conv2d (Conv2D) (None, 178, 178, 16) 448 _________________________________________________________________ max_pooling2d (MaxPooling2D) (None, 89, 89, 16) 0 _________________________________________________________________ conv2d_1 (Conv2D) (None, 87, 87, 32) 4640 _________________________________________________________________ max_pooling2d_1 (MaxPooling2 (None, 43, 43, 32) 0 _________________________________________________________________ conv2d_2 (Conv2D) (None, 41, 41, 64) 18496 _________________________________________________________________ flatten (Flatten) (None, 107584) 0 _________________________________________________________________ dense (Dense) (None, 128) 13770880 _________________________________________________________________ dense_1 (Dense) (None, 5) 645 ================================================================= Total params: 13,795,109 Trainable params: 13,795,109 Non-trainable params: 0 _________________________________________________________________
model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
history = model.fit( train_ds, validation_data=val_ds, epochs=3 )
Epoch 1/3 92/92 [==============================] - 375s 2s/step - loss: 1.5899 - accuracy: 0.3607 - val_loss: 1.1485 - val_accuracy: 0.5041 Epoch 2/3 92/92 [==============================] - 94s 1s/step - loss: 1.0672 - accuracy: 0.5698 - val_loss: 1.0633 - val_accuracy: 0.5736 Epoch 3/3 92/92 [==============================] - 92s 997ms/step - loss: 0.8770 - accuracy: 0.6614 - val_loss: 0.9982 - val_accuracy: 0.6185
相关知识
花朵识别软件下载
flower花朵识别数据集
基于实例分割的柑橘花朵识别及花量统计
识别花草的软件叫什么 可以识别植物app推荐
花卉识别下载
识别花草的软件哪个最准确 好用的识别植物app软件大全
花朵识别
花卉识别扫一扫有哪些推荐?教你识别花草方法
深度学习花的分类识别
花卉生产技术 温室花卉识别 温室观花花卉识别(26页)
网址: 花朵识别 https://m.huajiangbk.com/newsview104445.html
上一篇: 全文搜索 (SQL Server |
下一篇: 知网查重参考文章没有查出来 |