残差神经网络(Residual Neural Network,简称 ResNet)是深度学习领域的里程碑式模型,由何凯明等人在 2015 年提出,成功解决了深层神经网络训练中的梯度消失 / 爆炸问题,使训练超深网络(如 152 层)成为可能。以下从核心原理、结构设计、优势与应用等方面进行详解。

一、核心问题:深层网络的训练困境

在 ResNet 提出前,随着网络层数增加,模型性能会先提升,然后迅速下降 —— 这种下降并非由过拟合导致,而是因为深层网络的梯度难以有效传递到浅层,导致浅层参数无法被充分训练(梯度消失 / 爆炸)。

ResNet 通过引入 “残差连接”(Residual Connection)解决了这一问题。

二、核心原理:残差连接与恒等映射

1. 传统网络的映射方式

传统深层网络中,每一层的目标是学习一个 “直接映射”(Direct Mapping):
设输入为x,经过多层非线性变换后,输出为H(x),即网络需要学习H(x)。

2. 残差网络的映射方式

ResNet 提出:不直接学习H(x),而是学习 “残差”F(x)=H(x)−x
此时,原映射可表示为:H(x)=F(x)+x
其中,F(x)是残差函数(由若干卷积层 / 激活函数组成),x通过 “跳跃连接”(Skip Connection)直接与F(x)相加,形成最终输出。

3. 为什么残差连接有效?
  • 梯度传递更顺畅:反向传播时,梯度可通过x直接传递到浅层(避免梯度消失)。例如,若F(x)=0,则H(x)=x,形成 “恒等映射”,网络可轻松学习到这种简单映射,再在此基础上优化残差。
  • 简化学习目标:学习残差F(x)比直接学习H(x)更简单。例如,当目标映射接近恒等映射时,F(x)接近 0,网络只需微调即可,无需重新学习复杂的映射。

三、ResNet 的基本结构:残差块(Residual Block)

残差块是 ResNet 的基本单元,分为两种类型:

1. 基本残差块(Basic Block,用于 ResNet-18/34)

由 2 个卷积层组成,结构如下:x→Conv2d(64,3x3)→BN→ReLU→Conv2d(64,3x3)→BN→(+x)→ReLU

  • 输入x先经过两个 3x3 卷积层(带批归一化 BN 和 ReLU 激活),得到残差F(x)。
  • 若输入x与F(x)的维度相同(通道数、尺寸一致),则直接相加(恒等映射);若维度不同(如 stride > 1 或通道数变化),则需通过 1x1 卷积调整x的维度(称为 “投影捷径”,Projection Shortcut):x→Conv2d(out_channels,1x1,stride)→BN→(+F(x))
2. 瓶颈残差块(Bottleneck Block,用于 ResNet-50/101/152)

为减少计算量,用 3 个卷积层(1x1 + 3x3 + 1x1)组成,结构如下:x→Conv2d(C,1x1)→BN→ReLU→Conv2d(C,3x3)→BN→ReLU→Conv2d(4C,1x1)→BN→(+x′)→ReLU

  • 1x1 卷积用于 “降维”(减少通道数),3x3 卷积用于提取特征,最后 1x1 卷积 “升维”(恢复通道数),显著降低计算量。
  • 同样支持投影捷径(当维度不匹配时)。

四、完整 ResNet 的网络架构

ResNet 通过堆叠残差块形成深层网络,不同层数的 ResNet 结构如下表:

网络类型残差块类型卷积层配置(每个阶段的残差块数量)总层数
ResNet-18基本块[2, 2, 2, 2]18
ResNet-34基本块[3, 4, 6, 3]34
ResNet-50瓶颈块[3, 4, 6, 3]50
ResNet-101瓶颈块[3, 4, 23, 3]101
ResNet-152瓶颈块[3, 8, 36, 3]152

  • 整体流程:输入图像 → 7x7 卷积(步长 2)+ 最大池化 → 4 个阶段的残差块堆叠(每个阶段通道数翻倍,尺寸减半) → 全局平均池化 → 全连接层(输出分类结果)。

五、ResNet 的优势

  1. 解决深层网络训练难题:通过残差连接实现梯度有效传递,可训练数百层甚至上千层的网络。
  2. 性能优异:在 ImageNet 等数据集上,ResNet 的错误率显著低于 VGG、GoogLeNet 等模型。
  3. 泛化能力强:残差结构可迁移到其他任务(如目标检测、语义分割),成为许多深度学习模型的基础组件(如 Faster R-CNN、U-Net++)。

六、ResNet 的变体与延伸

  1. ResNeXt:引入 “分组卷积”(Group Convolution),在保持性能的同时减少参数。
  2. DenseNet:将残差连接的 “相加” 改为 “拼接”(Concatenate),强化特征复用。
  3. Res2Net:在残差块中引入多尺度特征融合,提升细粒度特征提取能力。
  4. 应用扩展:从图像分类扩展到目标检测(如 FPN)、视频分析(如 I3D)、自然语言处理(如残差 LSTM)等领域。

七、总结

ResNet 通过残差连接的创新设计,突破了深层网络的训练瓶颈,不仅推动了计算机视觉的发展,也为其他领域的深层模型设计提供了重要思路。其核心思想 ——通过简化学习目标(学习残差)提升模型性能—— 已成为深度学习的经典范式。

import torch
import torch.nn as nn
import torch.nn.functional as Fclass BasicBlock(nn.Module):"""基本残差块,用于ResNet-18/34"""expansion = 1  # 输出通道数是输入的多少倍def __init__(self, in_channels, out_channels, stride=1, downsample=None):super(BasicBlock, self).__init__()# 第一个卷积层self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)# 第二个卷积层(步长固定为1)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.downsample = downsample  # 用于调整输入x的维度以匹配残差def forward(self, x):identity = x  # 保存输入用于残差连接# 计算残差F(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)# 如果需要调整维度,则对输入x进行下采样if self.downsample is not None:identity = self.downsample(x)# 残差连接:H(x) = F(x) + xout += identityout = self.relu(out)return outclass Bottleneck(nn.Module):"""瓶颈残差块,用于ResNet-50/101/152"""expansion = 4  # 输出通道数是中间层的4倍def __init__(self, in_channels, out_channels, stride=1, downsample=None):super(Bottleneck, self).__init__()# 1x1卷积降维self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)# 3x3卷积提取特征self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)# 1x1卷积升维self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, stride=1, bias=False)self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampledef forward(self, x):identity = x# 计算残差F(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)# 调整输入维度if self.downsample is not None:identity = self.downsample(x)# 残差连接out += identityout = self.relu(out)return outclass ResNet(nn.Module):"""ResNet主网络"""def __init__(self, block, layers, num_classes=1000):super(ResNet, self).__init__()self.in_channels = 64  # 初始输入通道数# 初始卷积层self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(self.in_channels)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# 四个阶段的残差块self.layer1 = self._make_layer(block, 64, layers[0], stride=1)self.layer2 = self._make_layer(block, 128, layers[1], stride=2)self.layer3 = self._make_layer(block, 256, layers[2], stride=2)self.layer4 = self._make_layer(block, 512, layers[3], stride=2)# 分类头self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)# 初始化权重for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)def _make_layer(self, block, out_channels, blocks, stride=1):"""构建一个由多个残差块组成的层"""downsample = None# 如果步长不为1或输入输出通道数不匹配,需要下采样调整维度if stride != 1 or self.in_channels != out_channels * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_channels, out_channels * block.expansion,kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * block.expansion),)layers = []# 添加第一个残差块(可能包含下采样)layers.append(block(self.in_channels, out_channels, stride, downsample))self.in_channels = out_channels * block.expansion# 添加剩余的残差块(步长固定为1)for _ in range(1, blocks):layers.append(block(self.in_channels, out_channels))return nn.Sequential(*layers)def forward(self, x):# 初始特征提取x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)# 经过四个残差层x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)# 分类x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return x# 定义不同层数的ResNet
def resnet18(num_classes=1000):return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)def resnet34(num_classes=1000):return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)def resnet50(num_classes=1000):return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)def resnet101(num_classes=1000):return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)def resnet152(num_classes=1000):return ResNet(Bottleneck, [3, 8, 36, 3], num_classes)# 测试代码
if __name__ == "__main__":# 创建ResNet-18模型model = resnet18(num_classes=10)# 随机生成一个3通道输入(模拟224x224图像)x = torch.randn(2, 3, 224, 224)  # batch_size=2# 前向传播output = model(x)print(f"输入形状: {x.shape}")print(f"输出形状: {output.shape}")  # 应输出(2, 10)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/96306.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/96306.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/96306.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

学习嵌入式之驱动

一、基础搭建1.基础:c语言 软件编程语言 数据结构 软件编程思想2.驱动实现目标如果将Linux系统细致到开发板平台上? Liunx系统与硬件设备的适配3.自我能力的锻炼继续强化C语言锻炼大型代码阅读和分析能力学习大型项目的代码搭建和管理的能力…

在 Golang 中复用 HTTP 连接

问题提出最近在实现一个转发大模型调用请求的中转功能,涉及到要构造client发送请求的内容,一开始我每次都是新建一个client来发送请求,这样的代码实现存在一些问题——每次都要构造新的client,并且要重新建立连接。后面了解到在Go…

前端:el-upload文件上传与FormData 对象

<el-uploadclass"uploadDemo":limit"1"dragaccept".xls,.xlsx" <!-- 只保留Excel格式 -->:on-exceed"handleExceedFileLimit":on-change"handleChangeExcelFile":on-remove"handleRemoveExcelFile":bef…

自然处理语言NLP:One-Hot编码、TF-IDF、词向量、NLP特征输入、EmbeddingLayer实现、word2vec

文章目录自然语言处理&#xff08;NLP&#xff09;一、什么是自然语言处理&#xff08;NLP&#xff09;&#xff1f;二、NLP 的核心目标三、NLP 的主要应用方向&#xff08;应用场景&#xff09;四、NLP 的基本概念五、NLP 的基本处理流程1. 文本预处理2. 特征表示3. 模型选择与…

单词记忆-轻松记忆10个实用英语单词(13)

1. board含义&#xff1a;板子&#xff1b;董事会&#xff1b;登机 读音标注&#xff1a;/bɔːrd/ 例句&#xff1a;Write your name on the board. 译文&#xff1a;把你的名字写在板上。 衍生含义&#xff1a;董事会&#xff08;如“board of directors”&#xff09;&#…

Spring循环依赖源码调试详解,用两级缓存代替三级缓存

Spring循环依赖源码详解&#xff0c;改用两级缓存并实验 背景 最近一直在研究Spring的循环依赖&#xff0c;发现好像两级缓存也能解决循环依赖。 关于为何使用三级缓存&#xff0c;大致有两个原因 对于AOP的类型&#xff0c;保证Bean生命周期的顺序 对于有AOP代理增强的类型&am…

亚马逊BALL PIT球池外观专利侵权指控?不侵权意见书助力4条链接申诉成功!

儿童球池作为玩具品类中常见的一款产品&#xff0c;能够给儿童提供游乐的安全空间&#xff0c;深受亚马逊平台用户的喜爱。然而在近期&#xff0c;赛贝收到了部分亚马逊卖家的咨询&#xff0c;原因是他们在售的儿童球池产品链接被美国外观专利USD1009203S&#xff08;下称203专…

开源,LangExtract-Python库用LLM从非结构化文本提取结构化信息

摘要&#xff1a; LangExtract是一个Python库&#xff0c;利用大语言模型&#xff08;LLM&#xff09;根据用户定义指令从非结构化文本文档中提取结构化信息。它具备精确源定位、可靠结构化输出、长文档优化、交互式可视化、灵活LLM支持、适应任意领域等特点。可通过几行代码快…

如何根据团队技术能力选择最适合的PHP框架?

作为一名PHP开发者&#xff0c;面对众多的PHP框架&#xff0c;你是否曾感到选择困难&#xff1f;Laravel、Symfony、CodeIgniter、ThinkPHP…每个框架都有其特色和优势&#xff0c;但没有最好的框架&#xff0c;只有最适合的框架。而选择合适框架的关键因素之一&#xff0c;就是…

多人同时导出 Excel 导致内存溢出

1、问题根因分析多人同时导出Excel导致内存溢出&#xff08;OOM&#xff09;的核心原因是&#xff1a;在短时间内&#xff0c;大量数据被加载到JVM堆内存中&#xff0c;且创建了大量大对象&#xff08;如Apache POI的Cell、Row、Sheet对象&#xff09;&#xff0c;超过了堆内存…

深入 RAG(检索增强生成)系统架构:如何构建一个能查资料的大语言模型系统

&#x1f407;明明跟你说过&#xff1a;个人主页 &#x1f3c5;个人专栏&#xff1a;《深度探秘&#xff1a;AI界的007》 &#x1f3c5; &#x1f516;行路有良友&#xff0c;便是天堂&#x1f516; 目录 一、前言 1、LLM 的局限&#xff1a;模型知识“封闭” vs 现实知识…

linux tftpboot烧写地址分析

1&#xff0c;loadaddr 是一个环境变量&#xff0c;用于指定文件&#xff08;如内核镜像、设备树等&#xff09;加载到内存的起始地址。setenv loadaddr 0x82000000setenv loadaddr 0x80008000saveenv //.保存配置将 loadaddr 设置为 0x82000000&#xff0c;表示后续文件将加载…

硬件工程师9月实战项目分享

目录 简介 人员情况 实战项目简介 功能需求 需求分析 方案设计 电源树设计 时钟树设计 主芯片外围设计 接口设计 模拟链路设计 PCB设计检查要点 测试方案设计 硬件测试培训 测试代码学习 培训目标 掌握基本的硬件设计流程 掌握以FPGA为核心的硬件设计业务知识 …

力扣刷题——59.螺旋矩阵II

力扣刷题——59.螺旋矩阵II 题目 给你一个正整数 n &#xff0c;生成一个包含 1 到 n2 所有元素&#xff0c;且元素按顺时针顺序螺旋排列的 n x n 正方形矩阵 matrix 。示例 1&#xff1a;输入&#xff1a;n 3 输出&#xff1a;[[1,2,3],[8,9,4],[7,6,5]]示例 2&#xff1a; 输…

win11系统还原点恢复系统

背景 系统换位bug11后&#xff0c;真的是各种以前的操作和设置找不到&#xff0c;太烦了&#xff0c;我是没想到&#xff0c;连系统恢复还原点都这么难找。然后搜了一圈都是恢复系统之类的&#xff0c;真的崩溃。只好自己记录了。 ✍内容找到设置—>系统–>系统信息系统信…

DHCP 原理与配置(一)

应用场景随着网络规模的不断扩大&#xff0c;网络复杂度不断提升&#xff0c;网络中的终端设备例如主机、手机、 平板等&#xff0c;位置经常变化。终端设备访问网络时需要配置IP地址、网关地址、DNS服务器 地址等。采用手工方式为终端配置这些参数非常低效且不够灵活。 IETF于…

SARibbon的编译构建及详细用法

目录 1.1 源码构建 1.2 搭建项目 1.3 详细用法 1.4 不同风格 1.5 完整代码 引言:SARibbon是一个专门为Qt框架设计的开源Ribbon风格界面控件库,它模仿了微软Office和WPS的Ribbon UI风格,适用于需要复杂菜单和工具栏的大型桌面程序。本文从源码编译构建到详细使用,做了一…

CSS【详解】性能优化

精简 CSS移除未使用的 CSS&#xff08;“死代码”&#xff09;&#xff0c;可借助工具如 PurgeCSS、UnCSS 自动检测并删除未被页面使用的样式。避免重复样式&#xff0c;通过提取公共样式&#xff08;如 mixin 或公共类&#xff09;减少代码冗余。利用预处理器&#xff08;Sass…

Flutter 线程模型详解:主线程、异步与 Isolate

一、主线程&#xff1a;默认的执行环境 所有代码默认运行在主线程。下面的例子展示了一个会阻塞主线程的错误示范&#xff1a; import package:flutter/material.dart;void main() {runApp(const MyApp()); }class MyApp extends StatelessWidget {const MyApp({super.key});ov…

ChartDB:可视化数据库设计工具私有化部署

ChartDB:可视化数据库设计工具私有化部署一、什么是ChartDB ChartDB 是一款基于 Web 的开源数据库可视化工具&#xff0c;专为简化数据库设计与管理流程而开发。以下是其核心特性与功能概述: 1、核心功能 智能查询可视化‌&#xff1a;通过单条 SQL 查询即可生成数据库架构图&a…