code window

2024年5月1日星期三

Meta簡單而有效的語言模型訓練方法—多token預測

Meta提出了一種透過多token預測(Multi-token Prediction)來訓練更好、更快的大型語言模型的方法。這篇論文的重點如下:

  • 訓練語言模型同時預測多個未來的token,可以提高樣本效率(sample efficiency)。
  • 在推論階段,使用多token預測可以達到最高3倍的加速。


論文的主要貢獻包括:

  1. 提出了一種簡單的多token預測架構,在訓練時間和記憶體使用上沒有額外開銷。
  2. 實驗證明,這種訓練範式在大規模模型(最高達130億參數)上是有效的,平均可以解決大約15%以上的編程問題。
  3. 多token預測使得自我推測解碼(self-speculative decoding)成為可能,在各種批次大小下將模型的推論速度提高了最多3倍。

    https://arxiv.org/pdf/2404.19737

動機與目的

傳統的語言模型通常使用下一個token預測(next-token prediction)的方式進行訓練,即根據前面的token序列,預測下一個最可能出現的token。然而,這種訓練方式可能導致模型過度關注局部的模式,忽略了長程的依賴關係。為了解決這個問題,本文提出了多token預測(multi-token prediction)的訓練方法,同時預測未來的多個token,以提升語言模型的訓練效率和性能。



方法原理

模型架構

  • 語言模型使用一個共享的模型主體(shared model trunk),並在其上添加n個獨立的輸出頭(output head),分別預測未來的n個token。
  • 在訓練時,模型在每個位置同時預測未來的n個token,使用n個獨立的loss項。
  • 為了減少GPU記憶體用量,作者巧妙地調整了前向/反向傳播的順序。模型依序計算每個輸出頭的前向和反向傳播,同時累積主體的梯度,避免同時儲存所有n個龐大的logit向量。
  • 推論時,可以只用第一個輸出頭(也就是下一個token的預測),其餘輸出頭可選擇性地用於加速推論(稱為self-speculative decoding)。


訓練目標
在訓練時,模型在每個位置同時預測未來的\(n\)個token,使用\(n\)個獨立的cross-entropy loss項。假設輸入的token序列為\(x_1, x_2, ..., x_t,\)模型的訓練目標可以表示為:

$$L_n = - Σ_t log P(x_{t+1}, ..., x_{t+n} | x_1, ..., x_t)$$

其中,\(P(x_{t+1}, ..., x_{t+n} | x_1, ..., x_t)\)表示在給定前\(t\)個token的條件下,未來\(n\)個token的聯合概率分佈。將這個聯合概率分解為\(n\)個條件概率的乘積,可以得到:

\(L_n = - Σ_t [log P(x_{t+1} | x_1, ..., x_t) + log P(x_{t+2} | x_1, ..., x_t) + ... + log P(x_{t+n} | x_1, ..., x_t)]\)

每個條件概率\(P(x_{t+i} | x_1, ..., x_t)\)由一個獨立的輸出頭計算得到。


訓練技巧

為了減少GPU記憶體的使用量,作者巧妙地調整了前向/反向傳播的順序。模型依序計算每個輸出頭的前向和反向傳播,同時累積主體的梯度,避免同時儲存所有n個龐大的logit向量。這種技巧使得多token預測模型的訓練幾乎不增加額外的計算和存儲開銷。



推論過程

在推論階段,可以只使用第一個輸出頭(即下一個token的預測),其餘輸出頭可選擇性地用於加速推論。這種加速技術稱為self-speculative decoding,通過並行計算多個輸出頭的預測結果,可以提高推論的效率。


實驗結果

作者在多個編碼和自然語言任務上評估了多token預測模型的性能,並與傳統的下一個token預測模型進行了比較。


編碼任務

在HumanEval和MBPP兩個編碼資料集上,多token預測模型顯著優於基準模型,尤其在大模型(如13B參數)上提升更加明顯。4個token的預測在綜合表現上最佳,在HumanEval上pass@100提升了4.1%,在MBPP上pass@1提升了3.8%。此外,訓練多個epoch時,多token預測的優勢仍然存在。

自然語言任務

在自然語言任務上,多token預測也帶來了改進,特別是在需要生成較長文本的摘要和自然語言數學任務。在8個摘要資料集上,2個token的預測平均將ROUGE-L提升了0.51,4個token的預測平均提升了0.46。在GSM8K自然語言數學資料集上,2個token的預測模型顯著優於基準模型。

字元級訓練


為了驗證多token預測有助於學習更長程的依賴關係,作者進行了字元級(byte-level)的訓練實驗。結果表明,8個字元的多token預測模型在HumanEval上pass@1的表現比下一個字元預測模型高出20%,在MBPP上高出67%。這說明多token預測能夠捕捉更長距離的模式和依賴關係。

模型微調

使用預訓練的多token預測模型進行微調,也能在下游任務上取得優於基準模型的成果。在CodeContests資料集上,4個token預訓練的模型在pass@k上全面超過了下一個token預訓練的模型。



  • 在編碼(coding)任務上,多token預測模型在HumanEval和MBPP資料集上的表現顯著優於基準模型,尤其在大模型(如13B參數)上提升更加明顯。
  • 在自然語言任務上,多token預測也帶來了改進,特別是在需要生成較長文本的摘要和自然語言數學任務。
  • 多token預測有助於模型學習更長程的依賴關係。在字元級(byte-level)的訓練中,8個字元的多token預測大幅優於下一個字元預測。
  • 實驗顯示,4個token的預測在綜合表現上最佳。此外,訓練多個epoch時,多token預測的優勢仍然存在。
  • 使用訓練好的多token預測模型進行微調(如在CodeContests資料集上),也能取得優於基準模型的成果。

  • 額外的輸出頭可用於self-speculative decoding,在推論階段提供最高3倍的加速。


結論與討論

本文提出了一種簡單而有效的語言模型訓練方法——多token預測,通過同時預測未來的多個token,促進模型學習更長程的依賴關係。實驗結果表明,這種方法在編碼和自然語言任務上帶來了顯著的性能提升,尤其對大模型和較長文本的生成任務效果更佳。多token預測幾乎不增加訓練成本,卻能提高訓練和推論效率,值得進一步探索。



作者認為,這項工作為尋找更有效的語言模型訓練方法開闢了新的方向。未來的研究可以探索以下幾個方面:

  1. 在更大規模的數據集和模型上驗證多token預測的有效性。
  2. 研究最優的token預測數量n,以及如何自適應地選擇n。
  3. 設計更高效的多token預測架構,如使用單一的輸出頭來預測多個token。
  4. 將多token預測與其他輔助訓練目標結合,如掩碼語言建模(masked language modeling)。

多token預測是一種前景廣闊的語言模型訓練方法,有望幫助構建更強大、更連貫的語言模型,推動自然語言處理領域的發展。


以下是我對這項工作的一些想法:

多token預測利用了語言的長程依賴關係,通過同時預測多個未來的token,促使模型學習更全面、更連貫的表示。這種方法與人類語言學習的過程更為相似,因為我們在理解和生成語言時,也是基於對未來一段文本的預期,而不僅僅依賴於前一個詞。

該方法在編程任務上取得了顯著的性能提升,這可能是因為編程語言具有更強的結構性和邏輯性,多token預測更容易捕捉到其中的模式和依賴關係。在自然語言任務上的改進相對較小,可能是因為自然語言的不確定性和靈活性更高,單純增加預測的token數量效果有限,需要更細緻的建模方法。

多token預測在推論階段帶來的加速效果非常可觀,這對於實際應用中的延遲敏感場景(如實時對話、同步翻譯等)具有重要價值。不過,這種加速方法對模型性能的影響還需要進一步評估,確保生成質量不會顯著下降。

論文中的實驗主要集中在編程和自然語言文本上,未來可以考慮將多token預測應用於其他類型的序列數據,如時間序列、生物序列等,探索它在更廣泛領域的有效性。

多token預測作為一種輔助的訓練目標,與其他方法(如對比學習、知識蒸餾等)結合使用,可能會產生更好的協同效果。探索多種訓練策略的組合,有望進一步提升語言模型的性能和泛化能力。

我認為這項工作為改進大型語言模型的訓練和推理效率提供了一個簡單而有效的思路,具有廣闊的應用前景。未來可以在更大規模的資料集和模型上驗證這種方法的有效性,並探索與其他技術結合的可能性,推動語言模型的進一步發展。

沒有留言:

發佈留言

SambaNova SN40L: 利用Dataflow和專家組合(COE)來克服AI記憶牆的大模型

摘要 GPT-4等整體式大型語言模型(LLM)為現代生成AI應用鋪路。然而,大規模訓練、服務及維護整體式LLM仍然極其昂貴和充滿挑戰。現代AI加速器計算能力與記憶體比例的不成比例增長已經造成了記憶體壁障,需要新的方法來部署AI。最近的研究顯示,許多小型專家模型的組合,每個模型參數...