sklearn与决策树-应用
入门应用
一个简单的决策树分类例子
作者 : 老饼 日期 : 2022-06-26 09:46:57 更新 : 2022-08-03 20:45:40
本站原创文章,转载请说明来自《老饼讲解-机器学习》ml.bbbdata.com 


  CART决策树是常用的机器学习算法,它包括CART分类树与回归树,

回归树与分类树不同的地方在于,回归树的输出是数值,分类树输出的是类别。

本文展示一个用python(sklearn)实现的简单的CART分类树例子,用于学习sklearn分类树的调用方法



  01. 问题  


下面是一个简单的分类问题的数据与建模目标


       数  据      


现已采集150组 鸢尾花数据,
包括鸢尾花的四个特征与鸢尾花的类别。

数据如下(即sk-learn中的iris数据):
  
 花萼长度 sepal length (cm) 、花萼宽度 sepal width (cm)   
花瓣长度 petal length (cm) 、花瓣宽度 petal width (cm)  
山鸢尾:0,杂色鸢尾:1,弗吉尼亚鸢尾:2                   



      目   标      


我们希望通过采集的数据,
训练一个决策树模型,
之后应用该模型,
可以根据鸢尾花的四个特征去预测它的类别。




   02. 流程与代码   


     (一) 流 程     


1. 建立决策树模型                   
2. 用数据训练决策树模型         
3. 用训练好的决策树模型预测  


      (二) 代码    


from sklearn.datasets import load_iris
from sklearn import tree

#----------------数据准备----------------------------
iris = load_iris()                          # 加载数据

#---------------模型训练----------------------------------
clf = tree.DecisionTreeClassifier()         # sk-learn的决策树模型
clf = clf.fit(iris.data, iris.target)        # 用数据训练树模型构建()
r = tree.export_text(clf, feature_names=iris['feature_names'])


#---------------模型预测结果------------------------
text_x = iris.data[[0,1,50,51,100,101], :]
pred_target_prob = clf.predict_proba(text_x)        # 预测类别概率
pred_target = clf.predict(text_x)              # 预测类别

#---------------打印结果---------------------------
print("\n===模型======")
print(r)
print("\n===测试数据:=====")
print(text_x)
print("\n===预测所属类别概率:=====")
print(pred_target_prob)
print("\n===预测所属类别:======")
print(pred_target)



运行代码后,输出如下:


===模型======
|--- petal length (cm) <= 2.45
|   |--- class: 0
|--- petal length (cm) >  2.45
|   |--- petal width (cm) <= 1.75
|   |   |--- petal length (cm) <= 4.95
|   |   |   |--- petal width (cm) <= 1.65
|   |   |   |   |--- class: 1
|   |   |   |--- petal width (cm) >  1.65
|   |   |   |   |--- class: 2
|   |   |--- petal length (cm) >  4.95
|   |   |   |--- petal width (cm) <= 1.55
|   |   |   |   |--- class: 2
|   |   |   |--- petal width (cm) >  1.55
|   |   |   |   |--- sepal length (cm) <= 6.95
|   |   |   |   |   |--- class: 1
|   |   |   |   |--- sepal length (cm) >  6.95
|   |   |   |   |   |--- class: 2
|   |--- petal width (cm) >  1.75
|   |   |--- petal length (cm) <= 4.85
|   |   |   |--- sepal width (cm) <= 3.10
|   |   |   |   |--- class: 2
|   |   |   |--- sepal width (cm) >  3.10
|   |   |   |   |--- class: 1
|   |   |--- petal length (cm) >  4.85
|   |   |   |--- class: 2

===测试数据:=====
[[5.1 3.5 1.4 0.2]
 [4.9 3.  1.4 0.2]
 [7.  3.2 4.7 1.4]
 [6.4 3.2 4.5 1.5]
 [6.3 3.3 6.  2.5]
 [5.8 2.7 5.1 1.9]]

===预测所属类别概率:=====
[[1. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 0. 1.]
 [0. 0. 1.]]

===预测所属类别:======
[0 0 1 1 2 2]


以上就是决策树的最简例子




 End 




联系老饼