首页 > 分享 > KNN算法实例——鸢尾花种类预测

KNN算法实例——鸢尾花种类预测

线性回归的scikit-learn实现

scikit-learn中提供了一个KNeighborClassifier类来实现k近邻法分类模型

方法:

fit(X,y):训练模型
predict:使用模型来预测,返回待预测样本的标记。
score(X,y):返回在(X,y)上预测的准确率。
predict_proba(X):返回样本为每种标记的概率。
kneighbors([X,n_neighbors,return_distance]):返回样本点的k近邻点。如果return_diatance=True,同时还返回到这些近邻点的距离。
kneighbors_graph([X,n_neighbors,model]):返回样本点的连接图。

导入模块:from sklearn.neighbors import KNeighborsClassifier

案例:一个简单的KNN实例

from sklearn.neighbors import KNeighborsClassifier

x=[[1],[2],[3],[4]]

y=[0,0,1,1]

estimator = KNeighborsClassifier(n_neighbors=3)

estimator.fit(x,y)

ret = estimator.predict([[2.51]])

print(ret)

ret1 = estimator.predict([[2.52]])

print(ret1)

ret2 = estimator.predict([[1.50]])

print(ret2)

结果:

案例:鸢尾花种类预测

Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。关于数据集的具体介绍:

可视化显示鸢尾花种类分布:

import seaborn as sns

import matplotlib.pyplot as plt

import pandas as pd

from sklearn.datasets import load_iris,fetch_20newsgroups

from sklearn.model_selection import train_test_split

from sklearn.neighbors import KNeighborsClassifier

from pylab import mpl

mpl.rcParams["font.sans-serif"] = ["SimHei"]

mpl.rcParams["axes.unicode_minus"] = False

def iris_plot(data,col1,col2):

sns.lmplot(x=col1,y=col2,data=data,hue="target",fit_reg=False)

plt.title("鸢尾花数据显示")

plt.show

iris = load_iris()

iris_d = pd.DataFrame(data=iris.data,columns=['Sepal_Lenght','Sepal_Width','Petal_Length','Petal_Width'])

iris_d["target"] = iris.target

iris_plot(iris_d,'Sepal_Width','Petal_Length')

 图像显示:

数据预处理 

在现实生活问题中,我们得到的原始数据往往非常混乱、不全面,机器学习模型往往无法从中有效识别并提取信息。数据和特征决定了机器学习的上限,而模型和算法只是逼近这个上限而已,在采集完数据后,机器学习建模的首要步骤以及主要步骤便是数据预处理。

在此文仅展示归一化与标准化操作:

归一化:

"""

归一化演示

:return:None

"""

from sklearn.preprocessing import MinMaxScaler

data = pd.read_csv("数据文件路径")

print(data)

transfer = MinMaxScaler(feature_range=(3,5))

ret_data = transfer.fit_transform(data[["x","y","time"]])

print("归一化之后的数据为:n",ret_data)

标准化:

"""

标准化演示

:return:None

"""

from sklearn.preprocessing import StandardScaler

data = pd.read_csv("数据文件路径")

print(data)

transfer = StandardScaler()

ret_data = transfer.fit_transform(data[["x","y","time"]])

print("标准化之后的数据为:n",ret_data)

print("每一列的方差为:n",transfer.var_)

print("每一列的平均值为:n",transfer.mean_)

种类预测:

import pandas as pd

from sklearn.preprocessing import StandardScaler

from sklearn.model_selection import train_test_split

from sklearn.neighbors import KNeighborsClassifier

data = load_iris

x_train,x_test,y_train,y_test = train_test_split(iris.data,iris.target,test_size=0.2,random_state=22)

transfer = StandardScaler()

x_train = transfer.fit_transform(x_train)

x_test = transfer.transform(x_test)

estimator = KNeighborsClassifier(n_neighbors=9,algorithm='kd_tree')

estimator.fit(x_train,y_train)

y_predict = estimator.predict(x_test)

print("预测结果为:n",y_predict)

print("比对真实值和预测值:n",y_predict==y_test)

评估结果:

 获取准确率:

score = estimator.score(x_test,y_test)

print("准确率为:n",score)

 模型的保存与加载

import joblib

joblib.dump(estimator,"路径+文件名")

estimator = joblib.load("路径+文件名")

K值调优:

from sklearn.datasets import load_iris

from sklearn.model_selection import train_test_split,GridSearchCV

from sklearn.preprocessing import StandardScaler

from sklearn.neighbors import KNeighborsClassifier

estimator = KNeighborsClassifier(n_neighbors=5)

param_grid = {"n_neighbors": [1,3,5,7]}

estimator = GridSearchCV(estimator,param_grid=param_grid,cv=5)

estimator.fit(x_train,y_train)

y_pre = estimator.predict(x_test)

print("预测值是:n",y_pre)

score = estimator.score(x_test,y_test)

print("准确率为:n",score)

print("在交叉验证中验证的最好结果:n",estimator.best_score_)

print("最好的参数模型:n",estimator.best_estimator_)

print("交叉验证后的准确率结果:n",estimator.cv_results_)

结果:

至此鸢尾花的种类预测功能已基本完成!

如有任何错误或疑问,欢迎评论区留言 

相关知识

KNN算法实现鸢尾花数据集分类
【机器学习】应用KNN实现鸢尾花种类预测
KNN分类算法介绍,用KNN分类鸢尾花数据集(iris)
使用KNN算法对鸢尾花种类预测【百变AI秀】
Python实现KNN算法(鸢尾花数据)
【机器学习】KNN算法实现鸢尾花分类
Knn算法实现鸢尾花分类
Python原生代码实现KNN算法(鸢尾花数据集)
【机器学习】基于KNN算法实现鸢尾花数据集的分类
KNN鸢尾花分类

网址: KNN算法实例——鸢尾花种类预测 https://m.huajiangbk.com/newsview1842735.html

所属分类:花卉
上一篇: 机器学习入门实战1:鸢尾花分类
下一篇: 【机器学习实战入门项目】基于机器