一、概念:
最近邻 (k-Nearest Neighbors, KNN) 算法是一种分类算法, 1968年由 Cover和 Hart 提出, 应用场景有字符识别、 文本分类、 图像识别等领域。
核心思想: 一个样本与数据集中的k个样本最相似, 如果这k个样本中的大多数属于某一个类别, 则该样本也属于这个类别。

——距离度量:
在选择两个实例相似性时,一般使用的欧式距离,又称之为欧几里得度量,它定义于欧几里得空间中。n维空间中两个点x1(x11,x12,…,x1n)与 x2(x21,x22,…,x2n)间的欧氏距离:

——k值选择
如果选择较小的K值,就相当于用较小的邻域中的训练实例进行预测,学习的近似误差会减小,只有与输入实例较近的训练实例才会对预测结果起作用,但学习的估计误差会增大,预测结果会对近邻的实例点分成敏感。如果邻近的实例点恰巧是噪声,预测就会出错。K值减小就意味着整体模型变复杂,分的不清楚,就容易发生过拟合。
如果选择较大K值,就相当于用较大邻域中的训练实例进行预测,其优点是可以减少学习的估计误差,但近似误差会增大,也就是对输入实例预测不准确。
简而言之,K值减小就意味着整体模型变复杂,分的不清楚,就容易发生过拟合。K值得增大就意味着整体模型变的简单。
在实际应用中,K值一般取一个比较小的数值,通常采用交叉验证法来选取最优的K值。
——流程:
1) 计算已知类别数据集中的点与当前点之间的距离
2) 按距离递增次序排序
3) 选取与当前点距离最小的k个点
4) 统计前k个点所在的类别出现的频率
5) 返回前k个点出现频率最高的类别作为当前点的预测分类
——优点:
1、简单有效
2、重新训练代价低
3、算法复杂度低
4、适合类域交叉样本
5、适用大样本自动分类
——缺点:
1、惰性学习
2、类别分类不标准化
3、输出可解释性不强
4、不均衡性
5、计算量较大
概念参考:https://blog.csdn.net/sinat_30353259/article/details/80901746
二、代码实现(丁香花数据集):
import numpy as np import pandas as pd from matplotlib import pyplot as plt %matplotlib inline lilac_data = pd.read_csv("Desktop/ML/syringa.csv") lilac_data.head()
python
运行
1234567
#绘制丁香花子图特征 fig,axes = plt.subplots(2,3,figsize=(20,10)) fig.subplots_adjust(hspace=0.3,wspace=0.2) axes[0,0].set_xlabel("sepal_length") axes[0,0].set_ylabel("sepal_width") axes[0,0].scatter(lilac_data.sepal_length[:50],lilac_data.sepal_width[:50],c="b") axes[0,0].scatter(lilac_data.sepal_length[50:100],lilac_data.sepal_width[50:100],c="g") axes[0,0].scatter(lilac_data.sepal_length[100:],lilac_data.sepal_width[100:],c="r") axes[0,0].legend(["daphne","syinga","willow"],loc=2) axes[0,1].set_xlabel("sepal_length") axes[0,1].set_ylabel("petal_length") axes[0,1].scatter(lilac_data.sepal_length[:50],lilac_data.petal_length[:50],c="b") axes[0,1].scatter(lilac_data.sepal_length[50:100],lilac_data.petal_length[50:100],c="g") axes[0,1].scatter(lilac_data.sepal_length[100:],lilac_data.petal_length[100:],c="r") axes[0,2].set_xlabel("sepal_length") axes[0,2].set_ylabel("petal_width") axes[0,2].scatter(lilac_data.sepal_length[:50],lilac_data.petal_width[:50],c="b") axes[0,2].scatter(lilac_data.sepal_length[50:100],lilac_data.petal_width[50:100],c="g") axes[0,2].scatter(lilac_data.sepal_length[100:],lilac_data.petal_width[100:],c="r") axes[1,0].set_xlabel("sepal_width") axes[1,0].set_ylabel("petal_width") axes[1,0].scatter(lilac_data.sepal_width[:50],lilac_data.petal_width[:50],c="b") axes[1,0].scatter(lilac_data.sepal_width[50:100],lilac_data.petal_width[50:100],c="g") axes[1,0].scatter(lilac_data.sepal_width[100:],lilac_data.petal_width[100:],c="r") axes[1,1].set_xlabel("sepal_width") axes[1,1].set_ylabel("petal_length") axes[1,1].scatter(lilac_data.sepal_width[:50],lilac_data.petal_length[:50],c="b") axes[1,1].scatter(lilac_data.sepal_width[50:100],lilac_data.petal_length[50:100],c="g") axes[1,1].scatter(lilac_data.sepal_width[100:],lilac_data.petal_length[100:],c="r") axes[1,2].set_xlabel("petal_length") axes[1,2].set_ylabel("petal_width") axes[1,2].scatter(lilac_data.petal_length[:50],lilac_data.petal_width[:50],c="b") axes[1,2].scatter(lilac_data.petal_length[50:100],lilac_data.petal_width[50:100],c="g") axes[1,2].scatter(lilac_data.petal_length[100:],lilac_data.petal_width[100:],c="r")
python
运行
12345678910111213141516171819202122232425262728293031323334353637383940
#切分数据集 from sklearn.model_selection import train_test_split feature_data = lilac_data.iloc[:,:-1] label_data = lilac_data["labels"] X_train,X_test,y_train,y_test = train_test_split(feature_data,label_data,test_size=0.3,random_state=2) X_test.head()
python
运行
12345678
#构建KNN模型 from sklearn.neighbors import KNeighborsClassifier def sklearn_classify(train_data,label_data,test_data,k_num): knn = KNeighborsClassifier(n_neighbors=k_num) knn.fit(train_data,label_data) predict_label = knn.predict(test_data) return predict_label y_predict = sklearn_classify(X_train,y_train,X_test,3) print(y_predict) [out]: array(['daphne', 'daphne', 'willow ', 'daphne', 'daphne', 'willow ', 'daphne', 'syringa', 'willow ', 'daphne', 'daphne', 'daphne', 'daphne', 'daphne', 'syringa', 'syringa', 'syringa', 'willow ', 'syringa', 'willow ', 'syringa', 'willow ', 'willow ', 'syringa', 'syringa', 'daphne', 'daphne', 'willow ', 'daphne', 'willow ', 'willow ', 'daphne', 'syringa', 'willow ', 'willow ', 'daphne', 'willow ', 'willow ', 'syringa', 'willow ', 'willow ', 'willow ', 'willow ', 'syringa', 'daphne'], dtype=object) #计算准确率 def get_accuracy(test_labels,pred_labels): correct = np.sum(test_labels == pred_labels) accur = correct/len(test_labels) return accur print(get_accuracy(y_test,y_predict)) [out]: 0.7777777777777778 #测试k在2~20内的准确率 normal_accuracy=[] k_value=range(1,10) for k in k_value: y_predict = sklearn_classify(X_train,y_train,X_test,k) accuracy = get_accuracy(y_test,y_predict) normal_accuracy.append(accuracy) plt.xlabel("k") plt.ylabel("arruracy") plt.yticks(np.linspace(0.6,1,10)) plt.plot(k_value,normal_accuracy,"r") plt.grid(True) #增加网格画布
python
运行
123456789101112131415161718192021222324252627282930313233343536373839404142434445
K等于4、6时,accuracy最高为0.8667,为了节约计算资源,K取局部最优为4。
相关知识
KNN分类算法介绍,用KNN分类鸢尾花数据集(iris)
【机器学习】KNN算法实现鸢尾花分类
基于KNN算法的鸢尾花分类教程
KNN算法实现鸢尾花数据集分类
Knn算法实现鸢尾花分类
KNN算法分类算法
【机器学习】KNN算法实现手写板字迹识别
机器学习入门(一) 之 K近邻算法(KNN算法)
Python实现kNN算法,使用鸢尾花作为测试数据
Python原生代码实现KNN算法(鸢尾花数据集)
网址: KNN算法 https://m.huajiangbk.com/newsview2500194.html
| 上一篇: 系统架构设计——使用结构图分解复 |
下一篇: 一图胜千言!机器学习模型可视化! |