# 引入使用的模块
import tensorflow as tf
from sklearn import datasets
from matplotlib import pyplot as plt
import numpy as np
# 导入数据,从datasets中导出鸢尾花数据的输入特征和标签
x_data=datasets.load_iris().data
y_data=datasets.load_iris().target
# 随机打乱数据,神经网络是模拟人脑,人脑对于信息接受的过程是随机的,所以输入的数据要是随机的
# seed:随机种子,当设置种子数后,每次生成的随机数都一样,目前为了方便教学所以进行了设置,实际使用过程可以不设置
np.random.seed(116) # 对数据和标签打乱,但为了保证数据和标签一一对应组队进行打乱,使用相同的seed,保证输入特征和标签一一对应
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)
# 将打乱后的数据集分割为训练集和测试集,训练集一共有150组数据,75%作为训练部分,25%作为测试部分,训练集为前120行,测试集为后30行
x_train=x_data[:-30]
y_train=y_data[:-30]
x_test=x_data[-30:]
y_test=y_data[-30:]
# 转换x的数据类型,后面进行矩阵相乘时会因为数据类型不一致报错
# 运算过程是对x进行运算,y只是x的对应标签,所以要转换x的类型
x_train=tf.cast(x_train,tf.float32)
x_test=tf.cast(x_test,tf.float32)
# from_tensor_sclice函数使输入特征和标签值意义对应,把数据集分批次,每个批次batch组数据,分组运行会提高运行的效率
train_db=tf.data.Dataset.from_tensor_slices((x_train,y_train)).batch(32)
test_db=tf.data.Dataset.from_tensor_slices((x_test,y_test)).batch(32)
# 生成神经网络的参数,因为有4个输入特征(长宽高等),输出有3个(3类鸢尾花的类别),所以输入层为4个输入节点,输出层为3个神经元
# 用tf.Variable()标记参数可训练
# 使用seed使每次生成的随机数相同,同样为方便教学
w1=tf.Variable(tf.random.truncated_normal([4,3],stddev=0.1,seed=1)) # w1为权重
b1=tf.Variable(tf.random.truncated_normal([3],stddv=0.1,seed=1))
lr=0.1 # 设置学习率
train_loss_result=[] # 设置一个损失函数计算结果的列表,将计算的到的损失函数的值放进去,然后后续基于这些数值画图象
test_acc=[] # 将每轮的acc(accuracy,精度)记录在此列表中,为后续画acc曲线提供数据
epoch=500 # 循环500轮
loss_all=0 #每轮分4个step,loss_all记录四个step生成的4个loss的和,共有120个数据,每个batch有32组,所以要循环四轮才能够循环完120组数据
# 进行神经网络的训练
for epoch in range(epoch):
for step,(x_train,y_train) in enumerate(train_db):
with tf.GradientTape() as tape: # 使用with结构记录梯度信息
y=tf.matmul(x_train,w1)+b1 # 神经网络的方程为y=x(输入特征)*w(权重)+b(偏置)
y=tf.nn.softmax(y) # 此操作使输出为概率类型,此操作后与独热码同量级,能够相减求loss
y_=tf.one_hot(y_train,depth=3) # 将标签转换为独热码格式,方便计算loss和accuracy
loss=tf.reduce_mean(tf.square(y_-y)) # 采用均方误差损失函数mse=mean(sun(y-out)^2))
# 本程序是为了计算基于训练后的神经网络识别鸢尾花类型,所以需要将原类型标签(的概率,独热码)-神经网络得到的标签(的概率,独热码)
loss_all+=loss.numpy() # 将每个step计算出的loss累加,为后续求loss平均值提供数据,这样计算的loss更准确
# 计算loss对各个参数的梯度
grads=tape.gradient(loss,[w1,b1]) # 此 操作为求偏导操作,损失函数是为了寻求一组最优w和b,通过梯度法求解,所以需要对w和b求偏导
# 实现梯度更新 w1=w1-lr*w1_grad b1=b1-lr*b1_grad
w1.assign_sub(lr*grads[0]) # 参数w1自更新
b1.assign_sub(lr*grads[1]) # 参数b1自更新
# 每个epoch,打印loss信息
print("Epoch {},loss:{}".format(epoch,loss_all/4))
train_loss_results.append(loss_all / 4) # 将4个step的loss求平均记录在此变量中
loss_all=0 # loss_all归零,为记录下一个epoch的loss准备
# 测试部分
# total_correct为预测对的样本个数,total_number为测试的总样本数,将这两个变量都初始化为0
total_correct,total_number=0,0
for x_test,y_test in test_db: # 使用更新后的参数进行预测
y=tf.matmul(x_test,w1)+b1
y=tf.nn.softmax(y)
pred=tf.argmax(y,axis=1) # 返回y中最大值的索引,即预测的分类
# 将pred转换为y_test的数据类型
pred=tf.cast(pred,dtype=y_test.dtype)
# 若分类正确,则correct=1,否则为0,将bool型的结果转换为int型
correct=tf.cast(tf.equal(pred,y_test),dtype=tf.int32)
# 将每个batch的correct数加起来
correct=tf.reduce.sum(correct)
#将所有batch中的correct数加起来
total_correct+=int(correct)
# total_number为测试的总样本数,也就是x_test的行数,shape[0]返回变量的行数
total_number+=x_test.shape[0]
# 总的准确率等于total_correct/total_number
acc = total_correct / total_number
test_acc.append(acc)
print("Test_acc:", acc)
print("--------------------------")
#绘制loss曲线
plt.title('Loss Function Curve') #图片标题
plt.xlabel('Epoch') # x轴变量名称
plt.ylabel('Loss') # y轴变量名称
plt.plot(train_loss_results, label="$Loss$") # 逐点画出trian_loss_results值并连线,连线图标是Loss
plt.legend() # 画出曲线图标
plt.show() # 画出图像
# 绘制 Accuracy 曲线
plt.title('Acc Curve') # 图片标题
plt.xlabel('Epoch') # x轴变量名称
plt.ylabel('Acc') # y轴变量名称
plt.plot(test_acc, label="$Accuracy$") # 逐点画出test_acc值并连线,连线图标是Accuracy
plt.legend()
plt.show()
输出为:
x_data from datasets:
[[5.1 3.5 1.4 0.2]
[4.9 3. 1.4 0.2]
[4.7 3.2 1.3 0.2]
[4.6 3.1 1.5 0.2]
[5. 3.6 1.4 0.2]
[5.4 3.9 1.7 0.4]
[4.6 3.4 1.4 0.3]
[5. 3.4 1.5 0.2]
[4.4 2.9 1.4 0.2]
[4.9 3.1 1.5 0.1]
[5.4 3.7 1.5 0.2]
[4.8 3.4 1.6 0.2]
[4.8 3. 1.4 0.1]
[4.3 3. 1.1 0.1]
[5.8 4. 1.2 0.2]
[5.7 4.4 1.5 0.4]
[5.4 3.9 1.3 0.4]
[5.1 3.5 1.4 0.3]
[5.7 3.8 1.7 0.3]
[5.1 3.8 1.5 0.3]
[5.4 3.4 1.7 0.2]
[5.1 3.7 1.5 0.4]
[4.6 3.6 1. 0.2]
[5.1 3.3 1.7 0.5]
[4.8 3.4 1.9 0.2]
[5. 3. 1.6 0.2]
[5. 3.4 1.6 0.4]
[5.2 3.5 1.5 0.2]
[5.2 3.4 1.4 0.2]
[4.7 3.2 1.6 0.2]
[4.8 3.1 1.6 0.2]
[5.4 3.4 1.5 0.4]
[5.2 4.1 1.5 0.1]
[5.5 4.2 1.4 0.2]
[4.9 3.1 1.5 0.2]
[5. 3.2 1.2 0.2]
[5.5 3.5 1.3 0.2]
[4.9 3.6 1.4 0.1]
[4.4 3. 1.3 0.2]
[5.1 3.4 1.5 0.2]
[5. 3.5 1.3 0.3]
[4.5 2.3 1.3 0.3]
[4.4 3.2 1.3 0.2]
[5. 3.5 1.6 0.6]
[5.1 3.8 1.9 0.4]
[4.8 3. 1.4 0.3]
[5.1 3.8 1.6 0.2]
[4.6 3.2 1.4 0.2]
[5.3 3.7 1.5 0.2]
[5. 3.3 1.4 0.2]
[7. 3.2 4.7 1.4]
[6.4 3.2 4.5 1.5]
[6.9 3.1 4.9 1.5]
[5.5 2.3 4. 1.3]
[6.5 2.8 4.6 1.5]
[5.7 2.8 4.5 1.3]
[6.3 3.3 4.7 1.6]
[4.9 2.4 3.3 1. ]
[6.6 2.9 4.6 1.3]
[5.2 2.7 3.9 1.4]
[5. 2. 3.5 1. ]
[5.9 3. 4.2 1.5]
[6. 2.2 4. 1. ]
[6.1 2.9 4.7 1.4]
[5.6 2.9 3.6 1.3]
[6.7 3.1 4.4 1.4]
[5.6 3. 4.5 1.5]
[5.8 2.7 4.1 1. ]
[6.2 2.2 4.5 1.5]
[5.6 2.5 3.9 1.1]
[5.9 3.2 4.8 1.8]
[6.1 2.8 4. 1.3]
[6.3 2.5 4.9 1.5]
[6.1 2.8 4.7 1.2]
[6.4 2.9 4.3 1.3]
[6.6 3. 4.4 1.4]
[6.8 2.8 4.8 1.4]
[6.7 3. 5. 1.7]
[6. 2.9 4.5 1.5]
[5.7 2.6 3.5 1. ]
[5.5 2.4 3.8 1.1]
[5.5 2.4 3.7 1. ]
[5.8 2.7 3.9 1.2]
[6. 2.7 5.1 1.6]
[5.4 3. 4.5 1.5]
[6. 3.4 4.5 1.6]
[6.7 3.1 4.7 1.5]
[6.3 2.3 4.4 1.3]
[5.6 3. 4.1 1.3]
[5.5 2.5 4. 1.3]
[5.5 2.6 4.4 1.2]
[6.1 3. 4.6 1.4]
[5.8 2.6 4. 1.2]
[5. 2.3 3.3 1. ]
[5.6 2.7 4.2 1.3]
[5.7 3. 4.2 1.2]
[5.7 2.9 4.2 1.3]
[6.2 2.9 4.3 1.3]
[5.1 2.5 3. 1.1]
[5.7 2.8 4.1 1.3]
[6.3 3.3 6. 2.5]
[5.8 2.7 5.1 1.9]
[7.1 3. 5.9 2.1]
[6.3 2.9 5.6 1.8]
[6.5 3. 5.8 2.2]
[7.6 3. 6.6 2.1]
[4.9 2.5 4.5 1.7]
[7.3 2.9 6.3 1.8]
[6.7 2.5 5.8 1.8]
[7.2 3.6 6.1 2.5]
[6.5 3.2 5.1 2. ]
[6.4 2.7 5.3 1.9]
[6.8 3. 5.5 2.1]
[5.7 2.5 5. 2. ]
[5.8 2.8 5.1 2.4]
[6.4 3.2 5.3 2.3]
[6.5 3. 5.5 1.8]
[7.7 3.8 6.7 2.2]
[7.7 2.6 6.9 2.3]
[6. 2.2 5. 1.5]
[6.9 3.2 5.7 2.3]
[5.6 2.8 4.9 2. ]
[7.7 2.8 6.7 2. ]
[6.3 2.7 4.9 1.8]
[6.7 3.3 5.7 2.1]
[7.2 3.2 6. 1.8]
[6.2 2.8 4.8 1.8]
[6.1 3. 4.9 1.8]
[6.4 2.8 5.6 2.1]
[7.2 3. 5.8 1.6]
[7.4 2.8 6.1 1.9]
[7.9 3.8 6.4 2. ]
[6.4 2.8 5.6 2.2]
[6.3 2.8 5.1 1.5]
[6.1 2.6 5.6 1.4]
[7.7 3. 6.1 2.3]
[6.3 3.4 5.6 2.4]
[6.4 3.1 5.5 1.8]
[6. 3. 4.8 1.8]
[6.9 3.1 5.4 2.1]
[6.7 3.1 5.6 2.4]
[6.9 3.1 5.1 2.3]
[5.8 2.7 5.1 1.9]
[6.8 3.2 5.9 2.3]
[6.7 3.3 5.7 2.5]
[6.7 3. 5.2 2.3]
[6.3 2.5 5. 1.9]
[6.5 3. 5.2 2. ]
[6.2 3.4 5.4 2.3]
[5.9 3. 5.1 1.8]]
y_data from datasets:
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2]
x_data add index:
花萼长度 花萼宽度 花瓣长度 花瓣宽度
0 5.1 3.5 1.4 0.2
1 4.9 3.0 1.4 0.2
2 4.7 3.2 1.3 0.2
3 4.6 3.1 1.5 0.2
4 5.0 3.6 1.4 0.2
.. ... ... ... ...
145 6.7 3.0 5.2 2.3
146 6.3 2.5 5.0 1.9
147 6.5 3.0 5.2 2.0
148 6.2 3.4 5.4 2.3
149 5.9 3.0 5.1 1.8
[150 rows x 4 columns]
x_data add a column:
花萼长度 花萼宽度 花瓣长度 花瓣宽度 类别
0 5.1 3.5 1.4 0.2 0
1 4.9 3.0 1.4 0.2 0
2 4.7 3.2 1.3 0.2 0
3 4.6 3.1 1.5 0.2 0
4 5.0 3.6 1.4 0.2 0
.. ... ... ... ... ...
145 6.7 3.0 5.2 2.3 2
146 6.3 2.5 5.0 1.9 2
147 6.5 3.0 5.2 2.0 2
148 6.2 3.4 5.4 2.3 2
149 5.9 3.0 5.1 1.8 2
[150 rows x 5 columns]
相关知识
高光谱图像分类工具包:精准识别的利器
鸢尾花果实类型
卷积神经网络实现鸢尾花数据分类python代码实现
2022的七夕,奉上7个精美的表白代码,同时教大家快速改源码自用
推荐文章:探索手写汉字识别的新境界——HCCR
探索C二维码生成与识别的无限可能
Python原生代码实现KNN算法(鸢尾花数据集)
什么是手写识别?手写识别的方法、好处和挑战
【端到端】:从图像到识别:手写数字识别的完整流程
朴素贝叶斯算法对鸢尾花分类
网址: 鸢尾花类型识别的代码(自用) https://m.huajiangbk.com/newsview1570912.html
上一篇: 你和你的女神之间,差了一个Ope |
下一篇: 如何有效教学长春版三年级语文一片 |