网格搜索(Grid Search)详细教学

1. 什么是网格搜索?

在机器学习模型中,算法的**超参数(Hyperparameters)**对模型的表现起着决定性作用。比如:

  • KNN 的邻居数量 n_neighbors

  • SVM 的惩罚系数 C 和核函数参数 gamma

  • 随机森林的决策树数量 n_estimators

这些超参数不会在训练过程中自动学习得到,而是需要我们人为设定。网格搜索(Grid Search)是一种最常见的超参数优化方法:
它通过
遍历给定参数网格中的所有组合
,使用交叉验证来评估每组参数的效果,最终选出表现最优的一组。

通俗理解:
👉 网格搜索 = 穷举法找最佳参数。


2. 网格搜索的核心思想

  1. 定义参数范围(网格):例如 C=[0.1, 1, 10]gamma=[0.01, 0.1, 1]

  2. 训练所有组合:即 (C=0.1, gamma=0.01)(C=0.1, gamma=0.1)...直到 (C=10, gamma=1)

  3. 交叉验证评估:每组参数都会在 k 折交叉验证下计算平均性能指标(如准确率、F1 分数)。

  4. 选择最佳参数:选出指标最优的一组参数作为最终模型配置。


3. 为什么要用网格搜索?

  • 超参数选择自动化:不用凭感觉拍脑袋。

  • 保证找到最优解:只要网格覆盖范围足够大,就不会遗漏最佳参数组合。

  • 结合交叉验证:结果更加稳健,避免过拟合或欠拟合。

但缺点也明显:

  • 计算开销大:参数范围和组合越多,训练越耗时。

  • 不适合大规模搜索:参数维度高时可能出现“维度灾难”。


4. Scikit-Learn 中的网格搜索工具

sklearn.model_selection.GridSearchCV 是最常用的网格搜索实现。

4.1 函数原型

GridSearchCV(estimator,          # 基础模型,如SVC()、RandomForestClassifier()param_grid,         # 参数字典或列表,定义搜索空间scoring=None,       # 评估指标(accuracy、f1、roc_auc等)n_jobs=None,        # 并行任务数,-1表示使用所有CPUcv=None,            # 交叉验证折数,如cv=5verbose=0,          # 日志等级,1=简单进度条,2=详细refit=True,         # 是否在找到最优参数后重新训练整个模型return_train_score=False  # 是否返回训练集得分
)

GridSearchCV 常用参数表:

分类参数类型说明常用取值
核心estimatorestimator 对象基础模型,必须实现 fit / predictSVC()RandomForestClassifier()
param_griddict / list要搜索的参数空间,键=参数名,值=候选值列表{'C':[0.1,1,10], 'gamma':[0.01,0.1,1]}
评估scoringstr / callable模型评估指标accuracyf1_macroroc_aucneg_mean_squared_error
cvint / 生成器交叉验证方式5(5折交叉验证)、KFold(10)
refitbool / str用最佳参数在全训练集上重新训练True(默认)、'f1_macro'(多指标时指定)
效率n_jobsint并行任务数,-1=使用所有CPU-14
pre_dispatchint / str并行调度策略'2*n_jobs'(默认)
日志verboseint输出日志等级0=无输出,1=进度,2=详细
错误处理error_scorestr / numeric参数报错时的分数np.nan(默认)、0
调试return_train_scorebool是否返回训练集得分(用于过拟合分析)False(默认)、True


5. 网格搜索实战案例

5.1 示例数据集

以鸢尾花(Iris)分类为例,使用 SVM 模型。

import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV, train_test_split# 加载数据
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 定义模型
svc = SVC()

5.2 设置参数网格

param_grid = {'C': [0.1, 1, 10, 100],          # 惩罚系数'gamma': [1, 0.1, 0.01, 0.001],  # 核函数参数'kernel': ['rbf', 'linear']      # 核函数类型
}

5.3 执行网格搜索

grid = GridSearchCV(estimator=svc,param_grid=param_grid,scoring='accuracy',cv=5,verbose=2,n_jobs=-1
)
grid.fit(X_train, y_train)

5.4 输出结果

print("最佳参数:", grid.best_params_)
print("最佳得分:", grid.best_score_)
print("测试集准确率:", grid.best_estimator_.score(X_test, y_test))

结果示例


6. 网格搜索的可视化

我们可以把不同参数组合的表现绘制出来,直观查看最优解在哪个区域:

import matplotlib.pyplot as pltresults = pd.DataFrame(grid.cv_results_)# 只绘制 C 与 gamma 的得分热力图(kernel=rbf)
scores = results[results.param_kernel == 'rbf'].pivot(index='param_gamma',columns='param_C',values='mean_test_score'
)plt.imshow(scores, interpolation='nearest', cmap=plt.cm.hot)
plt.xlabel('C')
plt.ylabel('gamma')
plt.colorbar()
plt.xticks(np.arange(len(scores.columns)), scores.columns)
plt.yticks(np.arange(len(scores.index)), scores.index)
plt.title('Grid Search Accuracy Heatmap')
plt.show()

7. 网格搜索的进阶技巧

  1. 缩小搜索范围:先用较粗粒度搜索,再在最优附近细化搜索。

  2. 并行计算n_jobs=-1 可利用多核 CPU。

  3. 随机搜索(RandomizedSearchCV):当参数空间太大时,可考虑随机抽样搜索,更高效。

  4. 贝叶斯优化:如 OptunaHyperopt,比网格搜索更智能。


8. 注意事项

  • 参数空间不要过大,否则计算量爆炸。

  • 交叉验证的折数 cv 不宜过大,通常 5 或 10。

  • 选择合适的评分指标 scoring,分类问题常用 accuracyf1_macro,回归问题用 neg_mean_squared_error 等。

  • 最终模型建议用 grid.best_estimator_,而不是手动再初始化。


9. 总结

  • **网格搜索(Grid Search)**是一种系统化的超参数优化方法,通过遍历参数网格+交叉验证,找到表现最优的参数组合。

  • sklearn 中,GridSearchCV 是核心工具。

  • 它简单易用,但计算成本高,不适合大规模问题。

  • 实际应用中常结合粗到细搜索、随机搜索、贝叶斯优化来提升效率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/919603.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/919603.shtml
英文地址,请注明出处:http://en.pswp.cn/news/919603.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【LeetCode】18. 四数之和

文章目录18. 四数之和题目描述示例 1:示例 2:提示:解题思路算法一:排序 双指针(推荐)算法二:通用 kSum(含 2Sum 双指针)复杂度关键细节代码实现要点完整题解代码18. 四数…

Go语言入门(10)-数组

访问数组元素:数组中的每个元素都可以通过“[]”和一个从0开始的索引进行访问数组的长度可由内置函数len来确定。在声明数组时,未被赋值元素的值是对应类型的零值。下面看一个例子package mainfunc main(){var planets [8]stringplanets[0] "Mercu…

为什么经过IPSec隧道后HTTPS会访问不通?一次隧道环境下的实战分析

在运维圈子里,大家可能都遇到过这种奇怪的问题:浏览器能打开 HTTP 网站,但一换成 HTTPS,页面就死活打不开。前段时间,我们就碰到这么一个典型案例。故障现象某公司系统在 VPN 隧道里访问 HTTPS 服务,结果就…

【Linux系统】进程信号:信号的产生和保存

上篇文章我们介绍了Syetem V IPC的消息队列和信号量,那么信号量和我们下面要介绍的信号有什么关系吗?其实没有关系,就相当于我们日常生活中常说的老婆和老婆饼,二者并没有关系1. 认识信号1.1 生活角度的信号解释(快递比…

WEB服务器(静态/动态网站搭建)

简介 名词:HTML(超文本标记语言),网站(多个网页组成一台网站),主页,网页,URL(统一资源定位符) 网站架构:LAMP(linux(系统)+apache(服务器程序)+mysql(数据库管理软件)+php(中间软件)) 静态站点 Apache基础 Apache官网:www.apache.org 软件包名称:…

开发避坑指南(29):微信昵称特殊字符存储异常修复方案

异常信息 Cause: java.sql.SQLException: Incorrect string value: \xF0\x9F\x8D\x8B\xE5\xBB... for column nick_name at row 1异常背景 抽奖大转盘,抽奖后需要保存用户抽奖记录,用户再次进入游戏时根据抽奖记录判断剩余抽奖机会。保存抽奖记录时需要…

leetcode-python-242有效的字母异位词

题目&#xff1a; 给定两个字符串 s 和 t &#xff0c;编写一个函数来判断 t 是否是 s 的 字母异位词。 示例 1: 输入: s “anagram”, t “nagaram” 输出: true 示例 2: 输入: s “rat”, t “car” 输出: false 提示: 1 < s.length, t.length < 5 * 104 s 和 t 仅…

【ARM】Keil MDK如何指定单文件的优化等级

1、 文档目标解决在MDK中如何对于单个源文件去设置优化等级。2、 问题场景在正常的项目开发中&#xff0c;我们通常都是针对整个工程去做优化&#xff0c;相当于整个工程都是使用一个编译器优化等级去进行的工程构建。那么在一些特定的情况下&#xff0c;工程师需要保证我的部分…

零基础学Java第二十二讲---异常(2)

续接上一讲 目录 一、异常的处理&#xff08;续&#xff09; 1、异常的捕获-try-catch捕获并处理异常 1.1关于异常的处理方式 2、finally 3、异常的处理流程 二、自定义异常类 1、实现自定义异常类 一、异常的处理&#xff08;续&#xff09; 1、异常的捕获-try-catch捕…

自建开发工具IDE(一)之拖找排版—仙盟创梦IDE

自建拖拽布局排版在 IDE 中的优势及初学者开发指南在软件开发领域&#xff0c;用户界面&#xff08;UI&#xff09;的设计至关重要。自建拖拽布局排版功能为集成开发环境&#xff08;IDE&#xff09;带来了诸多便利&#xff0c;尤其对于初学者而言&#xff0c;是踏入开发领域的…

GitHub Copilot - GitHub 推出的AI编程助手

本文转载自&#xff1a;GitHub Copilot - GitHub 推出的AI编程助手 - Hello123工具导航。 ** 一、GitHub Copilot 核心定位 GitHub Copilot 是由 GitHub 与 OpenAI 联合开发的 AI 编程助手&#xff0c;基于先进大语言模型实现代码实时补全、错误检测及文档生成&#xff0c;显…

基于截止至 2025 年 6 月 4 日,在 App Store 上进行交易的设备数据统计,iOS/iPadOS 各版本在所有设备中所占比例详情

iOS 和 iPadOS 使用情况 基于截止至 2025 年 6 月 4 日&#xff0c;在 App Store 上进行交易的设备数据统计。 iPhone 在过去四年推出的设备中&#xff0c;iOS 18 的普及率达 88。 88% iOS 188% iOS 174% 较早版本 所有的设备中&#xff0c;iOS 18 的普及率达 82。 82% iOS 189…

云计算-k8s实战指南:从 ServiceMesh 服务网格、流量管理、limitrange管理、亲和性、环境变量到RBAC管理全流程

介绍 本文是一份 Kubernetes 与 ServiceMesh 实战操作指南,涵盖多个核心功能配置场景。从 Bookinfo 应用部署入手,详细演示了通过 Istio 创建 Ingress Gateway 实现外部访问,以及基于用户身份、请求路径的服务网格路由规则配置,同时为应用微服务设置了默认目标规则。 还包…

Vue 3项目中的路由管理和状态管理系统

核心概念理解 1. 整体架构关系 这两个文件构成了Vue应用的导航系统和状态管理系统&#xff1a; Router&#xff08;路由&#xff09;&#xff1a;控制页面跳转和URL变化Store&#xff08;状态&#xff09;&#xff1a;管理全局数据和用户状态两者协同工作实现权限控制 2. 数据流…

Linux Capability 解析

文章目录1. 权限模型演进背景2. Capability核心原理2.1 能力单元分类2.2 进程三集合2.3 文件系统属性3. 完整能力单元表4. 高级应用场景4.1 能力边界控制4.2 编程控制4.3 容器安全5. 安全实践建议6. 潜在风险提示 1. 权限模型演进背景 在传统UNIX权限模型中&#xff0c;采用二进…

vue 监听 sessionStorage 值的变化

<template><div class"specific-storage-watcher"><h3>仅监听 userId 变化</h3><p>当前 userId: {{ currentUserId }}</p><p v-if"changeRecord">最近变化: {{ changeRecord }}</p><button click"…

IDEA:控制台中文乱码

目录一、设置字符编码为 UTF-8一、设置字符编码为 UTF-8 点击菜单 File -> settings -> Eitor -> File Encodings , 将字符全局编码、项目编码、配置文件编码统一设置为UTF-8, 然后点击 Apply 应用设置&#xff0c;点击 OK 关闭对话框:

[Sql Server]特殊数值计算

任务一&#xff1a;求下方的Num列的中值:参考代码:use Test go SELECT DISTINCTPERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY Num) over()AS MedianSalary FROM MedianTest;任务二: 下方表中,每个选手有多个评委打分&#xff0c;求每个选手的评委打分中值。参考代码:use Tes…

01-Docker概述

Docker 的主要目标是:Build, Ship and Run Any App, Anywhere,也就是通过对应用组件的封装、分发、部署、运行等生命周期的管理,使用户的 APP 及其运行环境能做到一次镜像,处处运行。 Docker 运行速度快的原因: 由于 Docker 不需要 Hypervisor(虚拟机)实现硬件资源虚拟化…

Laravel中如何使用php-casbin

一、&#x1f680; 安装和配置 1. 安装包 composer require casbin/laravel-authz2. 发布配置文件 php artisan vendor:publish这会生成两个重要文件&#xff1a; config/lauthz.php - 主配置文件config/lauthz-rbac-model.conf - RBAC 模型配置文件 3. 运行数据库迁移 php…