🔍 开始你的图像分类之旅:一步一步学习 CIFAR-10 分类

图像分类是计算机视觉中最基础的任务之一,如果你是初学者,那么以 CIFAR-10 为训练场是一个不错的选择。本文一步一步带你从零开始,学习如何用深度学习模型实现图像分类。


一、CIFAR-10 数据集是什么?

CIFAR-10 是一个小型图像分类数据集,共包括 10 个类别:✈ 飞机(airplane)🚗 汽车(automobile)🐦 鸟(bird)🐱 猫(cat)🦌 鹿(deer)🐶 狗(dog)🐸 青蛙(frog)🐴 马(horse)🚢 船(ship)🚚 卡车(truck)

每张图片都是 32x32 的小图,有 RGB 三个颜色通道。

总共有 60000 张图,其中:

  • 训练集: 50000 张
  • 测试集: 10000 张

这些图片内容丰富,分辨率低,适合初学者练手。


二、模型训练的整体流程

我们用以下流程完成图像分类:

  1. 数据加载和预处理
  2. 构建模型(CNN)
  3. 设置损失函数和优化器
  4. 训练模型(前向 + 反向传播 + 更新参数)
  5. 测试模型效果

🧠 类比理解: 把整个过程比作“学会识别水果”:

  • 数据加载:收集不同水果的照片
  • 模型:像是大脑处理这些图像的神经元网络
  • 损失函数:告诉我们判断错误的严重程度
  • 优化器:帮助我们不断修正错误,直到准确

三、数据加载和预处理

我们使用 PyTorch 中的 transforms 来将图片:

  • 转换成 Tensor(张量)
  • 正则化颜色值到 -1~1 之间,加快模型收敛
import torch
import torchvision
import torchvision.transforms as transforms# 定义图像转换操作:将图片转换为 Tensor,并进行标准化
transform = transforms.Compose([transforms.ToTensor(),  # 将图片转为 Tensor 类型,方便计算transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 将图片标准化,均值和标准差都设为0.5
])# 加载 CIFAR-10 训练数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
# 使用 DataLoader 进行批量加载数据,batch_size 是每次加载的图片数量
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,shuffle=True)  # shuffle=True 表示打乱数据# 加载 CIFAR-10 测试数据集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,shuffle=False)  # 测试集不需要打乱数据

🎓 例子解释

  • 如果 batch_size=64,就意味着每次训练取 64 张图片
  • shuffle=True 可以打乱图片顺序,防止模型记住顺序而不是学习特征

🍎 通俗类比

  • ToTensor 就像是把一张照片转成表格(数字表示颜色)
  • Normalize 就像是统一标准,把所有颜色亮度调成统一区间,好比较

四、构建 CNN 模型

CNN(卷积神经网络)特别适合处理图像。我们构建一个简单 CNN 模型,包含:

  • 两个卷积层 + ReLU 激活
  • 两个最大池化层(缩小图片尺寸)
  • 一个全连接隐藏层 + 一个输出层(10类)
import torch.nn as nn
import torch.nn.functional as F# 定义简单的卷积神经网络
class SimpleCNN(nn.Module):def __init__(self):super().__init__()# 第一层卷积层,输入通道为3(RGB图像),输出通道为32,卷积核大小为3x3,padding=1 保证输出尺寸不变self.conv1 = nn.Conv2d(3, 32, 3, padding=1)# 最大池化层:2x2池化,用来减少图像尺寸self.pool = nn.MaxPool2d(2, 2)# 第二层卷积层,输入通道为32,输出通道为64self.conv2 = nn.Conv2d(32, 64, 3, padding=1)# 全连接层:将卷积层输出展平为一维向量,连接到一个128维的隐藏层self.fc1 = nn.Linear(64 * 8 * 8, 128)  # CIFAR-10图像尺寸32x32,经过两次池化后尺寸为8x8# 输出层:10类,CIFAR-10数据集包含10个类别self.fc2 = nn.Linear(128, 10)def forward(self, x):# 第一层卷积 + ReLU 激活 + 池化x = self.pool(F.relu(self.conv1(x)))# 第二层卷积 + ReLU 激活 + 池化x = self.pool(F.relu(self.conv2(x)))# 展平特征图为一维向量,便于输入全连接层x = x.view(-1, 64 * 8 * 8)# 全连接层 + ReLU 激活x = F.relu(self.fc1(x))# 输出层,返回每个类别的预测概率x = self.fc2(x)return x

📷 类比

  • 卷积操作就像“扫描照片”的滤镜,用来提取边缘、颜色块等图像特征
  • 最大池化像是“缩略图”,保留最显著的特征,减少计算量

五、设置损失函数和优化器

import torch.optim as optim# 初始化模型并将其放到 GPU 或 CPU 上
model = SimpleCNN()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 损失函数:交叉熵损失函数,用于多分类问题
criterion = nn.CrossEntropyLoss()# 优化器:Adam 优化器,适用于大部分情况,学习率设置为0.001
optimizer = optim.Adam(model.parameters(), lr=0.001)

🧠 例子类比

  • 损失函数就像考试成绩:越高说明你错越多
  • 优化器就像“老师指出你的错误”并教你怎么改正

📘 数值示例

  • 如果模型预测飞机的概率是 [0.1, 0.05, …, 0.7](第10类)
  • 但真实标签是第1类(飞机),交叉熵损失会很大
  • 优化器就会调整参数,使下一次飞机的概率尽量靠近第1类

六、训练模型:前向、反向、更新

# 训练过程:迭代多个 epoch,每个 epoch 会遍历所有训练数据
for epoch in range(10):  # 总共训练 10 个 epochrunning_loss = 0.0  # 每个 epoch 初始化损失值为 0for inputs, labels in trainloader:  # 遍历每个 batchinputs, labels = inputs.to(device), labels.to(device)  # 将数据移到 GPU 或 CPUoptimizer.zero_grad()  # 清除上一次的梯度outputs = model(inputs)  # 前向传播,得到每张图片的预测结果loss = criterion(outputs, labels)  # 计算损失值loss.backward()  # 反向传播,计算梯度optimizer.step()  # 更新模型参数running_loss += loss.item()  # 累加损失值print(f"Epoch {epoch + 1}, Loss: {running_loss:.3f}")  # 打印当前 epoch 的损失

📐 具体数值例子

  • 模型初始权重是 0.3,预测错误 → loss = 2.5
  • 反向传播算出权重梯度是 -0.8
  • 学习率为 0.01,更新后权重 = 0.3 - 0.01 × (-0.8) = 0.308

七、测试模型效果

correct = 0  # 初始化正确预测的个数
total = 0  # 初始化总预测的个数
model.eval()  # 设置模型为评估模式,关闭 Dropout 等训练时的特殊操作with torch.no_grad():  # 在测试时,不需要计算梯度,减少计算量for data in testloader:  # 遍历测试集中的数据images, labels = dataimages, labels = images.to(device), labels.to(device)  # 将数据移到 GPU 或 CPUoutputs = model(images)  # 获取模型的输出_, predicted = torch.max(outputs, 1)  # 获取预测结果,torch.max 返回最大值和其索引,这里我们只取索引total += labels.size(0)  # 累加总的测试样本数量correct += (predicted == labels).sum().item()  # 统计预测正确的样本数量print(f"测试准确率: {100 * correct / total:.2f}%")  # 打印测试集上的准确率

在这里插入图片描述

📊 例子

  • 如果 total=10000,correct=7000,准确率就是 70%

在这里插入图片描述

🏁 完整代码快速运行包

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt  # 添加matplotlib用于可视化
from matplotlib import rcParams  # 用于设置字体# 1. 数据加载与预处理
transform = transforms.Compose([transforms.ToTensor(),  # 将图像转换为Tensortransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化到[-1, 1]之间
])# 加载训练集与测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)# 2. 定义模型:简单的CNN模型
class SimpleCNN(nn.Module):def __init__(self):super().__init__()# 第一个卷积层:3个输入通道,32个输出通道,卷积核大小3x3,padding为1self.conv1 = nn.Conv2d(3, 32, 3, padding=1)# 最大池化层:2x2池化self.pool = nn.MaxPool2d(2, 2)# 第二个卷积层:32个输入通道,64个输出通道,卷积核大小3x3,padding为1self.conv2 = nn.Conv2d(32, 64, 3, padding=1)# 全连接层,输入大小64x8x8,输出128self.fc1 = nn.Linear(64 * 8 * 8, 128)# 最后一层全连接层,输出10个类别self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))  # 卷积+ReLU+池化x = self.pool(F.relu(self.conv2(x)))  # 卷积+ReLU+池化x = x.view(-1, 64 * 8 * 8)  # 展平数据,准备全连接x = F.relu(self.fc1(x))  # 全连接+ReLUx = self.fc2(x)  # 最后一层输出return x# 3. 初始化模型、损失函数与优化器
model = SimpleCNN()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 判断是否使用GPU
model.to(device)  # 将模型转移到GPU或CPU上criterion = nn.CrossEntropyLoss()  # 使用交叉熵作为损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam优化器,学习率为0.001# 4. 训练模型
for epoch in range(10):  # 训练10个epochrunning_loss = 0.0for inputs, labels in trainloader:inputs, labels = inputs.to(device), labels.to(device)  # 数据转移到GPU或CPUoptimizer.zero_grad()  # 清除上一次的梯度outputs = model(inputs)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 优化器更新参数running_loss += loss.item()  # 累加损失print(f"Epoch {epoch + 1}, Loss: {running_loss:.3f}")  # 打印每个epoch的损失# 5. 测试模型效果
correct = 0
total = 0
model.eval()  # 将模型设置为评估模式with torch.no_grad():  # 禁止计算梯度,提高效率for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)  # 数据转移到GPU或CPUoutputs = model(images)_, predicted = torch.max(outputs, 1)  # 获取最大概率的类别total += labels.size(0)  # 累加总样本数correct += (predicted == labels).sum().item()  # 统计正确的样本数# 可视化一个批次的预测结果plt.figure(figsize=(10, 10))for i in range(8):  # 显示前8张图片plt.subplot(2, 4, i + 1)plt.imshow(images[i].cpu().permute(1, 2, 0) * 0.5 + 0.5)  # 反归一化并显示图片plt.title(f"Label: {labels[i].item()}\nPrediction: {predicted[i].item()}")plt.axis('off')plt.show()break  # 仅显示一个批次# 输出测试准确率
print(f"测试准确率: {100 * correct / total:.2f}%")

你可以将以下内容保存为 train_cifar10.py 并运行:

python train_cifar10.py

💡 不需要修改任何内容就能开始训练和测试!有 CUDA 就用 GPU,否则自动使用 CPU。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/79337.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/79337.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/79337.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

3.学习笔记--Spring-AOP总结(p39)-Spring事务简介(P40)-Spring事务角色(P41)-Spring事务属性(P42)

1.AOP总结:面向切面编程,在不惊动原始基础上为方法进行功能增强。 2.AOP核心概念: (1)代理:SpringAOP的核心是采用代理模式 (2)连接点:在SpringAOP中,理解为任…

数据库-day06

一、实验名称和性质 分类查询 验证 综合 设计 二、实验目的 1.掌握数据查询的Group by ; 2. 掌握聚集函数的使用方法。 三、实验的软硬件环境要求 硬件环境要求: PC机(单机) 使用的软件名称、版本号以及模块: …

看门狗定时器(WDT)超时

一、问题 Arduino 程序使用<Ticker.h>包时&#xff0c;使用不当情况下&#xff0c;会导致“看门狗WDT超时” 1.1问题控制台报错 在串口监视器显示 --------------- CUT HERE FOR EXCEPTION DECODER ---------------Soft WDT resetException (4): epc10x402077cb epc2…

AI在多Agent协同领域的核心概念、技术方法、应用场景及挑战 的详细解析

以下是 AI在多Agent协同领域的核心概念、技术方法、应用场景及挑战 的详细解析&#xff1a; 1. 多Agent协同的定义与核心目标 多Agent系统&#xff08;MAS, Multi-Agent System&#xff09;&#xff1a; 由多个独立或协作的智能体&#xff08;Agent&#xff09;组成&#xff…

Wireshark TS | 异常 ACK 数据包处理

问题背景 来自于学习群里群友讨论的一个数据包跟踪文件&#xff0c;在其中涉及到两处数据包异常现象&#xff0c;而产生这些现象的实际原因是数据包乱序。由于这两处数据包异常&#xff0c;都有点特别&#xff0c;本篇也就其中一个异常现象单独展开说明。 问题信息 数据包跟…

【React】项目的搭建

create-react-app 搭建vite 搭建相关下载 在Vue中搭建项目的步骤&#xff1a;1.首先安装脚手架的环境&#xff0c;2.通过脚手架的指令创建项目 在React中有两种方式去搭建项目&#xff1a;1.和Vue一样&#xff0c;先安装脚手架然后通过脚手架指令搭建&#xff1b;2.npx create-…

深入浅出 NVIDIA CUDA 架构与并行计算技术

&#x1f407;明明跟你说过&#xff1a;个人主页 &#x1f3c5;个人专栏&#xff1a;《深度探秘&#xff1a;AI界的007》 &#x1f3c5; &#x1f516;行路有良友&#xff0c;便是天堂&#x1f516; 目录 一、引言 1、CUDA为何重要&#xff1a;并行计算的时代 2、NVIDIA在…

pytorch学习02

自动微分 自动微分模块torch.autograd负责自动计算张量操作的梯度&#xff0c;具有自动求导功能。自动微分模块是构成神经网络训练的必要模块&#xff0c;可以实现网络权重参数的更新&#xff0c;使得反向传播算法的实现变得简单而高效。 1. 基础概念 张量 Torch中一切皆为张…

Java虚拟机(JVM)平台无关?相关?

计算机的概念模型 计算机实际上就是实现了一个图灵机模型。即&#xff0c;输入参数&#xff0c;根据程序计算&#xff0c;输出结果。图灵机模型如图。 Tape是输入数据&#xff0c;Program是针对这些数据进行计算的程序&#xff0c;中间横着的方块表示的是机器的状态。 目前使…

satoken的奇奇怪怪的错误

发了 /user/getBrowseDetail和/user/getResponDetail&#xff0c;但为什么进入handle里面有三次&#xff1f;且第一次的handle类型是AbstractHandleMapping$PreFlightHttpRequestHandlerxxx,这一次进来的时候flag为false&#xff0c;StpUtils.checkLogin抛出了异常 第二次进来的…

【KWDB 创作者计划】_上位机知识篇---SDK

文章目录 前言一、SDK的核心组成API(应用程序接口)库文件(Libraries)开发工具文档与示例依赖项与环境配置二、SDK的作用简化开发流程确保兼容性与稳定性加速产品迭代功能扩展与定制三、SDK的典型应用场景硬件设备开发操作系统与平台云服务与API集成游戏与图形开发四、SDK与…

golang处理时间的包time一次性全面了解

本文旨在对官方time包有个全面学习了解。不钻抠细节&#xff0c;但又有全面了解&#xff0c;重点介绍常用的内容&#xff0c;一些低频的可能这辈子可能都用不上。主打一个花最少时间办最大事。 Duration对象: 两个time实例经过的时间,以长度为int64的纳秒来计数。 常见的durati…

PyCharm Flask 使用 Tailwind CSS 配置

使用 Tailwind CSS 步骤 1&#xff1a;初始化项目 在 PyCharm 终端运行&#xff1a;npm init -y安装 Tailwind CSS&#xff1a;npm install -D tailwindcss postcss autoprefixer初始化 Tailwind 配置文件&#xff1a;npx tailwindcss init这会生成 tailwind.config.js。 步…

【英语语法】基本句型

目录 前言一&#xff1a;主谓二&#xff1a;主谓宾三&#xff1a;主系表四&#xff1a;主谓双宾五&#xff1a;主谓宾补 前言 英语基本句型是语法体系的基石&#xff0c;以下是英语五大基本句型。 一&#xff1a;主谓 结构&#xff1a;主语 不及物动词 例句&#xff1a; T…

隔离DCDC辅助电源解决方案与产品应用科普

**“隔离”与“非隔离的区别** 隔离&#xff1a; 1、AC-DC&#xff0c;也叫“一次电源”&#xff0c;人可能会碰到的应用场合&#xff0c;起安全保护作用&#xff1b; 2、为了抗干扰&#xff0c;通过隔离能有效隔绝干扰信号传输。 非隔离&#xff1a; 1、“安全特低电压&#…

DS-SLAM 运动一致性检测的源码解读

运动一致性检测是Frame.cc的Frame::ProcessMovingObject(const cv::Mat &imgray)函数。 对应DS-SLAM流程图Moving consistency check的部分 把这个函数单独摘出来&#xff0c;写了一下对两帧检测&#xff0c;查看效果的程序&#xff1a; #include <opencv2/opencv.hpp…

安全测试的全面知识体系及实现路径

以下是安全测试的全面知识体系及实现路径,结合最新工具和技术趋势(截至2025年): 一、安全测试核心类型与工具 1. 静态应用安全测试(SAST) 知识点: 通过分析源代码、字节码或二进制文件识别漏洞(如SQL注入、缓冲区溢出)支持早期漏洞发现,减少修复成本,适合白盒测试场…

GPT-4o Image Generation Capabilities: An Empirical Study

GPT-4o 图像生成能力:一项实证研究 目录 介绍研究背景方法论文本到图像生成图像到图像转换图像到 3D 能力主要优势局限性与挑战对比性能影响与未来方向结论介绍 近年来,图像生成领域发生了巨大的变化,从生成对抗网络 (GAN) 发展到扩散模型,再到可以处理多种模态的统一生成架…

Redis之全局唯一ID

全局ID生成器 文章目录 全局ID生成器一、全局ID生成器的定义定义核心作用 二、全局ID生成器需满足的特征1. 唯一性&#xff08;Uniqueness&#xff09;​2. 高性能&#xff08;High Performance&#xff09;​3. 可扩展性&#xff08;Scalability&#xff09;​4. 有序性&#…

nginx中的代理缓存

1.缓存存放路径 对key取哈希值之后&#xff0c;设置cache内容&#xff0c;然后得到的哈希值的倒数第一位作为第一个子目录&#xff0c;倒数第三位和倒数第二位组成的字符串作为第二个子目录&#xff0c;如图。 proxy_cache_path /xxxx/ levels1:2 2.文件名哈希值