一.前言

本章节我们是要学习梯队计算,⾃动微分(Autograd)模块对张量做了进⼀步的封装,具有⾃动求导功能。⾃动微分模块是构成神经⽹络 训练的必要模块,在神经⽹络的反向传播过程中,Autograd 模块基于正向计算的结果对当前的参数进⾏微 分计算,从⽽实现⽹络权重参数的更新。

二.梯度基本计算

我们使⽤ backward ⽅法、grad 属性来实现梯度的计算和访问.

import torch# 1. 单标量梯度的计算
# y = x**2 + 20
def test01():# 定义需要求导的张量# 张量的值类型必须是浮点类型x = torch.tensor(10, requires_grad=True, dtype=torch.float64)# 变量经过中间运算f = x ** 2 + 20# ⾃动微分f.backward()# 打印 x 变量的梯度# backward 函数计算的梯度值会存储在张量的 grad 变量中print(x.grad)# 2. 单向量梯度的计算# y = x**2 + 20def test02():# 定义需要求导张量x = torch.tensor([10, 20, 30, 40], requires_grad=True, dtype=torch.float64)# 变量经过中间计算f1 = x ** 2 + 20# 注意:# 由于求导的结果必须是标量# ⽽ f 的结果是: tensor([120., 420.])# 所以, 不能直接⾃动微分# 需要将结果计算为标量才能进⾏计算f2 = f1.mean()  # f2 = 1/2 * x   2x/4# ⾃动微分f2.backward()# 打印 x 变量的梯度print(x.grad)if __name__ == '__main__':test01()test02()

tensor(20., dtype=torch.float64)
tensor([ 5., 10., 15., 20.], dtype=torch.float64) 

三.控制梯度计算 

我们可以通过⼀些⽅法使得在 requires_grad=True 的张量在某些时候计算不进⾏梯度计算。 

import torch# 1. 控制不计算梯度
def test01():x = torch.tensor(10, requires_grad=True, dtype=torch.float64)print(x.requires_grad)# 第⼀种⽅式: 对代码进⾏装饰with torch.no_grad():y = x ** 2print(y.requires_grad)# 第⼆种⽅式: 对函数进⾏装饰@torch.no_grad()def my_func(x):return x ** 2print(my_func(x).requires_grad)# 第三种⽅式torch.set_grad_enabled(False)y = x ** 2print(y.requires_grad)# 2. 注意: 累计梯度
def test02():# 定义需要求导张量x = torch.tensor([10, 20, 30, 40], requires_grad=True, dtype=torch.float64)for _ in range(3):f1 = x ** 2 + 20f2 = f1.mean()# 默认张量的 grad 属性会累计历史梯度值# 所以, 需要我们每次⼿动清理上次的梯度# 注意: ⼀开始梯度不存在, 需要做判断if x.grad is not None:x.grad.data.zero_()f2.backward()print(x.grad)# 3. 梯度下降优化最优解
def test03():# y = x**2x = torch.tensor(10, requires_grad=True, dtype=torch.float64)for _ in range(5000):# 正向计算f = x ** 2# 梯度清零if x.grad is not None:x.grad.data.zero_()# 反向传播计算梯度f.backward()# 更新参数x.data = x.data - 0.001 * x.gradprint('%.10f' % x.data)if __name__ == '__main__':test01()# print('--------------------')# test02()# print('--------------------')# test03()

这里得分开打印,就不在展示结果了,大家打印一下看看。

四.梯度计算注意

当对设置 requires_grad=True 的张量使⽤ numpy 函数进⾏转换时, 会出现如下报错:

Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead. 

此时, 需要先使⽤ detach 函数将张量进⾏分离, 再使⽤ numpy 函数. 

注意: detach 之后会产⽣⼀个新的张量, 新的张量作为叶⼦结点,并且该张量和原来的张量共享数据, 但是分 离后的张量不需要计算梯度。 

import torch# 1. detach 函数⽤法
def test01():x = torch.tensor([10, 20], requires_grad=True, dtype=torch.float64)# Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.# print(x.numpy())  # 错误print(x.detach().numpy())  # 正确# 2. detach 前后张量共享内存
def test02():x1 = torch.tensor([10, 20], requires_grad=True, dtype=torch.float64)# x2 作为叶⼦结点x2 = x1.detach()# 两个张量的值⼀样: 140421811165776 140421811165776print(id(x1.data), id(x2.data))x2.data = torch.tensor([100, 200])print(x1)print(x2)# x2 不会⾃动计算梯度: Falseprint(x2.requires_grad)if __name__ == '__main__':test01()test02()

结果展示: 

[10. 20.]
1834543349008 1834543349008
tensor([10., 20.], dtype=torch.float64, requires_grad=True)
tensor([100, 200])
False 

五.总结 

本⼩节主要讲解了 PyTorch 中⾮常重要的⾃动微分模块的使⽤和理解。我们对需要计算梯度的张量需要设 置 requires_grad=True 属性,并且需要注意的是梯度是累计的,在每次计算梯度前需要先进⾏梯度清零。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/914587.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/914587.shtml
英文地址,请注明出处:http://en.pswp.cn/news/914587.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度学习·目标检测和语义分割基础

边缘框 不是标准的x,y坐标轴。边缘框三种表示:左上右下下坐标,左上坐标长宽,中心坐标长宽 COCO 目标检测数据集的格式:注意一个图片有多个物体,使用csv或者文件夹结构的格式不可取。 锚框算法 生成很多…

ffmpeg音视频处理大纲

FFmpeg 是一个功能强大的开源音视频处理工具集,其核心代码以 C 语言实现。下面从源码角度分析 FFmpeg 如何实现转码、压缩、提取、截取、拼接、合并和录屏等功能: 一、FFmpeg 核心架构与数据结构 FFmpeg 的源码结构围绕以下核心组件展开: lib…

网络安全小练习

一、docker搭建 1.安装 2.改变镜像源(推荐国内镜像源:阿里云镜像源) 登录阿里云容器镜像源服务( 阿里云登录 - 欢迎登录阿里云,安全稳定的云计算服务平台 ) 复制系统分配的专属地址 配置 sudo mkdir …

数据结构——顺序表的相关操作

一、顺序表基础认知​1.顺序表的定义与特点​顺序表是数据结构中一种线性存储结构,它将数据元素按照逻辑顺序依次存储在一片连续的物理内存空间中。简单来说,就是用一段地址连续的存储单元依次存放线性表的元素,且元素之间的逻辑关系通过物理…

2025最新国产用例管理工具评测:Gitee Test、禅道、蓝凌测试、TestOps 哪家更懂研发协同?

在快节奏的 DevOps 时代,测试用例管理已不再是 QA 的独角戏,而是穿透需求—开发—测试—交付全流程的核心枢纽。想象一下,如果用例结构混乱,覆盖不全,甚至丢失版本变更历史,不仅协作乱,还影响交…

在线评测系统开发交流

https://space.bilibili.com/700332132?spm_id_from333.788.0.0 实验内容爬虫Web系统设计数据分析实验指导爬虫Web系统设计自然语言处理与信息检索数据可视化评分标准FAQ实验二:在线评测系统实验概述实验内容Step1:题目管理Step2:题目评测S…

Linux操作系统从入门到实战(十)Linux开发工具(下)make/Makefile的推导过程与扩展语法

Linux操作系统从入门到实战(十)Linux开发工具(下)make/Makefile的推导过程与扩展语法前言一、 make/Makefile的推导过程1. 先看一个完整的Makefile示例2. make的工作流程(1)寻找Makefile文件(2&…

NFS磁盘共享

步骤:注意事项‌:确保服务端防火墙关闭,或者允许2049端口通信,客户端需具备读写权限。服务器端安装NFS服务器:sudo apt-get install nfs-kernel-server # Debian/Ubuntu sudo yum install nfs-utils # Ce…

ORA-06413: 连接未打开

System.Data.OracleClient.OracleException:ORA-06413: 连接未打开 oracle 报错 ORA-06413: 连接未打开 db.Open();的报错链接未打开,System.Data.OracleClient.OracleException HResult0x80131938 MessageORA-06413: 连接未打开 关于ORA-06413错误(…

【PCIe 总线及设备入门学习专栏 5.1.2 -- PCIe EP core_rst_n 与 app_rst_n】

文章目录 app_rst_n 和 core_rst_n 的作用1. core_rst_n — PCIe 控制器内部逻辑复位作用控制方式2. app_rst_n — 应用层/用户逻辑复位作用特点两者关系图示:示例流程(Synopsys EP)rst_sync[3] 的作用详解(复位同步逻辑)为什么使用 rst_sync[3]?图示说明Synopsys 官方手…

Python初学者笔记第二十期 -- (文件IO)

第29节课 文件IO 在编程中,文件 I/O(输入/输出)允许程序与外部文件进行数据交互。Python 提供了丰富且易用的文件 I/O 操作方法,能让开发者轻松实现文件的读取、写入和修改等操作。 IO交互方向 从硬盘文件 -> 读取数据 -> 内…

Java JUC包概述

Java 的 java.util.concurrent(简称 JUC)包是 JDK 5 及以后引入的并发编程工具包,旨在解决传统线程模型(如 synchronized、wait/notify)的局限性,提供更灵活、高效、可扩展的并发编程组件。它极大简化了多线…

LeetCode--44.通配符匹配

前言:不知不觉又断更一天了,其实昨天就把这道题写得差不多了,只是刚好在力扣里面看见了一种新的解法,本来想写出来的,但是我把它推到今天了,因为太晚了,但是今天又睡懒觉了,所以我直…

WHAT - 依赖管理工具 CocoaPods

文章目录1. 什么是 CocoaPods?2. 如何安装 CocoaPods?(1) 确保已安装 Ruby(macOS 默认自带)(2) 安装 CocoaPods(3) 验证安装3. 在 React Native 项目中使用 CocoaPods(1) 进入 iOS 目录(2) 初始化 Podfile(如果不存在&…

C++ Boost Aiso TCP 网络聊天(服务端客户端一体化)

代码功能说明: 程序模式: 主动连接模式:当用户指定对端 IP 和端口时,尝试连接到对端被动监听模式:当用户未指定对端 IP 时,等待其他节点连接线程模型: 主线程:处理用户输入和消息发送接收线程:后台接收并显示对端消息关键组件: std::atomic<bool> connected:原…

WeakAuras 5.12.9 Ekkles lua

3.45猎人宝宝狼 技能恢复宏已知3.45BUG RL技能位会清空&#xff0c;小退大退 BB技能全部激活&#xff0c;修复以前可用宏一键恢复状态-------方法一&#xff1a;宏命令---------------------------------------------------------#showtooltip 狂怒之嚎 /petautocaston [btn:1]…

对于编写PID过程中的问题

当stm32RCT6使用位置环pid控制麦轮转动一定路程时&#xff0c;在这个时间段内想让一边轮胎速度加大应该怎么做&#xff1f;比如我pid的目标脉冲值为9000&#xff0c;在运行到3000的时候车偏左了&#xff0c;那我应该怎样让他回正&#xff0c;我想到的办法是增加其最大的脉冲值&…

LeetCode|Day13|88. 合并两个有序数组|Python刷题笔记

LeetCode&#xff5c;Day13&#xff5c;88. 合并两个有序数组&#xff5c;Python刷题笔记 &#x1f5d3;️ 本文属于【LeetCode 简单题百日计划】系列 &#x1f449; 点击查看系列总目录 >> &#x1f4cc; 题目简介 题号&#xff1a;88. 合并两个有序数组 难度&#xf…

【C++】初识C++(1)

个人主页&#xff1a;我要成为c嘎嘎大王 希望这篇小小文章可以让你有所收获&#xff01; 目录 前言 一、C的第一个程序 二、命名空间 2.1 namespace 的价值 2.2 namespace 的定义 2.2.1 正常的命名空间定义 2.2.2 命名空间可以嵌套 2.2.3 匿名命名空间 2.2.4 同名的name…

在新闻资讯 APP 中添加不同新闻分类页面,通过 ViewPager2 实现滑动切换

在新闻资讯 APP 中添加不同新闻分类页面&#xff0c;通过 ViewPager2 实现滑动切换 核心组件的作用 ViewPager2&#xff1a;是 ViewPager 的升级版&#xff0c;基于RecyclerView实现&#xff0c;支持水平 / 垂直滑动、RTL&#xff08;从右到左&#xff09;布局&#xff0c;且修…