Mamba: Linear-Time Sequence Modeling with Selective State Spaces一文中提出了Mamba,我們在之前的文章中也有詳細(xì)的介紹。

在本篇文章中,通過將繪制RNN,transformer,和Mamba的架構(gòu)圖,并進(jìn)行詳細(xì)的對比,這樣我們可以更詳細(xì)的了解它們之間的區(qū)別。

為了說明為什么Mamba是這樣一個有趣的架構(gòu),讓我們先介紹Transformer。

Transformer

Transformer將任何文本輸入視為由令牌組成的序列。

transformer的一個主要優(yōu)點是,無論它接收到多長的輸入,它都使用序列中的任何令牌信息(無論序列有多長)來對輸入數(shù)據(jù)進(jìn)行處理。

這就是我們在論文中看到的注意力機(jī)制的作用,但是為了獲得全局信息,注意力機(jī)制在長序列上非常耗費顯存,這個我們后面說。

Transformer由兩個結(jié)構(gòu)組成,一組用于表示文本的編碼器塊和一組用于生成文本的解碼器塊。這些結(jié)構(gòu)可以用于多種任務(wù),包括翻譯。

我們可以采用這種結(jié)構(gòu)來創(chuàng)建僅使用解碼器的生成模型。比如基于Transformer的GPT,使用解碼器塊來完成一些輸入文本。

單個解碼器塊由兩個主要部分組成,一個是自注意力模塊,另一個是前饋神經(jīng)網(wǎng)絡(luò)。

注意力創(chuàng)建一個矩陣,將每個令牌與之前的每個令牌進(jìn)行比較。矩陣中的權(quán)重由令牌對之間的相關(guān)性決定。

它支持并行化,所以可以極大地加快訓(xùn)練速度!

但是當(dāng)生成下一個令牌時,我們需要重新計算整個序列的注意力,即使我們已經(jīng)生成了一些新的令牌。

為長度為L的序列生成令牌大約需要L2的計算量,如果序列長度增加,計算量可能會很大。并且在這里需要計算所有令牌的注意力,所以如果序列很長,那么內(nèi)存占用也會很大。所以需要重新計算整個序列是Transformer體系結(jié)構(gòu)的主要瓶頸。當(dāng)然也有很多技巧來提升注意力機(jī)制的效率,這里我們暫時不提,只看最經(jīng)典的原始論文。

RNN

下面我們介紹更早的序列模型RNN。循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)是一種基于序列的網(wǎng)絡(luò)。它在序列的每個時間步長取兩個輸入,即時間步長t的輸入和前一個時間步長t-1的隱藏狀態(tài),以生成下一個隱藏狀態(tài)并預(yù)測輸出。

RNN有一個循環(huán)機(jī)制,允許它們將信息從上一步傳遞到下一步。我們可以“展開”這個可視化,使它更明確。

在生成輸出時,RNN只需要考慮之前的隱藏狀態(tài)和當(dāng)前的輸入。這樣不會重新計算以前的隱藏狀態(tài),這正Transformer沒有的。

這種流程可以讓RNN進(jìn)行快速推理,因為的時間與序列長度線性擴(kuò)展!并且可以有無限的上下文長度(理論上),因為每次推理他只取一個隱藏狀態(tài)和當(dāng)前輸入,內(nèi)存的占用是非常穩(wěn)定的。

我們將RNN應(yīng)用于之前使用過的輸入文本。

每個隱藏狀態(tài)都是以前所有隱藏狀態(tài)的聚合。但是這里就出現(xiàn)了問題,在生成名稱“Maarten”時,最后一個隱藏狀態(tài)不再包含關(guān)于單詞“Hello”的信息(或者說最早的信息會被坐進(jìn)的信息覆蓋)。這會導(dǎo)致隨著時間的推移,rnn會忘記信息,因為它們只考慮前一個狀態(tài)。

并且rnn的這種順序性產(chǎn)生了另一個問題。訓(xùn)練不能并行進(jìn)行,因為它需要按順序完成每一步。

與Transformer相比,rnn的問題完全相反!它的推理速度非常快,但不能并行化導(dǎo)致訓(xùn)練很慢。

人們一直在尋找一種既能像Transformer那樣并行化訓(xùn)練,能夠記住先前的信息,并且在推理時間還是隨序列長度線性增長的模型,Mamba就是這樣宣傳的。

在介紹Mamba之前,讓我們還需要介紹以下狀態(tài)空間模型

The State Space Model (SSM)

狀態(tài)空間模型(SSM),像Transformer和RNN一樣,可以處理序列信息,比如文本,也包括信號。

狀態(tài)空間是包含能夠完全描述一個系統(tǒng)的最少數(shù)量變量的概念。它是一種通過定義系統(tǒng)可能的狀態(tài)來數(shù)學(xué)表示問題的方式。

比如說我們正在通過一個迷宮。”狀態(tài)空間” 就是所有可能位置(狀態(tài))的地圖。每個點代表迷宮中的一個獨特位置,具有特定的細(xì)節(jié),比如你離出口有多遠(yuǎn)。

“狀態(tài)空間表示” 是對這個地圖的簡化描述。它展示了你當(dāng)前所處的位置(當(dāng)前狀態(tài)),以及下一步可以去哪里(可能的未來)。

雖然狀態(tài)空間模型使用方程和矩陣來跟蹤這種行為,描述狀態(tài)的變量,在我們的例子中是X和Y坐標(biāo)以及到出口的距離,可以表示為“狀態(tài)向量”。

聽起來熟悉嗎?這不就是強(qiáng)化學(xué)習(xí)中的狀態(tài)嗎,我個人認(rèn)為是可以這么理解的,那么怎么和序列有關(guān)呢?

因為語言模型中的嵌入或向量也經(jīng)常用于描述輸入序列的“狀態(tài)”。例如,你當(dāng)前位置的向量(狀態(tài)向量)可能看起來像這樣:

在神經(jīng)網(wǎng)絡(luò)中,“狀態(tài)”通常是指其隱藏狀態(tài),在大型語言模型的背景下,這是生成新標(biāo)記的一個最重要的方面之一。

狀態(tài)空間模型(SSMs)是用于描述這些狀態(tài)表示并根據(jù)某些輸入進(jìn)行下一個狀態(tài)預(yù)測的模型。

在時間t,狀態(tài)空間模型(SSMs):

這里就與強(qiáng)化學(xué)習(xí)中使用離散序列(如僅向左移動一次)不同,它將連續(xù)序列作為輸入并預(yù)測輸出序列。

ssm假設(shè)動態(tài)系統(tǒng),例如在三維空間中移動的物體,可以通過兩個方程從時間t的狀態(tài)預(yù)測。

通過求解這些方程,假設(shè)可以揭示基于觀測數(shù)據(jù)(輸入序列和先前狀態(tài))預(yù)測系統(tǒng)狀態(tài)的統(tǒng)計原理。

它的目標(biāo)是找到這個狀態(tài)表示h(t)這樣我們就可以從一個輸入序列到一個輸出序列。

這兩個方程就是是狀態(tài)空間模型的核心。狀態(tài)方程描述了基于輸入如何影響狀態(tài)(通過矩陣B)的狀態(tài)變化(通過矩陣A)。

h(t)表示任意時刻t的潛在狀態(tài)表示,而x(t)表示某個輸入。

輸出方程描述了狀態(tài)如何轉(zhuǎn)化為輸出(通過矩陣C),以及輸入如何影響輸出(通過矩陣D)。

矩陣A、B、C和D通常被稱為參數(shù),因為它們是可學(xué)習(xí)的。將這兩個方程可視化,我們可以得到如下架構(gòu):

下面我們看看這些矩陣如何影響學(xué)習(xí)過程。

假設(shè)我們有一個輸入信號x(t)這個信號首先乘以矩陣B它描述了輸入如何影響系統(tǒng)。

更新狀態(tài)(h)是包含環(huán)境核心“知識”的潛在空間。我們將狀態(tài)與矩陣A相乘,矩陣A描述了所有內(nèi)部狀態(tài)是如何連接的,因為它們代表了系統(tǒng)的潛在表示。

這里可以看到,在創(chuàng)建狀態(tài)表示之前應(yīng)用矩陣A,并在狀態(tài)表示更新之后更新矩陣A。

然后使用矩陣C來描述如何將狀態(tài)轉(zhuǎn)換為輸出。

最后利用矩陣D提供從輸入到輸出的直接信號。這通常也被稱為跳過(殘差)連接。

由于矩陣D類似于跳過連接,所以SSM通常被視為為不進(jìn)行跳過連接的部分

回到我們的簡化視圖,現(xiàn)在可以將重點放在矩陣A、B和C上,它們是SSM的核心。

更新原始方程并添加一些顏色來表示每個矩陣的目的

這兩個方程根據(jù)觀測數(shù)據(jù)預(yù)測系統(tǒng)的狀態(tài)。由于期望輸入是連續(xù)的,SSM是連續(xù)時間表示。

但是因為文字都是離散的輸入,我們還需要將模型離散化。這里就要使用* Zero-order hold * 技術(shù)

每次我們接收到一個離散信號,都會保證他的值不變,直到接收到一個新的離散信號再改變。這個過程創(chuàng)建了一個SSM可以使用的連續(xù)信號:

我們保持該值的時間由一個新的可學(xué)習(xí)參數(shù)表示,稱為步長?。這樣就得到了一個連續(xù)的信號并且可以只根據(jù)輸入的時間步長對值進(jìn)行采樣。

這些采樣值就是我們的離散輸出!在數(shù)學(xué)上,我們可以應(yīng)用Zero-order hold如下:

因為我們SSM處理的是離散信號,所以這里不是一個函數(shù)到函數(shù),x(t)→y(t),而是一個序列到序列,x?→y?,我們用公式表示如下:

矩陣A和B現(xiàn)在表示模型的離散參數(shù),用k代替t來表示離散的時間步長。

離散化的SSM允許在特定的時間步中處理信息。就像我們之前在循環(huán)神經(jīng)網(wǎng)絡(luò)(RNNs)中看到的那樣,循環(huán)方法在這里也非常有用,可以將問題重新表述為時間步驟:

在每個時間步長,我們計算當(dāng)前輸入(Bx?)如何影響前一個狀態(tài)(Ah??),然后計算預(yù)測輸出(Ch?)。

這種表示看起來是不是有點熟悉?其實他的處理方法和RNN一樣

也可以這樣展開:

這種技術(shù)與RNN類似,快速推理和慢速訓(xùn)練。

另一種ssm的表示是卷積的表示。我們應(yīng)用過濾器(核)來獲得聚合特征:

因為我們處理的是文本而不是圖像,所以我只要一維的視角:

我們用來表示這個“過濾器”的核是由SSM公式推導(dǎo)出來的:

可以使用SSM核遍歷每一組令牌并計算輸出:

上圖也說明了padding 可能對輸出產(chǎn)生的影響,所以我們一般都會在末尾padding而不是在前面。第二步核被移動一次來執(zhí)行下一步計算:

在最后一步,我們可以看到核的完整效果:

卷積的一個主要好處是它可以并行訓(xùn)練。但是由于核大小是固定,它們的推理不如rnn快速并且對序列長度有限制。

上面的三種SMM都有各自的優(yōu)缺點

這里可以使用一個簡單的技巧,即根據(jù)任務(wù)選擇表示。在訓(xùn)練過程中使用可以并行化的卷積表示,在推理過程中,我們使用高效的循環(huán)表示:

聽起來有點奇幻,但是有人就是實現(xiàn)出來了,這個模型叫做Linear State-Space Layer (LSSL)

https://proceedings.neurips.cc/paper_files/paper/2021/hash/05546b0e38ab9175cd905eebcc6ebb76-Abstract.html

它結(jié)合了線性動態(tài)系統(tǒng)理論和神經(jīng)網(wǎng)絡(luò)的概念,可以有效地捕獲數(shù)據(jù)中的時序信息和動態(tài)特征。LSSL 基于線性動態(tài)系統(tǒng)理論,這種系統(tǒng)可以用狀態(tài)空間模型表示。在這個模型中,系統(tǒng)的行為由狀態(tài)變量的演化和外部控制信號的影響決定。狀態(tài)變量是系統(tǒng)的內(nèi)部表示,可以捕獲系統(tǒng)的動態(tài)特性。

這些表示都有一個重要的特性,即線性時不變性(LTI)。LTI表示ssm參數(shù)A、B和C對于所有時間步長都是固定的。這意味著對于SSM生成的每個令牌,矩陣A、B和C都是相同的。

也就是說無論給SSM的序列是什么,A、B和C的值都保持不變。這樣就得到了一個不感知內(nèi)容的靜態(tài)表示。但是靜態(tài)表示沒有任何意義對吧,所以Mamba解決的就是這個問題,但是在介紹Mamba之前,我們還有一個知識點需要強(qiáng)調(diào),那就是矩陣A

因為SSM公式中最重要的就是矩陣a。正如我們之前在循環(huán)表示中看到的那樣,它捕獲了關(guān)于前一個狀態(tài)的信息來構(gòu)建新狀態(tài),如果矩陣a如果跟RNN一樣會遺忘掉非常靠前的信息那么SMM將沒有任何的意義,對吧。

矩陣A產(chǎn)生隱藏狀態(tài):

如何保留大上下文大小的方式創(chuàng)建矩陣A呢?

HiPPO 的模型結(jié)合了遞歸記憶(Recurrent Memory)和最優(yōu)多項式投影(Optimal Polynomial Projections)的概念,這種投影技術(shù)可以顯著改善遞歸記憶的性能,特別是在處理長序列和長期依賴關(guān)系時。

https://proceedings.neurips.cc/paper/2020/hash/102f0bb6efb3a6128a3c750dd16729be-Abstract.html

使用矩陣A來構(gòu)建一個狀態(tài)表示,該狀態(tài)表示可以很好地捕獲最近的令牌并衰減較舊的令牌。其公式可表示為:

具體的詳細(xì)內(nèi)容我們就不介紹了,有興趣的查看原論文。

這樣我們就基本上解決了所有的問題:1、狀態(tài)空間模型;2、處理遠(yuǎn)程依賴關(guān)系;3、離散化和并行計算

如果想深入了解有關(guān)如何計算HiPPO矩陣和自己構(gòu)建S4模型建議您閱讀注釋的S4。

https://srush.github.io/annotated-s4/

Mamba

上面介紹完所有必要的基礎(chǔ)知識,最后就是我們的重點了

Mamba 有兩個主要貢獻(xiàn):

1、選擇性掃描算法,模型可以過濾有關(guān)和無關(guān)的信息

2、硬件感知算法,通過并行掃描、核融合和重計算有效地存儲(中間)結(jié)果。

在探討這兩個主要貢獻(xiàn)之前,我們先看看一下為什么它們是必要的。

狀態(tài)空間模型,S4(Structured State Space Model),在語言建模和生成中的某些任務(wù)上表現(xiàn)不佳

比如在選擇性復(fù)制任務(wù)中,SSM的目標(biāo)是按順序復(fù)制輸入和輸出的部分:

(循環(huán)/卷積)SSM在這個任務(wù)中表現(xiàn)不佳,因為它是線性時不變的。對于SSM生成的每個令牌,矩陣A、B和C都是相同的。

因為它將每個令牌平等地視為固定的a、B和C矩陣的結(jié)果,所以SSM不能執(zhí)行內(nèi)容感知推理

SSM表現(xiàn)不佳的第二個任務(wù)是重現(xiàn)輸入中發(fā)現(xiàn)的模式:

我們的提示在“教”模型在每個“Q:”之后提供“A:”響應(yīng)。但是由于ssm是時間不變的,它不能選擇從其歷史中獲取先前的令牌。

以矩陣B為例不管輸入x是什么,矩陣B保持完全相同,并且與x無關(guān):

同理無論輸入是什么,A和C也不變,這就是我們上面說的靜態(tài)。

而Transformers 可以根據(jù)輸入序列動態(tài)地改變注意力??梢赃x擇性地“看”或“注意”序列的不同部分,再加上位置編碼,這使得Transformers對于這種任務(wù)非常的簡單。

ssm在這些任務(wù)上的糟糕性能說明了定常ssm的潛在問題,矩陣A、B和C的靜態(tài)特性導(dǎo)致了內(nèi)容感知問題。

選擇性地保留信息

SSM的循環(huán)表示創(chuàng)建了一個非常有效的小狀態(tài),因為它壓縮了整個歷史信息,所以與不壓縮歷史(注意力矩陣)的Transformer模型相比,它的功能要弱得多。

Mamba 的目標(biāo)是獲得Transformer一樣強(qiáng)大的“小”狀態(tài)

通過有選擇地將數(shù)據(jù)壓縮到狀態(tài),當(dāng)輸入一個句子時,通常會有一些信息,比如停頓詞,這些信息沒有太多的意義。

我們先看看SSM在訓(xùn)練期時的輸入和輸出維度:

在結(jié)構(gòu)化狀態(tài)空間模型(S4)中,矩陣a、B和C獨立于輸入,因為它們的維度N和D是靜態(tài)的,不會改變。

而Mamba通過結(jié)合輸入的序列長度和批量大小,使矩陣B和C,甚至步長?依賴于輸入:

這意味著對于每個輸入標(biāo)記,有不同的B和C矩陣,這解決了內(nèi)容感知問題!這里矩陣A保持不變,因為希望狀態(tài)本身保持靜態(tài),但影響它的方式(通過B和C)是動態(tài)的。

也就是說它們一起選擇性地選擇將什么保留在隱藏狀態(tài)中,什么需要忽略,這都是由輸入確定的。

較小的步長?導(dǎo)致忽略特定的單詞,而是更多地使用之前的上下文,而較大的步長?則更多地關(guān)注輸入單詞而不是上下文:

掃描操作

這些矩陣現(xiàn)在是動態(tài)的了所以它們不能使用卷積表示來計算,只能使用循環(huán)進(jìn)行處理,這就使得無法進(jìn)行并行化。

為了實現(xiàn)并行化,我們先看看循環(huán)的輸出:

每個狀態(tài)都是前一個狀態(tài)(乘以A)加上當(dāng)前輸入(乘以B)的和。這被稱為掃描操作,可以很容易地通過for循環(huán)計算出來。但是并行化似乎是不可能的,因為每個狀態(tài)只有在我們有前一個狀態(tài)時才能計算出來。

但是Mamba使用并行掃描算,通過關(guān)聯(lián)屬性假定執(zhí)行操作的順序無關(guān)緊要。這樣就可以計算部分序列并迭代組合它們:

這樣還有一個好處是因為順序不重要,也可以省略掉Transformer的位置編碼。

硬件感知的算法

最近gpu的一個缺點是它們在小但高效的SRAM和大但效率稍低的DRAM之間的傳輸(IO)速度有限。在SRAM和DRAM之間頻繁地復(fù)制信息成為瓶頸。

Mamba的DRAM和SRAM分配的具體實例如下:

中間狀態(tài)不被保存,但對于反向傳播計算梯度是必要的。作者重新計算了反向傳遞過程中的中間狀態(tài)。盡管這看起來效率很低,但它比從相對較慢的DRAM讀取所有這些中間狀態(tài)的成本要低得多。

這里我們就不詳細(xì)說明了,因為這部分我也沒太研究過

Mamba 塊

選擇性SSM可以作為一個塊,就像在Transformer中的的注意力模塊一樣。我們可以堆疊多個塊,并使用它們的輸出作為下一個曼巴塊的輸入:

最后一個端到端(輸入到輸出)的例子包含了歸一化層和選擇輸出標(biāo)記softmax。

這樣就得到了快速的推理和訓(xùn)練,而且是“無限”長度上下文的模型

總結(jié)

看完這篇文章,我希望你能對Mamba 和狀態(tài)空間模型有一定的了解,最后我們以作者的發(fā)現(xiàn)為結(jié)尾:

作者發(fā)現(xiàn)模型與相同尺寸的Transformer模型的性能相當(dāng),有時甚至超過了它們!

本文章轉(zhuǎn)載微信公眾號@算法進(jìn)階

上一篇:

全面!深度學(xué)習(xí)時間序列分類的綜述!

下一篇:

時空圖神經(jīng)網(wǎng)絡(luò)原理及Pytorch實現(xiàn)
#你可能也喜歡這些API文章!

我們有何不同?

API服務(wù)商零注冊

多API并行試用

數(shù)據(jù)驅(qū)動選型,提升決策效率

查看全部API→
??

熱門場景實測,選對API

#AI文本生成大模型API

對比大模型API的內(nèi)容創(chuàng)意新穎性、情感共鳴力、商業(yè)轉(zhuǎn)化潛力

25個渠道
一鍵對比試用API 限時免費

#AI深度推理大模型API

對比大模型API的邏輯推理準(zhǔn)確性、分析深度、可視化建議合理性

10個渠道
一鍵對比試用API 限時免費