一、多头注意力机制概述

多头注意力(Multi-Head Attention)是Transformer模型的核心组件,其核心思想是通过 ‌并行处理多个子空间‌ 来捕捉序列中不同位置间的复杂依赖关系。主要特点:

  • 并行计算:将高维向量拆分为多个低维子空间
  • 多视角学习:每个注意力头关注不同特征模式
  • 高效性:矩阵运算高度可并行化

在这里插入图片描述

二、代码实现

1. pyTorch 实现
import math
import torch
import torch.nn as nn
import torch.nn.functional as Fclass MultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):"""Args:embed_dim: 词向量维度(如512)num_heads: 注意力头数量(如8)"""super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_heads  # 每个头的维度(如512//8=64)assert self.head_dim * num_heads == embed_dim, "维度不可整除"# 定义线性变换层self.query = nn.Linear(embed_dim, embed_dim)  # Q矩阵self.key = nn.Linear(embed_dim, embed_dim)    # K矩阵self.value = nn.Linear(embed_dim, embed_dim)  # V矩阵self.out = nn.Linear(embed_dim, embed_dim)    # 输出层def transpose_for_scores(self, x):"""拆分多头并调整维度顺序输入: [batch_size, seq_len, embed_dim]输出: [batch_size, num_heads, seq_len, head_dim]"""new_shape = x.size()[:-1] + (self.num_heads, self.head_dim)x = x.view(*new_shape)  # 新增头维度return x.permute(0, 2, 1, 3)  # [batch, heads, seq_len, head_dim]def forward(self, query, key, value, mask=None):"""前向传播流程输入形状: [batch_size, seq_len, embed_dim]输出形状: [batch_size, seq_len, embed_dim]"""batch_size = query.size(0)# 1. 线性变换Q = self.query(query)  # [N, seq, D]K = self.key(key)      # [N, seq, D]V = self.value(value)  # [N, seq, D]# 2. 拆分多头Q = self.transpose_for_scores(Q)  # [N, h, seq, d]K = self.transpose_for_scores(K)  # [N, h, seq, d] V = self.transpose_for_scores(V)  # [N, h, seq, d]# 3. 计算注意力分数scores = torch.matmul(Q, K.transpose(-2, -1))  # [N, h, seq_q, seq_k]scores /= math.sqrt(self.head_dim)  # 缩放# 4. 应用掩码(可选)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)# 5. 计算注意力权重attn_weights = F.softmax(scores, dim=-1)  # [N, h, seq_q, seq_k]# 6. 应用权重到Valueout = torch.matmul(attn_weights, V)  # [N, h, seq_q, d]# 7. 合并多头out = out.permute(0, 2, 1, 3).contiguous()  # [N, seq_q, h, d]out = out.view(batch_size, -1, self.embed_dim)  # [N, seq, D]# 8. 输出层return self.out(out), attn_weights
2. tensorFlow实现
# TensorFlow (兼容TF2.x)import tensorflow as tf
from tensorflow.keras.layers import Layer, Denseclass MultiHeadAttention(Layer):def __init__(self, embed_dim, num_heads):"""Args:embed_dim: 词向量维度(如512)num_heads: 注意力头数量(如8)"""super(MultiHeadAttention, self).__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_headsassert self.head_dim * num_heads == embed_dim, "维度不可整除"# 定义线性变换层self.query_dense = Dense(embed_dim)self.key_dense = Dense(embed_dim)self.value_dense = Dense(embed_dim)self.output_dense = Dense(embed_dim)def split_heads(self, x, batch_size):"""拆分多头并调整维度顺序输入: [batch_size, seq_len, embed_dim]输出: [batch_size, num_heads, seq_len, head_dim]"""x = tf.reshape(x, (batch_size, -1, self.num_heads, self.head_dim))return tf.transpose(x, perm=[0, 2, 1, 3])def call(self, query, key, value, mask=None):batch_size = tf.shape(query)# 1. 线性变换Q = self.query_dense(query)  # [N, seq, D]K = self.key_dense(key)      # [N, seq, D]V = self.value_dense(value)  # [N, seq, D]# 2. 拆分多头Q = self.split_heads(Q, batch_size)  # [N, h, seq, d]K = self.split_heads(K, batch_size)  # [N, h, seq, d]V = self.split_heads(V, batch_size)  # [N, h, seq, d]# 3. 计算注意力分数matmul_qk = tf.matmul(Q, K, transpose_b=True)  # [N, h, seq_q, seq_k]scaled_attention_logits = matmul_qk / tf.math.sqrt(tf.cast(self.head_dim, tf.float32))# 4. 应用掩码(可选)if mask is not None:scaled_attention_logits += (mask * -1e9)  # 添加极大负值# 5. 计算注意力权重attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)# 6. 应用权重到Valueoutput = tf.matmul(attention_weights, V)  # [N, h, seq_q, d]# 7. 合并多头output = tf.transpose(output, perm=[0, 2, 1, 3])  # [N, seq_q, h, d]concat_attention = tf.reshape(output, (batch_size, -1, self.embed_dim))# 8. 输出层return self.output_dense(concat_attention), attention_weights

三、维度变化全流程详解

1. 参数设定
  • batch_size = 2
  • seq_len = 5
  • embed_dim = 512
  • num_heads = 8
  • head_dim = 512 // 8 = 64
2. 维度变化流程图
原始输入: [2, 5, 512]│├─线性变换───────保持形状→ [2, 5, 512]│├─拆分多头──────→ [2, 8, 5, 64]│                (拆分512为8个64维头)│├─计算注意力分数──→ [2, 8, 5, 5]│                (每个头计算5x5的注意力矩阵)│├─Softmax───────→ [2, 8, 5, 5]│                (最后一维归一化)│├─应用权重到Value→ [2, 8, 5, 64]│                (每个头输出新的序列表示)│├─合并多头───────→ [2, 5, 512]│                (拼接8个64维头恢复512维)│└─输出层────────→ [2, 5, 512]
3. 关键步骤维度变化

在这里插入图片描述

四、关键实现细节解析

1. 多头拆分与合并
# 拆分多头(核心代码)
new_shape = x.size()[:-1] + (num_heads, head_dim)
x = x.view(*new_shape).permute(0, 2, 1, 3)# 合并多头(逆过程)
x = x.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, embed_dim)
  • 为什么要permute:将num_heads维度提前,便于后续矩阵乘法并行处理多个头
2. 注意力分数计算
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
  • 转置维度‌:将K的seq_len和head_dim维度交换,使矩阵乘法满足[seq_q, d] x [d, seq_k] → [seq_q, seq_k]
  • 缩放因子‌:防止点积结果过大导致softmax梯度消失
3. 掩码处理技巧

python

scores = scores.masked_fill(mask == 0, -1e9)
  • 作用‌:将填充位置(如)的注意力权重趋近于0
  • 为什么用-1e9‌:经过softmax后,exp(-1e9) ≈ 0

五、完整运行示例

# 测试用例
embed_dim = 512
num_heads = 8
model = MultiHeadAttention(embed_dim, num_heads)# 生成测试数据
batch_size = 2
seq_len = 5
inputs = torch.randn(batch_size, seq_len, embed_dim)# 前向传播
output, attn = model(inputs, inputs, inputs)# 验证输出形状
print(output.shape)  # torch.Size([2, 5, 512])
print(attn.shape)    # torch.Size([2, 8, 5, 5])

六、总结与常见问题

1. 核心优势
  • 并行计算效率‌:通过矩阵运算同时处理所有位置和注意力头
  • 多视角学习‌:不同注意力头可关注语法、语义等不同特征
  • 长距离依赖‌:直接计算任意两个位置间的关联
2. FAQ
  • Q1:为什么需要多个注意力头?‌

  • A:类比CNN中多个卷积核,不同头可以捕捉不同类型的特征依赖

  • Q2:head_dim为什么要设置为embed_dim/num_heads?‌

  • A:保持总参数量不变,确保拆分前后的维度乘积相等(num_heads * head_dim = embed_dim)

  • Q3:permute之后为什么要调用contiguous()?‌

  • A:确保张量在内存中连续存储,避免后续view操作报错

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/73577.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/73577.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/73577.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Interview preparation.md

Vue 1.1 响应式系统 Vue 3 使用 Proxy 代替 Vue 2 中的 Object.defineProperty 来实现响应式系统。Proxy 可以监听对象的所有操作,包括属性的添加和删除,从而解决了 Vue 2 的一些局限性。 Vue 2:使用 Vue.set 添加响应式属性 new Vue({el…

2.8滑动窗口专题:最小覆盖子串

1. 题目链接 LeetCode 76. 最小覆盖子串 2. 题目描述 给定字符串 s 和 t,要求找到 s 中最小的窗口,使得该窗口包含 t 的所有字符(包括出现次数)。若不存在,返回空字符串。 示例: 输入:s &quo…

【数据分析大屏】基于Django+Vue汽车销售数据分析可视化大屏(完整系统源码+数据库+开发笔记+详细部署教程+虚拟机分布式启动教程)✅

目录 一、项目背景 二、项目创新点 三、项目功能 四、开发技术介绍 五、项目功能展示 六、权威视频链接 一、项目背景 汽车行业数字化转型加速,销售数据多维分析需求激增。本项目针对传统报表系统交互性弱、实时性差等痛点,基于DjangoVue架构构建…

cyberstrikelab lab2

lab2 重生之我是渗透测试工程师,被公司派遣去测试某网络的安全性。你的目标是成功获取所有服务器的权限,以评估网络安全状况。 先扫一下 ​ ​ 192.168.10.10 ​ ​ 骑士cms 先找后台路径 http://192.168.10.10:808/index.php?madmin&cind…

在 Ubuntu 服务器上使用宝塔面板搭建博客

📌 介绍 在本教程中,我们将介绍如何在 Ubuntu 服务器 上安装 宝塔面板,并使用 Nginx PHP MySQL 搭建一个博客(如 WordPress)。 主要步骤包括: 安装宝塔面板配置 Nginx PHP MySQL绑定域名与 SSL 证书…

PTA7-13 统计工龄

题目描述 给定公司 n 名员工的工龄,要求按工龄增序输出每个工龄段有多少员工。 输入格式: 输入首先给出正整数 n(≤105),即员工总人数;随后给出 n 个整数,即每个员工的工龄,范围在 [0, 50]。…

【 <一> 炼丹初探:JavaWeb 的起源与基础】之 Servlet 3.0 新特性:异步处理与注解配置

<前文回顾> 点击此处查看 合集 https://blog.csdn.net/foyodesigner/category_12907601.html?fromshareblogcolumn&sharetypeblogcolumn&sharerId12907601&sharereferPC&sharesourceFoyoDesigner&sharefromfrom_link <今日更新> 一、Servle…

电子电气架构 --- 汽车电子硬件架构

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 人生是一场骗局,最大的任务根本不是什么买车买房,也不是及时行乐,这就是欲望,不是理想,是把自己对生命的希望寄托在外物上,正确的做法应该是内…

使用 Homebrew 安装 OpenJDK 并配置环境变量

在 macOS 上使用 Homebrew 安装 OpenJDK 是一种简单而高效的方式。本文将使用 Homebrew 安装 OpenJDK&#xff0c;并设置环境变量以便 Java 能够正确运行。 1. 安装 Homebrew 首先&#xff0c;确保你的 macOS 系统已经安装了 Homebrew。如果没有安装&#xff0c;可以通过以下…

Java集合简单理解

Java 的集合框架&#xff08;Java Collections Framework, JCF&#xff09;是 Java 中用于存储和操作数据结构的核心库&#xff0c;提供了丰富的接口和实现类&#xff0c;用于处理不同类型的集合数据。以下是详细的介绍&#xff1a; 一、集合框架的体系结构 Java 集合主要分为…

群体智能优化算法-旗鱼优化算法 (Sailfish Optimizer, SFO,含Matlab源代码)

摘要 旗鱼优化算法&#xff08;Sailfish Optimizer, SFO&#xff09;是一种模拟旗鱼&#xff08;Sailfish&#xff09;和沙丁鱼&#xff08;Sardine&#xff09;之间捕食关系的新型元启发式算法。通过在搜索过程中模拟旗鱼对沙丁鱼的捕食行为&#xff0c;以及沙丁鱼群的逃逸与…

【C语言】编译和链接详解

hi&#xff0c;各位&#xff0c;让我们开启今日份博客~ 小编个人主页点这里~ 目录 一、翻译环境和运行环境1、翻译环境1.1预处理&#xff08;预编译&#xff09;1.2编译1.2.1词法分析1.2.2语法分析1.2.3语义分析 1.3汇编1.4链接 2.运行环境 一、翻译环境和运行环境 在ANSI C…

VIC模型率定验证

在气候变化问题日益严重的今天&#xff0c;水文模型在防洪规划&#xff0c;未来预测等方面发挥着不可替代的重要作用。目前&#xff0c;无论是工程实践或是科学研究中都存在很多著名的水文模型如SWAT/HSPF/HEC-HMS等。虽然&#xff0c;这些软件有各自的优点&#xff1b;但是&am…

【AWS入门】AWS云计算简介

【AWS入门】AWS云计算简介 A Brief Introduction to AWS Cloud Computing By JacksonML 什么是云计算&#xff1f;云计算能干什么&#xff1f;我们如何利用云计算&#xff1f;云计算如何实现&#xff1f; 带着一系列问题&#xff0c;我将做一个普通布道者&#xff0c;引领广…

Flutter_学习记录_ ImagePicker拍照、录制视频、相册选择照片和视频、上传文件

插件地址&#xff1a;https://pub.dev/packages/image_picker 添加插件 添加配置 android无需配置开箱即用&#xff0c;ios还需要配置info.plist <key>NSPhotoLibraryUsageDescription</key> <string>应用需要访问相册读取文件</string> <key>N…

蓝桥与力扣刷题(蓝桥 星期计算)

题目&#xff1a;已知今天是星期六&#xff0c;请问 20^22 天后是星期几? 注意用数字 1 到 7 表示星期一到星期日。 本题为填空题&#xff0c;只需要算出结果后&#xff0c;在代码中使用输出语句将所填结果输出即可。 解题思路&#xff0b;代码&#xff1a; 代码&#xff1…

向量数据库原理及选型

向量数据库 什么是向量什么是向量数据库原理应用场景 向量数据库的选型主流向量数据库介绍向量数据库对比主流向量数据库对比表 选型建议 什么是向量 向量是一组有序的数值&#xff0c;表示在多维空间中的位置或方向。向量通常用一个列或行的数字集合来表示&#xff0c;这些数…

以实现生产制造、科技研发、人居生活等一种或多种复合功能的智慧油站开源了

AI视频监控平台简介 AI视频监控平台是一款功能强大且简单易用的实时算法视频监控系统。它的愿景是最底层打通各大芯片厂商相互间的壁垒&#xff0c;省去繁琐重复的适配流程&#xff0c;实现芯片、算法、应用的全流程组合&#xff0c;从而大大减少企业级应用约95%的开发成本。用…

小程序网络大文件缓存方案

分享一个小程序网络大图加载慢的解决方案 用到的相关api getSavedFileList 获取已保存的文件列表&#xff1b;getStorageSync 获取本地缓存&#xff1b;downloadFile 下载网络图片&#xff1b;saveFile 保存文件到本地&#xff1b;setStorage 将数据储存到小程序本地缓存&…