老饼讲解-机器学习 机器学习 神经网络 深度学习
集成方法

【代码】Adaboost简单Demo(sklearn)

作者 : 老饼 发表日期 : 2022-06-26 14:15:15 更新日期 : 2024-01-21 17:05:53
本站原创文章,转载请说明来自《老饼讲解-机器学习》www.bbbdata.com



本文展示一个python的sklearn中实现Adaboost的Demo,

在自写代码中,我们再复现该Demo,比较sklearn包的结果与自写代码的结果是否一致


   

   

  01. sklearn实现Adaboost的代码   



本节展示一个用sklearn实现Adaboost的代码示例 



  代码简介 


代码展示一个利用python的sklearn实现Adaboost的Demo,
Demo先是随机生成一组二分类数据,然后调用Adaboost包进行训练,
最后展示预测结果和模型的相关参数




   Demo代码  


# -*- coding: utf-8 -*-
"""
本代码展示一个调用sklearn包实现Adaboost提升树算法的Demo
本代码来自《老饼讲解-机器学习》www.bbbdata.com
"""
import matplotlib.pyplot as plt
from sklearn.ensemble import AdaBoostClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import make_gaussian_quantiles

# ==================== 数据生成 ==================================
# -----生成训练数据-------
# 生成2维正态分布,生成的数据按分位数分为两类,500个样本,2个样本特征,协方差系数为2
X, y = make_gaussian_quantiles(cov=2.0,n_samples=500, n_features=2,n_classes=2, random_state=1) 
plt.axis('off')
plt.scatter(X[:, 0], X[:, 1], c=y)

# ================== 模型训练 ========================
bdt = AdaBoostClassifier(DecisionTreeClassifier(max_depth=2, min_samples_split=20, min_samples_leaf=5),
             algorithm="SAMME",
             n_estimators=50, learning_rate=0.8)
bdt.fit(X, y)

#==================== 预测 ================================
pred = bdt.predict(X)
proba = bdt.predict_proba(X)

#============== 打印结果 =================================
plt.figure()
plt.axis('off')
plt.scatter(X[:, 0], X[:, 1], c=pred)
print("----前10个样本预测结果-----:")
print(proba[1:10,1])
print("\n-------各个决策器的权重系数:--------:")
print(bdt.estimator_weights_)

代码运行版本:conda 4.10.3



  代码运行结果  


代码运行结果如下:
-----------------类别预测结果------------
----前10个样本预测结果-----:
[0.4401229  0.44322703 0.38694471 0.56047588 0.40757558 0.5991921
 0.47065033 0.4575639  0.56047588]

-------各个决策器的权重系数:--------:
[0.92214361 0.82369487 0.85884339 0.98420957 0.91186682 0.66167451
 0.50686971 0.72871313 0.60459277 0.7782458  0.8684703  0.58504242
 0.62814664 0.57422949 0.54349101 0.67251002 0.36467185 0.5005236
 0.60913436 0.42339819 0.43601101 0.45785634 0.73436682 0.49574203
 0.48327574 0.35112571 0.3133096  0.31226401 0.42923141 0.43631673
 0.45377492 0.46902642 0.36274209 0.28099441 0.37300286 0.48235625
 0.60807215 0.39493869 0.67715658 0.33012475 0.3236838  0.57414975
 0.372842   0.2476105  0.26227468 0.35676776 0.32916502 0.41230541
 0.20373823 0.23041444]
从结果中可以看到,基本类别都预测正确了,
把决策器个数设置更大,准确率还能再提高,这里作为Demo,我们不再优化








  End  






联系老饼