雑記 in hibernation

頭の整理と備忘録

追い詰められたのでShap入門します

本職でクソモデルをこしらえた結果、モデルの中身に対する説明責任が発生してしまいました。逃げ場を失ったので素直にShapに入門します。


1. Shapとは

ビジネスの場で機械学習モデルを適用したり改善したりする場合、各変数が予測値に対してどのような影響を与えているのかを理解することは重要です。しかし「とりあえずlightGBMに思いつく限りの変数をぶち込めばOK」ってなノリでモデリングした結果生まれるのは、往々にして「ブラックボックスだけどそこそこ精度はでてんだよな〜」的なサムシングです。それを素直に「ブラックボックスだけどそこそこ精度はでてんすよ」などと報告しようものなら上司とクライアントにすっ叩かれて終了なわけで、そこで登場するのが機械学習モデルに一定の解釈性・説明性を担保する「説明可能な AI(XAI)」という概念です。そしてShapはそれを実現する技術の一つであり、OSSとしてライブラリが提供されています。

Shapの原理に関する説明は、以下のページの概説が大変わかりやすいです。

www.datarobot.com


ざっくりいうと、ゲーム理論において成果に対するプレイヤーの寄与度を定量化する手法「シャープレイ値(Shapley Value)」を応用して、機械学習モデルのアウトプットに対する各変数の寄与度を推し量ろう、という考え方です。シャープレイ値の計算と同等の処理フローをそのまま実行した場合、説明変数の数に応じて計算量が爆発してしまいます。そこで、Shapでは近似的に計算して処理コストを抑えることで、現実的な処理時間で寄与度の大きさ(Shap値)を算出できます。


2. 使ってみる

さて、実際にShapを利用して説明変数と予測結果の関係を可視化して見ましょう。

2.1. モデルを用意する

なにはともあれ、まずは学習済みの機械学習モデルを用意しましょう。Shapのライブラリに用意されているボストン住宅価格のデータセットを利用し、回帰タスクを解くlightGBMモデルを構築します。

ライブラリの読み込み

import numpy as np
import pandas as pd
import lightgbm as lgb
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split

import shap


データセットの読み込み

# 目的変数の名前
target_col = "MEDV"

# 検証用データのインポート
df_input, y = shap.datasets.boston()
df_input[target_col] = y 
print(df_input.shape)
df_input.head()
(506, 14)
Function load_boston is deprecated; `load_boston` is deprecated in 1.0 and will be removed in 1.2.

    The Boston housing prices dataset has an ethical problem. You can refer to
    the documentation of this function for further details.

    The scikit-learn maintainers therefore strongly discourage the use of this
    dataset unless the purpose of the code is to study and educate about
    ethical issues in data science and machine learning.

    In this special case, you can fetch the dataset from the original
    source::

        import pandas as pd
        import numpy as np


        data_url = "http://lib.stat.cmu.edu/datasets/boston"
        raw_df = pd.read_csv(data_url, sep="\s+", skiprows=22, header=None)
        data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]])
        target = raw_df.values[1::2, 2]

    Alternative datasets include the California housing dataset (i.e.
    :func:`~sklearn.datasets.fetch_california_housing`) and the Ames housing
    dataset. You can load the datasets as follows::

        from sklearn.datasets import fetch_california_housing
        housing = fetch_california_housing()

    for the California housing dataset and::

        from sklearn.datasets import fetch_openml
        housing = fetch_openml(name="house_prices", as_frame=True)

    for the Ames housing dataset.
CRIM ZN INDUS CHAS NOX RM AGE DIS RAD TAX PTRATIO B LSTAT MEDV
0 0.00632 18 2.31 0 0.538 6.575 65.2 4.09 1 296 15.3 396.9 4.98 24
1 0.02731 0 7.07 0 0.469 6.421 78.9 4.9671 2 242 17.8 396.9 9.14 21.6
2 0.02729 0 7.07 0 0.469 7.185 61.1 4.9671 2 242 17.8 392.83 4.03 34.7
3 0.03237 0 2.18 0 0.458 6.998 45.8 6.0622 3 222 18.7 394.63 2.94 33.4
4 0.06905 0 2.18 0 0.458 7.147 54.2 6.0622 3 222 18.7 396.9 5.33 36.2


データセットの分割

 データセットの分割
df_train, df_test = train_test_split(df_input, test_size = 0.3, random_state = 1111)
df_train, df_valid = train_test_split(df_train, test_size = 0.3, random_state = 1111)
print("train", df_train.shape)
print("valid", df_valid.shape)
print("test", df_test.shape)

# LightGBM用のデータセットに入れる
lgb_train = lgb.Dataset(df_train.drop([target_col], axis=1), df_train[target_col])
lgb_eval = lgb.Dataset(df_valid.drop([target_col], axis=1), df_valid[target_col])
lgb_test = lgb.Dataset(df_test.drop([target_col], axis=1), df_test[target_col])
train (247, 14)
valid (107, 14)
test (152, 14)


モデルの学習

# パラメタ設定
params = {
    'objective': 'regression',
    'metric' : 'rmse',
    }

# 学習
model = lgb.train(
        params,
        lgb_train,
        valid_sets=lgb_eval,
        verbose_eval=10,
        early_stopping_rounds = 10
    )

# スコアを出す
print("`\nscore")
y_pred_train = model.predict(df_train.drop([target_col], axis=1))
score = mean_absolute_error(df_train[target_col], y_pred_train)
print(" train:", score)

y_pred_valid = model.predict(df_valid.drop([target_col], axis=1))
score = mean_absolute_error(df_valid[target_col], y_pred_valid)
print(" eval:", score)

y_pred_test = model.predict(df_test.drop([target_col], axis=1))
score = mean_absolute_error(df_test[target_col], y_pred_test)
print(" test:", score)
Training until validation scores don't improve for 10 rounds.
[10]  valid_0's rmse: 5.52729
[20]   valid_0's rmse: 4.62637
[30]  valid_0's rmse: 4.45367
[40]   valid_0's rmse: 4.33433
[50]  valid_0's rmse: 4.3036
[60]   valid_0's rmse: 4.29522
[70]  valid_0's rmse: 4.27483
[80]   valid_0's rmse: 4.27121
Early stopping, best iteration is:
[78]  valid_0's rmse: 4.26293
`
score
 train: 1.284695603988035
 eval: 2.7761173036246194
 test: 2.0338790049942674


はい、ということでモデルができたわけですが、このモデルを理解しようと思ったとき、僕に思いつく(そして実装できる)範囲でとりうる手段はせいぜい以下くらいなものです。

  • 特徴量の重要度を出してみる
  • 各変数を切り口として誤差分布や予測vs正解の散布図を出してみる

まあこれはこれで得られる示唆は色々あるとは思いますが、Shapを利用すればもっと突っ込んだ情報を得ることができそうです。


2.2. Shapの準備

モデルからShap値を計算するためのインスタンスを用意し、これを利用してデータセットから各サンプルに対するShap値(とそれにまつわる諸々の値)を算出します。今回は決定木のブースティングによるモデルなのでTreeExplainer()を利用していますが、これは利用するモデル(例えばNNや線形回帰など)によって適宜使い分ける必要があります。詳しくは公式のドキュメントをご参照ください。

# 説明変数のみ抽出したデータフレーム
X=df_input.drop([target_col], axis=1)

# モデルからShap値を計算するためにインスタンス
explainer = shap.TreeExplainer(model)

# Shap値にまつわる諸々の計算値を格納
shap_values = explainer(X)


ためしにshap_valuesの中身を覗いてみましょう。101番目のサンプルを指定して中身を見てみると、Shap値、ベースライン(これは全てのサンプルに共通の値)、元のデータセットの値がそれぞれ格納されていることがわかります。

sample_idx = 100
shap_values[sample_idx]
.values =
array([ 1.83374644e-01, -3.60976797e-03, -5.98876067e-02,  0.00000000e+00,
       -8.41275585e-02, -1.55915002e+00, -1.09658240e+00,  2.13409308e+00,
        5.06677749e-02,  4.36630215e-02, -9.08025713e-01, -2.04534972e-02,
        4.63101128e+00])

.base_values =
23.172064821751373

.data =
array([1.4866e-01, 0.0000e+00, 8.5600e+00, 0.0000e+00, 5.2000e-01,
       6.7270e+00, 7.9900e+01, 2.7778e+00, 5.0000e+00, 3.8400e+02,
       2.0900e+01, 3.9476e+02, 9.4200e+00])


2.3. Shap値を可視化してみる

さて、以降はshapに実装されている関数を用いて前段で取得したshap値を様々な観点から可視化していきます。

2.3.1. データセット全体の予測に対して重要な説明変数は?

summary_plot()で全サンプルのShap値の絶対値平均を説明変数ごとにプロットできます。feature_importanceのShap値版って感じですね。

サンプル全体としては、低所得者人口の割合(LSTAT)と1戸当たりの平均部屋数(RM)が予測に強く効いているようです。

shap.summary_plot(shap_values, X, plot_type="bar")

f:id:toeming:20220409222741p:plain


2.3.2. 説明変数のShap値分布は?予測値との相関は?

summary_plot()で各説明変数に対してShap値を横軸にとって各サンプルをプロットできます。このとき、説明変数の値は各ドットの色として表現されます。

低所得者人口の割合(LSTAT)ではShap値が低いほど変数の値が高い、つまり負の相関の傾向が確認できます。一方で戸当たりの平均部屋数(RM)はその逆で、正の相関がみられます。また、RMに関して、一部のサンプルでは予測値の増加にとても強く影響しているようです。単純な相関が見られない変数としては、例えば非小売業の割合(INDUS)は変数の水準が低いことが予測値の若干の増加または減少を引き起こしていることがわかります(ただし、基本的に全てのサンプルに対して予測値に対する影響力そのものは小さい)。

あとどうでもいいですが、このカラースケールのプロット見ると、なんというかShap感を強く感じます。

shap.summary_plot(shap_values, X)

f:id:toeming:20220409222909p:plain


dependence_plot()では特定の説明変数に絞って、説明変数の値を横軸、Shap値を縦軸として各サンプルをプロットできます。また、Shap値のばらつきに対して交互作用が強いと思われる変数の値で色分けされます(これ、個人的に特に面白いと思ったポイントです)。

ためしに戸当たりの平均部屋数(RM)を見てみます。6.5部屋くらいが予測値の正負のどちらに寄与するかの境目になっていて、この境界を超えてからは部屋数が増えるほど予測値を引き上げる効果が高くなるようです。7.5部屋以上ではShap値が顕著に高いサンプル群がありますが、こういったイレギュラーな傾向があるサンプルは、本来であれば予測の正確性も含めて深掘りしておきたい部分ですね。

プロットからわかるように同じ部屋数でも(当然)Shap値にはばらつきが発生するわけですが、RMの場合 、ばらつきの主な要因となっている変数は主要施設への距離(DIS)であり、これが色分けの基準に採用されています。距離が遠い場合6.5部屋以下ではShap値を引き下げ、それ以上では引き上げる、という交互作用を見て取ることができます。

shap.dependence_plot(ind="RM", shap_values=shap_values.values, features=X)

f:id:toeming:20220409222824p:plain


2.3.3. 説明変数同士の関係は?

先述の通り、特に交互作用が強い変数に対するプロットはdependence_plot()により自動で選択されてプロットの色分けとして出力できますが、任意の変数を指定することもできます。

以下のグラフでは低所得者人口の割合(LSTAT)を指定しています。

shap.dependence_plot(ind="RM",interaction_index='LSTAT', shap_values=shap_values.values, features=X)

f:id:toeming:20220409222947p:plain

縦軸と横軸に変数の値、色分けをShap値、みたいなプロットができるといいなと思ったのですが、そういった可視化関数の情報は見つかりませんでした(shap_valuesの値を使えば関数がなくても自力で可視化することはできるはず)。


2.3.4. 特定のサンプルに対する予測値の根拠は?

force_plot()ではサンプルの行番号を指定することで、指定したサンプルに対して説明変数の予測における各変数のShap値の内訳を可視化することができます。実際のユースケースとしては、予測精度が甘いサンプルをサンプリングして可視化することで誤差の要因を探る、みたいな使い方になる気がします。なお、実装上の注意として、Google Colaboratoryではセルごとにshap.initjs()を実行しないとエラーになります。

ためしにデータセットの先頭のサンプルを可視化してみました。このサンプルでは、主に低所得者人口の割合(LSTAT)や生徒と先生の比率(PTRATIO)が予測値を引き上げる一方、戸当たりの平均部屋数(RM)や一酸化窒素濃度(NOX)が予測値を引き下げていることがわかります。

shap.initjs()

sample_idx = 0 #sample_idx : 何番目のサンプルの情報をplotするか指定
force_plot = shap.force_plot(base_value=explainer.expected_value,
                shap_values=shap_values.values[sample_idx],
                features=X.iloc[sample_idx,:]
                )
force_plot

f:id:toeming:20220409225958p:plain


force_plot()で複数サンプルの行番号を指定した場合、サンプルを横一列に並べて先のShap値の積み上がり方を可視化することもできます。また、横軸と縦軸に表示されているタブから、サンプルの並び順の基準を変える、縦軸の指標を各変数のShap値に変更する、といったことも可能です。

shap.initjs()

n=100 #sample_idx : 何番目のサンプルの情報をplotするか指定
force_plot = shap.force_plot(
    base_value=explainer.expected_value,
    shap_values=shap_values.values[sample_idx:sample_idx+n],
    features=X.iloc[sample_idx:sample_idx+n,:])
force_plot

f:id:toeming:20220409230014p:plain


plots.waterfall()では、指定したサンプルについて、ベースラインの値からどのように変数ごとのShap値が積み上がって予測値に至ったかを可視化できます。

shap.plots.waterfall(shap_values[sample_idx])

f:id:toeming:20220409223302p:plain


複数のサンプルを並べてShap値の積み上がり方を確認したい場合は、decision_plot()を使うことで、plots.waterfall()と同様の意味合いのプロットを複数のサンプルに対して折線グラフで並べて出力できます。

ためしにデータセットの先頭から5番目までの5つのサンプルをプロットすると、以下のようになります。

n=5
shap.decision_plot(base_value=explainer.expected_value,
                   shap_values=shap_values.values[sample_idx:sample_idx+n],
                   features=X.iloc[sample_idx:sample_idx+n,:],
                   feature_names=X.columns.tolist())

f:id:toeming:20220409222009p:plain


いったん僕が把握できた範囲の可視化は以上です。

3. おわりに

ということで、Shap入門でした。これでモデルの挙動は掌握したも同然。「あーそーゆーことね 完全に理解した(←わかってない)」状態です。

他にも面白そうな可視化の仕方があったら随時更新していきたいです。