食物图像分类是计算机视觉的经典任务之一,其核心是让机器 “看懂” 图像中的食物类别。随着深度学习的发展,卷积神经网络(CNN)凭借强大的特征提取能力,成为图像分类的主流方案。本文将基于 PyTorch 框架,从代码实战出发,拆解食物图像分类项目中的核心知识点,包括环境搭建、数据预处理、数据集构建、CNN 模型设计、模型训练与测试、单图预测等,带大家从零搭建一个能识别 20 类食物的分类系统。

# 导入必要的库
import torch  # PyTorch核心库,用于构建和训练神经网络
from torch import nn  # 神经网络模块,包含各种层和损失函数
from torch.utils.data import Dataset, DataLoader  # 数据集和数据加载器,用于数据处理
import numpy as np  # 数值计算库,可用于数据预处理等
from PIL import Image  # 图像处理库,用于读取和处理图像
from torchvision import transforms  # 图像转换工具,用于数据增强和预处理
import os  # 操作系统接口,用于文件路径处理等# 定义数据转换策略:训练集使用数据增强,验证集/测试集保持一致的基础转换
data_transforms = {'train':  # 训练集转换(包含数据增强,增加样本多样性)transforms.Compose([transforms.Resize([300, 300]),  # 先将图像调整为300x300transforms.RandomRotation(45),  # 随机旋转(-45~45度),增强旋转不变性transforms.CenterCrop(256),  # 中心裁剪到256x256,去除旋转后的黑边transforms.RandomHorizontalFlip(p=0.5),  # 50%概率水平翻转transforms.RandomVerticalFlip(p=0.5),  # 50%概率垂直翻转transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),  # 随机调整亮度、对比度、饱和度和色调transforms.RandomGrayscale(p=0.1),  # 10%概率转为灰度图transforms.ToTensor(),  # 转为Tensor格式([C, H, W]),并将像素值归一化到[0,1]transforms.Normalize(  # 使用ImageNet的均值和标准差进行标准化[0.485, 0.456, 0.406],  # 均值(RGB三个通道)[0.229, 0.224, 0.225]  # 标准差(RGB三个通道))]),'valid':  # 验证集/测试集转换(无增强,保持数据一致性)transforms.Compose([transforms.Resize([256, 256]),  # 调整为256x256,与训练集裁剪后尺寸一致transforms.ToTensor(),  # 转为Tensor]),
}# -------------------------- 2. 自定义数据集类 --------------------------
class food_dataset(Dataset):"""自定义食物图像数据集类,继承自PyTorch的Dataset用于加载图像路径和对应标签,并进行预处理"""def __init__(self, file_path, transform=None):"""初始化数据集:param file_path: 存储图像路径和标签的文本文件路径:param transform: 图像转换函数(预处理/数据增强)"""self.file_path = file_path  # 文本文件路径self.transform = transform  # 转换函数self.imgs = []  # 存储所有图像路径self.labels = []  # 存储对应标签# 读取文件列表(每行格式:图片路径 数字标签)with open(self.file_path, 'r', encoding="utf-8") as f:for line in f.readlines():line = line.strip()  # 去除首尾空格和换行符if not line:  # 跳过空行continue# 按空格分割路径和标签(假设格式严格,无多余空格)img_path, label = line.split(' ')self.imgs.append(img_path)self.labels.append(label)def __len__(self):"""返回数据集样本数量"""return len(self.imgs)def __getitem__(self, index):"""根据索引获取单个样本(图像和标签):param index: 样本索引:return: 处理后的图像张量和标签张量"""# 读取图片并强制转为RGB(避免灰度图导致的通道数不匹配问题)try:image = Image.open(self.imgs[index]).convert('RGB')  # 确保3通道输入except Exception as e:# 捕获读取错误,便于调试raise ValueError(f"读取图片 {self.imgs[index]} 失败:{e}")# 应用转换(预处理/数据增强)if self.transform:image = self.transform(image)# 处理标签:转为整数类型的张量(PyTorch分类任务要求标签为long类型)label = torch.tensor(int(self.labels[index]), dtype=torch.int64)return image, label# 加载数据集
# 注意:需确保train.txt和test.txt文件存在,每行格式为「图片路径 数字标签」
try:# 加载训练集(使用训练集转换)training_data = food_dataset(file_path='./train.txt', transform=data_transforms['train'])# 加载测试集(使用验证集转换)test_data = food_dataset(file_path='./test.txt', transform=data_transforms['valid'])
except FileNotFoundError:# 捕获文件不存在错误,提示用户raise FileNotFoundError("请确保 train.txt 和 test.txt 文件在当前目录下")# 创建数据加载器(批量加载数据,支持打乱和多进程)
train_dataloader = DataLoader(training_data,batch_size=8,  # 批大小:每次加载8张图片shuffle=True  # 训练时打乱数据顺序,增强训练效果
)
test_dataloader = DataLoader(test_data,batch_size=8,  # 测试时也用相同批大小shuffle=True  # 测试时打乱不影响结果,主要便于观察不同样本
)# 设备配置:优先使用GPU(cuda),其次是Apple M系列芯片(mps),最后是CPU
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')  # 打印使用的设备# 定义CNN模型(卷积神经网络)
class CNN(nn.Module):"""自定义卷积神经网络模型,用于食物图像分类包含4个卷积块和1个全连接输出层"""def __init__(self):super().__init__()  # 调用父类nn.Module的初始化方法# 第一个卷积块:1次卷积 + ReLU激活 + 最大池化self.conv1 = nn.Sequential(# 卷积层:输入3通道(RGB),输出16通道,卷积核5x5,步长1,填充2(保持尺寸)nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1, padding=2),nn.ReLU(),  # 激活函数,引入非线性nn.MaxPool2d(kernel_size=2),  # 最大池化:尺寸减半(256→128))# 第二个卷积块:2次卷积 + ReLU + 最大池化self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),  # 输入16通道,输出32通道nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),  # 输入32通道,输出32通道nn.ReLU(),nn.MaxPool2d(2),  # 尺寸减半(128→64))# 第三个卷积块:2次卷积 + ReLU + 最大池化self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),  # 输入32通道,输出64通道nn.ReLU(),nn.Conv2d(64, 128, 5, 1, 2),  # 输入64通道,输出128通道nn.ReLU(),nn.MaxPool2d(2),  # 尺寸减半(64→32))# 第四个卷积块:1次卷积 + ReLU(无池化,保持尺寸)self.conv4 = nn.Sequential(nn.Conv2d(128, 128, 5, 1, 2),  # 输入128通道,输出128通道nn.ReLU(),  # 输出尺寸:32×32,通道数128)# 全连接输出层:将特征映射到20个类别(食物种类)# 输入尺寸计算:128通道 × 32高 × 32宽(经多次池化后的特征图尺寸)self.out = nn.Linear(128 * 32 * 32, 20)  # 20类:需与标签数量一致def forward(self, x):"""前向传播:定义数据在网络中的流动路径:param x: 输入张量,形状为[batch_size, 3, 256, 256]:return: 输出张量,形状为[batch_size, 20](各类别的预测分数)"""x = self.conv1(x)  # 经第一个卷积块处理x = self.conv2(x)  # 经第二个卷积块处理x = self.conv3(x)  # 经第三个卷积块处理x = self.conv4(x)  # 经第四个卷积块处理x = x.view(x.size(0), -1)  # 展平特征图:[batch_size, 128*32*32]output = self.out(x)  # 经全连接层输出预测结果return output# -------------------------- 训练与测试函数 --------------------------
def train(dataloader, model, loss_fn, optimizer):"""训练模型的函数:param dataloader: 训练数据集加载器:param model: 待训练的模型:param loss_fn: 损失函数(用于计算预测误差):param optimizer: 优化器(用于更新模型参数)"""model.train()  # 开启训练模式(启用Dropout、BatchNorm等训练特定行为)batch_size_num = 1  # 记录当前批次编号for X, y in dataloader:# 将数据移动到指定设备(GPU/CPU)X, y = X.to(device), y.to(device)# 前向传播:计算模型预测结果pred = model(X)# 计算损失(预测值与真实标签的差距)loss = loss_fn(pred, y)# 反向传播与参数更新optimizer.zero_grad()  # 清空上一轮的梯度(避免梯度累积)loss.backward()  # 反向传播计算梯度optimizer.step()  # 根据梯度更新模型参数# 打印损失(每2个batch打印一次,便于监控训练过程)loss_val = loss.item()  # 获取损失的标量值if batch_size_num % 2 == 0:print(f"loss: {loss_val:>7f}  [batch: {batch_size_num}]")batch_size_num += 1def test(dataloader, model, loss_fn):model.eval()  # 开启评估模式(关闭Dropout、固定BatchNorm参数等)size = len(dataloader.dataset)  # 测试集总样本数num_batches = len(dataloader)  # 测试集批次数test_loss, correct = 0, 0  # 总损失和正确预测数# 关闭梯度计算(测试时不需要更新参数,节省计算资源)with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)  # 数据移至设备pred = model(X)  # 预测test_loss += loss_fn(pred, y).item()  # 累加损失# 统计正确预测数:取预测概率最大的类别与真实标签比较correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 计算平均损失和准确率test_loss /= num_batches  # 平均损失correct /= size  # 准确率print(f"\nTest Result: \n Accuracy: {(100 * correct):>5.2f}%, Avg Loss: {test_loss:>8f}\n")# -------------------------- 单张图片预测函数 --------------------------
def predict_single_image(image_path, model, transform, device, label_map):"""对单张图片进行预测:param image_path: 图片路径:param model: 训练好的模型:param transform: 图像预处理函数(与测试集一致):param device: 计算设备:param label_map: 标签映射字典(数字标签→食物名称):return: 预测的食物名称"""# 读取并预处理图片(与测试集预处理一致)image = Image.open(image_path).convert('RGB')  # 确保3通道image = transform(image)  # 应用预处理(Resize和ToTensor)# 增加batch维度(模型要求输入格式为[batch, C, H, W],这里batch=1)image = image.unsqueeze(0).to(device)# 模型预测model.eval()  # 开启评估模式with torch.no_grad():  # 关闭梯度计算pred_logits = model(image)  # 得到预测分数(logits)# 取概率最大的类别标签(argmax(1)按行取最大值索引)pred_label = pred_logits.argmax(1).item()# 映射为食物名称if pred_label not in label_map:raise KeyError(f"预测标签 {pred_label} 不在标签映射字典中")return label_map[pred_label]# -------------------------- 主程序 --------------------------
if __name__ == "__main__":# 初始化模型、损失函数、优化器model = CNN().to(device)  # 创建模型并移至设备loss_fn = nn.CrossEntropyLoss()  # 多分类问题常用交叉熵损失# Adam优化器:自适应学习率,训练效果较好,学习率0.001optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练模型(100轮)epochs = 100for t in range(epochs):print(f"\nEpoch: {t + 1}/{epochs}\n----------------------------")train(train_dataloader, model, loss_fn, optimizer)print("Training Done!")# 测试模型在测试集上的性能test(test_dataloader, model, loss_fn)# 定义标签映射字典:数字标签→食物名称# 需与数据集的标签完全对应(顺序和数量一致)label_to_food = {0: "八宝粥", 1: "巴旦木", 2: "白萝卜", 3: "板栗", 4: "菠萝",5: "草莓", 6: "蛋", 7: "蛋挞", 8: "骨肉相连", 9: "瓜子",10: "哈密瓜", 11: "汉堡", 12: "胡萝卜", 13: "火龙果", 14: "鸡翅",15: "青菜", 16: "生肉", 17: "圣女果", 18: "薯条", 19: "炸鸡"}# 输入图片路径并预测image_path = input("请输入图片路径:")  # 用户输入待预测图片路径true_food = input("请输入该图片的真实食物名称:")  # 用户输入真实标签(用于对比)# 执行预测并输出结果predicted_food = predict_single_image(image_path=image_path,model=model,transform=data_transforms['valid'],device=device,label_map=label_to_food)# 输出对比结果print("\n" + "-" * 50)print(f"预测结果:{predicted_food}")print(f"真实结果:{true_food}")print(f"判断:{'预测正确' if predicted_food == true_food else '预测错误'}")print("-" * 50)

二、数据预处理:让数据 “适配” 模型

在深度学习中,数据预处理的质量直接影响模型性能。原始图像可能存在尺寸不一、像素值范围差异大、样本数量不足等问题,需通过预处理将其转化为模型可接受的格式,并通过数据增强提升模型泛化能力。

本项目的预处理逻辑集中在data_transforms字典中,分 “训练集” 和 “验证集 / 测试集” 两种策略,我们逐一拆解其设计思路。

2.1 为什么要区分训练集与验证集预处理?

  • 训练集:需要通过 “数据增强” 增加样本多样性,避免模型过拟合(即模型只记住训练样本,对新样本识别能力差)。
  • 验证集 / 测试集:需保持数据的 “真实性”,仅进行基础预处理(如 Resize、ToTensor),确保评估结果能反映模型的实际泛化能力。

2.2 训练集数据增强:每一步的作用与原理

训练集的预处理链为:Resize → RandomRotation → CenterCrop → RandomHorizontalFlip → RandomVerticalFlip → ColorJitter → RandomGrayscale → ToTensor → Normalize,我们逐个解析:

(1)Resize ([300, 300]):统一初始尺寸

将所有图像调整为 300×300 像素。为什么不直接调整为最终的 256×256?因为后续会进行旋转和裁剪,预留一定尺寸可避免旋转后出现黑边。

(2)RandomRotation (45):随机旋转

随机将图像旋转 - 45°~45°。食物在拍摄时可能有不同角度(如躺着的汉堡、竖放的胡萝卜),旋转增强能让模型对角度不敏感,提升鲁棒性。

(3)CenterCrop (256):中心裁剪

将旋转后的图像从中心裁剪为 256×256。旋转会导致图像边缘出现黑边,裁剪可去除黑边,同时将图像尺寸统一为模型输入尺寸(256×256)。

(4)RandomHorizontalFlip (p=0.5) & RandomVerticalFlip (p=0.5):随机翻转
  • 水平翻转(50% 概率):模拟 “左右镜像” 的食物(如翻转后的草莓外观不变)。
  • 垂直翻转(50% 概率):模拟 “上下颠倒” 的场景(如掉落的薯条)。
    翻转操作不改变食物的核心特征,但能增加样本多样性,且计算成本低。
(5)ColorJitter (0.1, 0.1, 0.1, 0.1):随机颜色抖动

调整图像的亮度、对比度、饱和度、色调,各参数的取值范围为 0~1(0 表示不调整,1 表示最大调整幅度)。
食物图像的颜色易受光照影响(如白天和夜晚拍摄的青菜颜色不同),颜色抖动能让模型对光照变化不敏感。

(6)RandomGrayscale (p=0.1):随机灰度化

10% 概率将彩色图像转为灰度图。虽然食物的颜色是重要特征,但灰度化能迫使模型关注食物的形状、纹理等更本质的特征,避免过度依赖颜色信息(如红色的草莓和红色的圣女果,需通过形状区分)。

(7)ToTensor ():转为 Tensor 格式

将 PIL 图像(H×W×C,像素值 0~255)转为 PyTorch Tensor(C×H×W,像素值归一化到 0~1)。

  • 维度转换:模型要求输入为 “通道优先”(C×H×W),而 PIL 图像是 “高度优先”(H×W×C),需通过 ToTensor 调整。
  • 归一化:将像素值从 0~255 缩放到 0~1,避免大数值导致模型梯度爆炸。
(8)Normalize ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]):标准化

使用 ImageNet 数据集的均值和标准差对 Tensor 进行标准化,公式为:
标准化后像素值 = (原始像素值 - 均值) / 标准差
为什么用 ImageNet 的参数?因为本项目后续可扩展为迁移学习(使用预训练模型),而预训练模型是在 ImageNet 上训练的,使用相同的标准化参数能让模型更快收敛。

transforms.Compose([transforms.Resize([300, 300]),  # 先将图像调整为300x300transforms.RandomRotation(45),  # 随机旋转(-45~45度),增强旋转不变性transforms.CenterCrop(256),  # 中心裁剪到256x256,去除旋转后的黑边transforms.RandomHorizontalFlip(p=0.5),  # 50%概率水平翻转transforms.RandomVerticalFlip(p=0.5),  # 50%概率垂直翻转transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),  # 随机调整亮度、对比度、饱和度和色调transforms.RandomGrayscale(p=0.1),  # 10%概率转为灰度图transforms.ToTensor(),  # 转为Tensor格式([C, H, W]),并将像素值归一化到[0,1]transforms.Normalize(  # 使用ImageNet的均值和标准差进行标准化[0.485, 0.456, 0.406],  # 均值(RGB三个通道)[0.229, 0.224, 0.225]  # 标准差(RGB三个通道))

2.3 验证集预处理

验证集的预处理链为:Resize([256, 256]) → ToTensor(),仅保留基础操作:

  • Resize ([256, 256]):直接将图像调整为 256×256,无需旋转(避免引入非真实样本)。
  • ToTensor ():与训练集一致,确保数据格式统一。
    (注:代码中验证集未做 Normalize,实际项目中建议与训练集保持一致,此处可根据需求调整)

三、自定义 Dataset:PyTorch 数据加载的核心

PyTorch 通过DatasetDataLoader实现数据加载,其中Dataset负责 “定义数据来源和格式”,DataLoader负责 “批量加载和并行处理”。本项目自定义了food_dataset类,用于加载食物图像和对应标签,我们详细解析其实现逻辑。

3.1 Dataset 的核心作用

Dataset是一个抽象类,要求子类必须实现三个方法:

  1. __init__:初始化数据集(读取文件列表、加载预处理函数)。
  2. __len__:返回数据集的总样本数。
  3. __getitem__:根据索引返回单个样本(图像 + 标签)。
    这三个方法确保了 PyTorch 能高效地迭代访问数据。

3.2 food_dataset 类逐方法解析

(1)init:初始化数据列表
def __init__(self, file_path, transform=None):self.file_path = file_path  # 存储图像路径和标签的txt文件路径self.transform = transform  # 预处理函数self.imgs = []  # 存储所有图像路径self.labels = []  # 存储对应标签# 读取txt文件,解析图像路径和标签with open(self.file_path, 'r', encoding="utf-8") as f:for line in f.readlines():line = line.strip()  # 去除首尾空格和换行符if not line:  # 跳过空行(避免解析错误)continueimg_path, label = line.split(' ')  # 按空格分割路径和标签self.imgs.append(img_path)self.labels.append(label)

  • txt 文件格式要求:每行需包含 “图像路径” 和 “数字标签”,用空格分隔
    其中 “0” 对应 “八宝粥”,“1” 对应 “巴旦木”,需与后续label_to_food字典一致。
(2)len:返回样本总数
def __len__(self):return len(self.imgs)

简单直接,返回self.imgs的长度),DataLoader会通过该方法确定迭代次数。

(3)getitem:返回单个样本
def __getitem__(self, index):# 读取图像并强制转为RGB(避免灰度图通道数问题)try:image = Image.open(self.imgs[index]).convert('RGB')except Exception as e:raise ValueError(f"读取图片 {self.imgs[index]} 失败:{e}")# 应用预处理if self.transform:image = self.transform(image)# 处理标签:转为int64类型Tensor(PyTorch分类任务要求)label = torch.tensor(int(self.labels[index]), dtype=torch.int64)return image, label

这是Dataset的核心方法,需重点关注三个细节:

  1. 强制 RGB 格式convert('RGB')确保所有图像都是 3 通道(避免部分灰度图是 1 通道,导致模型输入维度不匹配)。
  2. 异常处理try-except捕获图像读取错误(如路径错误、图像损坏),并明确提示错误位置,便于调试。
  3. 标签类型:将标签转为torch.int64(即 LongTensor),因为 PyTorch 的CrossEntropyLoss要求标签为 Long 类型。

3.3 如何准备自己的数据集?

  1. 收集图像:每个食物类别收集至少 100 张图像(样本越多,模型性能越好),建议按类别分文件夹存储
  2. 生成 txt 文件:编写脚本遍历图像文件夹,生成train.txttest.txt

    import os# 数据集根目录
    train_root = "./dataset/train"
    test_root = "./dataset/test"
    # 标签映射(与后续一致)
    label_to_food = {0: "八宝粥", 1: "巴旦木", ..., 19: "炸鸡"}
    # 反向映射:食物名称→数字标签
    food_to_label = {v: k for k, v in label_to_food.items()}# 生成train.txt
    with open("train.txt", "w", encoding="utf-8") as f:for food_name in os.listdir(train_root):food_dir = os.path.join(train_root, food_name)if not os.path.isdir(food_dir):continuelabel = food_to_label[food_name]for img_name in os.listdir(food_dir):img_path = os.path.join(food_dir, img_name)f.write(f"{img_path} {label}\n")# 生成test.txt(逻辑同上)
    with open("test.txt", "w", encoding="utf-8") as f:for food_name in os.listdir(test_root):food_dir = os.path.join(test_root, food_name)if not os.path.isdir(food_dir):continuelabel = food_to_label[food_name]for img_name in os.listdir(food_dir):img_path = os.path.join(food_dir, img_name)f.write(f"{img_path} {label}\n")
    
  3. 检查路径:确保 txt 文件中的图像路径与实际文件路径一致

四、DataLoader:批量加载与并行处理

Dataset定义了数据的 “来源”,而DataLoader则负责将数据 “批量加载” 到模型中,并支持并行处理,提升数据加载速度。

4.1 DataLoader 的核心参数解析

本项目的DataLoader初始化代码如下:

train_dataloader = DataLoader(training_data, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=8, shuffle=True)

核心参数含义:

  • batch_size:每次加载的样本数量(批大小),需根据 GPU 显存调整
  • shuffle:是否打乱数据顺序。训练集设为True,验证集可设为False,本项目测试集设为True是为了观察不同样本的预测效果。

4.2 DataLoader 与 Dataset 的协作流程

DataLoader的工作流程可概括为:

  1. 调用Dataset.__len__()获取总样本数,计算总批次数(总样本数 //batch_size)。
  2. shuffle=True,则在每个 epoch(训练轮次)开始前打乱样本索引。
  3. 对每个批次,根据索引调用Dataset.__getitem__()获取单个样本,组装成一个批次的 Tensor(形状为 [batch_size, C, H, W])。
  4. 将批次数据移动到指定设备(CUDA/MPS/CPU),供模型训练或测试。

4.3 数据加载到设备的逻辑

X, y = X.to(device), y.to(device)

其中device是通过以下代码确定的:

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

  • 为什么要移动设备? 模型和数据必须在同一设备上才能进行计算(如模型在 CUDA 上,数据也需在 CUDA 上),否则会报错。
  • 设备优先级:优先使用 CUDA(NVIDIA GPU),其次是 MPS(Apple M 系列),最后是 CPU

五、CNN 模型构建:从卷积到全连接的特征提取

卷积神经网络(CNN)是图像分类的核心,其通过 “卷积层提取局部特征→池化层降维→全连接层分类” 的流程,实现对图像的识别。本项目的 CNN 模型包含 4 个卷积块和 1 个全连接层,我们逐一解析其设计思路和尺寸计算。

5.1 CNN 的核心组件与作用

在解析代码前,先回顾 CNN 的三个核心组件:

  1. 卷积层(Conv2d):通过卷积核滑动提取图像的局部特征(如边缘、纹理、形状),输出 “特征图”(Feature Map)。
  2. 激活函数(ReLU):引入非线性,让模型能拟合复杂的特征关系(避免线性模型的表达能力不足)。
  3. 池化层(MaxPool2d):对特征图进行下采样,降低维度和计算量,同时增强模型对特征位置的鲁棒性。

5.2 模型代码逐块解析

模型定义代码如下,我们按 “卷积块 1→卷积块 2→卷积块 3→卷积块 4→全连接层” 的顺序解析:

class CNN(nn.Module):def __init__(self):super().__init__()# 卷积块1:1次卷积 + ReLU + 最大池化self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2),  # 池化后尺寸:256→128)# 卷积块2:2次卷积 + ReLU + 最大池化self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),  # 尺寸:128→64)# 卷积块3:2次卷积 + ReLU + 最大池化self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU(),nn.Conv2d(64, 128, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),  # 尺寸:64→32)# 卷积块4:1次卷积 + ReLU(无池化)self.conv4 = nn.Sequential(nn.Conv2d(128, 128, 5, 1, 2),nn.ReLU(),  # 输出尺寸:32×32,通道数128)# 全连接层:映射到20类self.out = nn.Linear(128 * 32 * 32, 20)
池化层参数计算:尺寸减半

MaxPool2d(kernel_size=2)表示池化核大小为 2×2,步长默认等于核大小(即 2),因此输出尺寸为输入尺寸的 1/2:

  • conv1池化前尺寸:256×256 → 池化后:128×128
  • conv2池化前尺寸:128×128 → 池化后:64×64
  • conv3池化前尺寸:64×64 → 池化后:32×32
各卷积块的输出特征
  • conv1:输出特征图形状为 [batch_size, 16, 128, 128],提取的是图像的低级特征(如边缘、颜色块)。
  • conv2:输出形状为 [batch_size, 32, 64, 64],通过 2 次卷积提取更复杂的特征(如食物的局部轮廓)。
  • conv3:输出形状为 [batch_size, 128, 32, 32],通道数增加到 128,特征更抽象(如食物的结构特征)。
  • conv4:输出形状为 [batch_size, 128, 32, 32],无池化,进一步细化特征(避免池化导致的特征损失)。
(4)全连接层:从特征到分类

全连接层self.out的输入维度是128×32×32,这是由conv4的输出特征图形状决定的:

  • conv4输出:[batch_size, 128, 32, 32] → 展平后为 [batch_size, 128×32×32](展平操作在forward中通过x.view(x.size(0), -1)实现)。
  • 输出维度:20,对应 20 种食物类别,每个维度输出该类别的 “预测分数”(后续通过argmax取分数最高的类别作为预测结果)。

5.3 forward 方法:定义数据流动路径

def forward(self, x):x = self.conv1(x)  # 经卷积块1处理x = self.conv2(x)  # 经卷积块2处理x = self.conv3(x)  # 经卷积块3处理x = self.conv4(x)  # 经卷积块4处理x = x.view(x.size(0), -1)  # 展平:[batch_size, 128*32*32]output = self.out(x)  # 全连接层输出return output
  • 展平操作x.view(x.size(0), -1)将 4 维特征图(batch, C, H, W)转为 2 维张量(batch, C×H×W),因为全连接层仅接受 2 维输入。
  • 数据流动:输入图像([batch, 3, 256, 256])→ 卷积块 1→2→3→4 → 展平 → 全连接层 → 输出([batch, 20])。

5.4 模型初始化与设备移动

模型初始化代码如下:

model = CNN().to(device)
  • CNN()创建模型实例,to(device)将模型参数移动到指定设备(CUDA/MPS/CPU),确保模型和数据在同一设备上计算。

六、模型训练与测试:从损失下降到性能评估

模型构建完成后,需通过训练让模型 “学习” 食物特征,再通过测试评估模型的泛化能力。本项目定义了traintest两个函数,分别实现训练和测试逻辑。

6.1 训练函数:让模型 “学习”

训练函数的核心是 “前向传播计算损失→反向传播更新参数”,代码如下:

def train(dataloader, model, loss_fn, optimizer):model.train()  # 开启训练模式(启用Dropout、BatchNorm训练行为)batch_size_num = 1  # 批次编号,用于打印损失for X, y in dataloader:# 数据移动到设备X, y = X.to(device), y.to(device)# 1. 前向传播:计算预测结果pred = model(X)# 2. 计算损失:预测值与真实标签的差距loss = loss_fn(pred, y)# 3. 反向传播与参数更新optimizer.zero_grad()  # 清空上一轮梯度(避免累积)loss.backward()        # 反向传播计算梯度optimizer.step()       # 根据梯度更新模型参数# 打印损失(每2个批次打印一次)loss_val = loss.item()if batch_size_num % 2 == 0:print(f"loss: {loss_val:>7f}  [batch: {batch_size_num}]")batch_size_num += 1
(1)关键步骤解析
  1. model.train():开启训练模式,对含有 Dropout、BatchNorm 的模型至关重要:

    • Dropout:训练时随机 “关闭” 部分神经元,防止过拟合;测试时不关闭。
    • BatchNorm:训练时使用批次的均值和方差归一化;测试时使用训练阶段累积的均值和方差。
  2. 前向传播(Forward Pass)

    • pred = model(X):将批次数据输入模型,得到预测结果([batch, 20])。
    • loss = loss_fn(pred, y):计算损失,本项目使用CrossEntropyLoss(多分类任务的常用损失函数)。
  3. 反向传播(Backward Pass)与参数更新

    • optimizer.zero_grad():清空梯度。若不清空,梯度会累积到上一轮,导致参数更新错误。
    • loss.backward():根据损失计算各参数的梯度(
    • optimizer.step():根据梯度更新模型参数
(2)损失函数与优化器选择

本项目使用的损失函数和优化器如下:

loss_fn = nn.CrossEntropyLoss()  # 多分类交叉熵损失
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # Adam优化器
  • CrossEntropyLoss:适用于多分类任务
  • Adam 优化器:自适应学习率优化器,收敛速度快,对学习率不敏感,是深度学习中最常用的优化器之一。lr=0.001是常用的初始学习率,可根据训练情况调整

6.2 测试函数:评估模型泛化能力

测试函数的核心是 “计算模型在测试集上的准确率和平均损失”,代码如下:

def test(dataloader, model, loss_fn):model.eval()  # 开启评估模式(关闭Dropout、固定BatchNorm)size = len(dataloader.dataset)  # 测试集总样本数num_batches = len(dataloader)    # 测试集批次数test_loss, correct = 0, 0        # 总损失和正确预测数# 关闭梯度计算(节省资源,避免参数更新)with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)# 累加损失test_loss += loss_fn(pred, y).item()# 累加正确预测数correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 计算平均损失和准确率test_loss /= num_batchescorrect /= sizeprint(f"\nTest Result: \n Accuracy: {(100 * correct):>5.2f}%, Avg Loss: {test_loss:>8f}\n")
(1)关键步骤解析
  1. model.eval():开启评估模式,与model.train()对应,确保模型在测试时的行为与训练时一致(如关闭 Dropout)。
  2. torch.no_grad():上下文管理器,关闭梯度计算。测试时无需更新参数,关闭梯度可大幅减少内存占用和计算时间。
  3. 准确率计算
    • pred.argmax(1):对每个样本,取预测分数最高的类别(维度 1 是类别维度,[batch, 20]→[batch, 1])。
    • (pred.argmax(1) == y):比较预测类别与真实标签,得到布尔张量(True = 正确,False = 错误)。
    • type(torch.float).sum().item():将布尔张量转为 float(True=1,False=0),求和得到正确预测数,再转为 Python 标量。
  4. 结果解读
    • Accuracy:准确率(正确预测数 / 总样本数),反映模型的整体识别能力,越高越好。
    • Avg Loss:平均损失,反映模型预测值与真实标签的平均差距,越低越好。

6.3 训练流程与轮次设置

训练流程代码如下:

# 训练轮次(epochs)
epochs = 100
for t in range(epochs):print(f"\nEpoch: {t + 1}/{epochs}\n")train(train_dataloader, model, loss_fn, optimizer)
print("Training Done!")# 测试模型
test(test_dataloader, model, loss_fn)
  • epochs(训练轮次):表示模型将遍历整个训练集的次数。本项目设为 100,可根据实际情况调整:
    • 若训练损失仍在下降,可增加 epochs;
    • 若训练损失下降但测试损失上升(过拟合),可减少 epochs 或加入早停机制。
  • 训练与测试顺序:每轮训练后可加入测试(如在train后调用test),便于监控模型是否过拟合;本项目在所有训练完成后测试,适用于快速验证。

七、单张图片预测:模型的实际应用

训练完成后,需将模型用于实际场景 —— 对单张食物图片进行分类。本项目定义了predict_single_image函数,实现从图像读取到类别输出的完整流程。

7.1 预测函数解析

def predict_single_image(image_path, model, transform, device, label_map):# 1. 读取并预处理图像(与测试集一致)image = Image.open(image_path).convert('RGB')  # 强制RGBimage = transform(image)  # 应用预处理(Resize + ToTensor)# 2. 增加batch维度(模型要求输入为[batch, C, H, W])image = image.unsqueeze(0).to(device)# 3. 模型预测model.eval()  # 开启评估模式with torch.no_grad():pred_logits = model(image)  # 预测分数(logits)pred_label = pred_logits.argmax(1).item()  # 取最高分数类别# 4. 映射为食物名称if pred_label not in label_map:raise KeyError(f"预测标签 {pred_label} 不在标签映射字典中")return label_map[pred_label]
(1)关键步骤解析
  1. 图像预处理一致性:预测时的预处理必须与测试集一致(本项目使用data_transforms['valid']),否则模型输入格式不匹配,预测结果会失真。
  2. 增加 batch 维度:模型训练和测试时输入都是批次数据([batch, C, H, W]),而单张图片是 [C, H, W],需通过unsqueeze(0)在第 0 维(batch 维)增加一个维度,变为 [1, C, H, W]。
  3. 标签映射label_map(如label_to_food)将数字标签(如 0)映射为食物名称(如 “八宝粥”),让预测结果更直观。

7.2 预测实战与结果展示

预测代码如下,用户输入图片路径和真实标签,模型输出预测结果并对比:

# 标签映射字典(与数据集标签对应)
label_to_food = {0: "八宝粥", 1: "巴旦木", 2: "白萝卜", 3: "板栗", 4: "菠萝",5: "草莓", 6: "蛋", 7: "蛋挞", 8: "骨肉相连", 9: "瓜子",10: "哈密瓜", 11: "汉堡", 12: "胡萝卜", 13: "火龙果", 14: "鸡翅",15: "青菜", 16: "生肉", 17: "圣女果", 18: "薯条", 19: "炸鸡"
}# 用户输入
image_path = input("请输入图片路径:")
true_food = input("请输入该图片的真实食物名称:")# 执行预测
predicted_food = predict_single_image(image_path=image_path,model=model,transform=data_transforms['valid'],device=device,label_map=label_to_food
)# 输出结果
print("\n" + "-" * 50)
print(f"预测结果:{predicted_food}")
print(f"真实结果:{true_food}")
print(f"判断:{'预测正确' if predicted_food == true_food else '预测错误'}")
print("-" * 50)
(1)预测示例

假设用户输入:

  • 图片路径:"D:\食物分类\food_dataset\test\八宝粥\img_八宝粥罐_81.jpeg"(一张八宝粥图片)
  • 真实食物名称:八宝粥

模型输出:

请输入图片路径:./test_images/hamburger.jpg
请输入该图片的真实食物名称:汉堡--------------------------------------------------
预测结果:汉堡
真实结果:汉堡
判断:预测正确
--------------------------------------------------
(2)预测错误原因

样本数量不足

训练的轮数过少

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/97926.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/97926.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/97926.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python 值传递 (Pass by Value) 和引用传递 (Pass by Reference)

Python 值传递 {Pass by Value} 和引用传递 {Pass by Reference}1. Mutable Objects and Immutable Objects in Python (Python 可变对象和不可变对象)2. Pass by Value and Pass by Reference2.1. What is Pass by Value in Python?2.2. What is Pass by Reference in Python…

aippt自动生成工具有哪些?一文看懂,总有一款适合你!

在当今快节奏的工作与学习环境中,传统耗时的PPT制作方式已难以满足高效表达的需求。随着人工智能技术的发展,AI自动生成PPT工具应运而生,成为提升演示文稿制作效率的利器。这类工具通过自然语言处理和深度学习技术,能够根据用户输…

Langflow 框架中 Prompt 技术底层实现分析

Langflow 框架中 Prompt 技术底层实现分析 1. Prompt 技术概述 Langflow 是一个基于 LangChain 的可视化 AI 工作流构建框架,其 Prompt 技术是整个系统的核心组件之一。Prompt 技术主要负责: 模板化处理:支持动态变量替换的提示词模板变量验证…

前端、node跨域问题

前端页面访问node后端接口跨域报错 Access to XMLHttpRequest at http://192.18.31.75/api/get?namess&age19 from origin http://127.0.0.1:5500 has been blocked by CORS policy: No Access-Control-Allow-Origin header is present on the requested resource. 这个报…

超越马力欧:如何为经典2D平台游戏注入全新灵魂

在游戏开发的世界里,2D平台游戏仿佛是一位熟悉的老朋友。从《超级马力欧兄弟》开启的黄金时代到现在,这个类型已经经历了数十年的演变与打磨。当每个基础设计似乎都已被探索殆尽时,我们如何才能打造出一款令人耳目一新的平台游戏?…

基于Springboot + vue3实现的时尚美妆电商网站

项目描述本系统包含管理员和用户两个角色。管理员角色:商品分类管理:新增、查看、修改、删除商品分类。商品信息管理:新增、查看、修改、删除、查看评论商品信息。用户管理:新增、查看、修改、删除用户。管理员管理:查…

网络协议之https?

写在前面 https协议还是挺复杂的,本人也是经过了很多次的学习,依然感觉一知半解,无法将所有的知识点串起来,本次学习呢,也是有很多的疑惑点,但是还是尽量的输出内容,来帮助自己和在看文章的你来…

word运行时错误‘53’,文件未找到:MathPage.WLL,更改加载项路径完美解决

最简单的方法解决!!!安装Mathtype之后粘贴显示:运行时错误‘53’,文件未找到:MathPage.WLLwin11安装mathtype后会有这个错误,这是由于word中加载项加载mathtype路径出错导致的,这时候…

React实现列表拖拽排序

本文主要介绍一下React实现列表拖拽排序方法,具体样式如下图首先,简单展示一下组件的数据结构 const CodeSetting props > {const {$t, // 国际化翻译函数vm, // 视图模型数据vm: {CodeSet: { Enable [], …

将 MySQL 表数据导出为 CSV 文件

目录 一、实现思路 二、核心代码 1. 数据库连接部分 2. 数据导出核心逻辑 3. CSV文件写入 三、完整代码实现 五、输出结果 一、实现思路 建立数据库连接 查询目标表的数据总量和具体数据 获取表的列名作为CSV文件的表头 将查询结果转换为二维数组格式 使用Hutool工具…

一文读懂RAG:从生活场景到核心逻辑,AI“查资料答题”原来这么简单

一文读懂RAG:从生活场景到核心逻辑,AI“查资料答题”原来这么简单 要理解 RAG(Retrieval-Augmented Generation,检索增强生成),不需要先背复杂公式,我们可以从一个生活场景切入——它本质是让AI…

git将当前分支推送到远端指定分支

在 Git 中&#xff0c;将当前本地分支推送到远程仓库的指定分支&#xff0c;可以使用 git push 命令&#xff0c;并指定本地分支和远程分支的映射关系。 基本语法 git push <远程名称> <本地分支名>:<远程分支名><远程名称>&#xff1a;通常是 origin&…

【Linux】线程封装

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 一、为什么需要封装线程库&#xff1f; pthread的痛点&#xff1a; 封装带来的好处&#xff1a; 二、线程封装核心代码解析 1. 头文件定义&#xff08;Thread.hpp&a…

智慧交通管理信号灯通信4G工业路由器应用

在交通信号灯管理中传统的有线通讯&#xff08;光纤、网线&#xff09;存在部署成本高、偏远区域覆盖难、故障维修慢等问题&#xff0c;而4G工业路由器凭借无线化、高稳定、强适配的特性&#xff0c;成为信号灯与管控平台间的数据传输核心&#xff0c;适配多场景需求。智慧交通…

《Python Flask 实战:构建一个可交互的 Web 应用,从用户输入到智能响应》

《Python Flask 实战:构建一个可交互的 Web 应用,从用户输入到智能响应》 一、引言:从“Hello, World!”到“你好,用户” 在 Web 应用的世界里,最打动人心的功能往往不是炫酷的界面,而是人与系统之间的真实互动。一个简单的输入框,一句个性化的回应,往往能让用户感受…

开发效率翻倍:资深DBA都在用的MySQL客户端利器

MySQL 连接工具&#xff08;也称为客户端或图形化界面工具&#xff0c;GUI Tools&#xff09;是数据库开发、管理和运维中不可或缺的利器。它们比命令行更直观&#xff0c;能极大提高工作效率。以下是一份主流的 MySQL 连接工具清单&#xff0c;并附上了它们的优缺点和适用场景…

基于Docker和Kubernetes的CI/CD流水线架构设计与优化实践

基于Docker和Kubernetes的CI/CD流水线架构设计与优化实践 本文分享了在生产环境中基于Docker和Kubernetes构建高效可靠的CI/CD流水线的实战经验&#xff0c;包括业务场景、技术选型、详细方案、踩坑与解决方案&#xff0c;以及最终的总结与最佳实践&#xff0c;帮助后端开发者快…

Trae x 图片素描MCP一键将普通图片转换为多风格素描效果

目录前言一、核心工具与优势解析二、操作步骤&#xff1a;从安装到生成素描效果第一步&#xff1a;获取MCP配置代码第二步&#xff1a;下载第三步&#xff1a;在 Trae 中导入 MCP 配置并建立连接第四步&#xff1a;核心功能调用三、三大素描风格差异化应用四.总结前言 在设计创…

2 XSS

XSS的原理 XSS&#xff08;跨站脚本攻击&#xff09;原理 1. 核心机制 XSS攻击的本质是恶意脚本在用户浏览器中执行。攻击者通过向网页注入恶意代码&#xff0c;当其他用户访问该页面时&#xff0c;浏览器会执行这些代码&#xff08;没有对用户的输入进行过滤导致用户输入的…

GitHub每日最火火火项目(9.3)

1. pedroslopez / whatsapp-web.js 项目名称&#xff1a;whatsapp-web.js项目介绍&#xff1a;基于 JavaScript 开发&#xff0c;是一个用于 Node.js 的 WhatsApp 客户端库&#xff0c;通过 WhatsApp Web 浏览器应用进行连接&#xff08;A WhatsApp client library for NodeJS …