本站原创文章,转载请说明来自《老饼讲解-机器学习》www.bbbdata.com
在本例,我们扒出sklearn中CART决策树的CCP路径计算的逻辑,
并重新以简化的代码,展示它是如何计算的。
通过本文,更清晰sklearn计算CCP路径时是怎么计算的。
我们先展示一个sklearn包自带的cost_complexity_pruning_path的例子
在本文第二小节我们再自写代码与该结果比较
例子代码
# -*- coding: utf-8 -*-
from sklearn.datasets import load_iris
from sklearn import tree
#----------------数据准备----------------------------
iris = load_iris() # 加载数据
X = iris.data
y = iris.target
#---------------模型训练---------------------------------
clf = tree.DecisionTreeClassifier(min_samples_split=10,ccp_alpha=0)
clf = clf.fit(X, y)
#-------sklearn决策树计算ccp路径-----------------------
pruning_path = clf.cost_complexity_pruning_path(X, y)
#-------打印结果---------------------------
print("\n====sklearn的CCP路径=================")
print("ccp_alphas:",pruning_path['ccp_alphas'])
print("impurities:",pruning_path['impurities'])
运行结果
====sklearn的CCP路径=================
ccp_alphas: [0. 0.00415459 0.01305556 0.02966049 0.25979603 0.33333333]
impurities: [0.02666667 0.03082126 0.04387681 0.07353731 0.33333333 0.66666667]
上面的例子是调用sklearn的cost_complexity_pruning_path函数进行计算CCP路径,
本节我们抛开sklearn工具包,自行计算CCP路径。
自写CCP计算代码
# -*- coding: utf-8 -*-
from sklearn.datasets import load_iris
from sklearn import tree
import numpy as np
'''
----------------------------------------
本代码来自《老饼讲解-机器学习》www.bbbdata.com
本代码主要用于理解和学习决策树后剪枝的CCP路径计算
本代码重写sklearn中的决策树算法包提供的cost_complexity_pruning_path一致
结果与cost_complexity_pruning_path的默认设置一致。
------------------------------------------------------
'''
'''
get_tree的作用:将sklearn树模型的关键数据提取,存成更简单的对象
-----------
说明:
impurity 就是Gini系数
impurity[i]=1-((value[i]/value[i].sum()) **2).sum()
'''
def get_tree(sk_tree):
#--------------拷贝sklearn树模型关键信息--------------------
children_left = clf.tree_.children_left.copy() # 左节点编号
children_right = clf.tree_.children_right.copy() # 右节点编号
feature = clf.tree_.feature.copy() # 分割的变量
threshold = clf.tree_.threshold.copy() # 分割阈值
impurity = clf.tree_.impurity.copy() # 不纯度(gini)
n_node_samples = clf.tree_.n_node_samples.copy() # 样本个数
value = clf.tree_.value.copy() # 样本分布
n_sample = value[0].sum() # 总样本个数
node_num = len(children_left) # 节点个数
# ------------补充节点父节点信息---------------------------
parent = np.zeros(node_num).astype(int)
parent[0] = -1
branch_idx = np.where(children_left!=-1)[0]
for i in branch_idx:
parent[children_left[i]] = i
parent[children_right[i]]= i
#-------------存成字典-----------------------------------------
tree = {
'children_left':children_left
,'children_right':children_right
,'feature':feature
,'threshold':threshold
,'impurity':impurity
,'n_node_samples':n_node_samples
,'value':value
,'n_sample':n_sample
,'node_num':node_num
,'parent':parent
}
return tree
#-----------查找指定节点的所有子节点---------------------------------
def find_child_node(tree,node_idx):
child_list = []
un_use= [node_idx]
while(len(un_use)>0):
cur_node = un_use.pop()
left_idx = tree['children_left'][cur_node]
right_idx = tree['children_right'][cur_node]
child_list.extend([left_idx,right_idx])
if( tree['children_left'][left_idx]!=-1):
un_use.append(left_idx)
if( tree['children_right'][right_idx]!=-1):
un_use.append(right_idx)
return child_list
#-----计算每个分枝节点的临界alpha和最小临界alpha所在节点-----------
def cal_branch_alpha(tree):
#------------------数据准备与变量初始化-------------------------------------------
node_num = tree['node_num']
alpha_list = np.zeros(node_num)+float('inf') # 初始化临界alpha
sub_leaf_num = np.zeros(node_num) # 初始化子节点个数向量
sub_leaf_impurity = np.zeros(node_num) # 初始化节点下所有叶子节点的不纯度总和
node_impurity = (tree['n_node_samples']/tree['n_sample']) * tree['impurity'] # 节点的加权不纯度
branch_idx = np.where(tree['children_left']!=-1)[0] # 分枝节点
leaf_idx = np.where(tree['children_left']==-1)[0] # 叶子节点
#----------------# 计算sub_leaf_num和sub_leaf_impurity---------------------------
for i in leaf_idx:
parent_idx = i
while(parent_idx !=-1):
sub_leaf_impurity[parent_idx] +=node_impurity[i]
sub_leaf_num[parent_idx] +=1
parent_idx = tree['parent'][parent_idx]
# ------用sub_leaf_num和sub_leaf_impurity计算临界alpha---------------------
alpha_list[branch_idx] = np.maximum((node_impurity[branch_idx]
- sub_leaf_impurity[branch_idx])/(sub_leaf_num[branch_idx]-1),0)
return alpha_list
#-----------计算树的不纯度-----------------------------------------
def cal_tree_impurities(tree):
leaf_idx = np.where(tree['children_left']==-1)[0]
node_impurity = (tree['n_node_samples']/tree['n_sample']) * tree['impurity']
tree_impurities = node_impurity[leaf_idx].sum()
return tree_impurities
#----------节点剪枝---------------------------------------------------
def prune_nodes(tree,node_list):
# 找出本次要删的所有节点编号(即剪枝节点的所有子孙节点)
sub_node_list = []
for i in node_list:
cur_child_list = find_child_node(tree,i)
sub_node_list.extend(cur_child_list)
sub_node_list = list(set(sub_node_list))
# 将剪枝节点置为子节点
tree['children_left'][node_list] = -1
tree['children_right'][node_list] = -1
tree['feature'][node_list] = -2
tree['threshold'][node_list] = -2
# 删除所有子孙节点的信息
tree['children_left'] = np.delete(tree['children_left'] ,sub_node_list)
tree['children_right'] = np.delete(tree['children_right'] ,sub_node_list)
tree['feature'] = np.delete(tree['feature'] ,sub_node_list)
tree['threshold'] = np.delete(tree['threshold'] ,sub_node_list)
tree['impurity'] = np.delete(tree['impurity'] ,sub_node_list)
tree['n_node_samples'] = np.delete(tree['n_node_samples'] ,sub_node_list)
tree['value'] = np.delete(tree['value'] ,sub_node_list,axis=0)
tree['parent'] = np.delete(tree['parent'] ,sub_node_list)
tree['node_num'] = len(tree['children_left'])
# 对剩余节点编号移位
sub_node_list.sort(reverse = True)
for i in sub_node_list:
tree['children_left'][tree['children_left']>i] -=1
tree['children_right'][tree['children_right']>i] -=1
# 计算CCP路径
def cal_ccp_path(tree):
ccp_alphas_list = [0] # 初始化临界alpha
impurities_list = [cal_tree_impurities(tree)] # 初始alpha对应的树不纯度
prune_node_list = [[]] # 初始化临界alpha对应删除的节点
while(tree['node_num']>1): # 逐轮对最小临界alpha节点剪枝
alpha_list = cal_branch_alpha(tree) # 计算各节点的临界alpha
min_node = list(np.where(alpha_list==alpha_list.min())[0]) # 最小临界alpha所在的分枝节点
prune_nodes(tree,min_node) # 对 最小临界alpha所在的分枝节点剪枝
tree_impurities = cal_tree_impurities(tree) # 计算树的纯度
ccp_alphas_list.append(alpha_list.min()) # 记录本轮最小临界alpha
impurities_list.append(tree_impurities) # 记录本轮树纯度
prune_node_list.append(min_node) # 记录本轮所剪节点
return ccp_alphas_list,prune_node_list,impurities_list
#----------------数据准备----------------------------
iris = load_iris() # 加载数据
X = iris.data
y = iris.target
#---------------模型训练---------------------------------
clf = tree.DecisionTreeClassifier(min_samples_split=10,ccp_alpha=0)
clf = clf.fit(X, y)
#-------自写代码计算的ccp路径---------------------------
tree = get_tree(clf) # 将sklearn决策树转换成自己定义的树,再进行路径计算
ccp_alphas_list,prune_node_list,impurities_list = cal_ccp_path(tree)
#-------打印结果---------------------------
print("\n====自行计算的CCP路径=================")
print("ccp_alphas_list:",np.around(ccp_alphas_list,8))
print("impurities_list:",np.around(impurities_list,8))
print("prune_node_list:",prune_node_list)
代码运行版本:conda 4.10.3
运行结果
====自行计算的CCP路径=================
ccp_alphas_list: [0. 0.00415459 0.01305556 0.02966049 0.25979603 0.33333333]
impurities_list: [0.02666667 0.03082126 0.04387681 0.07353731 0.33333333 0.66666667]
prune_node_list: [[], [8], [4], [3], [2], [0]]
可见,自行计算方法与sklearn包中结果一致。
End