您的位置:

Python实现决策树

一、决策树简介

决策树(Decision Tree)是一种常见的分类和回归算法,其可处理离散型和连续型数据,在数据挖掘、机器学习等领域被广泛应用。

决策树的结构类似一棵树,每个节点表示一个属性,叶子节点表示一个类别(或回归值),在决策时沿着树结构前进,并根据节点所表示的属性值进行选择,直至到达叶子节点。

二、决策树构建

决策树的构建过程包括分裂属性的选择和树的剪枝两个主要环节。

属性选择

属性选择的目标是找到最优属性,使其能够按照属性值将训练集中的样本划分到正确的类别中,通常使用信息增益(Information Gain)或信息增益比(Information Gain Ratio)来选择划分属性。

信息增益的计算公式如下:

def calc_information_gain(Y, X):
    info_D = calc_entropy(Y)  # 计算数据集的熵
    m = X.shape[1]  # 特征数
    info_Dv = np.zeros((m, 1))
    for i in range(m):
        # 计算按照第i个特征划分后的条件熵
        Dv = split_dataset(X, Y, i)
        info_Dv[i] = calc_cond_entropy(Y, Dv)
    gain = info_D - info_Dv  # 计算信息增益
    return gain

其中calc_entropy(Y)计算数据集的熵,calc_cond_entropy(Y, Dv)计算按照某个特征划分后的条件熵,split_dataset(X, Y, i)将数据集按照第i个特征划分。

树的剪枝

决策树的剪枝是为了防止过拟合,采用预剪枝和后剪枝两种方法。

预剪枝是在决策树构建过程中,限制树的大小来防止过拟合,例如限制树的深度或叶子节点的最小样本数。

后剪枝是在决策树构建完成后,对决策树进行剪枝来降低过拟合风险,常用的剪枝方法包括C4.5和CART算法。

三、Python实现决策树

数据集

我们以鸢尾花数据集为例,数据集包含150个样本,每个样本包含4个特征,分别是花萼长度(sepal_length)、花萼宽度(sepal_width)、花瓣长度(petal_length)和花瓣宽度(petal_width),以及类别(setosa、versicolor和virginica)。

首先导入数据集:

import pandas as pd

df = pd.read_csv('iris.csv')
X = df[['sepal_length', 'sepal_width', 'petal_length', 'petal_width']].to_numpy()
Y = df['species'].to_numpy()

将数据集划分为训练集和测试集:

from sklearn.model_selection import train_test_split

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.3, random_state=0)

决策树算法

首先定义决策树节点的类(tree_node.py):

class TreeNode:
    def __init__(self, feature_idx=None, threshold=None, mode=None, leaf=False):
        self.feature_idx = feature_idx  # 特征索引
        self.threshold = threshold  # 特征阈值
        self.mode = mode  # 类别(叶子节点才会有)
        self.leaf = leaf  # 是否为叶子节点
        self.children = {}  # 子节点列表

    def split(self, X, Y, criterion='entropy'):
        # 根据信息增益选择最优特征
        if criterion == 'entropy':
            gain = calc_information_gain(Y, X)
        else:  # criterion == 'gini'
            gain = calc_gini_gain(Y, X)
        i = np.argmax(gain)
        # 按照最优特征分裂数据集
        Dv = split_dataset(X, Y, i)
        # 如果信息增益为0,返回当前节点
        if np.array_equal(gain, np.zeros((X.shape[1], 1))):
            return self
        # 创建子节点并递归构建子树
        for k, v in Dv.items():
            node = TreeNode(leaf=(len(v) == 1))
            if not node.leaf:  # 非叶子节点
                node = node.split(X[v], Y[v], criterion=criterion)  # 递归构建子树
            else:  # 叶子节点
                node.mode = Y[v][0]  # 叶子节点为样本中最普遍的类别
            self.children[k] = node
            self.feature_idx = i
        return self

定义决策树的类(decision_tree.py):

class DecisionTree:
    def __init__(self, criterion='entropy', max_depth=None, min_samples_leaf=1):
        self.root = None  # 决策树根节点
        self.criterion = criterion  # 划分标准:'entropy'或'gini'
        self.max_depth = max_depth  # 最大深度
        self.min_samples_leaf = min_samples_leaf  # 叶节点最小样本数

    def fit(self, X, Y):
        self.root = TreeNode(leaf=(len(Y) == 1))
        if not self.root.leaf:  # 非叶子节点
            self.root = self.root.split(X, Y, criterion=self.criterion)  # 构建树
        else:  # 叶子节点
            self.root.mode = Y[0]  # 叶子节点为样本中最普遍的类别

    def predict(self, X):
        Y_pred = np.array([], dtype=int)
        for x in X:
            node = self.root
            while not node.leaf:
                i = node.feature_idx
                if x[i] <= node.threshold:
                    node = node.children[0]
                else:
                    node = node.children[1]
            Y_pred = np.append(Y_pred, node.mode)
        return Y_pred

应用决策树

构建决策树并对数据进行分类(main.py):

tree = DecisionTree(max_depth=3)
tree.fit(X_train, Y_train)
Y_pred = tree.predict(X_test)

计算分类精度,并绘制决策树(plot_tree.py):

from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch

def plot_node(node_text, center_pt, parent_pt, node_type):
    arrow_args = dict(arrowstyle="<-")
    if node_type == 'leaf':
        create_plot.ax1.annotate(node_text, xy=parent_pt, xycoords='axes fraction',
                                 xytext=center_pt, textcoords='axes fraction',
                                 va="center", ha="center", bbox=node_type, arrowprops=arrow_args)
    else:
        create_plot.ax1.annotate(node_text, xy=parent_pt, xycoords='axes fraction',
                                 xytext=center_pt, textcoords='axes fraction',
                                 va="center", ha="center", bbox=node_type)

def plot_arrow(center_pt, parent_pt, arrow_type):
    arrow_args = dict(arrowstyle=arrow_type)
    create_plot.ax1.annotate('', xy=parent_pt, xycoords='axes fraction',
                             xytext=center_pt, textcoords='axes fraction',
                             va="center", ha="center", bbox=dict(), arrowprops=arrow_args)

def get_leafs_count(tree_node):
    if tree_node.leaf:
        return 1
    count = 0
    for child in tree_node.children.values():
        count += get_leafs_count(child)
    return count

def get_tree_depth(tree_node):
    if tree_node.leaf:
        return 1
    max_depth = 0
    for child in tree_node.children.values():
        depth = get_tree_depth(child) + 1
        if depth > max_depth:
            max_depth = depth
    return max_depth

def plot_tree(tree):
    leaf_nodes_count = get_leafs_count(tree.root)
    tree_depth = get_tree_depth(tree.root)
    axprops = dict(xticks=[], yticks=[])
    create_plot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plot_node(str(tree.root.mode), (0.5, 0.8), (0.5, 1.0), 'root')
    plot_tree_node(tree.root, (0.5, 0.8), leaf_nodes_count, tree_depth)
    plt.axis('off')
    plt.show()

def plot_tree_node(tree_node, parent_center, leaf_nodes_count, tree_depth):
    if tree_node.leaf:  # 叶子节点
        return
    h_unit = 1.0 / tree_depth
    v_unit = 1.0 / leaf_nodes_count
    height = 0
    for k, child in tree_node.children.items():
        center = (parent_center[0] - v_unit, parent_center[1] - h_unit)
        plot_arrow(center, parent_center, '<-')
        if child.leaf:  # 叶子节点
            node_type = {'fc': '0.8', 'ec': 'black', 'boxstyle': 'round'}
            node_text = str(child.mode)
        else:  # 非叶子节点
            node_type = {'fc': '0.8', 'ec': 'black'}
            node_text = 'X[{}] <= {:.2f}'.format(child.feature_idx, child.threshold)
        plot_node(node_text, center, parent_center, node_type)
        plot_tree_node(child, center, leaf_nodes_count, tree_depth)
        height += 1

运行结果如下:

accuracy: 0.9778

四、小结

本文介绍了Python实现决策树的方法,包括决策树的构建过程、属性选择、树的剪枝方法等,同时提供了完整的代码实现。在实际应用中,决策树算法可以用于分类和回归等场景,并且能够处理离散型和连续型数据,是数据挖掘、机器学习等领域的重要算法。