首页 > 分享 > 鸢尾花类型识别的代码(自用)

鸢尾花类型识别的代码(自用)

# 引入使用的模块

import tensorflow as tf

from sklearn import datasets

from matplotlib import pyplot as plt

import numpy as np

# 导入数据,从datasets中导出鸢尾花数据的输入特征和标签

x_data=datasets.load_iris().data

y_data=datasets.load_iris().target

# 随机打乱数据,神经网络是模拟人脑,人脑对于信息接受的过程是随机的,所以输入的数据要是随机的

# seed:随机种子,当设置种子数后,每次生成的随机数都一样,目前为了方便教学所以进行了设置,实际使用过程可以不设置

np.random.seed(116) # 对数据和标签打乱,但为了保证数据和标签一一对应组队进行打乱,使用相同的seed,保证输入特征和标签一一对应

np.random.shuffle(x_data)

np.random.seed(116)

np.random.shuffle(y_data)

# 将打乱后的数据集分割为训练集和测试集,训练集一共有150组数据,75%作为训练部分,25%作为测试部分,训练集为前120行,测试集为后30行

x_train=x_data[:-30]

y_train=y_data[:-30]

x_test=x_data[-30:]

y_test=y_data[-30:]

# 转换x的数据类型,后面进行矩阵相乘时会因为数据类型不一致报错

# 运算过程是对x进行运算,y只是x的对应标签,所以要转换x的类型

x_train=tf.cast(x_train,tf.float32)

x_test=tf.cast(x_test,tf.float32)

# from_tensor_sclice函数使输入特征和标签值意义对应,把数据集分批次,每个批次batch组数据,分组运行会提高运行的效率

train_db=tf.data.Dataset.from_tensor_slices((x_train,y_train)).batch(32)

test_db=tf.data.Dataset.from_tensor_slices((x_test,y_test)).batch(32)

# 生成神经网络的参数,因为有4个输入特征(长宽高等),输出有3个(3类鸢尾花的类别),所以输入层为4个输入节点,输出层为3个神经元

# 用tf.Variable()标记参数可训练

# 使用seed使每次生成的随机数相同,同样为方便教学

w1=tf.Variable(tf.random.truncated_normal([4,3],stddev=0.1,seed=1)) # w1为权重

b1=tf.Variable(tf.random.truncated_normal([3],stddv=0.1,seed=1))

lr=0.1 # 设置学习率

train_loss_result=[] # 设置一个损失函数计算结果的列表,将计算的到的损失函数的值放进去,然后后续基于这些数值画图象

test_acc=[] # 将每轮的acc(accuracy,精度)记录在此列表中,为后续画acc曲线提供数据

epoch=500 # 循环500轮

loss_all=0 #每轮分4个step,loss_all记录四个step生成的4个loss的和,共有120个数据,每个batch有32组,所以要循环四轮才能够循环完120组数据

# 进行神经网络的训练

for epoch in range(epoch):

for step,(x_train,y_train) in enumerate(train_db):

with tf.GradientTape() as tape: # 使用with结构记录梯度信息

y=tf.matmul(x_train,w1)+b1 # 神经网络的方程为y=x(输入特征)*w(权重)+b(偏置)

y=tf.nn.softmax(y) # 此操作使输出为概率类型,此操作后与独热码同量级,能够相减求loss

y_=tf.one_hot(y_train,depth=3) # 将标签转换为独热码格式,方便计算loss和accuracy

loss=tf.reduce_mean(tf.square(y_-y)) # 采用均方误差损失函数mse=mean(sun(y-out)^2))

# 本程序是为了计算基于训练后的神经网络识别鸢尾花类型,所以需要将原类型标签(的概率,独热码)-神经网络得到的标签(的概率,独热码)

loss_all+=loss.numpy() # 将每个step计算出的loss累加,为后续求loss平均值提供数据,这样计算的loss更准确

# 计算loss对各个参数的梯度

grads=tape.gradient(loss,[w1,b1]) # 此 操作为求偏导操作,损失函数是为了寻求一组最优w和b,通过梯度法求解,所以需要对w和b求偏导

# 实现梯度更新 w1=w1-lr*w1_grad b1=b1-lr*b1_grad

w1.assign_sub(lr*grads[0]) # 参数w1自更新

b1.assign_sub(lr*grads[1]) # 参数b1自更新

# 每个epoch,打印loss信息

print("Epoch {},loss:{}".format(epoch,loss_all/4))

train_loss_results.append(loss_all / 4) # 将4个step的loss求平均记录在此变量中

loss_all=0 # loss_all归零,为记录下一个epoch的loss准备

# 测试部分

# total_correct为预测对的样本个数,total_number为测试的总样本数,将这两个变量都初始化为0

total_correct,total_number=0,0

for x_test,y_test in test_db: # 使用更新后的参数进行预测

y=tf.matmul(x_test,w1)+b1

y=tf.nn.softmax(y)

pred=tf.argmax(y,axis=1) # 返回y中最大值的索引,即预测的分类

# 将pred转换为y_test的数据类型

pred=tf.cast(pred,dtype=y_test.dtype)

# 若分类正确,则correct=1,否则为0,将bool型的结果转换为int型

correct=tf.cast(tf.equal(pred,y_test),dtype=tf.int32)

# 将每个batch的correct数加起来

correct=tf.reduce.sum(correct)

#将所有batch中的correct数加起来

total_correct+=int(correct)

# total_number为测试的总样本数,也就是x_test的行数,shape[0]返回变量的行数

total_number+=x_test.shape[0]

# 总的准确率等于total_correct/total_number

acc = total_correct / total_number

test_acc.append(acc)

print("Test_acc:", acc)

print("--------------------------")

#绘制loss曲线

plt.title('Loss Function Curve') #图片标题

plt.xlabel('Epoch') # x轴变量名称

plt.ylabel('Loss') # y轴变量名称

plt.plot(train_loss_results, label="$Loss$") # 逐点画出trian_loss_results值并连线,连线图标是Loss

plt.legend() # 画出曲线图标

plt.show() # 画出图像

# 绘制 Accuracy 曲线

plt.title('Acc Curve') # 图片标题

plt.xlabel('Epoch') # x轴变量名称

plt.ylabel('Acc') # y轴变量名称

plt.plot(test_acc, label="$Accuracy$") # 逐点画出test_acc值并连线,连线图标是Accuracy

plt.legend()

plt.show()

输出为:

x_data from datasets:

[[5.1 3.5 1.4 0.2]

[4.9 3. 1.4 0.2]

[4.7 3.2 1.3 0.2]

[4.6 3.1 1.5 0.2]

[5. 3.6 1.4 0.2]

[5.4 3.9 1.7 0.4]

[4.6 3.4 1.4 0.3]

[5. 3.4 1.5 0.2]

[4.4 2.9 1.4 0.2]

[4.9 3.1 1.5 0.1]

[5.4 3.7 1.5 0.2]

[4.8 3.4 1.6 0.2]

[4.8 3. 1.4 0.1]

[4.3 3. 1.1 0.1]

[5.8 4. 1.2 0.2]

[5.7 4.4 1.5 0.4]

[5.4 3.9 1.3 0.4]

[5.1 3.5 1.4 0.3]

[5.7 3.8 1.7 0.3]

[5.1 3.8 1.5 0.3]

[5.4 3.4 1.7 0.2]

[5.1 3.7 1.5 0.4]

[4.6 3.6 1. 0.2]

[5.1 3.3 1.7 0.5]

[4.8 3.4 1.9 0.2]

[5. 3. 1.6 0.2]

[5. 3.4 1.6 0.4]

[5.2 3.5 1.5 0.2]

[5.2 3.4 1.4 0.2]

[4.7 3.2 1.6 0.2]

[4.8 3.1 1.6 0.2]

[5.4 3.4 1.5 0.4]

[5.2 4.1 1.5 0.1]

[5.5 4.2 1.4 0.2]

[4.9 3.1 1.5 0.2]

[5. 3.2 1.2 0.2]

[5.5 3.5 1.3 0.2]

[4.9 3.6 1.4 0.1]

[4.4 3. 1.3 0.2]

[5.1 3.4 1.5 0.2]

[5. 3.5 1.3 0.3]

[4.5 2.3 1.3 0.3]

[4.4 3.2 1.3 0.2]

[5. 3.5 1.6 0.6]

[5.1 3.8 1.9 0.4]

[4.8 3. 1.4 0.3]

[5.1 3.8 1.6 0.2]

[4.6 3.2 1.4 0.2]

[5.3 3.7 1.5 0.2]

[5. 3.3 1.4 0.2]

[7. 3.2 4.7 1.4]

[6.4 3.2 4.5 1.5]

[6.9 3.1 4.9 1.5]

[5.5 2.3 4. 1.3]

[6.5 2.8 4.6 1.5]

[5.7 2.8 4.5 1.3]

[6.3 3.3 4.7 1.6]

[4.9 2.4 3.3 1. ]

[6.6 2.9 4.6 1.3]

[5.2 2.7 3.9 1.4]

[5. 2. 3.5 1. ]

[5.9 3. 4.2 1.5]

[6. 2.2 4. 1. ]

[6.1 2.9 4.7 1.4]

[5.6 2.9 3.6 1.3]

[6.7 3.1 4.4 1.4]

[5.6 3. 4.5 1.5]

[5.8 2.7 4.1 1. ]

[6.2 2.2 4.5 1.5]

[5.6 2.5 3.9 1.1]

[5.9 3.2 4.8 1.8]

[6.1 2.8 4. 1.3]

[6.3 2.5 4.9 1.5]

[6.1 2.8 4.7 1.2]

[6.4 2.9 4.3 1.3]

[6.6 3. 4.4 1.4]

[6.8 2.8 4.8 1.4]

[6.7 3. 5. 1.7]

[6. 2.9 4.5 1.5]

[5.7 2.6 3.5 1. ]

[5.5 2.4 3.8 1.1]

[5.5 2.4 3.7 1. ]

[5.8 2.7 3.9 1.2]

[6. 2.7 5.1 1.6]

[5.4 3. 4.5 1.5]

[6. 3.4 4.5 1.6]

[6.7 3.1 4.7 1.5]

[6.3 2.3 4.4 1.3]

[5.6 3. 4.1 1.3]

[5.5 2.5 4. 1.3]

[5.5 2.6 4.4 1.2]

[6.1 3. 4.6 1.4]

[5.8 2.6 4. 1.2]

[5. 2.3 3.3 1. ]

[5.6 2.7 4.2 1.3]

[5.7 3. 4.2 1.2]

[5.7 2.9 4.2 1.3]

[6.2 2.9 4.3 1.3]

[5.1 2.5 3. 1.1]

[5.7 2.8 4.1 1.3]

[6.3 3.3 6. 2.5]

[5.8 2.7 5.1 1.9]

[7.1 3. 5.9 2.1]

[6.3 2.9 5.6 1.8]

[6.5 3. 5.8 2.2]

[7.6 3. 6.6 2.1]

[4.9 2.5 4.5 1.7]

[7.3 2.9 6.3 1.8]

[6.7 2.5 5.8 1.8]

[7.2 3.6 6.1 2.5]

[6.5 3.2 5.1 2. ]

[6.4 2.7 5.3 1.9]

[6.8 3. 5.5 2.1]

[5.7 2.5 5. 2. ]

[5.8 2.8 5.1 2.4]

[6.4 3.2 5.3 2.3]

[6.5 3. 5.5 1.8]

[7.7 3.8 6.7 2.2]

[7.7 2.6 6.9 2.3]

[6. 2.2 5. 1.5]

[6.9 3.2 5.7 2.3]

[5.6 2.8 4.9 2. ]

[7.7 2.8 6.7 2. ]

[6.3 2.7 4.9 1.8]

[6.7 3.3 5.7 2.1]

[7.2 3.2 6. 1.8]

[6.2 2.8 4.8 1.8]

[6.1 3. 4.9 1.8]

[6.4 2.8 5.6 2.1]

[7.2 3. 5.8 1.6]

[7.4 2.8 6.1 1.9]

[7.9 3.8 6.4 2. ]

[6.4 2.8 5.6 2.2]

[6.3 2.8 5.1 1.5]

[6.1 2.6 5.6 1.4]

[7.7 3. 6.1 2.3]

[6.3 3.4 5.6 2.4]

[6.4 3.1 5.5 1.8]

[6. 3. 4.8 1.8]

[6.9 3.1 5.4 2.1]

[6.7 3.1 5.6 2.4]

[6.9 3.1 5.1 2.3]

[5.8 2.7 5.1 1.9]

[6.8 3.2 5.9 2.3]

[6.7 3.3 5.7 2.5]

[6.7 3. 5.2 2.3]

[6.3 2.5 5. 1.9]

[6.5 3. 5.2 2. ]

[6.2 3.4 5.4 2.3]

[5.9 3. 5.1 1.8]]

y_data from datasets:

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0

0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1

1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2

2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2

2 2]

x_data add index:

花萼长度 花萼宽度 花瓣长度 花瓣宽度

0 5.1 3.5 1.4 0.2

1 4.9 3.0 1.4 0.2

2 4.7 3.2 1.3 0.2

3 4.6 3.1 1.5 0.2

4 5.0 3.6 1.4 0.2

.. ... ... ... ...

145 6.7 3.0 5.2 2.3

146 6.3 2.5 5.0 1.9

147 6.5 3.0 5.2 2.0

148 6.2 3.4 5.4 2.3

149 5.9 3.0 5.1 1.8

[150 rows x 4 columns]

x_data add a column:

花萼长度 花萼宽度 花瓣长度 花瓣宽度 类别

0 5.1 3.5 1.4 0.2 0

1 4.9 3.0 1.4 0.2 0

2 4.7 3.2 1.3 0.2 0

3 4.6 3.1 1.5 0.2 0

4 5.0 3.6 1.4 0.2 0

.. ... ... ... ... ...

145 6.7 3.0 5.2 2.3 2

146 6.3 2.5 5.0 1.9 2

147 6.5 3.0 5.2 2.0 2

148 6.2 3.4 5.4 2.3 2

149 5.9 3.0 5.1 1.8 2

[150 rows x 5 columns]

相关知识

高光谱图像分类工具包:精准识别的利器
鸢尾花果实类型
卷积神经网络实现鸢尾花数据分类python代码实现
2022的七夕,奉上7个精美的表白代码,同时教大家快速改源码自用
推荐文章:探索手写汉字识别的新境界——HCCR
探索C二维码生成与识别的无限可能
Python原生代码实现KNN算法(鸢尾花数据集)
什么是手写识别?手写识别的方法、好处和挑战
【端到端】:从图像到识别:手写数字识别的完整流程
朴素贝叶斯算法对鸢尾花分类

网址: 鸢尾花类型识别的代码(自用) https://m.huajiangbk.com/newsview1570912.html

所属分类:花卉
上一篇: 你和你的女神之间,差了一个Ope
下一篇: 如何有效教学长春版三年级语文一片