知方号

知方号

社区供稿 <供稿是什么>

在推理阶段,当输入长度为   ,我们仅需使用   即可预测出下一个token,但模型却会并行计算出   ,这部分会产生大量的冗余计算。而实际上   可直接通过公式   算出,即   的计算只与   、所有   和   有关。

KV Cache的本质是以空间换时间,它将历史输入的token的   和   缓存下来,避免每步生成都重新计算历史token的   和   以及注意力表示   ,而是直接通过   的方式计算得到   ,然后预测下一个token。

举个例子,用户输入“中国的首都”,模型续写得到的输出为“是北京”,KV Cache每一步的计算过程如下。

第一步生成时,缓存   均为空,输入为“中国的首都”,模型将按照常规方式并行计算:

并行计算得到每个token对应的   ,以及注意力表示   。

使用   预测下一个token,得到“是”。

更新缓存,令   ,   。

第二步生成时,计算流程如下:

仅将“是”输入模型,对其词向量进行映射,得到   。

更新缓存,令   ,   。

计算   ,预测下一个token,得到“北”

第三步生成时,计算流程如下:

仅将“北”输入模型,对其词向量进行映射,得到   。

更新缓存,令   ,   。

计算   ,预测下一个token,得到“京”。

上述生成流程中,只有在第一步生成时,模型需要计算所有token的   ,并且缓存下来。此后的每一步,仅需计算当前token的  、 、  ,更新缓存  、  ,然后使用  、 、  即可算出当前token的注意力表示,最后用来预测一下个token。

Hungging Face对于KV Cache的实现代码如下,结合注释可以更加清晰地理解其运算过程:

query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)query = self._split_heads(query, self.num_heads, self.head_dim) # 当前token对应的querykey = self._split_heads(key, self.num_heads, self.head_dim) # 当前token对应的keyvalue = self._split_heads(value, self.num_heads, self.head_dim) # 当前token对应的valueif layer_past is not None: past_key, past_value = layer_past # KV Cache key = torch.cat((past_key, key), dim=-2) # 将当前token的key与历史的K拼接 value = torch.cat((past_value, value), dim=-2) # 将当前token的value与历史的V拼接if use_cache is True: present = (key, value)else: present = None# 使用当前token的query与K和V计算注意力表示if self.reorder_and_upcast_attn: attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)else: attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

KV Cache是以空间换时间,当输入序列非常长的时候,需要缓存非常多k和v,显存占用非常大。为了缓解该问题,可以使用MQA、GQA、Page Attention等技术,在后续的文章中,我们也将对这些技术进行介绍。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至lizi9903@foxmail.com举报,一经查实,本站将立刻删除。