本文將通過醫(yī)學(xué)數(shù)據(jù),使用 Python 演示如何復(fù)現(xiàn) SHAP 依賴圖,并詳細(xì)解釋連續(xù)性特征對模型預(yù)測結(jié)果的影響
SHAP 依賴圖用于可視化單個特征對機(jī)器學(xué)習(xí)模型預(yù)測結(jié)果的影響,具體來說,x 軸是特征值,y 軸是 SHAP 值(度量特征對預(yù)測結(jié)果的重要性),這些圖可以直觀地顯示出某個特征是對模型預(yù)測起正向還是負(fù)向作用
代碼實(shí)現(xiàn)
數(shù)據(jù)集加載
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['axes.unicode_minus'] = False
import warnings
warnings.filterwarnings("ignore")
df = pd.read_csv('Dataset.csv')
# 劃分特征和目標(biāo)變量
X = df.drop(['target'], axis=1)
y = df['target']
# 劃分訓(xùn)練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
random_state=42, stratify=df['target'])
df.head()
首先,需要加載數(shù)據(jù)集并將其劃分為特征 X 和目標(biāo)變量 y,然后進(jìn)行訓(xùn)練集和測試集的劃分。目標(biāo)變量是我們要預(yù)測的值,X 是輸入的特征,這是一個分類任務(wù),目標(biāo)是預(yù)測患者是否患有心臟病。雖然是分類任務(wù),但無論是分類問題還是回歸問題,SHAP 依賴圖的使用方式和原理是相同的,都可以用來解釋模型中各個特征對預(yù)測結(jié)果的貢獻(xiàn)
訓(xùn)練機(jī)器學(xué)習(xí)模型
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import GridSearchCV
# GBT模型參數(shù)
params_gbt = {
'learning_rate': 0.02, # 學(xué)習(xí)率,控制每一步的步長,用于防止過擬合。典型值范圍:0.01 - 0.1
'max_depth': 3, # 樹的深度,控制模型復(fù)雜度
'random_state': 42, # 隨機(jī)種子,用于重現(xiàn)模型的結(jié)果
'subsample': 0.7, # 每次迭代時隨機(jī)選擇的樣本比例,用于增加模型的泛化能力
}
# 初始化Gradient Boosting分類模型
model_gbt = GradientBoostingClassifier(**params_gbt)
# 定義參數(shù)網(wǎng)格,用于網(wǎng)格搜索
param_grid = {
'n_estimators': [100, 200, 300], # 樹的數(shù)量
'max_depth': [3, 4, 5], # 樹的深度
'learning_rate': [0.01, 0.1], # 學(xué)習(xí)率
}
# 使用GridSearchCV進(jìn)行網(wǎng)格搜索和k折交叉驗(yàn)證
grid_search = GridSearchCV(
estimator=model_gbt,
param_grid=param_grid,
scoring='neg_log_loss', # 評價指標(biāo)為負(fù)對數(shù)損失
cv=5, # 5折交叉驗(yàn)證
n_jobs=-1, # 并行計算
verbose=1 # 輸出詳細(xì)進(jìn)度信息
)
# 訓(xùn)練模型
grid_search.fit(X_train, y_train)
# 使用最優(yōu)參數(shù)訓(xùn)練模型
best_model = grid_search.best_estimator_
這里使用了梯度提升樹(GBT),這是一個強(qiáng)大且常用的機(jī)器學(xué)習(xí)算法,通過網(wǎng)格搜索進(jìn)行參數(shù)優(yōu)化
計算 SHAP 值
import shap
explainer = shap.TreeExplainer(best_model)
# 計算shap值為numpy.array數(shù)組
shap_values_numpy = explainer.shap_values(X)
# 計算shap值為Explanation格式
shap_values_Explanation = explainer(X)
模型訓(xùn)練完畢后,可以使用 shap 包來計算 SHAP 值,SHAP 值用于衡量特定特征對模型輸出的影響,這里分別通過 explainer.shap_values(X) 計算 SHAP 值為數(shù)組格式以便自定義繪制,和通過 explainer(X) 計算為 Explanation 格式,直接使用 SHAP 自帶的繪圖函數(shù)進(jìn)行可視化
默認(rèn)參數(shù)下繪制
# 繪制 'age' 特征的SHAP依賴圖
shap.dependence_plot('age', shap_values_Explanation.values, X, show=False)
plt.savefig("SHAP Dependence Plot_1.pdf", format='pdf',bbox_inches='tight',dpi=1200)
圖展示了 age(年齡) 特征對模型預(yù)測結(jié)果的 SHAP 值的依賴關(guān)系,說明不同年齡段如何影響模型的預(yù)測
從圖中可以看到:
展示了年齡對模型預(yù)測的非線性影響,同時揭示了另一個特征(thal)如何與年齡共同作用,影響預(yù)測結(jié)果,然而,與文獻(xiàn)中的圖表樣式相比,仍存在一些細(xì)微的差別繪制無顏色條的年齡 SHAP 依賴圖
# 繪制 'age' 特征的 SHAP 依賴圖,不顯示顏色條
shap.dependence_plot('age', shap_values_Explanation.values, X, interaction_index=None, show=False)
# 添加 SHAP=0 的橫線
plt.axhline(y=0, color='black', linestyle='-.', linewidth=1)
plt.savefig("SHAP Dependence Plot_2.pdf", format='pdf',bbox_inches='tight',dpi=1200)
plt.show()
在這里,通過設(shè)置 interaction_index=None 可以關(guān)閉顏色條,不顯示交互特征的影響。不過,該函數(shù)目前沒有內(nèi)置參數(shù)可以直接在 SHAP 值為 0 的位置添加一條橫線。為了實(shí)現(xiàn)這一功能,可以利用 matplotlib 的 plt.axhline() 方法,在繪制依賴圖后手動添加橫線
接下來,還可以通過 explainer.shap_values(X) 格式繪制這個shap依賴圖,以便實(shí)現(xiàn)自定義繪圖
將 SHAP 值轉(zhuǎn)換為 DataFrame 格式以便于自定義繪圖
shap_values_df = pd.DataFrame(shap_values_numpy, columns=X.columns)
shap_values_df.head()
單個shap依賴圖繪制
# 繪制散點(diǎn)圖,x軸是'age'特征,y軸是SHAP值
plt.figure(figsize=(6, 4),dpi=1200)
plt.scatter(df['age'], shap_values_df['age'], s=10)
# 添加shap=0的橫線
plt.axhline(y=0, color='black', linestyle='-.', linewidth=1)
plt.xlabel('Age', fontsize=12)
plt.ylabel('SHAP value for\nAge', fontsize=12)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.savefig("SHAP Dependence Plot_3.pdf", format='pdf',bbox_inches='tight')
plt.show()
代碼生成一個 SHAP 值依賴圖,其中展示了特征 age 對模型輸出的貢獻(xiàn),同時對圖表進(jìn)行了一些格式上的優(yōu)化,比如隱藏不必要的邊框線條、在 SHAP=0 處添加一條基準(zhǔn)線,并最終將圖像保存為高分辨率的 PDF 文件。相比于直接使用 shap.dependence_plot() 的默認(rèn)作圖方式,這種方法提供了更高的靈活性,特別是在定制化繪圖方面,可以根據(jù)不同場景、需求對圖表進(jìn)行高度定制,從而提高可視化的效果和表達(dá)的準(zhǔn)確性
多個shap依賴圖繪制
# 定義繪制 SHAP 依賴圖的函數(shù)
def plot_shap_dependence(feature_list, df, shap_values_df, file_name="SHAP_Dependence_Plots.pdf"):
fig, axs = plt.subplots(2, 3, figsize=(12, 8), dpi=1200)
plt.subplots_adjust(hspace=0.4, wspace=0.4)
# 循環(huán)繪制每個特征的 SHAP 依賴圖
for i, feature in enumerate(feature_list):
row = i // 3 # 行號
col = i % 3 # 列號
ax = axs[row, col]
# 繪制散點(diǎn)圖,x軸是特征值,y軸是SHAP值
ax.scatter(df[feature], shap_values_df[feature], s=10)
# 添加shap=0的橫線
ax.axhline(y=0, color='black', linestyle='-.', linewidth=1)
# 設(shè)置x和y軸標(biāo)簽
ax.set_xlabel(feature, fontsize=12)
ax.set_ylabel(f'SHAP value for\n{feature}', fontsize=12)
# 隱藏頂部和右側(cè)的脊柱
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# 隱藏最后一個空圖表的坐標(biāo)軸 (若畫布未關(guān)閉)
axs[1, 2].axis('off')
plt.savefig(file_name, format='pdf', bbox_inches='tight')
plt.show()
# 使用函數(shù)繪制age、trestbps、chol、thalach、oldpeak的shap依賴圖
feature_list = ['age', 'trestbps', 'chol', 'thalach', 'oldpeak']
plot_shap_dependence(feature_list, df, shap_values_df)
這段代碼定義一個函數(shù) plot_shap_dependence,用于繪制給定特征列表的 SHAP 依賴圖,生成 2 行 3 列的圖表布局,并在 SHAP=0 處添加基準(zhǔn)線,最后保存為高分辨率 PDF,該圖的樣式基本上與文獻(xiàn)中的 SHAP 依賴圖形式一致,包括散點(diǎn)圖、SHAP 值為 0 的基準(zhǔn)線、去掉頂部和右側(cè)脊柱的簡潔圖形設(shè)計等
本文章轉(zhuǎn)載微信公眾號@Python機(jī)器學(xué)習(xí)AI