论文解读:"Gradient Surgery for Multi-Task Learning"

1. 论文标题直译
  • Gradient Surgery: 梯度手术
  • for Multi-Task Learning: 应用于多任务学习

合在一起就是:为多任务学习量身定制的梯度手术。这个名字非常形象地概括了它的核心思想。

2. 它要解决的核心问题:多任务学习中的“梯度冲突”

想象一下,你正在训练一个AI模型来开一辆车,它需要同时完成两个任务:

  • 任务A: 识别红绿灯(要求模型关注图像上方的颜色区域)。
  • 任务B: 保持在车道线内(要求模型关注图像下方的白色线条)。

在训练时,模型会根据任务A的错误计算出一个梯度 g_A,根据任务B的错误计算出另一个梯度 g_B。梯度本质上是告诉模型参数“应该朝哪个方向更新才能做得更好”。

问题来了: 如果某次更新中,g_A 说“参数应该向东调整”,而 g_B 恰好说“参数应该向西调整”,那么把它们简单相加(g_A + g_B)的结果可能接近于零,模型几乎学不到任何东西。

更常见的情况是,g_A 想让参数向东走,g_B 想让参数向西北走。它们的合力会是一个“折衷”的方向,这个方向可能对两个任务都不是最优的,甚至可能提升一个任务的性能却损害了另一个。

这种现象就叫做梯度冲突 (Gradient Conflict) 或 负迁移 (Negative Transfer)。这是多任务学习中一个长期存在的痛点,它会导致训练不稳定,模型性能难以提升。

3. PCGrad 的解决方案:“梯度手术”

PCGrad (Projected Gradient Descent) 提出了一种非常聪明的解决方案,就像一个外科医生一样,在更新模型参数之前,先对这些相互冲突的梯度做一次“手术”。

手术流程如下:

第1步:分别计算每个任务的梯度 和传统方法不同,它不把所有损失加起来,而是为每个任务的损失 loss_Aloss_B... 单独计算梯度 g_Ag_B...

第2步:诊断是否存在“冲突” PCGrad 遍历所有梯度对(如 g_A 和 g_B),并通过计算它们的点积 (dot product) 来判断它们是否冲突。

  • 如果 dot(g_A, g_B) > 0: 说明两个梯度的夹角小于90度,它们大方向一致,是“盟友”。无需手术
  • 如果 dot(g_A, g_B) < 0: 说明两个梯度的夹角大于90度,它们的方向是“敌对”的。诊断为冲突,需要手术!

第3步:执行“手术”——投影和矫正 当检测到 g_A 和 g_B 冲突时,PCGrad 会执行以下操作:

  1. 投影 (Project):将梯度 g_A 投影到梯度 g_B 的方向上,得到一个分量 proj_B(g_A)。这个分量可以被理解为 g_A 中与 g_B “正面冲突”的那一部分。
  2. 矫正 (Correct):从原始梯度 g_A 中减去这个冲突分量:g_{A_{new}} = g_A - proj_B(g_A)

手术效果: 经过手术后的新梯度 g_{A_{new}} 与 g_B 变成了正交的(夹角为90度)。这意味着,g_{A_{new}} 的更新方向中,已经完全剔除了与 g_B 直接对抗的部分。它只保留了对自己有益,且不伤害对方的部分。

PCGrad 会对所有发生冲突的梯度对都执行这个“手术”。

第4步:合并与更新 将所有经过“手术”矫正后的新梯度相加,得到最终的、和谐的、没有内斗的梯度,然后用这个梯度去更新模型参数。

4. TensorFlow 实现中的 PCGrad

你在代码中看到的 PCGrad 通常是一个优化器包装器 (Optimizer Wrapper)。它的用法一般是这样的:

  1. 首先,定义一个基础的优化器,比如 Adam。

    base_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
  2. 然后,用 PCGrad 包装它

    from .PCGrad import PCGrad
    optimizer = PCGrad(base_optimizer)
  3. 在训练循环中,用法会稍有不同。 你不再是计算一个总的 loss 然后调用 apply_gradients。而是:

    # 1. 分别计算每个任务的 loss
    loss_A = compute_loss_A(y_true_A, y_pred_A)
    loss_B = compute_loss_B(y_true_B, y_pred_B)
    list_of_losses = [loss_A, loss_B]# 2. PCGrad 优化器会接管梯度的计算和矫正
    # 这一步是 PCGrad 内部实现的,它会:
    #   - 为每个 loss 计算梯度
    #   - 执行梯度手术
    #   - 返回最终的梯度
    # 通常会通过一个自定义的 train_step 来实现
    final_gradients = optimizer.get_gradients(list_of_losses, model.trainable_variables)# 3. 应用经过手术后的梯度
    optimizer.apply_gradients(zip(final_gradients, model.trainable_variables))

总结

方面解释
它是什么?PCGrad 是一种优化策略,而非损失函数或模型架构。
解决什么问题?解决多任务学习中的梯度冲突 (Gradient Conflict) 问题。
核心思想?梯度手术 (Gradient Surgery):在更新模型前,先检测并消除梯度之间的冲突部分。
如何实现?通过向量投影,将冲突的梯度分量从原始梯度中移除,使它们变得正交
最终效果?1. 训练过程更稳定。 2. 避免了任务间的“内耗”,有助于所有任务性能的同步提升。

因此,当你看到代码中使用了 PCGrad,就可以立刻明白:这个项目正在处理一个多任务学习的场景,并且使用了一种相当先进的技术来确保不同任务能够“和平共处”,协同进步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/97851.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/97851.shtml
英文地址,请注明出处:http://en.pswp.cn/pingmian/97851.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Nvidia显卡架构解析与cuda应用生态浅析

文章目录 0. Nvidia显卡简介 一、主要显卡系列 二、主要GPU架构与代表产品 1.main 1.1 CUDA 13.0 的重大变化 1.2 V100 的硬件短板已显现 1.3 这意味着什么? 1.4 写在后面 彩蛋:V100 0. Nvidia显卡简介 一、主要显卡系列 GeForce 系列(消费级) 用途:游戏、创作、日常图形…

开发指南:使用 MQTTNet 库构建 .Net 物联网 MQTT 应用程序

一、背景介绍 随着物联网的兴起&#xff0c;.Net 框架在构建物联网应用程序方面变得越来越流行。微软的 .Net Core 和 .Net 框架为开发人员提供了一组工具和库&#xff0c;以构建可以在 Raspberry Pi、HummingBoard、BeagleBoard、Pine A64 等平台上运行的物联网应用程序。 MQT…

突破性能瓶颈:基于腾讯云EdgeOne的AI图片生成器全球加速实践

1. 项目背景与挑战 1.1 开发背景 随着AIGC技术爆发&#xff0c;我们团队决定开发一款多模型支持的AI图片生成器&#xff0c;主要解决以下痛点&#xff1a; 不同AI模型的参数规范不统一生成结果难以系统化管理缺乏企业级的安全水印方案全球用户访问延迟高&#xff0c;中国用户…

一、Java 基础入门:从 0 到 1 认识 Java(详细笔记)

1.1 Java 语言简介与发展历程 Java 是一门面向对象的高级编程语言&#xff0c;以“跨平台、安全、稳定”为核心特性&#xff0c;自诞生以来长期占据编程语言排行榜前列&#xff0c;广泛应用于后端开发、移动端开发、大数据等领域。 1.1.1 起源与核心人物 起源背景&#xff1…

uniapp:根据目的地经纬度,名称,唤起高德/百度地图来导航,兼容App,H5,小程序

1、需要自行申请高德地图的key,配置manifest.json 2、MapSelector选择组件封装 <template><view><u-action-sheet :list="mapList" v-model="show" @click="changeMap"></u-action-sheet></view> </template&…

我对 WPF 动摇时的选择:.NET Framework 4.6.2+WPF+Islands+UWP+CompostionApi

目录 NET Framework 4.6.2的最大亮点 为什么固守462不升级 WPF-开发体验的巅峰 为什么对WPF动摇了 基于IslandsUWP的滤镜尝试 总结 NET Framework 4.6.2的最大亮点 安全性能大提升&#xff1a; 默认启用TLS1.2协议&#xff0c;更安全&#xff0c;它为后续的版本提供了重…

SpringBoot大文件下载失败解决方案

SpringBoot大文件下载失败解决方案 后端以文件流方式给前端接收下载文件,文件过大时出现下载失败的情况或者打开后提示文件损坏,实际是字节未完全读取写入。 针对大文件下载失败的情况,以下是详细的解决方案: 大文件下载失败的主要原因 内存溢出:一次性加载大文件到内存…

torch.gather

torch.gather 介绍 torch.gather(input, dim, index, *, sparse_gradFalse, outNone) → Tensor 沿由 dim 指定的轴收集值。 对于三维张量&#xff0c;输出按如下方式确定&#xff1a; out[i][j][k] input[index[i][j][k]][j][k] # 如果 dim 0 out[i][j][k] input[i][i…

Golang | http/server Gin框架简述

http/server http指的是Golang中的net/http包&#xff0c;这里用的是1.23.10。 概览 http包的作用文档里写的很简明&#xff1a;Package http provides HTTP client and server implementations. 主要是提供http的客户端和服务端&#xff0c;也就是能作为客户端发http请求&a…

Vision Transformer (ViT) :Transformer在computer vision领域的应用(三)

Experiment 上来的一段话就概括了整章的内容。 We evaluate the representation learning capabilities of ResNet, Vision Transformer (ViT), and the hybrid. 章节的一开头就说明了,对比的模型就是 ResNet,CNN领域中的代码模型。 ViT。 上一篇中提到的Hybrid模型,也就是…

5-12 WPS JS宏 Range数组规范性测试

Range()数组是JS宏中不缺少的组成部分,了解Range()数组的特性必不可少,下面我们一起测试一下各种Range()数组。 1.Range()数组特性 单元格区域:Range("a2:m2")与Range("a2","m2")的类型都是:Range/Object,功能都为单元格区域,功能…

uniapp微信小程序保存海报到手机相册canvas

在uniapp中实现微信小程序保存海报到手机相册&#xff0c;主要涉及Canvas绘制和图片保存。以下是关键步骤和代码示例&#xff1a; 一、关键代码展示&#xff1a; 1. 模板配置&#xff1a;页面展示该海报&#xff0c;可直接查看&#xff0c;也可下载保存到手机相册&#xff0c;h…

glib2-2.62.5-7.ky10.x86_64.rpm怎么安装?Kylin Linux RPM包安装详细步骤

一、准备工作 ​确认系统版本​ 这个包是 ky10的&#xff08;也就是 openEuler 20.03 LTS SP3 或类似版本&#xff09;&#xff0c;而且是 ​x86_64 架构&#xff08;就是常见的64位电脑&#xff09;​。 你要先确认你的系统是不是这个版本&#xff0c;不然可能装不上或者出问题…

webrtc之语音活动下——VAD人声判定原理以及源码详解

文章目录前言一、高斯混合模型介绍1.高斯模型举例1&#xff09;定义2&#xff09;举例说明2.高斯混合模型(GMM)1&#xff09;定义2&#xff09;举例说明3&#xff09;一维曲线二、VAD高斯混合模型1.模型训练介绍1&#xff09;训练方法2&#xff09;训练结果2.噪声高斯模型分布1…

【Redis】-- 主从复制

文章目录1. 主从复制1.1 主从复制是怎么个事&#x1f914;1.2 拓扑结构1.2.1 一主一从拓扑1.2.2 一主多从拓扑1.2.3 树形拓扑1.3 主从复制原理1.3.1 复制过程1.3.2 数据同步PSYNC1.3.2.1 replicationid/replid (复制id)1.3.2.2 复制偏移量维护1.3.3 psync运行流程1.3.4 全量复制…

开源炸场!阿里通义千问Qwen3-Next发布:80B参数仅激活3B,训练成本降90%,长文本吞吐提升10倍​

开源炸场&#xff01;阿里通义千问Qwen3-Next发布&#xff1a;80B参数仅激活3B&#xff0c;训练成本降90%&#xff0c;长文本吞吐提升10倍​ 开源世界迎来震撼突破&#xff01; 通义千问团队最新发布的Qwen3-Next架构&#xff0c;以其独创的"小而精"设计理念&#x…

【C++入门】C++基础

目录 1. 命名空间 1.1 命名空间的创建和使用 2. 输入输出 2.1 输出 2.2 输入 3. 缺省参数 3.1 全缺省 3.2 半缺省 4.函数重载 4.1 为什么C支持重载而C语言不支持&#xff1f; 4.1.2 编译的四个过程 4.2 extern是什么 5.引用 5.1 引用的特性 5.1.1 引用的“隐式类…

如何往mp4视频添加封面图和获取封面图?

前言&#xff1a;大家好&#xff0c;之前有给大家分享过mp4录像的方案&#xff0c;今天给大家分享的内容是&#xff1a;如何在添加自定义的封面图到mp4里面去&#xff0c;以及在进入回放mp4视频列表的时候&#xff0c;怎么获取mp4视频里面的封面图&#xff0c;当然这个获取到的…

你的第一个Transformer模型:从零实现并训练一个迷你ChatBot

点击 “AladdinEdu&#xff0c;同学们用得起的【H卡】算力平台”&#xff0c;注册即送-H卡级别算力&#xff0c;80G大显存&#xff0c;按量计费&#xff0c;灵活弹性&#xff0c;顶级配置&#xff0c;学生更享专属优惠。 引言&#xff1a;破除神秘感&#xff0c;拥抱核心思想 …

【20期】沪深指数《实时交易数据》免费获取股票数据API:PythonJava等5种语言调用实例演示与接口API文档说明

​ 随着量化投资在金融市场的快速发展&#xff0c;高质量数据源已成为量化研究的核心基础设施。本文将系统介绍股票量化分析中的数据获取解决方案&#xff0c;涵盖实时行情、历史数据及基本面信息等关键数据类型。 本文将重点演示这些接口在以下技术栈中的实现&#xff1a; P…