老饼讲解-机器学习 机器学习 神经网络 深度学习
决策树

【讲解】sklearn提取决策树数据

作者 : 老饼 发表日期 : 2022-06-26 09:49:55 更新日期 : 2023-12-08 16:32:19
本站原创文章,转载请说明来自《老饼讲解-机器学习》www.bbbdata.com


在用sklearn训练完决策树后,往往我们需要将决策树中的数据提取出来使用,

本文讲解如何提取训练好的决策树数据,并用demo说明每个数据的意义



  01. 问题  

用sklearn建好决策树后,可以打印出树的结构:
但往往我们提取图中的数据(例如用于将决策树转化成规则代码),
那图中的数据究竟在哪呢?
 
下面我们讲解如何在sklearn训练好决策树后,
提取决策树中的数据。




     决策树信息存储位置     


决策树模型主要有如下两类信息:
 
👉1. 树结构信息                                                                  
 
树结构信息以左右节点编号的形式来表示        
👉2.节点信息                                                                    
 
树节点信息包括分割变量、分割修士、不纯度等等
树模型的信息存储在决策树模型对象clf的属性中
具体如下:
决策树结构信息 
左节点编号  :  clf.tree_.children_left    
 右节点编号  :  clf.tree_.children_right  
   节点信息   
分割的变量  :  clf.tree_.feature                 
 分割的阈值 :  clf.tree_.threshold              
 不纯度(gini) :  clf.tree_.impurity             
 样本个数   :  clf.tree_.n_node_samples   
 样本分布   :  clf.tree_.value                     


  关于节点预测值  


sklearn并没有直接存决策树的类别(概率)预测值,
我们需要借助 样本分布 clf.tree_.value,
节点预测类别:样本最多的一类就是节点的预测类别,
节点预测类别的概率:样本占比则是预测概率。         



   03. 实例解说  


下面我们展示一个提取决策树信息的实例,

更加具体说明各个信息的使用


  决策树实例  


# -*- coding: utf-8 -*-
from sklearn.datasets import load_iris
from sklearn import tree
import graphviz 
#----------------数据准备----------------------------
iris = load_iris()                          # 加载数据
#---------------模型训练---------------------------------
clf = tree.DecisionTreeClassifier(random_state=0,max_depth=3)        
clf = clf.fit(iris.data, iris.target)     
#---------------树结构可视化-----------------------------
dot_data = tree.export_graphviz(clf) 
graph = graphviz.Source(dot_data)  
graph                                    # 需要独立运行
#---------------提取模型结构数据--------------------------
children_left    = clf.tree_.children_left            # 左节点编号
children_right   = clf.tree_.children_right            # 右节点编号
feature        = clf.tree_.feature               # 分割的变量
threshold       = clf.tree_.threshold              # 分割阈值
impurity       = clf.tree_.impurity               # 不纯度(gini)
n_node_samples   = clf.tree_.n_node_samples            # 样本个数
value         = clf.tree_.value                 # 样本分布

#-------------打印------------------------------
print("children_left:",children_left)        
print("children_right:",children_right)
print("feature:",feature)
print("threshold:",threshold)
print("impurity:",impurity)
print("n_node_samples:",n_node_samples)
print("value:",value)


运行结果:

children_left : [ 1 -1  3  4 -1 -1  7 -1 -1]
children_right: [ 2 -1  6  5 -1 -1  8 -1 -1]
feature       : [ 3 -2 3 2 -2 -2 2 -2 -2]
threshold     : [ 0.80000001 -2.  1.75   4.95000005 -2.   -2.  4.85000014 -2.   -2.]
impurity      : [0.66666667 0.  0.5   0.16803841 0.04079861 0.44444444 0.04253308 0.44444444 0.  ]
n_node_samples: [150 50 100 54 48  6 46  3 43]
value         : [[[50. 50. 50.]][[50.  0.  0.]] [[ 0. 50. 50.]] [[ 0. 49.  5.]] [[ 0. 47.  1.]] [[ 0.  2.  4.]] 
                [[ 0.  1. 45.]] [[ 0.  1.  2.]] [[ 0.  0. 43.]]]

借助决策树拓扑图与输出值对比,可以较容易理解各个数据在树中的意义



   提取树结构   


树结构信息存在children_left和children_right ,
它们记录了左右节点编号

children_left[0]    = 1  
代表 第0(根节点)个节点左节点编号为1 

children_right[0]  = 2  
代表 第0(根节点)个节点右节点编号为2 
由上可知,
根节点的左节点编号为1,右节点编号为2,
左节点1和节点2的子节点去哪找呢,
继续代入 children_left和 children_right即可。
左节点1的子节点编号: 
左子节点 children_left[1] = -1,
右子节点children_right[1] =-1, 
-1代表没有子节点
(即说明左节点1是叶子节点)。

右节点2的子节点编号: 
左子节点   children_left[2]   = 3   
右子节点   children_right[2] = 6  
....
如此类推,即知树结构。



  提取节点信息  


第0个节点的信息:
 
分割变量   :feature[0] = 3                               
分割阈值   :threshold[0] =0.8                         
不纯度(gini系数) :impurity[0]  = 0.66666667 
样本个数      :n_node_samples[0] = 150       
样本分布      :value[0] = [50 50 50]              


第1个节点的信息:
 
分割变量   :feature[1]  = -2  (-2代表是叶子节点,该值没意义)   
分割阈值   :threshold[1] = -2 (-2代表是叶子节点,该值没意义)
不纯度(gini系数) :impurity[1]= 0                                           
样本个数       :n_node_samples[1] = 50                                
样本分布       :value[1]= [50 0 0]                                          
......
......
如此类推即可。







 End 





联系老饼