使用网格搜索和K折交叉验证来优化 XGBoost 分类模型¶
本篇文章将重点介绍如何在分类任务中对 XGBoost 进行优化,XGBoost 的优势在于其处理大规模数据、提高模型准确性的同时能够防止过拟合,然而要充分发挥 XGBoost 在分类任务中的潜力,选择合适的超参数至关重要。
为了寻找最佳的超参数组合,通常会借助 网格搜索 和 K 折交叉验证 优化技术。网格搜索通过系统地遍历多个超参数组合来确定最佳配置,而 K 折交叉验证则通过将数据集分成 K 个子集以评估模型的泛化能力,结合这两种方法,可以有效地避免单次训练可能带来的过拟合风险,并为模型选择最佳的超参数。
接下来,我们将通过具体的示例代码,详细演示如何运用网格搜索和K折交叉验证优化 XGBoost 分类模型的过程。
1. 数据预处理¶
从 Dataset.csv 中读取数据,并将其分为训练集和测试集,为后续的模型训练和评估做好准备,通过分层采样(stratify
参数)保证训练集和测试集的类别分布一致,这对于模型的公平性和泛化能力至关重要,该数据集为二分类数据。
2. 定义 XGB 模型参数¶
上述超参数的解释如下:
参数 | 描述 |
---|---|
learning_rate |
学习率,控制每一步的步长,用于防止过拟合,通常值范围:0.01 ~ 0.1 |
booster |
提升方法,这里使用梯度提升树(Gradient Boosting Tree) |
objective |
损失函数,这里使用逻辑回归,用于二分类任务;如果是多分类任务,可选值为 multi:softmax 或 multi:softprob 。'multi:softmax' 或 objective: 'multi:softprob' ,multi:softmax: 输出预测类别,multi:softprob : 输出每个类别的概率 |
num_class |
类别数,二分类任务为 2,多分类任务为类别数 |
max_leaves |
每棵树的叶子节点数量,控制模型复杂度,较大值可以提高模型复杂度但可能导致过拟合 |
verbosity |
控制 XGBoost 输出信息的详细程度,0 表示无输出,1 表示输出进度信息 |
seed |
随机种子,用于重现模型的结果 |
nthread |
并行运算的线程数量,-1 表示使用所有可用的 CPU 核心 |
colsample_bytree |
每棵树随机选择的特征比例,用于增加模型的泛化能力 |
subsample |
每次迭代时随机选择的样本比例,用于增加模型的泛化能力 |
eval_metric |
评价指标,二分类使用对数损失(logloss ),多分类使用交叉熵(cross-entropy)或者多分类对数损失(mlogloss ),其他指标还有 merror 等。 |
3. 定义模型¶
使用定义的参数初始化XGBoost回归模型
4. 定义参数网格¶
下面我们定义一个网格参数,
定义需要进行网格搜索的参数范围,包括树的数量(n_estimators
)、最大深度(max_depth
)和最小节点权重(min_child_weight
),这些参数用于调整模型的复杂度和防止过拟合。
当然除了这些参数以外还存在很对参数。下面给出一些参考(这些参数并没有参与网格调参),网格参数范围多少当对运行速度成倍数增长
5. 创建 GridSearchCV 对象¶
GridSearchCV
用于进行网格搜索和交叉验证,scoring
设为 neg_root_mean_squared_error
,即负均方根误差,作为评估指标,cv=5
表示使用5折交叉验证,n_jobs=-1
表示使用所有可用的 CPU 核心进行并行计算。
6. 训练模型¶
GridSearchCV 训练好之后,我们输出最优模型参数:
5 folds
表示使用 5 折交叉验证,即数据集被分成 5 个子集,每次使用其中的 4 个子集进行训练,1 个子集进行验证,这样进行 5 次,每个子集都作为一次验证集,- 125 candidates: 表示总共有 125 组不同的参数组合需要评估(5x5x5),也就是定义的网格参数存在的组合形式,
- totalling 625 fits: 因为每个参数组合都要经过 5 次交叉验证,所以总共进行 625 次模型训练和评估(5x125=625)
- 最后模型找到的最佳参数组合为
max_depth=7
、min_child_weight=5
、n_estimators=500
,使用这些参数,模型在验证集上的平均 RMSE 分数为 0.4666597060108463。
7. 最优参数下的模型¶
使用找到的最优参数重新训练模型,得到最终的最佳模型并保存。
8. 模型评价¶
9. 可视化¶
这里采用 Matplotlib 进绘制实际值 (y_test
) 和预测值 (y_pred
) 之间的散点图,并在图中添加一个理想对角线(y = x
),以便比较实际值和预测值的关系。
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 6), dpi=1200)
# 绘制 y_pred 和 y_test 的散点图
plt.scatter(y_test, y_pred, alpha=0.3, color='blue', label='Predicted vs Actual')
# 绘制一条 y = x 的对角线,用于参考
max_value = max(max(y_test), max(y_pred))
plt.plot([0, max_value], [0, max_value], color='red', linestyle='--', linewidth=2, label='Ideal Line (y = x)')
plt.title(f'Actual vs Predicted Values\nR-squared: {r2:.2f}')
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.legend()
plt.grid(True)
plt.show()