老饼讲解-机器学习 机器学习 神经网络 深度学习
逻辑回归与决策树

【代码】复现sklearn中CCP路径的实现代码(python)

作者 : 老饼 发表日期 : 2022-06-26 13:56:47 更新日期 : 2023-03-01 14:46:27
本站原创文章,转载请说明来自《老饼讲解-机器学习》www.bbbdata.com


在本例,我们扒出sklearn中CART决策树的CCP路径计算的逻辑,

并重新以简化的代码,展示它是如何计算的。

通过本文,更清晰sklearn计算CCP路径时是怎么计算的。



   01.sklearn自带决策树的CCP路径计算   


我们先展示一个sklearn包自带的cost_complexity_pruning_path的例子

在本文第二小节我们再自写代码与该结果比较


   例子代码   


# -*- coding: utf-8 -*-
from sklearn.datasets import load_iris
from sklearn import tree

#----------------数据准备----------------------------
iris = load_iris()    # 加载数据
X = iris.data
y = iris.target
#---------------模型训练---------------------------------
clf = tree.DecisionTreeClassifier(min_samples_split=10,ccp_alpha=0)        
clf = clf.fit(X, y)     

#-------sklearn决策树计算ccp路径-----------------------
pruning_path = clf.cost_complexity_pruning_path(X, y)

#-------打印结果---------------------------    
 
print("\n====sklearn的CCP路径=================")
print("ccp_alphas:",pruning_path['ccp_alphas'])
print("impurities:",pruning_path['impurities']) 



   运行结果   


====sklearn的CCP路径=================
ccp_alphas: [0.         0.00415459 0.01305556 0.02966049 0.25979603 0.33333333]
impurities: [0.02666667 0.03082126 0.04387681 0.07353731 0.33333333 0.66666667]



   02.自写决策树的CCP路径计算   


上面的例子是调用sklearn的cost_complexity_pruning_path函数进行计算CCP路径,

本节我们抛开sklearn工具包,自行计算CCP路径。


   自写CCP计算代码   


# -*- coding: utf-8 -*-
from sklearn.datasets import load_iris
from sklearn import tree
import numpy as np
'''
----------------------------------------
本代码来自《老饼讲解-机器学习》www.bbbdata.com
本代码主要用于理解和学习决策树后剪枝的CCP路径计算
本代码重写sklearn中的决策树算法包提供的cost_complexity_pruning_path一致
结果与cost_complexity_pruning_path的默认设置一致。
------------------------------------------------------
'''

'''
get_tree的作用:将sklearn树模型的关键数据提取,存成更简单的对象
-----------
说明:
impurity 就是Gini系数 
impurity[i]=1-((value[i]/value[i].sum()) **2).sum()
'''
def get_tree(sk_tree):
    #--------------拷贝sklearn树模型关键信息--------------------
    children_left    = clf.tree_.children_left.copy()     # 左节点编号
    children_right   = clf.tree_.children_right.copy()    # 右节点编号
    feature          = clf.tree_.feature.copy()           # 分割的变量
    threshold        = clf.tree_.threshold.copy()         # 分割阈值
    impurity         = clf.tree_.impurity.copy()          # 不纯度(gini)
    n_node_samples   = clf.tree_.n_node_samples.copy()    # 样本个数
    value            = clf.tree_.value.copy()             # 样本分布
    n_sample         = value[0].sum()                     # 总样本个数
    node_num         = len(children_left)                 # 节点个数
    
    # ------------补充节点父节点信息---------------------------
    parent = np.zeros(node_num).astype(int)
    parent[0] = -1
    branch_idx = np.where(children_left!=-1)[0]
    for i in branch_idx:
        parent[children_left[i]] = i   
        parent[children_right[i]]= i 
    #-------------存成字典-----------------------------------------    
    tree = {
        'children_left':children_left
        ,'children_right':children_right
        ,'feature':feature
        ,'threshold':threshold
        ,'impurity':impurity
        ,'n_node_samples':n_node_samples
        ,'value':value
        ,'n_sample':n_sample
        ,'node_num':node_num
        ,'parent':parent
        }
    return tree


#-----------查找指定节点的所有子节点---------------------------------
def find_child_node(tree,node_idx):
    child_list = []
    un_use= [node_idx]
    while(len(un_use)>0):
        cur_node = un_use.pop()
        left_idx = tree['children_left'][cur_node]
        right_idx = tree['children_right'][cur_node]
        child_list.extend([left_idx,right_idx])
        if( tree['children_left'][left_idx]!=-1):
            un_use.append(left_idx)

        if( tree['children_right'][right_idx]!=-1):
            un_use.append(right_idx)
    return child_list   


#-----计算每个分枝节点的临界alpha和最小临界alpha所在节点-----------
def cal_branch_alpha(tree):
    #------------------数据准备与变量初始化-------------------------------------------
    node_num          = tree['node_num']
    alpha_list        = np.zeros(node_num)+float('inf')                 # 初始化临界alpha
    sub_leaf_num      = np.zeros(node_num)                              # 初始化子节点个数向量
    sub_leaf_impurity = np.zeros(node_num)                              # 初始化节点下所有叶子节点的不纯度总和
    node_impurity     = (tree['n_node_samples']/tree['n_sample']) * tree['impurity'] # 节点的加权不纯度
    branch_idx        = np.where(tree['children_left']!=-1)[0]          # 分枝节点
    leaf_idx          = np.where(tree['children_left']==-1)[0]          # 叶子节点
    
    #----------------# 计算sub_leaf_num和sub_leaf_impurity---------------------------
    for i in leaf_idx:      
        parent_idx = i
        while(parent_idx !=-1):
            sub_leaf_impurity[parent_idx] +=node_impurity[i]
            sub_leaf_num[parent_idx] +=1
            parent_idx = tree['parent'][parent_idx]
            
    # ------用sub_leaf_num和sub_leaf_impurity计算临界alpha---------------------
    alpha_list[branch_idx] = np.maximum((node_impurity[branch_idx] 
                           - sub_leaf_impurity[branch_idx])/(sub_leaf_num[branch_idx]-1),0)
    return alpha_list


#-----------计算树的不纯度-----------------------------------------
def cal_tree_impurities(tree):
    leaf_idx        = np.where(tree['children_left']==-1)[0]
    node_impurity   =  (tree['n_node_samples']/tree['n_sample']) * tree['impurity']
    tree_impurities = node_impurity[leaf_idx].sum()
    return tree_impurities


#----------节点剪枝---------------------------------------------------
def prune_nodes(tree,node_list):
    # 找出本次要删的所有节点编号(即剪枝节点的所有子孙节点)
    sub_node_list = []
    for i in node_list:
        cur_child_list = find_child_node(tree,i)
        sub_node_list.extend(cur_child_list)
    sub_node_list = list(set(sub_node_list))
    # 将剪枝节点置为子节点
    tree['children_left'][node_list]  = -1
    tree['children_right'][node_list] = -1
    tree['feature'][node_list]        = -2
    tree['threshold'][node_list]      = -2
    # 删除所有子孙节点的信息
    tree['children_left']  = np.delete(tree['children_left']  ,sub_node_list)
    tree['children_right'] = np.delete(tree['children_right'] ,sub_node_list)
    tree['feature']        = np.delete(tree['feature']        ,sub_node_list)
    tree['threshold']      = np.delete(tree['threshold']      ,sub_node_list)
    tree['impurity']       = np.delete(tree['impurity']       ,sub_node_list)
    tree['n_node_samples'] = np.delete(tree['n_node_samples'] ,sub_node_list)
    tree['value']          = np.delete(tree['value']          ,sub_node_list,axis=0)
    tree['parent']         = np.delete(tree['parent']         ,sub_node_list)
    tree['node_num']       = len(tree['children_left'])
    # 对剩余节点编号移位
    sub_node_list.sort(reverse = True)
    for i in sub_node_list:
        tree['children_left'][tree['children_left']>i]   -=1
        tree['children_right'][tree['children_right']>i]  -=1


# 计算CCP路径            
def cal_ccp_path(tree):
    ccp_alphas_list = [0]                                                # 初始化临界alpha
    impurities_list = [cal_tree_impurities(tree)]                        # 初始alpha对应的树不纯度
    prune_node_list = [[]]                                               # 初始化临界alpha对应删除的节点
    while(tree['node_num']>1):                                           # 逐轮对最小临界alpha节点剪枝
        alpha_list  = cal_branch_alpha(tree)                             # 计算各节点的临界alpha
        min_node    = list(np.where(alpha_list==alpha_list.min())[0])    # 最小临界alpha所在的分枝节点
        prune_nodes(tree,min_node)                                       # 对 最小临界alpha所在的分枝节点剪枝
        tree_impurities = cal_tree_impurities(tree)                      # 计算树的纯度
        ccp_alphas_list.append(alpha_list.min())                         # 记录本轮最小临界alpha
        impurities_list.append(tree_impurities)                          # 记录本轮树纯度
        prune_node_list.append(min_node)                                 # 记录本轮所剪节点
    return ccp_alphas_list,prune_node_list,impurities_list

#----------------数据准备----------------------------
iris = load_iris()    # 加载数据
X = iris.data
y = iris.target
#---------------模型训练---------------------------------
clf = tree.DecisionTreeClassifier(min_samples_split=10,ccp_alpha=0)        
clf = clf.fit(X, y)     


#-------自写代码计算的ccp路径--------------------------- 
tree = get_tree(clf)   # 将sklearn决策树转换成自己定义的树,再进行路径计算
ccp_alphas_list,prune_node_list,impurities_list = cal_ccp_path(tree)


#-------打印结果---------------------------    
print("\n====自行计算的CCP路径=================")
print("ccp_alphas_list:",np.around(ccp_alphas_list,8))
print("impurities_list:",np.around(impurities_list,8))
print("prune_node_list:",prune_node_list)

代码运行版本:conda 4.10.3


   运行结果   


====自行计算的CCP路径=================
ccp_alphas_list: [0.         0.00415459 0.01305556 0.02966049 0.25979603 0.33333333]
impurities_list: [0.02666667 0.03082126 0.04387681 0.07353731 0.33333333 0.66666667]
prune_node_list: [[], [8], [4], [3], [2], [0]]

可见,自行计算方法与sklearn包中结果一致。






 End 







联系老饼