一 标准Attention的计算

1.1 标准Attention机制详解

标准Attention(注意力)机制是深度学习,尤其是在自然语言处理领域中一项革命性的技术,它允许模型在处理序列数据时,动态地将焦点放在输入序列的不同部分,从而赋予模型更强的表征能力。其核心思想源于人类的视觉注意力,即我们关注信息的特定部分,而非一次性处理所有信息。在模型中,这意味着在生成一个输出时,可以有选择地关注输入序列中与之最相关的部分。

标准Attention,特别是 “缩放点积注意力”(Scaled Dot-Product Attention),其计算过程可以分解为几个关键步骤,涉及三个核心概念:查询(Query)键(Key)值(Value)。 我们可以将这个过程类比于一个信息检索系统。当你提出一个查询(Query)时,系统会将你的查询与数据库中所有项目的键(Key)进行匹配,以确定相似度,然后根据这些相似度分数,从对应项目的值(Value)中提取信息。 在实践中,输入序列中的每个元素的词嵌入向量会被乘以三个不同的、在训练过程中学习到的权重矩阵(Wq,Wk,WvW_q,W_k,W_vWq,Wk,Wv),从而为每个输入元素生成相应的Q,K,VQ,K,VQ,K,V向量。

计算的第一步是计算注意力得分(Attention Score)。 这是通过计算特定位置的Query向量与序列中所有位置的Key向量的点积来完成的。 如果一个Query和一个Key在向量空间中方向相近,它们的点积会很大,这表明它们之间的相关性很高。 这个得分矩阵(通常称为Attention Matrix)揭示了序列中每个词对其他所有词的关注程度。

接下来是对得分进行缩放和归一化。计算出的点积得分会除以一个缩放因子,通常是Key向量维度dkd_kdk的平方根 dk\sqrt{d_k}dk。 这一步至关重要,因为当向量维度较大时,点积的结果可能会变得非常大,从而将Softmax函数的梯度推向极小的区域,导致训练不稳定。 经过缩放后,使用Softmax函数对得分进行归一化,将其转换为一组总和为1的非负权重,即注意力权重(Attention Weights)。 这些权重可以被看作是一个概率分布,表示在当前查询下,应该将多少注意力分配给每个输入位置。

最后一步是利用得到的注意力权重来计算输出。将这些权重分别与每个位置的Value向量相乘,然后将所有加权后的Value向量求和。 这个过程产生一个上下文向量(Context Vector),它是一个包含了整个输入序列信息、并根据相关性进行了加权的动态表征。 最终,这个上下文向量就是该位置的注意力层输出。

整个缩放点积注意力的计算过程可以用以下简洁的数学公式来表达:
Attention(Q,K,V)=softmax(QKTdk)V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dkQKT)V
其中,Q、K、V分别是查询、键和值的矩阵表示。

1.2 标准Attention掩码处理

在实际应用中,我们经常需要对注意力得分进行控制,以适应不同的任务需求,这时就需要引入掩码(Masking)机制。掩码是一种向注意力机制指示哪些词元应该被忽略或屏蔽的技术。其核心原理是在Softmax归一化操作之前,将注意力得分矩阵中需要屏蔽的位置加上一个非常大的负数(理论上是负无穷,实践中通常是像 -1e9float('-inf') 这样的一个极小值)。这样一来,这些位置经过Softmax函数计算后,其对应的注意力权重就会变得极其接近于零,从而在后续的加权求和中被有效忽略。引入掩码后,标准Attention的计算公式被修正为:
Attention(Q,K,V)=softmax(QKTdk+Mask)V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + \text{Mask}\right)V Attention(Q,K,V)=softmax(dkQKT+Mask)V
其中 Mask 矩阵的形状与 QKTQK^TQKT 相同,其元素在需要保留的位置为0,在需要屏蔽的位置为一个极大的负数。

一种至关重要的掩码是因果掩码(Causal Mask),它主要应用于自回归的解码任务中,例如语言模型的生成过程。在这些任务里,模型在预测第 t 个位置的输出时,只能依赖于 t 时刻及之前已经生成的序列信息,绝对不能“看到”或利用任何未来(t+1 及之后)的信息。因果掩码就是为了强制实现这一限制。具体来说,它是一个方阵,其形状与注意力得分矩阵相同,在这个矩阵中,主对角线及其下方的元素为0(允许关注当前及过去的位置),而所有主对角线上方的元素均为负无穷大(禁止关注未来的位置)。这种结构确保了信息流只能从过去流向未来,维持了模型的自回归特性。

另一种广泛使用的掩码是填充掩码(Padding Mask)。在自然语言处理任务中,为了利用GPU进行高效的并行计算,我们通常会将多个不同长度的句子打包成一个批次(batch)。为了使批次内的所有序列长度统一,会用一个特殊的“填充”符号(例如<PAD>)将较短的序列补齐到与最长序列相同的长度。然而,这些填充符号本身不包含任何有意义的语义信息,模型不应该将注意力分配给它们。填充掩码的作用就是识别出这些填充位置,并在计算注意力时将它们屏蔽掉。这个掩码的生成通常基于输入序列,对于序列中每个是填充符号的位置,掩码矩阵中对应的行或列(取决于填充的是Key还是Query)都会被设置为负无穷大。

在更复杂的解码策略中,还可能用到树形掩码(Tree Mask)。与强制线性顺序的因果掩码不同,树形掩码允许注意力遵循一个预定义的树状或图状结构。这在一些非自回归或半自回归的生成任务中非常有用,例如当模型不是严格地从左到右生成文本,而是按照一个层次结构(如句法树)或者以多路并行的方式生成内容时。树形掩码会根据这个预设的生成结构来决定注意力关系,一个节点(词元)只能关注其在树中的祖先节点,而屏蔽掉所有其他节点。这为控制生成过程提供了更大的灵活性,允许模型利用更复杂的依赖关系。

当多种掩码需要同时生效时,例如在处理一个批次的解码任务时,我们既需要屏蔽掉未来的词元(因果掩码),也需要屏蔽掉序列末尾的填充词元(填充掩码),处理方式通常非常直接。由于掩码矩阵中非屏蔽位置的值是0,而屏蔽位置是一个极大的负数,我们可以简单地将不同的掩码矩阵进行逐元素相加。例如,将因果掩码矩阵和填充掩码矩阵相加,得到的新掩码矩阵中,一个位置只要在任意一个原始掩码中被屏蔽(值为负无穷),其最终的值就会是负无穷(因为 0 + (-inf) = -inf(-inf) + (-inf) = -inf)。因此,通过简单的矩阵加法,就可以将多个约束条件合并到一个统一的掩码中,并应用到注意力得分上,从而同时满足所有的屏蔽要求。

二 Roof-Line模型与GPU

2.1 Roof-Line模型详解

Roof-Line模型是一个直观的性能分析模型,旨在揭示一个给定的计算核心(如CPU或GPU)在运行特定计算任务时的潜在性能瓶颈。它通过一张二维图表,将硬件的理论性能极限与应用程序的实际特性联系起来,从而帮助开发者理解他们的代码性能受限于计算能力还是内存带宽,并指导优化方向。该模型的核心思想是,一个程序的实际性能(以每秒浮点运算次数,即FLOP/s为单位)上限,被硬件的峰值计算性能峰值内存带宽这两个因素共同决定。

构建Roof-Line图表需要两个关键的硬件参数和一个关键的软件参数:硬件方面,第一个参数是理论峰值计算性能(Peak Performance),通常用 GFLOP/s(每秒十亿次浮点运算)表示。这是处理器在理想情况下所能达到的最高浮点运算速度,它构成了图表中的一条水平线,代表了性能的绝对计算上限。第二个硬件参数是理论峰值内存带宽(Peak Bandwidth),通常用 GB/s(每秒千兆字节)表示。这是处理器与主内存之间数据传输的最大速率,带宽限制在图表中表现为一条斜线。软件方面,关键参数是算法的“算术强度”(Arithmetic Intensity),其单位是 FLOPs/Byte。算术强度衡量的是一个程序在执行过程中,平均每从内存中读取或写入一字节数据所进行的浮点运算次数。它是算法自身的一个内在属性,反映了其计算量与访存量的比例。

在Roof-Line图中,横轴代表算术强度,纵轴代表可实现的性能(GFLOP/s)。图中的“屋顶”由上述两条线共同构成:性能首先随着算术强度的增加,沿着由内存带宽决定的斜线线性上升,当达到与峰值计算性能水平线相交的点后,性能将不再增长,维持在峰值计算性能的水平上。这条斜线的斜率等于峰值内存带宽。这两条线的交点被称为“屋脊点”(Ridge Point),它所对应的算术强度是一个关键的临界值。当一个应用的算术强度低于这个临界值时,它位于“倾斜屋顶”之下,意味着程序的性能主要受到内存带宽的限制,被称为“访存密集型”或“带宽受限”(Memory-Bound)应用。反之,如果应用的算术强度高于该临界值,它则位于“平顶屋顶”之下,其性能主要受限于处理器的计算能力,被称为“计算密集型”或“计算受限”(Compute-Bound)应用。

通过将特定应用程序或计算核心(Kernel)的实际性能和其算术强度作为一个点绘制在Roof-Line图上,开发者可以清晰地看到其性能表现。这个点与上方“屋顶”之间的垂直距离,代表了进一步优化的潜力。如果一个点位于“倾斜屋顶”下方,优化策略应聚焦于减少内存访问或提高数据的复用率,以提升其算术强度,从而将该点向右移动,使其性能沿着斜线向上提升。例如,可以通过改进数据局部性、使用缓存分块(Cache Tiling)等技术来实现。如果一个点位于“平顶屋顶”下方,则说明内存带宽已不再是瓶颈,优化方向应转向提高计算效率,比如通过指令级并行(Instruction-Level Parallelism)、向量化(Vectorization)或更好的资源利用来接近硬件的峰值计算性能。因此,Roof-Line模型不仅提供了一个性能评估的框架,更重要的是,它为性能优化提供了明确、可视化的指导。

2.2 GPU的存储体系结构

图形处理器(GPU)的存储体系结构是一个分层的系统,其设计旨在平衡海量数据存储与超高速数据访问之间的矛盾,从而支持大规模并行计算。这个体系结构可以按照存储单元是否位于核心计算芯片上,划分为片下内存(Off-chip Memory)片上内存(On-chip Memory) 两大类。片下内存,通常指我们所说的显存,例如高带宽内存(HBM),它的主要特点是拥有巨大的存储容量,如NVIDIA A100 GPU可提供40GB的HBM,但其访问速度(即带宽)相对较慢,大约为1.5TB/s。与之相对,片上内存,如静态随机存取存储器(SRAM),其特点是存储空间非常有限,通常只有几十兆字节(例如20MB),但拥有惊人的访问带宽,可以达到19TB/s甚至更高,这种巨大的带宽差异是理解GPU性能优化的关键。

GPU的整个计算流程都围绕着这个分层存储体系展开。当一个计算任务(在CUDA中称为一个Kernel)开始执行时,所需的数据首先会从容量巨大但速度较慢的片下显存(HBM)中,通过总线加载到速度极快但容量有限的片上内存(SRAM)中。真正执行计算的核心单元,被称为流式多处理器(Streaming Multiprocessor, SM),会直接从片上内存读取数据进行处理。SM可以被看作是GPU的计算大脑,每个SM内部包含大量的CUDA核心,能够同时处理成百上千个线程。计算完成后,产生的结果会再次通过片上内存写回到片下显存中,以便长期存储或用于后续的计算任务。

为了更精细地管理数据流,片上存储本身也具有层次。每个独立的SM都拥有自己私有的L1缓存和共享内存(Shared Memory)。在现代NVIDIA架构(如Volta及之后)中,L1缓存和共享内存被整合在一起,共同构成了我们之前提到的高速SRAM。共享内存是一个可由程序员直接控制的暂存空间,允许同一SM内的所有线程高效地共享和交换数据,而L1缓存则更像传统的CPU缓存,自动地缓存数据以减少延迟。所有SM还会共享一个更大容量的L2缓存,它作为各SM与主显存HBM之间的中间缓冲层。整个数据通路因此是:数据从HBM加载到所有SM共享的L2缓存,再进入特定SM的L1缓存/共享内存(SRAM),最后被该SM内的CUDA核心使用。由于从HBM到SRAM的每一次数据传输都相对耗时,这构成了性能上的一个主要瓶颈。

基于这种存储架构的特性,一个核心的GPU编程优化思想应运而生:即最大限度地减少与慢速显存的通信次数。既然SRAM的容量有限,那么策略就是在单次数据加载中,尽可能地将SRAM填满,并让加载到SRAM中的数据被充分利用。这催生了所谓的“Kernel融合”(Kernel Fusion)技术。它指的是将多个原本需要分步执行、并需要将中间结果反复读写显存的计算任务,合并成一个单一的、更复杂的计算核心(Kernel)。通过这种方式,第一个计算步骤产生的中间结果可以直接保留在高速的SRAM中,并立即被第二个计算步骤使用,从而完全避免了将中间结果写回慢速HBM再重新读出的高昂开销。这种优化手段通过减少耗时的内存I/O操作,能够极大地提升程序的整体执行效率。

三 FlashAttention:Softmax处理

3.1 标准Safe Softmax处理

Safe Softmax 是一种在计算上更加稳定的方法,用于实现标准的Softmax函数,其目的是为了有效避免在处理极大或极小的输入值时可能出现的数值溢出问题。标准的Softmax函数对于一个向量 z=(z1,z2,…,zk)\mathbf{z} = (z_1, z_2, \ldots, z_k)z=(z1,z2,,zk),其第 iii 个元素的计算公式为:
Softmax(zi)=ezi∑j=1kezj \text{Softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^{k} e^{z_j}} Softmax(zi)=j=1kezjezi
这个公式在理论上是完美的,但在计算机的浮点数表示下存在潜在风险。当输入向量 z\mathbf{z}z 中包含一个非常大的正数时,计算 ezie^{z_i}ezi 可能会导致上溢(overflow),即结果超出了计算机能表示的最大数值范围,变成无穷大(inf)。这会使得分母也变为无穷大,最终导致整个计算结果变成非数值(NaN),从而使计算过程崩溃。

为了解决这个问题,Safe Softmax利用了Softmax函数的一个重要性质:平移不变性。该性质指出,给输入向量的所有元素同时加上或减去同一个常数 CCC,Softmax的输出结果保持不变。这可以通过数学推导证明:
Softmax(zi+C)=ezi+C∑j=1kezj+C=ezieC∑j=1kezjeC=ezieCeC∑j=1kezj \text{Softmax}(z_i + C) = \frac{e^{z_i + C}}{\sum_{j=1}^{k} e^{z_j + C}} = \frac{e^{z_i} e^C}{\sum_{j=1}^{k} e^{z_j} e^C} = \frac{e^{z_i} e^C}{e^C \sum_{j=1}^{k} e^{z_j}}Softmax(zi+C)=j=1kezj+Cezi+C=j=1kezjeCezieC=eCj=1kezjezieC
=ezi∑j=1kezj=Softmax(zi)= \frac{e^{z_i}}{\sum_{j=1}^{k} e^{z_j}} = \text{Softmax}(z_i) =j=1kezjezi=Softmax(zi)
Safe Softmax巧妙地运用了这一性质,它在计算指数之前,从输入向量 z\mathbf{z}z 的每一个元素中减去该向量中的最大值。令 C=−max⁡(z)C = -\max(\mathbf{z})C=max(z),则新的输入向量变为 z′=z−max⁡(z)\mathbf{z}' = \mathbf{z} - \max(\mathbf{z})z=zmax(z)。此时,z′\mathbf{z}'z 中的最大元素值变成了0,而其他所有元素都将是负数或零。对这个新的向量 z′\mathbf{z}'z 应用Softmax函数,即:
SafeSoftmax(zi)=ezi−max⁡(z)∑j=1kezj−max⁡(z) \text{SafeSoftmax}(z_i) = \frac{e^{z_i - \max(\mathbf{z})}}{\sum_{j=1}^{k} e^{z_j - \max(\mathbf{z})}} SafeSoftmax(zi)=j=1kezjmax(z)ezimax(z)
通过这种方式,指数函数的输入值被有效地“归一化”到了一个不会导致上溢的安全区间(小于等于0)。由于 e0=1e^0=1e0=1 且对于任何负数 xxx, exe^xex 的值都在 (0, 1] 之间,因此指数运算的结果永远不会上溢。同时,由于分母至少包含一个值为1的项(即对应最大输入值的项),也避免了因所有指数项都极小而可能导致下溢(underflow)使分母为零的风险。最终,这种处理在完全不改变最终输出概率分布的情况下,极大地增强了Softmax计算的数值稳定性。

3.2 分块Safe Softmax处理

在标准的Attention机制中,为了计算注意力权重,需要显式地生成一个大小为 N×NN \times NN×N 的注意力得分矩阵 S=QKTS = QK^TS=QKT(其中 NNN 是序列长度)。这种做法在处理长序列时会带来严重问题。首先,该矩阵的内存占用与序列长度的平方成正比,即 O(N2)O(N^2)O(N2),当 NNN 很大时(例如上万),这个矩阵会消耗海量的显存,甚至可能超出GPU的容量上限,导致峰值显存占用过高而无法运行。更关键的是,这种朴素的实现方式对GPU的存储体系结构极其不友好。由于这个巨大的 N×NN \times NN×N 矩阵无法完全装入高速但容量有限的片上SRAM中,它必须被完整地写入并存储在速度慢得多的片下显存(HBM)中。随后,为了进行Softmax计算,又需要将这个矩阵从HBM中重新读取出来。这种往返于慢速HBM的大量数据读写(I/O)操作,成为了计算的性能瓶颈,远比实际的浮点运算(FLOPs)耗时,因此,即使GPU拥有强大的计算能力,其性能也严重受限于内存带宽。FlashAttention的核心动机正是要通过减少对HBM的访存来打破这一瓶颈,从而实现计算加速。

FlashAttention通过一种精巧的分块(Tiling)处理方式,结合对在线Softmax算法的应用,实现了在不生成完整 N×NN \times NN×N 矩阵的情况下,计算出与标准Softmax完全一致的结果。其基本思想是将输入矩阵Q、K、V沿序列长度维度切分成多个小块,这些小块的大小足以被载入到高速的SRAM中。然后,算法在一个外层循环中遍历Q的各个块,在内层循环中遍历K和V的各个块。

# 外层循环遍历Q的各个块,比如一个块放64个q向量
for Q_block in Q_blocks:# 内层循环遍历K和V的各个块,比如一个块放64个(k,v)向量for K_block, V_block in zip(K_blocks, V_blocks):# 在SRAM中对当前块执行融合计算 (矩阵乘、在线Softmax更新等)pass 

在每次内层循环中,它只计算一个Q块和一个K块之间的得分矩阵块 Sij=QiKjTS_{ij} = Q_i K_j^TSij=QiKjT。关键的创新在于,它不需要等待所有得分块都计算完毕再进行Softmax,而是在处理每个块的同时,以一种在线(online)的方式更新Softmax的统计量。为了保证最终结果的正确性,算法为当前Q块的每一行都维护了两个核心的统计值:到目前为止所处理过的所有得分块中该行元素的最大值 mmm,以及用这个最大值归一化后的指数和 lll

当计算完一个新的得分块时,算法会更新这两个统计量。假设处理完前 jjj 个块后,某行的统计量为 m(j)m^{(j)}m(j)l(j)l^{(j)}l(j)。当处理第 j+1j+1j+1 个块时,会先计算出这个新块内的行最大值 m(j+1)m^{(j+1)}m(j+1) 和指数和 l(j+1)l^{(j+1)}l(j+1)。然后,通过以下公式将新旧统计量合并,得到截至第 j+1j+1j+1 块的全局统计量 mnewm_{new}mnewlnewl_{new}lnew(备注:lnewl_{new}lnew为什么这样更新,如果这个看不懂也没关系的,不会太影响后续的理解):
mnew=max⁡(m(j),m(j+1)) m_{new} = \max(m^{(j)}, m^{(j+1)}) mnew=max(m(j),m(j+1))
lnew=em(j)−mnewl(j)+em(j+1)−mnewl(j+1) l_{new} = e^{m^{(j)} - m_{new}} l^{(j)} + e^{m^{(j+1)} - m_{new}} l^{(j+1)} lnew=em(j)mnewl(j)+em(j+1)mnewl(j+1)
与此同时,之前根据旧统计量计算出的中间输出向量 OOO 也会被相应地进行缩放和更新。通过这种方式,算法迭代地处理完所有K/V的块,最终得到的输出与一次性对完整矩阵进行标准Softmax计算的结果在数值上是等价的。整个过程所有的中间计算,包括块状矩阵乘法、Softmax更新和与V矩阵的乘法,都在高速的SRAM内以“算子融合”(Kernel Fusion)的方式完成,从而极大地减少了对HBM的读写次数,实现了显著的性能提升。

四 FlashAttention:矩阵分块处理

4.1 矩阵分块循环计算

在计算注意力得分矩阵 S=QKTS=QK^TS=QKT 时,FlashAttention系列算法的核心创新在于其分块计算策略,该策略避免了在内存中生成并存储完整的 N×NN \times NN×N 得分矩阵。v1和v2版本都遵循这一原则,但它们通过不同的分块与循环方式来组织计算。

FlashAttention v1采用了一种以Q矩阵为中心的嵌套循环结构。其外层循环沿着矩阵Q的行维度(即序列长度)进行分块,内层循环则相应地遍历矩阵K的所有行块。在这种模式下,一个Q的块(记为QiQ_iQi)会被首先加载,然后内层循环会依次加载K的每一个块(记为KjK_jKj),并计算出对应的得分矩阵子块Sij=QiKjTS_{ij} = Q_i K_j^TSij=QiKjT。这意味着一个QiQ_iQi块会与所有的KjK_jKj块依次进行计算,从而提高了QiQ_iQi块的数据复用率。

# FlashAttention v1 伪代码
for Qi in iter_blocks(Q):# Qi 块保持不变,遍历所有K块for Kj in iter_blocks(K):# 计算得分矩阵的一个子块Sij = Qi @ Kj.T

FlashAttention v2则对这个计算流程进行了重组,以实现更优的计算模式。v2版本颠倒了循环的顺序,将外层循环改为遍历K矩阵的块,内层循环则遍历Q矩阵的块。虽然单步计算仍然是Sij=QiKjTS_{ij} = Q_i K_j^TSij=QiKjT,但这种循环顺序的改变,从根本上改变了计算任务的组织方式。它不再是围绕一个固定的QiQ_iQi块进行计算,而是将整个QKTQK^TQKT的计算分解为一系列更独立的子任务,每个子任务对应一个(Qi,Kj)(Q_i, K_j)(Qi,Kj)对。这种结构使得不同的计算任务可以更灵活地被调度和并行执行。

# FlashAttention v2 伪代码
for Kj in iter_blocks(K):# Kj 块保持不变,遍历所有Q块for Qi in iter_blocks(Q):# 计算得分矩阵的一个子块Sij = Qi @ Kj.T

4.2 硬件分块调度情况

在FlashAttention算法的实际硬件执行中,其核心的分块计算逻辑被巧妙地映射到了GPU的并行计算架构上,其中流式多处理器(SM)和线程块(Thread Block)是基本的调度单元。v1和v2版本在如何将分块计算任务分配给这些硬件单元上存在本质区别,v2的改进正是其性能实现巨大飞跃的关键。

在FlashAttention v1中,硬件调度与分块逻辑的对应关系较为直接。通常,一个线程块会被分配去计算最终输出矩阵O的一个行块(a block of rows)。当这个线程块被调度到某个SM上执行时,它首先负责从慢速的HBM显存中加载其对应负责的QiQ_iQi块到该SM高速的片上共享内存(SRAM)中。随后,该线程块进入一个内层循环,在此循环中它会依次从HBM中加载K矩阵的所有块(K1,K2,…,KnK_1, K_2, \ldots, K_nK1,K2,,Kn)。每加载一个KjK_jKj块,就在SRAM内完成Sij=QiKjTS_{ij} = Q_i K_j^TSij=QiKjT的计算以及后续的在线Softmax更新。在这种模式下,并行性主要体现在GPU可以同时调度多个线程块,在不同的SM上并行地处理不同的QiQ_iQi块。其优点是QiQ_iQi块在SRAM中的数据复用率很高,但缺点是单个线程块内部的工作是串行的,它必须等待漫长的内层循环完成(即遍历完所有K块),这可能导致工作负载不均和部分硬件资源在某些时刻的闲置(比如长样本和短样本一起计算)。

FlashAttention v2对这种调度方式进行了根本性的重构,以实现更高程度的并行化和更优的硬件资源利用率。v2的核心改进在于它改变了线程块的任务分配和循环结构,不再让单个线程块包揽一整个输出行块的完整计算。取而代之的是,它将整个QKTQK^TQKT的计算任务在更高层级(Grid级别)上进行划分。外层循环现在由K的块主导,而针对Q块的循环任务则被更细粒度地拆分,并分配给线程块内的多个线程束(Warps)并行处理。这意味着一个线程块的工作不再是固定一个QiQ_iQi然后遍历所有KjK_jKj,而是而是变成了固定一个KjK_jKj然后遍历所有QiQ_iQi

一个线程块由多个线程束(Warps)组成,在v2中,处理不同QiQ_iQi块的内层循环任务可以被分配给不同的线程束来并行或流水线式地执行。例如,当一个线程束在等待从内存加载Q1Q_1Q1块时,另一个线程束可以利用计算单元去处理已经加载好的Q2Q_2Q2块。这种任务的交错执行,极大地掩盖了内存访问延迟,使得计算单元几乎总是有事可做,从而将硬件利用率推向了新的高度。而在v1中,不同的KjK_jKj是串行执行的,只能按顺序来,做不到v2这种并发执行。

v2相对v1的改进所带来的好处是双重的。首先,它极大地提升了并行度和负载均衡。通过将计算任务分解为更小、更独立的单元,GPU的调度器可以更灵活、更均匀地将工作分配到所有的SM上,从而显著提高硬件的占用率(Occupancy)。这避免了v1中可能出现的某些线程块因内层循环过长而成为瓶颈,而其他SM却处于空闲状态的现象。其次,也是更精妙的一点,v2的调度方式减少了对SRAM的访问。通过精心重排计算顺序,v2能够让更多的中间结果直接保存在线程束更快、更私有的寄存器(Registers)中,避免了频繁地将数据写入SRAM再读出的开销。总而言之,v2的改进从v1的“优化HBM访存”层面,深入到了“优化片上计算与存储”的层面,通过实现更优的并行调度和更少的片上数据移动,从而更极致地压榨了GPU的计算潜力。

4.3 batch_size > 1时v2的调度

在FlashAttention-v2中,当处理一个包含不同序列长度样本的批次时,其硬件调度策略更加精妙,旨在最大化数据局部性和硬件并行效率。整个计算任务依然通过一次统一的kernel launch来处理,但线程块(Thread Block)的划分与分配逻辑遵循了以K块为中心的原则。

我们来分析一个具体的例子:一个批次包含两个样本,样本1的KV缓存长度为1024、Q查询数量为16;样本2的KV缓存长度为2048、Q查询数量为32。首先,我们定义计算块的大小,例如K块大小(BLOCK_K)为128,Q块大小(BLOCK_Q)为16。

根据v2的设计,GPU线程块的逻辑网格(Grid)是围绕着批次(B)、注意力头(H)和K块的索引来构建的。一个线程块会被分配一个形如 (batch_idx, head_idx, k_block_idx) 的任务。

  • 对于样本1,其KV缓存长度为1024,因此它有 1024 / 128 = 8个K块。它的Q查询数量为16,因此有 16 / 16 = 1个Q块
  • 对于样本2,其KV缓存长度为2048,因此它有 2048 / 128 = 16个K块。它的Q查询数量为32,因此有 32 / 16 = 2个Q块

现在,我们来看线程块是如何被调度的:

  • 线程块分配:GPU调度器会启动总共 (8 + 16) = 24个线程块(对于单个注意力头而言)。其中8个线程块会处理样本1,16个线程块会处理样本2。例如,一个线程块A可能被分配任务 (b=0, h=h, k_idx=0),即处理样本1的第0个K块。另一个线程块B可能被分配任务 (b=1, h=h, k_idx=10),即处理样本2的第10个K块。

  • 线程块内部执行

    • 线程块A(处理样本1)在某个SM上执行时,它首先从HBM中加载样本1的第0个K块(以及对应的V块)到高速的SRAM中。然后,它进入一个内层循环,该循环将遍历该样本所有的Q块。由于样本1只有一个Q块,这个内层循环只会迭代1次
    • 线程块B(处理样本2)在另一个SM上执行时,它会从HBM加载样本2的第10个K块(和V块)到SRAM。随后,它也进入遍历Q块的内层循环。但由于样本2有两个Q块,这个内层循环需要迭代2次。在每次迭代中,它会加载相应的Q块并执行计算。

v2这种以K块为中心的调度方式,相比v1具有显著优势。在自回归解码这类Q序列很短(通常为1)而KV缓存很长的场景下,v1的模式(固定Q块,遍历所有K块)会导致极长的内层循环和较低的硬件利用率。而v2的模式下,一个线程块固定一个K块,然后遍历数量很少的Q块,这使得每个线程块的工作量更加均衡和短暂。这种方式不仅极大地提升了并行度(因为可以为更多的K块启动独立的线程块),还显著改善了访存局部性——较大的K/V块数据一旦被加载到SRAM,可以与所有相关的Q块进行计算,从而减少了对HBM的昂贵读写,最终实现了更高的计算效率和吞吐量。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/90203.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/90203.shtml
英文地址,请注明出处:http://en.pswp.cn/diannao/90203.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C/C++ inline-hook(x86)高级函数内联钩子

&#x1f9f5; C/C inline-hook&#xff08;x86&#xff09;高级函数内联钩子 引用&#xff1a; fetch-x86-64-asm-il-sizeC i386/AMD64平台汇编指令对齐长度获取实现 &#x1f9e0; 一、Inline Hook技术体系架构 Inline Hook是一种二进制指令劫持技术&#xff0c;通过修改目…

云服务器的安全防护指南:从基础安全设置到高级威胁防御

随着云计算的广泛应用&#xff0c;云服务器已成为企业和个人存储数据、运行应用的重要基础设施。然而&#xff0c;随之而来的安全威胁也日益增多——从常见的网络攻击&#xff08;如 DDoS、SQL 注入&#xff09;到复杂的恶意软件和零日漏洞&#xff0c;无一不考验着系统的安全性…

状态机管家:MeScroll 的交互秩序维护

一、核心架构设计与性能基石 MeScroll作为高性能滚动解决方案&#xff0c;其架构设计遵循"分层解耦、精准控制、多端适配"的原则&#xff0c;通过四大核心模块实现流畅的滚动体验&#xff1a; 事件控制层&#xff1a;精准捕获触摸行为&#xff0c;区分滚动方向与距…

数据出海的隐形冰山:企业如何避开跨境传输的“合规漩涡”?

首席数据官高鹏律师数字经济团队创作&#xff0c;AI辅助凌晨三点的写字楼&#xff0c;某跨境电商的技术总监盯着屏幕上的报错提示&#xff0c;指尖悬在键盘上迟迟没落下。刚从新加坡服务器调取的用户行为数据&#xff0c;在传输到国内分析系统时被拦截了——系统提示“不符合跨…

【Rust base64库】Rust bas64编码解码详细解析与应用实战

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Rust开发…

如何利用AI大模型对已有创意进行评估,打造杀手级的广告创意

摘要 广告创意是影响广告效果的最重要的因素之一&#xff0c;但是如何评估和优化广告创意&#xff0c;一直是一个难题。传统的方法&#xff0c;如人工评审、A/B测试、点击率等&#xff0c;都有各自的局限性和缺陷。本文将介绍一种新的方法&#xff0c;即利用人工智能大模型&am…

OSCP - HTB - Cicada

主要知识点 SMB 用户爆破Backup Operator 组提权 具体步骤 nmap扫描一下先&#xff0c;就像典型的windows 靶机一样&#xff0c;开放了N多个端口 Nmap scan report for 10.10.11.35 Host is up (0.19s latency). Not shown: 65522 filtered tcp ports (no-response) PORT …

10046 解决 Oracle error

How to Offline a PDB Datafile in NOARCHIVELOG mode CDB which is not Open in Read Write (Doc ID 2240730.1)1. pdb 下的datafile 只能在pdb下操作&#xff0c;不能在cdb下操作For the purposes of this document, the following fictitious environment is used as an exa…

在HP暗影精灵Ubuntu20.04上修复IntelAX211Wi-Fi不可用的全过程记录——系统安装以后没有WIFI图标无法使用无线网

在 HP 暗影精灵 Ubuntu 20.04 上修复 Intel AX211 Wi-Fi 不可用的全过程记录 2025 年 7 月初 系统环境&#xff1a;HP OMEN&#xff08;暗影精灵&#xff09;笔记本 | 双系统 Windows 11 & Ubuntu 20.04 | 内核 5.15 / 6.15 mainline 问题关键词&#xff1a;Intel AX21…

Sql server 中关闭ID自增字段(SQL取消ID自动增长)

sql server在导入数据的时候&#xff0c;有时候要考虑id不变&#xff0c;就要先取消自动增长再导入数据&#xff0c;导完后恢复自增。 比如网站改版从旧数据库导入新数据库&#xff0c;数据库结构不相同&#xff0c;可能会使用insert into xx select ..from yy的语句导入数据。…

Python实现文件夹中文件名与Excel中存在的文件名进行对比,并进行删除操作

以下python程序版本为Python3.13.01.请写一个python程序&#xff0c;实现以下逻辑&#xff1a;从文件夹获取所有文件名&#xff0c;与Excel中的fileName列进行对比&#xff0c;凡是不在该文件夹下的文件名&#xff0c;从Excel文档中删除后&#xff0c;并将Excel中fileName和fil…

广告业务动态查询架构设计:从数据建模到可视化呈现

在数字化营销领域&#xff0c;广告主每天面临着海量数据带来的分析挑战&#xff1a;从账户整体投放效果&#xff0c;到分渠道、分地域的精细化运营&#xff0c;每一层级的数据洞察都需要灵活高效的查询能力。我们的广告业务动态查询系统&#xff0c;正是为解决这类需求而生 &am…

pytorch、torchvision与python版本对应关系

pytorch、torchvision与python版本对应关系 可以查看官网&#xff1a; https://github.com/pytorch/vision#installation

【机器学习笔记 Ⅲ】3 异常检测算法

异常检测算法&#xff08;Anomaly Detection&#xff09;详解 异常检测是识别数据中显著偏离正常模式的样本&#xff08;离群点&#xff09;的技术&#xff0c;广泛应用于欺诈检测、故障诊断、网络安全等领域。以下是系统化的解析&#xff1a;1. 异常类型类型描述示例点异常单个…

【ssh】在 Windows 上生成 SSH 公钥并实现免密登录 Linux

在 Windows 上生成 SSH 公钥并实现免密登录 Linux&#xff0c;可以使用 ssh-keygen 命令&#xff0c;这是 Windows 10 和 Windows 11 中默认包含的 OpenSSH 工具的一部分。下面是详细步骤&#xff1a; 在 Windows 上生成 SSH 公钥 打开 PowerShell 或命令提示符&#xff1a; 在…

MS51224 一款 16 位、3MSPS、双通道、同步采样模数转换器(ADC)

MS51224 是一款 16 位、3MSPS、双通道、同步采样模数转换器&#xff08;ADC&#xff09;&#xff0c;具有集成的内部参考和参考电压缓冲器。芯片可由 5V 单电源供电&#xff0c;支持单极性和全差分模拟信号输入&#xff0c;具有出色的直流和交流性能。芯片模拟输入信号频率高达…

WPF学习(四)

文章目录一、用户控价1.1 依赖属性的注册1.2 具体使用一、用户控价 1.1 依赖属性的注册 using System.Windows; using System.Windows.Controls;namespace WpfApp {public partial class MyUserControl : UserControl{// 依赖属性&#xff1a;外部可绑定的文本public static …

vue3+typescript项目配置路径别名@

1. vite.config.ts配置//方法1 import { defineConfig } from vite; import vue from vitejs/plugin-vue; import path from path;export default defineConfig({plugins: [vue()],resolve: {alias: {: path.resolve(__dirname, src)}} });//方法2,需要执行npm install -D type…

MySql 常用SQL语句、 SQL优化

✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨SQL语句主要分为哪几类 SQL&#xff08;结构化查询语言&#xff09;是用于管理和操作关系型数据库的标准语言&#xff0c;其语句通常根据功能划分为以下几大类&#xff0c;每类包含不同的子句和命令&#xff0c;用于实现特定的数据库操作需求&am…

代理模式实战指南:打造高性能RPC调用与智能图片加载系统

代理模式实战指南&#xff1a;打造高性能RPC调用与智能图片加载系统 &#x1f31f; 嗨&#xff0c;我是IRpickstars&#xff01; &#x1f30c; 总有一行代码&#xff0c;能点亮万千星辰。 &#x1f50d; 在技术的宇宙中&#xff0c;我愿做永不停歇的探索者。 ✨ 用代码丈量…