一、任务描述
从手写数字图像中自动识别出对应的数字(0-9)” 的问题,属于单标签图像分类任务(每张图像仅对应一个类别,即 0-9 中的一个数字)
1、任务的核心定义:输入与输出
- 输入:28×28 像素的灰度图像(像素值范围 0-255,0 代表黑色背景,255 代表白色前景),图像内容是人类手写的 0-9 中的某一个数字,例如:一张 28×28 的图像,像素分布呈现 “3” 的形状,就是模型的输入。
- 输出:一个 “类别标签”,即从 10 个可能的类别(0、1、2、…、9)中选择一个,作为输入图像对应的数字,例如:输入 “3” 的图像,模型输出 “类别 3”,即完成一次正确识别。
- 目标:让模型在 “未见的手写数字图像” 上,尽可能准确地输出正确类别(通常用 “准确率” 衡量,即正确识别的图像数 / 总图像数)
2、任务的核心挑战
- 不同人书写习惯差异极大:有人写的 “4” 带弯钩,有人写的 “7” 带横线,有人字体粗大,有人字体纤细;甚至同一个人不同时间写的同一数字,笔画粗细、倾斜角度也会不同。例如:同样是 “5”,可能是 “直笔 5”“圆笔 5”,也可能是倾斜 10° 或 20° 的 “5”—— 模型需要忽略这些 “风格差异”,抓住 “数字的本质特征”(如 “5 有一个上半圆 + 一个竖线”)。
- 图像噪声与干扰:手写数字图像可能存在噪声,比如纸张上的污渍、书写时的断笔、扫描时的光线不均,这些都会影响像素分布。例如:一张 “0” 的图像,边缘有一小块污渍,模型需要判断 “这是噪声” 而不是 “0 的一部分”,避免误判为 “6” 或 “8”。
二、模型训练
1、MNIST数据集
MNIST(Modified National Institute of Standards and Technology database)是由美国国家标准与技术研究院(NIST)整理的手写数字数据集,后经修改(调整图像大小、居中对齐)成为机器学习领域的 “基准数据集”,MNIST手写数字识别的核心是 “让计算机从标准化的手写数字灰度图中,自动识别出对应的 0-9 数字”,它看似基础,却浓缩了图像分类的核心挑战(风格多样性、噪声鲁棒性、特征自动提取),同时是实际 OCR 场景的技术基础和机器学习入门的经典案例。
- 数据量适中:包含 70000 张图像,其中 60000 张用于训练(让模型学习特征),10000 张用于测试(验证模型泛化能力);
- 图像规格统一:所有图像都是 28×28 灰度图,无需复杂的预处理(如尺寸缩放、颜色通道处理),降低入门门槛;
- 标注准确:每张图像都有明确的 “正确数字标签”(人工标注),无需额外标注成本。
2、代码
- 数据准备:使用torchvision.datasets加载 MNIST 数据集,对数据进行转换(转为 Tensor 并标准化),使用DataLoader创建可迭代的数据加载器;
- 模型定义:定义了一个简单的两层神经网络SimpleNN,第一层将 28x28 的图像展平后映射到 128 维,第二层将 128 维特征映射到 10 个类别(对应数字 0-9);
- 训练设置:使用交叉熵损失函数(CrossEntropyLoss),使用 Adam 优化器,设置批量大小为64,训练轮次为5;
- 训练过程:循环多个训练轮次(epoch),每个轮次中迭代所有批次数据,执行前向传播、计算损失、反向传播和参数更新。
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms# 设置随机种子,确保结果可复现
torch.manual_seed(42)# 1. 数据准备
# 定义数据变换
transform = transforms.Compose([transforms.ToTensor(), # 转换为Tensortransforms.Normalize((0.1307,), (0.3081,)) # 标准化,MNIST数据集的均值和标准差
])# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', # 数据保存路径train=True, # 训练集download=True, # 如果数据不存在则下载transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False, # 测试集download=True,transform=transform
)# 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 2. 定义模型
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()# 输入层到隐藏层self.fc1 = nn.Linear(28 * 28, 128) # MNIST图像大小为28x28# 隐藏层到输出层self.fc2 = nn.Linear(128, 10) # 10个类别(0-9)def forward(self, x):# 将图像展平为一维向量x = x.view(-1, 28 * 28)# 隐藏层,使用ReLU激活函数x = torch.relu(self.fc1(x))# 输出层,不使用激活函数(因为后面会用CrossEntropyLoss)x = self.fc2(x)return x# 3. 初始化模型、损失函数和优化器
model = SimpleNN()
criterion = nn.CrossEntropyLoss() # 交叉熵损失,适用于分类问题
optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam优化器# 4. 训练模型
def train(model, train_loader, criterion, optimizer, epochs=5):model.train() # 设置为训练模式train_losses = []for epoch in range(epochs):running_loss = 0.0for batch_idx, (data, target) in enumerate(train_loader):# 清零梯度optimizer.zero_grad()# 前向传播outputs = model(data)loss = criterion(outputs, target)# 反向传播和优化loss.backward()optimizer.step()running_loss += loss.item()# 每100个批次打印一次信息if batch_idx % 100 == 99:print(f'Epoch [{epoch + 1}/{epochs}], Batch [{batch_idx + 1}/{len(train_loader)}], Loss: {running_loss / 100:.4f}')running_loss = 0.0train_losses.append(running_loss / len(train_loader))return train_losses# 6. 运行训练和测试
if __name__ == '__main__':# 训练模型print("开始训练模型...")train_losses = train(model, train_loader, criterion, optimizer, epochs=5)print("模型训练完成...")# 保存模型torch.save(model.state_dict(), 'mnist_model.pth')print("模型已保存为 mnist_model.pth")
三、模型使用测试
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
from torchvision import transforms # 修正transforms的导入方式# 定义与训练时相同的模型结构
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(28*28, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28*28)x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 加载模型
def load_model(model_path='mnist_model.pth'):model = SimpleNN()# 加载模型时添加参数以避免潜在的Python 3兼容性问题model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))model.eval() # 设置为评估模式return model# 图像预处理(与训练时保持一致)
def preprocess_image(image_path):# 打开图像并转换为灰度图img = Image.open(image_path).convert('L') # 'L'表示灰度模式# 调整大小为28x28img = img.resize((28, 28))# 转换为numpy数组并归一化img_array = np.array(img) / 255.0# 定义图像转换(使用torchvision的transforms)transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])# 注意:这里需要先将numpy数组转换为PIL图像再应用transformimg_pil = Image.fromarray((img_array * 255).astype(np.uint8))img_tensor = transform(img_pil).unsqueeze(0) # 增加批次维度return img_tensor# 预测函数
def predict_digit(model, image_path):# 预处理图像img_tensor = preprocess_image(image_path)# 预测with torch.no_grad(): # 不计算梯度outputs = model(img_tensor)_, predicted = torch.max(outputs.data, 1)return predicted.item() # 返回预测的数字# 示例使用
if __name__ == '__main__':# 加载模型model = load_model('mnist_model.pth')# 预测示例图像test_image_path = 'test_digit.png' # 用户需要提供的测试图像路径try:predicted_digit = predict_digit(model, test_image_path)print(f"预测的数字是: {predicted_digit}")except Exception as e:print(f"预测出错: {str(e)}")
使用gpu0(第一块gpu)进行训练/推理:
torch.cuda.set_device(0)
model = model.cuda(0)
使用cpu记性训练/推理:
model = model.cpu()
怎么用pytorch训练一个模型-手写数字识别
手把手教你如何跑通一个手写中文汉字识别模型-OCR识别【pytorch】
手把手教你用PyTorch从零训练自己的大模型(非常详细)零基础入门到精通,收藏这一篇就够了
揭秘大模型的训练方法:使用PyTorch进行超大规模深度学习模型训练
全套解决方案:基于pytorch、transformers的中文NLP训练框架,支持大模型训练和文本生成,快速上手,海量训练数据!
用 pytorch 从零开始创建大语言模型(三):编码注意力机制
YOLOv5源码逐行超详细注释与解读(1)——项目目录结构解析