前文中,只是给了基础模型: 

PyTorch 实现 CIFAR-10 图像分类:从数据预处理到模型训练与评估-CSDN博客

今天我们增加交叉验证和超参数调优,

先看运行结果:
===== 在测试集上评估最终模型 =====
最终模型在测试集上的准确率:60.14%
最优模型已保存为 'cifar10_best_model.pth'(超参数:{'batch_size': 32, 'epochs': 5, 'lr': 0.01, 'momentum': 0.85})

Process finished with exit code 0
比基础模型准确率高了一点,

 完整代码如下:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from sklearn.model_selection import KFold, ParameterGrid  # 用于交叉验证和超参数网格搜索# --------------------------
# 1. 数据准备(与原代码一致,但后续会在训练集内部做交叉验证)
# --------------------------
# 数据预处理:标准化(与原代码相同)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 数据集路径(请替换为你的实际路径)
data_path = r'D:\workspace_py\deeplean\data'# 加载完整训练集和测试集(测试集始终不变,用于最终评估)
full_trainset = datasets.CIFAR10(root=data_path, train=True, download=False, transform=transform)
testset = datasets.CIFAR10(root=data_path, train=False, download=False, transform=transform)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# --------------------------
# 2. 定义CNN模型(与原代码一致)
# --------------------------
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = self.fc3(x)return x# --------------------------
# 3. 交叉验证函数(核心新增)
# --------------------------
def cross_validate(model, train_dataset, k_folds=5, epochs=5, lr=0.001, batch_size=32, momentum=0.9):"""5折交叉验证:将训练集分成5份,每次用4份训练,1份验证,返回平均准确率"""kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)  # 固定随机种子,结果可复现fold_results = []  # 存储每折的验证准确率for fold, (train_ids, val_ids) in enumerate(kfold.split(train_dataset)):print(f'\n===== 第 {fold + 1}/{k_folds} 折交叉验证 =====')# 1. 划分当前折的训练集和验证集train_subset = Subset(train_dataset, train_ids)  # 本次训练用的数据val_subset = Subset(train_dataset, val_ids)  # 本次验证用的数据# 2. 创建数据加载器train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)# 3. 初始化模型和优化器(每折都重新训练新模型,避免干扰)model_instance = Net()  # 重新实例化模型criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(model_instance.parameters(), lr=lr, momentum=momentum)# 4. 训练当前折的模型for epoch in range(epochs):model_instance.train()  # 训练模式running_loss = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = model_instance(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()# 每200步打印一次损失(简化输出)if i % 200 == 199:print(f'折 {fold + 1},轮次 {epoch + 1},第 {i + 1} 步:平均损失 {running_loss / 200:.3f}')running_loss = 0.0# 5. 在验证集上评估当前折的模型model_instance.eval()  # 验证模式correct = 0total = 0with torch.no_grad():for data in val_loader:images, labels = dataoutputs = model_instance(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()val_acc = 100 * correct / totalprint(f'第 {fold + 1} 折验证准确率:{val_acc:.2f}%')fold_results.append(val_acc)# 计算所有折的平均准确率(该超参数组合的最终得分)avg_acc = sum(fold_results) / len(fold_results)print(f'\n===== 该超参数组合的平均验证准确率:{avg_acc:.2f}% =====')return avg_acc# --------------------------
# 4. 超参数调优(核心新增)
# --------------------------
def hyperparameter_tuning(train_dataset):"""超参数网格搜索:尝试不同的超参数组合,用交叉验证选最优"""# 定义要测试的超参数组合(可根据需要增减)param_grid = {'lr': [0.001, 0.01],  # 学习率:尝试两个值'batch_size': [32, 64],  # 批大小:尝试两个值'momentum': [0.9, 0.85],  # 动量:尝试两个值'epochs': [5]  # 训练轮次(固定为5,减少计算量)}best_acc = 0.0best_params = None  # 存储最优超参数# 遍历所有超参数组合(共 2×2×2=8 种组合)for params in ParameterGrid(param_grid):print(f'\n---------- 测试超参数组合:{params} ----------')# 用交叉验证评估当前组合的性能current_acc = cross_validate(model=Net(),train_dataset=train_dataset,k_folds=5,epochs=params['epochs'],lr=params['lr'],batch_size=params['batch_size'],momentum=params['momentum'])# 记录最优组合if current_acc > best_acc:best_acc = current_accbest_params = paramsprint(f'★ 发现更优组合!当前最优准确率:{best_acc:.2f}%')print(f'\n===== 超参数调优完成 =====')print(f'最优超参数:{best_params}')print(f'最优平均验证准确率:{best_acc:.2f}%')return best_params# --------------------------
# 5. 主函数:执行超参数调优 + 最终训练 + 测试集评估
# --------------------------
if __name__ == '__main__':# 步骤1:超参数调优(用交叉验证选最优参数)print('===== 开始超参数调优(这一步比较慢,需要耐心等待)=====')best_params = hyperparameter_tuning(full_trainset)# 步骤2:用最优超参数在完整训练集上训练最终模型print('\n===== 用最优超参数训练最终模型 =====')final_model = Net()criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(final_model.parameters(),lr=best_params['lr'],momentum=best_params['momentum'])train_loader = DataLoader(full_trainset,batch_size=best_params['batch_size'],shuffle=True)# 训练最终模型(轮次与调优时一致)for epoch in range(best_params['epochs']):final_model.train()running_loss = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = final_model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 200 == 199:print(f'最终模型训练 - 轮次 {epoch + 1},第 {i + 1} 步:平均损失 {running_loss / 200:.3f}')running_loss = 0.0# 步骤3:在测试集上评估最终模型(用从未见过的测试数据)print('\n===== 在测试集上评估最终模型 =====')final_model.eval()test_loader = DataLoader(testset, batch_size=32, shuffle=False)correct = 0total = 0with torch.no_grad():for data in test_loader:images, labels = dataoutputs = final_model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()test_acc = 100 * correct / totalprint(f'最终模型在测试集上的准确率:{test_acc:.2f}%')# 步骤4:保存最优模型torch.save(final_model.state_dict(), 'cifar10_best_model.pth')print(f"最优模型已保存为 'cifar10_best_model.pth'(超参数:{best_params})")

新增加的功能 :

(1)5 折交叉验证(cross_validate函数)
  • 作用:把训练集分成 5 份,每次用 4 份训练、1 份验证,重复 5 次,取平均准确率作为 “该参数组合的得分”。
  • 白话举例:相当于学生做 5 套模拟题,每次用 4 套复习、1 套测试,最后算平均分,比只做 1 套题更能反映真实水平。
  • 关键细节:每折都重新训练新模型,避免前一折的 “记忆” 影响结果。
(2)超参数调优(hyperparameter_tuning函数)
  • 作用:尝试不同的超参数组合(如学习率 0.001 vs 0.01,批大小 32 vs 64),用交叉验证选平均分最高的组合。
  • 白话举例:相当于学生尝试不同的复习方法(每天学 1 小时 vs 2 小时,刷题 vs 看笔记),通过模拟题平均分找到最适合自己的方法。
  • 参数网格:代码中测试了 8 种组合(2 学习率 ×2 批大小 ×2 动量),可根据需要增减(组合越多,计算时间越长)。
(3)最终模型训练
  • 用调优得到的 “最优超参数” 在完整训练集上重新训练模型(之前交叉验证只用了部分数据)。
  • 最后在独立的测试集上评估(测试集从未参与训练和调优,相当于 “高考”)。
3. 运行说明
  • 计算时间:超参数调优 + 交叉验证会比原代码慢很多(8 种组合 ×5 折 ×5 轮训练),建议在有 GPU 的环境运行。
  • 结果解读:最终会输出 “最优超参数” 和 “测试集准确率”,这个准确率比原代码更可信(排除了偶然因素)。
  • 可调整项param_grid中的参数可以修改(如增加学习率选项[0.0001, 0.001, 0.01]),但组合数会增加,计算时间变长。

通过这两个步骤,模型的性能和可靠性会显著提升,尤其适合数据量不大的场景(如医学影像、小数据集)。

交叉验证

一、什么是交叉验证?为什么需要它?

1. 核心问题:如何判断模型好坏?

假设你用一份训练集训练模型,然后用同一批数据测试,准确率 90%—— 这能说明模型好吗?不能!因为模型可能 “死记硬背” 了训练数据(过拟合),换一批新数据就不行了。

所以需要用 “没见过的数据” 来验证模型 —— 但我们只有一份训练集,怎么办?

2. 交叉验证的解决思路

交叉验证(以代码中的5 折交叉验证为例)就像 “多次模拟考试”:

  1. 把训练集分成 5 等份(比如 5 个小数据集 A、B、C、D、E)。
  2. 第一次:用 A、B、C、D 训练模型,用 E 验证(看模型在 E 上的准确率)。
  3. 第二次:用 A、B、C、E 训练,用 D 验证。
  4. 重复 5 次(每次换一份做验证集),最后取 5 次验证准确率的平均值。

这样做的好处:

  1. 避免 “一次验证” 的偶然性(比如刚好抽到简单的验证集)。
  2. 更全面地评估模型在不同数据分布上的表现,结果更可靠。
3. 代码中的交叉验证实现(cross_validate 函数)

代码里的cross_validate函数就是干这个的:

  1. KFold(n_splits=5)把训练集分成 5 份。
  2. 循环 5 次(每折):
    1. 每次从 5 份中选 4 份做 “临时训练集”,1 份做 “临时验证集”。
    2. 用临时训练集训练模型,用临时验证集算准确率。
  3. 最后返回 5 次准确率的平均值,作为这个模型 / 超参数组合的 “评分”。

二、什么是超参数调优?为什么需要它?

1. 超参数是什么?

超参数是训练前手动设定的参数,不是模型自己学出来的。比如代码中的:

  1. lr(学习率):模型更新参数的 “步长”,太大可能跑过头,太小可能学太慢。
  2. batch_size(批大小):每次训练用多少数据,影响训练速度和稳定性。
  3. momentum(动量):优化器的参数,帮助模型更快收敛。

这些参数直接影响模型的训练效果,但没有 “标准答案”,需要试出来。

2. 超参数调优的目的

找到一组最好的超参数组合,让模型的性能(比如准确率)达到最高。

比如:学习率 0.01 + 批大小 32 + 动量 0.9 可能比 学习率 0.001 + 批大小 64 + 动量 0.85 效果更好,我们需要找到这个 “更好” 的组合。

3. 代码中的超参数调优实现(网格搜索)

代码用了 “网格搜索” 的方法,原理很简单:

  1. 列清单:先定义每个超参数的可能取值(比如lr选 [0.001, 0.01],batch_size选 [32, 64])。
  2. 组合所有可能:把这些取值的所有搭配列出来(比如 2×2×2=8 种组合)。
  3. 逐个测试:对每种组合,用交叉验证算它的 “评分”(平均验证准确率)。
  4. 选最优:最后挑出评分最高的组合,作为 “最佳超参数”。

对应代码中的hyperparameter_tuning函数:

  1. param_grid定义了要测试的超参数和可能值。
  2. ParameterGrid自动生成所有组合。
  3. 循环每个组合,用cross_validate算分,保存最高分的组合。

三、交叉验证和超参数调优的关系

简单说:超参数调优是找最好的配方”,交叉验证是 “判断配方好不好的工具”

  1. 没有交叉验证,直接用一组数据测试超参数,可能因为 “运气好” 选错(比如刚好验证集简单)。
  2. 用交叉验证评估每个超参数组合,结果更可靠,能真正找到 “稳定好” 的组合。

总结

  1. 交叉验证:通过多次 “训练 - 验证” 划分,更可靠地评估模型性能,避免偶然性。
  2. 超参数调优:通过尝试不同的超参数组合(网格搜索),结合交叉验证的评分,找到让模型表现最好的 “参数配方”。

代码中,先通过超参数调优找到最好的参数,再用这些参数训练最终模型,最后在测试集上验证 —— 这样得到的模型更可能在新数据上表现良好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/89909.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/89909.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/89909.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解决pip指令超时问题

用pip指令,在安装Django3.2时报错,询问ChatGpt后得到的解决方案pip 下载超时 —— 是 当前网络连接到 PyPI 官方源太慢或不稳定,甚至可能连不上了,而 pip 默认的超时时间又太短,就导致了中途失败:ReadTimeo…

Oracle定时清理归档日志

线上归档日志满了,系统直接崩了,为解决这个问题,创建每月定时清理归档日志。 创建文件名 delete_archivelog.rman CONFIGURE ARCHIVELOG DELETION POLICY CLEAR; RUN {ALLOCATE CHANNEL c1 TYPE DISK;DELETE ARCHIVELOG ALL COMPLETED BEFORE…

ELF 文件操作手册

目录 一、ELF 文件结构概述 二、查看 ELF 文件头信息 1、命令选项 2、示例输出 3、内核数据结构 三、ELF 程序头表 1、命令选项 2、示例输出 3、关键说明 4、内核数据结构 四、ELF 节头表详解 查看节头表信息 1、命令选项 2、示例输出 3、标志说明 4、重要节说…

深入浅出Python函数:参数传递、作用域与案例详解

🙋‍♀️ 博主介绍:颜颜yan_ ⭐ 本期精彩:深入浅出Python函数:参数传递、作用域与案例详解 🏆 热门专栏:零基础玩转Python爬虫:手把手教你成为数据猎人 🚀 专栏亮点:零基…

ps aux 和 ps -ef

在 Linux/Unix 系统中,ps aux 和 ps -ef 都是用于查看进程信息的命令,结合 grep node 可以筛选出与 Node.js 相关的进程。它们的核心功能相似,但在输出格式和选项含义上有区别:1. 命令对比命令含义主要区别ps auxBSD 风格语法列更…

Spark ML 之 LSH

src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala test("approxSimilarityJoin for self join") {val data = {for (i <- 0 until 24) yield Vectors

关键成功因素法(CSF)深度解析:从战略目标到数据字典

关键成功因素法由John Rockart提出&#xff0c;用于信息系统规划&#xff0c;帮助企业识别影响系统成功的关键因素&#xff0c;从而确定信息需求&#xff0c;指导信息技术管理。该方法通过识别关键成功因素&#xff0c;找出关键信息集合&#xff0c;确定系统开发优先级&#xf…

Django母婴商城项目实践(六)- Models模型之ORM操作

6、Models模型操作 1 ORM概述 介绍 Django对数据进行增删改操作是借助内置的ORM框架(Object Relational Mapping,对象关系映射)所提供的API方法实现的,允许你使用类和对象对数据库进行操作,从而避免通过SQL语句操作数据库。 简单来说,ORM框架的数据操作API是在 QuerySet…

【PTA数据结构 | C语言版】哥尼斯堡的“七桥问题”

本专栏持续输出数据结构题目集&#xff0c;欢迎订阅。 文章目录题目代码题目 哥尼斯堡是位于普累格河上的一座城市&#xff0c;它包含两个岛屿及连接它们的七座桥&#xff0c;如下图所示。 可否走过这样的七座桥&#xff0c;而且每桥只走过一次&#xff1f;瑞士数学家欧拉(Leo…

Redis 详解:从入门到进阶

文章目录前言一、什么是 Redis&#xff1f;二、Redis 使用场景1. 缓存热点数据2. 消息队列3. 分布式锁4. 限流与防刷5. 计数器、排行榜三、缓存三大问题&#xff1a;雪崩 / 穿透 / 击穿1. ❄️ 缓存雪崩&#xff08;Cache Avalanche&#xff09;2. &#x1f50d; 缓存穿透&…

QCustomPlot 使用教程

下载网址&#xff1a;官方网站&#xff1a;http://www.qcustomplot.com/我的环境是 window10 qt5.9.9 下载后&#xff0c;官网提供了很多例子。可以作为参考直接运行自己如何使用&#xff1a;第一步&#xff1a;使用QCustomPlot非常简单&#xff0c;只需要把qcustomplot.cpp和…

基于springboot+mysql的作业管理系统(源码+论文)

一、开发环境 1 Spring Boot框架简介 描述&#xff1a; 简化开发&#xff1a;Spring Boot旨在简化新Spring应用的初始搭建和开发过程。配置方式&#xff1a;采用特定的配置方式&#xff0c;减少样板化配置&#xff0c;使开发人员无需定义繁琐的配置。开发工具&#xff1a;可…

LVS 集群技术基础

LVS(linux virual server)LVS集群技术---NAT模式一.准备四台虚拟机1.client(eth0ip:172.254.100)2.lvs(eth0ip:172.254.200;eth1ip:192.168.0.200)3.rs1(eht0ip:192.168.0.10)4.rs2(eth0ip:192.168.0.20)二&#xff1a;在rs1和rs2安装httpd功能dnf/yum install htppd -y三&…

Oracle RU19.28补丁发布,一键升级稳

&#x1f4e2;&#x1f4e2;&#x1f4e2;&#x1f4e3;&#x1f4e3;&#x1f4e3; 作者&#xff1a;IT邦德 中国DBA联盟(ACDU)成员&#xff0c;15年DBA工作经验 Oracle、PostgreSQL ACE CSDN博客专家及B站知名UP主&#xff0c;全网粉丝15万 擅长主流Oracle、MySQL、PG、高斯及…

lvs 集群技术

LVS概念LVS&#xff1a;Linux Virtual Server&#xff0c;负载调度器&#xff0c;是一种基于Linux操作系统内核的高性能、高可用网络服务负载均衡解决方案。LVS工作原理基于网络层&#xff08;四层&#xff0c;传输层&#xff09;的负载均衡技术&#xff0c;它通过内核级别的IP…

AR巡检和传统巡检的区别

随着工业4.0时代的到来&#xff0c;数字化转型逐渐成为各行各业提升效率、保障安全和降低成本的关键。而在这一转型过程中&#xff0c;巡检工作作为确保设备稳定运行的重要环节&#xff0c;逐步从传统方式走向智能化、数字化。尤其是增强现实&#xff08;AR&#xff09;技术的引…

Axure设计设备外壳 - AxureMost 落葵网

在UI设计中&#xff0c;设备外壳&#xff08;硬件外壳与界面中的“虚拟外壳”&#xff09;和背景是构成视觉体验的核心元素&#xff0c;它们不仅影响美观&#xff0c;更直接关联用户对功能的理解和操作效率。以下从设计角度详细解析其作用与使用逻辑&#xff1a; 一、设备外壳&…

基于深度学习的电信号分类识别与混淆矩阵分析

基于深度学习的电信号分类识别与混淆矩阵分析 1. 引言 1.1 研究背景与意义 电信号分类识别是信号处理领域的重要研究方向,在医疗诊断、工业检测、通信系统等多个领域有着广泛的应用。传统的电信号分类方法主要依赖于手工提取特征和浅层机器学习模型,但这些方法往往难以捕捉…

Git 和Gitee远程连接 上传和克隆

第一步创建远程库第二步初始化本地库创建链接删掉.idea 和target(这两个没用运行就自动生成了)右键空白处选择Git Bash Here 初始化本地库git init建立远程连接建立连接这里是我的地址&#xff0c;后面拼接你的地址git remote add origin https://gitee.com/liu-qing_liang/git…

零基础100天CNN实战计划:用Python从入门到图像识别高手

一、为什么你需要这份100天CNN学习计划&#xff1f; 在人工智能领域&#xff0c;卷积神经网络&#xff08;CNN&#xff09; 是计算机视觉的基石技术。无论是人脸识别、医学影像分析还是自动驾驶&#xff0c;CNN都扮演着核心角色。但对于初学者来说&#xff0c;面对复杂的数学公…