這是我第三次評論這篇論文, 重要的報告我講三次, 很重要!很重要!很重要!我認為這個模型會超越Mamba!
https://arxiv.org/pdf/2402.19427.pdf
近年來,Transformer語言模型在多個自然語言處理任務上取得了巨大成功,但其全局注意力機制在處理超長序列時面臨計算瓶頸。另一方面,循環神經網絡(RNN)長期以來被認為更適合對長距離依賴進行建模,但其串行計算特性限制了訓練和推理效率。為了兼顧長程建模能力和計算效率,Google提出了Griffin——一種融合了門控線性遞迴單元(RG-LRU)和局部多頭注意力(Local Multi-head Attention)的混合語言模型。
Griffin模型
整體架構
Griffin的主體是一個堆疊的Transformer結構,包含L個殘差塊(L一般取12~40)。每個殘差塊包含兩個子結構:多層感知器(MLP)和時序混合塊。其中,MLP塊在所有殘差塊中共享,而時序混合塊在不同層中交替使用RG-LRU塊和局部注意力(Local Attention)塊。模型的輸入是一個長度為N的token序列,通過詞嵌入(Embedding)層映射為H維稠密向量。最終輸出通過嵌入矩陣的轉置進行復原,並計算softmax交叉熵損失。MLP塊
$$ f(x) = (W_1 x + b_1) \otimes \sigma(W_2 x + b_2) $$
其中$W_1,W_2 \in R^{H \times H}, b_1,b_2 \in R^H$分別為兩個全連接層的權重矩陣和偏置項,$\sigma$為Sigmoid函數,$\otimes$為按元素乘法。相比ReLU等激活函數,GLU可以更好地建模輸入之間的非線性交互作用。
時序混合塊
RG-LRU
$$r_t = \sigma(W_{xr} x_t + b_r)$$
$$i_t = \sigma(W_{xi} x_t + b_i)$$
$$\tilde{A}_t = \text{diag}(r_t) A^T \text{diag}(r_t)$$
$$h_t = \tilde{A}t h{t-1} + \sqrt{1-\tilde{A}_t^2} \odot (i_t \odot x_t)$$
其中$r_t,i_t \in R^H$分別為遞迴門和輸入門,$W_{xr},W_{xi} \in R^{H \times H}$和$b_r,b_i \in R^H$ 為門控單元的參數,$A \in R^{H \times H}$為一個對角矩陣,對角線元素在(0,1)範圍內。運算$\odot$表示按元素乘法。
RG-LRU引入了兩個關鍵的改進:
(1) 遞迴門$r_t$控制了過去狀態$h_{t-1}$的保留程度,其中$\tilde{A}_t$是對$A$應用$r_t$進行插值的結果。當$r_t$接近1時,模型傾向於保留過去信息;當$r_t$接近0時,模型傾向於遺忘過去信息,只關注當前輸入$x_t$。這種自適應的記憶機制使RG-LRU能靈活地應對不同時間尺度的依賴關係。
(2) 輸入門$i_t$控制了當前輸入$x_t$對隱狀態$h_t$的貢獻。與LSTM等傳統RNN不同,RG-LRU的輸入門不依賴上一時刻的隱狀態,從而實現了$O(1)$的順序計算復雜度。
局部注意力
$$o_t = \sum_{i=t-M}^{t+M} \alpha_{ti} (W_V h_i)$$
其中$M$為注意力窗口的半徑,$W_V \in R^{H \times H}$為值(Value)映射矩陣。注意力權重$\alpha_{ti}$通過查詢向量$q_t$和鍵向量$k_i$的內積計算:
$$\alpha_{ti} = \frac{\exp(q_t^T k_i)}{\sum_{j=t-M}^{t+M} \exp(q_t^T k_j)}$$
其中$q_t = W_Q h_t, k_i = W_K h_i$,對應的投影矩陣為$W_Q,W_K \in R^{H \times H}$。為了建模多種類型的依賴關係,Griffin使用多頭注意力機制,計算公式為:
$$o_t = W_O[o_t^{(1)},\dots,o_t^{(K)}] + b_o$$
其中$o_t^{(k)}$是第$k$個注意力頭的輸出,$W_O \in R^{H \times KH},b_o \in R^H$為最終的線性變換參數。
在實踐中,我們發現設置注意力窗口大小$2M+1=1024$,注意力頭數$K=32$可以取得最佳的性能。這種局部注意力機制顯著降低了計算複雜度(從$O(N^2)$減少到$O(NM)$),且與RG-LRU形成了很好的互補。
作為堅定支持者,Google 的 Griffin 論文完美地展示了Scaling laws。當他們將模型參數擴大 7 倍時,在各項任務上的表現大約提升了 10%。所有模型都是在相同的 3000 億 tokens 數據上訓練的。隨著參數的增加,模型變得更加樣本有效,外推能力也更強。
位置編碼
$$\alpha_{ti} = \frac{\exp(q_t^T k_i + q_t^T r_{i-t})}{\sum_{j=t-M}^{t+M} \exp(q_t^T k_j + q_t^T r_{j-t})}$$
其中$r_{\Delta t} \in R^H$是一個可學習的相對位置編碼向量,表示時間步$t$和$i$之間的距離$\Delta t = i-t$的影響。
這是 Google 的一項突破性研究成果。
Google 發布了具有新 Griffin 架構的模型,其性能優於 Transformer。
訓練與推理優化
參數高效性
得益於RG-LRU的線性計算特性,Griffin的參數量和計算復雜度與層數L呈線性關係。具體來說,設模型隱藏層維度為H,詞表大小為V,則Griffin的參數量近似為:
$$\text{Params} = 4LH^2 + 2VH$$
其中$4LH^2$項對應MLP塊、RG-LRU和注意力層的權重矩陣參數,$2VH$對應詞嵌入矩陣及其轉置。考慮到$H << V$,模型的參數量主要由詞嵌入矩陣主導。這意味著Griffin可以通過增加深度L來提高模型容量,而不會導致參數量過度膨脹。通過控制隱藏層維度H,在保持參數量不變的情況下比較Griffin和Transformer在不同層數L下的性能。
GPU/TPU並行化
雖然RG-LRU易於實現廉價的串行計算,但為了充分利用現代加速器(如GPU和TPU)的並行能力,還需要對其進行並行化改造。受益於RG-LRU的簡潔性,可以輕松地將其計算過程表示為一系列矩陣乘法和逐元素操作,從而實現高效的批量化(Batching)和張量化(Tensorization)加速。
在訓練時,將一個批次的輸入序列表示為形狀為$[B, N, H]$的三維張量,其中$B$為批次大小。通過將第二維(長度維)的計算映射到加速器的不同線程/內核上,Griffin可以實現與Transformer相當的訓練吞吐量。對於局部注意力層,使用了快速注意力(Fast Attention)算法,通過計算局部注意力權重的前綴和(Prefix Sum),將總體複雜度降低到$O(BNH)$。
在推理時,需要逐步生成輸出序列。與Transformer需要維護一個隨生成長度增長的鍵值緩存(Key-Value Cache)不同,RG-LRU只需要維護一個固定大小的隱狀態向量$h_t$,從而顯著節省了內存佔用。此外,Griffin還可以利用內存化(Memorization)技術,即將每一層的輸出都保存到一個大小固定的循環緩衝區中,供下一層計算時復用。這種做法避免了在深層模型中重複計算前幾層的結果,進一步提升了推理速度。
接下來,研究者還選取了三個專門用於評測長距離依賴建模能力的任務:LAMBADA、ListOps和Pathfinder。其中,LAMBADA是一個基於上下文的單詞預測任務,ListOps需要模型執行算術運算,而Pathfinder則考察模型在網格圖上進行推理的能力。實驗結果顯示,Griffin在這些任務上的表現也一致地超過了基線模型。值得注意的是,隨著序列長度的增加,Griffin的優勢變得更加明顯,這歸因於RG-LRU層強大的遞歸歸納能力。
在推理效率方面,Griffin展現出了明顯的優勢。得益於RG-LRU層恆定的內存佔用和局部注意力的線性計算複雜度,Griffin在生成超長文本時的推理速度比Transformer-XL等模型快2~3倍,且內存佔用減少了50%以上。這使得Griffin非常適合應用於資源受限的場景,如移動設備或實時系統。
混合精度訓練
為了進一步提高訓練效率並節省顯存,採用了混合精度(Mixed Precision)訓練策略。具體來說,模型的前向和反向傳播過程使用半精度浮點數(FP16),而模型權重的更新則在單精度(FP32)下進行。通過這種設置,可以將訓練所需的顯存減少近50%,同時保持模型的收斂性和最終性能。我們還發現,對梯度應用動態缩放(Dynamic Scaling)技術可以進一步提高混合精度訓練的穩定性。
為了深入分析Griffin的行為特徵,研究者還進行了一系列消融實驗。首先,他們探究了不同超參數(如層數、隱藏狀態維度、注意力頭數等)對模型性能的影響。結果表明,增加Griffin的層數和寬度都能帶來性能的提升,且效果優於同等大小的Transformer模型。其次,研究者還考察了不同注意力範圍對Griffin的影響。他們發現,當注意力窗口大小在1024左右時,模型在各個任務上的表現最優,且即使在更長的序列長度下也能保持優勢。
他們在WikiText-103、One Billion Word和PG-19等語言建模數據集上測試了Griffin的泛化能力。這些數據集包含了不同長度、主題和風格的文本,可以全面評估模型在長距離依賴關係上的建模能力。實驗中,研究者控制了模型的參數量,以公平地比較Griffin與其他基線模型的性能。結果表明,Griffin在所有數據集上都取得了最優的perplexity(PPL)分數,顯著優於Transformer-XL、Compressive Transformer等現有方法。
接下來,研究者還選取了三個專門用於評測長距離依賴建模能力的任務:LAMBADA、ListOps和Pathfinder。其中,LAMBADA是一個基於上下文的單詞預測任務,ListOps需要模型執行算術運算,而Pathfinder則考察模型在網格圖上進行推理的能力。實驗結果顯示,Griffin在這些任務上的表現也一致地超過了基線模型。值得注意的是,隨著序列長度的增加,Griffin的優勢變得更加明顯,這歸因於RG-LRU層強大的遞歸歸納能力。
為了深入分析Griffin的行為特徵,研究者還進行了一系列消融實驗。首先,他們探究了不同超參數(如層數、隱藏狀態維度、注意力頭數等)對模型性能的影響。結果表明,增加Griffin的層數和寬度都能帶來性能的提升,且效果優於同等大小的Transformer模型。其次,研究者還考察了不同注意力範圍對Griffin的影響。他們發現,當注意力窗口大小在1024左右時,模型在各個任務上的表現最優,且即使在更長的序列長度下也能保持優勢。
在推理效率方面,Griffin展現出了明顯的優勢。得益於RG-LRU層恆定的內存佔用和局部注意力的線性計算複雜度,Griffin在生成超長文本時的推理速度比Transformer-XL等模型快2~3倍,且內存佔用減少了50%以上。這使得Griffin非常適合應用於資源受限的場景,如移動設備或實時系統。
結論
1. RG-LRU(Real-Gated Linear Recurrent Unit):
RG-LRU是Griffin模型的核心組件之一。它是一種新型的門控線性遞迴層,其設計靈感來自於標準的LRU(Linear Recurrent Unit)和LSTM/GRU中使用的門控機制。
具體來說,RG-LRU引入了兩個門:遞迴門r_t和輸入門i_t,它們分別控制前一時刻隱狀態h_{t-1}和當前輸入x_t對新隱狀態h_t的貢獻。與LSTM/GRU不同的是,RG-LRU的門不依賴於前一時刻的隱狀態,這樣可以顯著提高計算效率。
遞迴門r_t進一步調節一個對角矩陣A_t,使其在保留歷史信息(r_t->1時A_t->I)和遺忘歷史信息(r_t->0時A_t->0)之間進行非線性插值。這種獨特的門控方式使RG-LRU能學會在長序列上選擇性地汲取信息。
RG-LRU還有一個復數版本CG-LRU(Complex-Gated Linear Recurrent Unit),它將輸入和隱狀態表示為復數,並使用復數矩陣參數,以提高层的表示能力。但實驗發現,在語言建模任務上,實數版本的RG-LRU性能並不亞於CG-LRU。
2. 時序混合(Temporal Mixing):
Griffin採用分層的時序混合方式,在12-40層的網路中,每3層就混合一個RNN塊(含3個RG-LRU層)和一個局部MQA(Multi-Query Attention)層。這樣的結構讓模型在捕捉長距離依賴的同時,也能很好地對局部範圍內的信息進行匹配。
實驗顯示,對於序列長度為2048的預訓練,局部MQA的最優窗口大小在1024左右。當序列長度增加到4096、8192時,1024的窗口大小仍然能取得比全局注意力Transformer更好的效果。這表明Griffin利用RNN積累長程信息的能力可以很好地彌補局部注意力的不足。
3. 硬體加速:
為了高效地在TPU-v3上訓練RG-LRU,研究者實現了定製的Pallas內核,將線性掃描運算中的內存讀寫次數降到最低。這使得RG-LRU在訓練加速比上可以達到Transformer的水平(記憶體帶寬約為900GB/s)。
同時,由於RG-LRU的隱狀態大小和局部MQA的KV cache大小遠小於Transformer的KV cache,Griffin在推理階段表現出了顯著的延遲優勢(低20%以上)和吞吐量優勢(高3倍以上)。這在生成超長序列時尤為明顯。
4. 外推和零樣本學習:
研究者在一個書籍語料(Books)和arXiv語料上評估了Griffin的外推(extrapolation)能力,即在遠長於訓練序列的長度上生成文本的能力。實驗表明,Griffin可以在長達訓練序列4倍的序列上穩定地利用更長上下文改進預測。而Transformer受限於其位置編碼,很難在超出訓練序列長度時維持較好表現。
此外,Griffin還在幾個合成的復制/檢索任務(Selective Copying、Induction Heads、Phone Number Lookup)上接受了測試。結果顯示,Griffin可以在監督訓練下快速學會這些任務所需的復制檢索技能。但在零樣本(zero-shot)設定下,預訓練的Griffin模型在這些任務上的外推能力仍然不及Transformer。這可能與局部注意力和RNN對字面匹配(verbatim match)的能力有限有關。
5. 訓練規模與下游任務性能:
研究者在100M到14B參數規模下訓練了Griffin,Hawk和Transformer模型,並在MMLU、HellaSwag、PIQA等7個下游任務上進行了評估。結果顯示,在只用300B tokens訓練的情況下:
- Griffin和Hawk在所有任務上的平均得分隨模型規模增大而穩定上升,其中Griffin在所有規模下都超過Transformer。
- Hawk-3B顯著超過了用600B tokens訓練的Mamba-3B。
- Griffin-7B/14B則在多數任務上達到甚至超過了用2T tokens訓練的Transformer模型Llama-2的水平。
這充分展現了Griffin超越Transformer的樣本效率。
綜上所述,Griffin通過巧妙融合RNN和局部注意力,在計算效率、樣本效率、長程建模等多方面取得了超越Transformer的效果,是一種非常有前景的新型語言模型架構。未來還可以進一步探索如何改進其在超長序列上的注意力機制,以及在零樣本語言任務上的外推和泛化能力。在工程實踐中,Griffin有望憑藉其出色的推理性能在長文本應用場景大放異彩。
沒有留言:
發佈留言