目录

简介

一、迁移学习

1.什么是迁移学习

2. 迁移学习的步骤

二、残差网络ResNet

1.了解ResNet

2.ResNet网络---残差结构

三、代码分析

1. 导入必要的库

2. 模型准备(迁移学习)

3. 数据预处理

4. 自定义数据集类

5. 数据加载器

6. 设备配置

7. 训练函数

8. 测试函数

9. 训练配置和执行

整体流程总结


简介

        经过长久的卷积神经网络的学习、我们学习了如何提高模型的准确率,但是最终我们的准确率还是没达到百分之八十。原因是因为我们本身模型的局限,面对现有很多成熟的模型,它们有很好的效果,都是经过多次训练选取了最佳的参数,那我们能不能去使用哪些大佬的模型呢?

        答案是可以的,这就使用到迁移学习的知识。

深度学习之第五课卷积神经网络 (CNN)如何训练自己的数据集(食物分类)

深度学习之第六课卷积神经网络 (CNN)如何保存和使用最优模型

深度学习之第七课卷积神经网络 (CNN)调整学习率

一、迁移学习

1.什么是迁移学习

        迁移学习是指利用已经训练好的模型,在新的任务上进行微调。迁移学习可以加快模型训练速度,提高模型性能,并且在数据稀缺的情况下也能很好地工作。

2. 迁移学习的步骤

        1、选择预训练的模型和适当的层:通常,我们会选择在大规模图像数据集(如ImageNet)上预训练的模型,如VGG、ResNet等。然后,根据新数据集的特点,选择需要微调的模型层。对于低级特征的任务(如边缘检测),最好使用浅层模型的层,而对于高级特征的任务(如分类),则应选择更深层次的模型。

        2、冻结预训练模型的参数:保持预训练模型的权重不变,只训练新增加的层或者微调一些层,避免因为在数据集中过拟合导致预训练模型过度拟合。

        3、在新数据集上训练新增加的层:在冻结预训练模型的参数情况下,训练新增加的层。这样,可以使新模型适应新的任务,从而获得更高的性能。

        4、微调预训练模型的层:在新层上进行训练后,可以解冻一些已经训练过的层,并且将它们作为微调的目标。这样做可以提高模型在新数据集上的性能。

        5、评估和测试:在训练完成之后,使用测试集对模型进行评估。如果模型的性能仍然不够好,可以尝试调整超参数或者更改微调层。

太多概念,我们直接使用残差网络进行迁移学习。

二、残差网络ResNet

1.了解ResNet

        ResNet 网络是在 2015年 由微软实验室中的何凯明等几位大神提出,斩获当年ImageNet竞赛中分类任务第一名,目标检测第一名。获得COCO数据集中目标检测第一名,图像分割第一名。

传统卷积神经网络存在的问题?

卷积神经网络都是通过卷积层和池化层的叠加组成的。 在实际的试验中发现,随着卷积层和池化层的叠加,学习效果不会逐渐变好,反而出现2个问题:

        1、梯度消失和梯度爆炸 梯度消失:若每一层的误差梯度小于1,反向传播时,网络越深,梯度越趋近于0 梯度爆炸:若每一层的误差梯度大于1,反向传播时,网络越深,梯度越来越大

        2、退化问题

如何解决问题?

为了解决梯度消失或梯度爆炸问题,论文提出通过数据的预处理以及在网络中使用 BN(Batch Normalization)层来解决。 为了解决深层网络中的退化问题,可以人为地让神经网络某些层跳过下一层神经元的连接,隔层相连,弱化每层之间的强联系。这种神经网络被称为 残差网络 (ResNets)。

                                        实线为测试集错误率 虚线为训练集错误率

2.ResNet网络---残差结构

ResNet的经典网络结构有:ResNet-18、ResNet-34、ResNet-50、ResNet-101、ResNet-152几种,其中,ResNet-18和ResNet-34的基本结构相同,属于相对浅层的网络,后面3种的基本结构不同于ResNet-18和ResNet-34,属于更深层的网络。

不论是多少层的ResNet网络,它们都有以下共同点:

  • 网络一共包含5个卷积组,每个卷积组中包含1个或多个基本的卷积计算过程(Conv-> BN->ReLU)
  • 每个卷积组中包含1次下采样操作,使特征图大小减半,下采样通过以下两种方式实现:
    • 最大池化,步长取2,只用于第2个卷积组(Conv2_x)
    • 卷积,步长取2,用于除第2个卷积组之外的4个卷积组
  • 第1个卷积组只包含1次卷积计算操作,5种典型ResNet结构的第1个卷积组完全相同,卷积核均为7x7, 步长为均2
  • 第2-5个卷积组都包含多个相同的残差单元,在很多代码实现上,通常把第2-5个卷积组分别叫做Stage1、Stage2、Stage3、Stage4
  • 首先是第一层卷积使用kernel 7∗7,步长为2,padding为3。之后进行BN,ReLU和maxpool。这些构成了第一部分卷积模块conv1。
  • 然后是四个stage,有些代码中用make_layer()来生成stage,每个stage中有多个模块,每个模块叫做building block,resnet18= [2,2,2,2],就有8个building block。注意到他有两种模块BasicBlockBottleneck。resnet18和resnet34用的是BasicBlock,resnet50及以上用的是Bottleneck。无论BasicBlock还是Bottleneck模块,都用到了残差连接(shortcut connection)方式:

下图以ResNet18为例介绍一下它的网络模型

layer1

        ResNet18 ,使用的是 BasicBlocklayer1,特点是没有进行降采样,卷积层的 stride = 1,不会降采样。在进行 shortcut 连接时,也没有经过 downsample 层。

layer2,layer3,layer4

而 layer2layer3layer4 的结构图如下,每个 layer 包含 2 个 BasicBlock,但是第 1 个 BasicBlock 的第 1 个卷积层的 stride = 2,会进行降采样。在进行 shortcut 连接时,会经过 downsample 层,进行降采样和降维

        residual结构使用了一种shortcut的连接方式,也可理解为捷径。让特征矩阵隔层相加,注意F(X)和X形状要相同,所谓相加是特征矩阵相同位置上的数字进行相加。

        一个残差块有2条路径 F(x)和 x,F(x) 路径拟合残差,可称之为残差路径; 路径为`identity mapping`恒等映射,可称之为`shortcut`。图中的⊕为`element-wise addition`,要求参与运算的F(x)  和 x的尺寸要相同。

其中关键技术 Batch Normalization是对每一个卷积后进行标准化

        Batch Normalization目的:使所有的feature map满足均值为0,方差为1的分布规律

三、代码分析

1. 导入必要的库

import torch
from torch.utils.data import DataLoader,Dataset  # 数据加载相关
from PIL import Image  # 图像处理
from torchvision import transforms  # 数据预处理
import numpy as np
from torch import nn  # 神经网络模块
import torchvision.models as models  # 预训练模型

2. 模型准备(迁移学习)

这部分是迁移学习的重点,

# 加载预训练的ResNet-18模型
resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)# 冻结所有预训练参数(迁移学习常用策略)
for param in resnet_model.parameters():print(param)  # 打印参数(实际应用中可删除)param.requires_grad = False  # 冻结参数,不参与训练# 获取原模型最后一层的输入特征数
in_features = resnet_model.fc.in_features  # ResNet18的fc层输入是512# 替换最后一层全连接层,输出类别数为20(根据实际任务调整)
resnet_model.fc = nn.Linear(in_features, 20)# 收集需要更新的参数(只有新替换的全连接层参数)
params_to_update = []
for param in resnet_model.parameters():if param.requires_grad == True:params_to_update.append(param)

        这里采用了迁移学习策略:冻结预训练模型的大部分参数,只训练最后一层的分类器,这样可以加快训练速度并提高效果。

  • models.resnet18():创建 ResNet-18 网络结构
  • weights=models.ResNet18_Weights.DEFAULT:使用在 ImageNet 数据集上预训练好的权重初始化模型
  • 迁移学习的关键操作:保留预训练模型学到的特征提取能力
  • requires_grad = False:告诉 PyTorch 不需要计算这些参数的梯度
  • 原 ResNet-18 用于 1000 类分类,这里替换为 20 类分类
  • 只训练新替换的全连接层参数,大大减少计算量

3. 数据预处理

data_transforms = {'train': transforms.Compose([  # 训练集的数据增强transforms.Resize([300, 300]),  # 调整大小transforms.RandomRotation(45),  # 随机旋转transforms.CenterCrop(224),  # 中心裁剪transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转transforms.RandomVerticalFlip(p=0.5),  # 随机垂直翻转transforms.ToTensor(),  # 转为Tensor# 归一化,使用ImageNet的均值和标准差transforms.Normalize([0.485, 0.456, 0.486], [0.229, 0.224, 0.225])]),'valid': transforms.Compose([  # 验证集不做数据增强,只做必要处理transforms.Resize([224, 224]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.486], [0.229, 0.224, 0.225])]),
}

4. 自定义数据集类

class food_dataset(Dataset):  # 继承Dataset类def __init__(self, file_path, transform=None):self.file_path = file_pathself.imgs = []  # 存储图像路径self.labels = []  # 存储标签self.transform = transform# 从文件中读取图像路径和标签with open(file_path, 'r') as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):  # 返回数据集大小return len(self.imgs)def __getitem__(self, idx):  # 获取单个样本image = Image.open(self.imgs[idx])  # 打开图像if self.transform:  # 应用预处理image = self.transform(image)# 处理标签,转为Tensorlabel = self.labels[idx]label = torch.from_numpy(np.array(label, dtype=np.int64))return image, label

5. 数据加载器

# 创建训练集和测试集
train_data = food_dataset(file_path='train.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path='test.txt', transform=data_transforms['train'])  # 注意这里可能应该用'valid'# 创建数据加载器,用于批量加载数据
train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=True)

6. 设备配置

# 自动选择可用的计算设备(GPU优先)
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")# 将模型移动到选定的设备
model = resnet_model.to(device)

7. 训练函数

def train(dataloader, model, loss_fn, optimizer):model.train()  # 切换到训练模式batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)  # 将数据移动到设备# 前向传播pred = model.forward(X)loss = loss_fn(pred, y)  # 计算损失# 反向传播和参数更新optimizer.zero_grad()  # 梯度清零loss.backward()  # 反向传播计算梯度optimizer.step()  # 更新参数# 打印训练信息loss = loss.item()if batch_size_num % 64 == 0:print(f"loss: {loss:>7f} [number: {batch_size_num}]")batch_size_num += 1

8. 测试函数

best_acc = 0  # 记录最佳准确率def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()  # 切换到评估模式test_loss, correct = 0, 0with torch.no_grad():  # 关闭梯度计算,节省内存for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()# 计算正确预测的数量correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 计算平均损失和准确率test_loss /= num_batchescorrect /= sizeprint(f"Test result:\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}")# 保存最佳模型global best_accif correct > best_acc:best_acc = correcttorch.save(model, 'best3.pt')  # 保存整个模型

9. 训练配置和执行

# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()  # 交叉熵损失,适用于分类任务
optimizer = torch.optim.Adam(params_to_update, lr=0.001)  # Adam优化器# 学习率调度器,每10个epoch学习率乘以0.5
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)# 训练轮次
epochs = 20
acc_s = []
loss_s = []# 开始训练
for t in range(epochs):print(f"Epoch {t+1}\n-----------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)scheduler.step()  # 更新学习率
print("Done!")
print(f"最佳的结果:\n Accuracy: {(100*best_acc):>0.1f}%")

整体流程总结

  1. 加载预训练的 ResNet-18 模型并修改最后一层以适应新任务
  2. 定义数据预处理和增强方法
  3. 创建自定义数据集类来读取图像和标签
  4. 设置训练设备(GPU 或 CPU)
  5. 定义训练和测试函数
  6. 配置优化器、损失函数和学习率调度器
  7. 执行多轮训练,每轮结束后在测试集上评估并保存最佳模型

最后我们都结果可以达到百分之90左右,效果得到很大的提升。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/98137.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/98137.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/98137.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Pinia 两种写法全解析:Options Store vs Setup Store(含实践与场景对比)

目标:把 Pinia 的两种写法讲透,写明“怎么写、怎么用、怎么选、各自优缺点与典型场景”。全文配完整代码与注意事项,可直接当团队规范参考。一、背景与准备 适用版本:Vue 3 Pinia 2.x安装与初始化: # 安装 npm i pini…

setup函数相关【3】

目录1.setup函数:1.概述:2.案例分析:2.setup函数的优化:(setup语法糖)优化1:优化2:安装插件:安装指令:只对当前项目安装配置vite.config.ts:代码编…

如何通过AI进行数据资产梳理

最终产出 数据资产清单 包含所有数据资产的详细目录,列出数据集名称、描述、所有者、格式、存储位置和元数据。 用途:帮助政府部门清晰了解数据资产分布和状态。 数据质量报告 数据质量评估结果,记录准确性、完整性、一致性等问题及改进建议,基于政府认可的数据质量框架(如…

【传奇开心果系列】Flet框架结合pillow实现的英文文字倒映特效自定义模板特色和实现原理深度解析

Flet框架结合pillow实现的英文文字倒映特效自定义模板特色和实现原理深度解析 一、效果展示截图 二、使用场景 三、特色说明 四、概括说明 五、依赖文件列表 六、安装依赖命令 七、 项目结构建议 八、注意事项 九、Flet 文字倒影效果实现原理分析 (一)组件结构与功能 1. 图像…

2025最新深度学习面试必问100题--理论+框架+原理+实践 (下篇)

2025最新深度学习面试必问100题–理论框架原理实践 (下篇) 在上篇中,我们已经深入探讨了机器学习基础、CNN、RNN及其变体,以及模型优化的核心技巧。 在下篇中,我们将把目光投向更远方,聚焦于当今AI领域最炙手可热的前沿。我们将深…

原子工程用AC6编译不过问题

…\Output\atk_h750.axf: Error: L6636E: Pre-processor step failed for ‘…\User\SCRIPT\qspi_code.scf.scf’修改前: #! armcc -E ;#! armclang -E --targetarm-arm-none-eabi -mcpucortex-m7 -xc /* 使用说明 ! armclang -E --targetarm-arm-none-eabi -mcpuco…

Python有哪些经典的常用库?(第一期)

目录 1、NumPy (数值计算基础库) 核心特点: 应用场景: 代码示例: 2、Pandas (数据分析处理库) 应用场景: 代码示例: 3、Scikit-learn (机器学习库) 核心特点: 应用场景: 代码示例&am…

现代 C++ 高性能程序驱动器架构

🧠 现代 C 高性能程序驱动器架构M/PA(多进程)是隔离的“孤岛”,M/TA(多线程)是共享的“战场”,EDSM(事件驱动)是高效的“反应堆”,MDSM(消息驱动&…

投资储能项目能赚多少钱?小程序帮你测算

为解决电网负荷平衡、提升新能源消纳等问题,储能项目的投资开发越来越多。那么,投资储能项目到底能赚多少钱?适不适合投资?用“绿虫零碳助手”3秒钟精准测算。操作只需四步,简单易懂:1.快速登录&#xff1a…

Mac 能够连Wife,但是不能上网问题解决

请按照以下步骤从最简单、最可能的原因开始尝试: 第一步:基础快速排查 这些步骤能解决大部分临时性的小故障。 重启设备:关闭您的 Mac 和路由器,等待一分钟后再重新打开。这是解决网络问题最有效的“万能药”。检查其他设备&am…

基于SpringBoot的旅游管理系统的设计与实现(代码+数据库+LW)

摘要 本文阐述了一款基于SpringBoot框架的旅游管理系统设计与实现。该系统整合了用户信息管理、旅游资源展示、订单处理流程及安全保障机制等核心功能,专为提升旅游行业的服务质量和运营效率而设计。 系统采用前后端分离架构,前端界面设计注重跨设备兼…

Springboot乐家流浪猫管理系统16lxw(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。

系统程序文件列表项目功能:领养人,流浪猫,领养申请开题报告内容基于Spring Boot的乐家流浪猫管理系统开题报告一、研究背景与意义随着城市化进程加速和人口增长,流浪猫问题已成为全球性社会挑战。据统计,全球每年约有1.5亿只无家可归的宠物&a…

函数定义跳转之代码跳转

相信大家在开发的过程中都有用到函数定义跳转的功能,在 IDE 中,如果在函数调用的地方停留光标,可能会提示对应的函数定义,在 GitHub 中也是如此,对于一些仓库来说,我们可以直接查看对应的函数定义了&#x…

探讨Xsens在人形机器人研发中的四个核心应用

探索Xsens动作捕捉如何改变人形机器人研发——使机器人能够从人类运动中学习、更直观地协作并弥合模拟与现实世界之间的差距。人形机器人技术是当今世界最令人兴奋且最复杂的前沿领域之一。研究人员不仅致力于开发能够像人类一样行走和行动的机器人,还致力于开发能够…

C语言高级编程:一文读懂数据结构的四大逻辑与两大存储

各类资料学习下载合集 ​​https://pan.quark.cn/s/8c91ccb5a474​ 作为一名程序员,我们每天都在与“数据”打交道。但你是否想过,这些数据在计算机中是如何被“整理”和“安放”的?为什么有些操作快如闪电,而有些则慢如蜗牛? 答案就藏在数据结构之中。 如果说算法是…

MySQL问题4

MySQL中varchar和char的区别 在 MySQL 中,VARCHAR 和 CHAR 都是用于存储字符串类型的字段,但它们在存储方式、性能、适用场景等方面存在明显区别:1. 存储方式类型说明CHAR(n)定长字符串,始终占用固定 n 个字符空间。不足的会自动在…

Web3 出海香港 101 |BuildSpace AMA 第一期活动高亮观点回顾

香港政府在 2022-2023 年之间已经开始布局 Web3,由香港政府全资拥有的数码港也进行了持续两年多的深耕。目前数码港已有接近 300 家企业入驻于此,包括 Animoca Brands、HashKey Group、CertiK 等行业知名独角兽公司。此外,如 Cobo、OneKey、D…

LTE CA和NR CA的区别和联系

LTE CA(Carrier Aggregation)和NR CA(New Radio Carrier Aggregation)都是载波聚合技术,它们的核心目标都是通过组合多个频段的带宽来提高数据传输速率,增强无线网络的吞吐量。尽管它们的功能相似&#xff…

VBA 中的 Excel 工作表函数

一、引言 在使用VBA进行Excel自动化处理时,我们经常需要调用Excel内置的工作表函数来完成复杂的计算或数据处理任务。然而,很多VBA初学者并不清楚如何正确地在VBA中调用这些函数,甚至重复造轮子。本文将从基础到进阶,系统介绍如何…

老年公寓管理系统设计与实现(代码+数据库+LW)

摘要 随着老龄化社会的不断发展,老年人群体的生活质量和管理需求逐渐引起社会的广泛关注。为了提高老年公寓的管理效率与服务质量,开发了一种基于SpringBoot框架的老年公寓管理系统。该系统充分利用了SpringBoot框架的快速开发优势,结合现代…