1.2 RNN 介紹

循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)是基于序列數(shù)據(jù)(如語(yǔ)言、語(yǔ)音、時(shí)間序列)的遞歸性質(zhì)而設(shè)計(jì)的,是一種反饋類型的神經(jīng)網(wǎng)絡(luò),它專門用于處理序列數(shù)據(jù),如逐字生成文本或預(yù)測(cè)時(shí)間序列數(shù)據(jù)(例如股票價(jià)格、詩(shī)歌生成)。

RNN和全連接神經(jīng)網(wǎng)絡(luò)的本質(zhì)差異在于“輸入是帶有反饋信息的”,RNN除了接受每一步的輸入x(t) ,同時(shí)還有輸入上一步的歷史反饋信息——隱藏狀態(tài)h (t-1) ,也就是當(dāng)前時(shí)刻的隱藏狀態(tài)h(t) 或決策輸出O(t) 由當(dāng)前時(shí)刻的輸入 x(t) 和上一時(shí)刻的隱藏狀態(tài)h (t-1) 共同決定。從某種程度,RNN和大腦的決策很像,大腦接受當(dāng)前時(shí)刻感官到的信息(外部的x(t) )和之前的想法(內(nèi)部的h (t-1) )的輸入一起決策。

RNN的結(jié)構(gòu)原理可以簡(jiǎn)要概述為兩個(gè)公式。

RNN的隱藏狀態(tài)為:h(t) = f( U * x(t) + W * h(t-1) + b1), f為激活函數(shù),常用tanh、relu;

RNN的輸出為:o(t) = g( V * h(t) + b2),g為激活函數(shù),當(dāng)用于分類任務(wù),一般用softmax;

1.3 從RNN到LSTM

但是在實(shí)際中,RNN在長(zhǎng)序列數(shù)據(jù)處理中,容易導(dǎo)致梯度爆炸或者梯度消失,也就是長(zhǎng)期依賴(long-term dependencies)問(wèn)題,其根本原因就是模型“記憶”的序列信息太長(zhǎng)了,都會(huì)一股腦地記憶和學(xué)習(xí),時(shí)間一長(zhǎng),就容易忘掉更早的信息(梯度消失)或者崩潰(梯度爆炸)。

梯度消失:歷史時(shí)間步的信息距離當(dāng)前時(shí)間步越長(zhǎng),反饋的梯度信號(hào)就會(huì)越弱(甚至為0)的現(xiàn)象,梯度被近距離梯度主導(dǎo),導(dǎo)致模型難以學(xué)到遠(yuǎn)距離的依賴關(guān)系。

改善措施:可以使用 ReLU 激活函數(shù);門控RNN 如GRU、LSTM 以改善梯度消失

梯度爆炸:網(wǎng)絡(luò)層之間的梯度(值大于 1)重復(fù)相乘導(dǎo)致的指數(shù)級(jí)增長(zhǎng)會(huì)產(chǎn)生梯度爆炸,導(dǎo)致模型無(wú)法有效學(xué)習(xí)。

改善措施:可以使用 梯度截?cái)啵灰龑?dǎo)信息流的正則化;ReLU 激活函數(shù);門控RNN 如GRU、LSTM(和普通 RNN 相比多經(jīng)過(guò)了很多次導(dǎo)數(shù)都小于 1激活函數(shù),因此 LSTM 發(fā)生梯度爆炸的頻率要低得多)以改善梯度爆炸。

所以,如果我們能讓 RNN 在接受上一時(shí)刻的狀態(tài)和當(dāng)前時(shí)刻的輸入時(shí),有選擇地記憶和遺忘一部分內(nèi)容(或者說(shuō)信息),問(wèn)題就可以解決了。比如上上句話提及”我去考試了“,然后后面提及”我考試通過(guò)了“,那么在此之前說(shuō)的”我去考試了“的內(nèi)容就沒(méi)那么重要,選擇性地遺忘就好了。這也就是長(zhǎng)短期記憶網(wǎng)絡(luò)(Long Short-Term Memory, LSTM)的基本思想。

二、LSTM原理

LSTM是種特殊RNN網(wǎng)絡(luò),在RNN的基礎(chǔ)上引入了“門控”的選擇性機(jī)制,分別是遺忘門、輸入門和輸出門,從而有選擇性地保留或刪除信息,以能夠較好地學(xué)習(xí)長(zhǎng)期依賴關(guān)系。如下圖RNN(上) 對(duì)比 LSTM(下):

2.1 LSTM的核心

在RNN基礎(chǔ)上引入門控后的LSTM,結(jié)構(gòu)看起來(lái)好復(fù)雜!但其實(shí)LSTM作為一種反饋神經(jīng)網(wǎng)絡(luò),核心還是歷史的隱藏狀態(tài)信息的反饋,也就是下圖的Ct:

對(duì)標(biāo)RNN的ht隱藏狀態(tài)的更新,LSTM的Ct只是多個(gè)些“門控”刪除或添加信息到狀態(tài)信息。由下面依次介紹LSTM的“門控”:遺忘門,輸入門,輸出門的功能,LSTM的原理也就好理解了。

2.2 遺忘門

LSTM 的第一步是通過(guò)”遺忘門”從上個(gè)時(shí)間點(diǎn)的狀態(tài)Ct-1中丟棄哪些信息。

具體來(lái)說(shuō),輸入Ct-1,會(huì)先根據(jù)上一個(gè)時(shí)間點(diǎn)的輸出ht-1和當(dāng)前時(shí)間點(diǎn)的輸入xt,并通過(guò)sigmoid激活函數(shù)的輸出結(jié)果ft來(lái)確定要讓Ct-1,來(lái)忘記多少,sigmoid后等于1表示要保存多一些Ct-1的比重,等于0表示完全忘記之前的Ct-1。

2.3 輸入門

下一步是通過(guò)輸入門,決定我們將在狀態(tài)中存儲(chǔ)哪些新信息。

我們根據(jù)上一個(gè)時(shí)間點(diǎn)的輸出ht-1和當(dāng)前時(shí)間點(diǎn)的輸入xt 生成兩部分信息i t 及C~t,通過(guò)sigmoid輸出i t,用tanh輸出C~t。之后通過(guò)把i t 及C~t兩個(gè)部分相乘,共同決定在狀態(tài)中存儲(chǔ)哪些新信息。

在輸入門 + 遺忘門控制下,當(dāng)前時(shí)間點(diǎn)狀態(tài)信息Ct為:

2.4 輸出門

最后,我們根據(jù)上一個(gè)時(shí)間點(diǎn)的輸出ht-1和當(dāng)前時(shí)間點(diǎn)的輸入xt 通過(guò)sigmid 輸出Ot,再根據(jù)Ot 與 tanh控制的當(dāng)前時(shí)間點(diǎn)狀態(tài)信息Ct 相乘作為最終的輸出。

綜上,一張圖可以說(shuō)清LSTM原理:

三、LSTM簡(jiǎn)單寫詩(shī)

本節(jié)項(xiàng)目利用深層LSTM模型,學(xué)習(xí)大小為10M的詩(shī)歌數(shù)據(jù)集,自動(dòng)可以生成詩(shī)歌。

如下代碼構(gòu)建LSTM模型。

## 本項(xiàng)目完整代碼:github.com/aialgorithm/Blog
# 或“算法進(jìn)階”公眾號(hào)文末閱讀原文可見(jiàn)

model = tf.keras.Sequential([
# 不定長(zhǎng)度的輸入
tf.keras.layers.Input((None,)),
# 詞嵌入層
tf.keras.layers.Embedding(input_dim=tokenizer.vocab_size, output_dim=128),
# 第一個(gè)LSTM層,返回序列作為下一層的輸入
tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),
# 第二個(gè)LSTM層,返回序列作為下一層的輸入
tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),
# 對(duì)每一個(gè)時(shí)間點(diǎn)的輸出都做softmax,預(yù)測(cè)下一個(gè)詞的概率
tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(tokenizer.vocab_size, activation='softmax')),
])

# 查看模型結(jié)構(gòu)
model.summary()
# 配置優(yōu)化器和損失函數(shù)
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.categorical_crossentropy)

模型訓(xùn)練,考慮訓(xùn)練時(shí)長(zhǎng),就簡(jiǎn)單訓(xùn)練2個(gè)epoch。

class Evaluate(tf.keras.callbacks.Callback):
"""
訓(xùn)練過(guò)程評(píng)估,在每個(gè)epoch訓(xùn)練完成后,保留最優(yōu)權(quán)重,并隨機(jī)生成SHOW_NUM首古詩(shī)展示
"""

def __init__(self):
super().__init__()
# 給loss賦一個(gè)較大的初始值
self.lowest = 1e10

def on_epoch_end(self, epoch, logs=None):
# 在每個(gè)epoch訓(xùn)練完成后調(diào)用
# 如果當(dāng)前l(fā)oss更低,就保存當(dāng)前模型參數(shù)
if logs['loss'] <= self.lowest:
self.lowest = logs['loss']
model.save(BEST_MODEL_PATH)
# 隨機(jī)生成幾首古體詩(shī)測(cè)試,查看訓(xùn)練效果
print("cun'h")
for i in range(SHOW_NUM):
print(generate_acrostic(tokenizer, model, head="春花秋月"))

# 創(chuàng)建數(shù)據(jù)集
data_generator = PoetryDataGenerator(poetry, random=True)
# 開(kāi)始訓(xùn)練
model.fit_generator(data_generator.for_fit(), steps_per_epoch=data_generator.steps, epochs=TRAIN_EPOCHS,
callbacks=[Evaluate()])

加載簡(jiǎn)單訓(xùn)練的LSTM模型,輸入關(guān)鍵字(如:算法進(jìn)階)后,自動(dòng)生成藏頭詩(shī)。可以看出詩(shī)句粗略看上去挺優(yōu)雅,但實(shí)際上經(jīng)不起推敲。后面增加訓(xùn)練的epoch及數(shù)據(jù)集應(yīng)該可以更好些。

# 加載訓(xùn)練好的模型
model = tf.keras.models.load_model(BEST_MODEL_PATH)

keywords = input('輸入關(guān)鍵字:\n')

# 生成藏頭詩(shī)
for i in range(SHOW_NUM):
print(generate_acrostic(tokenizer, model, head=keywords),'\n')

參考資料:https://colah.github.io/posts/2015-08-Understanding-LSTMs/ https://towardsdatascience.com/illustrated-guide-to-lstms-and-gru-s-a-step-by-step-explanation-44e9eb85bf21 https://www.zhihu.com/question/34878706

文章轉(zhuǎn)自微信公眾號(hào)@算法進(jìn)階

上一篇:

神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)到的是什么?(Python)

下一篇:

時(shí)序預(yù)測(cè)的深度學(xué)習(xí)算法介紹
#你可能也喜歡這些API文章!

我們有何不同?

API服務(wù)商零注冊(cè)

多API并行試用

數(shù)據(jù)驅(qū)動(dòng)選型,提升決策效率

查看全部API→
??

熱門場(chǎng)景實(shí)測(cè),選對(duì)API

#AI文本生成大模型API

對(duì)比大模型API的內(nèi)容創(chuàng)意新穎性、情感共鳴力、商業(yè)轉(zhuǎn)化潛力

25個(gè)渠道
一鍵對(duì)比試用API 限時(shí)免費(fèi)

#AI深度推理大模型API

對(duì)比大模型API的邏輯推理準(zhǔn)確性、分析深度、可視化建議合理性

10個(gè)渠道
一鍵對(duì)比試用API 限時(shí)免費(fèi)