自动微分:基础概念与应用

自动微分(Autograd)是现代深度学习框架(如PyTorch、TensorFlow)中的一个核心功能。它通过构建计算图并在计算图上自动计算梯度,简化了反向传播算法的实现。以下是自动微分的基本概念及其操作。


1. 基础概念

自动微分指的是通过跟踪计算图中的每一步计算,自动计算目标函数相对于模型参数的梯度。这些计算图是在每次前向传播时动态构建的。基于这个图,系统可以在反向传播时自动计算梯度,而不需要手动推导每个梯度。

1.1 张量

torch中一切皆为张量,属性requires_grad决定是否对其进行梯度计算。默认是False,如需计算梯度则设置为True

1.2 计算图

torch.autograd通过创建一个动态计算图来跟踪张良的操作,每个张量是计算图中的一个节点,节点之间的操作构成图的边。

在Pytorch中,当张量的requiers_grad=Ture时,Pytorch会自动跟踪与该张量相关的所有操作,并构建计算图。每个操作都会生成一个新的张量,并记录其依赖关系。当设置为True时,表示该张量在计算图中需要参与梯度计算,即在反向传播(Backpropagation)过程中惠子dog计算其梯度;当设置为False时,不会计算梯度。
例如
z=x∗yloss=z.sum()z = x * y\\loss = z.sum()z=xyloss=z.sum()
在上述代码中,x 和 y 是输入张量,即叶子节点,z 是中间结果,loss 是最终输出。每一步操作都会记录依赖关系:

z = x * y:z 依赖于 x 和 y。

loss = z.sum():loss 依赖于 z。

这些依赖关系形成了一个动态计算图,如下所示:

	  x       y\     /\   /\ /z||vloss

叶子节点

在 PyTorch 的自动微分机制中,叶子节点(leaf node) 是计算图中:

  • 由用户直接创建的张量,并且它的 requires_grad=True。
  • 这些张量是计算图的起始点,通常作为模型参数或输入变量。

特征:

  • 没有由其他张量通过操作生成。
  • 如果参与了计算,其梯度会存储在 leaf_tensor.grad 中。
  • 默认情况下,叶子节点的梯度不会自动清零,需要显式调用 optimizer.zero_grad() 或 x.grad.zero_() 清除。

如何判断一个张量是否是叶子节点?

通过 tensor.is_leaf 属性,可以判断一个张量是否是叶子节点。

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)  # 叶子节点
y = x ** 2  # 非叶子节点(通过计算生成)
z = y.sum()print(x.is_leaf)  # True
print(y.is_leaf)  # False
print(z.is_leaf)  # False

叶子节点与非叶子节点的区别

特性叶子节点非叶子节点
创建方式用户直接创建的张量通过其他张量的运算生成
is_leaf 属性TrueFalse
梯度存储梯度存储在 .grad 属性中梯度不会存储在 .grad,只能通过反向传播传递
是否参与计算图是计算图的起点是计算图的中间或终点
删除条件默认不会被删除在反向传播后,默认被释放(除非 retain_graph=True)

detach():张量 x 从计算图中分离出来,返回一个新的张量,与 x 共享数据,但不包含计算图(即不会追踪梯度)。

特点

  • 返回的张量是一个新的张量,与原始张量共享数据。
  • 对 x.detach() 的操作不会影响原始张量的梯度计算。
  • 推荐使用 detach(),因为它更安全,且在未来版本的 PyTorch 中可能会取代 data。
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x.detach()  # y 是一个新张量,不追踪梯度y += 1  # 修改 y 不会影响 x 的梯度计算
print(x)  # tensor([1., 2., 3.], requires_grad=True)
print(y)  # tensor([2., 3., 4.])

反向传播

使用tensor.backward()方法执行反向传播,从而计算张量的梯度。这个过程会自动计算每个张量对损失函数的梯度。例如:调用 loss.backward() 从输出节点 loss 开始,沿着计算图反向传播,计算每个节点的梯度。

梯度

计算得到的梯度通过tensor.grad访问,这些梯度用于优化模型参数,以最小化损失函数。

2. 计算梯度

2.1 标量梯度计算

标量梯度计算指的是计算标量(通常是损失函数)相对于模型参数的梯度。在深度学习中,常见的损失函数(如均方误差、交叉熵等)都是标量值。

import torch# 定义张量
x = torch.tensor(2.0, requires_grad=True)
y = x**2 + 3*x + 1  # 定义标量函数# 计算梯度
y.backward()  # 反向传播
print(x.grad)  # 输出x的梯度

为何需要标量梯度?

  • 在训练过程中,我们需要计算损失函数相对于各个参数的梯度,从而调整模型参数。标量梯度的计算是整个训练过程中优化模型的基础。
2.2 向量梯度计算

向量梯度计算用于计算多维向量函数相对于输入向量的梯度。例如,输出是一个向量时,我们希望计算每个分量的梯度。

# 定义张量
x = torch.tensor([2.0, 3.0], requires_grad=True)
y = x**2  # 计算每个元素的平方# 计算梯度
y.backward(torch.tensor([1.0, 1.0]))  # 向量梯度计算
print(x.grad)  # 输出x的梯度

为何需要向量梯度?

  • 在多输入多输出的情况下,向量梯度计算能有效地描述每个输入对于输出的影响。
2.3 多标量梯度计算

在一些复杂的场景中,损失函数可能有多个标量输出。我们需要计算每个标量输出对参数的梯度。

x = torch.tensor([2.0, 3.0], requires_grad=True)
y1 = x[0]**2 + 3*x[0] + 1
y2 = x[1]**3 + 2*x[1] - 5
y = y1 + y2  # 多标量函数y.backward()  # 计算梯度
print(x.grad)

为何需要多标量梯度?

  • 多标量梯度有助于处理多任务学习中的梯度计算,特别是当每个任务有不同的损失函数时。
2.4 多向量梯度计算

当输出是多个向量时,我们通常需要计算每个向量对每个输入的梯度。比如在生成对抗网络(GAN)或多任务学习中,常见这种情况。

x = torch.tensor([1.0, 2.0], requires_grad=True)
y1 = x[0]**2 + 3*x[0]
y2 = x[1]**3 + 2*x[1]grad_outputs = torch.tensor([1.0, 1.0])  # 指定多个输出梯度
y1.backward(grad_outputs)  # 分别计算y1和y2的梯度

为何需要多向量梯度?

  • 计算多个输出的梯度可以帮助我们进行多维度的优化,尤其是在复杂的网络结构中,多个输出有助于提高模型的多样性和鲁棒性。

3. 梯度上下文控制

在深度学习中,常常需要控制梯度计算的上下文,以节省内存或者针对性地优化某些参数。

3.1 控制梯度计算

我们可以通过torch.no_grad()with torch.set_grad_enabled(False)来临时停止梯度计算,这对于不需要计算梯度的操作(例如推理阶段)非常有用。

with torch.no_grad():y = x * 2  # 在此块中,不会计算梯度

为何控制梯度计算?

  • 在推理阶段,我们不需要梯度,这样可以节省计算资源和内存。
3.2 累计梯度

在某些情况下,梯度计算需要分多个小批次进行累计。例如,使用小批次训练时,梯度会在每个小批次上累加。

optimizer.zero_grad()  # 清空之前的梯度
y.backward()  # 累计梯度
optimizer.step()  # 更新参数

为何累计梯度?

  • 累计梯度可以使得模型在小批次上进行优化,而不丢失总体梯度信息,适用于大规模数据的训练。
3.3 梯度清零

在每次更新前,我们需要清空之前计算的梯度,否则它会在下一步的计算中累加。

optimizer.zero_grad()  # 清除上次计算的梯度

为何清零梯度?

  • 防止梯度计算的累积影响下一次计算,确保每次计算梯度时的准确性。

4. 案例分析

4.1 求函数最小值

通过计算梯度并使用优化算法(如梯度下降),我们可以找到函数的最小值。

x = torch.tensor(2.0, requires_grad=True)
for _ in range(100):y = x**2 + 3*x + 1y.backward()with torch.no_grad():x -= 0.1 * x.grad  # 使用梯度更新xx.grad.zero_()  # 清空梯度

为何使用梯度下降求解最小值?

  • 通过不断调整参数,沿着梯度方向前进,直到收敛到函数的最小值。
4.2 函数参数求解

如果已知函数并希望通过梯度来求解某些未知参数,可以使用反向传播来更新这些参数。

def func(x):return x**2 - 4*x + 3x = torch.tensor(3.0, requires_grad=True)
for i in range(100):y = func(x)y.backward()x.data -= 0.1 * x.gradx.grad.zero_()

为何求解函数参数?

  • 在机器学习中,模型的参数通过梯度计算来优化,进而提高模型的性能。

结论

自动微分的引入让深度学习框架大大简化了梯度计算过程。通过自动计算标量、向量梯度以及控制梯度的计算上下文,开发者可以专注于模型设计而非手动推导梯度公式。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/94461.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/94461.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/94461.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

k8s原理及操作

简介 kubernetes的本质是一组服务器集群,它可以在集群的每个节点上运行特定的程序,来对节点中的容器 进行管理。目的是实现资源管理的自动化,主要提供了如下的主要功能: 自我修复:一旦某一个容器崩溃,能够在…

理解音频响度:LUFS 标准及其计算实现

LUFS 及其重要性 1.1、什么是 LUFS? LUFS(Loudness Units relative to Full Scale)是音频工程中用于测量感知响度的标准单位。它已成为广播、流媒体和音乐制作领域的行业标准,用于确保不同音频内容具有一致的响度水平。 LUFS 是 I…

【在ubuntu下使用vscode打开c++的make项目及编译调试】

在ubuntu下使用vscode打开c的make项目及编译调试第一步:安装必要的软件第二步:示例项目准备1. 创建C源文件: main.cpp2. 创建头文件: utils.h3. 创建实现文件: utils.cpp第三步:使用 VS Code 打开项目第四步…

3-2.Python 函数 - None(None 概述、None 应用场景)

一、None 概述在 Python 中,None 是一个特殊的常量,用于表示空值或无值None 是 Python 中唯一的一个 NoneType 类型的实例二、None 应用场景 1、定义变量 None 常用于初始化变量,表示该变量暂时不需要有具体值 name Noneprint(name) print(t…

js获取html元素并设置高度为100vh-键盘高度

获取HTML元素并设置高度为(100vh - 键盘高度) 我将设计一个页面,展示如何获取HTML元素并动态设置其高度为视口高度减去键盘高度,这在移动设备上特别有用,可以避免键盘遮挡内容。 设计思路 创建一个带有输入框的界面,模拟键盘弹…

基于SpringBoot的校园博客管理系统

🔗 目录 一. 前言   二. 前端框架、后端框架以及存储框架使用情况说明   三. 核心技术     1. ✅Java开发语言     2. ✅MyBatis     3. ✅Mysql     4. ✅Vue     5. ✅部署项目   四. 演示效果     1. 管理员功能模块       …

Nginx + Certbot配置 HTTPS / SSL 证书

前提条件: 1.已有域名 2.Nginx 已安装并正在运行,且有对应的 Server 配置 3.防火墙开放 80 和 443 端口 安装 EPEL 仓库: sudo yum install epel-release -y安装 Snapd sudo yum install snapd -y启用并启动 Snapd Socket sudo systemctl ena…

图结构使用 Louvain 社区检测算法进行分组

图结构使用 Louvain 社区检测算法进行分组 flyfish Louvain 算法是一种基于模块度最大化的社区检测算法,核心目标是在复杂网络中找到“内部连接紧密、外部连接稀疏”的社区结构。它的优势在于高效性(可处理百万级节点的大规模网络)和近似最优…

layui.formSelects自定义多选组件在layer.open中使用、获取、复现

layui.formSelects自定义多选组件在layer.open中使用、获取、复现 引入css和js //<th:block th:include"include :: layui-formSelects-css"/> <link th:href"{/ajax/libs/layui-formSelects/formSelects-v4.css}" rel"stylesheet"/>…

基于SpringBoot的社团管理系统【2026最新】

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏&#xff1a;…

运行node18报错

又碰到一个奇葩的问题&#xff0c;报错如下> tigermes.vue30.1.0 serve > vue-cli-service serveBrowserslist: caniuse-lite is outdated. Please run:npx update-browserslist-dblatestWhy you should do it regularly: https://github.com/browserslist/update-db#rea…

Python第三方库IPFS-API使用详解:构建去中心化应用的完整指南

目录 Python第三方库IPFS-API使用详解&#xff1a;构建去中心化应用的完整指南 引言&#xff1a;IPFS与去中心化存储的革命 星际文件系统&#xff08;IPFS&#xff0c;InterPlanetary File System&#xff09;是一种革命性的点对点超媒体协议&#xff0c;旨在创建持久且分布式的…

ETL与iPaaS的融合方案:加速数据集成流程

在今天的商业世界里&#xff0c;数据几乎无处不在。企业每天都在产生和接收海量的数据——从CRM到ERP&#xff0c;从云端SaaS应用到本地数据库&#xff0c;来源越来越分散&#xff0c;集成也越来越复杂。 传统的ETL工具&#xff08;提取、转换、加载&#xff09;在处理结构化数…

详解flink SQL基础(四)

文章目录1.Flink SQL介绍2.streaming SQL&watermarks使用3.窗口聚合&#xff08;window aggregations&#xff09;4.over aggregations5.FlinkSQL 流连接&#xff08;Streaming join&#xff09;6.使用MATCH_RECOGNIZE 进行模式识别和复杂事件处理7.变更记录&#xff08;ch…

有鹿机器人:为城市描绘清洁新图景的智能使者

一、智慧清洁&#xff1a;科技赋能的环境革新每天清晨&#xff0c;当我沿着小区路径缓缓行驶&#xff0c;双激光雷达系统便开始精准测绘环境。我的专业清扫能力源自2cm精度死亡贴边技术&#xff0c;这项让同行惊叹的能力&#xff0c;可以轻松震出嵌了十年的烟头&#xff0c;彻底…

Tableau Server高危漏洞允许攻击者上传任意恶意文件

Tableau Server 存在一个严重安全漏洞&#xff0c;可能允许攻击者上传并执行恶意文件&#xff0c;最终导致系统完全沦陷。该漏洞编号为 CVE-2025-26496&#xff0c;CVSS 评分为 9.6 分&#xff0c;影响 Windows 和 Linux 平台上的多个 Tableau Server 和 Tableau Desktop 版本。…

数据结构07(Java)-- (堆,大根堆,堆排序)

前言 本文为本小白&#x1f92f;学习数据结构的笔记&#xff0c;将以算法题为导向&#xff0c;向大家更清晰的介绍数据结构相关知识&#xff08;算法题都出自&#x1f64c;B站马士兵教育——左老师的课程&#xff0c;讲的很好&#xff0c;对于想入门刷题的人很有帮助&#x1f4…

onnx入门教程(七)——如何添加 TensorRT 自定义算子

在前面的模型入门系列文章中&#xff0c;我们介绍了部署一个 PyTorch 模型到推理后端&#xff0c;如 ONNXRuntime&#xff0c;这其中可能遇到很多工程性的问题。有些可以通过创建 ONNX 节点来解决&#xff0c;该节点仍然使用后端原生的实现进行推理。而有些无法导出到后端的算法…

YggJS RButton 按钮组件 v1.0.0 使用教程

&#x1f4cb; 目录 简介核心特性快速开始安装指南基础使用主题系统高级功能API 参考最佳实践性能优化故障排除总结 &#x1f680; 简介 YggJS RButton 是一个专门为 React 应用程序设计的高性能按钮组件库。它提供了两套完整的设计主题&#xff1a;科技风主题和极简主题&…

Linux(二十)——SELinux 概述与状态切换

文章目录前言一、SELinux 概述1.1 SELinux 简介1.2 SELinux 特点1.2.1 MAC&#xff08;Mandatory Access Control&#xff09;1.2.2 RBAC&#xff08;Role-Based Access Control&#xff09;1.2.3 TE&#xff08;Type Enforcement&#xff09;1.3 SELinux 的执行模式1.4 SELinu…