在 Hacker News 及 Twitter 等社交網(wǎng)絡(luò)上,該論文都反響熱烈,有網(wǎng)友表示差分 Transformer 提出的改進(jìn)簡(jiǎn)單又美麗,而帶來(lái)的提升又非常顯著。

甚至已有開發(fā)者做出了差分 Transformer 的輕量實(shí)現(xiàn)!

那么差分 Transformer 彌補(bǔ)了原生 Transformer 的哪些問題呢?如下圖所示,Transformer 往往會(huì)過度關(guān)注不相關(guān)的上下文,該團(tuán)隊(duì)將此稱為注意力噪聲(attention noise)。而差分 Transformer 則能放大對(duì)答案范圍的注意力并消除噪音,從而增強(qiáng)上下文建模的能力。這就要用到該團(tuán)隊(duì)新提出的差分注意力機(jī)制(differential attention mechanism)了。

差分注意力機(jī)制可以消除注意力噪聲,鼓勵(lì)模型重點(diǎn)關(guān)注關(guān)鍵信息。該方法有些類似于電氣工程中的降噪耳機(jī)和差分放大器。

下面我們就來(lái)詳細(xì)了解一下差分 Transformer 的設(shè)計(jì)思路。

差分 Transformer

差分 Transformer 是一種用于序列建模的基礎(chǔ)模型架構(gòu)。為了方便說(shuō)明,他們使用了僅解碼器(decoder-only)模型作為示例來(lái)描述該架構(gòu)。

該模型堆疊了 L 個(gè) Diff Transformer 層。給定一個(gè)輸入序列 x,將輸入嵌入打包成 X^0。輸入會(huì)被進(jìn)一步上下文化來(lái)獲得輸出 X^L。每一層都由兩個(gè)模塊組成:一個(gè)差分注意力模塊和之后的前向網(wǎng)絡(luò)模塊。

相比于 Transformer,差分 Transformer 的主要差別在于使用差分注意力替換了傳統(tǒng)的 softmax 注意力,同時(shí)保持整體宏觀布局不變。此外,他們也參考 LLaMA 采用了 pre-RMSNorm 和 SwiGLU 這兩項(xiàng)改進(jìn)措施。

差分注意力

差分注意力機(jī)制的作用是將查詢、鍵和值向量映射成輸出。這里使用查詢和鍵向量來(lái)計(jì)算注意力分?jǐn)?shù),然后計(jì)算值向量的加權(quán)和。

此處的關(guān)鍵設(shè)計(jì)是使用一對(duì) softmax 函數(shù)來(lái)消除注意力分?jǐn)?shù)的噪聲。具體來(lái)說(shuō),給定輸入 X,首先將它們投射成查詢、鍵和值 Q_1、Q_2、K_1、K_2、V。然后差分注意力算子 DiffAttn (?) 通過以下方式計(jì)算輸出:

其中 W^Q、W^K 、W^V 是參數(shù),λ 是可學(xué)習(xí)的標(biāo)量。為了同步學(xué)習(xí)動(dòng)態(tài),將標(biāo)量 λ 重新參數(shù)化為:

其中 λ_q1、λ_k1、λ_q2、λ_k2 是可學(xué)習(xí)的向量,λ_init ∈ (0, 1) 是用于初始化 λ 的常數(shù)。該團(tuán)隊(duì)通過經(jīng)驗(yàn)發(fā)現(xiàn),設(shè)置 λ_init = 0.8 ? 0.6 × exp (?0.3?(l ? 1)) 在實(shí)踐中效果很好,其中 l ∈ [1, L] 表示層索引。它在實(shí)驗(yàn)中被用作默認(rèn)策略。

他們也探索了另一種初始化策略:對(duì)所有層使用相同的 λ_init(例如 0.8)。如后面消融研究所示,使用不同的初始化策略時(shí),性能相對(duì)穩(wěn)健。

差分注意力利用兩個(gè) softmax 注意力函數(shù)之間的差來(lái)消除注意力噪聲。這個(gè)想法類似于電氣工程中提出的差分放大器,其中兩個(gè)信號(hào)之間的差用作輸出,這樣就可以消除輸入的共模噪聲。此外,降噪耳機(jī)的設(shè)計(jì)也基于類似的想法。

該團(tuán)隊(duì)也為差分注意力使用了多頭機(jī)制。令 h 表示注意力頭的數(shù)量。他們對(duì)各個(gè)頭使用不同的投影矩陣 W^Q_i 、W^K_i 、W^V_i ,i ∈ [1, h]。標(biāo)量 λ 在同一層內(nèi)的頭之間共享。然后對(duì)頭輸出執(zhí)行歸一化,并投射成最終結(jié)果,如下所示:


其中 λ_init 是 (2) 式中的常數(shù)標(biāo)量,W^O 是可學(xué)習(xí)的投影矩陣,LN (?) 是對(duì)每個(gè)頭使用 RMSNorm,Concat (?) 的作用是沿通道維度將頭連接在一起。這里使用一個(gè)固定乘數(shù)(1 ? λ_init)作為 LN (?) 的縮放尺度,以使梯度與 Transformer 對(duì)齊。

圖 2 使用了 GroupNorm (?) 來(lái)強(qiáng)調(diào) LN (?) 獨(dú)立應(yīng)用于每個(gè) head。由于差分注意力往往具有更稀疏的模式,因此頭之間的統(tǒng)計(jì)信息更加多樣化。為了改進(jìn)梯度的統(tǒng)計(jì)情況,LN (?) 算子會(huì)在連接操作之前對(duì)每個(gè)頭進(jìn)行歸一化。

整體架構(gòu)

其整體架構(gòu)會(huì)堆疊 L 層,其中每層包含一個(gè)多頭差分注意力模塊和一個(gè)前向網(wǎng)絡(luò)模塊。如此,便可將差分 Transformer 層描述為:

其中 LN (?) 是 RMSNorm,SwiGLU (X) = (swish (XW^G) ⊙ XW_1) W_2,且 W^G、W_1、W_2 是可學(xué)習(xí)的矩陣。

實(shí)驗(yàn)

該團(tuán)隊(duì)從以下角度評(píng)估了差分 Transformer 在 LLM 中的應(yīng)用,包括對(duì)比評(píng)估、應(yīng)用評(píng)估和消融研究。這里我們僅關(guān)注實(shí)驗(yàn)結(jié)果,更多實(shí)驗(yàn)過程請(qǐng)?jiān)L問原論文。

語(yǔ)言建模評(píng)估

該團(tuán)隊(duì)評(píng)估了差分 Transformer 的語(yǔ)言建模能力。為此,他們使用 1T token 訓(xùn)練了一個(gè) 3B 大小的差分 Transformer 語(yǔ)言模型,并與之前的 Transformer 語(yǔ)言模型做了比較。

結(jié)果見表 1,其中報(bào)告的是在 LM Eval Harness 基準(zhǔn)上的零樣本結(jié)果。

可以看到,3B 規(guī)模下,差分 Transformer 語(yǔ)言模型的表現(xiàn)優(yōu)于之前的 Transformer 語(yǔ)言模型。此外,實(shí)驗(yàn)也表明差分 Transformer 在多種任務(wù)上都勝過 Transformer,詳見原論文附錄。

與 Transformer 的可擴(kuò)展性比較

該團(tuán)隊(duì)也比較了新舊 Transformer 的可擴(kuò)展性。結(jié)果見圖 3,其中 a 比較了模型規(guī)模方面的可擴(kuò)展性,而 b 則是訓(xùn)練 token 數(shù)量方面的可擴(kuò)展性。

可以看到,在這兩個(gè)方面,差分 Transformer 的可擴(kuò)展性均優(yōu)于常規(guī) Transformer:僅需后者 65% 左右的模型大小或訓(xùn)練 token 數(shù)量就能達(dá)到相媲美的性能。

長(zhǎng)上下文評(píng)估

當(dāng) 3B 模型上下文長(zhǎng)度增長(zhǎng)至 64K,模型的表現(xiàn)又如何呢?又使用另外 1.5B token 訓(xùn)練了 3B 版本的檢查點(diǎn)模型之后,該團(tuán)隊(duì)發(fā)現(xiàn)隨著上下文長(zhǎng)度的增加,累積平均負(fù)對(duì)數(shù)似然(NLL)持續(xù)下降。差分 Transformer 得到的 NLL 值低于常規(guī) Transformer。見圖 4,這樣的結(jié)果表明,差分 Transformer 可以有效地利用不斷增加的上下文。

關(guān)鍵信息檢索

為了檢驗(yàn)差分 Transformer 檢索關(guān)鍵信息的能力,該團(tuán)隊(duì)執(zhí)行了 Needle-In-A-Haystack(草堆找針)測(cè)試。

表 2 給出了 4K 上下文長(zhǎng)度的情況,其中 N 是針的數(shù)量,R 是查詢引用的數(shù)量??梢钥吹剑罘?Transformer 的多針檢索準(zhǔn)確度高于常規(guī) Transformer,尤其是當(dāng)針數(shù)量較多時(shí),差分 Transformer 的優(yōu)勢(shì)會(huì)更加明顯。

那么當(dāng)上下文長(zhǎng)度提升至 64K 時(shí),又會(huì)如何呢?結(jié)果見圖 5,這里使用的上下文長(zhǎng)度在 8K 到 64K 之間,使用了 N = 8 和 R = 1 的設(shè)置。

可以看到,在不同的上下文長(zhǎng)度下,差分 Transformer 能夠保持相對(duì)穩(wěn)定的性能。而當(dāng)上下文長(zhǎng)度越來(lái)越大時(shí),常規(guī) Transformer 的性能會(huì)逐漸下降。

另外,表 3 展示了分配給關(guān)鍵信息檢索任務(wù)的答案范圍和噪聲上下文的注意力分?jǐn)?shù)。該分?jǐn)?shù)可代表模型保留有用信息、抵抗注意力噪聲的能力。

可以看到,相比于常規(guī) Transformer,差分 Transformer 能為答案范圍分配更高的注意力分?jǐn)?shù),同時(shí)為注意力噪聲分配更低的注意力分?jǐn)?shù)。

上下文學(xué)習(xí)能力評(píng)估

該團(tuán)隊(duì)從兩個(gè)角度評(píng)估模型的上下文學(xué)習(xí)能力,包括多樣本分類和上下文學(xué)習(xí)的穩(wěn)健性。

圖 6 展示了新舊 Transformer 模型的多樣本分類結(jié)果。結(jié)果表明,在不同的數(shù)據(jù)集和不同的演示樣本數(shù)量上,差分 Transformer 均穩(wěn)定地優(yōu)于 Transformer。此外,差分 Transformer 的平均準(zhǔn)確度優(yōu)勢(shì)也很明顯,從 5.2% 到 21.6% 不等。

圖 7 則展示了兩種模型的上下文學(xué)習(xí)穩(wěn)健性結(jié)果。該分析基于 TREC 數(shù)據(jù)集,并且采用了兩種提示詞格式:示例隨機(jī)排列(圖 7a)和按類別交替排列(圖 7b)。

在這兩種設(shè)置下,差分 Transformer 的性能方差要小得多。結(jié)果表明,新方法在上下文學(xué)習(xí)任務(wù)中更為穩(wěn)健。相比之下,Transformer 容易受到順序排列的影響,導(dǎo)致最佳結(jié)果與最差結(jié)果之間差距巨大。

上下文幻覺評(píng)估

該團(tuán)隊(duì)基于文本摘要和問答任務(wù)評(píng)估了模型的上下文幻覺現(xiàn)象。結(jié)果見表 4。

可以看到,相比于常規(guī) Transformer,差分 Transformer 在摘要和問答任務(wù)上的上下文幻覺更低。該團(tuán)隊(duì)表示,原因可能是差分 Transformer 能更好地關(guān)注任務(wù)所需的基本信息,而不是無(wú)關(guān)上下文。

激活異常值分析

在 LLM 中,一部分激活值明顯大于大多數(shù)激活值的現(xiàn)象被稱為激活異常值(activation outliers)。異常值導(dǎo)致訓(xùn)練和推理過程中模型量化困難。實(shí)驗(yàn)表明差分 Transformer 可以降低激活異常值的幅度,從而可能實(shí)現(xiàn)更低的量化位寬。

表 5 展示了兩個(gè)訓(xùn)練得到 Transformer 和差分 Transformer 模型的激活值統(tǒng)計(jì)情況。這里分析了兩種類型的激活,包括注意力 logit(即 pre-softmax 激活)和隱藏狀態(tài)(即層輸出)??梢钥吹?,盡管中位數(shù)相似,但與 Transformer 相比,差分 Transformer 的較大激活值要低得多。這表明新方法產(chǎn)生的激活異常值較少。

圖 8 則展示了將注意力 logit 量化到更低位的情況。這里使用的方案是:使用 absmax 量化的動(dòng)態(tài)后訓(xùn)練量化。其中,16 位配置表示未經(jīng)量化的原始結(jié)果。模型逐步量化為 8 位、6 位和 4 位。這里報(bào)告的是在 HellaSwag 上的零樣本準(zhǔn)確度,但該團(tuán)隊(duì)也指出在其它數(shù)據(jù)集上也有類似表現(xiàn)。

從圖中可知,即使降低位寬,差分 Transformer 也能保持較高性能。相較之下,常規(guī) Transformer 的準(zhǔn)確度在 6 位和 4 位量化時(shí)會(huì)顯著下降。這一結(jié)果表明,差分 Transformer 本身就能緩解注意力分?jǐn)?shù)中的激活異常值問題,從而可為低位 FlashAttention 的實(shí)現(xiàn)提供新機(jī)會(huì)。

最后,該團(tuán)隊(duì)也進(jìn)行了消融實(shí)驗(yàn),證明了各個(gè)新設(shè)計(jì)的有效性。

文章轉(zhuǎn)自微信公眾號(hào)@算法進(jìn)階

上一篇:

超完整!11 種經(jīng)典時(shí)間序列預(yù)測(cè)方法!

下一篇:

圖神經(jīng)網(wǎng)絡(luò)加速綜述: 算法、系統(tǒng)和硬件
#你可能也喜歡這些API文章!

我們有何不同?

API服務(wù)商零注冊(cè)

多API并行試用

數(shù)據(jù)驅(qū)動(dòng)選型,提升決策效率

查看全部API→
??

熱門場(chǎng)景實(shí)測(cè),選對(duì)API

#AI文本生成大模型API

對(duì)比大模型API的內(nèi)容創(chuàng)意新穎性、情感共鳴力、商業(yè)轉(zhuǎn)化潛力

25個(gè)渠道
一鍵對(duì)比試用API 限時(shí)免費(fèi)

#AI深度推理大模型API

對(duì)比大模型API的邏輯推理準(zhǔn)確性、分析深度、可視化建議合理性

10個(gè)渠道
一鍵對(duì)比試用API 限時(shí)免費(fèi)