1.数据集的获取。
使用SCIKIT-LEARN的自带的鸢尾花数据集,获取数据集.
2.数据集的划分。
基于hold-out法,构建训练集与测试集并且确保训练集与测试集内各类别占比一致。
要求:训练集80%,测试集20%。
3. 模型的学习。
利用训练集,学习两种复杂程度不同的CART分类树(用深度控制),可视化分类树的学习结果,并给出每一棵树的特征重要性评分。
4. 基于测试集的分类树的评价。
(1)结合测试集各样本的类别预测结果及真实类别答案,生成混淆矩阵,并可视化混淆矩阵
(2)基于混淆矩阵,估计每个类别的查准率、查全率、F1值,以及宏查准率、宏查全率、宏F1值;估计总体预测正确率.
5. 使用整个数据集学习上述两种不同深度的分类树, 可视化。
源码如下:
import pandas as pd
from matplotlib import pyplot as plt
from pandas.core.common import random_state
from sklearn.datasets import load_iris
from sklearn.tree import plot_tree
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score
from sklearn.model_selection import train_test_split
import seaborn as sns
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=random_state())
dtree_shallow = DecisionTreeClassifier(max_depth=2)
dtree_shallow.fit(X_train, y_train)
dtree_deep = DecisionTreeClassifier(max_depth=4)
dtree_deep.fit(X_train, y_train)
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 7))
plot_tree(dtree_shallow, filled=True, rounded=True, ax=axes[0], feature_names=iris.feature_names,
class_names=iris.target_names)
axes[0].set_title('Shallow Decision Tree')
plot_tree(dtree_deep, filled=True, rounded=True, ax=axes[1], feature_names=iris.feature_names,
class_names=iris.target_names)
axes[1].set_title('Deep Decision Tree')
plt.show()
def show_score(tree):
importance = tree.feature_importances_
for i, v in enumerate(importance):
print('Feature: %0d, Score: %.5f' % (i, v))
print("深度为2的决策树评分为:")
show_score(dtree_shallow)
print("深度为4的决策树评分为:")
show_score(dtree_deep)
def show_confusion_matrix(tree, title):
y_pred = tree.predict(X_test)
cm = confusion_matrix(y_test, y_pred)
df_cm = pd.DataFrame(cm)
ax = sns.heatmap(df_cm, annot=True, cmap="Purples")
ax.set_title(title)
ax.set_xlabel('predict target')
ax.set_ylabel('true target')
plt.show()
show_confusion_matrix(dtree_shallow, 'Confusion Matrix of ShallowTree')
show_confusion_matrix(dtree_deep, 'Confusion Matrix of DeepTree')
def show_performance_measurement(tree):
y_pred = tree.predict(X_test)
precision = precision_score(y_test, y_pred, average=None)
recall = recall_score(y_test, y_pred, average=None)
f1 = f1_score(y_test, y_pred, average=None)
accuracy = accuracy_score(y_test, y_pred)
macro_precision = precision_score(y_test, y_pred, average='macro')
macro_recall = recall_score(y_test, y_pred, average='macro')
macro_f1 = f1_score(y_test, y_pred, average='macro')
print(f'Precision: {precision}')
print(f'Recall: {recall}')
print(f'F1 score: {f1}')
print(f'Accuracy: {accuracy}')
print(f'Macro Precision: {macro_precision}')
print(f'Macro Recall: {macro_recall}')
print(f'Macro F1 score: {macro_f1}')
print("深度为2的决策树性能度量指标:")
show_performance_measurement(dtree_shallow)
print("深度为4的决策树性能度量指标:")
show_performance_measurement(dtree_shallow)
X = iris.data
y = iris.target
tree1 = DecisionTreeClassifier(max_depth=2)
tree1.fit(X, y)
tree2 = DecisionTreeClassifier(max_depth=4)
tree2.fit(X, y)
plt.figure(figsize=(15, 7))
plt.subplot(1, 2, 1)
plot_tree(tree1, feature_names=iris.feature_names, class_names=iris.target_names, filled=True)
plt.title('Decision Tree with max depth 2')
plt.subplot(1, 2, 2)
plot_tree(tree2, feature_names=iris.feature_names, class_names=iris.target_names, filled=True)
plt.title('Decision Tree with max depth 4')
plt.show()
运行结果与输出图片:
深度为2的决策树评分为:
Feature: 0, Score: 0.00000
Feature: 1, Score: 0.00000
Feature: 2, Score: 0.00000
Feature: 3, Score: 1.00000
深度为4的决策树评分为:
Feature: 0, Score: 0.01875
Feature: 1, Score: 0.01875
Feature: 2, Score: 0.05648
Feature: 3, Score: 0.90602
深度为2的决策树性能度量指标:
Precision: [1. 0.9 0.9]
Recall: [1. 0.9 0.9]
F1 score: [1. 0.9 0.9]
Accuracy: 0.9333333333333333
Macro Precision: 0.9333333333333332
Macro Recall: 0.9333333333333332
Macro F1 score: 0.9333333333333332
深度为4的决策树性能度量指标:
Precision: [1. 0.9 0.9]
Recall: [1. 0.9 0.9]
F1 score: [1. 0.9 0.9]
Accuracy: 0.9333333333333333
Macro Precision: 0.9333333333333332
Macro Recall: 0.9333333333333332
Macro F1 score: 0.9333333333333332
进程已结束,退出代码0
相关知识
python利用c4.5决策树对鸢尾花卉数据集进行分类(iris)
基于决策树构建鸢尾花数据的分类模型并绘制决策树模型
【2016年第1期】基于大数据的小麦蚜虫发生程度决策树预测分类模型
基于python编程的五种鲜花识别
【机器学习】R语言实现随机森林、支持向量机、决策树多方法二分类模型
第一个机器学习项目(鸢尾花分类问题)
python实践gcForest模型对鸢尾花数据集iris进行分类
TensorFlow 2建立神经网络分类模型——以iris数据为例
基于BP神经网络对鸢尾花的分类的研究
人工智能毕业设计基于python的花朵识别系统
网址: Python语言基于CART决策树的鸢尾花数据分类 https://m.huajiangbk.com/newsview387250.html
上一篇: 【===会唱歌的鸢尾花===摄影 |
下一篇: 绿萝什么意思 |