高斯混合模型

高斯混合模型(GMM)?是由多個高斯分布混合而成的。假設(shè)數(shù)據(jù)集是由?k?個高斯分布組成的混合模型,那么給定數(shù)據(jù)點?x,它的概率分布可以表示為每個分布的加權(quán)和:

EM算法推導(dǎo)

由于我們不知道每個點屬于哪個高斯分布,因此 GMM 采用?EM算法(期望最大化算法)來迭代估計參數(shù)。

通過多次迭代,EM算法可以讓這些參數(shù)收斂到合適的值。

2. 案例實現(xiàn)

我們會從 Kaggle 下載一個數(shù)據(jù)集,使用 GMM 對數(shù)據(jù)進(jìn)行分類。為了演示原理,我們使用一個簡單的二維數(shù)據(jù)集。并且根據(jù)原理進(jìn)行代碼的編寫。

數(shù)據(jù)集下載

選擇一個簡單的、適合分類的二維數(shù)據(jù)集,例如 Kaggle 上的?Iris 數(shù)據(jù)集?或?Mall Customers 數(shù)據(jù)集。我們將以 Mall Customers 為例,用顧客的年收入和消費(fèi)得分來進(jìn)行聚類分析。

數(shù)據(jù)分析及可視化

我們要畫出 4 個及以上的分析圖,來逐步理解數(shù)據(jù)和模型效果。

Python 實現(xiàn)

下面是?Mall Customers 數(shù)據(jù)集?的完整 Python 實現(xiàn):

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal

# 讀取數(shù)據(jù)集
data = pd.read_csv("Mall_Customers.csv")
X = data[['Annual Income (k$)', 'Spending Score (1-100)']].values

# 初始化參數(shù)
def initialize_params_fixed(X, K):
n, d = X.shape
pi = np.ones(K) / K # 初始化每個混合成分的權(quán)重
mu = X[np.random.choice(n, K, False), :] # 隨機(jī)選擇K個初始均值
sigma = np.array([np.eye(d) for _ in range(K)]) # 初始化協(xié)方差矩陣為單位矩陣
return pi, mu, sigma

# 計算多元正態(tài)分布
def multivariate_gaussian(X, mu, sigma):
return multivariate_normal(mean=mu, cov=sigma).pdf(X)

# E 步:計算每個點屬于每個成分的責(zé)任值 (gamma)
def expectation_step_stable(X, pi, mu, sigma):
N = X.shape[0]
K = len(pi)
gamma = np.zeros((N, K))

for k in range(K):
try:
gamma[:, k] = pi[k] * multivariate_gaussian(X, mu[k], sigma[k])
except np.linalg.LinAlgError:
# 如果協(xié)方差矩陣是奇異矩陣,加入微小正則化項以確保正定性
sigma[k] += np.eye(X.shape[1]) * 1e-6
gamma[:, k] = pi[k] * multivariate_gaussian(X, mu[k], sigma[k])

# 防止零除錯誤,保證數(shù)值穩(wěn)定性
gamma_sum = np.sum(gamma, axis=1, keepdims=True)
gamma_sum[gamma_sum == 0] = 1e-10 # 防止除以零
gamma = gamma / gamma_sum

return gamma

# M 步:更新GMM的參數(shù)
def maximization_step(X, gamma):
N, d = X.shape
K = gamma.shape[1]

Nk = np.sum(gamma, axis=0) # 計算每個聚類的總責(zé)任值
pi = Nk / N # 更新混合系數(shù)
mu = np.dot(gamma.T, X) / Nk[:, np.newaxis] # 更新均值

sigma = np.zeros((K, d, d)) # 更新協(xié)方差矩陣
for k in range(K):
X_centered = X - mu[k]
gamma_diag = np.diag(gamma[:, k])
sigma[k] = np.dot(X_centered.T, np.dot(gamma_diag, X_centered)) / Nk[k]

return pi, mu, sigma

# 計算對數(shù)似然
def compute_log_likelihood(X, pi, mu, sigma):
N = X.shape[0]
K = len(pi)
log_likelihood = 0

for n in range(N):
tmp = 0
for k in range(K):
tmp += pi[k] * multivariate_gaussian(X[n], mu[k], sigma[k])
log_likelihood += np.log(tmp)

return log_likelihood

# GMM 實現(xiàn),包含數(shù)值穩(wěn)定性修復(fù)
def gmm_fixed_stable(X, K, max_iter=100, tol=1e-6):
pi, mu, sigma = initialize_params_fixed(X, K)
log_likelihoods = []

for i in range(max_iter):
# E 步
gamma = expectation_step_stable(X, pi, mu, sigma)

# M 步
pi, mu, sigma = maximization_step(X, gamma)

# 添加小的正則化項,確保協(xié)方差矩陣為正定
sigma += np.eye(sigma.shape[1]) * 1e-6

# 計算對數(shù)似然
log_likelihood = compute_log_likelihood(X, pi, mu, sigma)
log_likelihoods.append(log_likelihood)

# 檢查是否收斂
if i > 0 and abs(log_likelihoods[-1] - log_likelihoods[-2]) < tol:
break

return pi, mu, sigma, log_likelihoods, gamma

# 數(shù)據(jù)可視化:原始數(shù)據(jù)分布
def plot_original_data(X):
plt.scatter(X[:, 0], X[:, 1], c='blue', label='Data points', alpha=0.5)
plt.title('Original Data Distribution')
plt.xlabel('Annual Income (k$)')
plt.ylabel('Spending Score (1-100)')
plt.show()

# 分類結(jié)果展示
def plot_clusters(X, gamma, mu):
K = gamma.shape[1]
colors = ['r', 'g', 'b', 'y', 'm']

for k in range(K):
plt.scatter(X[:, 0], X[:, 1], c=gamma[:, k], cmap='viridis', label=f'Cluster {k+1}', alpha=0.6)

plt.scatter(mu[:, 0], mu[:, 1], c='black', marker='x', s=100, label='Centroids')
plt.title('GMM Clustering')
plt.xlabel('Annual Income (k$)')
plt.ylabel('Spending Score (1-100)')
plt.legend()
plt.show()

# 對數(shù)似然收斂圖
def plot_log_likelihood(log_likelihoods):
plt.plot(log_likelihoods)
plt.title('Log Likelihood Convergence')
plt.xlabel('Iterations')
plt.ylabel('Log Likelihood')
plt.show()

# 各類別概率分布圖
def plot_probability_distributions(gamma):
K = gamma.shape[1]
for k in range(K):
plt.hist(gamma[:, k], bins=20, alpha=0.5, label=f'Cluster {k+1}')

plt.title('Probability Distributions for Each Cluster')
plt.xlabel('Probability')
plt.ylabel('Number of Points')
plt.legend()
plt.show()

# 運(yùn)行 GMM 算法
K = 3 # 假設(shè)數(shù)據(jù)有 3 個聚類
pi, mu, sigma, log_likelihoods, gamma = gmm_fixed_stable(X, K)

# 繪制圖形
plot_original_data(X) # 原始數(shù)據(jù)分布圖
plot_clusters(X, gamma, mu) # 分類結(jié)果圖
plot_log_likelihood(log_likelihoods) # 對數(shù)似然收斂圖
plot_probability_distributions(gamma) # 各類別概率分布圖

代碼部分細(xì)節(jié)解釋:

1. 原始數(shù)據(jù)分布圖:展示客戶年收入和消費(fèi)得分的散點圖,幫助我們直觀理解數(shù)據(jù)分布情況。

2. 分類結(jié)果圖:展示 GMM 分類后每個客戶所屬的類別,以及每個類別的均值點(質(zhì)心)。

3. 對數(shù)似然收斂圖:展示對數(shù)似然值的收斂過程,判斷模型是否收斂。

4. 各類別概率分布圖:展示不同類別的概率分布,幫助理解分類的置信度。

通過高斯混合模型(GMM)的推導(dǎo)與 Python 實現(xiàn),咱們完成了從基礎(chǔ)原理到實際應(yīng)用的完整過程。有問題,大家可以評論區(qū)討論~

文章轉(zhuǎn)自微信公眾號@深夜努力寫Python

上一篇:

突破最強(qiáng)時間序列模型,移動平均??!

下一篇:

突破最強(qiáng)時間序列模型,LightGBM??!
#你可能也喜歡這些API文章!

我們有何不同?

API服務(wù)商零注冊

多API并行試用

數(shù)據(jù)驅(qū)動選型,提升決策效率

查看全部API→
??

熱門場景實測,選對API

#AI文本生成大模型API

對比大模型API的內(nèi)容創(chuàng)意新穎性、情感共鳴力、商業(yè)轉(zhuǎn)化潛力

25個渠道
一鍵對比試用API 限時免費(fèi)

#AI深度推理大模型API

對比大模型API的邏輯推理準(zhǔn)確性、分析深度、可視化建議合理性

10個渠道
一鍵對比試用API 限時免費(fèi)