文章目录

  • 手写数字识别项目
    • 一、准备数据集
    • 二、定义模型
    • 三、模型训练
      • 3.1 导入依赖库
      • 3.2 设备设置(CPU/GPU 自动选择)
      • 3.3 超参数定义
      • 3.4数据集准备
        • 1.获取数据集
        • 2.划分训练集与验证集
        • 3.创建 DataLoader(按批次加载数据)
      • 3.5模型初始化与断点续训
      • 3.6损失函数与优化器定义
      • 3.7训练函数(train ())
      • 3.8验证函数(valid ())
      • 3.9主训练循环(多轮训练与验证)
    • 四、模型训练完整代码
    • 五、总结流程

手写数字识别项目

一、准备数据集

首先我们创建一个卷积模型,训练的时候就需要一个原始的数据集,那么数据集哪里来?Pytorch官网其实有一些数据集,数据集地址
在这里插入图片描述

我们使用到的数据集是MNIST

导入包

import torch
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

使用数据集,所有的官方数据集都继承 torch.utils.data.Dataset,如果你没有数据集,那download = True,它会联网下载到你本地。

# label: 数据集传入的标签值
def target_transform(label):return torch.tensor(label)ds = MNIST(root='./data',  # 保存或读取数据的目录train=True,  # 是否加载训练数据集download=False,  # 是否下载数据集transform=ToTensor(),  # 用于转换图片的函数# target_transform=target_transform  # 用于转换标签的函数target_transform=lambda label: torch.tensor(label)  # 直接匿名函数转换成张量
)

测试打印数据

print(len(ds))
print(ds[0])
print(ds[0][0].shape)

二、定义模型

简单的图像识别模型的套路:卷积 -> 激活 -> 池化 -> … -> 卷积 -> 激活 -> 池化 ->展平 -> 全连接层 -> 激活-> … -> 全连接层输出,会将图片缩小的同时增加通道数,当特征图缩小到 10 以内,就结束卷积过程。之后我们会讲到LeNet5模型,这儿我们简单的定义一个模型进行训练。

from torch import nn# 卷积激活池化 模块
class ConvActivatePool(nn.Module):def __init__(self, in_channels, out_channels, kernel_size):super().__init__()# 一般卷积后会选择让图片大小保持不变 进行填充self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding='same')self.relu = nn.ReLU()# 池化在此处提取了特征的同时,让图片下采样了self.pool = nn.MaxPool2d(2)def forward(self, x):x = self.conv(x)x = self.relu(x)y = self.pool(x)return yclass NumberRecognition(nn.Module):def __init__(self):super().__init__()self.cap1 = ConvActivatePool(1, 64, 11)self.cap2 = ConvActivatePool(64, 128, 5)# 分类层self.classifier = nn.Sequential(# 展平nn.Flatten(start_dim=1),# 全连接层nn.Linear(128 * 7 * 7, 2048),nn.ReLU(),nn.Dropout(p=0.3),nn.Linear(2048, 1024),nn.ReLU(),# 输出结果为 10 分类,所以输出层全连接输出 10nn.Linear(1024, 10))# x 形状 (N, C=1, H=28, W=28)def forward(self, x):x = self.cap1(x)# N x 64 x 14 x 14x = self.cap2(x)# N x 128 x 7 x 7# 图片缩小到 10 以内,则停止卷积# 调用分类器,对图片进行分类y = self.classifier(x)return yif __name__ == '__main__':import torchmodel = NumberRecognition()x = torch.rand(16, 1, 28, 28)y = model(x)print(y.shape)

三、模型训练

3.1 导入依赖库

import math
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split, Subset
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from model import NumberRecognition

3.2 设备设置(CPU/GPU 自动选择)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

3.3 超参数定义

EPOCH = 10          # 训练轮次:整个训练集遍历10次
LR = 1e-2           # 学习率:控制参数更新的步长(1e-2 = 0.01)
BATCH_SIZE = 10     # 批次大小:每次训练用10个样本更新一次参数
val_rate = 0.2      # 验证集比例:从训练集中划分20%作为验证集

3.4数据集准备

1.获取数据集
ds = MNIST(root='./data',        # 数据集保存路径(若不存在会自动创建)train=True,           # 加载训练集(False则加载测试集)download=False,       # 是否自动下载数据集(首次运行需设为True)transform=ToTensor(), # 对图像的变换:PIL→Tensor(0-1归一化+维度调整)target_transform=lambda label: torch.tensor(label) # 对标签的变换:int→Tensor
)
2.划分训练集与验证集
ds_total_len = len(ds)          # 总样本数:MNIST训练集共60000个样本
train_len = int(ds_total_len * (1 - val_rate)) # 训练集样本数:60000×0.8=48000
val_len = ds_total_len - train_len             # 验证集样本数:60000×0.2=12000
train_ds, val_ds = random_split(ds, [train_len, val_len]) # 随机划分
3.创建 DataLoader(按批次加载数据)
# 计算总批次数(向上取整,避免最后一批样本被丢弃)
train_total_batch = math.ceil(train_len / BATCH_SIZE) # 48000/10=4800批
val_total_batch = math.ceil(val_len / BATCH_SIZE)     # 12000/10=1200批# 训练集DataLoader
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True  # 训练集每次epoch前打乱样本顺序(避免模型记忆样本顺序,提升泛化)
)# 验证集DataLoader
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True  # 验证集打乱无意义(仅计算损失),建议设为False以提高效率
)

3.5模型初始化与断点续训

# 初始化自定义模型(NumberRecognition在model.py中定义,需确保输入输出维度匹配)
model = NumberRecognition()# 尝试加载历史模型参数(支持断点续训)
try:# 加载参数文件(weights_only=True)state_dict = torch.load('./weights/model.pth', weights_only=True)model.load_state_dict(state_dict) # 将参数加载到模型中print('加载模型参数成功')
except:# 若文件不存在(首次训练),打印提示print('未找到模型参数')# 将模型迁移到指定设备(CPU/GPU)
model.to(device)

3.6损失函数与优化器定义

# 损失函数:交叉熵损失(适合多分类任务,如MNIST的10类数字)
loss_fn = nn.CrossEntropyLoss()# 优化器:Adam优化器(常用优化器,结合SGD的动量和RMSprop的自适应学习率)
optimizer = torch.optim.Adam(model.parameters(),  # 需优化的参数(模型的所有权重和偏置)lr=LR,               # 学习率(与超参数一致)weight_decay=1e-4    # L2正则化(权重衰减,防止模型参数过大导致过拟合)
)

3.7训练函数(train ())

# 全局变量:累计训练损失和批次数量(用于计算平均损失)
train_total_loss = 0.
train_count = 0def train():global train_total_loss, train_count # 声明使用全局变量print('开始训练')model.train() # 将模型设为“训练模式”(关键!启用Dropout/BatchNorm更新)# 遍历训练集DataLoader,每次取一个批次for i, (images, labels) in enumerate(train_dl):# 1. 将数据迁移到指定设备(与模型设备一致)images, labels = images.to(device), labels.to(device)# 2. 清空上一轮的梯度(PyTorch梯度会累加,不清空会导致梯度错误)optimizer.zero_grad()# 3. 前向传播:模型预测输出y_pred = model(images) # 输出形状:(BATCH_SIZE, 10),每一行是10个类的得分# 4. 计算损失(预测值与真实标签的差距)loss = loss_fn(y_pred, labels)# 5. 累计损失和批次数量(用于后续计算平均损失)train_total_loss += loss.item() # loss是Tensor,用.item()转为Python数值train_count += 1# 6. 反向传播:计算参数梯度(自动微分核心)loss.backward()# 7. 优化器更新参数(根据梯度调整权重和偏置)optimizer.step()# 每100个批次打印一次训练进度(避免打印过于频繁)if (i + 1) % 100 == 0:avg_loss = train_total_loss / train_countprint(f'BATCH: [{i + 1}/{train_total_batch}]; loss: {avg_loss:.4f}')# 返回本轮训练的平均损失(用于epoch结束时打印)return train_total_loss / train_count

3.8验证函数(valid ())

def valid():# 局部变量:累计验证损失和批次数量(每轮验证重新初始化,避免与训练混淆)val_total_loss = 0.val_count = 0print('开始验证')model.eval() # 将模型设为“评估模式”(关键!禁用Dropout/BatchNorm更新)# 禁用梯度计算(验证阶段无需反向传播,节省内存和时间)with torch.no_grad():# 遍历验证集DataLoaderfor i, (images, labels) in enumerate(val_dl):# 1. 数据迁移到指定设备images, labels = images.to(device), labels.to(device)# 2. 前向传播(无梯度计算)y_pred = model(images)# 3. 计算验证损失loss = loss_fn(y_pred, labels)val_total_loss += loss.item()val_count += 1# 每100个批次打印验证进度if (i + 1) % 100 == 0:avg_loss = val_total_loss / val_countprint(f'BATCH: [{i + 1}/{val_total_batch}]; loss: {avg_loss:.4f}')# 返回本轮验证的平均损失return val_total_loss / val_count

3.9主训练循环(多轮训练与验证)

# 遍历所有训练轮次
for epoch in range(EPOCH):print(f'\nEPOCH: [{epoch + 1}/{EPOCH}]') # 打印当前轮次(从1开始更直观)# 1. 训练本轮并获取训练平均损失train_loss = train()# 2. 验证本轮并获取验证平均损失val_loss = valid()# 3. 打印本轮训练结果print(f'EPOCH END; train loss: {train_loss:.4f}; val loss: {val_loss:.4f}')# 训练结束后,保存最终模型参数(覆盖原有文件)
torch.save(model.state_dict(), './weights/model.pth')
print('\n模型参数已保存至 ./weights/model.pth')

四、模型训练完整代码

import math
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split, Subset
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from model import NumberRecognitiondevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')EPOCH = 10
LR = 1e-2
BATCH_SIZE = 10
val_rate = 0.2ds = MNIST(root='./data', train=True, download=False, transform=ToTensor(),target_transform=lambda label: torch.tensor(label))ds_total_len = len(ds)
train_len = int(ds_total_len * (1 - val_rate))
val_len = len(ds) - train_len
train_ds, val_ds = random_split(ds, [train_len, val_len])train_total_batch = math.ceil(train_len / BATCH_SIZE)
val_total_batch = math.ceil(val_len / BATCH_SIZE)train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True)model = NumberRecognition()
try:state_dict = torch.load('./weights/model.pth', weights_only=True)model.load_state_dict(state_dict)print('加载模型参数成功')
except:print('未找到模型参数')model.to(device)loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)train_total_loss = 0.
train_count = 0def train():global train_total_loss, train_countprint('开始训练')model.train()for i, (images, labels) in enumerate(train_dl):# 3. 将数据放到设备上images, labels = images.to(device), labels.to(device)optimizer.zero_grad()y = model(images)loss = loss_fn(y, labels)train_total_loss += loss.item()train_count += 1loss.backward()optimizer.step()if (i + 1) % 100 == 0:print(f'BATCH: [{i + 1}/{train_total_batch}]; loss: {train_total_loss / train_count}')return train_total_loss / train_countdef valid():val_total_loss = 0.val_count = 0print('开始验证')model.eval()with torch.no_grad():for i, (images, labels) in enumerate(val_dl):images, labels = images.to(device), labels.to(device)y = model(images)loss = loss_fn(y, labels)val_total_loss += loss.item()val_count += 1if (i + 1) % 100 == 0:print(f'BATCH: [{i + 1}/{val_total_batch}]; loss: {val_total_loss / val_count}')return val_total_loss / val_countfor epoch in range(EPOCH):print(f'EPOCH: [{epoch + 1}/{EPOCH}]')train_loss = train()val_loss = valid()print(f'EPOCH END; train loss: {train_loss}; val loss: {val_loss}')torch.save(model.state_dict(), './weights/model.pth')

五、总结流程

  1. 加载 MNIST 公开手写数字数据集(训练集)
  2. 划分训练集与验证集(用于监控过拟合)
  3. 加载自定义的数字识别模型(NumberRecognition),支持断点续训(加载历史参数)
  4. 定义训练 / 验证流程,使用交叉熵损失和 Adam 优化器训练模型
  5. 训练完成后保存模型参数,便于后续推理或继续训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/95418.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/95418.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/95418.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

批量给文件夹添加文件v2【件批量复制工具】

代码功能介绍 这个代码的功能就是一个,给某个文件夹里面添加某个文件(含父级文件夹下的每一个子文件夹) 举个例子,父级文件夹是:“D:\Desktop\1,要添加的文件路径是:D:\1.txt” 则最后会把文件…

Qt实现2048小游戏:看看AI如何评估棋盘策略实现“人机合一

2048 是一款经典的数字益智游戏,其简单的规则背后蕴含着丰富的策略性。该项目不仅完整实现了 2048 的核心玩法,还包含了一个基于启发式评估和蒙特卡洛方法的智能 AI 玩家。 我们将从项目整体架构入手,逐一解析游戏核心逻辑、UI 渲染、事件处理、AI 策略等关键模块,并通过展…

封装红黑树实现mysetmymap

1. 源码分析 set实例化rb_tree时第二个模板参数给的是key&#xff0c;map实例化rb_tree时第⼆个模板参数给的是 pair<const key,T>&#xff0c;这样一颗红黑树既可以实现key搜索场景的set&#xff0c;也可以实现key/value搜索场 景的map源码里面模板参数是用T代表value&…

以OWTB为核心以客户为基础的三方仓运配一体化平台分析V0.2

一、系统概述以OWTB&#xff08;Order-Warehouse-Transportation-Billing&#xff0c;订单-仓储-运输-结算&#xff09;为核心的三方仓运配一体化平台&#xff0c;是专为第三方物流企业打造的深度定制化解决方案。该平台以第三方仓运配为主线&#xff0c;以多客户/多SKU/个性化…

技术框架之脚手架实现

一、 序言在日常的企业级Java开发中&#xff0c;我们经常会发现自己在重复地做着一些项目初始化工作&#xff1a;创建相似的项目结构、引入一堆固定的依赖包、编写通用的配置文件、拷贝那些几乎每个项目都有的基础工具类和日志配置。这些工作不仅枯燥乏味&#xff0c;而且容易出…

小迪安全v2023学习笔记(七十七讲)—— 业务设计篇隐私合规检测重定向漏洞资源拒绝服务

文章目录前记WEB攻防——第七十七天业务设计篇&隐私合规检测&URL重定向&资源拒绝服务&配合项目隐私合规 - 判断规则&检测项目介绍案例演示URL重定向 - 检测判断&钓鱼配合介绍黑盒测试看业务功能看参数名goole语法搜索白盒测试跳转URL绕过思路钓鱼配合资…

用AI做旅游攻略,真能比人肉整理靠谱?

大家好&#xff0c;我是极客团长&#xff01; 作为一个沉迷研究 “AI 工具怎么渗透日常生活” 的科技博主&#xff0c;我开了个 AI 解决生活小事系列。 前两期聊了用 AI 写新闻博客、扒商业报告&#xff0c;后台一堆人催更&#xff1a;能不能搞点接地气的&#xff1f;比如&am…

Axure RP 9 Mac 交互原型设计

原文地址&#xff1a;Axure RP 9 Mac 交互原型设计 安装教程 Axure RP 9是一款功能强大的原型设计和协作工具。 它不仅能够帮助用户快速创建出高质量的原型设计&#xff0c;还能促进团队成员之间的有效协作&#xff0c;从而极大地提高数字产品开发的效率和质量。 拥有直观易…

多线程——线程状态

目录 1.线程的状态 1.1 NEW 1.2 RUNNABLE 1.3 BLOCKED 1.4 WAITING 1.5 TIMED_WAITING 1.6 TERMINATED 2.线程状态的相互转换 在上期的学习中&#xff0c;已经理解线程的启动&#xff08;start()&#xff09;、休眠&#xff08;sleep()&#xff09;、中断&#xff08;i…

IMX6ULL的设备树文件简析

先分析一个完整的设备树&#xff0c;是怎么表达各种外设信息的。以imux6ull开发板为例进行说明。这个文件里就一个设备信息才这么点内容&#xff0c;是不是出问题了&#xff1f;当然不是&#xff0c;我们知道dts文件是可包含的&#xff0c;所以&#xff0c;最终形成的一个完整文…

【ARM】PACK包管理

1、 文档目标对 pack 包的管理有更多的了解。2、 问题场景客户在安装了过多的 pack 包导致软件打开比较慢&#xff0c;各种 pack 包颜色的区别&#xff0c;及图标不同。3、软硬件环境1&#xff09;、软件版本&#xff1a;Keil MDK 5.392&#xff09;、电脑环境&#xff1a;Wind…

【Kubernetes】知识点4

36. 说明K8s中Pod级别的Graceful Shutdown。答&#xff1a;Graceful Shutdown&#xff08;优雅关闭&#xff09;是指当 Pod 需要终止时&#xff0c;系统给予运行中的容器一定的时间来等待业务的应用的正常关闭&#xff08;如保存数据、关闭连接、释放资源等&#xff09;&#x…

Paraverse平行云实时云渲染助力第82届威尼斯电影节XR沉浸式体验

今年&#xff0c;Paraverse平行云实时云渲染平台LarkXR&#xff0c;为享有盛誉的第82届威尼斯国际电影节&#xff08;8月27日至9月6日&#xff09;带来沉浸式体验。 LarkXR助力我们的生态伙伴FRENCH TOUCH FACTORY&#xff0c;实现ITHACA容积视频的XR交互演示&#xff0c;从意大…

大数据开发计划表(实际版)

太好了&#xff01;我将为你生成一份可打印的PDF版学习计划表&#xff0c;并附上项目模板与架构图示例&#xff0c;帮助你更直观地执行计划。 由于当前环境无法直接生成和发送文件&#xff0c;我将以文本格式为你完整呈现&#xff0c;你可以轻松复制到Word或Markdown中&#xf…

GitLab 18.3 正式发布,更新多项 DevOps、CI/CD 功能【二】

沿袭我们的月度发布传统&#xff0c;极狐GitLab 发布了 18.3 版本&#xff0c;该版本带来了通过直接转移进行迁移、CI/CD 作业令牌的细粒度权限控制、自定义管理员角色、Kubernetes 1.33 支持、通过 API 让流水线执行策略访问 CI/CD 配置等几十个重点功能的改进。下面是对部分重…

Docker学习笔记(二):镜像与容器管理

Docker 镜像 最小的镜像 hello-world 是 Docker 官方提供的一个镜像&#xff0c;通常用来验证 Docker 是否安装成功。 先通过 docker pull 从 Docker Hub 下载它。 [rootdocker ~]# docker pull hello-world Using default tag: latest latest: Pulling from library/hello-wor…

STM32F103C8T6开发板入门学习——寄存器和库函数介绍

学习目标&#xff1a;STM32F103C8T6开发板入门学习——寄存器和库函数介绍学习内容&#xff1a; 1. 寄存器介绍 1.1 存储器映射 存储器本身无固有地址&#xff0c;是具有特定功能的内存单元。它的地址是由芯片厂商或用户分配&#xff0c;给存储器分配地址的过程就叫做存储区映射…

【CouponHub项目开发】使用RocketMQ5.x实现延时修改优惠券状态,并通过使用模板方法模式重构消息队列发送功能

在上个章节中我实现了创建优惠券模板的功能&#xff0c;但是&#xff0c;优惠券总会有过期时间&#xff0c;我们怎么去解决到期自动修改优惠券状态这样一个功能呢&#xff1f;我们可以使用RocketMQ5.x新出的任意定时发送消息功能来解决。 初始方案&#xff1a;首先在创建优惠券…

Claude Code SDK 配置Gitlab MCP服务

一、MCP配置前期准备 &#xff08;一&#xff09;创建个人令牌/群组令牌 我这里是创建个人令牌&#xff0c;去到首页左上角&#xff0c;点击头像——>偏好设置——>访问令牌——>添加新令牌 &#xff08;二&#xff09;配置mcp信息 去到魔塔社区&#xff0c;点击mc…

Eclipse 常用搜索功能汇总

Eclipse 常用搜索功能汇总 Eclipse 提供了多种搜索功能&#xff0c;帮助开发者快速定位代码、文件、类、方法、API 等资源。以下是详细的使用方法和技巧。 一、常用搜索快捷键快捷键功能描述Ctrl H打开全局搜索对话框&#xff0c;支持文件、Java 代码、任务等多种搜索。Ctrl …