其中:

2. 輸入門(Input Gate)

輸入門用于決定當前時刻哪些信息需要寫入細胞狀態(tài)。包括兩個步驟:

計算輸入門的激活值:

生成新的候選細胞狀態(tài):

其中:

3. 更新 Cell State

細胞狀態(tài)的更新包括兩部分:保留部分舊的細胞狀態(tài)和添加新的細胞狀態(tài)。公式如下:

其中:

4. 輸出門(Output Gate)

輸出門用于決定當前時刻的隱藏狀態(tài)。公式如下:

當前時刻的隱藏狀態(tài):

其中:

總的來說,LSTM通過引入遺忘門、輸入門和輸出門,有效地控制了信息的流動,解決了傳統(tǒng)RNN在處理長序列數(shù)據時的梯度消失和梯度爆炸問題。

這里再匯總一下上面的公式:

  1. 遺忘門:
  2. 輸入門:
  3. 候選細胞狀態(tài):$ \tilde{C}t = \tanh(W_C \cdot [h{t-1}, x_t] + b_C) $
  4. 更新細胞狀態(tài):
  5. 輸出門:
  6. 隱藏狀態(tài):

通過這些公式,LSTM能夠有效地捕捉和保留序列中的長期依賴關系,從而在處理時間序列數(shù)據、自然語言處理等任務中表現(xiàn)出色。

案例:利用LSTM進行時間序列預測

數(shù)據集介紹

數(shù)據集來自UCI機器學習庫的北京PM2.5數(shù)據集。數(shù)據包含2010年至2014年北京空氣質量監(jiān)測數(shù)據,包括PM2.5濃度、天氣數(shù)據等。

公眾號后臺,回復「數(shù)據集」,取PRSA_data_2010.1.1-2014.12.31.csv

PRSA_data_2010.1.1-2014.12.31.csv 數(shù)據集包括以下列:

算法流程

  1. 數(shù)據加載與預處理
  2. 構建LSTM模型
  3. 模型訓練
  4. 模型評估

完整代碼

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from keras.models import Sequential
from keras.layers import LSTM, Dense
from keras.callbacks import EarlyStopping

# 數(shù)據加載
data = pd.read_csv('PRSA_data_2010.1.1-2014.12.31.csv')

# 數(shù)據預處理
data['pm2.5'].fillna(data['pm2.5'].mean(), inplace=True) # 處理缺失值

# 選擇特征和目標變量
features = ['DEWP', 'TEMP', 'PRES', 'Iws']
target = 'pm2.5'

# 標準化
scaler = MinMaxScaler()
data_scaled = scaler.fit_transform(data[features + [target]])

# 創(chuàng)建時間序列數(shù)據集
def create_dataset(dataset, look_back=1):
X, y = [], []
for i in range(len(dataset) - look_back):
X.append(dataset[i:(i + look_back), :-1])
y.append(dataset[i + look_back, -1])
return np.array(X), np.array(y)

look_back = 24
X, y = create_dataset(data_scaled, look_back)

# 劃分訓練集和測試集
train_size = int(len(X) * 0.8)
X_train, X_test = X[:train_size], X[train_size:]
y_train, y_test = y[:train_size], y[train_size:]

# 構建LSTM模型
model = Sequential()
model.add(LSTM(50, return_sequences=True, input_shape=(look_back, len(features))))
model.add(LSTM(50))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mean_squared_error')

# 模型訓練
early_stopping = EarlyStopping(monitor='val_loss', patience=10)
history = model.fit(X_train, y_train, epochs=100, batch_size=32, validation_split=0.2, callbacks=[early_stopping])

# 模型評估
y_pred = model.predict(X_test)

# 反標準化
y_test_inv = scaler.inverse_transform(np.concatenate((X_test[:, -1, :-1], y_test.reshape(-1, 1)), axis=1))[:, -1]
y_pred_inv = scaler.inverse_transform(np.concatenate((X_test[:, -1, :-1], y_pred), axis=1))[:, -1]

# 繪制結果
plt.figure(figsize=(12, 6))
plt.plot(y_test_inv, label='True')
plt.plot(y_pred_inv, label='Predicted')
plt.legend()
plt.show()

# 繪制損失曲線
plt.figure(figsize=(12, 6))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.show()

代碼中的步驟一些細節(jié),再給大家說一下:

  1. 超參數(shù)調整
  2. 特征工程
  3. 數(shù)據處理
  4. 模型改進

最后

以上,整個是一個完整的LSTM時間序列預測的案例,包括數(shù)據預處理、模型構建、訓練、評估和可視化。大家在自己實際的實驗中,根據需求進行進一步的調整和優(yōu)化。

本文章轉載微信公眾號@深夜努力寫Python

上一篇:

選擇最佳模型,輕松上手 GBDT、LightGBM、XGBoost、AdaBoost ?。?/h5>

下一篇:

講透一個強大算法模型,層次聚類!!
#你可能也喜歡這些API文章!

我們有何不同?

API服務商零注冊

多API并行試用

數(shù)據驅動選型,提升決策效率

查看全部API→
??

熱門場景實測,選對API

#AI文本生成大模型API

對比大模型API的內容創(chuàng)意新穎性、情感共鳴力、商業(yè)轉化潛力

25個渠道
一鍵對比試用API 限時免費

#AI深度推理大模型API

對比大模型API的邏輯推理準確性、分析深度、可視化建議合理性

10個渠道
一鍵對比試用API 限時免費