活动地址:CSDN21天学习挑战赛
目录
构造神经网络的layers函数
layers.Conv2D
参数说明
总结
layers.MaxPooling2D layers.AveragePooling2D
参数说明
总结
layers.Dropout
参数说明
总结
layers.Flatten
layers.Dense
参数说明
总结
layers.Rescaling
参数说明
总结
花朵识别
导入数据
构建卷积神经网络
编译训练网络模型
预测
构造神经网络时主要用到的都是 keras.layers库内 的类的构造函数
layers.Conv2D主要是用来 形成卷积层
参数说明def __init__(self,
filters,
kernel_size,
strides=(1, 1),
padding='valid',
data_format=None,
dilation_rate=(1, 1),
groups=1,
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
filters:int型,表示卷积核数,也即特征图数,与输出的shape的channels一致kernel_size:核大小,int型二元组,一般是(3,3)strides:移动步长,int型二元组,默认(1,1),表示(左右移步长,上下移步长)padding:str型,表示是否采用边缘填充0,默认是valid无填充,值为same表示填充data_format:一般默认是channels_last,即input_shape=(batch_size,w,h,channels)也有channels_first 即input_shape=(batch_size,channels,w,h)activation:激活函数,值之前已讨论过use_bias:是否使用偏置参数,bool型,默认True,使用偏置参数kernel_regularize 对权重参数正则化bias_regularize 对偏置参数正则化activity_regularize 对输出的正则化input-shape 定义输入类型,默认data_format是channels_first,input_shape是(batch_size,height,weight,channels) 且通常batch_size不用指定,指定的话是2的幂
总结一般使用时,只需要传参filters、kernel_size、padding、activation、input_shape即可
(过拟合时也可能会用到kernel_regularize)
layers.MaxPooling2D 是构建采用最大池化方法的池化层
layers.AveragePooling2D 是构建采用平均池化方法的池化层
两种类的构造函数基本一致,下面统一说明
参数说明def __init__(self,
pool_size=(2, 2),
strides=None,
padding='valid',
data_format=None,
**kwargs):
pool_size:int型二元组,默认值(2,2),表示的是池化核的大小,每pool_size取最大值strides:步长,int型二元组或int数值,一般不用设置,默认取pool_size值padding:str型,值为valid表示不采用边缘填充,same表示边缘填充PS:输出大小的计算:
采用边缘填充:input_shape/strides 向上取整
不采用边缘填充:(input_shape-pool_size+1/strides)
总结一般使用时,并不需要传参,有时会只传pool_size
layers.Dropout 是构建过拟合时采用的丢弃层
参数说明def __init__(self, rate, noise_shape=None, seed=None, **kwargs) rate:丢弃率,表示每次训练中该层的灭活比,一般值是0~1(1会报错)noise_shape:一维的元组或者列表,长度与输入类型相同,且里面的元素要么为1,要么为输入类型值,哪个轴为1,哪个轴就会被一致地dropoutseed:随机数种子 总结
一般使用时,只需要传参rate即可
Flatten层用来将输入“压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡。Flatten不影响batch的大小。
一般使用时,无参调用
layers.Dense 是 构建全连接层
参数说明def __init__(self,
units,
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
units:int型,表示全连接层的输出维度activation:str型,表示激活函数,一般采用的是"relu",其他激活函数见之前的博客中use_bias:bool型,表示是否使用偏置参数,默认值是Truekernel_regularize 对权重参数正则化bias_regularize 对偏置参数正则化activity_regularize 对输出的正则化 总结使用时,传units和activation即可,过拟合时有时会传 kernel_regularize
layers.Rescaling 主要是构建缩放层,进行归一化或者标准化
参数说明def __init__(self, scale, offset=0., **kwargs) scale:表示图像像素值的缩放公式,一般图像是 1/255.0offset:偏移基准量,最后图像的像素值=offset+scale公式得到的结果还可以指定input_shape
eg:归一化 layers.Rescaling(1/255.0)
归到[-1,1] layers.Rescaling(1/127.5,offset=-1)
总结使用时也主要是用来归一化,一般是指定scale和input_shape即可
花朵识别所采用的数据集是通过 tf_flowers | TensorFlow Datasets 下载得到的
具体下载方式可见之前的博客,这里不再赘述。链接如下: 深度学习——怎样读取本地数据作为训练集和测试集_尘心平的博客-CSDN博客
这里是涉及到了本地数据如何导入,具体操作在之前的博客中也有介绍。链接如下:
深度学习——怎样读取本地数据作为训练集和测试集_尘心平的博客-CSDN博客
这里附上代码:
data_url = "E:/Download/flower_photos/flower_photos"
class_names = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
image_width = 256
image_height = 256
train_data = image_dataset_from_directory(
directory=data_url,
class_names=class_names,
image_size=(image_height, image_width),
validation_split=0.2,
subset='training',
seed=123
)
test_data = image_dataset_from_directory(
directory=data_url,
class_names=class_names,
image_size=(image_height, image_width),
validation_split=0.2,
subset='validation',
seed=123
)
buffer_size = 800
train_data = train_data.cache().shuffle(buffer_size).prefetch(buffer_size=tf.data.AUTOTUNE)
test_data = test_data.cache().prefetch(buffer_size=tf.data.AUTOTUNE)
为解决过拟合的问题,构建了Dropout层,且设置灭活比为0.6
model = models.Sequential([
layers.Rescaling(1 / 255.0, input_shape=(image_height, image_width,3)),
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(None, None, 3)),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Conv2D(64, (3, 3), activation='relu'),
layers.MaxPooling2D((2, 2)),
layers.Flatten(),
layers.Dropout(0.6),
layers.Dense(64, activation='relu'),
layers.Dense(5)
])
model.summary()
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
history = model.fit(train_data, epochs=10,
validation_data=test_data)
模型预测代码:
pre = model.predict(test_data)
for x in range(10):
print(pre[x])
for x in range(10):
print(class_names[np.array(pre[x]).argmax()])
绘制一部分测试集图片以进行验证:
plt.figure(figsize=(20, 10))
for test_image, test_label in test_data.as_numpy_iterator():
for i in range(10):
plt.subplot(5, 10, i + 1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(test_image[i] / 255.0, cmap=plt.cm.binary)
plt.xlabel(class_names[test_label[i]])
plt.show()
输出loss val_loss accuracy val_accuracy 变化曲线:
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy ')
plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.xlabel('Epoch')
plt.ylabel('value')
plt.ylim([0, 1.5])
plt.legend(loc='lower left')
plt.show()
test_loss, test_acc = model.evaluate(test_data, verbose=1)
print(test_loss, test_acc)
可以看到val_loss和loss均成下降趋势,模型较好
输出预测结果值:
[ 1.9147255 0.54115176 0.32046622 -2.2310185 0.6476273 ]
[ 1.5751086 0.25440234 -0.90600187 -1.9116721 0.9208099 ]
[-2.0533743 0.74125975 -1.6817499 2.0241127 -0.01674169]
[ 6.0942693 -0.64630526 -1.418276 -3.2073083 -0.05675798]
[-3.6862264 0.8362238 -2.3960068 5.0659366 1.2542382]
[-4.6353736 -2.6309597 5.1571536 -2.1555338 10.289471 ]
[-0.19035766 4.5748014 0.73688596 -6.761815 -1.4931774 ]
[-0.23641998 -0.66789824 3.328943 -1.2963772 2.43171 ]
[-4.2409375e-03 -3.9561117e+00 7.4320264e+00 -4.5852895e+00 1.0197869e+01]
[ 7.151232 3.0763505 3.9734569 -12.45473 2.183545 ]
相对应的预测标签:
实际图片:
预测结果与实际结果较为符合
相关知识
基于深度学习的花卉识别研究
智能农业的植物病虫害识别:如何实现准确的预测和控制1.背景介绍 智能农业是指利用人工智能、大数据、物联网等新技术,对农业
基于深度学习的花卉识别(附数据与代码)
深度学习花朵识别系统的设计与实现
使用Python实现深度学习模型:智能农业病虫害检测与防治
Keras学习笔记(二)Keras模型的创建
keras框架——基于深度学习CNN神经网络的水果成熟度识别分类系统源码
撒花!《神经网络与深度学习》中文教程正式开源!全书 pdf、ppt 和代码一同放出
Keras复现VGG16及实现花卉分类
05
网址: 深度学习——以花朵识别为例,分析构造神经网络时用到的各个类构造函数(Dense、Conv2D、Flatten等) https://m.huajiangbk.com/newsview550769.html
上一篇: 「陕西爱尚花果山网络科技有限公司 |
下一篇: 花了三天时间终于搞懂 Docke |