文章目录

    • 引言
    • 1. 环境准备和数据加载
      • 1.1 下载MNIST数据集
      • 1.2 数据可视化
    • 2. 数据预处理
    • 3. 设备配置
    • 4. 构建卷积神经网络模型
    • 5. 训练和测试函数
      • 5.1 训练函数
      • 5.2 测试函数
    • 6. 模型训练和评估
      • 6.1 初始化损失函数和优化器
      • 6.2 训练过程
    • 7. 关键点解析
    • 8. 完整代码
    • 9. 总结

引言

手写数字识别是计算机视觉和深度学习领域的经典入门项目。本文将详细介绍如何使用PyTorch框架构建一个卷积神经网络(CNN)来实现MNIST手写数字识别任务。我们将从数据加载、模型构建到训练和测试,一步步解析整个过程。

1. 环境准备和数据加载

首先,我们需要导入必要的PyTorch模块:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

1.1 下载MNIST数据集

MNIST数据集包含60,000个训练样本和10,000个测试样本,每个样本都是一个28x28像素的灰度手写数字图像。

# 下载训练数据集
training_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor(),
)# 下载测试数据集
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)

1.2 数据可视化

我们可以使用matplotlib库来查看数据集中的一些样本:

from matplotlib import pyplot as pltfigure = plt.figure()
for i in range(9):img, label = training_data[i+59000]  # 提取后几张图片figure.add_subplot(3,3,i+1)plt.title(label)plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()

2. 数据预处理

为了高效训练模型,我们需要使用DataLoader将数据集分批次加载:

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

3. 设备配置

PyTorch支持在CPU、NVIDIA GPU和苹果M系列芯片上运行,我们可以自动检测最佳可用设备:

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

4. 构建卷积神经网络模型

我们定义一个CNN类来实现手写数字识别:

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(1, 8, 3, 1, 1),  # (8,28,28)nn.ReLU(),nn.MaxPool2d(2),           # (8,14,14))self.conv2 = nn.Sequential(nn.Conv2d(8, 16, 3, 1, 1), # (16,14,14)nn.ReLU(),nn.MaxPool2d(2),           # (16,7,7))self.out = nn.Linear(16*7*7, 10)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)      # flatten操作output = self.out(x)return outputmodel = CNN().to(device)

这个CNN模型包含:

  • 两个卷积层,每个卷积层后接ReLU激活函数和最大池化层
  • 一个全连接输出层
  • 输入大小:(1,28,28)
  • 输出大小:10(对应0-9的数字类别)

5. 训练和测试函数

5.1 训练函数

def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()if batch_size_num % 100 == 0:print(f"loss: {loss.item():>7f} [number:{batch_size_num}]")batch_size_num += 1

5.2 测试函数

def Test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test result: \n Accuracy:{(100*correct)}%, Avg loss:{test_loss}")

6. 模型训练和评估

6.1 初始化损失函数和优化器

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

6.2 训练过程

# 初始训练和测试
train(train_dataloader, model, loss_fn, optimizer)
Test(test_dataloader, model, loss_fn)# 多轮训练
epochs = 10
for t in range(epochs):print(f"epoch {t+1}\n---------------")train(train_dataloader, model, loss_fn, optimizer)
print("Done!")# 最终测试
Test(test_dataloader, model, loss_fn)

7. 关键点解析

  1. 数据转换:使用ToTensor()将图像数据转换为PyTorch张量,并自动归一化到[0,1]范围。

  2. 批处理:DataLoader的batch_size参数控制每次训练使用的样本数量,影响内存使用和训练速度。

  3. 模型结构

    • 卷积层提取空间特征
    • ReLU激活函数引入非线性
    • 最大池化层降低特征图尺寸
    • 全连接层输出分类结果
  4. 训练模式切换model.train()model.eval()分别用于训练和测试阶段,影响某些层(如Dropout和BatchNorm)的行为。

  5. 优化过程:Adam优化器结合了动量法和自适应学习率的优点,通常能获得较好的训练效果。

8. 完整代码

import torch
from torch import nn    #导入神经网络模块
from torch.utils.data import DataLoader  #数据包管理工具,打包数据
from torchvision import  datasets  #封装了很多与图像相关的模型,数据集
from torchvision.transforms import ToTensor  #数据转换,张量,将其他类型的数据转换为tensor张量,numpy array'''下载训练数据集(包含训练图片+标签)'''
training_data = datasets.MNIST( #跳转到函数的内部源代码,pycharm按下ctrl + 鼠标点击root="data", #表示下载的手写数字  到哪个路径。60000train=True, #读取下载后的数据中的训练集download=True, #如果你之前已经下载过了,就不用下载transform=ToTensor(), #张量,图片是不能直接传入神经网络模型)   #对于pytorch库能够识别的数据一般是tensor张量'''下载测试数据集(包含训练图片+标签)'''
test_data = datasets.MNIST( #跳转到函数的内部源代码,pycharm按下ctrl + 鼠标点击root="data", #表示下载的手写数字  到哪个路径。60000train=False, #读取下载后的数据中的训练集download=True, #如果你之前已经下载过了,就不用下载transform=ToTensor(), #Tensor是在深度学习中提出并广泛应用的数据类型)   #Numpy数组只能在CPU上运行。Tensor可以在GPU上运行。这在深度学习应用中可以显著提高计算速度。
print(len(training_data))'''展示手写数字图片,把训练集中的59000张图片展示'''
from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img,label = training_data[i+59000] #提取第59000张图片figure.add_subplot(3,3,i+1) #图像窗口中创建多个小窗口,小窗口用于显示图片plt.title(label)plt.axis("off")  #plt.show(I) 显示矢量plt.imshow(img.squeeze(),cmap="gray") #plt.imshow()将Numpy数组data中的数据显示为图像,并在图形窗口中显示a = img.squeeze()  #img.squeeze()从张量img中去掉维度为1的,如果该维度的大小不为1,则张量不会改变
plt.show()'''创建数据DataLoader(数据加载器)'''
# batch_size:将数据集分为多份,每一份为batch_size个数据
#       优点:可以减少内存的使用,提高训练速度train_dataloader = DataLoader(training_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)'''判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU'''
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")   #字符串的格式化,CUDA驱动软件的功能:pytorch能够去执行cuda的命令
# 神经网络的模型也需要传入到GPU,1个batch_size的数据集也需要传入到GPU,才可以进行训练''' 定义神经网络  类的继承这种方式'''
class CNN(nn.Module): #通过调用类的形式来使用神经网络,神经网络的模型,nn.mdouledef __init__(self): #输入大小:(1,28,28)super(CNN,self).__init__()  #初始化父类self.conv1 = nn.Sequential(      #将多个层组合成一起,创建了一个容器,将多个网络组合在一起nn.Conv2d(              # 2d一般用于图像,3d用于视频数据(多一个时间维度),1d一般用于结构化的序列数据in_channels=1,      # 图像通道个数,1表示灰度图(确定了卷积核 组中的个数)out_channels=8,     # 要得到多少个特征图,卷积核的个数kernel_size=3,      # 卷积核大小 3×3stride=1,           # 步长padding=1,          # 一般希望卷积核处理后的结果大小与处理前的数据大小相同,效果会比较好),                      # 输出的特征图为(8,28,28)nn.ReLU(),  # Relu层,不会改变特征图的大小nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2×2操作),输出结果为(8,14,14))self.conv2 = nn.Sequential(nn.Conv2d(8,16,3,1,1),  #输出(16,14,14)nn.ReLU(),  #Relu层  (16,14,14)nn.MaxPool2d(kernel_size=2),    #池化层,输出结果为(16,7,7))self.out = nn.Linear(16*7*7,10)  # 全连接层得到的结果def forward(self,x):   #前向传播,你得告诉它 数据的流向 是神经网络层连接起来,函数名称不能改x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0),-1)    # flatten操作,结果为:(batch_size,64 * 7 * 7)output = self.out(x)return output
model = CNN().to(device) #把刚刚创建的模型传入到GPU
print(model)def train(dataloader,model,loss_fn,optimizer):model.train() #告诉模型,我要开始训练,模型中w进行随机化操作,已经更新w,在训练过程中,w会被修改的
# pytorch提供2种方式来切换训练和测试的模式,分别是:model.train() 和 mdoel.eval()
# 一般用法是:在训练开始之前写上model.train(),在测试时写上model.eval()batch_size_num = 1for X,y in dataloader:              #其中batch为每一个数据的编号X,y = X.to(device),y.to(device) #把训练数据集和标签传入cpu或GPUpred = model.forward(X)         # .forward可以被省略,父类种已经对此功能进行了设置loss = loss_fn(pred,y)          # 通过交叉熵损失函数计算损失值loss# Backpropagation 进来一个batch的数据,计算一次梯度,更新一次网络optimizer.zero_grad()           # 梯度值清零loss.backward()                 # 反向传播计算得到每个参数的梯度值woptimizer.step()                # 根据梯度更新网络w参数loss_value = loss.item()        # 从tensor数据种提取数据出来,tensor获取损失值if batch_size_num %100 ==0:print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1def Test(dataloader,model,loss_fn):size = len(dataloader.dataset)  #10000num_batches = len(dataloader)  # 打包的数量model.eval()        #测试,w就不能再更新test_loss,correct =0,0with torch.no_grad():       #一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候for X,y in dataloader:X,y = X.to(device),y.to(device)pred = model.forward(X)test_loss += loss_fn(pred,y).item() #test_loss是会自动累加每一个批次的损失值correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y) #dim=1表示每一行中的最大值对应的索引号,dim=0表示每一列中的最大值对应的索引号b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batches #能来衡量模型测试的好坏correct /= size  #平均的正确率print(f"Test result: \n Accuracy:{(100*correct)}%, Avg loss:{test_loss}")loss_fn = nn.CrossEntropyLoss()  #创建交叉熵损失函数对象,因为手写字识别一共有十种数字,输出会有10个结果
#
optimizer = torch.optim.Adam(model.parameters(),lr=0.01) #创建一个优化器,SGD为随机梯度下降算法
# # params:要训练的参数,一般我们传入的都是model.parameters()
# # lr:learning_rate学习率,也就是步长
#
# # loss表示模型训练后的输出结果与样本标签的差距。如果差距越小,就表示模型训练越好,越逼近真实的模型
train(train_dataloader,model,loss_fn,optimizer) #训练1次完整的数据。多轮训练
Test(test_dataloader,model,loss_fn)epochs = 10
for t in range(epochs):print(f"epoch {t+1}\n---------------")train(train_dataloader,model,loss_fn,optimizer)
print("Done!")
Test(test_dataloader,model,loss_fn)

9. 总结

通过本文,我们学习了如何使用PyTorch实现一个完整的手写数字识别项目。从数据加载、模型构建到训练和评估,每个步骤都展示了PyTorch框架的简洁和强大。这个简单的CNN模型在MNIST数据集上可以达到很高的准确率,为进一步学习更复杂的计算机视觉任务打下了良好基础。

未来可以尝试:

  • 调整网络结构(增加层数、改变通道数)
  • 尝试不同的优化器和学习率
  • 添加数据增强技术
  • 在更复杂的数据集上应用类似方法

希望这篇教程能帮助你入门PyTorch和计算机视觉领域!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/84835.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/84835.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/84835.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Activiti初识

文章目录 1 工作流介绍1_工作流概念介绍2 工作流系统3 适用行业4 具体应用5 实现方式 2 Activiti介绍1_BPM2 BPM 软件3 BPMN 3 使用步骤1_部署 activiti2 流程定义3 流程定义部署4 启动一个流程实例5 用户查询待办任务(Task)6 用户办理任务7 流程结束 4 Activiti应用1_Activiti…

CyclicBarrier入门代码解析

文章目录 核心思想:组队出游,人到齐了才出发 🚌最简单易懂的代码示例代码解析运行效果分析CyclicBarrier vs CountDownLatch 的关键区别CyclicBarrier在业务系统里面通常有什么常用的应用场景核心应用模式1. 数据并行处理与ETL(最…

Maven 配置中绕过 HTTP 阻断机制的完整解决方案

Maven 配置中绕过 HTTP 阻断机制的完整解决方案 一、背景与问题分析 自 Maven 3.8.1 版本起&#xff0c;出于安全考虑&#xff0c;默认禁止了对 HTTP 仓库的访问。这一机制通过 <mirror> 配置中的 maven-default-http-blocker 实现&#xff0c;其作用是拦截所有使用 HT…

【大厂机试题解法笔记】恢复数字序列

题目 对于一个连续正整数组成的序列&#xff0c;可以将其拼接成一个字符串&#xff0c;再将字符串里的部分字符打乱顺序。如序列8 9 10 11 12,拼接成的字符串为89101112,打乱一部分字符后得到90811211,原来的正整数10就被拆成了0和1。 现给定一个按如上规则得到的打乱字符的字…

MongoDB 事务有哪些限制和注意事项?

MongoDB 的多文档 ACID 事务虽然强大&#xff0c;但在使用时确实有一些限制和需要特别注意的事项。 以下是主要的限制和注意事项&#xff1a; 1. 性能开销 (Performance Overhead) 额外协调: 事务需要额外的协调工作&#xff0c;包括跟踪事务状态、管理锁&#xff08;即使是乐…

CTF实战技巧:获取初始权限后如何高效查找Flag

CTF实战技巧&#xff1a;获取初始权限后如何高效查找Flag 在CTF比赛中&#xff0c;获得初始访问权限只是开始&#xff0c;真正的挑战在于如何在系统中高效定位Flag。本文将分享我在渗透测试中总结的系统化Flag搜索方法&#xff0c;涵盖Linux和Windows双平台。 引言&#xff1a;…

kafka Tool (Offset Explorer)使用SASL Plaintext进行身份验证

一、前面和不需要认证的情况相同&#xff1a; 1、填写Properties中的cluster name和版本&#xff0c;以及zk的ip和port 2、Advanced中填写bootstrap servers 二、和不需要认证时不同的点&#xff1a; 1、Security的Type&#xff0c;不需要认证时选plaintext&#xff0c;需要认…

最小费用最大流算法

最小费用最大流算法 原理 问题:网络中有源点(起点)和汇点(终点),每条边有流量上限和单位流量费用。求: 从源点到汇点的最大流量在流量最大的前提下,总费用最小核心思想:在找增广路时,选择单位费用之和最小的路径(使用SPFA找最短路) 实现步骤 建图:使用链式前向…

从汇编的角度揭开C++ this指针的神秘面纱(上)

C中的this指针一直比较神秘。任何类的对象&#xff0c;都有一个this指针&#xff0c;无处不在。那么this指针的本质究竟是什么&#xff1f;this指针什么时候会被用到&#xff1f;今天通过几段简单的代码&#xff0c;来揭秘一下。 要先揭秘this指针&#xff0c;先来说一下函数调…

18 - GCNet

论文《GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond》 1、作用 GCNet通过聚合每个查询位置的全局上下文信息来捕获长距离依赖关系&#xff0c;从而改善了图像/视频分类、对象检测和分割等一系列识别任务的性能。非局部网络&#xff08;NLNet&…

人工智能学习17-Pandas-查看数据

人工智能学习概述—快手视频 人工智能学习17-Pandas-查看数据—快手视频

RV1126+OPENCV在视频中添加LOGO图像

一.RV1126OPENCV在视频中添加LOGO图像大体流程图 主要是利用RV1126的视频流结合OPENCV的API在视频流里面添加LOGO图像&#xff0c;换言之就是在RV1126的视频流里面叠加图片。大体流程我们来看上图&#xff0c;要完成这个功能我们需要创建两个线程(实际上还有初始化过程&#xf…

汽车制造通信革新:网关模块让EtherCAT成功对接CCLINK

‌在现代工业自动化生产领域&#xff0c;不同品牌和类型的设备往往采用不同的通信协议&#xff0c;这给设备之间的互联互通带来了挑战。某汽车制造企业的生产线上&#xff0c;采用了三菱FX5U PLC作为主站进行整体生产流程的控制和调度&#xff0c;同时配备了库卡机器人作为从站…

vue父类跳转到子类带参数,跳转完成后去掉参数

当通过路由导航的时候&#xff0c;由于父类页面带参数到子类&#xff0c;导致路径上面有参数 这样不仅不美观&#xff0c;而且在点击导航菜单按钮时还会有各种问题&#xff0c;这时我们只需要将路由后面的参数去掉就好了&#xff0c;在子页面mounted()函数里面获取到父类的参数…

纯 CSS 实现的的3种扫光效果

介绍一个比较常见的动画效果。 在日常开发中&#xff0c;为了强调凸显某些文本或者元素&#xff0c;会加一些扫光动效&#xff0c;起到吸引眼球的效果&#xff0c;比如文本的 或者是一个卡片容器&#xff0c;里面可能是图片或者文本或者任意元素 除此之外&#xff0c;还有那…

如何在FastAPI中构建一个既安全又灵活的多层级权限系统?

title: 如何在FastAPI中构建一个既安全又灵活的多层级权限系统? date: 2025/06/14 12:43:05 updated: 2025/06/14 12:43:05 author: cmdragon excerpt: FastAPI通过依赖注入系统和OAuth2、JWT等安全方案,支持构建多层级权限系统。系统设计包括基于角色的访问控制、细粒度权…

大模型_Ubuntu24.04安装RagFlow_使用hyper-v虚拟机_超级详细--人工智能工作笔记0251

因为之前使用dify搭建了一个知识库&#xff0c;但是dify的效果&#xff0c;尤其是在文档解析方面是非常不友好的&#xff0c;虽然测试了&#xff0c;纳米的效果非常好&#xff0c;但是纳米只能容纳2000个文件&#xff0c;如果 你的知识库中有代码&#xff0c;sql文件等等&…

LeetCode - LCR 173. 点名

题目 LCR 173. 点名 - 力扣&#xff08;LeetCode&#xff09; 思路 首先对数组进行排序&#xff0c;使学号按顺序排列 在排序后的数组中&#xff0c;如果没有缺失的学号&#xff0c;那么每个元素应该等于其索引值 使用二分查找找到第一个不等于其索引的元素位置&#xff1…

VSCode如何优雅的debug python文件,包括外部命令uv run main.py等等

debug程序的方式有很多种。每一种方式都各有缺点:有的方式虽然优雅,但是局限性很大;有的方式麻烦,但是局限性小。 常规方式: 优点:然后可以观察所有线程。一劳永逸。缺点:就是写参数很麻烦,但是你可以让chatgpt等大模型帮你写。最最最优雅的方式: 优点:就是需要在代码…

[调试技巧]VS Code如何在代理模式下使用 MCP 工具?

在开发环境调试MCP&#xff0c;通过agent模式与大模型对话&#xff0c;并不能保证每次均正确调用tool。在阅读官方文档之后&#xff0c;得知以下小技巧。 添加 MCP 服务器后&#xff0c;您可以在代理模式下使用它提供的工具。要在代理模式下使用 MCP 工具 打开聊天视图 (CtrlAl…