一、决策树简介
决策树(Decision Tree)是一种常见的分类和回归算法,其可处理离散型和连续型数据,在数据挖掘、机器学习等领域被广泛应用。
决策树的结构类似一棵树,每个节点表示一个属性,叶子节点表示一个类别(或回归值),在决策时沿着树结构前进,并根据节点所表示的属性值进行选择,直至到达叶子节点。
二、决策树构建
决策树的构建过程包括分裂属性的选择和树的剪枝两个主要环节。
属性选择
属性选择的目标是找到最优属性,使其能够按照属性值将训练集中的样本划分到正确的类别中,通常使用信息增益(Information Gain)或信息增益比(Information Gain Ratio)来选择划分属性。
信息增益的计算公式如下:
def calc_information_gain(Y, X): info_D = calc_entropy(Y) # 计算数据集的熵 m = X.shape[1] # 特征数 info_Dv = np.zeros((m, 1)) for i in range(m): # 计算按照第i个特征划分后的条件熵 Dv = split_dataset(X, Y, i) info_Dv[i] = calc_cond_entropy(Y, Dv) gain = info_D - info_Dv # 计算信息增益 return gain
其中calc_entropy(Y)计算数据集的熵,calc_cond_entropy(Y, Dv)计算按照某个特征划分后的条件熵,split_dataset(X, Y, i)将数据集按照第i个特征划分。
树的剪枝
决策树的剪枝是为了防止过拟合,采用预剪枝和后剪枝两种方法。
预剪枝是在决策树构建过程中,限制树的大小来防止过拟合,例如限制树的深度或叶子节点的最小样本数。
后剪枝是在决策树构建完成后,对决策树进行剪枝来降低过拟合风险,常用的剪枝方法包括C4.5和CART算法。
三、Python实现决策树
数据集
我们以鸢尾花数据集为例,数据集包含150个样本,每个样本包含4个特征,分别是花萼长度(sepal_length)、花萼宽度(sepal_width)、花瓣长度(petal_length)和花瓣宽度(petal_width),以及类别(setosa、versicolor和virginica)。
首先导入数据集:
import pandas as pd df = pd.read_csv('iris.csv') X = df[['sepal_length', 'sepal_width', 'petal_length', 'petal_width']].to_numpy() Y = df['species'].to_numpy()
将数据集划分为训练集和测试集:
from sklearn.model_selection import train_test_split X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.3, random_state=0)
决策树算法
首先定义决策树节点的类(tree_node.py):
class TreeNode: def __init__(self, feature_idx=None, threshold=None, mode=None, leaf=False): self.feature_idx = feature_idx # 特征索引 self.threshold = threshold # 特征阈值 self.mode = mode # 类别(叶子节点才会有) self.leaf = leaf # 是否为叶子节点 self.children = {} # 子节点列表 def split(self, X, Y, criterion='entropy'): # 根据信息增益选择最优特征 if criterion == 'entropy': gain = calc_information_gain(Y, X) else: # criterion == 'gini' gain = calc_gini_gain(Y, X) i = np.argmax(gain) # 按照最优特征分裂数据集 Dv = split_dataset(X, Y, i) # 如果信息增益为0,返回当前节点 if np.array_equal(gain, np.zeros((X.shape[1], 1))): return self # 创建子节点并递归构建子树 for k, v in Dv.items(): node = TreeNode(leaf=(len(v) == 1)) if not node.leaf: # 非叶子节点 node = node.split(X[v], Y[v], criterion=criterion) # 递归构建子树 else: # 叶子节点 node.mode = Y[v][0] # 叶子节点为样本中最普遍的类别 self.children[k] = node self.feature_idx = i return self
定义决策树的类(decision_tree.py):
class DecisionTree: def __init__(self, criterion='entropy', max_depth=None, min_samples_leaf=1): self.root = None # 决策树根节点 self.criterion = criterion # 划分标准:'entropy'或'gini' self.max_depth = max_depth # 最大深度 self.min_samples_leaf = min_samples_leaf # 叶节点最小样本数 def fit(self, X, Y): self.root = TreeNode(leaf=(len(Y) == 1)) if not self.root.leaf: # 非叶子节点 self.root = self.root.split(X, Y, criterion=self.criterion) # 构建树 else: # 叶子节点 self.root.mode = Y[0] # 叶子节点为样本中最普遍的类别 def predict(self, X): Y_pred = np.array([], dtype=int) for x in X: node = self.root while not node.leaf: i = node.feature_idx if x[i] <= node.threshold: node = node.children[0] else: node = node.children[1] Y_pred = np.append(Y_pred, node.mode) return Y_pred
应用决策树
构建决策树并对数据进行分类(main.py):
tree = DecisionTree(max_depth=3) tree.fit(X_train, Y_train) Y_pred = tree.predict(X_test)
计算分类精度,并绘制决策树(plot_tree.py):
from sklearn.metrics import accuracy_score import matplotlib.pyplot as plt from matplotlib.patches import FancyArrowPatch def plot_node(node_text, center_pt, parent_pt, node_type): arrow_args = dict(arrowstyle="<-") if node_type == 'leaf': create_plot.ax1.annotate(node_text, xy=parent_pt, xycoords='axes fraction', xytext=center_pt, textcoords='axes fraction', va="center", ha="center", bbox=node_type, arrowprops=arrow_args) else: create_plot.ax1.annotate(node_text, xy=parent_pt, xycoords='axes fraction', xytext=center_pt, textcoords='axes fraction', va="center", ha="center", bbox=node_type) def plot_arrow(center_pt, parent_pt, arrow_type): arrow_args = dict(arrowstyle=arrow_type) create_plot.ax1.annotate('', xy=parent_pt, xycoords='axes fraction', xytext=center_pt, textcoords='axes fraction', va="center", ha="center", bbox=dict(), arrowprops=arrow_args) def get_leafs_count(tree_node): if tree_node.leaf: return 1 count = 0 for child in tree_node.children.values(): count += get_leafs_count(child) return count def get_tree_depth(tree_node): if tree_node.leaf: return 1 max_depth = 0 for child in tree_node.children.values(): depth = get_tree_depth(child) + 1 if depth > max_depth: max_depth = depth return max_depth def plot_tree(tree): leaf_nodes_count = get_leafs_count(tree.root) tree_depth = get_tree_depth(tree.root) axprops = dict(xticks=[], yticks=[]) create_plot.ax1 = plt.subplot(111, frameon=False, **axprops) plot_node(str(tree.root.mode), (0.5, 0.8), (0.5, 1.0), 'root') plot_tree_node(tree.root, (0.5, 0.8), leaf_nodes_count, tree_depth) plt.axis('off') plt.show() def plot_tree_node(tree_node, parent_center, leaf_nodes_count, tree_depth): if tree_node.leaf: # 叶子节点 return h_unit = 1.0 / tree_depth v_unit = 1.0 / leaf_nodes_count height = 0 for k, child in tree_node.children.items(): center = (parent_center[0] - v_unit, parent_center[1] - h_unit) plot_arrow(center, parent_center, '<-') if child.leaf: # 叶子节点 node_type = {'fc': '0.8', 'ec': 'black', 'boxstyle': 'round'} node_text = str(child.mode) else: # 非叶子节点 node_type = {'fc': '0.8', 'ec': 'black'} node_text = 'X[{}] <= {:.2f}'.format(child.feature_idx, child.threshold) plot_node(node_text, center, parent_center, node_type) plot_tree_node(child, center, leaf_nodes_count, tree_depth) height += 1
运行结果如下:
accuracy: 0.9778
四、小结
本文介绍了Python实现决策树的方法,包括决策树的构建过程、属性选择、树的剪枝方法等,同时提供了完整的代码实现。在实际应用中,决策树算法可以用于分类和回归等场景,并且能够处理离散型和连续型数据,是数据挖掘、机器学习等领域的重要算法。