文章目录

    • 引言
    • 一、项目概述
    • 二、环境配置
    • 三、数据预处理
      • 3.1 数据转换设置
      • 3.2 数据集准备
    • 四、自定义数据集类
    • 五、CNN模型架构
    • 六、训练与评估流程
      • 6.1 训练函数
      • 6.2 评估与模型保存
    • 七、完整训练流程
    • 八、模型保存与加载
      • 8.1 保存模型
      • 8.2 加载模型
    • 九、优化建议
    • 十、常见问题解决
    • 十一、完整代码
    • 十二、总结

引言

本文将详细介绍如何使用PyTorch框架构建一个完整的食物图像分类系统,包含数据预处理、模型构建、训练优化以及模型保存等关键环节。与上一篇博客介绍的版本相比,本版本增加了模型保存与加载功能,并优化了测试评估流程。

一、项目概述

本项目的目标是构建一个能够识别20种不同食物的图像分类系统。主要技术特点包括:

  1. 简化但高效的数据预处理流程
  2. 三层CNN网络架构设计
  3. 训练过程中自动保存最佳模型
  4. 完整的训练-评估流程实现

二、环境配置

首先确保已安装必要的Python库:

import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import os

三、数据预处理

3.1 数据转换设置

我们为训练集和验证集定义了不同的转换策略:

data_transforms = {'train': transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),]),'valid': transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),]),
}

简化说明

  • 本版本简化了数据增强,仅保留基本的resize和tensor转换
  • 实际应用中可根据需求添加更多增强策略

3.2 数据集准备

def train_test_file(root, dir):file_txt = open(dir+'.txt','w')path = os.path.join(root,dir)for roots, directories, files in os.walk(path):if len(directories) != 0:dirs = directorieselse:now_dir = roots.split('\\')for file in files:path_1 = os.path.join(roots,file)file_txt.write(path_1+' '+str(dirs.index(now_dir[-1]))+'\n')file_txt.close()

该函数会生成包含图像路径和标签的文本文件,格式为:

path/to/image1.jpg 0
path/to/image2.jpg 1
...

四、自定义数据集类

我们继承PyTorch的Dataset类实现自定义数据集:

class food_dataset(Dataset):def __init__(self, file_path, transform=None):self.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs)def __getitem__(self, idx):image = Image.open(self.imgs[idx])if self.transform:image = self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label, dtype=np.int64))return image, label

关键改进

  • 更清晰的数据加载逻辑
  • 完善的类型转换处理
  • 支持灵活的数据变换

五、CNN模型架构

我们设计了一个三层CNN网络:

class CNN(nn.Module):def __init__(self):super(CNN,self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 16, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2))self.conv2 = nn.Sequential(nn.Conv2d(16,32,5,1,2),nn.ReLU(),nn.MaxPool2d(2))self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2))self.out = nn.Linear(64*32*32, 20)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)return self.out(x)

架构特点

  1. 每层包含卷积、ReLU激活和最大池化
  2. 使用padding保持特征图尺寸
  3. 最后通过全连接层输出分类结果

六、训练与评估流程

6.1 训练函数

def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()if batch_size_num % 1 == 0:print(f"loss: {loss.item():>7f} [batch:{batch_size_num}]")batch_size_num += 1

6.2 评估与模型保存

best_acc = 0def Test(dataloader, model, loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= size# 保存最佳模型if correct > best_acc:best_acc = correcttorch.save(model.state_dict(), "best_model.pth")print(f"\n测试结果: \n 准确率:{(100*correct):.2f}%, 平均损失:{test_loss:.4f}")

关键改进

  1. 增加全局变量best_acc跟踪最佳准确率
  2. 实现两种模型保存方式:
    • 只保存模型参数(state_dict)
    • 保存整个模型
  3. 更详细的测试结果输出

七、完整训练流程

# 初始化
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model = CNN().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练循环
epochs = 10
for t in range(epochs):print(f"Epoch {t+1}\n{'-'*20}")train(train_dataloader, model, loss_fn, optimizer)# 最终评估
Test(test_dataloader, model, loss_fn)

八、模型保存与加载

8.1 保存模型

# 方法1:只保存参数
torch.save(model.state_dict(), "model_params.pth")# 方法2:保存完整模型
torch.save(model, "full_model.pt")

8.2 加载模型

# 方法1对应加载方式
model = CNN().to(device)
model.load_state_dict(torch.load("model_params.pth"))# 方法2对应加载方式
model = torch.load("full_model.pt").to(device)

九、优化建议

  1. 数据增强:添加更多变换提高模型泛化能力
  2. 学习率调度:使用torch.optim.lr_scheduler动态调整学习率
  3. 早停机制:当验证集性能不再提升时停止训练
  4. 模型微调:使用预训练模型作为基础

十、常见问题解决

  1. 内存不足

    • 减小batch size
    • 使用梯度累积
    • 尝试混合精度训练
  2. 过拟合

    • 增加Dropout层
    • 添加L2正则化
    • 使用更多数据增强
  3. 训练不稳定

    • 检查数据标准化
    • 调整学习率
    • 检查损失函数

十一、完整代码

import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import osdata_transforms = { #字典'train':transforms.Compose([            #对图片预处理的组合transforms.Resize([256,256]),   #对数据进行改变大小transforms.ToTensor(),          #数据转换为tensor]),'valid':transforms.Compose([transforms.Resize([256,256]),transforms.ToTensor(),]),
}def train_test_file(root,dir):file_txt = open(dir+'.txt','w')path = os.path.join(root,dir)for roots,directories,files in os.walk(path):if len(directories) !=0:dirs = directorieselse:now_dir = roots.split('\\')for file in files:path_1 = os.path.join(roots,file)print(path_1)file_txt.write(path_1+' '+str(dirs.index(now_dir[-1]))+'\n')file_txt.close()root = r'.\食物分类\food_dataset'
train_dir = 'train'
test_dir = 'test'
train_test_file(root,train_dir)
train_test_file(root,test_dir)#Dataset是用来处理数据的
class food_dataset(Dataset):        # food_dataset是自己创建的类名称,可以改为你需要的名称def __init__(self,file_path,transform=None):    #类的初始化,解析数据文件txtself.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f: #是把train.txt文件中的图片路径保存在self.imgssamples = [x.strip().split(' ') for x in f.readlines()]for img_path,label in samples:self.imgs.append(img_path)  #图像的路径self.labels.append(label)   #标签,还不是tensor# 初始化:把图片目录加到selfdef __len__(self):  #类实例化对象后,可以使用len函数测量对象的个数return  len(self.imgs)#training_data[1]def __getitem__(self, idx):    #关键,可通过索引的形式获取每一个图片的数据及标签image = Image.open(self.imgs[idx])  #读取到图片数据,还不是tensor,BGRif self.transform:                  #将PIL图像数据转换为tensorimage = self.transform(image)   #图像处理为256*256,转换为tensorlabel = self.labels[idx]    #label还不是tensorlabel = torch.from_numpy(np.array(label,dtype=np.int64))    #label也转换为tensorreturn image,label
#training_data包含了本次需要训练的全部数据集
training_data = food_dataset(file_path='train.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path='test.txt', transform=data_transforms['valid'])#training_data需要具备索引的功能,还要确保数据是tensor
train_dataloader = DataLoader(training_data,batch_size=16,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=16,shuffle=True)'''判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU'''
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")   #字符串的格式化,CUDA驱动软件的功能:pytorch能够去执行cuda的命令
# 神经网络的模型也需要传入到GPU,1个batch_size的数据集也需要传入到GPU,才可以进行训练''' 定义神经网络  类的继承这种方式'''
class CNN(nn.Module): #通过调用类的形式来使用神经网络,神经网络的模型,nn.mdouledef __init__(self): #输入大小:(3,256,256)super(CNN,self).__init__()  #初始化父类self.conv1 = nn.Sequential( #将多个层组合成一起,创建了一个容器,将多个网络组合在一起nn.Conv2d(              # 2d一般用于图像,3d用于视频数据(多一个时间维度),1d一般用于结构化的序列数据in_channels=3,      # 图像通道个数,1表示灰度图(确定了卷积核 组中的个数)out_channels=16,     # 要得到多少个特征图,卷积核的个数kernel_size=5,      # 卷积核大小 3×3stride=1,           # 步长padding=2,          # 一般希望卷积核处理后的结果大小与处理前的数据大小相同,效果会比较好),                      # 输出的特征图为(16,256,256)nn.ReLU(),  # Relu层,不会改变特征图的大小nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2×2操作),输出结果为(16,128,128))self.conv2 = nn.Sequential(nn.Conv2d(16,32,5,1,2),  #输出(32,128,128)nn.ReLU(),  #Relu层  (32,128,128)nn.MaxPool2d(kernel_size=2),    #池化层,输出结果为(32,64,64))self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),  # 输出(64,64,64)nn.ReLU(),  # Relu层  (64,64,64)nn.MaxPool2d(kernel_size=2),  # 池化层,输出结果为(64,32,32))self.out = nn.Linear(64*32*32,20)  # 全连接层得到的结果def forward(self,x):   #前向传播,你得告诉它 数据的流向 是神经网络层连接起来,函数名称不能改x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0),-1)    # flatten操作,结果为:(batch_size,32 * 64 * 64)output = self.out(x)return output# 提取模型的2种方法:
#   1、读取参数的方法
model = CNN().to(device) #初始化模型,w都是随机初始化的
# model.load_state_dict(torch.load("best.pth"))
#   2、读取完整模型的方法,无需提前创建model
#   model = CNN().to(device)
#   model = torch.load('best.pt')#w,b,cnn
# 模型保存的对不对?
# model.eval() #固定模型参数和数据,防止后面被修改
print(model)def train(dataloader,model,loss_fn,optimizer):model.train() #告诉模型,我要开始训练,模型中w进行随机化操作,已经更新w,在训练过程中,w会被修改的
# pytorch提供2种方式来切换训练和测试的模式,分别是:model.train() 和 mdoel.eval()
# 一般用法是:在训练开始之前写上model.train(),在测试时写上model.eval()batch_size_num = 1for X,y in dataloader:              #其中batch为每一个数据的编号X,y = X.to(device),y.to(device) #把训练数据集和标签传入cpu或GPUpred = model.forward(X)         # .forward可以被省略,父类种已经对此功能进行了设置loss = loss_fn(pred,y)          # 通过交叉熵损失函数计算损失值loss# Backpropagation 进来一个batch的数据,计算一次梯度,更新一次网络optimizer.zero_grad()           # 梯度值清零loss.backward()                 # 反向传播计算得到每个参数的梯度值woptimizer.step()                # 根据梯度更新网络w参数loss_value = loss.item()        # 从tensor数据种提取数据出来,tensor获取损失值if batch_size_num %1 ==0:print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1best_acc = 0def Test(dataloader,model,loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)  # 打包的数量model.eval()        #测试,w就不能再更新test_loss,correct =0,0with torch.no_grad():       #一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候for X,y in dataloader:X,y = X.to(device),y.to(device)pred = model(X) #等价于model.forward(X)test_loss += loss_fn(pred,y).item() #test_loss是会自动累加每一个批次的损失值correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y) #dim=1表示每一行中的最大值对应的索引号,dim=0表示每一列中的最大值对应的索引号b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batches #能来衡量模型测试的好坏correct /= size  #平均的正确率# 保存最优模型的2种方法(模型的文件扩展名一般:pt\pth,t7) #opencvif correct > best_acc:best_acc = correct
#   1.保存模型参数方法:torch.save(model.state_dict(),path)  (w,b)
#         print(model.state_dict().keys())    #输出模型参数名称   cnntorch.save(model.state_dict(), "best2025-04.pth")
#   2.保存完整模型(w,b,模型cnn)
#         torch.save(model,'best.pt')print(f"\n最终测试结果: \n 准确率:{(100*correct):.2f}%, 平均损失:{test_loss:.4f}")loss_fn = nn.CrossEntropyLoss()  #创建交叉熵损失函数对象,因为手写字识别一共有十种数字,输出会有10个结果optimizer = torch.optim.Adam(model.parameters(),lr=0.001) #创建一个优化器,SGD为随机梯度下降算法
# # params:要训练的参数,一般我们传入的都是model.parameters()
# # lr:learning_rate学习率,也就是步长
#
# # loss表示模型训练后的输出结果与样本标签的差距。如果差距越小,就表示模型训练越好,越逼近真实的模型
# train(train_dataloader,model,loss_fn,optimizer) #训练1次完整的数据。多轮训练
# Test(test_dataloader,model,loss_fn)epochs = 10
for t in range(epochs):print(f"epoch {t+1}\n---------------")train(train_dataloader,model,loss_fn,optimizer)
print("Done!")
Test(test_dataloader,model,loss_fn)

十二、总结

本文详细介绍了使用PyTorch实现食物分类的完整流程,重点讲解了:

  1. 自定义数据集的处理方法
  2. CNN网络的设计与实现
  3. 训练过程中的模型保存策略
  4. 完整的训练-评估流程

通过本教程,读者可以掌握PyTorch图像分类的基本方法,并能够根据实际需求进行调整和优化。完整代码已包含在文章中,建议在实际应用中根据具体数据集调整相关参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/85007.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/85007.shtml
英文地址,请注明出处:http://en.pswp.cn/bicheng/85007.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《棒球百科》棒球怎么玩·棒球9号位

用最简单的方式介绍棒球的核心玩法和规则,完全零基础也能看懂: 一句话目标 进攻方:用球棒把球打飞,然后拼命跑完4个垒包(逆时针绕一圈)得分。 防守方:想尽办法让进攻方出局,阻止他…

语言模型是怎么工作的?通俗版原理解读!

大模型为什么能聊天、写代码、懂医学? 我们从四个关键模块,一步步拆开讲清楚 👇 ✅ 模块一:模型的“本事”从哪来?靠训练数据 别幻想它有意识,它的能力,全是“喂”出来的: 吃过成千…

nrf52811墨水屏edp_service.c文件学习

on_connect函数 /**brief Function for handling the ref BLE_GAP_EVT_CONNECTED event from the S110 SoftDevice.** param[in] p_epd EPD Service structure.* param[in] p_ble_evt Pointer to the event received from BLE stack.*/ static void on_connect(ble_epd_t …

Nginx-2 详解处理 Http 请求

Nginx-2 详解处理 Http 请求 Nginx 作为当今最流行的开源 Web 服务器之一,以其高性能、高稳定性和丰富的功能而闻名。在处理 HTTP请求 的过程中,Nginx 采用了模块化的设计,将整个请求处理流程划分为若干个阶段,每个阶段都可以由特…

40-Oracle 23 ai Bigfile~Smallfile-Basicfile~Securefile矩阵对比

小伙伴们是不是在文件选择上还默认给建文件4G/个么,在oracle每个版本上系统默认属性是什么,选择困难症了没,一起一次性文件存储和默认属性看透。 基于Oracle历代在存储架构的技术演进分析,结合版本升级和23ai新特性,一…

【一】零基础--分层强化学习概览

分层强化学习(Hierarchical Reinforcement Learning, HRL)最早一般视为1993 年封建强化学习的提出. 一、HL的基础理论 1.1 MDP MDP(马尔可夫决策过程):MDP是一种用于建模序列决策问题的框架,包含状态&am…

Java延时

在 Java 中实现延时操作主要有以下几种方式,根据使用场景选择合适的方法: 1. Thread.sleep()(最常用) java 复制 下载 try {// 延时 1000 毫秒(1秒)Thread.sleep(1000); } catch (InterruptedExcepti…

电阻篇---下拉电阻的取值

下拉电阻的取值需要综合考虑电路驱动能力、功耗、信号完整性、噪声容限等多方面因素。以下是详细的取值分析及方法: 一、下拉电阻的核心影响因素 1. 驱动能力与电流限制 单片机 IO 口驱动能力:如 STM32 的 IO 口在输入模式下的漏电流通常很小&#xf…

NY271NY274美光科技固态NY278NY284

美光科技NY系列固态硬盘深度剖析:技术、市场与未来 技术前沿:232层NAND架构与性能突破 在存储技术的赛道上,美光科技(Micron)始终是行业领跑者。其NY系列固态硬盘(SSD)凭借232层NAND闪存架构的…

微信开发者工具 插件未授权使用,user uni can not visit app

参考:https://www.jingpinma.cn/archives/159.html 问题描述 我下载了一个别人的小程序,想运行看看效果,结果报错信息如下 原因 其实就是插件没有安装,需要到小程序平台安装插件。处理办法如下 在 app.json 里,声…

UE5 读取配置文件

使用免费的Varest插件,可以读取本地的json数据 获取配置文件路径:当前配置文件在工程根目录,打包后在 Windows/项目名称 下 读取json 打包后需要手动复制配置文件到Windows/项目名称 下

【kdump专栏】KEXEC机制中SME(安全内存加密)

【kdump专栏】KEXEC机制中SME&#xff08;安全内存加密&#xff09; 原始代码&#xff1a; /* Ensure that these pages are decrypted if SME is enabled. */ 533 if (pages) 534 arch_kexec_post_alloc_pages(page_address(pages), 1 << order, 0);&#x1f4cc…

C# vs2022 找不到指定的 SDK“Microsof.NET.Sdk

找不到指定的 SDK"Microsof.NET.Sdk 第一查 看 系统盘目录 C:\Program Files\dotnet第二 命令行输入 dotnet --version第三 检查环境变量总结 只要执行dotnet --version 正常返回版本号此问题即解决 第一查 看 系统盘目录 C:\Program Files\dotnet 有2种方式 去检查 是否…

Pytest断言全解析:掌握测试验证的核心艺术

Pytest断言全解析&#xff1a;掌握测试验证的核心艺术 一、断言的本质与重要性 什么是断言&#xff1f; 断言是自动化测试中的验证检查点&#xff0c;用于确认代码行为是否符合预期。在Pytest中&#xff0c;断言直接使用Python原生assert语句&#xff0c;当条件不满足时抛出…

【编译原理】题目合集(一)

未经许可,禁止转载。 文章目录 选择填空综合选择 将编译程序分成若干个“遍”是为了 (D.利用有限的机器内存,但降低了执行效率) A.提高程序的执行效率 B.使程序的结构更加清晰 C.利用有限的机器内存并提高执行效率 D.利用有限的机器内存,但降低了执行效率 词法分析…

uni-app项目实战笔记13--全屏页面的absolute定位布局和fit-content自适应内容宽度

本篇主要实现全屏页面的布局&#xff0c;其中还涉及内容自适应宽度。 创建一个preview.vue页面用于图片预览&#xff0c;写入以下代码&#xff1a; <template><view class"preview"><swiper circular><swiper-item v-for"item in 5&quo…

OVS Faucet Tutorial笔记(下)

官方文档&#xff1a; OVS Faucet Tutorial 5、Routing Faucet Router 通过控制器模拟三层网关&#xff0c;提供 ARP 应答、路由转发功能。 5.1 控制器配置 5.1.1 编辑控制器yaml文件&#xff0c;增加router配置 rootserver1:~/faucet/inst# vi faucet.yaml dps:switch-1:d…

PCB设计教程【大师篇】stm32开发板PCB布线(信号部分)

前言 本教程基于B站Expert电子实验室的PCB设计教学的整理&#xff0c;为个人学习记录&#xff0c;旨在帮助PCB设计新手入门。所有内容仅作学习交流使用&#xff0c;无任何商业目的。若涉及侵权&#xff0c;请随时联系&#xff0c;将会立即处理 1. 布线优先级与原则 - 遵循“重…

Phthon3 学习记录-0613

List&#xff08;列表&#xff09;、Tuple&#xff08;元组&#xff09;、Set&#xff08;集合&#xff09;和 Dictionary&#xff08;字典&#xff09; 在接口自动化测试中&#xff0c;List&#xff08;列表&#xff09;、Tuple&#xff08;元组&#xff09;、Set&#xff08…

UVa12298 3KP-BASH Project

UVa12298 3KP-BASH Project 题目链接题意输入格式输出格式 分析AC 代码 题目链接 UVa12298 3KP-BASH Project 题意 摘自 《算法竞赛入门经典&#xff1a;训练指南》刘汝佳&#xff0c;陈锋著。有删改。 你的任务是为一个假想的 3KP 操作系统编写一个简单的 Bash 模拟器。由于操…