Gakushukun1’s diary

20代エンジニア, 統計的機械学習勉強中 twitter: @a96665004

ADVI(PyStan)を用いて混合正規分布の推測

目的

混合正規分布の推測をStanのADVI(Automatic Differentiation Variational Inference)を用いて行う.


混合正規分布の事後分布は, StanのNUTS(No-U-Turn-Sampling)ではうまく作れないことが多いので, 代わりに変分ベイズ法のADVI(Automatic Differentiation Variational Inference)を用いて作る.
データは, 真のパラメータが既知である3個の正規分布の混合から発生する(サンプルサイズ n=600). ADVIでは, 推定されるパラメータが揺らいでいるので, 1000回得られたパラメータサンプルのうち, 後半の500個を平均して, モデルにplug-inすることで, 予測分布を構成した.
実験に使ったコードを次に示す.

import pystan
import numpy as np
from scipy.stats import multivariate_normal, norm
import pandas as pd
import matplotlib.pyplot as plt


if __name__ == '__main__':
    np.random.seed(seed=123)
    N = 600 # サンプル数
    E = [[1.0, 0.0], [0.0, 1.0]] # 単位行列
    t_a = np.array([1/3]*3) # 真の混合比
    # t_mu = np.array([[0.0, 2.0], [-np.sqrt(3), -1.0], [np.sqrt(3), -1.0]])
    t_mu = np.array([[0.0, 1.5], [-np.sqrt(3) * 2 / 3, -0.75], [np.sqrt(3) *2 / 3, -0.75]]) # 真の平均値
    t_sigma = np.array([E, E, E]) # 真の分散共分散(単位行列)
    t_z = [np.argmax(np.random.multinomial(1, t_a)) for i in range(N)] # 真の潜在変数

    # 入力するサンプル
    Y = [np.random.multivariate_normal(t_mu[t_z[i]], t_sigma[t_z[i]]) for i in range(N)]

    # Stanに入力
    data = {'N': N, 'K': 3, 'M': 2, 'Y': Y}
    sm = pystan.StanModel(file='mixnorm_.stan')
    fit = sm.vb(data=data, seed=123)

    # 結果の取得
    sample = pd.read_csv(fit['args']['sample_file'].decode('utf-8'), comment='#')
    sample = sample.drop([0, 1])

    # 予測されたパラメータ
    predict_a = [np.mean(sample['a.1'][501:1000]), np.mean(sample['a.2'][501:1000]), np.mean(sample['a.3'][501:1000])]
    predict_mu = [[np.mean(sample['mu.1.1'][501:1000]), np.mean(sample['mu.1.2'][501:1000])],
                  [np.mean(sample['mu.2.1'][501:1000]), np.mean(sample['mu.2.2'][501:1000])],
                  [np.mean(sample['mu.3.1'][501:1000]), np.mean(sample['mu.3.2'][501:1000])]]
    predict_sigma = np.mean(sample['sigma'][501:1000])
    
    x, y = np.meshgrid(np.arange(-5.0, 5.0, 0.05), np.arange(-5.0, 5.0, 0.05))
    pos = np.dstack((x, y))
    gauss = [multivariate_normal(mean=t_mu[k], cov=t_sigma[k]) for k in range(3)]
    z = sum([t_a[k] * gauss[k].pdf(pos) for k in range(3)])

    color = ['red', 'blue', 'black']

    predict_gauss = [[norm(loc=predict_mu[k][m], scale=predict_sigma) for m in range(2)] for k in range(3)]
    predict_z = [[sum([predict_a[k] * np.prod([predict_gauss[k][m].pdf(pos[i, j, m]) for m in range(2)]) for k in range(3)]) for j in range(200)] for i in range(200)]

    # プロット
    plt.figure()
    plt.subplot(2, 1, 1)
    plt.contour(x, y, z)

    for i in range(N):
        plt.scatter(Y[i][0], Y[i][1], c=color[t_z[i]], s=10)

    plt.subplot(2, 1, 2)
    plt.contour(x, y, predict_z)

    plt.show()
  • Stan
data {
    int N;
    int K;
    int M;
    vector[M] Y[N];
}

parameters {
    simplex[K] a;
    vector[M] mu[K];
    real<lower=0> sigma;    
}

transformed parameters {
    vector[K] lp[N];
    
    for (n in 1:N) {
        for (k in 1:K) {
            lp[n, k] = log(a[k]);
            for (m in 1:M) {
                lp[n, k] += normal_lpdf(Y[n, m] | mu[k, m], sigma);
            }
        }
    }
}

model {
    a ~ dirichlet(rep_vector(5.0, K));
    
    for (k in 1:K) {
        for (m in 1:M) {
            mu[k, m] ~ normal(0.0, 5);
        }
    }

    for (n in 1:N) {
        target += log_sum_exp(lp[n]);
    }
}

generated quantities {
    simplex[K] pr_z[N];

    for(n in 1:N) {
        pr_z[n] = softmax(lp[n]);
    }
}

真の分布は, 二次元のユークリッド空間上に定義され, サンプル点は二次元の点からなる.
真の分布と, サンプル点は, 次の図の通りである. 各点の色は, どの正規分布から発生したかを表しているが, 学習モデルにはその情報は与えられない(潜在変数).
f:id:gakushukun1:20190615135909p:plain

このサンプルに対して, ADVIを用いて作った予測分布を次に示す. 良好に推定されていることが分かる.
f:id:gakushukun1:20190615140257p:plain


次に, 3個の山が重なりを持つ場合について実験した. 真の分布とサンプルは次の通りである.
f:id:gakushukun1:20190615140502p:plain

このサンプルに対して, ADVIを用いて作った予測分布を次に示す. 予測分布の右下が他と比べてやや大きくなっているが, これは発生したサンプルの揺らぎが原因であると思われ, 推測自体は良好にできていると考えられる.
f:id:gakushukun1:20190615140554p:plain

実行中に,

WARNING:pystan:Automatic Differentiation Variational Inference (ADVI) is an EXPERIMENTAL ALGORITHM.

という警告が発生するが, 今回は良好に推測できていた. ただし, ここでは全ての正規分布の分散共分散行列を共通の単位行列の定数倍というモデルを用いており, コンポーネントごとに異なる分散共分散を持つモデルでは適切な推測は困難であった.

そもそも, 異なる分散共分散行列を持つ混合正規分布の推測は, 他の方法によっても困難であるので, StanのADVIは有力な方法であると考えられる.

まとめ

StanのADVIを用いて混合正規分布の学習を行った.
StanのADVIは開発中とのことであるが, 一般にNUTSは混合分布には適さないことが多いように思われるので, 今の所はADVIを使うほうがよいかもしれない.