大语言模型加速技术之KV Cache

  • Why we need KV Cache ?
  • Self-Attention Without Cache
  • Self-Attention With Cache
  • Huggingface 官方代码实现

Why we need KV Cache ?

生成式generative模型的推理过程很有特点,我们给一个输入文本,模型会输出一个回答(长度为N),其实该过程中执行了N次推理过程。即GPT类模型一次推理只输出一个token,输出token会与输入tokens 拼接在一起,然后作为下一次推理的输入,这样不断反复直到遇到终止符。

如上描述是我们通常认知的GPT推理过程。代码描述如下:

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizerdef main():# 加载模型和 tokenizermodel = GPT2LMHeadModel.from_pretrained("gpt2").eval()tokenizer = GPT2Tokenizer.from_pretrained("gpt2")# 初始输入in_text = "Open AI is a"in_tokens = torch.tensor(tokenizer.encode(in_text)).unsqueeze(0)  # [1, seq_len]token_eos = torch.tensor([198])  # line break symbolout_token = Nonei = 0with torch.no_grad():while out_token != token_eos:outputs = model(in_tokens)logits = outputs.logitsout_token = torch.argmax(logits[0, -1, :], dim=-1, keepdim=True).unsqueeze(0)  # [1, 1]in_tokens = torch.cat((in_tokens, out_token), dim=1)text = tokenizer.decode(in_tokens[0])print(f'step {i} input: {text}', flush=True)i += 1out_text = tokenizer.decode(in_tokens[0])print(f'\nInput: {in_text}')print(f'Output: {out_text}')if __name__ == "__main__":main()

输出:

step 0 input: Open AI is a new
step 1 input: Open AI is a new way
step 2 input: Open AI is a new way to
step 3 input: Open AI is a new way to build
step 4 input: Open AI is a new way to build AI
step 5 input: Open AI is a new way to build AI that
step 6 input: Open AI is a new way to build AI that is
step 7 input: Open AI is a new way to build AI that is more
step 8 input: Open AI is a new way to build AI that is more efficient
step 9 input: Open AI is a new way to build AI that is more efficient and
step 10 input: Open AI is a new way to build AI that is more efficient and more
step 11 input: Open AI is a new way to build AI that is more efficient and more efficient
step 12 input: Open AI is a new way to build AI that is more efficient and more efficient than
step 13 input: Open AI is a new way to build AI that is more efficient and more efficient than traditional
step 14 input: Open AI is a new way to build AI that is more efficient and more efficient than traditional AI
step 15 input: Open AI is a new way to build AI that is more efficient and more efficient than traditional AI.
step 16 input: Open AI is a new way to build AI that is more efficient and more efficient than traditional AI.Input: Open AI is a
Output: Open AI is a new way to build AI that is more efficient and more efficient than traditional AI.

在上面的推理过程中,每 step 内,输入一个 token序列,经过Embedding层将输入token序列变为一个三维张量 [b, s, h],经过一通计算,最后经 logits 层将计算结果映射至词表空间,输出张量维度为 [b, s, vocab_size]。

当前轮输出token与输入tokens拼接,并作为下一轮的输入tokens,反复多次。可以看出第 i+1 轮输入数据只比第 i 轮输入数据新增了一个 token,其他全部相同!

因此第 i+1 轮推理时必然包含了第 i 轮的部分计算。KV Cache 的出发点就在这里,缓存当前轮可重复利用的计算结果,下一轮计算时直接读取缓存结果。

上面所举例子并没有使用KV Cache进行推理,请注意。

Self-Attention Without Cache

下图给出了无 Cache 情况下,类GPT式生成式模型进行推理的过程:

在这里插入图片描述

这种方式的问题是: 每生成一个 token,就要重新计算所有之前 token 的 Q/K/V + Attention + FFN

Self-Attention With Cache

下图给出了有 Cache 情况下,类GPT式生成式模型进行推理的过程:

在这里插入图片描述

Huggingface 官方代码实现

本节将根据 Huggingface 官方代码实现进行 KV Cache 实现讲解 (只展示核心代码,移除了大量与本文无关的逻辑)。

官方代码链接: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py

下面将给出使用了 KV Cache 进行推理的代码:

import torch
from transformers import GPT2Tokenizer, GPT2Config
from modeling_gpt2 import GPT2LMHeadModel  # copy from huggingface , 删除了大量无关代码def generate_text(model, tokenizer, prompt, max_new_tokens=50, eos_token_id=198):model.eval()input_ids = tokenizer.encode(prompt, return_tensors="pt")past_key_values = Noneoutput_ids = input_ids.clone()with torch.no_grad():for step in range(max_new_tokens):outputs = model(input_ids=input_ids,past_key_values=past_key_values,use_cache=True)logits = outputs.logitspast_key_values = outputs.past_key_valuesnext_token_logits = logits[:, -1, :]next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)output_ids = torch.cat([output_ids, next_token], dim=-1)if next_token.item() == eos_token_id:breakinput_ids = next_token  # 采用KV Cache后,推理过程修改的关键: 下一步只送入新 tokenprint(f"step {step}: {tokenizer.decode(output_ids[0])}", flush=True)return tokenizer.decode(output_ids[0])def main():config = GPT2Config()tokenizer = GPT2Tokenizer.from_pretrained("gpt2")model = GPT2LMHeadModel(config)prompt = "Once upon a time"output = generate_text(model, tokenizer, prompt)print("\nFinal output:")print(output)if __name__ == "__main__":main()

KV Cache 的引入是为了加速自回归模型的推理速度,具体体现在:

  1. 每轮推理时,只需要计算当前轮新增 token 的 Q/K/V,而不需要重新计算所有之前 token 的 Q/K/V。

  2. 缓存当前轮计算结果,下一轮推理时直接读取缓存结果。

在首轮推理的过程中,我们传入的是 promt 提示词列表,并且 KV Cache 此时为空,还未进行初始化。因此首轮推理过程需要完成 promt 提示词列表的 keys 和 values 的缓存;由于 GPT2 由多层 GPT2Block 堆叠而成,而每一层 GPT2Block 都有一个 GPT2Attention 模块, 因此 KV Cache 需要准备好每一层 GPT2Attention 模块的 keys 和 values 缓存 (分层Cache - legacy_cache)。

class GPT2Model(GPT2PreTrainedModel):def forward(self,input_ids=None,past_key_values=None, cache_position=None,attention_mask=None,position_ids=None,head_mask=None,use_cache=None,):          return_legacy_cache = Falseif use_cache:# 1. 首轮推理,先进行 Legacy Cache 初始化if past_key_values is None:return_legacy_cache = Truepast_key_values = DynamicCache()# 2. 后续推理,将模型以元组形式返回的缓存重新封装为Legacy Cache形式elif not isinstance(past_key_values, Cache):return_legacy_cache = Truepast_key_values = DynamicCache.from_legacy_cache(past_key_values)# 3. 词嵌入 inputs_embeds = self.wte(input_ids)# 4. 位置编码计算if cache_position is None:# 4.1 已经缓存的词序列长度past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0# 4.2 只为当前传入的词生成位置序列cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device)    if position_ids is None:position_ids = cache_position.unsqueeze(0) # 添加batch维度# 4.3 生成位置编码position_embeds = self.wpe(position_ids)# 5. 词嵌入 + 位置编码hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)# 6. 进入堆叠GPT2Block模块前向传播流程for i, block in enumerate(self.h):hidden_states = block(hidden_states,past_key_values if not (self.gradient_checkpointing and self.training) else None, # 训练时,不启用KV Cachecache_position,causal_mask,use_cache=use_cache,)hidden_states = self.ln_f(hidden_states)hidden_states = hidden_states.view(output_shape)# 7. 将KV Cache用元组的形式进行返回 past_key_values = past_key_values if use_cache else Noneif return_legacy_cache:past_key_values = past_key_values.to_legacy_cache()return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states,past_key_values=past_key_values,hidden_states=all_hidden_states,attentions=all_self_attentions,cross_attentions=all_cross_attentions,)

下图展示的是步骤7中以元组形式返回的KV Cache结构:

在这里插入图片描述

下面将展示GPT2Block模块的实现逻辑,由于不涉及KV Cache的实现细节,所以不过多展开:

class GPT2Block(GradientCheckpointingLayer):def forward(self,hidden_states: Optional[tuple[torch.FloatTensor]],past_key_value: Optional[Cache] = None,cache_position: Optional[torch.LongTensor] = None,attention_mask: Optional[torch.FloatTensor] = None,use_cache: Optional[bool] = False,) -> Union[tuple[torch.Tensor], Optional[tuple[torch.Tensor, tuple[torch.FloatTensor, ...]]]]:# 1. 归一化residual = hidden_stateshidden_states = self.ln_1(hidden_states)# 2. 自注意力运算attn_output, self_attn_weights = self.attn(hidden_states,past_key_value=past_key_value,cache_position=cache_position,attention_mask=attention_mask,use_cache=use_cache,)# 3. residual connectionhidden_states = attn_output + residual# 4. 归一化 + MLP +  residual connectionresidual = hidden_stateshidden_states = self.ln_2(hidden_states)feed_forward_hidden_states = self.mlp(hidden_states)hidden_states = residual + feed_forward_hidden_statesreturn hidden_states

推理时的常规流程(无 KV Cache), 每生成一个新 token,都要:

  • 重新输入全部历史 token

  • 对所有历史 token 重新计算 key 和 value

  • 这意味着重复计算,效率低,计算开销线性增长


有了 KV Cache 后的改进:

  1. 第一次输入完整句子,计算并缓存其 key/value;

  2. 后续每次生成新 token 时:

    • 只计算新 token 的 query、key、value;

    • 把新 token 的 key/value 插入缓存中(代码中用 past_key_value.update(...) 完成);

    • attention 直接使用「历史缓存 key/value + 当前新 token 的 key/value」来完成;

  3. 整个注意力的 query 只有一个(当前 token),key/value 是历史缓存 + 当前 token

class GPT2Attention(nn.Module):def __init__(self, config, is_cross_attention=False, layer_idx=None):self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) # 输入维度: (batch,seq_len,embed_dim) , 变换后的输出维度: (batch,seq_len,3*embed_dim)self.c_proj = Conv1D(self.embed_dim, self.embed_dim)def forward(self,hidden_states: Optional[tuple[torch.FloatTensor]],past_key_value: Optional[Cache] = None,cache_position: Optional[torch.LongTensor] = None,attention_mask: Optional[torch.FloatTensor] = None,) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:# 1. 一维卷积进行线性变换和升维,然后切分成query,key,valuequery_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)# 2. (batch,seq_len,-1,head_dim) , head_dim 是多头自注意力中每个头切分到的维度 shape_q = (*query_states.shape[:-1], -1, self.head_dim)shape_kv = (*key_states.shape[:-1], -1, self.head_dim)# 3. 维度统一: (batch,heads,seq_len,head_dim)query_states = query_states.view(shape_q).transpose(1, 2)key_states = key_states.view(shape_kv).transpose(1, 2)value_states = value_states.view(shape_kv).transpose(1, 2)# 4. KV Cache 不为空 if past_key_value is not None:# 4.1 cache_position 记录当前词对应输入词序列中的索引cache_kwargs = {"cache_position": cache_position}# 4.2 将当前词的key和val进行缓存,根据所在GPTBlock层级(layer_idx说明),和位于词序列的索引(cache_kwargs说明),插入对应层的list缓存中去,同时返回对应的key和val listkey_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs)# 5. 进行经典的多头自注意力运算(不展开细聊) attn_output, attn_weights = attention_interface(self,query_states, # 当前输入词的querykey_states,   # cache key list + 输入词的keyvalue_states,  # cache val list + 输入词的valattention_mask, # padding maskdropout=self.attn_dropout.p if self.training else 0.0,)attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()attn_output = self.c_proj(attn_output)attn_output = self.resid_dropout(attn_output)return attn_output, attn_weights

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/90813.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/90813.shtml
英文地址,请注明出处:http://en.pswp.cn/bicheng/90813.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

代码随想录算法训练营第五十三天|图论part4

110.字符串接龙 题目链接&#xff1a;110. 字符串接龙文章讲解&#xff1a;代码随想录思路&#xff1a; 把每个字符串看成图的一个节点。 转换为求无权图两节点的的最短路径。求最短路径用bfs #include <string> #include <vector> #include <iostream> #i…

Java进阶4:泛型、序列化和反序列化

Java泛型 Java泛型是JDK5引入的一个新的特性&#xff0c;泛型提供了编译时的类型安全检测机制&#xff0c;这个机制运行程序员在编译的时候检测到非法的类型。泛型的本质是参数化类型&#xff0c;也就是所操作的数据类型被指定为一个参数。 泛型方法 可以写一个泛型方法&#x…

RAG实战指南 Day 24:上下文构建与提示工程

【RAG实战指南 Day 24】上下文构建与提示工程 文章内容 开篇 欢迎来到"RAG实战指南"系列的第24天&#xff01;今天我们将深入探讨RAG系统中至关重要的上下文构建与提示工程技术。在检索增强生成系统中&#xff0c;如何有效地组织检索到的文档片段&#xff0c;并将…

AWD的攻击和防御手段

一、AWD相关介绍 AWD&#xff08;Attack With Defence&#xff09;是 CTF 线下赛中最接近真实攻防场景、观赏性和对抗性最强的赛制之一。 赛制本质 人人对抗&#xff1a;所有战队互为攻击者与防守者。 零和记分&#xff1a;你拿到的每一分都是别人的失分&#xff0c;总积分恒…

泛微OA8前台SQL注入

漏洞URL&#xff1a; http://106.15.190.147/js/hrm/getdata.jsp?cmdgetSelectAllId&sql***注入点 在getdata.jsp中&#xff0c;直接将request对象交给 weaver.hrm.common.AjaxManager.getData(HttpServletRequest, ServletContext) : 方法处理 在getData方法中&#xff0…

Android 蓝牙学习

在Android中&#xff0c;进行蓝牙设备扫描startDiscovery需要申请位置权限&#xff0c;但有的时候并不需要申请位置权限&#xff0c;就有了android:usesPermissionFlags"neverForLocation"&#xff0c;设置这个就不用申请位置权限来进行蓝牙扫描。 android:usesPerm…

Earth靶机攻略

一.环境准备 1.1Earth靶机环境准备 首先将我们解压好的的Earth.ova放入虚拟机里&#xff0c;并配置环境 将网络连接换成NET连接 二.信息搜集 2.1ip搜集 2.1.1使用netdiscover命令扫描靶机的ip地址,发现地址为192.168.182.137 2.2端口扫描 2.2.1使用nmap工具对目标机进行端…

java8 List常用基本操作(去重,排序,转换等)

参考简述网址: java8 List 根据对象某个字段或多个字段去重、筛选、List转Map、排序、分组、统计计数等等 list简单方法示例如下&#xff1a; 一、先定义一个订单对象&#xff08;Order&#xff09; public class Order {private Long id;private Long userId;private Strin…

【linux】HAProxy七层代理

HAPorxy简介HAProxy&#xff08;High Availability Proxy&#xff09;是一款高性能、开源的负载均衡器与反向代理服务器&#xff0c;主要用于 TCP 和 HTTP 协议的流量分发&#xff0c;广泛应用于高并发、高可用的网络架构中核心特性超高性能与低资源消耗基于单进程、多线程模型…

Android 的16 KB内存页设备需要硬件支持吗,还是只需要手机升级到Android15系统就可以

Android 的16 KB内存页设备需要硬件支持吗&#xff0c;还是只需要手机升级到Android15系统就可以&#xff1f; deepseek回答&#xff1a; DeepSeek R1 思考 3 秒 思考过程 好的&#xff0c;用户现在问的是Android的16 KB内存页设备是否需要硬件支持&#xff0c;还是只需要升级到…

相机内外参矩阵:从3D世界坐标到2D像素坐标变换

相机内外参矩阵&#xff1a;从3D世界坐标到2D像素坐标变换介绍**1. 内参矩阵&#xff08;Intrinsic Matrix, K&#xff09;****2. 外参矩阵&#xff08;Extrinsic Matrix, [R|t]&#xff09;****3. 完整投影过程&#xff08;世界坐标 → 像素坐标&#xff09;****步骤1&#xf…

哈希指针与数据结构:构建可信数字世界的基石

一、哈希指针的核心原理哈希指针是一种创新型数据结构&#xff0c;融合了传统指针的定位功能与密码学哈希的验证能力&#xff1a;双重功能&#xff1a;既存储数据地址&#xff0c;又包含该数据的哈希值&#xff0c;实现数据定位与完整性验证的统一。抗篡改机制&#xff1a;数据…

java实现一个方法,isTure则程序继续往下,为false则return的链式写法

以下是实现链式条件检查的Java方法&#xff0c;采用函数式风格设计。代码包含一个Chainable类&#xff0c;支持连续的check方法和多个终止操作&#xff08;如then, orElse等&#xff09;&#xff0c;满足在条件为false时中断链式调用并返回默认值的需求&#xff1a;import java…

数据结构学习之堆

本篇我们将学习新的数据结构——二叉树。 作者的个人gitee&#xff1a;楼田莉子 (riko-lou-tian) - Gitee.com 目录 树的概念 树形结构 非树形结构 树的相关术语 树的表示 树在实际生活上的应用 二叉树 慢二叉树 完全二叉树 二叉树的储存结构 二叉树的存储结构 顺序结构…

【csdn问答社区分析】前端开发热点问题全解析

前端时间我在csdn问答社区的前端部分"视察”了一圈发现了大家的问题主要集中在以下方面一、框架与组件库使用问题 Vue相关问题 组件化开发&#xff1a;如avue-crud组件自定义样式不生效、el-select大数据分页懒加载、element-plus表格动态列校验等。功能实现&#xff1a;包…

Pycharm2025 安装教程 免费分享 没任何套路

Pycharm 安装也是很简单的&#xff0c;简单过一下流程&#xff0c;如果需要的可以转存下载到自己电脑上。我用夸克网盘分享了「pycharm2025」&#xff0c;复制链接浏览器打开转存后即可下载。链接&#xff1a;https://pan.quark.cn/s/4bb74a939332备注&#xff1a;附带2023-202…

Javaweb————什么是超文本传输协议?

&#x1f3cd;️&#x1f3cd;️&#x1f3cd;️引言&#xff1a;什么是协议&#xff1f; 协议是一种约定&#xff0c;规定好一种信息的格式&#xff0c;如果发送方按照这种请求格式发送信息,那么接 收端就要按照这样的格式解析数据,否则就会出错&#xff0c;这就是协议 常用协…

UniappDay03

1.热门推荐-准备工作// 用defineProps获取页面参数,query const query defineProps<{type: string }>() const currHot hotMap.find((v) > v.type query.type) // 动态设置标题 uni.setNavigationBarTitle({ title: currHot!.title }) </script>2.获取热门推…

基于动态增强的 LLM 置信度方法研究

基于动态增强的 LLM 置信度方法研究 一、引言(Introduction) 大型语言模型(LLM)的性能提升高度依赖于对模型内部表征的精准调控 —— 表征工程通过优化模型中间层隐藏状态的传递规律,能够在不改变模型参数的前提下显著提升任务适应性(Wei et al., 2022)。当前主流方法中…

ComfyUI中运行Wan 2.1工作流,电影级视频,兼容Mac Windows

魔当(LM Downloader)是一个大模型应用下载工具 &#xff0c;目前 魔当 已经支持ComfyUI下载Wan 2.1视频模型。 魔当下载地址 https://seemts.com/ 先看生成效果 原始图片&#xff0c;你可以保存到自己电脑上测试 生成视频&#xff1a; 推荐提示词&#xff1a; A futurist…