目录
一、数据增强
1. 核心概念
2. 核心目的
3. 常用方法
4. 实现示例(基于 PyTorch)
5. 自定义数据集加载
二、保存最优模型
1. 核心概念
2. 实现步骤
(1)定义 CNN 模型
(2)定义训练与测试函数
(3)启动训练
3. 模型加载与使用
三、总结
在卷积神经网络(CNN)的训练过程中,数据增强和模型保存是提升性能与实用性的关键环节。以下结合理论与实例,详细解析其原理及实现方式。
一、数据增强
1. 核心概念
数据增强是通过对原始训练数据进行随机变换(如旋转、翻转、调整亮度等),生成新的训练样本的技术。其本质是扩展数据多样性,让模型在训练中接触更多 “变体”,从而提升泛化能力(减少过拟合)。
2. 核心目的
- 模拟真实场景中的变量(如光照变化、视角差异、遮挡等)。
- 解决训练数据不足的问题,通过 “人工扩充” 提升模型鲁棒性。
3. 常用方法
4. 实现示例(基于 PyTorch)
import torch
from torchvision import transforms# 定义训练集和验证集的数据增强策略
data_transforms = {'train': transforms.Compose([transforms.Resize([300, 300]), # 缩放图像transforms.RandomRotation(45), # 随机旋转(-45°~45°)transforms.CenterCrop(256), # 中心裁剪至256x256transforms.RandomHorizontalFlip(p=0.5), # 50%概率水平翻转transforms.ColorJitter(brightness=0.2, contrast=0.1), # 颜色调整transforms.ToTensor(), # 转换为Tensor(像素值归一化到[0,1])transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 标准化]),'valid': transforms.Compose([transforms.Resize([256, 256]), # 验证集仅缩放,不做随机增强transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}
5. 自定义数据集加载
通过继承Dataset
类,将增强策略应用于实际数据:
from torch.utils.data import Dataset
from PIL import Image
import numpy as npclass FoodDataset(Dataset):def __init__(self, file_path, transform=None):self.transform = transformself.imgs = []self.labels = []# 从txt文件读取图像路径和标签(格式:图像路径 标签)with open(file_path, 'r') as f:for line in f.readlines():img_path, label = line.strip().split(' ')self.imgs.append(img_path)self.labels.append(int(label))def __len__(self):return len(self.imgs)def __getitem__(self, idx):# 加载图像并应用增强image = Image.open(self.imgs[idx]).convert('RGB')if self.transform:image = self.transform(image)# 标签转换为Tensorlabel = torch.tensor(self.labels[idx], dtype=torch.long)return image, label# 加载训练集和验证集
train_dataset = FoodDataset('./train.1txt', transform=data_transforms['train'])
valid_dataset = FoodDataset('./test.1txt', transform=data_transforms['valid'])# 数据加载器(批量处理)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=64, shuffle=False)
其中train.1txt文件内容为:
test.1txt文件内容:
其中的每个文件地址都有其对应的图片,数据量较大,训练时间会较长,如需使用,可私信发送打包文件。
整篇文章所有代码连接为一份完整代码。
二、保存最优模型
1. 核心概念
训练过程中,模型性能(如验证集准确率)会随迭代波动。保存最优模型指在训练中跟踪关键指标(如最高准确率),并保存对应状态,以便后续直接使用最佳模型。
2. 实现步骤
(1)定义 CNN 模型
import torch.nn as nnclass SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()# 卷积层1:3通道输入→16通道输出,5x5卷积核self.conv1 = nn.Sequential(nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2) # 池化后尺寸减半)# 卷积层2:16通道→32通道self.conv2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(2))# 卷积层3:32通道→128通道(无池化)self.conv3 = nn.Sequential(nn.Conv2d(32, 128, kernel_size=5, stride=1, padding=2),nn.ReLU())# 全连接层:输入为128×64×64(经3次卷积+池化后的尺寸),输出20类self.fc = nn.Linear(128 * 64 * 64, 20)def forward(self, x):x = self.conv1(x) # 输出:16×128×128x = self.conv2(x) # 输出:32×64×64x = self.conv3(x) # 输出:128×64×64x = x.view(x.size(0), -1) # 展平为向量x = self.fc(x)return x# 初始化模型并移动到设备(GPU/CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SimpleCNN().to(device)
运行结果:
(2)定义训练与测试函数
# 损失函数与优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练函数
def train(dataloader, model, loss_fn, optimizer):model.train() # 开启训练模式(启用 dropout/batchnorm)for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)# 反向传播optimizer.zero_grad() # 清空梯度loss.backward() # 计算梯度optimizer.step() # 更新参数if batch % 100 == 0:print(f"Batch {batch}, Loss: {loss.item():.4f}")# 测试函数(含最优模型保存)
best_acc = 0.0 # 记录最佳准确率def test(dataloader, model, loss_fn):global best_accmodel.eval() # 开启评估模式(固定 dropout/batchnorm)size = len(dataloader.dataset)num_batches = len(dataloader)test_loss, correct = 0, 0with torch.no_grad(): # 关闭梯度计算,节省内存for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test: Accuracy: {(100*correct):.1f}%, Avg loss: {test_loss:.4f}")# 保存最优模型(准确率提升时)if correct > best_acc:best_acc = correct# 保存完整模型(含结构和参数)torch.save(model, "best_model.pt")# 或仅保存参数(更轻量):torch.save(model.state_dict(), "best_model.pth")
(3)启动训练
epochs = 150 # 训练轮数
for t in range(epochs):print(f"\nEpoch {t+1}/{epochs}")train(train_loader, model, loss_fn, optimizer)test(valid_loader, model, loss_fn)
print("训练完成!最优模型已保存为 best_model.pt")
3. 模型加载与使用
训练结束后,可直接加载最优模型进行预测:
# 加载保存的模型
loaded_model = torch.load("best_model.pt").to(device)
loaded_model.eval() # 切换至评估模式# 示例:对单张图像预测
def predict(image_path):image = Image.open(image_path).convert('RGB')# 应用验证集的预处理transform = data_transforms['valid']image = transform(image).unsqueeze(0).to(device) # 增加批次维度with torch.no_grad():pred = loaded_model(image)return pred.argmax(1).item() # 返回预测类别# 测试预测
print("预测类别:", predict("test_image.jpg"))
训练结束得到当前训练的最优模型,其为pt\pth\t7文件,此时该文件即为当前模型,可直接调用该文件使用。
三、总结
- 数据增强通过模拟真实场景变化,提升模型泛化能力,需注意训练集用随机增强、验证集仅做标准化。
- 保存最优模型通过跟踪验证集指标(如准确率),保留性能最佳的模型状态,避免训练后期过拟合导致的性能下降。
以上方法可直接应用于图像分类、目标检测等 CNN 任务,实际使用时需根据数据集特点调整增强策略和模型结构。