在后续的系列文章中,我们将逐步深入探讨 VGG16 相关的核心内容,具体涵盖以下几个方面:

  1. 卷积原理篇:详细剖析 VGG 的 “堆叠小卷积核” 设计理念,深入解读为何 3×3×2 卷积操作等效于 5×5 卷积,以及 3×3×3 卷积操作等效于 7×7 卷积。

  2. 架构设计篇:运用 PyTorch 精确定义 VGG16 类,深入解析 “Conv - BN - ReLU - Pooling” 这一标准模块的构建原理与实现方式。

3. 训练实战篇:在小规模医学影像数据集上对 VGG16 模型进行严格验证,并精心调优如 batch_size、学习率等关键超参数,以实现模型性能的最优化。

若您希望免费获取本系列文章的完整代码,可通过添加 V 信:18983376561 来获取。

一、VGG16 架构

VGG16 作为卷积神经网络中的经典架构,其结构清晰且具有强大的特征提取能力。下面是 VGG16 的架构图:

二、训练流程与代码解析

1. 数据预处理:让图像适应模型输入

CIFAR-10是一个更接近普适物体的彩色图像数据集。CIFAR-10 是由Hinton 的学生Alex Krizhevsky 和Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含10 个类别的RGB 彩色图片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。每个图片的尺寸为32 × 32 ,每个类别有6000个图像,数据集中一共有50000 张训练图片和10000 张测试图片。

然而,VGG16 模型原设计是针对 224x224 的图像输入。为了使 CIFAR10 数据集能够适配 VGG16 模型,我们需要对图像进行预处理。具体而言,通过transforms.Resize((224, 224))将图像缩放至 224x224 的尺寸,再利用Normalize进行标准化处理,将均值和标准差均设为 0.5,从而使像素值归一化到 [-1, 1] 区间。以下是关键代码片段:

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchinfo import summaryfrom VGG16 import VGG16
device = torch.device('cuda')transform_train = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

2. 数据加载:高效读取与批量处理

为了实现数据的高效读取与批量处理,我们使用DataLoader来加载数据。设置batch_size = 128,以平衡内存使用和训练效率;同时,设置num_workers = 12,利用多线程技术加速数据读取过程。对于训练集,我们将shuffle参数设置为True,打乱数据顺序,避免模型记忆数据顺序而导致过拟合;对于测试集,将shuffle参数设置为False,保持数据顺序,便于结果的复现和评估。以下是具体代码:

train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(train, batch_size=128, shuffle=True, num_workers=12)test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(test, batch_size=128, shuffle=False, num_workers=12)

3. 模型构建:调用自定义 VGG16 网络

在代码中,我们假设VGG16类已经被正确定义,该类应包含 16 层卷积层和全连接层结构。通过model.to(device)将模型部署到 GPU 上进行训练,以加速训练过程。由于 CIFAR10 是一个 10 分类任务,因此模型的最终全连接层输出维度应为 10。如果没有可用的 GPU,需要将device设置为cpu,但训练速度会显著降低。

4. 训练配置:损失函数与优化策略

在训练过程中,我们需要选择合适的损失函数和优化策略来指导模型的学习。具体配置如下:

  • 损失函数:选用CrossEntropyLoss来处理多分类问题,该损失函数会自动整合 Softmax 计算,简化了代码实现。
  • 优化器:选择随机梯度下降(SGD)作为优化器,设置学习率lr = 0.1,动量momentum = 0.9以加速收敛过程,同时设置权重衰减weight_decay = 0.0001,采用 L2 正则化防止模型过拟合。
  • 学习率调度器:使用ReduceLROnPlateau根据验证损失自动调整学习率。当验证损失连续 5 个 epoch 未下降时,学习率将乘以 0.1(factor = 0.1),这样可以避免模型陷入局部最优解。以下是相关代码:
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
model = VGG16().to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.1, patience=5)EPOCHS = 200
for epoch in range(EPOCHS):losses = []running_loss = 0for i, inp in enumerate(trainloader):inputs, labels = inpinputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)losses.append(loss.item())loss.backward()optimizer.step()running_loss += loss.item()if i % 100 == 0 and i > 0:print(f'Loss [{epoch + 1}, {i}](epoch, minibatch): ', running_loss / 100)running_loss = 0.0avg_loss = sum(losses) / len(losses)scheduler.step(avg_loss)

5. 训练循环:迭代优化与监控

在 200 个 epoch 的训练过程中,我们每 100 个批次打印一次平均损失,以便实时监控模型的训练进度。从输出日志可以看出,模型初始损失较高(第 1 个 epoch 约为 2.3),随着训练的不断进行,损失逐渐下降,最终损失趋近于 0.001 左右,这表明模型对训练数据的拟合效果良好。

print('Training Done')
# Loss [1, 100](epoch, minibatch):  3.8564858746528627
# Loss [1, 200](epoch, minibatch):  2.307221896648407
# Loss [1, 300](epoch, minibatch):  2.304955897331238
# Loss [2, 100](epoch, minibatch):  2.3278213500976563
# Loss [2, 200](epoch, minibatch):  2.3041475653648376
# Loss [2, 300](epoch, minibatch):  2.3039899492263793
# ...
# Loss [199, 100](epoch, minibatch):  0.001291145777431666
# Loss [199, 200](epoch, minibatch):  0.0017596399529429619
# Loss [199, 300](epoch, minibatch):  0.0013808918403083225
# Loss [200, 100](epoch, minibatch):  0.0013940928343799896
# Loss [200, 200](epoch, minibatch):  0.0011531753832969116
# Loss [200, 300](epoch, minibatch):  0.001689423452335177

三、训练结果与问题分析

在训练完成后,我们可以对模型进行保存和加载操作,以便后续的使用和评估。以下是保存和加载模型的代码示例:

# 保存整个模型
torch.save(model, 'VGG16.pth')# 或者只保存模型的参数
torch.save(model.state_dict(), 'VGG16_params.pth')# 加载整个模型
loaded_model = torch.load('VGG16.pth')# 或者加载模型的参数
loaded_params = torch.load('VGG16_params.pth')# 如果只加载了模型的参数,需要先将参数加载到模型对象中
# 假设我们有一个新的模型实例
new_model = VGG16(num_classes=10)
new_model.load_state_dict(loaded_params)correct = 0
total = 0with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()
print('Accuracy on 10,000 test images: ', 100 * (correct / total), '%')

通过测试集计算模型的准确率,我们得到约 86.5% 的结果。然而,需要注意以下两个问题:

  • CIFAR10 的挑战:CIFAR10 数据集中的图像分辨率较低(32x32),图像细节较少,并且部分类别之间存在一定的相似性(如狗与猫),这对模型的特征提取能力提出了较高的要求。
  • 过拟合风险:训练损失极低,但测试准确率未能达到 90% 以上,这可能表明模型存在过拟合现象,即模型在训练集上的表现远好于在测试集上的表现。

四、优化方向:如何让模型更上一层楼

1. 数据增强:对抗过拟合的 “核武器”

原代码未使用数据增强技术,为了提高模型的泛化能力,我们可以添加以下数据增强操作:

  • 随机裁剪与翻转:使用transforms.RandomCrop(32, padding = 4)transforms.RandomHorizontalFlip(),增加数据的多样性,使模型能够学习到更多不同视角和位置的特征。
  • 颜色扰动:通过transforms.ColorJitter(brightness = 0.1, contrast = 0.1, saturation = 0.1),增强模型对色彩变化的鲁棒性,使其能够适应不同光照和色彩条件下的图像。
  • Cutout/MixUp:采用随机遮挡图像区域(Cutout)或混合样本(MixUp)的方法,进一步提升模型的泛化能力。

2. 模型调整:更适配小数据集的设计

  • 使用预训练模型:可以将在 ImageNet 上预训练的 VGG16 模型权重迁移到 CIFAR10 任务中。但需要注意输入尺寸的差异(从 224 调整为 32),可以尝试冻结部分卷积层,只对后续层进行微调。
  • 轻量化改进:VGG16 模型的参数量较大(约 1.38 亿),对于 CIFAR10 这样的小数据集可能会导致过拟合。可以考虑改用更小的网络,如 VGG11、ResNet18,或者减少通道数(如将起始通道从 64 减少到 32)。
  • 添加 Dropout:在全连接层前插入nn.Dropout(0.5),抑制神经元之间的共适应现象,降低模型过拟合的风险。

3. 优化策略升级

  • 学习率策略:可以改用余弦退火(Cosine Annealing)或周期性学习率(CLR)策略,动态调整学习率,帮助模型逃离鞍点,提高收敛速度和性能。
  • 优化器选择:尝试使用 AdamW(结合权重衰减的 Adam)或 RMSprop 等优化器,这些优化器在处理稀疏梯度场景时可能更有效。
  • 混合精度训练:使用 PyTorch 的torch.cuda.amp模块进行混合精度训练,减少显存占用并加速训练过程,尤其适用于长周期的训练任务。

4. 训练技巧与调参

  • 早停(Early Stopping):监控验证集损失,若连续多个 epoch 验证集损失未提升,则提前终止训练,避免无效的训练过程。
  • 标签平滑(Label Smoothing):在损失函数中引入标签平滑技术,防止模型对单一类别过度自信,提高模型的泛化能力。
  • 调整批量大小:尝试使用更小的batch_size(如 64)以增加梯度更新的频率,或者使用更大的批量(如 256)以充分利用 GPU 的并行计算能力。

5. 测试阶段优化

  • 测试时增强(TTA):在测试阶段,对测试图像进行多尺度裁剪、翻转等操作,然后取预测结果的平均值,提升预测的鲁棒性。
  • 集成学习:训练多个不同初始化的 VGG 模型,通过投票或平均法融合这些模型的预测结果,降低模型的随机性影响,提高整体性能。

五、总结与实践建议

本次实战通过在 CIFAR10 数据集上训练 VGG16 模型,全面展示了深度学习从数据预处理到模型部署的完整流程。86.5% 的准确率仅仅是一个起点,通过采用数据增强、模型轻量化、优化策略调整等一系列优化手段,完全有能力将模型的准确率提升至 90% 以上(CIFAR10 的当前最优模型准确率可达 95% 以上)。

深度学习的学习过程需要理论与实践紧密结合,希望大家能够动手实践,亲自体验模型优化的过程。如果您需要完整代码或希望进行进一步的讨论,欢迎在评论区留言。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/81639.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/81639.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/81639.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ubuntu 20.04之Docker安装ES7.17.14和Kibana7.17.14

你需要已经安装如下运行环境: Ubuntu 20.04 docker 28 docker-compose 1.25 一、手动拉取镜像 docker pull docker.elastic.co/kibana/kibana:7.17.14docker pull docker.elastic.co/elasticsearch/elasticsearch:7.17.14 或者手动导入镜像 docker load -i es7.17.14.ta…

实时技术方案对比:SSE vs WebSocket vs Long Polling

早期网站仅展示静态内容,而如今我们更期望:实时更新、即时聊天、通知推送和动态仪表盘。 那么要如何实现实时的用户体验呢?三大经典技术各显神通: SSE(Server-Sent Events):轻量级单向数据流WebSocket:双向全双工通信Long Polling(长轮询):传统过渡方案假设目前有三…

测试开发面试题:Python高级特性通俗讲解与实战解析

前言:为什么测试工程师必须掌握Python高级特性? 通俗比喻: 基础语法就像“锤子”,能敲钉子;高级特性就像“瑞士军刀”,能应对复杂场景(如自动化框架、高并发测试)。面试官考察点&a…

C语言-9.指针

9.1指针 9.1-1取地址运算:&运算符取得变量的地址 运算符& scanf(“%d”,&i);里的&获取变量的地址,它们操作数必须是变量int i;printf(“%x”,&i);地址的大小是否与int相同取决于编译器int i;printf(“%p”,&i); &不能取的地址不能对没有地址的…

【C++】Vcpkg 介绍及其常见命令

Vcpkg 简介 Vcpkg 是微软开发的一个跨平台的 C/C 依赖管理工具,用于简化第三方库的获取、构建和管理过程。 主要特点 跨平台支持:支持 Windows、Linux 和 macOS开源免费:MIT 许可证大型库集合:包含超过 2000 个开源库简化集成&…

Unity3D 动画文件优化总结

前言 在Unity3D中,动画文件的压缩和优化是提升性能的重要环节,尤其在移动端或复杂场景中。以下是针对Animation Clip和Animator Controller的优化方法总结: 对惹,这里有一个游戏开发交流小组,希望大家可以点击进来一…

前端工程的相关管理 git、branch、build

环境配置 标准环境打包 测试版:npm run build-test 预生产:npm run build-preview 正式版:npm run build 建议本地建里一个 .env.development.local 方便和后端联调时修改配置相关信息。 和 src 同级有一下区分环境的文件: .env.d…

VAPO:视觉-语言对齐预训练(对象级语义)详解

简介 多模态预训练模型(Vision-Language Pre-training, VLP)近年来取得了飞跃发展。在视觉-语言模型中,模型需要同时理解图像和文本,这要求模型学习二者之间的语义对应关系。早期方法如 VisualBERT、LXMERT 等往往使用预先提取的图像区域特征和文本词嵌入拼接输入,通过 T…

docker运行Redis

创建目录 mkdir -p /home/jie/docker/redis/{conf,data,logs}添加权限 chmod -R 777 /home/jie/docker/redis创建配置文件 cat > /home/jie/docker/redis/conf/redis.conf << EOF # 基本配置 bind 0.0.0.0 protected-mode yes port 6379# 安全配置 密码是root require…

初识 java

目录 前言 一、jdk&#xff0c;JRE和JVM之间的关系 二、JVM的内存划分 前言 初步了解 jdk&#xff0c;JRE&#xff0c;JVM 之间的关系&#xff0c;JVM 的内存划分。 一、jdk&#xff0c;JRE和JVM之间的关系 jdk 是 java 开发工具集&#xff0c;包含JRE&#xff1b; JRE 是…

关于百度地图JSAPI自定义标注的图标显示不完整的问题(其实只是因为图片尺寸问题)

下载了几个阿里矢量图标库里的图标作为百度地图的自定义图标&#xff0c;结果百度地图显示的图标一直不完整。下载的PNG图标已经被正常引入到前端代码&#xff0c;anchor也设置为了图标底部中心&#xff0c;结果还是显示不完整。 if (iconUrl) {const icon new mapClass.Icon(…

系统安全及应用深度笔记

系统安全及应用深度笔记 一、账号安全控制体系构建 &#xff08;一&#xff09;账户全生命周期管理 1. 冗余账户精细化治理 非登录账户基线核查 Linux 系统默认创建的非登录账户&#xff08;如bin、daemon、mail&#xff09;承担系统服务支撑功能&#xff0c;其登录 Shell 必…

02-前端Web开发(JS+Vue+Ajax)

介绍 在前面的课程中&#xff0c;我们已经学习了HTML、CSS的基础内容&#xff0c;我们知道HTML负责网页的结构&#xff0c;而CSS负责的是网页的表现。 而要想让网页具备一定的交互效果&#xff0c;具有一定的动作行为&#xff0c;还得通过JavaScript来实现。那今天,我们就来讲…

AXXI4总线协议 ------ AXI_FULL协议

https://download.csdn.net/download/mvpkuku/90855619 一、AXI_FULL协议的前提知识 1. 各端口的功能 2. 4K边界问题 3. outstanding 4.时序仿真体验 可通过VIVADO自带ADMA工程观察仿真波形图 二、FPGA实现 &#xff08;主要用于读写DDR&#xff09; 1.功能模块及框架 将…

React系列——nvm、node、npm、yarn(MAC)

nvm&#xff0c;node&#xff0c;npm之间的区别 1、nvm&#xff1a;nodejs版本管理工具。nvm 可以管理很多 node 版本和 npm 版本。 2、nodejs&#xff1a;在项目开发时的所需要的代码库 3、npm&#xff1a;nodejs包管理工具。nvm、nodejs、npm的关系 nvm 管理 nodejs 和 npm…

2025年AI与网络安全的终极博弈:冲击、重构与生存法则

引言 2025年&#xff0c;生成式AI的推理速度突破每秒千万次&#xff0c;网络安全行业正经历前所未有的范式革命。攻击者用AI批量生成恶意代码&#xff0c;防御者用AI构建智能护盾&#xff0c;这场技术军备竞赛正重塑行业规则——60%的传统安全岗位面临转型&#xff0c;70%的防…

【Android】Android 实现一个依赖注入的注解

Android 实现一个依赖注入的注解 &#x1f3af; 目标功能 自定义注解 Inject创建一个 Injector 类&#xff0c;用来扫描并注入对象支持 Activity 或其他类中的字段注入 &#x1f9e9; 步骤一&#xff1a;定义注解 import java.lang.annotation.ElementType; import java.lan…

Spring Boot与Kafka集成实践:从入门到实战

Spring Boot与Kafka集成实践 引言 在现代分布式系统中&#xff0c;消息队列是不可或缺的组件之一。Apache Kafka作为一种高吞吐量的分布式消息系统&#xff0c;广泛应用于日志收集、流处理、事件驱动架构等场景。Spring Boot作为Java生态中最流行的微服务框架&#xff0c;提供…

ubuntu的虚拟机上的网络图标没有了

非正常的关机导致虚拟机连接xshell连接不上&#xff0c;ping也ping不通。网络的图标也没有了。 记录一下解决步骤 1、重启服务 sudo systemctl restart NetworkManager 2、图标显示 sudo nmcli network off sudo nmcli network on 3、sudo dhclient ens33 //(网卡) …

生产者 - 消费者模式实现方法整理

一、Channels &#xff08;一&#xff09;使用场景 适用于高并发、大数据量传输&#xff0c;且需要异步操作的场景&#xff0c;如实时数据处理系统。 &#xff08;二&#xff09;使用方法 创建 Channel<T>&#xff08;无界&#xff09;或 BoundedChannel<T>&…