目录

一、数据增强

1. 核心概念

2. 核心目的

3. 常用方法

4. 实现示例(基于 PyTorch)

5. 自定义数据集加载

二、保存最优模型

1. 核心概念

2. 实现步骤

(1)定义 CNN 模型

(2)定义训练与测试函数

(3)启动训练

3. 模型加载与使用

三、总结


在卷积神经网络(CNN)的训练过程中,数据增强和模型保存是提升性能与实用性的关键环节。以下结合理论与实例,详细解析其原理及实现方式。

一、数据增强
1. 核心概念

数据增强是通过对原始训练数据进行随机变换(如旋转、翻转、调整亮度等),生成新的训练样本的技术。其本质是扩展数据多样性,让模型在训练中接触更多 “变体”,从而提升泛化能力(减少过拟合)。

2. 核心目的
  • 模拟真实场景中的变量(如光照变化、视角差异、遮挡等)。
  • 解决训练数据不足的问题,通过 “人工扩充” 提升模型鲁棒性。
3. 常用方法

4. 实现示例(基于 PyTorch)
import torch
from torchvision import transforms# 定义训练集和验证集的数据增强策略
data_transforms = {'train': transforms.Compose([transforms.Resize([300, 300]),  # 缩放图像transforms.RandomRotation(45),  # 随机旋转(-45°~45°)transforms.CenterCrop(256),     # 中心裁剪至256x256transforms.RandomHorizontalFlip(p=0.5),  # 50%概率水平翻转transforms.ColorJitter(brightness=0.2, contrast=0.1),  # 颜色调整transforms.ToTensor(),  # 转换为Tensor(像素值归一化到[0,1])transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 标准化]),'valid': transforms.Compose([transforms.Resize([256, 256]),  # 验证集仅缩放,不做随机增强transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}
5. 自定义数据集加载

通过继承Dataset类,将增强策略应用于实际数据:

from torch.utils.data import Dataset
from PIL import Image
import numpy as npclass FoodDataset(Dataset):def __init__(self, file_path, transform=None):self.transform = transformself.imgs = []self.labels = []# 从txt文件读取图像路径和标签(格式:图像路径 标签)with open(file_path, 'r') as f:for line in f.readlines():img_path, label = line.strip().split(' ')self.imgs.append(img_path)self.labels.append(int(label))def __len__(self):return len(self.imgs)def __getitem__(self, idx):# 加载图像并应用增强image = Image.open(self.imgs[idx]).convert('RGB')if self.transform:image = self.transform(image)# 标签转换为Tensorlabel = torch.tensor(self.labels[idx], dtype=torch.long)return image, label# 加载训练集和验证集
train_dataset = FoodDataset('./train.1txt', transform=data_transforms['train'])
valid_dataset = FoodDataset('./test.1txt', transform=data_transforms['valid'])# 数据加载器(批量处理)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=64, shuffle=False)

其中train.1txt文件内容为:

test.1txt文件内容:

  其中的每个文件地址都有其对应的图片,数据量较大,训练时间会较长,如需使用,可私信发送打包文件。

        整篇文章所有代码连接为一份完整代码。

二、保存最优模型
1. 核心概念

训练过程中,模型性能(如验证集准确率)会随迭代波动。保存最优模型指在训练中跟踪关键指标(如最高准确率),并保存对应状态,以便后续直接使用最佳模型。

2. 实现步骤
(1)定义 CNN 模型
import torch.nn as nnclass SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()# 卷积层1:3通道输入→16通道输出,5x5卷积核self.conv1 = nn.Sequential(nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2)  # 池化后尺寸减半)# 卷积层2:16通道→32通道self.conv2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(2))# 卷积层3:32通道→128通道(无池化)self.conv3 = nn.Sequential(nn.Conv2d(32, 128, kernel_size=5, stride=1, padding=2),nn.ReLU())# 全连接层:输入为128×64×64(经3次卷积+池化后的尺寸),输出20类self.fc = nn.Linear(128 * 64 * 64, 20)def forward(self, x):x = self.conv1(x)  # 输出:16×128×128x = self.conv2(x)  # 输出:32×64×64x = self.conv3(x)  # 输出:128×64×64x = x.view(x.size(0), -1)  # 展平为向量x = self.fc(x)return x# 初始化模型并移动到设备(GPU/CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SimpleCNN().to(device)

运行结果:

(2)定义训练与测试函数
# 损失函数与优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练函数
def train(dataloader, model, loss_fn, optimizer):model.train()  # 开启训练模式(启用 dropout/batchnorm)for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)# 反向传播optimizer.zero_grad()  # 清空梯度loss.backward()        # 计算梯度optimizer.step()       # 更新参数if batch % 100 == 0:print(f"Batch {batch}, Loss: {loss.item():.4f}")# 测试函数(含最优模型保存)
best_acc = 0.0  # 记录最佳准确率def test(dataloader, model, loss_fn):global best_accmodel.eval()  # 开启评估模式(固定 dropout/batchnorm)size = len(dataloader.dataset)num_batches = len(dataloader)test_loss, correct = 0, 0with torch.no_grad():  # 关闭梯度计算,节省内存for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test: Accuracy: {(100*correct):.1f}%, Avg loss: {test_loss:.4f}")# 保存最优模型(准确率提升时)if correct > best_acc:best_acc = correct# 保存完整模型(含结构和参数)torch.save(model, "best_model.pt")# 或仅保存参数(更轻量):torch.save(model.state_dict(), "best_model.pth")
(3)启动训练
epochs = 150  # 训练轮数
for t in range(epochs):print(f"\nEpoch {t+1}/{epochs}")train(train_loader, model, loss_fn, optimizer)test(valid_loader, model, loss_fn)
print("训练完成!最优模型已保存为 best_model.pt")
3. 模型加载与使用

训练结束后,可直接加载最优模型进行预测:

# 加载保存的模型
loaded_model = torch.load("best_model.pt").to(device)
loaded_model.eval()  # 切换至评估模式# 示例:对单张图像预测
def predict(image_path):image = Image.open(image_path).convert('RGB')# 应用验证集的预处理transform = data_transforms['valid']image = transform(image).unsqueeze(0).to(device)  # 增加批次维度with torch.no_grad():pred = loaded_model(image)return pred.argmax(1).item()  # 返回预测类别# 测试预测
print("预测类别:", predict("test_image.jpg"))

 训练结束得到当前训练的最优模型,其为pt\pth\t7文件,此时该文件即为当前模型,可直接调用该文件使用。

三、总结
  • 数据增强通过模拟真实场景变化,提升模型泛化能力,需注意训练集用随机增强、验证集仅做标准化。
  • 保存最优模型通过跟踪验证集指标(如准确率),保留性能最佳的模型状态,避免训练后期过拟合导致的性能下降。

以上方法可直接应用于图像分类、目标检测等 CNN 任务,实际使用时需根据数据集特点调整增强策略和模型结构。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/97859.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/97859.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/97859.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

tcpdump用法

tcpdump用法tcpdump一、什么是tcpdump二、命令格式与参数三、参数列表四、过滤规则组合逻辑运算符过滤器关键字理解 Flag 标识符五、常用例子tcpdump 一、什么是tcpdump 二、命令格式与参数 option 可选参数:将在后边一一解释。 proto 类过滤器:根据协…

平衡车 - 电机调速

🌈个人主页:羽晨同学 💫个人格言:“成为自己未来的主人~” 在我们的这篇文章当中,我们主要想要实现的功能的是电机调速功能。在我们的这篇文章中,主要实现的是开环的功能,而非闭环,也就是不加…

从利润率看价值:哪些公司值得长期持有?

💡 为什么盯紧利润率? 投资者常常盯着营收增长,却忽略了一个更关键的指标——利润率。 收入可以靠规模“堆”出来,但利润率却是企业护城河的真实体现。心理学研究表明:当一个产品或服务被消费者认定为“不可替代”&a…

小迪web自用笔记25

传统文件上传:上传至服务器本身硬盘。云存储:借助云存储oss对象存储(只能被访问,不可解析)Oss云存储Access key与Access ID:有了这两个东西之后就可以操作云存储,可以向里面发数据了。这玩意儿泄…

分发饼干——很好的解释模板

好的,孩子,我们来玩一个“喂饼干”的游戏。 0. 问题的本质是什么? 想象一下,你就是个超棒的家长,手里有几块大小不一的饼干,而面前有几个饿着肚子的小朋友。每个小朋友都有一个最小的“胃口”值&#xff0c…

场景题:如果一个大型项目,某一个时间所有的CPU的已经被占用了,导致服务不可用,我们开发人员应该如何使服务器尽快恢复正常

问:如果一个大型项目,某一个时间所有的CPU的 已经被占用了,导致服务不可用,我们开发人员 应该如何使服务器尽快恢复正常答:应对CPU 100%导致服务不可用的紧急恢复流程面试官,如果遇到这种情况,我会立即按照…

Docker 安装 RAGFlow保姆教程

前提条件 Ubuntu 服务器(20.04 或 22.04 LTS 推荐) 已安装 Docker 和 Docker Compose 如果尚未安装,请先运行以下命令:# 安装 Docker curl -fsSL https://get.docker.com -o get-docker.sh sudo sh get-docker.sh # 将当前用户加入 docker 组,避免每次都要 sudo sudo user…

为什么实际工程里 C++ 部署深度学习模型更常见?为什么大家更爱用 TensorRT?

很多人刚接触深度学习模型部署的时候,都会习惯用 Python,因为训练的时候就是 PyTorch、TensorFlow 啊,写起来方便。但一到 实际工程,特别是工业设备、医疗影像、上位机系统这种场景,你会发现大多数人都转向了 C 部署。…

深入理解 Java 集合框架:底层原理与实战应用

在日常开发中,集合是 Java 中使用频率最高的工具之一。从最常见的 ArrayList、HashMap 到更复杂的并发集合,几乎每一个 Java 程序员都离不开集合框架。集合框架不仅提供了丰富的数据结构实现,还封装了底层复杂的逻辑,让开发者能够…

爬取m3u8视频完整教程

爬取步骤:1.先找到网页源代码2.从网页源代码中拿到m3u83.下载m3u84.读取m3u8文件,下载视频5.合并视频首先我们来爬取一个星辰影院的电影:下面我以这个为例:我们需要在源代码中找到m3u8这个url:紧接着我们利用下面的方法…

Python爬虫实战: 基于Scrapy的Amazon跨境电商选品数据爬虫方案

概述与设计思路 利用Python的Scrapy框架进行大规模页面抓取和结构化数据提取,配合aiohttp实现高并发请求,从而高效获取Amazon平台上的商品列表、详情、评论等公开信息。通过对这些数据进行清洗与分析,可以识别出有潜力的商品,评估市场竞争程度,并跟踪竞争对手的动态,为跨…

稳定版IM即时通讯 仿默往APP即时通讯im源码聊天社交源码支持二开原生开发独立部署 含搭建教程

内容目录一、详细介绍二、效果展示1.部分代码2.效果图展示三、学习资料下载一、详细介绍 技术开发语言: 后台管理端:Java GO Mysql数据库 安卓端:Java iOS端:ob PC端:c 功能简单介绍: 单聊&#xff…

封装一个redis获取并解析数据的工具类

redis获取并解析数据工具类实现代码使用示例实现代码 import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.StrUtil; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.TypeReference; import lom…

23种设计模式——策略模式 (Strategy Pattern)​详解

✅作者简介:大家好,我是 Meteors., 向往着更加简洁高效的代码写法与编程方式,持续分享Java技术内容。 🍎个人主页:Meteors.的博客 💞当前专栏:设计模式 ✨特色专栏:知识分享 &#x…

CI(持续集成)、CD(持续交付/部署)、CT(持续测试)、CICD、CICT

目录 **CI、CD、CT 详解与关系** **1. CI(Continuous Integration,持续集成)** **2. CD(Continuous Delivery/Deployment,持续交付/部署)** **持续交付(Continuous Delivery)** **持续部署(Continuous Deployment)** **3. CT(Continuous Testing,持续测试)** **4.…

【音视频】WebRTC ICE 模块深度剖析

原文链接: https://mp.weixin.qq.com/s?__bizMzIzMjY3MjYyOA&mid2247498075&idx2&sn6021a2f60b1e7c71ce4d7af6df0b9b89&chksme893e540dfe46c56323322e780d41aec1f851925cfce8b76b3f4d5cfddaa9c7cbb03a7ae4c25&scene178&cur_album_id314699…

linux0.12 head.s代码解析

重新设置IDT和GDT,为256个中断门设置默认的中断处理函数检查A20地址线是否启用设置数学协处理器将main函数相关的参数压栈设置分页机制,将页表映射到0~16MB的物理内存上返回main函数执行 源码详细注释如下: /** linux/boot/head.s** (C) 1991 Linus T…

Maven动态控制版本号秘籍:高效发包部署,版本管理不再头疼!

作者:唐叔在学习 专栏:唐叔的Java实践 关键词:Maven版本控制、versions插件、动态版本号、持续集成、自动化部署、Java项目管理 摘要:本文介绍如何使用Maven Versions插件动态控制项目版本号和依赖组件版本号,实现无需…

简述:普瑞时空数据建库软件(国土变更建库)之一(变更预检查部分规则)

简述:普瑞时空数据建库软件(国土变更建库)之一(变更预检查部分规则) 主要包括三种类型:常规检查、行政区范围检查、20X异常灭失检查 本blog地址:https://blog.csdn.net/hsg77

shell中命令小工具:cut、sort、uniq,tr的使用方式

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录前言一、cut —— 按列或字符截取1. 常用选项2. 示例二、sort —— 排序(默认按行首字符升序)1. 常用选项常用 sort 命令选项三、uniq —— 去…