前言
本文將與大家一起梳理從MHA、MQA、GQA到MLA的演變歷程,並重點介紹MLA的設計思路。MHA
MHA(多頭注意力)是開創性論文《Attention is all you need》中提出的一種注意力形式,可以說是當前主流大型語言模型(LLM)的基礎工作。數學上,多頭注意力是多個獨立的單頭注意力的拼接。假設輸入的向量序列為 \(x_1, x_2, \ldots, x_l\),其中 \(x_i \in \mathbb{R}^d\),那麼MHA可以記為:$$
\begin{aligned}
o_t &= [o_t^{(1)}, o_t^{(2)}, \ldots, o_t^{(h)}] \\
o_t^{(s)} &= \text{Attention}(q_t^{(s)}, k_{\leq t}^{(s)}, v_{\leq t}^{(s)}) \\
&\equiv \sum_{i \leq t} \frac{\exp(q_t^{(s)} k_i^{(s)\top})}{\sum_{i \leq t} \exp(q_t^{(s)} k_i^{(s)\top})} v_i^{(s)} \\
q_t^{(s)} &= x_t W_q^{(s)} \in \mathbb{R}^{d_k}, \quad W_q^{(s)} \in \mathbb{R}^{d \times d_k} \\
k_i^{(s)} &= x_i W_k^{(s)} \in \mathbb{R}^{d_k}, \quad W_k^{(s)} \in \mathbb{R}^{d \times d_k} \\
v_i^{(s)} &= x_i W_v^{(s)} \in \mathbb{R}^{d_v}, \quad W_v^{(s)} \in \mathbb{R}^{d \times d_v}
\end{aligned}
$$
為了簡化,這裡省略了注意力矩陣的縮放因數。在實踐中,常見的設置是 \(d_k = d_v = \frac{d}{h}\)。例如,對於LLAMA2-7b來說,有 \(d = 4096, h = 32, d_k = d_v = 128\);對於LLAMA2-70b則是 \(d = 8192, h = 64, d_k = d_v= 128\)。這裡只考慮了主流自回歸大型語言模型所使用的Causal Attention,這意味著在逐個token生成時,新預測出來的第 \( t+1 \) 個token並不會影響已經計算好的 \( k_{\leq t}^{(s)}, v_{\leq t}^{(s)} \),因此這部分結果我們可以快取下來供後續生成調用,避免不必要的重複計算,這就是所謂的KV快取(KV Cache)。
瓶頸
一個自然的問題是:為什麼降低KV快取的大小如此重要?眾所周知,一般情況下大型語言模型的推理都是在GPU上進行,單張GPU的顯存是有限的,一部分我們要用來存放模型的參數和前向計算的啟動值,這部分依賴於模型的體量,選定模型後它就是個常數;另外一部分我們要用來存放模型的KV快取,這部分不僅依賴於模型的體量,還依賴於模型的輸入長度,也就是在推理過程中是動態增長的,當Context長度足夠長時,它的大小就會占主導地位,可能超出一張卡甚至一台機(8張卡)的總顯存量。
在GPU上部署模型的原則是:能一張卡部署的,就不要跨多張卡;能一台機部署的,就不要跨多台機。這是因為「卡內通信帶寬 > 卡間通信帶寬 > 機間通信帶寬」,由於「木桶效應」,模型部署時跨的設備越多,受設備間通信帶寬的拖累就越大,事實上即便是單卡H100內SRAM與HBM的帶寬已經達到了3TB/s,但對於短Context來說這個速度依然還是推理的瓶頸,更不用說更慢的卡間、機間通信了。
所以,減少KV快取的目的就是要實現在更少的設備上推理更長的Context,或者在相同的Context長度下讓推理的batch size更大,從而實現更快的推理速度或者更大的吞吐總量。當然,最終目的都是為了實現更低的推理成本。
MQA
MQA(多查詢注意力),是減少KV快取的一次非常樸素的嘗試,首次提出自《Fast Transformer Decoding: One Write-Head is All You Need》這篇2019年的論文,這也意味著早在LLM成為熱門話題之前,減少KV快取就已經是研究人員非常關注的一個議題了。MQA的思路很簡單,直接讓所有注意力頭(Attention Head)共用同一個K(Key)、V(Value),用公式來說,就是取消了MHA中所有的 \( k \)、\( v \) 的上標 \( (s) \):
$$
\begin{aligned}
o_t &= [o_t^{(1)}, o_t^{(2)}, \ldots, o_t^{(h)}] \\
o_t^{(s)} &= \text{Attention}(q_t^{(s)}, k_{\leq t}, v_{\leq t}) \\
&\equiv \sum_{i \leq t} \frac{\exp(q_t^{(s)} k_i^\top)}{\sum_{i \leq t} \exp(q_t^{(s)} k_i^\top)} v_i \\
q_t^{(s)} &= x_t W_q^{(s)} \in \mathbb{R}^{d_k}, \quad W_q^{(s)} \in \mathbb{R}^{d \times d_k} \\
k_i &= x_i W_k \in \mathbb{R}^{d_k}, \quad W_k \in \mathbb{R}^{d \times d_k} \\
v_i &= x_i W_v \in \mathbb{R}^{d_v}, \quad W_v \in \mathbb{R}^{d \times d_v}
\end{aligned}
$$
\begin{aligned}
o_t &= [o_t^{(1)}, o_t^{(2)}, \ldots, o_t^{(h)}] \\
o_t^{(s)} &= \text{Attention}(q_t^{(s)}, k_{\leq t}, v_{\leq t}) \\
&\equiv \sum_{i \leq t} \frac{\exp(q_t^{(s)} k_i^\top)}{\sum_{i \leq t} \exp(q_t^{(s)} k_i^\top)} v_i \\
q_t^{(s)} &= x_t W_q^{(s)} \in \mathbb{R}^{d_k}, \quad W_q^{(s)} \in \mathbb{R}^{d \times d_k} \\
k_i &= x_i W_k \in \mathbb{R}^{d_k}, \quad W_k \in \mathbb{R}^{d \times d_k} \\
v_i &= x_i W_v \in \mathbb{R}^{d_v}, \quad W_v \in \mathbb{R}^{d \times d_v}
\end{aligned}
$$
使用MQA的模型包括PaLM、StarCoder、Gemini等。很明顯,MQA直接將KV快取減少到了原來的 \( \frac{1}{h} \),這是非常顯著的,單從節省顯存角度看已經是極限了。
效果方面,目前看來大部分任務的損失都比較有限,且MQA的支持者相信這部分損失可以通過進一步訓練來彌補回來。此外,注意到MQA由於共用了K、V,將會導致注意力的參數量減少了將近一半,而為了模型總參數量的不變,通常會相應地增大FFN(Feed Forward Network)/GLU(Gated Linear Unit)的規模,這也能彌補一部分效果損失。
GQA
然而,也有人擔心MQA對KV快取的壓縮太嚴重,以至於會影響模型的學習效率以及最終效果。為此,一個MHA與MQA之間的過渡版本GQA(分組查詢注意力)應運而生,出自論文《GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints》,是去年的工作。事後看來,GQA的思想也很樸素,它就是將所有Head分為 \( g \) 個組( \( g \) 可以整除 \( h \)),每組共用同一對K、V,用數學公式表示為:
$$
\begin{aligned}
o_t &= [o_t^{(1)}, o_t^{(2)}, \ldots, o_t^{(h)}] \\
o_t^{(s)} &= \text{Attention}(q_t^{(s)}, k_{\leq t}^{(\lceil s\frac{g}{h} \rceil)}, v_{\leq t}^{(\lceil s\frac{g}{h} \rceil)}) \\
&\equiv \sum_{i \leq t} \frac{\exp(q_t^{(s)} k_{(\lceil s\frac{g}{h} \rceil)i}^\top)}{\sum_{i \leq t} \exp(q_t^{(s)} k_{(\lceil s\frac{g}{h} \rceil)i}^\top)} v_{(\lceil s\frac{g}{h} \rceil)i} \\
q_t^{(s)} &= x_t W_q^{(s)} \in \mathbb{R}^{d_k}, \quad W_q^{(s)} \in \mathbb{R}^{d \times d_k} \\
k_i^{(\lceil s\frac{g}{h} \rceil)} &= x_i W_k^{(\lceil s\frac{g}{h} \rceil)} \in \mathbb{R}^{d_k}, \quad W_k^{(\lceil s\frac{g}{h} \rceil)} \in \mathbb{R}^{d \times d_k} \\
v_i^{(\lceil s\frac{g}{h} \rceil)} &= x_i W_v^{(\lceil s\frac{g}{h} \rceil)} \in \mathbb{R}^{d_v}, \quad W_v^{(\lceil s\frac{g}{h} \rceil)} \in \mathbb{R}^{d \times d_v}
\end{aligned}
$$
這裡的 \( \lceil \cdot \rceil \) 是向上取整符號。GQA提供了從MHA到MQA的自然過渡,當 \( g = h \) 時就是MHA,當 \( g = 1 \) 時就是MQA。當 \( 1 < g < h \) 時,它只將KV快取壓縮到 \( \frac{g}{h} \),壓縮率不如MQA,但同時也提供了更大的靈活性,效果上更有保障。
GQA最知名的使用者,大概是Meta開源的LLAMA2-70B,以及LLAMA3全系列,此外使用GQA的模型還有TigerBot、DeepSeek-V1、StarCoder2、Yi、ChatGLM2、ChatGLM3等,相比使用MQA的模型更多(ChatGLM雖然在它的介紹中說自己是MQA,但實際是 \( g=2 \) 的GQA)。
在llama2/3-70B中,GQA的 \( g=8 \),其他用了GQA的同體量模型基本上也保持了這個設置,這並非偶然,而是同樣出於推理效率的考慮。我們知道,70B這個體量的模型,如果不進行極端的量化,那麼不可能部署到單卡(A100/H100 80G)上。單卡不行,那麼就能單機了,一般情況下一台機可以裝8張卡,剛才我們說了,Attention的每個Head實際上是獨立運算然後拼接起來的,當 \( g=8 \) 時,正好可以每張卡負責計算一組K、V對應的Attention Head,這樣可以在盡可能保證K、V多樣性的同時最大程度上減少卡間通信。
MLA
有了MHA、MQA、GQA的鋪墊,我們理解MLA(多頭潛在注意力)就相對容易一些了。DeepSeek-V2的技術報告裡是從低秩投影的角度引入MLA的,以至於有部分讀者提出「為什麼LoRA提出這麼久了,直到MLA才提出對KV快取低秩分解的做法」之類的疑問。然而,筆者認為低秩投影這個角度並不貼近本質,因為要說低秩投影的話,事實上只要我們將GQA的所有K、V疊在一起,就會發現GQA也相當於在做低秩投影:
$$
\begin{aligned}
[k_i^{(1)}, \dots, k_i^{(g)}, v_i^{(1)}, \dots, v_i^{(g)}] &\equiv c_i \in \mathbb{R}^{g(d_k+d_v)} \\
c_i &= x_i[W_k^{(1)}, \dots, W_k^{(g)}, W_v^{(1)}, \dots, W_v^{(g)}] \in \mathbb{R}^{d \times g(d_k+d_v)}
\end{aligned}
$$
這裡我們將所有\( k_i^{(s)} \)和\( v_i^{(s)} \)拼在一起記為\( c_i \),相應的投影矩陣也拼在一起記為\( W_c \)。注意到一般都有\( d_c = g(d_k + d_v) < d \),所以\( x_i \)到\( c_i \)的轉換就是一個低秩投影。因此,MLA的本質改進不是低秩投影,而是在低秩投影之後的處理。
Part 1
GQA在投影之後做了什麼呢?首先它將向量對半分為兩份分別作為K、V,然後每一份又均分為\( g \)份,每一份複製\( h/g \)次,以此來「湊」夠\( h \)個Attention Head所需要的K、V。我們知道分割、複製都是簡單的線性變換,所以MLA的第一個想法是將這些簡單的線性變換換成一般的線性變換,以增強模型的能力:$$
\begin{aligned}
o_t &= [o_t^{(1)}, o_t^{(2)}, \ldots, o_t^{(h)}] \\
o_t^{(s)} &= \text{Attention}(q_t^{(s)}, k_t^{(s)}, v_t^{(s)}) \\
&\equiv \sum_{i \leq t} \frac{\exp(q_t^{(s)} k_i^{(s)\top})}{\sum_{i \leq t} \exp(q_t^{(s)} k_i^{(s)\top})} v_i^{(s)} \\
k_i^{(s)} &= c_i W_k^{(s)} \in \mathbb{R}^{d_k}, \quad W_k^{(s)} \in \mathbb{R}^{d_c \times d_k} \\
v_i^{(s)} &= c_i W_v^{(s)} \in \mathbb{R}^{d_v}, \quad W_v^{(s)} \in \mathbb{R}^{d_c \times d_v} \\
q_t^{(s)} &= x_t W_q^{(s)} \in \mathbb{R}^{d_k}, \quad W_q^{(s)} \in \mathbb{R}^{d \times d_k}
\end{aligned}
$$
理論上這樣能增加模型能力,但別忘了GQA的主要目的是減少KV快取。MLA的這個做法,通過不同的投影矩陣讓所有的K、V Head都變得各不相同,那麼KV快取的大小就恢復成跟MHA一樣大了,違背了GQA的初衷。
對此,MLA發現,我們可以結合Dot-Attention的具體形式,通過一個簡單但不失巧妙的恒等變換來規避這個問題。首先,在訓練階段還是照常進行,此時優化空間不大;然後,在推理階段,我們利用
$$
q_t^{(s)}k_i^{(s)\top} = (x_t W_q^{(s)})(c_i W_k^{(s)})^\top = x_t (W_q^{(s)} W_k^{(s)\top}) c_i^\top
$$
這意味著推理階段,我們可以將\( W_q^{(s)} W_k^{(s)\top} \)合併起來作為Q的投影矩陣,那麼\( c_i \)則取代了原本的\( k_i^{(s)} \),同理,在\( o_t \)後面我們還有一個投影矩陣,於是\( v_i^{(s)} \)的\( W_v^{(s)} \)也可以吸收到後面的投影矩陣中去,於是等效地\( v_i^{(s)} \)也可以用\( c_i \)代替,也就是說此時KV快取只需要存下所有的\( c_i \)就行,而不必存下所有的\( k_i^{(s)} \)、\( v_i^{(s)} \)。注意到\( c_i \)跟\( s \)無關,也就是說是所有頭共用的,即MLA在推理階段它可以恒等變換為一個MQA。
再次強調,本文的主題一直都是減少KV快取,那到目前為止,MLA做到了什麼呢?答案是通過不同的投影矩陣來增強了GQA的能力,並且推理時可以保持同樣大小的KV快取。那麼反過來,如果我們只需要跟GQA相近的能力,那麼是不是就可以再次減少KV快取了?換言之,\( d_c \)沒必要取\( g(d_k + d_v) \),而是取更小的值(DeepSeek-V2取了512),從而進一步壓縮KV快取,這就是MLA的核心思想。
(注:這裡有一個細節,就是\( W_q^{(s)} W_k^{(s)\top} \)合併成一個矩陣的恒等變換,理論上只有在無限精度下才成立,實際上如果我們使用單精度尤其是BF16的話,經過變換後的精度損失往往還是挺明顯的,經過多層累積後可能放大到比較可觀的程度,這裡可能要根據實際誤差看要不要做一些後處理。)
Part 2
一切似乎都很完美,看上去一個又好又省的理想設計就要出爐了。不過別急,當我們再深入思考一下就會發現,到目前為止的MLA有一個難以繞開的缺陷——不相容RoPE(旋轉位置編碼)。剛才我們提到,MLA之所以能保持與GQA相同大小的KV緩存,其關鍵一步是“將\(W_q^{(s)} W_k^{(s)^\top}\)”合併成一個與位置無關的矩陣作為Q的投影矩陣。但是,如果加入RoPE(旋轉位置編碼),這一步就無法實現了。RoPE是一個與位置相關的\(d_k \times d_k\)的分塊對角矩陣\(R_m\),滿足\(R_m R_n^\top = R_{m-n}\)。在MLA中加入RoPE後,會使得\(W_q^{(s)} W_k^{(s)^\top}\)之間多插入了一個項\(R_{t-i}\):
$$
q_i^{(s)} = x_i W_q^{(s)} R_i, \quad k_i^{(s)} = c_i W_k^{(s)} R_i \\
q_t^{(s)} k_i^{(s)^\top} = (x_t W_q^{(s)} R_t)(c_i W_k^{(s)} R_i)^\top = x_t (W_q^{(s)} R_{t-i} W_k^{(s)^\top}) c_i^\top
$$
這裡的\(W_q^{(s)} R_{t-i} W_k^{(s)^\top}\)就無法合併為一個固定的投影矩陣了(與位置差\(t-i\)相關),從而MLA的想法無法與RoPE結合實現。
前段時間,我也很榮幸地與DeepSeek團隊討論過這個問題,但這個問題可以說非常本質,所以當時我實際上也沒能提出什麼有效的建議。最簡單的方式是放棄RoPE,換用其他基於Attention Bias的位置編碼,如ALIBI,但DeepSeek的實驗顯示它明顯不如RoPE(注意,MLA不是不能加RoPE,而是加了RoPE之後無法用恒等變換技巧來減少KV緩存)。我也提議過換Sandwich,它不像ALIBI單調衰減到負無窮,估計效果會好些,但感覺是治標不治本。還有一個折中的辦法是將\(q_i^{(s)}\)的輸入也改為\(c_i\),然後RoPE加在\(c_i\)之後,即:
$$
q_i^{(s)} = c_i R_i W_q^{(s)}, \quad k_i^{(s)} = c_i R_i W_k^{(s)}
$$
這樣\(R_i\)就可以吸收到\(c_i\)中去,但這樣就沒有\(R_m R_n^\top = R_{m-n}\)的運算了,此時的RoPE不再是通過絕對位置實現相對位置,而單純是加在Q、K上的絕對位置資訊,讓模型自己想辦法提煉相對位置資訊。
最後發布的MLA,採取了一種混合的方法——每個注意力頭的Q、K新增\( d_r \)個維度用來添加RoPE,其中K新增的維度每個頭共用:
$$
o_t^{(s)} = \text{Attention}(q_t^{(s)}, k_{\leq t}^{(s)}, v_{\leq t}^{(s)}) \\
q_t^{(s)} = [x_t W_{qc}^{(s)}, x_t W_{qr}^{(s)} R_t] \in \mathbb{R}^{d_k + d_r}, \quad W_{qc}^{(s)} \in \mathbb{R}^{d \times d_k}, \quad W_{qr}^{(s)} \in \mathbb{R}^{d \times d_r} \\
k_i^{(s)} = [c_i W_{kc}^{(s)}, x_i W_{kr}^{(s)} R_i] \in \mathbb{R}^{d_k + d_r}, \quad W_{kc}^{(s)} \in \mathbb{R}^{d_c \times d_k}, \quad W_{kr}^{(s)} \in \mathbb{R}^{d \times d_r} \\
v_i^{(s)} = c_i W_{vc}^{(s)} \in \mathbb{R}^{d_v}, \quad W_{vc}^{(s)} \in \mathbb{R}^{d_c \times d_v}
$$
o_t^{(s)} = \text{Attention}(q_t^{(s)}, k_{\leq t}^{(s)}, v_{\leq t}^{(s)}) \\
q_t^{(s)} = [x_t W_{qc}^{(s)}, x_t W_{qr}^{(s)} R_t] \in \mathbb{R}^{d_k + d_r}, \quad W_{qc}^{(s)} \in \mathbb{R}^{d \times d_k}, \quad W_{qr}^{(s)} \in \mathbb{R}^{d \times d_r} \\
k_i^{(s)} = [c_i W_{kc}^{(s)}, x_i W_{kr}^{(s)} R_i] \in \mathbb{R}^{d_k + d_r}, \quad W_{kc}^{(s)} \in \mathbb{R}^{d_c \times d_k}, \quad W_{kr}^{(s)} \in \mathbb{R}^{d \times d_r} \\
v_i^{(s)} = c_i W_{vc}^{(s)} \in \mathbb{R}^{d_v}, \quad W_{vc}^{(s)} \in \mathbb{R}^{d_c \times d_v}
$$
這樣一來,沒有RoPE的維度就可以重複「Part 1」的操作,在推理時KV快取只需要存下所有的\( c_i \),新增的帶RoPE的維度就可以用來補充位置資訊,並且由於所有頭共用,所以也就只有在K快取這裡增加了\( d_r \)個維度,原論文取了\( d_r = \frac{d_k}{2} = 64 \),相比原本的\( d_c = 512 \),增加的幅度不大。
Part 3
最後有一個細節,就是MLA的最終版本,還將Q的輸入也改為了低秩投影形式,這與減少KV快取無關,主要是為了減少訓練期間參數量和相應的梯度(原論文說的是啟動值,個人表示不大理解)所占的顯存:$$
\begin{aligned}
o_t^{(s)} &= \text{Attention}(q_t^{(s)}, k_{\leq t}^{(s)}, v_{\leq t}^{(s)}) \\
q_t^{(s)} &= [c'_i W_{qc}^{(s)}, c'_i W_{qr}^{(s)} R_i] \in \mathbb{R}^{d_k + d_r}, \quad W_{qc}^{(s)} \in \mathbb{R}^{d'_c \times d_k}, \quad W_{qr}^{(s)} \in \mathbb{R}^{d'_c \times d_r} \\
k_i^{(s)} &= [c_i W_{kc}^{(s)}, x_i W_{kr}^{(s)} R_i] \in \mathbb{R}^{d_k + d_r}, \quad W_{kc}^{(s)} \in \mathbb{R}^{d_c \times d_k}, \quad W_{kr}^{(s)} \in \mathbb{R}^{d \times d_r} \\
v_i^{(s)} &= c_i W_{vc}^{(s)} \in \mathbb{R}^{d_v}, \quad W_{vc}^{(s)} \in \mathbb{R}^{d_c \times d_v} \\
c'_i &= x_i W'_{c} \in \mathbb{R}^{d'_c}, \quad W'_{c} \in \mathbb{R}^{d \times d'_c}
\end{aligned}
$$
注意\( k_i^{(s)} \)中的第二項,帶RoPE的部分,其輸入還是\( x_i \)而不是\( c_i \),這裡保持了原論文的設定,不是筆誤,\( d'_c \)原論文的取值是1536,跟\( d_c = 512 \)不同。同時,我們把帶RoPE的MHA放在下面,方便大家對比:
$$
\begin{aligned}
o_t^{(s)} &= \text{Attention}(q_t^{(s)}, k_{\leq t}^{(s)}, v_{\leq t}^{(s)}) \\
q_t^{(s)} &= x_i W_q^{(s)} R_i \in \mathbb{R}^{d_k}, \quad W_q^{(s)} \in \mathbb{R}^{d \times d_k} \\
k_i^{(s)} &= x_i W_k^{(s)} R_i \in \mathbb{R}^{d_k}, \quad W_k^{(s)} \in \mathbb{R}^{d \times d_k} \\
v_i^{(s)} &= x_i W_v^{(s)} \in \mathbb{R}^{d_v}, \quad W_v^{(s)} \in \mathbb{R}^{d \times d_v}
\end{aligned}
$$
可以發現,其實在訓練階段,除了多了一步低秩投影以及只在部分維度加RoPE外,MLA與Q、K的Head Size由\( d_k \)換成\( d_k + d_r \)的MHA基本無異。
小結
本文簡單概述了多頭注意力的演變歷程,特別是從MHA向MQA、GQA,最終到MLA的變化理念,最後詳細展開了對MLA的介紹。在本文中,MLA被視為GQA的一般化,它用投影矩陣的方式替代了GQA的分割、重複,並引入了一個恆等變換技巧來進一步壓縮KV快取,同時採用了一種混合方法來相容RoPE。總的來說,MLA稱得上是一種非常實用的注意力變體,其創新在於如何在保持推理效率的同時減少存儲和計算資源的需求。這種設計不僅顯示了深入的技術見解,也反映出當前大型語言模型推理優化的趨勢——即在最大限度減少資源消耗的同時,盡可能保持或提升模型的性能。未來的研究可能會繼續在這條路上進行探索,尋找更有效的方法來解決推理時的瓶頸問題,特別是在處理更大規模的數據和模型時。這些技術的進步將對AI的實用性和可達性產生深遠的影響。
沒有留言:
發佈留言