1. 迁移学习

迁移学习(transfer learning)是一种机器学习方法,通过将源数据集(如ImageNet)上训练得到的模型知识迁移到目标数据集(如特定场景的椅子识别任务)。这种方法的核心在于利用预训练模型已学习到的通用特征,提升目标任务的性能,尤其是在数据量有限的情况下。

ImageNet等大规模数据集虽然包含多样化的图像类别(如动物、植物、人造物体等),但在此类数据集上训练的模型能够自动学习低层次和高层次的通用视觉特征。这些特征包括但不限于:

  • 边缘检测:识别图像中的轮廓和边界。
  • 纹理分析:捕捉不同材质的表面特性。
  • 形状建模:提取物体的几何结构。
  • 对象组合关系:理解多物体间的空间和语义关联。

尽管ImageNet可能不直接包含大量椅子图像,但模型学习到的上述特征仍然能够有效支持椅子识别任务。例如,椅子的边缘结构与动物轮廓可能共享相似的检测模式,而木质或布艺椅子的纹理特征可能与其他物体(如家具或服饰)的纹理分析机制高度相关。

通过迁移学习,目标任务无需从零开始训练模型,从而显著减少计算资源和时间成本。此外:

  • 在目标数据集较小的情况下,迁移学习能够避免过拟合问题
  • 预训练模型的高层次特征提取能力可提升目标任务的泛化性能
  • 适用于跨领域应用,如医学影像分析、自动驾驶等场景

2. 微调

迁移学习中的常见技巧:微调(fine-tuning)。

微调包括以下四个步骤。

  1. 在源数据集(例如ImageNet数据集)上预训练神经网络模型,即源模型

  2. 创建一个新的神经网络模型,即目标模型。这将复制源模型上的所有模型设计及其参数(输出层除外)。我们假定这些模型参数包含从源数据集中学到的知识,这些知识也将适用于目标数据集。我们还假设源模型的输出层与源数据集的标签密切相关;因此不在目标模型中使用该层。

  3. 向目标模型添加输出层,其输出数是目标数据集中的类别数。然后随机初始化该层的模型参数。

  4. 在目标数据集(如椅子数据集)上训练目标模型。输出层将从头开始进行训练,而所有其他层的参数将根据源模型的参数进行微调

微调中的 权重初始化

微调

  • 迁移学习将从源数据集中学到的知识迁移到目标数据集,微调是迁移学习的常见技巧

  • 除输出层外,目标模型从源模型中复制所有模型设计及其参数,并根据目标数据集对这些参数进行微调。

  • 但是,目标模型的输出层需要从头开始训练。

  • 通常,微调参数使用较小的学习率,而从头开始训练输出层可以使用更大的学习率。

3. 代码实现

%matplotlib inline
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l#@save
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip','fba480ffa8aa7e0febbb511d181409f899b9baa5')data_dir = d2l.download_extract('hotdog')train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))

下面显示了前8个正类样本图片和最后8张负类样本图片。

正如所看到的,图像的大小和纵横比各有不同。

hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4);

数据增广:

对于RGB(红、绿和蓝)颜色通道,我们分别标准化每个通道。 具体而言,该通道的每个值减去该通道的平均值,然后将结果除以该通道的标准差。

# 使用RGB通道的均值和标准差,以标准化每个通道
normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),normalize])test_augs = torchvision.transforms.Compose([torchvision.transforms.Resize([256, 256]),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),normalize])

查看源模型:

我们使用在ImageNet数据集上预训练的ResNet-18作为源模型。

在这里,我们指定pretrained=True以自动下载预训练的模型参数。 如果首次使用此模型,则需要连接互联网才能下载。

预训练的源模型实例包含许多特征层和一个输出层fc。 此划分的主要目的是促进对除输出层以外所有层的模型参数进行微调。

下面给出了源模型的成员变量fc

pretrained_net = torchvision.models.resnet18(pretrained=True)
pretrained_net.fc输出:
Linear(in_features=512, out_features=1000, bias=True)

定义目标模型finetune_net:【换输出层

  1. finetune_net = torchvision.models.resnet18(pretrained=True) 载入 在 ImageNet 上预训练好的 ResNet-18 网络(权重已固定,特征提取器已具备通用视觉能力)。

  2. finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2) 把 ResNet-18 最后的 1000 类全连接层 替换成 只有 2 个输出的新层(二分类任务)。

    1. in_features 保留原模型提取到的特征维度(512)。

    2. 输出维度改为 2,对应新任务的类别数。

finetune_net = torchvision.models.resnet18(pretrained=True)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight);

首先,我们定义了一个训练函数train_fine_tuning,该函数使用微调,因此可以多次调用。

其中特征提取层学习率低,最后一层输出层学习率高。

# 如果param_group=True,输出层中的模型参数将使用十倍的学习率
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,param_group=True):train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=train_augs),batch_size=batch_size, shuffle=True)test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=test_augs),batch_size=batch_size)devices = d2l.try_all_gpus()loss = nn.CrossEntropyLoss(reduction="none")if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ["fc.weight", "fc.bias"]]trainer = torch.optim.SGD([{'params': params_1x},{'params': net.fc.parameters(),'lr': learning_rate * 10}],lr=learning_rate, weight_decay=0.001)else:trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,weight_decay=0.001)d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)

我们使用较小的学习率,通过微调预训练获得的模型参数。

train_fine_tuning(finetune_net, 5e-5)输出:
loss 0.220, train acc 0.915, test acc 0.939
999.1 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]

为了进行比较,我们定义了一个相同的模型,但是将其所有模型参数初始化为随机值。

由于整个模型需要从头开始训练,因此我们需要使用更大的学习率。

scratch_net = torchvision.models.resnet18()
scratch_net.fc = nn.Linear(scratch_net.fc.in_features, 2)
train_fine_tuning(scratch_net, 5e-4, param_group=False)输出:
loss 0.374, train acc 0.839, test acc 0.843
1623.8 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/93789.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/93789.shtml
英文地址,请注明出处:http://en.pswp.cn/bicheng/93789.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

STL库——string(类函数学习)

ʕ • ᴥ • ʔ づ♡ど 🎉 欢迎点赞支持🎉 个人主页:励志不掉头发的内向程序员; 专栏主页:C语言; 文章目录 前言 一、STL简介 二、string类的优点 三、标准库中的string类 四、string的成员函数 4.1、构造…

登上Nature!清华大学光学神经网络研究突破

2025深度学习发论文&模型涨点之——光学神经网络光学神经网络的基本原理是利用光的传播、干涉、衍射等特性来实现神经网络中的信息处理和计算。在传统神经网络中,信息以电信号的形式在电子元件之间传输和处理,而在光学神经网络中,信息则以…

【java】对word文件设置只读权限

文件流输出时 template.getXWPFDocument().enforceCommentsProtection(); 文件输出时 //打开己创建的word文档 XWPFDocument document new XWPFDocument(new FileInputStream("output.docx")); //设置文档为只读 document.enforceReadonlyProtection(); //保存文…

Zookeeper 在 Kafka 中扮演了什么角色?

在 Apache Kafka 的早期架构中,ZooKeeper 扮演了分布式协调服务角色,负责管理和协调整个 Kafka 集群。 尽管新版本的 Kafka 正在逐步移除对 ZooKeeper 的依赖,但在许多现有和较早的系统中,了解 ZooKeeper 的作用仍然非常重要。 Zo…

什么叫做 “可迭代的产品矩阵”?如何落地?​

“可迭代的产品矩阵” 不是静态的产品组合,而是围绕用户需求与商业目标构建的动态生态。它以核心产品为根基,通过多维度延伸形成产品网络,同时具备根据市场反馈持续优化的弹性,让产品体系既能覆盖用户全生命周期需求,又…

Nginx代理配置详解:正向代理与反向代理完全指南

系列文章索引: 第一篇:《Nginx入门与安装详解:从零开始搭建高性能Web服务器》第二篇:《Nginx基础配置详解:nginx.conf核心配置与虚拟主机实战》第三篇:《Nginx代理配置详解:正向代理与反向代理…

Vue3 Element-plus 封装Select下拉复选框选择器

废话不多说&#xff0c;样式如下&#xff0c;代码如下&#xff0c;需要自取<template><el-selectv-model"selectValue"class"checkbox-select"multiple:placeholder"placeholder":style"{ width: width }"change"change…

jenkins 自动部署

一、win10 环境安装&#xff1a; 1、jdk 下载安装&#xff1a;Index of openjdk-local 2、配置环境变量&#xff1a; 3、jenkins 下载&#xff1a;Download and deploy 下载后的结果&#xff1a;jenkins.war 4、jenkins 启动&#xff1a; 5、创建管理员用户 admin 登录系统…

2020 GPT3 原文 Language Models are Few-Shot Learners 精选注解

本文为个人阅读GPT3&#xff0c;部分内容注解&#xff0c;由于GPT3原文篇幅较长&#xff0c;且GPT3无有效开源信息 这里就不再一一粘贴&#xff0c;仅对原文部分内容做注解&#xff0c;仅供参考 详情参考原文链接 原文链接&#xff1a;https://arxiv.org/pdf/2005.14165 语言模…

设计模式笔记_行为型_迭代器模式

1. 迭代器模式介绍迭代器模式&#xff08;Iterator Pattern&#xff09;是一种行为设计模式&#xff0c;旨在提供一种方法顺序访问一个聚合对象中的各个元素&#xff0c;而又不需要暴露该对象的内部表示。这个模式的主要目的是将集合的遍历与集合本身分离&#xff0c;使得用户可…

【Part 4 未来趋势与技术展望】第一节|技术上的抉择:三维实时渲染与VR全景视频的共生

《VR 360全景视频开发》专栏 将带你深入探索从全景视频制作到Unity眼镜端应用开发的全流程技术。专栏内容涵盖安卓原生VR播放器开发、Unity VR视频渲染与手势交互、360全景视频制作与优化&#xff0c;以及高分辨率视频性能优化等实战技巧。 &#x1f4dd; 希望通过这个专栏&am…

mac查看nginx安装位置 mac nginx启动、重启、关闭

安装工具&#xff1a;homebrew步骤&#xff1a;1、打开终端&#xff0c;习惯性命令&#xff1a;brew update //结果&#xff1a;Already up-to-date.2、终端继续执行命令&#xff1a;brew search nginx //查询要安装的软件是否存在3、执行命令&#xff1a;brew info nginx4. …

网络通信的基本概念与设备

目录 一、互联网 二、JAVA跨平台与C/C的原理 1、JAVA跨平台的原理 2、C/C跨平台的原理 三、网络互连模型 四、客户端与服务器 五、计算机之间的通信基础 1、IP地址与MAC地址 2、ARP与ICMP对比 ①ARP协议&#xff08;地址解析协议&#xff09; ②ICMP协议&#xff08…

云原生俱乐部-k8s知识点归纳(1)

这篇文章主要是讲讲k8s中的知识点归纳&#xff0c;以帮助理解。虽然平时也做笔记和总结&#xff0c;但是就将内容复制过来不太好&#xff0c;而且我比较喜欢打字。因此知识点归纳总结还是以叙述的口吻来说说&#xff0c;并结合我的理解加以论述。k8s和docker首先讲一讲docker和…

基于Node.js+Express的电商管理平台的设计与实现/基于vue的网上购物商城的设计与实现/基于Node.js+Express的在线销售系统

基于Node.jsExpress的电商管理平台的设计与实现/基于vue的网上购物商城的设计与实现/基于Node.jsExpress的在线销售系统

Git 对象存储:理解底层原理,实现高效排错与存储优化

### 探秘 Git 对象存储&#xff1a;底层原理与优化实践#### 一、Git 对象存储的底层原理 Git 采用**内容寻址文件系统**&#xff0c;核心机制如下&#xff1a; 1. **对象类型与存储** - **Blob 对象**&#xff1a;存储文件内容&#xff0c;通过 git hash-object 生成唯一 SHA-…

【2025CVPR-目标检测方向】RaCFormer:通过基于查询的雷达-相机融合实现高质量的 3D 目标检测

1. 研究背景与动机​ ​问题​:现有雷达-相机融合方法依赖BEV特征融合,但相机图像到BEV的转换因深度估计不准确导致特征错位;雷达BEV特征稀疏,相机BEV特征因深度误差存在畸变。 ​核心思路​:提出跨视角查询融合框架,通过对象查询(object queries)同时采样图像视角(原…

【每日一题】Day 7

560.和为K的子数组 题目&#xff1a; 给你一个整数数组 nums 和一个整数 k &#xff0c;请你统计并返回该数组中和为 k 的子数组的个数 。 子数组是数组中元素的连续非空序列。 示例 1&#xff1a; 输入&#xff1a;nums [1,1,1], k 2 输出&#xff1a;2 示例 2&#x…

3ds MAX文件/贴图名称乱码?6大根源及解决方案

在3ds MAX渲染阶段&#xff0c;文件或贴图名称乱码导致渲染失败&#xff0c;是困扰众多用户的常见难题。其背后原因多样&#xff0c;精准定位方能高效解决&#xff1a;乱码核心根源剖析字符编码冲突 (最常见)非ASCII字符风险&#xff1a; 文件路径或名称包含中文、日文、韩文等…

链路聚合路由器OpenMPTCProuter源码编译与运行

0.前言 前面写了两篇关于MPTCP的文章&#xff1a; 《链路聚合技术——多路径传输Multipath TCP(MPTCP)快速实践》《使用MPTCPBBR进行数据传输&#xff0c;让网络又快又稳》 对MPTCP有了基本的了解与实践&#xff0c;并在虚拟的网络拓扑中实现了链路带宽的叠加。 1.OpenMPTC…