手写数字识别是计算机视觉领域的“Hello World”,也是深度学习入门的经典案例。它通过训练模型识别0-9的手写数字图像(如MNIST数据集),帮助我们快速掌握神经网络的核心流程。本文将以PyTorch框架为基础,带你从数据加载、模型构建到训练评估,完整实现一个手写数字识别系统。

二、数据加载与预处理:认识MNIST数据集

1. MNIST数据集简介

MNIST是手写数字的标准数据集,包含:

  • 训练集:60,000张28x28的灰度图(0-9数字)
  • 测试集:10,000张同尺寸图片
  • 每张图片已归一化(像素值0-1),标签为0-9的整数

2. 代码实现:下载与加载数据

使用torchvision.datasets可直接下载MNIST,transforms.ToTensor()将图片转为PyTorch张量(通道优先格式:[1,28,28],1为灰度通道数)。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor# 下载训练集(60,000张)
train_data = datasets.MNIST(root="data",       # 数据存储路径train=True,        # 标记为训练集download=True,     # 自动下载(首次运行时)transform=ToTensor()  # 转为张量(shape: [1,28,28])
)# 下载测试集(10,000张)
test_data = datasets.MNIST(root="data",train=False,       # 标记为测试集download=True,transform=ToTensor()
)

3. 数据封装:DataLoader批量加载

DataLoader将数据集打包为可迭代的批量数据,支持随机打乱(训练集)、多线程加载等。

device = "cuda" if torch.cuda.is_available() else "cpu"  # 自动选择GPU/CPU
batch_size = 64  # 每批64张图片(可根据显存调整)# 训练集DataLoader(打乱顺序)
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
# 测试集DataLoader(不打乱顺序)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

三、模型构建:设计卷积神经网络(CNN)

1. 为什么选择CNN?

手写数字识别需要捕捉图像的局部特征(如笔画边缘、拐点),而CNN的卷积层通过滑动窗口提取局部模式,池化层降低计算量,全连接层完成分类,非常适合处理图像任务。

2. 模型结构详解(附代码注释)

以下是我们定义的CNN模型,包含3个卷积块和1个全连接输出层:

class CNN(nn.Module):def __init__(self):super().__init__()  # 继承PyTorch模块基类# 卷积块1:输入1通道(灰度图)→ 输出8通道特征图self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,    # 输入通道数(灰度图)out_channels=8,   # 输出8个特征图(8个卷积核)kernel_size=5,    # 卷积核尺寸5x5(覆盖局部区域)stride=1,         # 滑动步长1(不跳跃)padding=2         # 边缘填充2圈0(保持输出尺寸不变)),nn.ReLU(),  # 非线性激活(引入复杂模式)nn.MaxPool2d(kernel_size=2)  # 最大池化(2x2窗口,尺寸减半))# 卷积块2:特征抽象(8→16→32通道)self.conv2 = nn.Sequential(nn.Conv2d(8, 16, 5, 1, 2),  # 8→16通道,5x5卷积,填充2(尺寸不变)nn.ReLU(),nn.Conv2d(16, 32, 5, 1, 2), # 16→32通道,5x5卷积,填充2(尺寸不变)nn.ReLU(),nn.MaxPool2d(kernel_size=2)  # 尺寸减半(14→7))# 卷积块3:特征精炼(32→256通道,保留空间信息)self.conv3 = nn.Sequential(nn.Conv2d(32, 256, 5, 1, 2),  # 32→256通道,5x5卷积,填充2(尺寸不变)nn.ReLU())# 全连接输出层:256*7*7维特征→10类概率self.out = nn.Linear(256 * 7 * 7, 10)  # 10对应0-9数字类别def forward(self, x):"""前向传播:定义数据流动路径"""x = self.conv1(x)  # 输入:[64,1,28,28] → 输出:[64,8,14,14](池化后尺寸减半)x = self.conv2(x)  # 输入:[64,8,14,14] → 输出:[64,32,7,7](两次卷积+池化)x = self.conv3(x)  # 输入:[64,32,7,7] → 输出:[64,256,7,7](仅卷积)x = x.view(x.size(0), -1)  # 展平:[64,256,7,7] → [64,256*7*7](全连接需要一维输入)output = self.out(x)       # 输出:[64,10](每个样本对应10类的得分)return output

3. 关键参数计算(以输入28x28为例)

  • conv1后:卷积核5x5,填充2,输出尺寸(28-5+2*2)/1 +1=28;池化后尺寸28/2=14 → 输出[64,8,14,14]
  • conv2后:两次卷积保持14x14,池化后14/2=7 → 输出[64,32,7,7]
  • conv3后:卷积保持7x7 → 输出[64,256,7,7]
  • 展平后256*7*7=12544维向量 → 全连接到10类

四、训练配置:损失函数与优化器

1. 损失函数:交叉熵损失(CrossEntropyLoss)

手写数字识别是多分类任务,交叉熵损失函数直接衡量模型输出概率与真实标签的差异。PyTorch的nn.CrossEntropyLoss已集成Softmax操作(无需手动添加)。

2. 优化器:随机梯度下降(SGD)

优化器负责根据损失值更新模型参数。这里选择SGD(学习率lr=0.1),简单且对小数据集友好(也可尝试Adam等更复杂的优化器)。

model = CNN().to(device)  # 模型加载到GPU/CPU
loss_fn = nn.CrossEntropyLoss()  # 交叉熵损失
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)  # SGD优化器

五、训练循环:让模型“学习”特征

1. 训练逻辑概述

训练过程的核心是“前向传播→计算损失→反向传播→更新参数”,重复直到模型收敛。具体步骤:

  1. 模型设为训练模式(model.train());
  2. 遍历训练数据,按批输入模型;
  3. 计算预测值与真实标签的损失;
  4. 反向传播计算梯度(loss.backward());
  5. 优化器更新参数(optimizer.step());
  6. 清空梯度(optimizer.zero_grad())避免累积。

2. 代码实现:训练函数

def train(dataloader, model, loss_fn, optimizer):model.train()  # 开启训练模式(影响Dropout/BatchNorm等层)total_loss = 0  # 记录总损失for batch_idx, (x, y) in enumerate(dataloader):x, y = x.to(device), y.to(device)  # 数据加载到GPU/CPU# 1. 前向传播:模型预测pred = model(x)# 2. 计算损失:预测值 vs 真实标签loss = loss_fn(pred, y)total_loss += loss.item()  # 累加批次损失# 3. 反向传播:计算梯度optimizer.zero_grad()  # 清空历史梯度loss.backward()        # 反向传播计算当前梯度# 4. 更新参数:根据梯度调整模型权重optimizer.step()# 每100个批次打印一次损失(监控训练进度)if (batch_idx + 1) % 100 == 0:print(f"批次 {batch_idx+1}/{len(dataloader)}, 当前损失: {loss.item():.4f}")avg_loss = total_loss / len(dataloader)print(f"训练完成,平均损失: {avg_loss:.4f}")

六、测试评估:验证模型泛化能力

1. 测试逻辑概述

测试阶段需关闭模型的随机操作(如Dropout),用测试集评估模型的泛化能力。核心指标是准确率(正确预测的样本比例)。

2. 代码实现:测试函数

def test(dataloader, model):model.eval()  # 开启评估模式(关闭Dropout等随机层)correct = 0   # 记录正确预测数total = 0     # 记录总样本数with torch.no_grad():  # 关闭梯度计算(节省内存)for x, y in dataloader:x, y = x.to(device), y.to(device)pred = model(x)  # 模型预测# 统计正确数:pred.argmax(1)取预测概率最大的类别correct += (pred.argmax(1) == y).sum().item()total += y.size(0)  # 累加批次样本数accuracy = correct / totalprint(f"测试准确率: {accuracy * 100:.2f}%")return accuracy

七、完整训练与结果

1. 运行训练循环

我们训练10个epoch(遍历整个训练集10次):

# 训练10轮
for epoch in range(10):print(f"
====={epoch+1} 轮训练 =====")train(train_dataloader, model, loss_fn, optimizer)# 测试最终效果
print("
===== 最终测试 =====")
test_acc = test(test_dataloader, model)

2. 典型输出结果

假设训练10轮后,测试准确率可能达到98.5%+(具体取决于超参数和硬件):

===== 第 1 轮训练 =====
批次 100/938, 当前损失: 0.2145
...
训练完成,平均损失: 0.1234===== 第 10 轮训练 =====
批次 100/938, 当前损失: 0.0321
...
训练完成,平均损失: 0.0189===== 最终测试 =====
测试准确率: 98.76%

八、改进方向:让模型更强大

当前模型已能较好识别手写数字,但仍有优化空间:

1. 调整超参数

  • 学习率:若损失下降缓慢,降低lr(如0.01);若波动大,增大lr
  • 批量大小:增大batch_size(如128)可加速训练(需更大显存)。
  • 训练轮次:增加epoch(如20轮),但需防止过拟合(训练损失持续下降,测试损失上升)。

2. 添加正则化

  • Batch Normalization:在卷积层后添加nn.BatchNorm2d(out_channels),加速收敛并稳定训练。
    self.conv1 = nn.Sequential(nn.Conv2d(1,8,5,1,2),nn.BatchNorm2d(8),  # 新增nn.ReLU(),nn.MaxPool2d(2)
    )
    
  • Dropout:在全连接层前添加nn.Dropout(p=0.5),随机断开神经元,防止过拟合。
    self.out = nn.Sequential(nn.Dropout(0.5),  # 新增nn.Linear(256*7*7, 10)
    )
    

3. 使用更深的网络

当前模型仅3个卷积块,对于复杂任务(如ImageNet),可使用ResNet等残差网络,通过跳跃连接(Skip Connection)解决深层网络的梯度消失问题。

九、总结

通过本文,你已完成从数据加载到模型训练的全流程,掌握了:

  • 数据预处理:使用torchvision加载标准数据集,DataLoader批量管理数据;
  • 模型构建:设计CNN的核心组件(卷积层、激活函数、池化层);
  • 训练与评估:理解损失函数、优化器的作用,掌握训练循环和测试逻辑。

手写数字识别是深度学习的起点,你可以尝试修改模型结构(如增加卷积层)、更换数据集(如Fashion-MNIST)或调整超参数,进一步探索深度学习的魅力!

动手建议:运行代码时,尝试将device改为cpu(无GPU时),观察训练速度变化;或修改kernel_size(如3x3),对比模型性能差异。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/94607.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/94607.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/94607.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

实战笔记——构建智能Agent:SpreadJS代码助手

目录 前言 解决思路 需求理解 MCP Server LangGraph 本教程目标 技术栈 第一部分:构建 MCP Server - 工具服务化的基础架构 第二部分:Tools 实现 第三部分:基于 LangGraph 构建智能 Agent 第四部分:服务器和前端搭建 前…

【Word】用 Python 轻松实现 Word 文档对比并生成可视化 HTML 报告

在日常工作和学习中,我们经常需要对两个版本的文档进行比对,比如合同修改、论文修订、报告更新等。手动逐字检查不仅耗时费力,还容易遗漏细节。 今天,我将带你使用 Python python-docx difflib 实现一个自动化 Word 文档对比工具…

从0开始搭建一个前端项目(vue + vite + typescript)

版本 node:v22.17.1 pnpm:v10.13.1 vue:^3.5.18 vite:^7.0.6 typescipt:~5.8.0脚手架初始化vue pnpm create vuelatest只选择: TypeScript, JSX 3. 用vscode打开创建的项目,并删除多余的代码esl…

1.ImGui-环境安装

免责声明:内容仅供学习参考,请合法利用知识,禁止进行违法犯罪活动! 本次游戏没法给 内容参考于:微尘网络安全 IMGUI是一个被广泛应用到逆向里面的,它可以用来做外部的绘制,比如登录界面&…

基于springboot的二手车交易系统

博主介绍:java高级开发,从事互联网行业六年,熟悉各种主流语言,精通java、python、php、爬虫、web开发,已经做了六年的毕业设计程序开发,开发过上千套毕业设计程序,没有什么华丽的语言&#xff0…

修改win11任务栏时间字体和小图标颜色

1 打开运行提示框 在桌面按快捷键winR,然后如下图所示输入regedit2 查找路径 1、在路径处粘贴路径计算机\HKEY_CURRENT_USER\Software\Microsoft\Windows\CurrentVersion\Themes\Personalize 2、如下图所示,双击打开ColorPrevalence,将里面的…

第13集 当您的USB设备不在已实测支持列表,如何让TOS-WLink支持您的USB设备--答案Wireshark USB抓包

问:当您的USB设备不在已实测支持列表,如何让TOS-WLink支持您的USB设备? 答案:使用Wireshark USB抓包,日志发给我 为什么要抓包: USB设备种类繁多;TOS-WLink是单片机,内存紧张&#…

[灵动微电子 MM32BIN560CN MM32SPIN0280]读懂电机MCU之比较器

作为刚接触微控制器的初学者,在看到MM32SPIN0280用户手册中“比较器”相关内容时,是不是会感到困惑?比如“5个通用比较器”“轮询功能”“迟滞电压”这些术语,好像都和电机控制有关,但又不知道具体怎么用。别担心&…

⸢ 贰 ⸥ ⤳ 安全架构:数字银行安全体系规划

👍点「赞」📌收「藏」👀关「注」💬评「论」 🔥更多文章戳👉Whoami!-CSDN博客🚀 在金融科技深度融合的背景下,信息安全已从单纯的技术攻防扩展至架构、合规、流程与创新的…

布隆过滤器完全指南:从原理到实战

布隆过滤器完全指南:从原理到实战 摘要:本文深入解析布隆过滤器的核心原理、实现细节和实际应用,提供完整的Java实现代码,并探讨性能优化策略。适合想要深入理解概率数据结构的开发者阅读。 前言 在大数据时代,如何快速判断一个元素是否存在于海量数据集合中?传统的Hash…

​嵌入式Linux学习 - 网络服务器实现与客户端的通信

1.单循环服务器 2.并发服务器 1. 设置socket属性 2. 进程 ​3. 线程 3.多路IO复用模型 - 提高并发程度 1. 区别 2. IO处理模型 1. 阻塞IO模型 2. 非阻塞IO模型 3. 信号驱动IO 4. IO多路复用 3. 特点 4. 函数接口 1. select 2. poll 3. epoll 半包 1.单循环服务…

Mybatis中缓存机制的理解以及优缺点

文章目录一、MyBatis 缓存机制详解1. 一级缓存(Local Cache)2. 二级缓存(Global Cache)3. 缓存执行顺序二、MyBatis 缓存的优点三、MyBatis 缓存的缺点四、适用场景与最佳实践总结MyBatis 提供了完善的缓存机制,用于减…

Rust 登堂 之 类型转换(三)

Rust 是类型安全的语言,因此在Rust 中做类型转换不是一件简单的事,这一章节,我们将对Rust 中的类型转换进行详尽讲解。 高能预警,本章节有些难,可以考虑学了进阶后回头再看 as 转换 先来看一段代码 fn main() {let a…

【MySQL 为什么默认会给 id 建索引? MySQL 主键索引 = 聚簇索引?】

MySQL 索引 MySQL 为什么默认会给 id 建索引? & MySQL 主键索引 聚簇索引? 结论:在 MySQL (InnoDB) 中,主键索引是自动创建的聚簇索引,不需要删除,其他索引是补充优化。 1. MySQL 的id 索引是怎么来的…

[光学原理与应用-321]:皮秒深紫外激光器产品不同阶段使用的工具软件、对应的输出文件

在皮秒深紫外激光器的开发过程中,不同阶段使用的工具软件及其对应的输出文件如下:一、设计阶段工具软件:Zemax OpticStudio:用于光学系统的初步设计和仿真,包括光线追迹、像差分析、优化设计等。MATLAB:用于…

openEuler常用操作指令

openEuler常用操作指令 一、前言 1.简介 openEuler是由开放原子开源基金会孵化的全场景开源操作系统项目,面向数字基础设施四大核心场景(服务器、云计算、边缘计算、嵌入式),全面支持ARM、x86、RISC-V、loongArch、PowerPC、SW…

Python爬虫实战:构建网易云音乐个性化音乐播放列表同步系统

1. 引言 1.1 研究背景 在数字音乐生态中,各大音乐平台凭借独家版权、个性化推荐等优势占据不同市场份额。根据国际唱片业协会(IFPI)2024 年报告,全球流媒体音乐用户已突破 50 亿,其中超过 60% 的用户同时使用 2 个及以上音乐平台。用户在不同平台积累的播放列表包含大量…

vscode 配置 + androidStudio配置

插件代码片段 饿了么 icon{"Print to console": {"prefix": "ii-ep-","body": ["i-ep-"],"description": "elementPlus Icon"} }Ts 初始化模版{"Print to console": {"prefix": &q…

DQN(深度Q网络):深度强化学习的里程碑式突破

本文由「大千AI助手」原创发布,专注用真话讲AI,回归技术本质。拒绝神话或妖魔化。搜索「大千AI助手」关注我,一起撕掉过度包装,学习真实的AI技术! ✨ 1. DQN概述:当深度学习遇见强化学习 DQN(D…

个人博客运行3个月记录

个人博客 自推一波,目前我的Hexo个人博客已经优化的足够好了, 已经足够稳定的和简单进行发布和管理,但还是有不少问题,总之先记下来再说 先总结下 关于评论系统方面,我从Waline (快速上手 | Waline) 更换成了&#x…