DAY 50: 预训练模型与 CBAM 模块的融合与微调

今天,我们将把之前学到的知识融会贯通,探讨如何将 CBAM 这样的注意力模块应用到强大的预训练模型(如 ResNet)中,并学习如何高效地对这些模型进行微调,以适应我们自己的任务。

知识点回顾

  1. ResNet 结构解析:深入理解 ResNet 的核心思想——残差连接,并剖析其经典模型 ResNet18 的具体结构。
  2. CBAM 放置位置的思考:探讨在 ResNet 这类复杂结构中,将 CBAM 模块放置在何处才能最大化其效果。
  3. 针对预训练模型的训练策略:学习两种高级微调(Fine-tuning)技巧:
    • 差异化学习率 (Differential Learning Rates)
    • 三阶段微调 (Progressive Unfreezing)

1. ResNet18 模型结构解析

在深入研究如何修改模型之前,我们必须先透彻理解其内部构造。以 ResNet18 为例,它的成功主要归功于解决了深度网络训练中的一个关键问题:网络退化

1.1 核心思想:残差学习 (Residual Learning)

随着网络层数的增加,模型的性能非但没有提升,反而出现了下降,这被称为“退化”现象。ResNet 的作者提出,让网络直接学习输入与输出之间的残差 (Residual),会比学习完整的输出映射更容易。

这就是跳跃连接 (Skip Connection) 的由来:将输入 x 直接加到网络层的输出 F(x) 上,得到最终结果 H(x) = F(x) + x。这样,如果某一网络层 F(x) 发现自己是多余的,它只需将自己的权重学习为0,输出就直接等于输入 x,实现了恒等映射,保证了网络性能不会因为加深而退化。

1.2 ResNet18 的基本构建块 (BasicBlock)

ResNet18 由多个 BasicBlock 堆叠而成。每个 BasicBlock 包含:

  • 两个 3x3 的卷积层。
  • 每个卷积层后接一个批归一化 (BatchNorm) 和 ReLU 激活函数。
  • 一个跳跃连接,将输入 x 与第二个卷积层的输出相加。
  • 如果输入的通道数或尺寸与输出不匹配,跳跃连接会通过一个 1x1 的卷积 (downsample) 来进行适配。
1.3 ResNet18 整体结构

ResNet18 的结构可以清晰地分为几个部分:

  1. 初始卷积层 (conv1):一个 7x7 的大卷积核,用于初步提取图像的宏观特征,并进行第一次下采样。
  2. 四个残差层 (layer1-4):由多个 BasicBlock 组成。每个 layer 的第一个 BasicBlock 可能会进行下采样(步长为2),以减小特征图尺寸并加倍通道数。
    • layer1: 2个 BasicBlock, 64通道
    • layer2: 2个 BasicBlock, 128通道
    • layer3: 2个 BasicBlock, 256通道
    • layer4: 2个 BasicBlock, 512通道
  3. 全局平均池化 (avgpool):将最后的特征图在空间维度上进行平均,得到一个向量。
  4. 全连接分类层 (fc):根据任务类别数进行最终的分类。

2. CBAM 放置位置的思考

将 CBAM 这样的“即插即用”模块添加到现有模型中时,其放置位置至关重要,直接影响模型的性能。

核心原则在特征提取最充分的地方应用注意力

在 ResNet 的 BasicBlock 中,特征经过两个卷积层 conv1conv2 的连续提取。因此,最佳的放置位置是在第二个卷积层之后,与跳跃连接的输入相加之前

这样做的逻辑是:

  1. BasicBlock 内的卷积网络充分提取特征。
  2. 使用 CBAM 对这些提取出的特征进行通道和空间上的“精炼”,增强重要特征,抑制无关特征。
  3. 最后,将这个“精炼”过的特征 F'(x) 与原始输入 x 通过跳跃连接相加。

3. 针对预训练模型的训练策略

直接使用一个大型的预训练模型并在我们自己的(通常较小的)数据集上从头开始训练,既耗时又容易过拟合。更高效的方法是微调 (Fine-tuning),即利用预训练模型已经学到的通用特征,只对模型进行微小的调整以适应新任务。

a. 差异化学习率

思想:对模型的不同部分使用不同的学习率。

  • 特征提取层 (Backbone):这些层在 ImageNet 等大型数据集上已经学习到了非常通用的特征(如边缘、纹理、形状)。我们只需要对它们进行微调,因此使用一个较小的学习率(如 1e-4)。
  • 分类层 (Classifier/Head):这是我们为了适应新任务而新添加的层,其权重是随机初始化的,需要从头开始学习。因此,我们为它设置一个较大的学习率(如 1e-3)。

这样做可以防止较大的学习率破坏已经训练好的骨干网络权重,同时保证新加的分类层能够快速收敛。

b. 三阶段微调 (Progressive Unfreezing)

这是一种更稳定、更精细的微调策略,通过“渐进式解冻”来训练模型。

  • 第一阶段:只训练分类层

    • 操作:冻结骨干网络的所有参数,只更新我们新添加的分类层的参数。
    • 目的:让随机初始化的分类层快速学习,以适应新数据集的特征分布,而不会因其初始梯度过大而破坏骨干网络。
  • 第二阶段:微调部分骨干网络

    • 操作:解冻骨干网络的后面几层(如 ResNet 的 layer3, layer4),并与分类层一起训练。此时,骨干网络的学习率应设得非常小。
    • 目的:让网络的高层语义特征(更接近特定任务)也开始适应新数据。
  • 第三阶段:微调整个网络

    • 操作:解冻所有网络层,以一个极小的学习率(如 1e-5)对整个模型进行训练。
    • 目的:对整个网络进行整体的、细微的调整,以达到最佳性能。

4. 实践:对 VGG16 + CBAM 模型进行微调

虽然上面我们分析了 ResNet,但这些思想是通用的。下面我们以 VGG16 为例,演示如何为其添加 CBAM 模块并应用微调策略。

步骤

  1. 加载预训练的 VGG16 模型
  2. 修改模型结构:遍历 VGG16 的特征提取层 (features),在每个 MaxPool2d 层之后插入一个 CBAM 模块。
  3. 替换分类头:将 VGG16 原本用于1000类分类的全连接层替换为适用于我们自己任务(如 CIFAR-10 的10类分类)的新层。
  4. 设置差异化学习率:为骨干网络和新的分类层分别设置不同的学习率。
  5. (可选)实现冻结训练:通过设置参数的 requires_grad 属性为 False 来冻结特定层。

核心代码示例(仅展示结构修改与学习率设置)

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import vgg16# 假设 CBAM 模块已经定义好 (如 DAY 49 的代码)
# class CBAM(nn.Module): ...# 1. 加载预训练VGG16
model = vgg16(pretrained=True)
features = list(model.features)# 2. 在VGG的MaxPool2d层后插入CBAM模块
vgg_cbam_features = []
channel_map = {4: 64, 9: 128, 16: 256, 23: 512, 30: 512} # VGG16 MaxPool层索引 -> 通道数for i, layer in enumerate(features):vgg_cbam_features.append(layer)if isinstance(layer, nn.MaxPool2d):# 插入CBAMin_channels = channel_map.get(i)if in_channels:vgg_cbam_features.append(CBAM(in_channels))model.features = nn.Sequential(*vgg_cbam_features)# 3. 替换分类头 (假设为CIFAR-10)
num_features = model.classifier[0].in_features
model.classifier = nn.Sequential(nn.Linear(num_features, 4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096, 10), # 新的分类头,10个类别
)# 4. 设置差异化学习率
# 将参数分为两组:骨干网络和分类头
backbone_params = model.features.parameters()
classifier_params = model.classifier.parameters()optimizer = optim.Adam([{'params': backbone_params, 'lr': 1e-4},     # 骨干网络使用小学习率{'params': classifier_params, 'lr': 1e-3}  # 分类头使用大学习率
])# 5. (可选)冻结训练示例
# 冻结所有骨干网络参数
for param in model.features.parameters():param.requires_grad = False# 此时,只有分类头的参数会被更新
optimizer_frozen = optim.Adam(model.classifier.parameters(), lr=1e-3)# 在训练一段时间后,可以解冻
# for param in model.features.parameters():
#     param.requires_grad = True

通过以上步骤,我们就成功地将 CBAM 模块集成到了 VGG16 中,并为其设置了高效的微调策略。这种“理解架构 -> 策略性修改 -> 智能训练”的流程,是提升模型性能和训练效率的核心方法。


@浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/919820.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/919820.shtml
英文地址,请注明出处:http://en.pswp.cn/news/919820.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

北极圈边缘生态研究:从数据采集到分析的全流程解析

原文链接:https://onlinelibrary.wiley.com/doi/10.1111/1744-7917.70142?afR北极圈边缘生态研究:从数据采集到分析的全流程解析简介本教程基于一项在俄罗斯摩尔曼斯克州基洛夫斯克市开展的长期生态学研究,系统讲解如何对高纬度地区特定昆虫…

Excel处理控件Aspose.Cells教程:使用Python将 Excel 转换为 NumPy

使用 Python 处理 Excel 数据非常常见。这通常涉及将数据从 Excel 转换为可高效操作的形式。将 Excel 数据转换为可分析的格式可能非常棘手。在本篇教程中,您将学习借助强大Excel处理控件Aspose.Cells for Python,如何仅用几行代码将 Excel 转换为 NumPy…

python 字典有序性的实现和OrderedDict

文章目录 一、Python 3.7+ 字典有序性的验证 二、如何在字典头部插入键值对 方法 1:创建新字典(推荐) 方法 2:使用 `collections.OrderedDict`(适合频繁头部插入场景) 方法 3:转换为列表操作(不推荐,效率低) 底层核心结构:双数组哈希表 有序性的实现原理 与旧版本(…

JVM 调优全流程案例:从频繁 Full GC 到百万 QPS 的实战蜕变

🔥 JVM 调优全流程案例:从频繁 Full GC 到百万 QPS 的实战蜕变 文章目录🔥 JVM 调优全流程案例:从频繁 Full GC 到百万 QPS 的实战蜕变🧩 一、调优本质:性能瓶颈的破局之道💡 为什么JVM调优如此…

基于TimeMixer现有脚本扩展的思路分析

文章目录1. 加入数据集到data_loader.py和data_factory.py2. 参照exp_classification.py写自定义分类任务脚本(如exp_ADReSS.py)3. 接一个MLP分类头4. 嵌入指标计算、绘图、保存训练历史的函数5. 开始训练总结**一、可行性分析****二、具体实现步骤****1…

技术演进中的开发沉思-75 Linux系列:中断和与windows中断的区分

作为一名从 2000 年走过来的老程序员,看着 IT 技术从桌面开发迭代到微服务时代,始终觉得好技术就像老故事 —— 得有骨架(知识点),更得有血肉(场景与感悟)。我想正是我的经历也促成了我想写这个…

【8位数取中间4位数】2022-10-23

缘由请输入一个8位的十进制整数,编写程序取出该整数的中间4位数,分别输出取出的这4位数以及该4位数加上1024的得数。 输入:一个整数。 输出:两个整数,用空格分隔-编程语言-CSDN问答 int n 0;std::cin >> n;std:…

mac电脑使用(windows转Mac用户)

首先,我们学习mac的键盘复制 command c 粘贴 command v 剪切 command xlinux命令行 退出中止 control c 退出后台 control d中英文切换大小写,按住左边向上的箭头 字母鼠标操作 滚轮:2个指头一起按到触摸板,上滑,…

项目中优惠券计算逻辑全解析(处理高并发)

其实这个部分的代码已经完成一阵子了,但是想了一下决定还是整理一下这部分的代码,因为最开始做的时候业务逻辑还是感觉挺有难度的整体流程概述优惠方案计算主要在DiscountServiceImpl类的findDiscountSolution方法中实现。整个计算过程可以分为以下五个步…

支持电脑课程、游戏、会议、网课、直播录屏 多场景全能录屏工具

白鲨录屏大师:支持电脑课程、游戏、会议、网课、直播录屏 多场景全能录屏工具,轻松捕捉每一刻精彩 在数字化学习、娱乐与办公场景中,高质量的录屏需求日益增长。无论是课程内容的留存、游戏高光的记录,还是会议要点的复盘、网课知…

LeetCode算法日记 - Day 20: 两整数之和、只出现一次的数字II

目录 1. 两数之和 1.1 题目解析 1.2 解法 1.3 代码实现 2. 只出现一次的数字II 2.1 题目解析 2.2 解法 2.3 代码实现 1. 两数之和 371. 两整数之和 - 力扣(LeetCode) 给你两个整数 a 和 b ,不使用 运算符 和 - ,计算并…

Spring AI 快速接入 DeepSeek 大模型

Spring AI 快速接入 DeepSeek 大模型 文章目录Spring AI 快速接入 DeepSeek 大模型Spring AI 框架概述核心特性适用场景官网与资源AI 提供商与模型类型模型类型(Model Type)AI提供商(Provider)两者的关系Spring AI 框架支持哪些 A…

jQuery 知识点复习总览

文章目录jQuery 知识点复习总览一、jQuery 基础1. jQuery 简介2. jQuery 引入3. jQuery 核心函数二、选择器1. 基本选择器2. 层级选择器3. 过滤选择器4. 表单选择器三、DOM 操作1. 内容操作2. 属性操作3. CSS 操作4. 元素操作四、事件处理1. 事件绑定2. 事件对象3. 自定义事件五…

博客系统接口自动化练习

框架图: 详细代码地址:gitee仓库 博客系统接口自动化文档请看文章顶部。

智慧矿山误报率↓83%!陌讯多模态融合算法在矿用设备监控的落地优化

原创声明:本文为原创技术解析文章,核心技术参数与架构设计引用自 “陌讯技术白皮书(智慧矿山专项版)”,算法部署相关资源适配参考aishop.mosisson.com平台的陌讯视觉算法专项适配包,禁止未经授权的转载与二…

Laravel 使用阿里云OSS S3 协议文件上传

1. 安装 S3 软件包 composer require league/flysystem-aws-s3-v3 "^3.0" --with-all-dependencies2. 配置.env 以阿里云 OSS 地域华东2 上海为例: FILESYSTEM_DISKs3 //设置默认上传到S3AWS_ACCESS_KEY_ID***…

UVM一些不常用的功能

uvm_coreservice_t是什么AI:在 UVM(Universal Verification Methodology)中,uvm_coreservice_t 是一个核心服务类,它扮演着UVM 框架内部核心服务的 “管理者” 和 “统一入口” 的角色。其主要作用是封装并提供对 UVM …

怎么确定mongodb是不是链接上了?

现有mongosh链接了MongoDB,里面能操作,但是想python进行链接,因为代码需要,现在测试下链接成功了没有。如下: 要确认你的 MongoDB 连接是否成功,可以通过以下方法检查: 1. 使用 list_database_names 方法【测试成功】 python import asyncioasync def test_connecti…

Unity 二进制读写小框架

文章目录前言框架获取与集成使用方法基本配置自动生成序列化方法实战示例技术原理与优势二进制序列化的优势SJBinary的设计特点最佳实践建议适用场景总结前言 在Unity开发过程中,与后台交互时经常需要处理大型数据文件。当遇到一个近2MB的本地JSON文件需要解析为对…

​Kubernetes 详解:云原生时代的容器编排与管理

一 Kubernetes 简介及部署方法 1.1 应用部署方式演变 在部署应用程序的方式上,主要经历了三个阶段: 传统部署:互联网早期,会直接将应用程序部署在物理机上 优点:简单,不需要其它技术的参与 缺点&#xf…