老饼讲解-机器学习 机器学习 神经网络 深度学习
机器学习入门
1.学前解惑
2.第一课:初探模型
3.第二课:逻辑回归与梯度下降
4.第三课:决策树
5.第四课:逻辑回归与决策树补充
6.第五课:常见的其它算法
7.第六课:综合应用

【代码】CART决策树代码(自实现)

作者 : 老饼 发表日期 : 2022-06-26 04:00:22 更新日期 : 2023-12-19 16:48:42
本站原创文章,转载请说明来自《老饼讲解-机器学习》www.bbbdata.com


本文展示如何使用python自行实现CART决策树(非调包)和CCP剪树,

算法逻辑来自matlab的fitctree函数,亲测结果与fitctree一致

通过代码的理解,可以更进一步理解CART决策树的算法逻辑和实现细节




  01. CART决策树代码-简介  



本节介绍本文的CART决策树代码实现的内容



   CART决策树-代码功能与算法来源   


本代码实现CART决策树的构建,预测和CCP剪枝功能。
 
代码的算法流程来自matlab自带的决策树包,构建的结果与matlab的fitctree一致
最后,代码中附加一个使用Demo,展示代码各个函数功能的使用
CART决策树的相关原理可见:
CART决策树模型简介与实例
CART决策树算法流程          




   CART决策树代码-运行后执行的内容   


代码运行后,就是运行Demo函数,
Demo函数主要是用sklearn自带的iris数据构训练一棵决策树,并进行剪枝,预测等操作
具体实现内容如下:
1、数据生成:使用sklearn自带的iris数据                                          
2、用iris数据构建一棵完全生长的决策树                                           
3、用构建好的决策树对样本进行预测,并打印出预测错误的样本编号 
4、将树去掉无效子节点                                                                    
5、计算各个节点的临界alpha                                                           
6、计算并打印CCP路径                                                                    
7、设定alpha=0.1,进行剪枝                                                             
8、根据节点编号进行剪枝                                                                




   02. CART决策树代码   



本节展示用python自实现一棵CART决策树的代码



    自实现CART决策树的代码    


用python自实现CART决策树(分类树)的代码如下:
# -*- coding: utf-8 -*-
"""
决策树CART(分类树)自实现
PASS:算法来源来matlab决策树工具包,笔者亲测,与软件包结果一致。
本代码来自老饼讲解-机器学习:www.bbbdata.com
"""
from sklearn.datasets import load_iris
import numpy as np
from copy import copy

# 节点类:决策树即为一串的树节点连接
class Node(object):
    def __init__(self,cMat,sample_idx):
        sample_class   = cMat[sample_idx].sum(axis=0)          # 统计各类别样本个数
        class_num    = (sample_class>0).sum()            # 类别个数
        node_class   = sample_class.argmax()             # 节点类别
        sample_num   = len(sample_idx)                # 样本个数
        err_num      = sample_num-sample_class[node_class]   # 错误样本个数
        
        self.id       = None                           # 节点ID,删除节点时可通过该值删除
        self.is_leaf    = 0                            # 是否是叶子节点
        self.left_node   = None                           # 左节点(也是Node类)
        self.right_node  = None                           # 右节点(也是Node类)
        self.cut_var    = None                           # 分割变量
        self.cut_val    = None                           # 分割值


        self.sample_idx   = sample_idx                      # 属于该节点的样本索引
        self.sample_num   = sample_num                      # 属于该节点的样本个数
        self.sample_class   = sample_class                     # 属于该节点的样本属于各类的个数,如[30,10]代表0类30个,1类10个
        self.class_num    = class_num                      # 属于该节点的样本类别个数
        self.node_class   = node_class                      # 该节点被判为属于哪一类别 


        self.err_num       = err_num                      # 节点上判断错误的样本个数
        self.leaf_errnum     = 0                         # 该节点下的叶子节点的总错误样本个数,用于计算alpha
        self.leaf_nodenum     = 0                         # 该节点下的叶子节点个数,用于计算alpha
        self.alpha        = None                        # 剪枝系数alpha
    
    # 将节点改为叶子节点    
    def be_leaf(self):       
        self.left_node   = None
        self.right_node   = None
        self.cut_var    = None
        self.cut_val    = None
        self.is_leaf    = 1
    
    # 将节点深拷贝(采用递归拷贝)    
    def copy(self):
        if(self.is_leaf==0):
            left_node    = self.left_node.copy()
            right_node   = self.right_node.copy()
            new_node   = copy(self)
            new_node.sample_idx   = self.sample_idx.copy()
            new_node.sample_class   = self.sample_class.copy()
            new_node.left_node    = left_node
            new_node.right_node   = right_node
        if(self.is_leaf==1):
          new_node=copy(self)
          new_node.sample_idx   = self.sample_idx.copy()
          new_node.sample_class   = self.sample_class.copy()
        return new_node

# 计算使用变量x,切值为cut_val时的收益函数(gini)        
def cal_gain(x,cMat,cut_val):
    is_left     = x<=cut_val
    left_rate   = is_left.sum()/len(is_left)               # 左节点样本个数占比
    right_rate   = 1- left_rate                      # 右节点样本个数占比
    left_cMat    = cMat[is_left]                     # 左节点类别
    right_cMat    = cMat[~is_left]                    # 右节点类别
    p_left     = left_cMat.sum(axis=0)/left_cMat.sum()    # 左节点各类别占比
    p_right     = right_cMat.sum(axis=0)/right_cMat.sum() if right_cMat.sum()>0 else  right_cMat.sum(axis=0)  


    g_left  = 1- (p_left**2).sum()    # 左节点基尼系数
    g_right  = 1- (p_right**2).sum()   # 右节点基尼系数
    gain    = -(left_rate)*g_left - (right_rate)*g_right # 收益值:-左右基尼系数加权和
    return gain

# 找出x变量的最佳切割点
def find_best_cut(x,cMat):
     unique_x = np.unique(x)
     best_cut_val = (unique_x[0]+unique_x[1])/2 if len(unique_x)>1 else unique_x[0]
     best_g = cal_gain(x,cMat,best_cut_val)
     for i in range(1,len(unique_x)-1):
         cut_val = (unique_x[i]+unique_x[i+1])/2 
         g = cal_gain(x,cMat,cut_val)
         if g>best_g:
             best_cut_val = cut_val
             best_g = g
     return best_g,best_cut_val

# 将类别转为类别矩阵
def class2Cmat(y):
    c_name= np.unique(y)
    c_num = len(c_name)
    cMat = np.zeros([len(y),c_num])
    for i in range(c_num):
        cMat[y==c_name[i],i]=1 
    return cMat,c_name
 
# 删除无用叶子: 叶子不能降低判别误差,则删
def prune_bad(node):
    un_use_node = [node]   # 从根节点开始
    delete_bad = 0         # 初始化是否删除过节点
    while(len(un_use_node)>0):
        cur_node = un_use_node.pop()
        if(cur_node.is_leaf==1):    # 如果是叶子节点,不必判断
            pass 
        elif((cur_node.left_node.is_leaf==1) &(cur_node.right_node.is_leaf==1)):             # 如果左右都是叶子,则判断是否删除
           if((cur_node.left_node.err_num+ cur_node.right_node.err_num) >=cur_node.err_num): # 叶子节点不能降低判别误差,则删
               cur_node.be_leaf()    # 将节点置为叶子节点
               delete_bad = 1        # 标记:已删过节点
        else:                        # 如果不是叶子,也不是倒算第二层节点,则把叶子添加到判断列表
           un_use_node.append(cur_node.left_node)  
           un_use_node.append(cur_node.right_node) 
    if(delete_bad==1):               # 如果删过节点,则将树重新判断
       prune_bad(node) 

# 给节点及其子孙设置id序号       
def set_node_id(node,next_id):
     node.id = next_id
     if(node.is_leaf==0):
         next_id = set_node_id(node.left_node,next_id+1)
         set_node_id(node.right_node,next_id)
     return next_id+1

# 计算各个节点的临界alpha,并返回最小临界alpha
def cal_alpha(node,total_num=None):
    total_num = node.sample_num if(total_num is None) else total_num
    if(node.is_leaf==0):  # 如果不是叶子节点,获取左右节点下的叶子节点的错误样本个数,并计算临界alpha
       left_err,left_nodenum,left_min_alpha    = cal_alpha(node.left_node,total_num)         # 获取左节点下的叶子节点的错误样本个数
       right_err,right_nodenum,right_min_alpha = cal_alpha(node.right_node,total_num)        # 获取右节点下的叶子节点的错误样本个数
       node.leaf_errnum  = left_err + right_err                                              # 计算当前节点下的叶子节点的错误样本个数
       node.leaf_nodenum = left_nodenum + right_nodenum                                      # 计算当前了点下的叶子节点个数
       node.alpha = max(((node.err_num-node.leaf_errnum)/total_num)/(node.leaf_nodenum-1),0) # 计算临界alpha
       min_alpha  = min(left_min_alpha,right_min_alpha,node.alpha)                           # 记录最小临界alpha
    else:    # 如果是叶子节点,直接获取叶子节点错误个数
        node.leaf_errnum  = node.err_num  
        node.leaf_nodenum = 1
        min_alpha = float('inf')
    return node.leaf_errnum,node.leaf_nodenum,min_alpha

# 对临界alpha<=alpha的节点剪枝
def prune_alpha(node,alpha):
    prune_list=[]
    if(node.is_leaf==1):
        return prune_list
    if(node.alpha<=alpha):
        prune_list.append(node.id)
        left_prune_list  = prune_alpha(node.left_node,0)
        right_prune_list = prune_alpha(node.right_node,0)
        node.be_leaf()
    else:
        left_prune_list  = prune_alpha(node.left_node,alpha)
        right_prune_list = prune_alpha(node.right_node,alpha)
    prune_list.extend(left_prune_list)
    prune_list.extend(right_prune_list)
    return prune_list
        
# 获取迭代剪枝的最小alpha
'''
先计算树最小临界alpha,剪掉最小临界alpha的叶子,再计算树小临界,再剪..再计算,再剪,直到只剩根节点,
返回每轮的最小临界alpha
'''
def cal_prune_list(node):
    cnode = node.copy()
    min_alpha_list = []
    prune_list = []
    while((cnode.is_leaf==0) and len(min_alpha_list)<1000):
        leaf_errnum,leaf_nodenum,min_alpha = cal_alpha(cnode)
        cur_prune_list = prune_alpha(cnode,min_alpha)
        prune_list.append(cur_prune_list)
        min_alpha_list.append(min_alpha)
    return min_alpha_list,prune_list


#根据alpha值最大剪枝
def prune(node,alpha):
    p_node = node.copy()
    min_alpha_list,prune_list = cal_prune_list(p_node)
    prune_term = np.argwhere(np.array(min_alpha_list)<=alpha)
    
    if(len(prune_term)>0):
        prune_term=prune_term[-1][0]+1
        for i in range(prune_term):
            leaf_errnum,leaf_nodenum,min_alpha = cal_alpha(p_node)
            prune_alpha(p_node,min_alpha)
    return p_node


# 剪掉指定ID的节点
def prune_nodes(node,id_list):
    p_node = node.copy()
    un_use_node=[p_node]
    while((len(un_use_node)>0) and len(id_list)>0 ):
        cur_node = un_use_node.pop()
        if(cur_node.id in id_list):
            cur_node.be_leaf()
        elif(cur_node.is_leaf==0):
            un_use_node.append(cur_node.left_node)
            un_use_node.append(cur_node.right_node)
    return p_node

# predict
def predict(node,x):
    while(node.is_leaf==0):
       node = node.left_node if x[node.cut_var]<=node.cut_val else node.right_node
    return node.node_class
        
# 打印树
def print_node(node,deep=0,var_name_list=[],show_sample_class=0,show_alpha_info=0): 
    node_id = '('+str(node.id)+')'
    alpha_info = ' (leaf_errnum:'+str(node.leaf_errnum)+')' +  ' (alpha:'+str(node.alpha)+')'  if show_alpha_info==1 else ''
    if(node.is_leaf==0):
         var_name  = 'x' + str(node.cut_var)  if(len(var_name_list)==0) else var_name_list[node.cut_var]
         left_sample_class  = "("+str(node.left_node.sample_class)+")"  if show_sample_class==1 else ''
         right_sample_class = "("+str(node.right_node.sample_class)+")" if show_sample_class==1 else ''
         
         print('  |'*deep+"--"+node_id+var_name+"<="+str(node.cut_val)+left_sample_class+alpha_info)
         print_node(node.left_node,deep+1,var_name_list=var_name_list,show_sample_class=show_sample_class,show_alpha_info=show_alpha_info)
         print('  |'*deep+"--"+node_id+var_name+">"+str(node.cut_val)+right_sample_class+alpha_info)
         print_node(node.right_node,deep+1,var_name_list=var_name_list,show_sample_class=show_sample_class,show_alpha_info=show_alpha_info)
    else:
        print('  |'*deep+"--"+node_id+"class="+str(node.node_class) +alpha_info)

# 主程序:构建树        
def build_tree(x,y):
    min_leaf_num = 10                            # 参数预设
    n_samples,n_feture = x.shape
    cMat,c_name = class2Cmat(y)                  # 将y转为类别矩阵
        
    root_node = Node(cMat,np.arange(n_samples))  # 树初始化
    un_use_node=[root_node]
    
    # 树构建主流程
    while(len(un_use_node)>0 ):                # 如果还有节点未分裂完成
        # --------- 弹出节点 ------------------------------------------
        cur_node = un_use_node.pop()           # 弹出一个未完成分裂的节点
        node_x = x[cur_node.sample_idx]        # 获取节点样本的x
        cur_cMat = cMat[cur_node.sample_idx]   # 获取节点样本的y(类别矩阵形式)
        not_leaf  = (cur_node.sample_num>=min_leaf_num)& (cur_node.class_num>1) # 判断是否未达叶子条件
        #  ---------- 分裂或设为叶子 --------------------------------------
        if(not_leaf):                          # 如果未能成为叶子,继续分裂
            best_var = 0                       # 预设第一个变量为最佳变量
            best_g,best_cut_val = find_best_cut(node_x[:,best_var],cur_cMat) # 预设第一个变量的最佳切割为节点最佳切割
            for i in range(1,n_feture):                                      # 历遍变量,找出每个变量的最佳切割,再比较哪个变量的最佳切割最好
                g,cut_val = find_best_cut(node_x[:,i],cur_cMat)              # 找出该变量的最佳切割,与最佳收益
                if g>best_g:                                                 # 更新最佳变量、最佳切割、最佳收益
                  best_g       = g         
                  best_cut_val = cut_val
                  best_var     = i 
                  
            cur_node.cut_var = best_var                                      # 把最佳变量作为本节点最佳变量
            cur_node.cut_val = best_cut_val                                  # 把最佳切割作为本节点最佳切割
            is_left = node_x[:,best_var]<=best_cut_val                       # 找出左节点样本
            cur_node.left_node  = Node(cMat,cur_node.sample_idx[is_left ])   # 新建左节点
            cur_node.right_node = Node(cMat,cur_node.sample_idx[~is_left ])  # 新建右节点
            un_use_node.extend([cur_node.left_node,cur_node.right_node])     # 把左右节点添加到分裂池
        else:
           cur_node.be_leaf()                                                # 如果节点已达叶子条件,则设为叶子节点 
    set_node_id(root_node,1)                                                 # 给树每个节点添加序号
    return root_node                                                         # 返回根节点,即树

# 各个功能的使用demo
def test_demo():
    # -----加载数据-----------------
    iris = load_iris()    
    var_name_list=['sepal_length','sepal_width','petal_length','petal_width']
    x = iris.data  
    y = iris.target
    
    # ----构建完全生长的树-----------
    tree = build_tree(x,y)       
             
    # -----预测---------------------
    predict_y = np.zeros(y.shape)
    for i in range(x.shape[0]):
       predict_y[i]  = predict(tree,x[i])
       
    # -----打印信息-------------------------------------   
    print("\n------全生长树:------")
    print_node(tree,var_name_list=var_name_list)    
    print("\n------预测错误样本:------")
    print(np.argwhere(predict_y != y ))
    
    #-----剪枝-----------------------------------
    # 去掉无效叶子
    prune_bad(tree)                                       # 去掉无效叶子
    print("\n------去掉无效叶子的树------------")          # 打印树
    print_node(tree,var_name_list=var_name_list)      
    
    # 根据alpha剪枝
    # 带临界alpha信息的树
    cal_alpha(tree,tree.sample_num)                        # 获取临界alpha信息
    print("\n------树的临界alpha信息----------")            # 打印树
    print_node(tree,var_name_list=var_name_list,show_alpha_info=1)      
    
    # 多轮迭代式剪枝得到的CCP路径
    min_alpha_list,prune_list = cal_prune_list(tree)        # 模拟迭代式删除临界alpha
    print("\n--CCP路径--")      # 打信息
    print("每轮alpha:",min_alpha_list)
    print("每轮剪除节点:",prune_list)
    
    
    p_tree = prune(tree,0.1 )                               # 根据alpha剪枝(会多轮迭代)
    print("\n------指定alpha进行剪枝后的决策树-------")       # 打印树
    print_node(p_tree,var_name_list=var_name_list)
    
    # 其它剪枝      
    p_tree =prune_nodes(p_tree,[5,9])                       # 指定节点剪枝
    prune_alpha(tree,0.1)                                   # 根据alpha剪枝(只剪一轮)

# 调用测试Demo
test_demo()




    代码运行结果    


代码运行后的输出如下:
 1、用iris作为训练数据构建的完全生长的决策树 
 
自实现CART决策树代码-运行结果1
  
 2、决策树在训练样本中预测错误的样本编号
 自实现CART决策树代码-运行结果2
 
 3、将完全生长的树去掉无效节点的后得到的决策树
 自实现CART决策树代码-运行结果3
  
4、展示树的节点信息,包括alpha

 自实现CART决策树代码-运行结果4
 
 5、打印出ccp路径
 自实现CART决策树代码-运行结果5 
  
6、根据CCP路径信息,我们指定alpha=0.1,进行剪枝后得到的决策树
 
自实现CART决策树代码-运行结果6







  03. CART决策树代码说明   



本节对代码中的相关函数进行说明



test_demo: 测试用例主函数,直接运行时就是执行该函数。


1、数据生成:使用sklearn自带的iris数据
2、用iris数据构建一棵完全生长的决策树
3、用构建好的决策树对样本进行预测,并打印出预测错误的样本编号
4、将树去掉无效子节点
5、计算各个节点的临界alpha
6、计算并打印CCP路径
7、设定alpha=0.1,进行剪枝
8、根据节点编号进行剪枝


build_tree:决策树构建主函数,用于构建一棵CART决策树


决策树构建主函数,用于构建一棵CART决策树


print_node:打印决策树


将决策树结构打印出来,并可选择是否显示节点上的相关信息。


predict:决策树的预测函数


传入决策树和要预测的x,即可得到决策树的预测结果


cal_prune_list:计算CCP路径


每次按最小alpha值,迭代剪枝,直到剪完整棵树,最后返回剪枝的路径。
注:剪枝动作是模拟进行的,并不影响树本身。


prune_bad:对无效节点进行剪枝


如果节点的分枝并不能降低判别误差,说明节点的分枝是无益的,对这些无效节点进行剪枝


prune_nodes:决策树的节点剪枝函数


传入要剪掉的节点编号,对树进行剪枝


prune:决策树的alpha剪枝函数


传入alpha值,按alpha值进行剪枝




8个辅助函数:用于辅助计算的函数。


Node:节点类
cal_alpha:计算各个节点的临界alpha,并返回最小临界alpha
prune_alpha:剪掉当前树节点alpha<=某个值的节点
set_node_id:重新设置子孙节点编号
find_best_cut:找出x变量的最佳切割点
cal_gain:找出x变量计算使用变量x,切值为cut_val时的收益函数(gini)最佳切割点
class2Cmat:将类别转为类别矩阵
cal_gain:找出x变量计算使用变量x,切值为cut_val时的收益函数(gini)最佳切割点




以上就是自现实一棵CART决策树的代码了~






 End 






联系老饼