引言:从二分类到多分类
逻辑回归是机器学习中最基础也最重要的算法之一,但初学者常常困惑:逻辑回归明明是二分类算法,如何能处理多分类问题呢?本文将带你深入了解逻辑回归的多分类策略,并通过完整的鸢尾花分类代码实现。一、多分类问题无处不在
在我们的日常生活和工作中,多分类问题比比皆是:邮件分类:工作(y=1)、朋友(y=2)、家庭(y=3)、爱好(y=4)天气预测:晴天(y=1)、多云(y=2)、雨天(y=3)、雪天(y=4)医疗诊断:健康(y=1)、感冒(y=2)、流感(y=3)这些场景都需要算法能够区分多个类别,而逻辑回归通过巧妙的扩展就能胜任这些任务。
二、One-vs-All策略揭秘
1. 核心思想
One-vs-All(一对多,也称为One-vs-Rest)策略将多分类问题转化为多个二分类问题: 对于N个类别,训练N个独立的二分类器第i个分类器将第i类作为正类,其余所有类别作为负类预测时,选择所有分类器中预测概率最高的类别2. 数学表达
对于第i类,我们的假设函数为:
h θ ( i ) ( x ) = P ( y = i ∣ x ; θ ) h_theta^{(i)}(x) = P(y = i|x;theta) hθ(i)(x)=P(y=i∣x;θ)
预测时选择:
max i h θ ( i ) ( x ) max_i h_theta^{(i)}(x) imaxhθ(i)(x)
三、鸢尾花分类完整实现
使用Python和scikit-learn库完整实现鸢尾花的多分类任务。1. 环境准备
# 导入必要的库 import numpy as np import pandas as pd from sklearn.datasets import load_iris from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split from sklearn.metrics import (accuracy_score,confusion_matrix,classification_report) import matplotlib.pyplot as plt import seaborn as sns from sklearn.multiclass import OneVsRestClassifier from sklearn.preprocessing import StandardScaler
2. 数据加载与探索
# 设置中文字体 plt.rcParams['font.sans-serif'] = ['SimHei'] # 或 'Microsoft YaHei' plt.rcParams['axes.unicode_minus'] = False# 加载鸢尾花数据集 iris = load_iris() # 特征名称: ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'] # sepal length:花萼长度 sepal width:花萼宽度 petal length: 花瓣长度 petal width: 花瓣宽度 X = iris.data # 特征矩阵 (150, 4) # 目标类别: ['setosa' 'versicolor' 'virginica'] # setosa 山鸢尾 versicolor 变色鸢尾 virginica 维吉尼亚鸢尾 y = iris.target # 标签 (150,)# 查看特征名称和目标类别 print("特征名称:", iris.feature_names) print("目标类别:", iris.target_names)# 将数据转换为DataFrame便于可视化 iris_df = pd.DataFrame(X, columns=iris.feature_names) iris_df['species'] = y iris_df['species'] = iris_df['species'].map({0: 'setosa', 1: 'versicolor', 2: 'virginica'})# 绘制特征分布图 sns.pairplot(iris_df, hue='species', palette='husl') plt.suptitle("鸢尾花特征分布", y=1.02) plt.show()
特征名称: ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'] 目标类别: ['setosa' 'versicolor' 'virginica']
3. 数据预处理
# 划分训练集和测试集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)print(f"训练集样本数: {len(X_train)}") print(f"测试集样本数: {len(X_test)}")# 特征标准化(逻辑回归通常需要) scaler = StandardScaler() X_train = scaler.fit_transform(X_train) X_test = scaler.transform(X_test)
训练集样本数: 120 测试集样本数: 30
4. 模型训练与评估
# 构建逻辑回归模型 log_reg = LogisticRegression(C=1000, # 正则化强度的倒数solver='sag', # 随机平均梯度下降max_iter=1000, # 最大迭代次数random_state=42 ) # 使用 OneVsRestClassifier 包装 ovr_classifier = OneVsRestClassifier(log_reg)# 训练模型 ovr_classifier.fit(X_train, y_train)# 在训练集和测试集上评估 train_acc = ovr_classifier.score(X_train, y_train) test_acc = ovr_classifier.score(X_test, y_test)print(f"训练集准确率: {train_acc:.2%}") print(f"测试集准确率: {test_acc:.2%}")# 更详细的评估报告 y_pred = ovr_classifier.predict(X_test) print("n分类报告:") print(classification_report(y_test, y_pred, target_names=iris.target_names))# 绘制混淆矩阵 cm = confusion_matrix(y_test, y_pred) plt.figure(figsize=(6,6)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',xticklabels=iris.target_names,yticklabels=iris.target_names) plt.xlabel('预测标签') plt.ylabel('真实标签') plt.title('混淆矩阵') plt.show()
训练集准确率: 96.67% 测试集准确率: 96.67%分类报告:precision recall f1-score supportsetosa 1.00 1.00 1.00 10versicolor 1.00 0.90 0.95 10virginica 0.91 1.00 0.95 10accuracy 0.97 30macro avg 0.97 0.97 0.97 30 weighted avg 0.97 0.97 0.97 30
5. 决策边界可视化
# 为可视化,只使用两个主要特征 X_train_2d = X_train[:, :2] X_test_2d = X_test[:, :2]# 重新训练一个2D模型 log_reg_2d = LogisticRegression(C=1000,solver='sag',max_iter=2000,random_state=42 )ovr_classifier_2d = OneVsRestClassifier(log_reg_2d) ovr_classifier_2d.fit(X_train_2d, y_train) # 必须先训练模型# 创建网格点 x_min, x_max = X_train_2d[:, 0].min() - 1, X_train_2d[:, 0].max() + 1 y_min, y_max = X_train_2d[:, 1].min() - 1, X_train_2d[:, 1].max() + 1 xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02), np.arange(y_min, y_max, 0.02))# 预测每个网格点的类别 Z = ovr_classifier_2d.predict(np.c_[xx.ravel(), yy.ravel()]) Z = Z.reshape(xx.shape)# 预测每个网格点的类别 Z = ovr_classifier_2d.predict(np.c_[xx.ravel(), yy.ravel()]) Z = Z.reshape(xx.shape)# 绘制决策边界 plt.figure(figsize=(10, 6)) plt.contourf(xx, yy, Z, alpha=0.4, cmap='Pastel2') scatter = plt.scatter(X_train_2d[:, 0], X_train_2d[:, 1], c=y_train,cmap='Dark2', edgecolor='black')# 添加图例和标签 legend_elements = scatter.legend_elements()[0] plt.legend(legend_elements,iris.target_names,title="鸢尾花种类") plt.xlabel(iris.feature_names[0]) plt.ylabel(iris.feature_names[1]) plt.title("逻辑回归多分类决策边界") plt.show()
四、关键参数解析
在构建逻辑回归模型时的重要参数:
C=1000:正则化强度的倒数,较小的值表示更强的正则化。这里设为较大的值,相当于减少正则化。multi_class=‘ovr’:指定使用One-vs-Rest策略处理多分类问题。scikit-learn还支持’multinomial’选项,使用softmax函数直接进行多分类。solver=‘sag’:优化算法选择随机平均梯度下降(Stochastic Average Gradient),适合大数据集。其他可选算法包括:‘liblinear’:适合小数据集;‘newton-cg’:牛顿法;‘lbfgs’:拟牛顿法。max_iter=1000:最大迭代次数,确保模型能够收敛。