首页 > 分享 > Python语言基于CART决策树的鸢尾花数据分类

Python语言基于CART决策树的鸢尾花数据分类

1.数据集的获取。

使用SCIKIT-LEARN的自带的鸢尾花数据集,获取数据集.

2.数据集的划分。

基于hold-out法,构建训练集与测试集并且确保训练集与测试集内各类别占比一致。

要求:训练集80%,测试集20%。

3. 模型的学习。

利用训练集,学习两种复杂程度不同的CART分类树(用深度控制),可视化分类树的学习结果,并给出每一棵树的特征重要性评分。

4. 基于测试集的分类树的评价。

(1)结合测试集各样本的类别预测结果及真实类别答案,生成混淆矩阵,并可视化混淆矩阵

(2)基于混淆矩阵,估计每个类别的查准率、查全率、F1值,以及宏查准率、宏查全率、宏F1值;估计总体预测正确率.

5. 使用整个数据集学习上述两种不同深度的分类树, 可视化

源码如下:

import pandas as pd

from matplotlib import pyplot as plt

from pandas.core.common import random_state

from sklearn.datasets import load_iris

from sklearn.tree import plot_tree

from sklearn.tree import DecisionTreeClassifier

from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score

from sklearn.model_selection import train_test_split

import seaborn as sns

iris = load_iris()

X, y = iris.data, iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=random_state())

dtree_shallow = DecisionTreeClassifier(max_depth=2)

dtree_shallow.fit(X_train, y_train)

dtree_deep = DecisionTreeClassifier(max_depth=4)

dtree_deep.fit(X_train, y_train)

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 7))

plot_tree(dtree_shallow, filled=True, rounded=True, ax=axes[0], feature_names=iris.feature_names,

class_names=iris.target_names)

axes[0].set_title('Shallow Decision Tree')

plot_tree(dtree_deep, filled=True, rounded=True, ax=axes[1], feature_names=iris.feature_names,

class_names=iris.target_names)

axes[1].set_title('Deep Decision Tree')

plt.show()

def show_score(tree):

importance = tree.feature_importances_

for i, v in enumerate(importance):

print('Feature: %0d, Score: %.5f' % (i, v))

print("深度为2的决策树评分为:")

show_score(dtree_shallow)

print("深度为4的决策树评分为:")

show_score(dtree_deep)

def show_confusion_matrix(tree, title):

y_pred = tree.predict(X_test)

cm = confusion_matrix(y_test, y_pred)

df_cm = pd.DataFrame(cm)

ax = sns.heatmap(df_cm, annot=True, cmap="Purples")

ax.set_title(title)

ax.set_xlabel('predict target')

ax.set_ylabel('true target')

plt.show()

show_confusion_matrix(dtree_shallow, 'Confusion Matrix of ShallowTree')

show_confusion_matrix(dtree_deep, 'Confusion Matrix of DeepTree')

def show_performance_measurement(tree):

y_pred = tree.predict(X_test)

precision = precision_score(y_test, y_pred, average=None)

recall = recall_score(y_test, y_pred, average=None)

f1 = f1_score(y_test, y_pred, average=None)

accuracy = accuracy_score(y_test, y_pred)

macro_precision = precision_score(y_test, y_pred, average='macro')

macro_recall = recall_score(y_test, y_pred, average='macro')

macro_f1 = f1_score(y_test, y_pred, average='macro')

print(f'Precision: {precision}')

print(f'Recall: {recall}')

print(f'F1 score: {f1}')

print(f'Accuracy: {accuracy}')

print(f'Macro Precision: {macro_precision}')

print(f'Macro Recall: {macro_recall}')

print(f'Macro F1 score: {macro_f1}')

print("深度为2的决策树性能度量指标:")

show_performance_measurement(dtree_shallow)

print("深度为4的决策树性能度量指标:")

show_performance_measurement(dtree_shallow)

X = iris.data

y = iris.target

tree1 = DecisionTreeClassifier(max_depth=2)

tree1.fit(X, y)

tree2 = DecisionTreeClassifier(max_depth=4)

tree2.fit(X, y)

plt.figure(figsize=(15, 7))

plt.subplot(1, 2, 1)

plot_tree(tree1, feature_names=iris.feature_names, class_names=iris.target_names, filled=True)

plt.title('Decision Tree with max depth 2')

plt.subplot(1, 2, 2)

plot_tree(tree2, feature_names=iris.feature_names, class_names=iris.target_names, filled=True)

plt.title('Decision Tree with max depth 4')

plt.show()

运行结果与输出图片:

深度为2的决策树评分为:

Feature: 0, Score: 0.00000

Feature: 1, Score: 0.00000

Feature: 2, Score: 0.00000

Feature: 3, Score: 1.00000

深度为4的决策树评分为:

Feature: 0, Score: 0.01875

Feature: 1, Score: 0.01875

Feature: 2, Score: 0.05648

Feature: 3, Score: 0.90602

深度为2的决策树性能度量指标:

Precision: [1. 0.9 0.9]

Recall: [1. 0.9 0.9]

F1 score: [1. 0.9 0.9]

Accuracy: 0.9333333333333333

Macro Precision: 0.9333333333333332

Macro Recall: 0.9333333333333332

Macro F1 score: 0.9333333333333332

深度为4的决策树性能度量指标:

Precision: [1. 0.9 0.9]

Recall: [1. 0.9 0.9]

F1 score: [1. 0.9 0.9]

Accuracy: 0.9333333333333333

Macro Precision: 0.9333333333333332

Macro Recall: 0.9333333333333332

Macro F1 score: 0.9333333333333332

进程已结束,退出代码0

相关知识

python利用c4.5决策树对鸢尾花卉数据集进行分类(iris)
基于决策树构建鸢尾花数据的分类模型并绘制决策树模型
【2016年第1期】基于大数据的小麦蚜虫发生程度决策树预测分类模型
基于python编程的五种鲜花识别
【机器学习】R语言实现随机森林、支持向量机、决策树多方法二分类模型
第一个机器学习项目(鸢尾花分类问题)
python实践gcForest模型对鸢尾花数据集iris进行分类
TensorFlow 2建立神经网络分类模型——以iris数据为例
基于BP神经网络对鸢尾花的分类的研究
人工智能毕业设计基于python的花朵识别系统

网址: Python语言基于CART决策树的鸢尾花数据分类 https://m.huajiangbk.com/newsview387250.html

所属分类:花卉
上一篇: 【===会唱歌的鸢尾花===摄影
下一篇: 绿萝什么意思