在 PyTorch 中,torch.gather() 是一个非常实用的张量操作函数,主要用于根据索引从输入张量中选择特定位置的值。它常用于注意力机制、序列处理等场景。


函数定义

torch.gather(input, dim, index) → Tensor
  • input:待提取数据的张量。
  • dim:在哪个维度上进行索引选择。
  • index:一个与 input 在除了 dim 维度外相同形状的张量,其值指定了从 input 中提取的索引位置。
  • 返回值:从 input 的指定维度 dim 上根据 index 提取出的新张量。

形象理解

举个简单的例子:

示例 1:二维张量,按列(dim=1)提取

import torchinput = torch.tensor([[10, 20, 30],[40, 50, 60]])
index = torch.tensor([[2, 1, 0],[0, 1, 2]])output = torch.gather(input, dim=1, index=index)
print(output)

解释

  • 对于第一行:从 [10, 20, 30] 中提取位置 [2,1,0],结果是 [30, 20, 10]
  • 对于第二行:从 [40, 50, 60] 中提取位置 [0,1,2],结果是 [40, 50, 60]

输出

tensor([[30, 20, 10],[40, 50, 60]])

示例 2:按行(dim=0)提取

input = torch.tensor([[1, 2],[3, 4],[5, 6]])index = torch.tensor([[0, 1],[1, 2],[2, 0]])output = torch.gather(input, dim=0, index=index)
print(output)

解释

  • 每个位置从第 dim=0 维度提取对应的元素。例如:

    • 第 (0,0) 位置:从 [1,3,5] 中取第 0 行,值为 1
    • 第 (1,0) 位置:从 [1,3,5] 中取第 1 行,值为 3
    • 第 (2,1) 位置:从 [2,4,6] 中取第 0 行,值为 2

输出

tensor([[1, 4],[3, 6],[5, 2]])

应用场景

  1. 注意力机制中的权重选择
  2. 序列解码中的 beam search
  3. 从嵌套表示中根据索引获取嵌套内容

实战场景举例

假设有一个 batch 的 BERT 输出,想从每个句子中提取第 N 个 token(如 [CLS]、某个关键词)的表示向量


假设数据

import torch
from transformers import BertModel, BertTokenizertokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")sentences = ["I love World", "Transformers are powerful"]
inputs = tokenizer(sentences, padding=True, return_tensors="pt")# 获取 BERT 输出
outputs = model(**inputs)
last_hidden_state = outputs.last_hidden_state  # (batch_size, seq_len, hidden_size)print(last_hidden_state.shape)
# torch.Size([2, 5, 768])  假设 padding 后为长度 5,hidden size 为 768

场景 1:提取每个句子的第一个 token(通常是 [CLS])

cls_embeddings = last_hidden_state[:, 0, :]  # shape: (batch_size, hidden_size)

这个可以直接使用切片完成,不需要 gather


场景 2:提取每个句子中 指定位置的 token 表示(如“love”或“are”)

假设我们事先知道每个句子中感兴趣 token 的位置:
# 每个句子中我们想要提取的 token 索引
# 假设我们想提取第 2 个 token
token_indices = torch.tensor([2, 1])  # shape: (batch_size,)

使用 gather 抽取对应 token 的向量:

# last_hidden_state: (batch_size, seq_len, hidden_size)
batch_size, seq_len, hidden_size = last_hidden_state.size()# 将 token_indices 转成 index 用于 gather: shape (batch_size, 1, 1)
token_indices = token_indices.view(-1, 1, 1).expand(-1, 1, hidden_size)  # (batch_size, 1, hidden_size)# gather on dim=1(seq_len)
token_embeddings = torch.gather(last_hidden_state, dim=1, index=token_indices)  # (batch_size, 1, hidden_size)# squeeze 掉中间的维度
token_embeddings = token_embeddings.squeeze(1)  # (batch_size, hidden_size)print(token_embeddings.shape)

小结

操作需求用法
取所有句子的第一个 tokenoutput[:, 0, :]
取所有句子的第 N 个 tokenoutput[:, N, :]
取每个句子的指定 token(不同位置)torch.gather()(如上所示)

注意事项

  • index 必须与 input 的 shape 一致,除了在指定的 dim 维度上的大小。
  • index 的值必须小于 inputdim 维度上的长度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/85683.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/85683.shtml
英文地址,请注明出处:http://en.pswp.cn/bicheng/85683.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uniapp 微信小程序在线引入字体图标

在线引入字体图标,出现体验版,真机调试字体图标不出来,模拟器上是好的 由于字体图标和小程序域名不在同一个,所以出现了跨域问题,将字体图标文件放到小程序同一个域名下就好了

macOS版的节点小宝上架苹果APP Store了

前言 前段时间很多小伙伴按照小白的教程在飞牛NAS部署了节点小宝之后,Windows的小伙伴玩得不亦乐乎! 反观macOS用户……因为没有#macOS版本的节点小宝,就算是在飞牛NAS上部署了节点小宝,却一点也开心不起来。 毕竟iOS版本的节点…

tensor向量按任意维度进行切片、拆分、组合

torch.index_select(input_tensor, 切片维度, 切片索引) 注意:切完之后,转onnx时会生成Gather节点; torch自带切片操作: start : end : step: 范围前闭后开,将其放在哪个维度上,就对那个维度…

(八)Linux进程程序替换

1 进程替换 进程替换是为了让程序能在不创建新进程的情况下&#xff0c;让父进程和子进程执行不同的代码&#xff0c;以实现控制清晰、执行高效的程序调度机制。 1.1 先看效果 #include <stdio.h> #include <unistd.h> int main() {printf("before:I am a p…

支持 TDengine 的数据库管理工具—qStudio

qStudio qStudio 是一款免费的多平台 SQL 数据分析工具&#xff0c;可以轻松浏览数据库中的表、变量、函数和配置设置。最新版本 qStudio 内嵌支持 TDengine。 前置条件​ 使用 qStudio 连接 TDengine 需要以下几方面的准备工作。 安装 qStudio。qStudio 支持主流操作系统包…

破解 VMP+OLLVM 混淆:通过 Hook jstring 快速定位加密算法入口

版权归作者所有&#xff0c;如有转发&#xff0c;请注明文章出处&#xff1a;https://cyrus-studio.github.io/blog/ VMP 壳 OLLVM 的加密算法 某电商APP的加密算法经过dex脱壳分析&#xff0c;找到参数加密的方法在 DuHelper.doWork 中 package com.shizhuang.duapp.common…

Automatisch:开源的工作流自动化利器

在当今数字化的时代,企业和个人都在寻找高效的方式来自动化业务流程,减少手动操作带来的时间和成本消耗。Automatisch 作为一款开源的 Zapier 替代方案,为我们提供了一个强大而灵活的工具,让工作流自动化变得更加简单和可控。 一、Automatisch 简介 Automatisch 是一个商…

RAG应用效果评估框架与优化指南

1. 引言:为何RAG评估至关重要? 一个RAG系统通常包含多个可调参数和可替换组件(如不同的嵌入模型、向量数据库、LLM、Prompt模板等)。没有有效的评估机制,优化过程就像“盲人摸象”,难以判断改动是否带来了真正的提升。 RAG评估的核心目的: 量化系统性能:将RAG的“好坏…

豆包大模型应用场景

豆包作为通用大模型&#xff0c;应用场景其实覆盖了个人和企业两端。个人端要突出生活化功能——比如帮学生解题、帮上班族写周报&#xff1b;企业端则要强调降本增效&#xff0c;比如客服自动化、代码生成这些硬需求。用户没指定角度&#xff0c;那就都覆盖吧。 注意到用户用“…

OSITCP/IP

模型&协议 在互联网发展的早期,不同的计算机厂商有不同的网络传输协议,例如:IBM的SNA协议、苹果的AppleTalk协议等,这些协议互不兼容,导致虽然不同的产商计算机在物理层面是链接的,但是在网络上基本无法完成正常通信。这就导致一个用户如果使用了某个厂商的某个网络…

店匠科技闪耀“跨博会”,技术+生态打造灵活出海能力

2025年6月16日至18日&#xff0c;第八届全球跨境电商节暨第十届深圳国际跨境电商贸易博览会&#xff08;简称“跨博会”&#xff09;在深圳会展中心举行。作为全球跨境电商行业的年度盛会&#xff0c;本届展会以“文化跨境、品牌出海、智量强国”为主题&#xff0c;汇聚近 1500…

selenium弹框元素定位-冻结界面

有些网站上面的元素&#xff0c;我们鼠标放在上面&#xff0c;会动态弹出一些内容。 但是当我们的鼠标从音乐图标移开&#xff0c;这个栏目就整个消失了&#xff0c;就没法查看其对应的HTML。 怎么办&#xff1f;在开发者工具栏console里面执行如下js代码 &#xff1a; setTi…

美学心得(第二百七十九集)罗国正

美学心得&#xff08;第二百七十九集&#xff09; 罗国正 &#xff08;2025年6月&#xff09; 3299、分清不同本体、主体及其之间的关系&#xff0c;是 正确的审美、判断首先的关键 罗国正 &#xff08;2025年6月11日于广州&#xff09; “人也按照美的规律来建造。”这句话…

云祺容灾备份系统公有云备份与恢复实操-AWS

1、创建访问密钥 访问并登录AWS控制台&#xff0c;点击右上角用户名、安全凭证&#xff0c;在我的安全凭证窗口中&#xff0c;下拉找到访问密钥&#xff0c;并点击创建访问密钥&#xff0c;选择其他&#xff0c;点击下一步&#xff0c;即可获得密钥信息如图1至图6。 注意&…

windows内网穿透

内网穿透&#xff08;NAT穿透&#xff09;是一种通过技术手段将局域网&#xff08;内网&#xff09;中的服务暴露到公网&#xff08;外网&#xff09;的方法&#xff0c;使外部用户能够访问内网资源。其核心是解决因NAT&#xff08;网络地址转换&#xff09;或防火墙限制导致的…

threejs 实现720°全景图,;两种方式:环境贴图、CSS3DRenderer渲染

前提 有一个前提条件&#xff1a;六张大小一致的图片&#xff0c;六个图片分别对应的是720全景图的六个面&#xff1a;上、下、左、右、前、后。 这个不是那种无人机拍摄的全景图&#xff0c;是六个图片拼起来的&#xff0c;这样的取景方式要比无人机的要经济一些。 ---…

老牌软件 Ghost 备份还原操作基础

一、Ghost 简介 Symantec Ghost&#xff08;也称为 Norton Ghost&#xff09; 是一款强大的磁盘克隆和备份还原工具&#xff0c;广泛用于系统部署、数据恢复和灾难恢复。其主要功能包括&#xff1a; 创建磁盘镜像&#xff08;.GHO文件&#xff09;备份/还原分区或整个硬盘支持…

SSH连接服务器并同步本地文件

SSH连接服务器并同步本地文件 1. 复制本地公钥 cat ~/.ssh/id_rsa.pub如果不确定本地是否有公钥 ls ~/.ssh/id_rsa.pub# 如果出现如下&#xff0c;则说明你本地存在公钥 # /Users/username/.ssh/id_rsa.pub若没有公钥&#xff0c;需生成 # 使用下面命令&#xff0c;然后一路回…

中英泰马来语订货系统:助力东南亚批发贸易企业数字化转型升级

随着全球数字化转型浪潮的推进&#xff0c;东南亚地区的批发贸易企业也正逐步迈向数字化发展道路。特别是在中英泰马来语订货系统的推动下&#xff0c;东南亚的批发商和零售商能够更高效、便捷地开展跨国贸易与供应链管理。这不仅帮助传统企业提高了运营效率&#xff0c;还助力…

微信小程序获取指定元素,滚动页面到指定位置

微信小程序获取指定元素&#xff0c;滚动页面到指定位置 微信小程序获取指定元素的宽高等信息,并滚动页面到指定位置 微信小程序获取指定元素的宽高等信息,并滚动页面到指定位置 注&#xff1a;原生小程序开发&#xff1a; createSelectorQuery() 创建一个选择器查询实例。 sel…