欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 汽车 > 时评 > 【速写】KV-cache与解码的再探讨(以束搜索实现为例)

【速写】KV-cache与解码的再探讨(以束搜索实现为例)

2025/5/13 9:31:42 来源:https://blog.csdn.net/CY19980216/article/details/147903242  浏览:    关键词:【速写】KV-cache与解码的再探讨(以束搜索实现为例)

文章目录

  • 1 Beam Search 解码算法实现
  • 2 实现带KV Cache的Beam Search解码
  • 3 关于在带kv-cache的情况下的use_cache参数


1 Beam Search 解码算法实现

下面是一个使用PyTorch实现的beam search解码算法:

几个小细节:

  • 束搜索可以加入length_penalty,目前model.generate也是有这个参数的,这个惩罚项直接是用来除生成概率的
  • 通常这种需要计算概率相乘的情况,都是避免做乘法,而是使用log p相加
  • 具体实现中应当考虑eos标识符导致的early stop的候选序列,需要提前存储到外面
  • 然后就是关于使用log softmax得到log概率后,这其实是一个负的概率,序列越长,log prob会越小,- log prob 才是越大的,因此在做惩罚的时候,应该是吧 prob / len(seq) ** penality,即长序列的 log prob 会被除掉更多,这是合理的,因为短序列的 - log prob 天然地比 长序列地 - log prob 要更小,这样量纲才是正确的
import torch
import torch.nn.functional as F
from typing import List, Tupledef beam_search(model: torch.nn.Module,initial_input: torch.Tensor,beam_width: int,max_length: int,vocab_size: int,device: torch.device,length_penalty: float = 1.0,early_stopping: bool = True
) -> Tuple[List[List[int]], List[float]]:"""Beam search 解码算法实现参数:model: 用于预测下一个token的模型initial_input: 初始输入张量 (shape: [1, seq_len])beam_width: beam大小max_length: 生成序列的最大长度vocab_size: 词汇表大小device: 使用的设备 (cpu/cuda)length_penalty: 长度惩罚系数 (α), 用于调整对长序列的偏好early_stopping: 是否在所有beam序列达到EOS时提前停止返回:Tuple[List[List[int]], List[float]]: (生成的序列列表, 对应的分数列表)"""# 初始化beamsequences = [[initial_input.tolist()[0]]]  # 初始序列scores = [0.0]  # 初始分数 (log概率)# 存储完整的beam (已经生成EOS的序列)completed_sequences = []completed_scores = []for step in range(max_length):# 如果所有beam都已完成,提前停止if early_stopping and len(sequences) == 0:break# 准备当前步的输入candidates = []for i, seq in enumerate(sequences):# 跳过已经完成的序列if len(seq) > 0 and seq[-1] == 2:  # 假设2是EOS tokencompleted_sequences.append(seq)completed_scores.append(scores[i])continue# 将序列转换为张量input_tensor = torch.tensor([seq], dtype=torch.long).to(device)# 获取模型预测with torch.no_grad():outputs = model(input_tensor)next_token_logits = outputs[:, -1, :]  # 取最后一个token的logitsnext_token_probs = F.log_softmax(next_token_logits, dim=-1)# 获取top-k tokens和它们的log概率topk_probs, topk_tokens = torch.topk(next_token_probs, beam_width, dim=-1)topk_probs = topk_probs.squeeze(0)topk_tokens = topk_tokens.squeeze(0)# 创建候选序列for j in range(beam_width):new_seq = seq.copy()new_seq.append(topk_tokens[j].item())new_score = scores[i] + topk_probs[j].item()candidates.append((new_seq, new_score))# 如果没有候选序列,提前停止if not candidates:break# 选择top-k候选序列,并剪枝candidates.sort(key=lambda x: x[1] / (len(x[0]) ** length_penalty), reverse=True)sequences, scores = zip(*candidates[:beam_width])sequences = list(sequences)scores = list(scores)# 添加剩余的未完成序列到完成列表中completed_sequences.extend(sequences)completed_scores.extend(scores)# 对完成的序列按分数排序sorted_sequences = [seq for seq, _ in sorted(zip(completed_sequences, completed_scores),key=lambda x: x[1] / (len(x[0]) ** length_penalty),reverse=True)]sorted_scores = sorted(completed_scores,key=lambda score: score / (len(sorted_sequences[completed_scores.index(score)]) ** length_penalty),reverse=True)return sorted_sequences, sorted_scores
  1. 模型要求:

    • 模型应接受形状为 [batch_size, seq_len] 的输入
    • 模型应输出形状为 [batch_size, seq_len, vocab_size] 的logits
  2. 参数说明:

    • initial_input: 初始输入序列(如开始token)
    • beam_width: 控制搜索宽度,值越大结果可能越好但计算成本越高
    • length_penalty: 控制生成长度的偏好(α>1鼓励更长序列,α<1鼓励更短序列)
    • early_stopping: 设为True时,当所有beam都生成EOS token时停止
  3. 返回值:

    • 返回按分数排序的序列列表和对应的分数列表

如何调用?

model = ...  # 你的PyTorch模型
initial_input = torch.tensor([[1]])  # 假设1是开始token
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)sequences, scores = beam_search(model=model,initial_input=initial_input,beam_width=5,max_length=50,vocab_size=10000,  # 你的词汇表大小device=device,length_penalty=1.2
)print("Top sequence:", sequences[0])
print("Score:", scores[0])

2 实现带KV Cache的Beam Search解码

这个是关键,之前写过一个k步的贪心解码,使用了递归写法,但不知道怎么在递归中加入kv cache,其实感觉怎么写都不太好。

其实这样看只要写循环就可以了。

实际上transformers的model.generate方法里,对greedy、beam、以及其他采样的方法都是有做 kv cache 缓存优化的:

# 首次调用(处理初始输入)
outputs = model(input_ids, use_cache=True)
logits = outputs.logits
past_key_values = outputs.past_key_values  # 缓存KV# 后续步骤(生成时)
next_token_input = torch.tensor([[new_token]])
outputs = model(next_token_input, past_key_values=past_key_values, use_cache=True)
new_past_key_values = outputs.past_key_values  # 更新缓存

要在Beam Search中为每个候选序列维护独立的KV Cache,我们需要对标准实现进行一些修改。KV Cache可以显著提高自回归模型的推理效率,因为它避免了重复计算先前token的key和value。

以下是带有KV Cache的Beam Search实现:

  • 因为我们已经有cache了,因此实际上每步迭代生成下一个token时,只需要把当前序列的最后一个token输入即可,只是很棒的操作
import torch
import torch.nn.functional as F
from typing import List, Tuple, Optional, Dict, Anydef beam_search_with_kv_cache(model: torch.nn.Module,initial_input: torch.Tensor,beam_width: int,max_length: int,vocab_size: int,device: torch.device,length_penalty: float = 1.0,early_stopping: bool = True,use_kv_cache: bool = True
) -> Tuple[List[List[int]], List[float]]:"""带KV Cache的Beam Search解码算法参数:model: 用于预测下一个token的模型initial_input: 初始输入张量 (shape: [1, seq_len])beam_width: beam大小max_length: 生成序列的最大长度vocab_size: 词汇表大小device: 使用的设备 (cpu/cuda)length_penalty: 长度惩罚系数early_stopping: 是否在所有beam序列达到EOS时提前停止use_kv_cache: 是否使用KV Cache加速返回:Tuple[List[List[int]], List[float]]: (生成的序列列表, 对应的分数列表)"""# 初始化beamsequences = [[initial_input.tolist()[0]]]scores = [0.0]# 存储KV Cache (每个候选序列一个cache)kv_caches = [None]  # 初始cache为None# 存储完整的beamcompleted_sequences = []completed_scores = []for step in range(max_length):if early_stopping and len(sequences) == 0:breakcandidates = []new_kv_caches = []for i, (seq, score, kv_cache) in enumerate(zip(sequences, scores, kv_caches)):# 跳过已经完成的序列if len(seq) > 0 and seq[-1] == 2:  # 假设2是EOS tokencompleted_sequences.append(seq)completed_scores.append(score)continue# 准备输入 (只使用最后一个token,因为前面的已经cache了)input_tensor = torch.tensor([[seq[-1]]], dtype=torch.long).to(device)# 前向传播,使用或更新KV Cachewith torch.no_grad():if use_kv_cache:if kv_cache is None:# 第一次调用,处理整个初始序列full_input = torch.tensor([seq], dtype=torch.long).to(device)outputs = model(full_input, use_cache=True)next_token_logits = outputs.logits[:, -1, :]new_kv_cache = outputs.past_key_valueselse:# 后续调用,使用KV Cacheoutputs = model(input_tensor, past_key_values=kv_cache, use_cache=True)next_token_logits = outputs.logits[:, -1, :]new_kv_cache = outputs.past_key_valueselse:# 不使用KV Cache的情况full_input = torch.tensor([seq], dtype=torch.long).to(device)outputs = model(full_input, use_cache=False)next_token_logits = outputs.logits[:, -1, :]new_kv_cache = Nonenext_token_probs = F.log_softmax(next_token_logits, dim=-1)# 获取top-k tokenstopk_probs, topk_tokens = torch.topk(next_token_probs, beam_width, dim=-1)topk_probs = topk_probs.squeeze(0)topk_tokens = topk_tokens.squeeze(0)# 创建候选序列for j in range(beam_width):new_seq = seq.copy()new_seq.append(topk_tokens[j].item())new_score = score + topk_probs[j].item()candidates.append((new_seq, new_score, new_kv_cache))# 如果没有候选序列,提前停止if not candidates:break# 选择top-k候选序列candidates.sort(key=lambda x: x[1] / (len(x[0]) ** length_penalty), reverse=True)top_candidates = candidates[:beam_width]# 解包候选序列sequences = [cand[0] for cand in top_candidates]scores = [cand[1] for cand in top_candidates]kv_caches = [cand[2] for cand in top_candidates]# 添加剩余的未完成序列completed_sequences.extend(sequences)completed_scores.extend(scores)# 对完成的序列按分数排序sorted_pairs = sorted(zip(completed_sequences, completed_scores),key=lambda x: x[1] / (len(x[0]) ** length_penalty),reverse=True)sorted_sequences = [pair[0] for pair in sorted_pairs]sorted_scores = [pair[1] for pair in sorted_pairs]return sorted_sequences, sorted_scores

关键修改说明

  1. KV Cache维护:

    • 每个候选序列现在维护自己的KV Cache
    • 初始时KV Cache为None
    • 第一次处理序列时生成完整的KV Cache
    • 后续步骤只处理最后一个token并更新KV Cache
  2. 模型接口要求:

    • 模型需要支持past_key_values参数来接收先前的KV Cache
    • 模型需要返回past_key_values作为输出的一部分
    • 典型实现方式(如HuggingFace的transformers):
      outputs = model(input_ids, past_key_values=past_key_values, use_cache=True)
      next_token_logits = outputs.logits
      past_key_values = outputs.past_key_values
      
  3. 性能优化:

    • 使用KV Cache后,每次前向传播只处理最后一个token(这个很有趣,但是要设置use_cache=True
    • 避免了重复计算先前token的key和value
    • 对于长序列可以显著提高速度

一个简单的调用示例:

# 假设我们有一个支持KV Cache的模型
model = ...  # 例如HuggingFace的GPT2模型
initial_input = torch.tensor([[model.config.bos_token_id]])  # 开始token
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 运行带KV Cache的beam search
sequences, scores = beam_search_with_kv_cache(model=model,initial_input=initial_input,beam_width=5,max_length=50,vocab_size=model.config.vocab_size,device=device,length_penalty=1.2,use_kv_cache=True  # 启用KV Cache
)print("Top sequence:", sequences[0])
print("Score:", scores[0])

补注:

在这个部分:

            # 前向传播,使用或更新KV Cachewith torch.no_grad():if use_kv_cache:if kv_cache is None:# 第一次调用,处理整个初始序列full_input = torch.tensor([seq], dtype=torch.long).to(device)outputs = model(full_input, use_cache=True)

上,输出的full_input 的size是[1, 1, seqlen],理论上应该是[1, seqlen]才对,因此要么是

            # 前向传播,使用或更新KV Cachewith torch.no_grad():if use_kv_cache:if kv_cache is None:# 第一次调用,处理整个初始序列full_input = torch.tensor(seq, dtype=torch.long).to(device)outputs = model(full_input, use_cache=True)

要么是:

            # 前向传播,使用或更新KV Cachewith torch.no_grad():if use_kv_cache:if kv_cache is None:# 第一次调用,处理整个初始序列full_input = torch.tensor([seq], dtype=torch.long).to(device)outputs = model(full_input.squeeze(0), use_cache=True)

这样测试跑通应该是没有问题的


3 关于在带kv-cache的情况下的use_cache参数

比如之前手写的一个贪心解码算法:

# -*- coding: utf8 -*-
# @author: caoyang
# @email: caoyang@stu.sufe.edu.cnimport torch
import logging
from copy import deepcopy
from functools import wraps
from torch.nn import functional as Ffrom transformers import AutoTokenizer, AutoModelForCausalLM# Standard greedy decode
# @param model: Huggingface model object
# @param tokenizer: Huggingface tokenizer Object
# @param prompt: Str
# @param max_length: Int, the number of tokens to be generated
# @param device: Str, e.g. "cuda" or "cpu"
# @param kv_cache: Boolean, whether to use KV-cache to accelerate, if True then large memory will be consumed
# @return generated_text: Str
# @return generated_token_prob: List[Tuple(Int, Str, Float)], `len(generated_id_prob)` is `max_length`, indicating the generated probability of each token
# @return generated_logits: Tuple[FloatTensor(1, n_vocab)], `len(generated_logits)` is `max_length`, indicating the logits when each token is generated
def greedy_decode(model,tokenizer,prompt, max_length,device = "cuda",kv_cache = True,):inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)	# Str => Long(1, n_tokens)past_key_values = Nonegenerated_token_probs = list()generated_logits = list()model.gradient_checkpointing_enable()for i in range(max_length):logging.info(f"Round {i}: {past_key_values.key_cache[0].size() if past_key_values is not None else None}")outputs = model(inputs, past_key_values=past_key_values)logits = outputs.logits	# Float(1, n_tokens + i + 1, n_vocab), where `n_vocab` is 151936 in DeepSeek-R1-Distill-Qwenif kv_cache:past_key_values = outputs.past_key_values	# Dictlike[key_cache: Float(1, 2, X, hidden_size), value_cache: Float(1, 2, X, hidden_size)], where X = (i + 1) * (n_tokens + i / 2)next_token_probs = F.softmax(logits[:, -1, :], dim=-1)	# Float(1, n_tokens + i + 1, n_vocab) => Float(1, n_vocab)next_token_id = torch.argmax(next_token_probs, dim=-1)	# Float(1, n_vocab) => Long(1, )next_token_prob = next_token_probs[0, next_token_id].item()	# Float(1, n_vocab) => Float()next_token = tokenizer.decode(next_token_id[0].item(), skip_special_tokens=False)	# Long(1, ) => Strinputs = torch.cat([inputs, next_token_id.unsqueeze(-1)], dim=-1)	# Long(1, n_tokens + i) => Long(1, n_tokens + i + 1)generated_token_probs.append((next_token_id.item(), next_token, next_token_prob))generated_logits.append(logits[:, -1, :])generated_text = tokenizer.decode(token_ids = inputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True,)	# Long(1, n_tokens + max_length) => Strreturn generated_text, generated_token_probs, tuple(generated_logits)

实际上除了第一次输入外,接下来都可以用最后一个token作为输入,而不需要把之前整个一长串的input都输入到model中去:

# -*- coding: utf8 -*-
# @author: caoyang
# @email: caoyang@stu.sufe.edu.cnimport torch
import logging
from copy import deepcopy
from functools import wraps
from torch.nn import functional as Ffrom transformers import AutoTokenizer, AutoModelForCausalLM# Standard greedy decode
# @param model: Huggingface model object
# @param tokenizer: Huggingface tokenizer Object
# @param prompt: Str
# @param max_length: Int, the number of tokens to be generated
# @param device: Str, e.g. "cuda" or "cpu"
# @param kv_cache: Boolean, whether to use KV-cache to accelerate, if True then large memory will be consumed
# @return generated_text: Str
# @return generated_token_prob: List[Tuple(Int, Str, Float)], `len(generated_id_prob)` is `max_length`, indicating the generated probability of each token
# @return generated_logits: Tuple[FloatTensor(1, n_vocab)], `len(generated_logits)` is `max_length`, indicating the logits when each token is generated
def greedy_decode(model,tokenizer,prompt, max_length,device = "cuda",kv_cache = True,):inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)	# Str => Long(1, n_tokens)past_key_values = Nonegenerated_token_probs = list()generated_logits = list()model.gradient_checkpointing_enable()for i in range(max_length):logging.info(f"Round {i}: {past_key_values.key_cache[0].size() if past_key_values is not None else None}")if kv_cache:if i == 0:outputs = model(inputs, past_key_values=past_key_values)else:outputs = model(inputs[:, -1].unsqueeze(0), past_key_values=past_key_values, use_cache=True)else:outputs = model(inputs, past_key_values=None)logits = outputs.logits	# Float(1, n_tokens + i + 1, n_vocab), where `n_vocab` is 151936 in DeepSeek-R1-Distill-Qwenif kv_cache:past_key_values = outputs.past_key_values	# Dictlike[key_cache: Float(1, 2, X, hidden_size), value_cache: Float(1, 2, X, hidden_size)], where X = (i + 1) * (n_tokens + i / 2)next_token_probs = F.softmax(logits[:, -1, :], dim=-1)	# Float(1, n_tokens + i + 1, n_vocab) => Float(1, n_vocab)next_token_id = torch.argmax(next_token_probs, dim=-1)	# Float(1, n_vocab) => Long(1, )next_token_prob = next_token_probs[0, next_token_id].item()	# Float(1, n_vocab) => Float()next_token = tokenizer.decode(next_token_id[0].item(), skip_special_tokens=False)	# Long(1, ) => Strinputs = torch.cat([inputs, next_token_id.unsqueeze(-1)], dim=-1)	# Long(1, n_tokens + i) => Long(1, n_tokens + i + 1)generated_token_probs.append((next_token_id.item(), next_token, next_token_prob))generated_logits.append(logits[:, -1, :])generated_text = tokenizer.decode(token_ids = inputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True,)	# Long(1, n_tokens + max_length) => Strreturn generated_text, generated_token_probs, tuple(generated_logits)

这个确实是很有帮助的,能加速推理很多。这个原理其实很简单,因为只需要KVcache与最后一个token就可以计算得到下一层的注意力权重(其实就是下一轮生成的KVcache),然后倒是发现deepseek在生成图像链接时出错了,难得逮到DeepSeek犯错的时候:

在这里插入图片描述

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词