HelloAI
L2 第 4 篇 🐣 难度 🕒 11 分钟

决策树:最直观可解释的 ML 算法

一连串"是不是"问题,组成一棵决策的树。它是 XGBoost、LightGBM 等竞赛王者算法的基石。

阿莱
2026/6/27

逻辑回归很强,但有个缺陷:它只能学线性决策边界

如果数据本质上是非线性的——比如”年龄 < 18 或 > 60 时风险高,中间低”——逻辑回归就不太行。

决策树完美解决这种问题——它通过一连串”if-else”判断做分类。最像人类思维的 ML 算法

一个例子

任务:判断一个人是否会买某款手机。

人类销售可能这么想:

预算 > 5000 元吗?
├── 是:年龄 < 35?
│       ├── 是 → 大概率买(高消费年轻人)
│       └── 否:是否换机超过 2 年了?
│               ├── 是 → 中概率买
│               └── 否 → 大概率不买
└── 否:性别?
        ├── 男 → 看是否游戏玩家
        ├── 女 → 看是否摄影爱好者
        └── ...

这就是决策树——一棵从根到叶的判断树。

树的结构

元素含义
根节点(root)第一个判断
内部节点(internal)中间的每个判断
叶节点(leaf)最终的预测
分支(branch)“是/否”的两条路

每个内部节点问一个关于某个特征的问题:“这个特征 > 某个值吗?“

怎么”训”出一棵树

直觉:在每个节点,找最能”分开”两类样本的那个特征+阈值

衡量”分得好不好”的指标有两个常用的:

1. 基尼不纯度(Gini Impurity)

G=1kpk2G = 1 - \sum_k p_k^2

pkp_k 是该节点里属于类 kk 的比例。

  • 全是同一类(纯):G=0G = 0
  • 类别均匀混合(最不纯):GG 最大

2. 信息熵(Entropy)

H=kpklogpkH = -\sum_k p_k \log p_k

L1-05 我们讲过。熵越大越不纯。

训练时:在每个节点,遍历所有特征 + 所有可能阈值,找让”子节点不纯度”最低的那个划分

这是一个贪心算法——不保证全局最优,但实际效果通常很好。

用 sklearn 跑一遍

from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

# 经典鸢尾花数据集
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# 训练
tree = DecisionTreeClassifier(max_depth=3, random_state=42)
tree.fit(X_train, y_train)

# 评估
print(f"准确率: {tree.score(X_test, y_test):.3f}")

# 可视化整棵树
plt.figure(figsize=(15, 8))
plot_tree(tree, filled=True, feature_names=['花萼长','花萼宽','花瓣长','花瓣宽'],
          class_names=['Setosa', 'Versicolor', 'Virginica'])
plt.savefig('tree.png')

跑完会得到一棵可视化的树,你能直接读出”模型是怎么决策的”——这是其它算法没有的可解释性。

几个重要超参数

参数作用
max_depth树的最大深度(防过拟合的最重要参数)
min_samples_split一个节点要分裂至少需要多少样本
min_samples_leaf叶节点最少多少样本
criterion’gini’ 或 ‘entropy’

经验:先用 max_depth=5 试,过拟合就降到 3,欠拟合就加到 10。

决策树的杀手锏:可解释性

tree.feature_importances_ 给你每个特征的重要度(每次划分时它贡献了多少信息增益):

import pandas as pd

feature_names = ['花萼长','花萼宽','花瓣长','花瓣宽']
importances = pd.Series(tree.feature_importances_, index=feature_names)
print(importances.sort_values(ascending=False))

# 花瓣长     0.91
# 花瓣宽     0.07
# 花萼长     0.02
# 花萼宽     0.00

模型告诉你:花瓣长度几乎决定了一切

这种可解释性在以下场景必须

  • 银行信贷模型(监管要求”为什么拒贷”)
  • 医疗诊断辅助(医生要看 reasoning)
  • 司法量刑参考(法律可解释)

真实情况:欧盟 GDPR 第 22 条明确禁止”完全自动化、对个人有重大影响、且无法解释”的决策。这一条法律单独让决策树(及衍生方法)在金融/医疗领域永远有需求

树能处理什么

天生处理良好

  • 数字特征 + 类别特征混合(不用 one-hot 编码)
  • 特征不同量纲(不用标准化)
  • 非线性关系
  • 缺失值(sklearn 实现就支持)
  • 多类别分类(不用搞 one-vs-rest)

这让决策树预处理特别简单——直接 fit。

决策树的致命弱点

过拟合: 单棵树深了就会”死记硬背”——把训练数据里的每个噪声都建模。

不稳定: 训练数据稍微变一点,树的结构可能完全不同。

难学到对角决策边界: 它只能做”竖切”或”横切”,对角线得分很多步逼近,不优雅。

单棵性能有限: 单决策树在标准基准上往往输给逻辑回归、SVM。

但是……集成起来无敌

接下来的故事很精彩——

单棵树弱,但让很多棵树投票,效果会强大到逆天

算法思路
随机森林训 N 棵独立的树,投票表决
GBDT / XGBoost / LightGBM树 1 学一部分;树 2 学树 1 留下的错;树 3 学树 1+2 的错;累加

这两套思路(Bagging 和 Boosting)的具体讲解在 L2-05XGBoost / LightGBM 是 Kaggle 竞赛过去 10 年的霸主——比深度学习还能打。

💡 一个工业事实

在很多业务里——XGBoost 至今仍然比神经网络效果好。 特别是结构化表格数据(Excel 风格的):

  • 用户特征 + 行为
  • 金融数据
  • 医疗特征
  • 推荐场景的 ranking

如果你的数据是表格 + 数千到数百万样本,先试 XGBoost,再考虑神经网络

一个完整的示例

from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import classification_report

# 真实数据:乳腺癌分类
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 训练
tree = DecisionTreeClassifier(max_depth=4, random_state=42)
tree.fit(X_train, y_train)

# 交叉验证(更稳定的评估)
cv_scores = cross_val_score(tree, X_train, y_train, cv=5)
print(f"交叉验证准确率: {cv_scores.mean():.3f} ± {cv_scores.std():.3f}")

# 测试集报告
y_pred = tree.predict(X_test)
print(classification_report(y_test, y_pred))

跑完通常能拿到 93% 左右准确率——单棵树,5 行代码

下一篇我们用随机森林把它推到 97%。

下一篇:《随机森林 + 集成学习:让一群树投票》

📬

读到这里说明你认真在学 🎯

订阅每周精选 —— 下一篇新文章 / 新可视化第一时间送到邮箱。

💬

讨论区

· 用 GitHub 账号登录评论
⚠️ Giscus 评论未配置 —— 在 src/components/Comments.astro 顶部填入 仓库 ID 和分类 ID(见组件注释里的配置步骤)。