雑記 in hibernation

頭の整理と備忘録

AutoGluonでAuto ML実装(最短実装編)


ぶっちゃけ機械学習エンジニアリングにはあまり興味がない、最低限の作業でそこそこ精度が出てくれればいいと思っている身としてはAuto MLには惹かれるものがありました。で、ちょうど最近AutoGluonの解説を聞く機会があったのですが、想像よりもはるかに手軽&高精度で驚いたので、この機会にお試しで実装してみたいと思います。

今回は最短実装編と題して、まず超絶シンプル実装でAuto MLの手軽さを実感してみたいと思います。次回の記事で、もう少し実践的な使い方を色々試してみたいと思います。


AutoGluonとは

AutoGluonはAWSが提供しているOSSのAutoMLツールキットです。データの前処理からモデリング(学習&アンサンブル&ハイパラ探索)、精度評価までの流れを、ローコードで手軽に構築することができます。また、画像や自然言語処理など、非構造化データにも対応しています。

以下のリンク先の説明がわかりやすいです。

https://pages.awscloud.com/rs/112-TZM-766/images/1.AWS_AutoML_AutoGluon.pdf


実装

ってことで、さっそく実装してみます。

毎度おなじみタイタニックの生存予測の二値分類問題を題材に、AutoGluonを最短経路で実装していきます。実行環境は例によってGoogle colabです。

実装にあたり、公式のクイックスタートを参考にしています。

auto.gluon.ai


1. ライブラリ類のインストール

まずはAutoGluonのインストールをしていきます。この辺りは各々の作業環境で異なるかもです。公式の情報をベースにしつつ、Google Colabではそのままだとエラーが出るため、あちこちを参照して試行錯誤した結果以下のコードに落ち着きました。依然として、それぞれのinstallに対して"WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. "という警告が残ってしまっているのですが、ざっと調べた感じ「パッケージのバージョンが衝突している」「こまけぇことは気にすんな」的な情報がちらほら。一旦動作はするのでスルーしますが、もやもやしますね。

!pip install --upgrade pip
!pip install --upgrade setuptools
!pip install --upgrade mxnet
!pip install autogluon.tabular
Requirement already satisfied: pip in /usr/local/lib/python3.7/dist-packages (21.3.1)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (60.2.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Requirement already satisfied: mxnet in /usr/local/lib/python3.7/dist-packages (1.9.0)
Requirement already satisfied: requests<3,>=2.20.0 in /usr/local/lib/python3.7/dist-packages (from mxnet) (2.23.0)
Requirement already satisfied: graphviz<0.9.0,>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from mxnet) (0.8.4)
Requirement already satisfied: numpy<2.0.0,>1.16.0 in /usr/local/lib/python3.7/dist-packages (from mxnet) (1.19.5)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.20.0->mxnet) (2021.10.8)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.20.0->mxnet) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.20.0->mxnet) (3.0.4)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.20.0->mxnet) (1.25.11)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Requirement already satisfied: autogluon.tabular in /usr/local/lib/python3.7/dist-packages (0.3.1)
Requirement already satisfied: pytest in /usr/local/lib/python3.7/dist-packages (from autogluon.tabular) (3.6.4)
Requirement already satisfied: numpy<1.22,>=1.19 in /usr/local/lib/python3.7/dist-packages (from autogluon.tabular) (1.19.5)
Requirement already satisfied: autogluon.core==0.3.1 in /usr/local/lib/python3.7/dist-packages (from autogluon.tabular) (0.3.1)
Requirement already satisfied: scikit-learn<0.25,>=0.23.2 in /usr/local/lib/python3.7/dist-packages (from autogluon.tabular) (0.24.2)
Requirement already satisfied: autogluon.features==0.3.1 in /usr/local/lib/python3.7/dist-packages (from autogluon.tabular) (0.3.1)
Requirement already satisfied: scipy<1.7,>=1.5.4 in /usr/local/lib/python3.7/dist-packages (from autogluon.tabular) (1.6.3)
Requirement already satisfied: networkx<3.0,>=2.3 in /usr/local/lib/python3.7/dist-packages (from autogluon.tabular) (2.6.3)
Requirement already satisfied: pandas<2.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from autogluon.tabular) (1.1.5)
Requirement already satisfied: psutil<5.9,>=5.7.3 in /usr/local/lib/python3.7/dist-packages (from autogluon.tabular) (5.8.0)
Requirement already satisfied: dask>=2.6.0 in /usr/local/lib/python3.7/dist-packages (from autogluon.core==0.3.1->autogluon.tabular) (2021.12.0)
Requirement already satisfied: ConfigSpace==0.4.19 in /usr/local/lib/python3.7/dist-packages (from autogluon.core==0.3.1->autogluon.tabular) (0.4.19)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from autogluon.core==0.3.1->autogluon.tabular) (3.2.2)
Requirement already satisfied: cython in /usr/local/lib/python3.7/dist-packages (from autogluon.core==0.3.1->autogluon.tabular) (0.29.24)
Requirement already satisfied: autograd>=1.3 in /usr/local/lib/python3.7/dist-packages (from autogluon.core==0.3.1->autogluon.tabular) (1.3)
Requirement already satisfied: tornado>=5.0.1 in /usr/local/lib/python3.7/dist-packages (from autogluon.core==0.3.1->autogluon.tabular) (5.1.1)
Requirement already satisfied: boto3 in /usr/local/lib/python3.7/dist-packages (from autogluon.core==0.3.1->autogluon.tabular) (1.20.26)
Requirement already satisfied: graphviz<1.0,>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from autogluon.core==0.3.1->autogluon.tabular) (0.8.4)
Requirement already satisfied: paramiko>=2.4 in /usr/local/lib/python3.7/dist-packages (from autogluon.core==0.3.1->autogluon.tabular) (2.9.1)
Requirement already satisfied: distributed>=2.6.0 in /usr/local/lib/python3.7/dist-packages (from autogluon.core==0.3.1->autogluon.tabular) (2021.12.0)
Requirement already satisfied: tqdm>=4.38.0 in /usr/local/lib/python3.7/dist-packages (from autogluon.core==0.3.1->autogluon.tabular) (4.62.3)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from autogluon.core==0.3.1->autogluon.tabular) (2.23.0)
Requirement already satisfied: dill<1.0,>=0.3.3 in /usr/local/lib/python3.7/dist-packages (from autogluon.core==0.3.1->autogluon.tabular) (0.3.4)
Requirement already satisfied: pyparsing in /usr/local/lib/python3.7/dist-packages (from ConfigSpace==0.4.19->autogluon.core==0.3.1->autogluon.tabular) (3.0.6)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas<2.0,>=1.0.0->autogluon.tabular) (2018.9)
Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas<2.0,>=1.0.0->autogluon.tabular) (2.8.2)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn<0.25,>=0.23.2->autogluon.tabular) (1.1.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn<0.25,>=0.23.2->autogluon.tabular) (3.0.0)
Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.7/dist-packages (from pytest->autogluon.tabular) (21.2.0)
Requirement already satisfied: py>=1.5.0 in /usr/local/lib/python3.7/dist-packages (from pytest->autogluon.tabular) (1.11.0)
Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.7/dist-packages (from pytest->autogluon.tabular) (1.15.0)
Requirement already satisfied: atomicwrites>=1.0 in /usr/local/lib/python3.7/dist-packages (from pytest->autogluon.tabular) (1.4.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from pytest->autogluon.tabular) (60.2.0)
Requirement already satisfied: pluggy<0.8,>=0.5 in /usr/local/lib/python3.7/dist-packages (from pytest->autogluon.tabular) (0.7.1)
Requirement already satisfied: more-itertools>=4.0.0 in /usr/local/lib/python3.7/dist-packages (from pytest->autogluon.tabular) (8.12.0)
Requirement already satisfied: future>=0.15.2 in /usr/local/lib/python3.7/dist-packages (from autograd>=1.3->autogluon.core==0.3.1->autogluon.tabular) (0.16.0)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from dask>=2.6.0->autogluon.core==0.3.1->autogluon.tabular) (3.13)
Requirement already satisfied: toolz>=0.8.2 in /usr/local/lib/python3.7/dist-packages (from dask>=2.6.0->autogluon.core==0.3.1->autogluon.tabular) (0.11.2)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from dask>=2.6.0->autogluon.core==0.3.1->autogluon.tabular) (21.3)
Requirement already satisfied: fsspec>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from dask>=2.6.0->autogluon.core==0.3.1->autogluon.tabular) (2021.11.1)
Requirement already satisfied: cloudpickle>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from dask>=2.6.0->autogluon.core==0.3.1->autogluon.tabular) (2.0.0)
Requirement already satisfied: partd>=0.3.10 in /usr/local/lib/python3.7/dist-packages (from dask>=2.6.0->autogluon.core==0.3.1->autogluon.tabular) (1.2.0)
Requirement already satisfied: sortedcontainers!=2.0.0,!=2.0.1 in /usr/local/lib/python3.7/dist-packages (from distributed>=2.6.0->autogluon.core==0.3.1->autogluon.tabular) (2.4.0)
Requirement already satisfied: msgpack>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from distributed>=2.6.0->autogluon.core==0.3.1->autogluon.tabular) (1.0.3)
Requirement already satisfied: tblib>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from distributed>=2.6.0->autogluon.core==0.3.1->autogluon.tabular) (1.7.0)
Requirement already satisfied: click>=6.6 in /usr/local/lib/python3.7/dist-packages (from distributed>=2.6.0->autogluon.core==0.3.1->autogluon.tabular) (7.1.2)
Requirement already satisfied: zict>=0.1.3 in /usr/local/lib/python3.7/dist-packages (from distributed>=2.6.0->autogluon.core==0.3.1->autogluon.tabular) (2.0.0)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.7/dist-packages (from distributed>=2.6.0->autogluon.core==0.3.1->autogluon.tabular) (2.11.3)
Requirement already satisfied: pynacl>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from paramiko>=2.4->autogluon.core==0.3.1->autogluon.tabular) (1.4.0)
Requirement already satisfied: cryptography>=2.5 in /usr/local/lib/python3.7/dist-packages (from paramiko>=2.4->autogluon.core==0.3.1->autogluon.tabular) (36.0.1)
Requirement already satisfied: bcrypt>=3.1.3 in /usr/local/lib/python3.7/dist-packages (from paramiko>=2.4->autogluon.core==0.3.1->autogluon.tabular) (3.2.0)
Requirement already satisfied: botocore<1.24.0,>=1.23.26 in /usr/local/lib/python3.7/dist-packages (from boto3->autogluon.core==0.3.1->autogluon.tabular) (1.23.26)
Requirement already satisfied: s3transfer<0.6.0,>=0.5.0 in /usr/local/lib/python3.7/dist-packages (from boto3->autogluon.core==0.3.1->autogluon.tabular) (0.5.0)
Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from boto3->autogluon.core==0.3.1->autogluon.tabular) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->autogluon.core==0.3.1->autogluon.tabular) (1.3.2)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->autogluon.core==0.3.1->autogluon.tabular) (0.11.0)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->autogluon.core==0.3.1->autogluon.tabular) (1.25.11)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->autogluon.core==0.3.1->autogluon.tabular) (2.10)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->autogluon.core==0.3.1->autogluon.tabular) (2021.10.8)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->autogluon.core==0.3.1->autogluon.tabular) (3.0.4)
Requirement already satisfied: cffi>=1.1 in /usr/local/lib/python3.7/dist-packages (from bcrypt>=3.1.3->paramiko>=2.4->autogluon.core==0.3.1->autogluon.tabular) (1.15.0)
Requirement already satisfied: locket in /usr/local/lib/python3.7/dist-packages (from partd>=0.3.10->dask>=2.6.0->autogluon.core==0.3.1->autogluon.tabular) (0.2.1)
Requirement already satisfied: heapdict in /usr/local/lib/python3.7/dist-packages (from zict>=0.1.3->distributed>=2.6.0->autogluon.core==0.3.1->autogluon.tabular) (1.0.1)
Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2->distributed>=2.6.0->autogluon.core==0.3.1->autogluon.tabular) (2.0.1)
Requirement already satisfied: pycparser in /usr/local/lib/python3.7/dist-packages (from cffi>=1.1->bcrypt>=3.1.3->paramiko>=2.4->autogluon.core==0.3.1->autogluon.tabular) (2.21)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv


2. インポート

必要なライブラリ類をインポートします。

# いつもの
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

# AutoGluon
from autogluon.tabular import TabularPredictor


3. データセットの用意

ここはAutoGluonそれ自体とは関係ない処理です。例題となるデータセットを読み込み、学習用と推論用に分割しておきます。


# データのインポート
data_dir = "ファイルのありか"
file_name = "train.csv"
df_in = pd.read_csv(data_dir + "/" + file_name)

# データの分割
df_train, df_test = train_test_split(df_in, test_size=0.3)
print("train : " , df_train.shape)
print("test : " , df_test.shape)

df_in.head()
PassengerId Survived Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked
0 1 0 3 Braund, Mr. Owen Harris male 22 1 0 A/5 21171 7.25 NaN S
1 2 1 1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38 1 0 PC 17599 71.2833 C85 C
2 3 1 3 Heikkinen, Miss. Laina female 26 0 0 STON/O2. 3101282 7.925 NaN S
3 4 1 1 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35 1 0 113803 53.1 C123 S
4 5 0 3 Allen, Mr. William Henry male 35 0 0 373450 8.05 NaN S


4. Auto Gluonによる学習・推論

さて、ここからが本番です。AutoGluonによるモデリングを行なっていきます。


4.1. 学習

目的変数のカラムを指定してfit()を呼び出すだけで学習できちゃいます。欠損処理やエンコーディングなどは一切不要で、とにかく訓練用セットをぶち込むだけ、という圧倒的楽々仕様です。学習後の出力結果に”AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).”とあるように、今回の問題が二値分類であることもちゃんと見抜いています。

label = "Survived"
predictor = TabularPredictor(label=label).fit(df_train)
No path specified. Models will be saved in: "AutogluonModels/ag-20220103_191942/"
Beginning AutoGluon training ...
AutoGluon will save models to "AutogluonModels/ag-20220103_191942/"
AutoGluon Version:  0.3.1
Train Data Rows:    623
Train Data Columns: 11
Preprocessing data ...
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
    2 unique label values:  [1, 0]
    If 'binary' is not the correct problem_type, please manually specify the problem_type argument in fit() (You may specify problem_type as one of: ['binary', 'multiclass', 'regression'])
Selected class <--> label mapping:  class 1 = 1, class 0 = 0
Using Feature Generators to preprocess the data ...
Fitting AutoMLPipelineFeatureGenerator...
    Available Memory:                    12092.11 MB
    Train Data (Original)  Memory Usage: 0.22 MB (0.0% of available memory)
    Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
    Stage 1 Generators:
        Fitting AsTypeFeatureGenerator...
            Note: Converting 1 features to boolean dtype as they only contain 2 unique values.
    Stage 2 Generators:
        Fitting FillNaFeatureGenerator...
    Stage 3 Generators:
        Fitting IdentityFeatureGenerator...
        Fitting CategoryFeatureGenerator...
            Fitting CategoryMemoryMinimizeFeatureGenerator...
        Fitting TextSpecialFeatureGenerator...
            Fitting BinnedFeatureGenerator...
            Fitting DropDuplicatesFeatureGenerator...
        Fitting TextNgramFeatureGenerator...
            Fitting CountVectorizer for text features: ['Name']
            CountVectorizer fit with vocabulary size = 5
    Stage 4 Generators:
        Fitting DropUniqueFeatureGenerator...
    Types of features in original data (raw dtype, special dtypes):
        ('float', [])        : 2 | ['Age', 'Fare']
        ('int', [])          : 4 | ['PassengerId', 'Pclass', 'SibSp', 'Parch']
        ('object', [])       : 4 | ['Sex', 'Ticket', 'Cabin', 'Embarked']
        ('object', ['text']) : 1 | ['Name']
    Types of features in processed data (raw dtype, special dtypes):
        ('category', [])                    : 3 | ['Ticket', 'Cabin', 'Embarked']
        ('float', [])                       : 2 | ['Age', 'Fare']
        ('int', [])                         : 4 | ['PassengerId', 'Pclass', 'SibSp', 'Parch']
        ('int', ['binned', 'text_special']) : 9 | ['Name.char_count', 'Name.word_count', 'Name.capital_ratio', 'Name.lower_ratio', 'Name.special_ratio', ...]
        ('int', ['bool'])                   : 1 | ['Sex']
        ('int', ['text_ngram'])             : 6 | ['__nlp__.john', '__nlp__.miss', '__nlp__.mr', '__nlp__.mrs', '__nlp__.william', ...]
    0.4s = Fit runtime
    11 features in original data used to generate 25 features in processed data.
    Train Data (Processed) Memory Usage: 0.05 MB (0.0% of available memory)
Data preprocessing and feature engineering runtime = 0.46s ...
AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'
    To change this, specify the eval_metric argument of fit()
Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 498, Val Rows: 125
Fitting 13 L1 models ...
Fitting model: KNeighborsUnif ...
    0.632   = Validation score   (accuracy)
    0.02s   = Training   runtime
    0.11s   = Validation runtime
Fitting model: KNeighborsDist ...
    0.656   = Validation score   (accuracy)
    0.02s   = Training   runtime
    0.1s    = Validation runtime
Fitting model: LightGBMXT ...
/usr/local/lib/python3.7/dist-packages/lightgbm/engine.py:239: UserWarning: 'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. Pass 'log_evaluation()' callback via 'callbacks' argument instead.
  _log_warning("'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. "
    0.816   = Validation score   (accuracy)
    0.73s   = Training   runtime
    0.01s   = Validation runtime
Fitting model: LightGBM ...
/usr/local/lib/python3.7/dist-packages/lightgbm/engine.py:239: UserWarning: 'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. Pass 'log_evaluation()' callback via 'callbacks' argument instead.
  _log_warning("'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. "
    0.832   = Validation score   (accuracy)
    0.27s   = Training   runtime
    0.01s   = Validation runtime
Fitting model: RandomForestGini ...
    0.792   = Validation score   (accuracy)
    0.79s   = Training   runtime
    0.11s   = Validation runtime
Fitting model: RandomForestEntr ...
    0.8     = Validation score   (accuracy)
    0.76s   = Training   runtime
    0.11s   = Validation runtime
Fitting model: CatBoost ...
    0.816   = Validation score   (accuracy)
    0.59s   = Training   runtime
    0.01s   = Validation runtime
Fitting model: ExtraTreesGini ...
    0.792   = Validation score   (accuracy)
    0.76s   = Training   runtime
    0.11s   = Validation runtime
Fitting model: ExtraTreesEntr ...
    0.8     = Validation score   (accuracy)
    0.76s   = Training   runtime
    0.11s   = Validation runtime
Fitting model: NeuralNetFastAI ...
    0.848   = Validation score   (accuracy)
    7.38s   = Training   runtime
    0.03s   = Validation runtime
Fitting model: XGBoost ...
    0.824   = Validation score   (accuracy)
    0.3s    = Training   runtime
    0.01s   = Validation runtime
Fitting model: NeuralNetMXNet ...
    0.768   = Validation score   (accuracy)
    10.17s  = Training   runtime
    0.17s   = Validation runtime
Fitting model: LightGBMLarge ...
/usr/local/lib/python3.7/dist-packages/lightgbm/engine.py:239: UserWarning: 'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. Pass 'log_evaluation()' callback via 'callbacks' argument instead.
  _log_warning("'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. "
    0.824   = Validation score   (accuracy)
    0.42s   = Training   runtime
    0.01s   = Validation runtime
Fitting model: WeightedEnsemble_L2 ...
    0.848   = Validation score   (accuracy)
    0.42s   = Training   runtime
    0.0s    = Validation runtime
AutoGluon training complete, total runtime = 26.12s ...
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("AutogluonModels/ag-20220103_191942/")


学習の結果のサマリも確認してみます。アンサンブルを含む様々なモデリングが試行されていて、結果的に"NeuralNetFastAI"がベストだったようです。


fit_summary = predictor.fit_summary(show_plot=True)
*** Summary of fit() ***
Estimated performance of each model:
                  model  score_val  pred_time_val   fit_time  pred_time_val_marginal  fit_time_marginal  stack_level  can_infer  fit_order
0       NeuralNetFastAI      0.848       0.026925   7.384434                0.026925           7.384434            1       True         10
1   WeightedEnsemble_L2      0.848       0.027650   7.804674                0.000725           0.420240            2       True         14
2              LightGBM      0.832       0.011829   0.273988                0.011829           0.273988            1       True          4
3               XGBoost      0.824       0.010160   0.298455                0.010160           0.298455            1       True         11
4         LightGBMLarge      0.824       0.011892   0.415987                0.011892           0.415987            1       True         13
5              CatBoost      0.816       0.008856   0.594643                0.008856           0.594643            1       True          7
6            LightGBMXT      0.816       0.012459   0.727325                0.012459           0.727325            1       True          3
7      RandomForestEntr      0.800       0.107793   0.761485                0.107793           0.761485            1       True          6
8        ExtraTreesEntr      0.800       0.108856   0.762689                0.108856           0.762689            1       True          9
9      RandomForestGini      0.792       0.107973   0.787341                0.107973           0.787341            1       True          5
10       ExtraTreesGini      0.792       0.108095   0.757838                0.108095           0.757838            1       True          8
11       NeuralNetMXNet      0.768       0.170961  10.173865                0.170961          10.173865            1       True         12
12       KNeighborsDist      0.656       0.104259   0.015277                0.104259           0.015277            1       True          2
13       KNeighborsUnif      0.632       0.105555   0.016072                0.105555           0.016072            1       True          1
Number of models trained: 14
Types of models trained:
{'CatBoostModel', 'WeightedEnsembleModel', 'NNFastAiTabularModel', 'XGBoostModel', 'TabularNeuralNetModel', 'LGBModel', 'KNNModel', 'RFModel', 'XTModel'}
Bagging used: False 
Multi-layer stack-ensembling used: False 
Feature Metadata (Processed):
(raw dtype, special dtypes):
('category', [])                    : 3 | ['Ticket', 'Cabin', 'Embarked']
('float', [])                       : 2 | ['Age', 'Fare']
('int', [])                         : 4 | ['PassengerId', 'Pclass', 'SibSp', 'Parch']
('int', ['binned', 'text_special']) : 9 | ['Name.char_count', 'Name.word_count', 'Name.capital_ratio', 'Name.lower_ratio', 'Name.special_ratio', ...]
('int', ['bool'])                   : 1 | ['Sex']
('int', ['text_ngram'])             : 6 | ['__nlp__.john', '__nlp__.miss', '__nlp__.mr', '__nlp__.mrs', '__nlp__.william', ...]
Plot summary of models saved to file: AutogluonModels/ag-20220103_191942/SummaryOfModels.html
*** End of fit() summary ***


4.2. 推論

テストセットに対して推論を行います。特に指定がない限り、学習時の精度が最も高かった"NeuralNetFastAI"が推論に用いられます。


predictions = predictor.predict(df_test)
predictions
269    1
544    0
59     0
655    0
566    0
      ..
653    1
587    0
159    0
570    0
123    1
Name: Survived, Length: 268, dtype: int64


4.3. 評価

テストセットの推論に対する精度を評価してみます。正答率で0.81ということで、そこそこの精度で推論できてそうです。


perf = predictor.evaluate_predictions(y_true=df_test[label], y_pred=predictions)
Evaluation: accuracy on test data: 0.8134328358208955
Evaluations on test data:
{
    "accuracy": 0.8134328358208955,
    "balanced_accuracy": 0.7895168855534709,
    "mcc": 0.6008483860727447,
    "f1": 0.7395833333333334,
    "precision": 0.8068181818181818,
    "recall": 0.6826923076923077
}


おわりに

ということで、途中経過の表示と解説のためにセルを分けて記載はしましたが、圧縮すれば3行(インポート含む)で前処理から推論までできちゃっています。精度も決して悪くないです。

次回、もうちょっと実践志向で色々いじってみたいと思います。