数据集包含150个数据集(其中120个是训练集iris_training.csv,30个是测试集iris_test.csv),分为3类(Setosa,Versicolour,Virginica),每类50个数据,每个数据包含4个属性:花萼长度,花萼宽度,花瓣长度,花瓣宽度。
120,4,setosa,versicolor,virginica 6.4,2.8,5.6,2.2,2 5.0,2.3,3.3,1.0,1 4.9,2.5,4.5,1.7,2 . . . . . . . . . . . . . . . 4.4,2.9,1.4,0.2,0 4.8,3.0,1.4,0.1,0 5.5,2.4,3.7,1.0,1 12345678910
由于标签是鸢尾花的类别,因此将标签转换成独热编码[1, 0, 0], [0, 1, 0], [0, 0, 1]
可以用pandas包里的get_dummies()函数实现,我这里是自己写的代码
train = np.array(pd.read_csv('D:/tensorflow_exercise/data/iris_training.csv') # 读取训练集数据 x_train = train[:, 0:4] # x_train是训练集的特征 rows = train.shape[0] y_train = np.array(np.zeros([rows, 3])) # 将标签转化为one-hot编码 for r in range(rows): label = int(train[r][4]) y_train[r][label] = 1 1234567
同理,测试集的话只需要修改读取路径和设置变量x_test和y_test即可
2、建立模型采用tensorflow建立一个简单的线性模型:
W = tf.Variable(tf.zeros([4, 3])) b = tf.Variable(tf.zeros([3]) + 0.01) output = tf.nn.softmax(tf.matmul(xs, W) + b) 123
输出接一个softmax函数
损失函数为交叉熵:
loss = -tf.reduce_sum(ys * tf.log(output + 1e-10)) 1
采用梯度下降法最小化loss,学习率设置为0.001:
train_step = tf.train.GradientDescentOptimizer(0.001).minimize(loss) 1 3、模型训练
模型建立好后,通过tf.global_variables_initializer()对变量进行初始化
模型总共训练1000次,每100次输出loss查看训练过程
for i in range(1000): sess.run(train_step, feed_dict={xs: x_train, ys: y_train}) if i % 100 == 0: print('Loss(train set):%.2f' % (sess.run(loss, feed_dict={xs: x_train, ys: y_train}))) 1234 4、鸢尾花种类预测
模型训练完毕之后,即可将测试集输入模型进行预测。由于预测结果是独热编码,所以准确率计算使用tf.argmax()函数来实现。返回值是预测结果中最大值的索引,由于独热编码的性质,返回的索引值即为类别。
然后使用tf.equal()判断是否与实际类别一致(返回值为bool型)。所以需要通过一个tf.cast()函数来转换为[0, 1]值,最后取平均值求出准确率。[这个是直接参考的tensorflow中文教程]
access = tf.equal(tf.argmax(output, 1), tf.argmax(ys, 1)) accuracy = tf.reduce_mean(tf.cast(access, "float")) 12 5、结果
最后在训练集上以及测试集都得到一个较满意的结果
--------------------开始训练模型---------------- Loss(train set):125.14 Loss(train set):67.55 Loss(train set):30.55 Loss(train set):23.07 Loss(train set):20.45 Loss(train set):18.60 Loss(train set):17.22 Loss(train set):16.14 Loss(train set):15.28 Loss(train set):14.57 --------------------训练结束-------------------- ********************性能评价******************** 训练集准确率: 0.975 测试集准确率: 0.96666664
1234567891011121314151617附上详细代码:
import pandas as pd import numpy as np import tensorflow as tf def main(): # 读取训练集数据 train = np.array(pd.read_csv('D:/tensorflow_exercise/data/iris_training.csv')) x_train = train[:, 0:4] rows = train.shape[0] y_train = np.array(np.zeros([rows, 3])) # 将标签转化为one-hot编码 for r in range(rows): label = int(train[r][4]) y_train[r][label] = 1 # 读取测试集数据 test = np.array(pd.read_csv('D:/tensorflow_exercise/data/iris_test.csv')) x_test = test[:, 0:4] rows = test.shape[0] y_test = np.array(np.zeros([rows, 3])) for r in range(rows): label = int(test[r][4]) y_test[r][label] = 1 xs = tf.placeholder(dtype='float', shape=[None, 4]) ys = tf.placeholder(dtype='float', shape=[None, 3]) W = tf.Variable(tf.zeros([4, 3])) b = tf.Variable(tf.zeros([3]) + 0.01) output = tf.nn.softmax(tf.matmul(xs, W) + b) # 输出加个softmax层 loss = -tf.reduce_sum(ys * tf.log(output + 1e-10)) # 损失函数用交叉熵 train_step = tf.train.GradientDescentOptimizer(0.001).minimize(loss) # 梯度下降法最小化损失函数 access = tf.equal(tf.argmax(output, 1), tf.argmax(ys, 1)) accuracy = tf.reduce_mean(tf.cast(access, "float")) init = tf.global_variables_initializer() sess = tf.Session() sess.run(init) print('--------------------开始训练模型--------------------') for i in range(1000): sess.run(train_step, feed_dict={xs: x_train, ys: y_train}) if i % 100 == 0: print('Loss(train set):%.2f' % (sess.run(loss, feed_dict={xs: x_train, ys: y_train}))) print('--------------------训练结束--------------------nn') print('************************性能评价************************') print('训练集准确率:', sess.run(accuracy, {xs: x_train, ys: y_train})) print('测试集准确率:', sess.run(accuracy, {xs: x_test, ys: y_test})) sess.close() if __name__ == '__main__': main()
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253相关知识
深度学习入门——基于TensorFlow的鸢尾花分类实现(TensorFlow
基于TensorFlow实现LSTM的鸢尾花数据分类
TensorFlow学习记录(八)
TensorFlow使用BP神经网络实现鸢尾花分类
TensorFlow 2建立神经网络分类模型——以iris数据为例
基于BP神经网络对鸢尾花的分类的研究
基于神经网络——鸢尾花识别(Iris)
基于TensorFlow训练花朵识别模型的源码和Demo
Tensorflow鸢尾花分类(数据加载与特征处理)
使用鸢尾花数据集构建神经网络模型
网址: 基于tensorflow实现的鸢尾花预测模型 https://m.huajiangbk.com/newsview1354157.html
上一篇: 沁园春·身边花草系列之九·鸢尾 |
下一篇: 鸾尾花和鸢尾花一样吗 |