网格搜索(Grid Search)详细教学
1. 什么是网格搜索?
在机器学习模型中,算法的**超参数(Hyperparameters)**对模型的表现起着决定性作用。比如:
KNN 的邻居数量
n_neighbors
SVM 的惩罚系数
C
和核函数参数gamma
随机森林的决策树数量
n_estimators
这些超参数不会在训练过程中自动学习得到,而是需要我们人为设定。网格搜索(Grid Search)是一种最常见的超参数优化方法:
它通过遍历给定参数网格中的所有组合,使用交叉验证来评估每组参数的效果,最终选出表现最优的一组。
通俗理解:
👉 网格搜索 = 穷举法找最佳参数。
2. 网格搜索的核心思想
定义参数范围(网格):例如
C=[0.1, 1, 10]
,gamma=[0.01, 0.1, 1]
。训练所有组合:即
(C=0.1, gamma=0.01)
、(C=0.1, gamma=0.1)
...直到(C=10, gamma=1)
。交叉验证评估:每组参数都会在 k 折交叉验证下计算平均性能指标(如准确率、F1 分数)。
选择最佳参数:选出指标最优的一组参数作为最终模型配置。
3. 为什么要用网格搜索?
超参数选择自动化:不用凭感觉拍脑袋。
保证找到最优解:只要网格覆盖范围足够大,就不会遗漏最佳参数组合。
结合交叉验证:结果更加稳健,避免过拟合或欠拟合。
但缺点也明显:
计算开销大:参数范围和组合越多,训练越耗时。
不适合大规模搜索:参数维度高时可能出现“维度灾难”。
4. Scikit-Learn 中的网格搜索工具
sklearn.model_selection.GridSearchCV
是最常用的网格搜索实现。
4.1 函数原型
GridSearchCV(estimator, # 基础模型,如SVC()、RandomForestClassifier()param_grid, # 参数字典或列表,定义搜索空间scoring=None, # 评估指标(accuracy、f1、roc_auc等)n_jobs=None, # 并行任务数,-1表示使用所有CPUcv=None, # 交叉验证折数,如cv=5verbose=0, # 日志等级,1=简单进度条,2=详细refit=True, # 是否在找到最优参数后重新训练整个模型return_train_score=False # 是否返回训练集得分
)
GridSearchCV
常用参数表:
分类 | 参数 | 类型 | 说明 | 常用取值 |
---|---|---|---|---|
核心 | estimator | estimator 对象 | 基础模型,必须实现 fit / predict | SVC() 、RandomForestClassifier() |
param_grid | dict / list | 要搜索的参数空间,键=参数名,值=候选值列表 | {'C':[0.1,1,10], 'gamma':[0.01,0.1,1]} | |
评估 | scoring | str / callable | 模型评估指标 | accuracy 、f1_macro 、roc_auc 、neg_mean_squared_error |
cv | int / 生成器 | 交叉验证方式 | 5 (5折交叉验证)、KFold(10) | |
refit | bool / str | 用最佳参数在全训练集上重新训练 | True (默认)、'f1_macro' (多指标时指定) | |
效率 | n_jobs | int | 并行任务数,-1=使用所有CPU | -1 、4 |
pre_dispatch | int / str | 并行调度策略 | '2*n_jobs' (默认) | |
日志 | verbose | int | 输出日志等级 | 0 =无输出,1 =进度,2 =详细 |
错误处理 | error_score | str / numeric | 参数报错时的分数 | np.nan (默认)、0 |
调试 | return_train_score | bool | 是否返回训练集得分(用于过拟合分析) | False (默认)、True |
5. 网格搜索实战案例
5.1 示例数据集
以鸢尾花(Iris)分类为例,使用 SVM 模型。
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV, train_test_split# 加载数据
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 定义模型
svc = SVC()
5.2 设置参数网格
param_grid = {'C': [0.1, 1, 10, 100], # 惩罚系数'gamma': [1, 0.1, 0.01, 0.001], # 核函数参数'kernel': ['rbf', 'linear'] # 核函数类型
}
5.3 执行网格搜索
grid = GridSearchCV(estimator=svc,param_grid=param_grid,scoring='accuracy',cv=5,verbose=2,n_jobs=-1
)
grid.fit(X_train, y_train)
5.4 输出结果
print("最佳参数:", grid.best_params_)
print("最佳得分:", grid.best_score_)
print("测试集准确率:", grid.best_estimator_.score(X_test, y_test))
结果示例:
6. 网格搜索的可视化
我们可以把不同参数组合的表现绘制出来,直观查看最优解在哪个区域:
import matplotlib.pyplot as pltresults = pd.DataFrame(grid.cv_results_)# 只绘制 C 与 gamma 的得分热力图(kernel=rbf)
scores = results[results.param_kernel == 'rbf'].pivot(index='param_gamma',columns='param_C',values='mean_test_score'
)plt.imshow(scores, interpolation='nearest', cmap=plt.cm.hot)
plt.xlabel('C')
plt.ylabel('gamma')
plt.colorbar()
plt.xticks(np.arange(len(scores.columns)), scores.columns)
plt.yticks(np.arange(len(scores.index)), scores.index)
plt.title('Grid Search Accuracy Heatmap')
plt.show()
7. 网格搜索的进阶技巧
缩小搜索范围:先用较粗粒度搜索,再在最优附近细化搜索。
并行计算:
n_jobs=-1
可利用多核 CPU。随机搜索(RandomizedSearchCV):当参数空间太大时,可考虑随机抽样搜索,更高效。
贝叶斯优化:如
Optuna
、Hyperopt
,比网格搜索更智能。
8. 注意事项
参数空间不要过大,否则计算量爆炸。
交叉验证的折数
cv
不宜过大,通常 5 或 10。选择合适的评分指标
scoring
,分类问题常用accuracy
、f1_macro
,回归问题用neg_mean_squared_error
等。最终模型建议用
grid.best_estimator_
,而不是手动再初始化。
9. 总结
**网格搜索(Grid Search)**是一种系统化的超参数优化方法,通过遍历参数网格+交叉验证,找到表现最优的参数组合。
在
sklearn
中,GridSearchCV
是核心工具。它简单易用,但计算成本高,不适合大规模问题。
实际应用中常结合粗到细搜索、随机搜索、贝叶斯优化来提升效率。