Self-Attention 和 KV Cache 是如何工作的

Attention Is All You Need

目前大部分 LLM 基于 decoder 架构 [3],推理过程分为两步:(1)使用 self-attention 向量化语句输入;(2)使用自回归(autoregressive)decoder 预测下一个 token。

Self-Attention

那篇著名的论文提出了 transformer 模型和 self-attention 结构。它利用 multi-head 和 scaled dot-product 技术,使得模型能够更好的高亮语句中的关键信息(忽略无关信息),使生成的语句向量能更好的表达它本身的意思。

它为啥能 work?本质上它在句子内部通过加权平均的方式,实现了句子内关键内容高亮的效果,原理请参考 [1]。

Multi-Head Attention (MHA)

只看论文比较干,浏览下代码,搞明白它们的 shapes 如何变换,很快就能搞明白其中的原理。代码源自于 [4],我又重新加了些 comments [5]。

假设用户的输入是 $X$,它的 shape 是 (batch_size, seq_len, hidden_size),其中

  • batch_size 是批处理长度,batch size 越大并发越大,消耗的内存数量也越多,推理过程中 batch size 是动态的,如果短期内有大量请求进来则用最大的 batch size,反之则在一个周期内有多少请求,batch size 就设置为多少。
  • seq_len 是输入的总长度,一般 tokenizer 会把文本拆分为 token,每个 token 都对应一个 embedding。
  • hidden_size 是 embedding 的 size,有的地方也称为 d_model

训练阶段系统随机初始化 $W_q$$W_k$$W_v$,它们的尺寸都是 (hidden_size, hidden_size)。推理阶段它们已经被训练模型调教过了(由于我不训练模型,代码里它们都是随机生成的)。

所谓的 self-attention 就是在自己的输入($X$)中提取关键信息。 此时的 query ($Q = X \cdot W_q$),key ($K = X \cdot W_k$) 和 value ($V = X \cdot W_v$),它们的 shape 都是 (batch_size, seq_len, hidden_size)。所谓的 self-attention 就是用自己的输入 $X$ 计算自己。

所谓的 multi-head 就是把上述几个参数拆为 N 个小矩阵(N 个头),从不同维度提取语句中的关键信息。比如 $Q = (Q_{1}, Q_{2}, \dots, Q_{N})$,每一个小矩阵 $Q_i$ 的尺寸是 (batch_size, seq_len, 1, d_k),其中 d_k = hidden_size / N ($d_k$),拼起来的 $Q$ 的尺寸是 (batch_size, seq_len, N, d_k)$K$ & $V$ 同理。

为了方便点乘,这里 $Q$、$K$ 和 $V$ 都做了一次转置,shape 被修改为了 (batch, N, seq_len, d_k)

Attention scores 的计算是 $(\frac{Q \cdot K^T}{\sqrt{d_k}}) \cdot V$,其中 $\sqrt{d_k}$ 的作用是缩放因子,避免训练不稳定。本质是上面三个矩阵相乘,shape 变换过程是

  • $K^T$: (batch, N, d_k, seq_len)
  • $\frac{Q \cdot K^T}{\sqrt{d_k}}$: (batch, N, seq_len, seq_len)
  • $(\frac{Q \cdot K^T}{\sqrt{d_k}}) \cdot V$: (batch_size, N, seq_len, d_k)

计算后的 scores 通过转置把 shape 变为 (batch_size, seq_len, N, d_k)。让 Nd_k 相乘 reshape 后,shape 变为 (batch_size, seq_len, hidden_size)。最后执行了 softmax 归一化,目的是让列元素相加为 1,但是不改变最终结果的 shape。

输出的就是向量化后的语句。 到了这一步只是把语句内重点提取了出来,然后把句子按照重点生成了 embedding,最终的文本输出需要 decoder 解析语句 embedding 实现文本输出。

KV Cache

那么 KV cache 顾名思义,它的作用是缓存已经计算过的 $K$ 和 $V$,避免重复计算,提升性能。

为什么可以通过 KV cache 加速计算?

一个语句由若干个 token 组成,当我们计算第 i 个 token 的 embedding 的时候,它只与 $Q$、部分 $K = (k_0, \cdots, k_i)$ 和部分 $V = (v_0, \cdots, v_i)$ 有关。你要是感兴趣其中的数学原理,参见 [2]。其中 $(k_0, \cdots, k_{i-1})$ 和 $(v_0, \cdots, v_{i-1})$ 在生成前面 tokens 的时候被计算过了,本轮只需要计算 $q_i$、$k_i$ 和 $v_i$。

所谓的 KV cache 就是两个连续的 GPU 内存区域(代码 [6]),分别保存着历史计算的 $k_0, k_1, \cdots$ 和 $v_0, v_1, \cdots$。

不启用 KV cache 时输入的是整个句子 $X$ (seq_len 个 tokens),而启用时只需要提供一个 token $x_i$,其 shape 是 (batch_size, 1, hidden_size)

Query、key 和 value 除了是逐 token 计算以外,其他的跟之前一样($q_i = x_i \cdot W_q$),其 shape 是 (batch_size, 1, hidden_size)。应用 N-heads,其 shape 变为 (batch_size, 1, N, d_k),同理 $k_i$ 和 $v_i$。

把 $k_i$ 和 $v_i$ 分别保存到 K cache 和 V cache。以 K cache 为例,它的 shape 是 (batch_size, i, N, d_k),V cache 同理。

代码里似乎会在当前基础上再拓展 seq_len 个空 tensors(why?),代码中的 shape 是 (batch_size, (i + seq_len), N, d_k),但是不影响最终结果。

这里会把 query、keys 和 values 做一次统一的转置,query 的 shape 为 (batch_size, N, 1, d_k),keys 和 values 的 shape 为 (batch_size, N, (i + seq_len), d_k)

计算 attention scores 时 shape 的变换过程:

  • $K^T$: (batch_size, N, d_k, (i + seq_len))
  • $\frac{q_i \cdot K_T}{\sqrt{d_k}}$: (batch_size, N, 1, (i + seq_len))
  • $\frac{q_i \cdot K_T}{\sqrt{d_k}} \cdot V$: (batch_size, N, 1, d_k)

随之而来是 softmax 归一化,重新调整 tensor shape,最终输出的 shape 是 (batch_size, 1, hidden_size)

Grouped Multi-Query Attention (GQA)

TODO: 等有时间再写

References

  1. https://www.zhihu.com/question/298810062/answer/2274132657
  2. https://zhuanlan.zhihu.com/p/662498827
  3. https://huggingface.co/learn/llm-course/chapter1/6
  4. https://github.com/hkproj/pytorch-transformer/blob/main/model.py#L83
  5. https://github.com/justxuewei/playground/blob/main/attention/attention.py
  6. https://github.com/justxuewei/playground/blob/main/attention/attention_kv.py
All rights reserved
Except where otherwise noted, content on this page is copyrighted.