0. 引出问题

在神经网络反向传播过程中 loss = [loss₁,loss₂, loss₃],为什么 ∂loss/∂w

∂loss₁/∂w 
∂loss₂/∂w
∂loss₃/∂w 

∂loss₁/∂w 和 loss 维度一样都是三位向量 ,[∂loss₁/∂w, ∂loss₂/∂w, ∂loss₃/∂w] 就变成3*3的矩阵
如下所示:

import torchw = torch.tensor([1.0, 2.0,3.0], requires_grad=True)
loss = w * 3  
print("loss: \n", loss)loss_m = []for i, val in enumerate(loss):w.grad = None  # 清零val.backward(retain_graph=True)print(f"∂loss{i+1}/∂w = {w.grad}")loss_m.append(w.grad.clone())print("loss_m: \n", torch.stack(loss_m))

输出结果:

loss: tensor([3., 6., 9.], grad_fn=<MulBackward0>)∂loss1/∂w = tensor([3., 0., 0.])
∂loss2/∂w = tensor([0., 3., 0.])
∂loss3/∂w = tensor([0., 0., 3.])loss_m: tensor([[3., 0., 0.],[0., 3., 0.],[0., 0., 3.]])

loss: tensor([3., 6., 9.]) 为向量,对w求导时为矩阵
但是 w.grad 必须 是标量或张量,不能是向量矩阵

1. 标量求导

import torchw = torch.tensor([1.0, 2.0,3.0], requires_grad=True)
loss = w * 3  
print("loss: \n", loss)loss_m = []
# 方法1:分别计算
for i, val in enumerate(loss):w.grad = None  # 清零val.backward(retain_graph=True)print(f"∂loss{i+1}/∂w = {w.grad}")loss_m.append(w.grad.clone())print("loss_m: \n", torch.stack(loss_m))grads = torch.autograd.grad(loss.sum(), w,retain_graph=True)
print("grads: \n", grads)  grads1 = torch.autograd.grad(loss.mean(), w)[0]
print("grads1: \n", grads1) 

输出;

loss: tensor([3., 6., 9.], grad_fn=<MulBackward0>)
∂loss1/∂w = tensor([3., 0., 0.])
∂loss2/∂w = tensor([0., 3., 0.])
∂loss3/∂w = tensor([0., 0., 3.])
loss_m: tensor([[3., 0., 0.],[0., 3., 0.],[0., 0., 3.]])
grads: (tensor([3., 3., 3.]),)
grads1: tensor([1., 1., 1.])

同样的例子:

import torch# 3个样本的真实数据
x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], requires_grad=True)
y_true = torch.tensor([1.0, 2.0, 3.0])# 线性模型:y = w₁x₁ + w₂x₂
w = torch.tensor([0.5, 0.5], requires_grad=True)
predictions = (x @ w)  # [1.5, 3.5, 5.5]
print("预测值:", predictions)
# 计算每个样本的梯度
individual_grads = []
for i in range(3):loss = (predictions[i] - y_true[i])**2loss.backward(retain_graph=True)individual_grads.append(w.grad.clone())w.grad.zero_()print("样本1梯度:", individual_grads[0]) 
print("样本2梯度:", individual_grads[1])  
print("样本3梯度:", individual_grads[2])  # 标量梯度:自动综合
total_loss = ((predictions - y_true)**2).mean()
total_loss.backward()# 验证:标量梯度 = 向量梯度的平均
manual_average = (individual_grads[0] + individual_grads[1] + individual_grads[2]) / 3print("手动平均:", manual_average)  
print("标量结果:", w.grad)  

输出结果:

预测值: tensor([1.5000, 3.5000, 5.5000], grad_fn=<MvBackward0>)
样本1梯度: tensor([1., 2.])
样本2梯度: tensor([ 9., 12.])
样本3梯度: tensor([25., 30.])
手动平均: tensor([11.6667, 14.6667])
标量结果: tensor([11.6667, 14.6667])

训练神经网络是为了最小化整体损失,不是单独优化每个样本

# 实际训练:最小化平均损失
batch_loss = individual_losses.mean()  # 标量
batch_loss.backward()  # 得到平均梯度
optimizer.step()       # 朝平均最优方向更新

2. 什么时候需要向量梯度?

仅用于研究:分析样本敏感性

def compute_sample_gradients(model, x, y):"""仅用于分析,不用于训练"""grads = []for xi, yi in zip(x, y):model.zero_grad()pred = model(xi.unsqueeze(0))loss = ((pred - yi) ** 2)loss.backward()grads.append(model.weight.grad.clone())return grads  # 每个样本的单独梯度

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/920008.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/920008.shtml
英文地址,请注明出处:http://en.pswp.cn/news/920008.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

tcpdump命令打印抓包信息

tcpdump命令打印抓包信息 下面是在服务器抓取打印服务端7701端口打印 rootgb:/home/gb# ifconfig -a eth0: flags4163<UP,BROADCAST,RUNNING,MULTICAST> mtu 1500inet 10.250.251.197 netmask 255.255.255.0 broadcast 10.250.251.255inet6 fe80::76fe:48ff:fe94:5a5 …

Mysql-经典实战案例(13):如何通过Federated实现跨实例访问表

实现原理&#xff1a;使用Federated引擎本创建一个链接表实现&#xff0c;但是Federated 引擎只是一个按列的顺序和类型解析远程返回的数据流准备工作&#xff1a; 1. 本地库启用 Federated 引擎查看是否已启用&#xff1a; SHOW ENGINES;如果Federated 引擎的 Support 是 YES …

Linux -- 动静态库

一、什么是库1、动静态库概念# 库是写好的现有的&#xff0c;成熟的&#xff0c;可以复⽤的代码。现实中每个程序都要依赖很多基础的底层库&#xff0c;不可能每个⼈的代码都从零开始&#xff0c;因此库的存在意义⾮同寻常。# 本质上来说库是⼀种可执⾏代码的⼆进制形式&#x…

Linux笔记---单例模式与线程池

1. 单例模式单例模式是一种常用的设计模式&#xff0c;它确保一个类只有一个实例&#xff0c;并提供一个全局访问点来获取这个实例。这种模式在需要控制资源访问、管理共享状态或协调系统行为时非常有用。单例模式的核心特点&#xff1a;私有构造函数&#xff1a;防止外部通过n…

Linux中的指令

1.adduseradduser的作用是创立一个新的用户。当我们在命令行中输入1中的指令后&#xff0c;就会弹出2中的命令行&#xff0c;让我们设立新的密码&#xff0c;紧接着就会让我们再次输入新的密码&#xff0c;对于密码的输入它是不会显示出来的&#xff0c;如果输入错误就会让我们…

【n8n】Docker容器中安装ffmpeg

容器化部署 n8n 时&#xff0c;常常会遇到一些环境依赖问题。缺少 docker 命令或无法安装 ffmpeg 是较为常见的场景&#xff0c;如果处理不当&#xff0c;会导致流程执行受限。 本文介绍如何在 n8n 容器中解决 docker 命令不可用和 ffmpeg 安装受限的问题&#xff0c;并给出多…

【基础算法】初识搜索:递归型枚举与回溯剪枝

文章目录一、搜索1. 什么是搜索&#xff1f;2. 遍历 vs 搜索3. 回溯与剪枝二、OJ 练习1. 枚举子集 ⭐(1) 解题思路(2) 代码实现2. 组合型枚举 ⭐(1) 解题思路请添加图片描述(2) 代码实现3. 枚举排列 ⭐(1) 解题思路(2) 代码实现4. 全排列问题 ⭐(1) 解题思路(2) 代码实现一、搜…

Node.js异步编程——async/await实现

一、async/await基础语法 在Node.Js编程中,async关键字用于定义异步函数,这个异步函数执行完会返回一个Promise对象,异步函数的内部可以使用await关键字来暂停当前代码的继续执行,直到Promise操作完成。 在用法上,async关键字主要用于声明一个异步函数,await关键字主要…

搭建一个简单的Agent

准备本案例使用deepseek&#xff0c;登录deepseek官网&#xff0c;登录账号&#xff0c;充值几块钱&#xff0c;然后创建Api key可以创建虚拟环境&#xff0c;python版本最好是3.12&#xff0c;以下是文件目录。test文件夹中&#xff0c;放一些txt文件做测试&#xff0c;main.p…

uv,下一代Python包管理工具

什么是uv uv&#xff08;Universal Virtual&#xff09;是由Astral团队&#xff08;知名Python工具Ruff的开发者&#xff09;推出的下一代Python包管理工具&#xff0c;使用Rust编写。它集成了包管理、虚拟环境、依赖解析、Python版本控制等功能&#xff0c;它聚焦于三个关键点…

单片机的输出模式推挽和开漏如何选择呢?

推挽和开漏是单片机的输出模式&#xff0c;属于I/O口配置的常见类型。开漏&#xff08;Open-Drain&#xff09;和推挽&#xff08;Push-Pull&#xff09;是两种根本不同的输出电路结构&#xff0c;理解它们的区别是正确使用任何单片机&#xff08;包括51和STM32&#xff09;GPI…

java18学习笔记-Simple Web Server

408:Simple Web Server Python、Ruby、PHP、Erlang 和许多其他平台提供从命令行运行的开箱即用服务器。这种现有的替代方案表明了对此类工具的公认需求。 提供一个命令行工具来启动仅提供静态文件的最小web服务器。没有CGI或类似servlet的功能可用。该工具将用于原型设计、即…

深度解析Atlassian 团队协作套件(Jira、Confluence、Loom、Rovo)如何赋能全球分布式团队协作

无穷无尽的聊天记录、混乱不堪的文档、反馈信息分散在各个不同时区……在全球分布式团队中开展真正的高效协作&#xff0c;就像是一场不可能完成的任务。 为什么会这样&#xff1f;因为即使是最聪明的团队&#xff0c;也会遇到类似的障碍&#xff1a; 割裂的工作流&#xff1a…

理解AI 智能体:智能体架构

1. 引言 智能体架构&#xff08;agent architecture&#xff09;是一份蓝图&#xff0c;它定义了AI智能体各组件的组织方式和交互机制&#xff0c;使智能体能够感知环境、进行推理并采取行动。本质上&#xff0c;它就像是智能体的数字大脑——整合了“眼睛”&#xff08;传感器…

Spring Cloud系列—SkyWalking链路追踪

上篇文章&#xff1a; Spring Cloud系列—Seata分布式事务解决方案TCC模式和Saga模式https://blog.csdn.net/sniper_fandc/article/details/149947829?fromshareblogdetail&sharetypeblogdetail&sharerId149947829&sharereferPC&sharesourcesniper_fandc&…

机器人领域的算法研发

研究生期间学习大模型&#xff0c;可投递机器人领域的算法研发、技术支持等相关岗位&#xff0c;以下是具体推荐&#xff1a; AI算法工程师&#xff08;大模型方向-机器人应用&#xff09;&#xff1a;主要负责大模型开发与优化&#xff0c;如模型预训练、调优及训练效率提升等…

深度学习入门:神经网络

文章目录一、深度学习基础认知二、神经网络核心构造解析2.1 神经元的基本原理2.2 感知器&#xff1a;最简单的神经网络2.3 多层感知器&#xff1a;引入隐藏层解决非线性问题2.3.1 多层感知器的结构特点2.3.2 偏置节点的作用2.3.3 多层感知器的计算过程三、神经网络训练核心方法…

mysql的索引有哪些?

1. 主键索引&#xff08;PRIMARY KEY&#xff09;主键索引通常在创建表时定义&#xff0c;确保字段唯一且非空&#xff1a;-- 建表时直接定义主键 CREATE TABLE users (id INT NOT NULL,name VARCHAR(50),PRIMARY KEY (id) -- 单字段主键 );-- 复合主键&#xff08;多字段组合…

【计算机视觉与深度学习实战】08基于DCT、DFT和DWT的图像变换处理系统设计与实现(有完整代码python3.13可直接粘贴使用)

1. 引言 数字图像处理作为计算机视觉和信号处理领域的重要分支,在过去几十年中得到了快速发展。图像变换技术作为数字图像处理的核心技术之一,为图像压缩、特征提取、去噪和增强等应用提供了强有力的数学工具。离散余弦变换(Discrete Cosine Transform, DCT)、离散傅里叶变…

使用Python实现DLT645-2007智能电表协议

文章目录&#x1f334;通讯支持&#x1f334; 功能完成情况服务端架构设计一、核心模块划分二、数据层定义三、协议解析层四、通信业务层&#xff08;以DLT645服务端为例&#xff09;五、通信层&#xff08;以TCP为例&#xff09;使用例子&#x1f334;通讯支持 功能状态TCP客…