{
"cells": [
{
"cell_type": "markdown",
"id": "3f3af765",
"metadata": {},
"source": [
"# Multi-Label Prediction\n",
"\n",
"사용자 정의 클래스를 추가하여, 다중 레이블에 대해서도 쉽게 훈련을 수행할 수 있습니다. 본 핸즈온을 통해 다중 레이블 훈련의 예시를 살펴 보겠습니다."
]
},
{
"cell_type": "markdown",
"id": "573971e8",
"metadata": {},
"source": [
"## 1. MultilabelPredictor Class\n",
"\n",
"사용자 지정 MultilabelPredictor 클래스를 정의하여 각 레이블에 대해 하나씩 TabularPredictor 개체 컬렉션을 관리합니다."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0ce05726",
"metadata": {},
"outputs": [],
"source": [
"from autogluon.tabular import TabularDataset, TabularPredictor\n",
"from autogluon.common.utils.utils import setup_outputdir\n",
"from autogluon.core.utils.loaders import load_pkl\n",
"from autogluon.core.utils.savers import save_pkl\n",
"import os.path\n",
"\n",
"class MultilabelPredictor():\n",
" \"\"\" Tabular Predictor for predicting multiple columns in table.\n",
" Creates multiple TabularPredictor objects which you can also use individually.\n",
" You can access the TabularPredictor for a particular label via: `multilabel_predictor.get_predictor(label_i)`\n",
"\n",
" Parameters\n",
" ----------\n",
" labels : List[str]\n",
" The ith element of this list is the column (i.e. `label`) predicted by the ith TabularPredictor stored in this object.\n",
" path : str, default = None\n",
" Path to directory where models and intermediate outputs should be saved.\n",
" If unspecified, a time-stamped folder called \"AutogluonModels/ag-[TIMESTAMP]\" will be created in the working directory to store all models.\n",
" Note: To call `fit()` twice and save all results of each fit, you must specify different `path` locations or don't specify `path` at all.\n",
" Otherwise files from first `fit()` will be overwritten by second `fit()`.\n",
" Caution: when predicting many labels, this directory may grow large as it needs to store many TabularPredictors.\n",
" problem_types : List[str], default = None\n",
" The ith element is the `problem_type` for the ith TabularPredictor stored in this object.\n",
" eval_metrics : List[str], default = None\n",
" The ith element is the `eval_metric` for the ith TabularPredictor stored in this object.\n",
" consider_labels_correlation : bool, default = True\n",
" Whether the predictions of multiple labels should account for label correlations or predict each label independently of the others.\n",
" If True, the ordering of `labels` may affect resulting accuracy as each label is predicted conditional on the previous labels appearing earlier in this list (i.e. in an auto-regressive fashion).\n",
" Set to False if during inference you may want to individually use just the ith TabularPredictor without predicting all the other labels.\n",
" kwargs :\n",
" Arguments passed into the initialization of each TabularPredictor.\n",
"\n",
" \"\"\"\n",
"\n",
" multi_predictor_file = 'multilabel_predictor.pkl'\n",
"\n",
" def __init__(self, labels, path=None, problem_types=None, eval_metrics=None, consider_labels_correlation=True, **kwargs):\n",
" if len(labels) < 2:\n",
" raise ValueError(\"MultilabelPredictor is only intended for predicting MULTIPLE labels (columns), use TabularPredictor for predicting one label (column).\")\n",
" if (problem_types is not None) and (len(problem_types) != len(labels)):\n",
" raise ValueError(\"If provided, `problem_types` must have same length as `labels`\")\n",
" if (eval_metrics is not None) and (len(eval_metrics) != len(labels)):\n",
" raise ValueError(\"If provided, `eval_metrics` must have same length as `labels`\")\n",
" self.path = setup_outputdir(path, warn_if_exist=False)\n",
" self.labels = labels\n",
" self.consider_labels_correlation = consider_labels_correlation\n",
" self.predictors = {} # key = label, value = TabularPredictor or str path to the TabularPredictor for this label\n",
" if eval_metrics is None:\n",
" self.eval_metrics = {}\n",
" else:\n",
" self.eval_metrics = {labels[i] : eval_metrics[i] for i in range(len(labels))}\n",
" problem_type = None\n",
" eval_metric = None\n",
" for i in range(len(labels)):\n",
" label = labels[i]\n",
" path_i = self.path + \"Predictor_\" + label\n",
" if problem_types is not None:\n",
" problem_type = problem_types[i]\n",
" if eval_metrics is not None:\n",
" eval_metric = eval_metrics[i]\n",
" self.predictors[label] = TabularPredictor(label=label, problem_type=problem_type, eval_metric=eval_metric, path=path_i, **kwargs)\n",
"\n",
" def fit(self, train_data, tuning_data=None, **kwargs):\n",
" \"\"\" Fits a separate TabularPredictor to predict each of the labels.\n",
"\n",
" Parameters\n",
" ----------\n",
" train_data, tuning_data : str or autogluon.tabular.TabularDataset or pd.DataFrame\n",
" See documentation for `TabularPredictor.fit()`.\n",
" kwargs :\n",
" Arguments passed into the `fit()` call for each TabularPredictor.\n",
" \"\"\"\n",
" if isinstance(train_data, str):\n",
" train_data = TabularDataset(train_data)\n",
" if tuning_data is not None and isinstance(tuning_data, str):\n",
" tuning_data = TabularDataset(tuning_data)\n",
" train_data_og = train_data.copy()\n",
" if tuning_data is not None:\n",
" tuning_data_og = tuning_data.copy()\n",
" else:\n",
" tuning_data_og = None\n",
" save_metrics = len(self.eval_metrics) == 0\n",
" for i in range(len(self.labels)):\n",
" label = self.labels[i]\n",
" predictor = self.get_predictor(label)\n",
" if not self.consider_labels_correlation:\n",
" labels_to_drop = [l for l in self.labels if l != label]\n",
" else:\n",
" labels_to_drop = [self.labels[j] for j in range(i+1, len(self.labels))]\n",
" train_data = train_data_og.drop(labels_to_drop, axis=1)\n",
" if tuning_data is not None:\n",
" tuning_data = tuning_data_og.drop(labels_to_drop, axis=1)\n",
" print(f\"Fitting TabularPredictor for label: {label} ...\")\n",
" predictor.fit(train_data=train_data, tuning_data=tuning_data, **kwargs)\n",
" self.predictors[label] = predictor.path\n",
" if save_metrics:\n",
" self.eval_metrics[label] = predictor.eval_metric\n",
" self.save()\n",
"\n",
" def predict(self, data, **kwargs):\n",
" \"\"\" Returns DataFrame with label columns containing predictions for each label.\n",
"\n",
" Parameters\n",
" ----------\n",
" data : str or autogluon.tabular.TabularDataset or pd.DataFrame\n",
" Data to make predictions for. If label columns are present in this data, they will be ignored. See documentation for `TabularPredictor.predict()`.\n",
" kwargs :\n",
" Arguments passed into the predict() call for each TabularPredictor.\n",
" \"\"\"\n",
" return self._predict(data, as_proba=False, **kwargs)\n",
"\n",
" def predict_proba(self, data, **kwargs):\n",
" \"\"\" Returns dict where each key is a label and the corresponding value is the `predict_proba()` output for just that label.\n",
"\n",
" Parameters\n",
" ----------\n",
" data : str or autogluon.tabular.TabularDataset or pd.DataFrame\n",
" Data to make predictions for. See documentation for `TabularPredictor.predict()` and `TabularPredictor.predict_proba()`.\n",
" kwargs :\n",
" Arguments passed into the `predict_proba()` call for each TabularPredictor (also passed into a `predict()` call).\n",
" \"\"\"\n",
" return self._predict(data, as_proba=True, **kwargs)\n",
"\n",
" def evaluate(self, data, **kwargs):\n",
" \"\"\" Returns dict where each key is a label and the corresponding value is the `evaluate()` output for just that label.\n",
"\n",
" Parameters\n",
" ----------\n",
" data : str or autogluon.tabular.TabularDataset or pd.DataFrame\n",
" Data to evalate predictions of all labels for, must contain all labels as columns. See documentation for `TabularPredictor.evaluate()`.\n",
" kwargs :\n",
" Arguments passed into the `evaluate()` call for each TabularPredictor (also passed into the `predict()` call).\n",
" \"\"\"\n",
" data = self._get_data(data)\n",
" eval_dict = {}\n",
" for label in self.labels:\n",
" print(f\"Evaluating TabularPredictor for label: {label} ...\")\n",
" predictor = self.get_predictor(label)\n",
" eval_dict[label] = predictor.evaluate(data, **kwargs)\n",
" if self.consider_labels_correlation:\n",
" data[label] = predictor.predict(data, **kwargs)\n",
" return eval_dict\n",
"\n",
" def save(self):\n",
" \"\"\" Save MultilabelPredictor to disk. \"\"\"\n",
" for label in self.labels:\n",
" if not isinstance(self.predictors[label], str):\n",
" self.predictors[label] = self.predictors[label].path\n",
" save_pkl.save(path=self.path+self.multi_predictor_file, object=self)\n",
" print(f\"MultilabelPredictor saved to disk. Load with: MultilabelPredictor.load('{self.path}')\")\n",
"\n",
" @classmethod\n",
" def load(cls, path):\n",
" \"\"\" Load MultilabelPredictor from disk `path` previously specified when creating this MultilabelPredictor. \"\"\"\n",
" path = os.path.expanduser(path)\n",
" if path[-1] != os.path.sep:\n",
" path = path + os.path.sep\n",
" return load_pkl.load(path=path+cls.multi_predictor_file)\n",
"\n",
" def get_predictor(self, label):\n",
" \"\"\" Returns TabularPredictor which is used to predict this label. \"\"\"\n",
" predictor = self.predictors[label]\n",
" if isinstance(predictor, str):\n",
" return TabularPredictor.load(path=predictor)\n",
" return predictor\n",
"\n",
" def _get_data(self, data):\n",
" if isinstance(data, str):\n",
" return TabularDataset(data)\n",
" return data.copy()\n",
"\n",
" def _predict(self, data, as_proba=False, **kwargs):\n",
" data = self._get_data(data)\n",
" if as_proba:\n",
" predproba_dict = {}\n",
" for label in self.labels:\n",
" print(f\"Predicting with TabularPredictor for label: {label} ...\")\n",
" predictor = self.get_predictor(label)\n",
" if as_proba:\n",
" predproba_dict[label] = predictor.predict_proba(data, as_multiclass=True, **kwargs)\n",
" data[label] = predictor.predict(data, **kwargs)\n",
" if not as_proba:\n",
" return data[self.labels]\n",
" else:\n",
" return predproba_dict"
]
},
{
"cell_type": "markdown",
"id": "68a45e75",
"metadata": {},
"source": [
"## 2. Data preparation and Training\n",
"\n",
"`01_binary_classification.ipynb`와 동일한 데이터셋을 사용합니다."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2b9c36d5",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" age | \n",
" workclass | \n",
" fnlwgt | \n",
" education | \n",
" education-num | \n",
" marital-status | \n",
" occupation | \n",
" relationship | \n",
" race | \n",
" sex | \n",
" capital-gain | \n",
" capital-loss | \n",
" hours-per-week | \n",
" native-country | \n",
" class | \n",
"
\n",
" \n",
" \n",
" \n",
" 6118 | \n",
" 51 | \n",
" Private | \n",
" 39264 | \n",
" Some-college | \n",
" 10 | \n",
" Married-civ-spouse | \n",
" Exec-managerial | \n",
" Wife | \n",
" White | \n",
" Female | \n",
" 0 | \n",
" 0 | \n",
" 40 | \n",
" United-States | \n",
" >50K | \n",
"
\n",
" \n",
" 23204 | \n",
" 58 | \n",
" Private | \n",
" 51662 | \n",
" 10th | \n",
" 6 | \n",
" Married-civ-spouse | \n",
" Other-service | \n",
" Wife | \n",
" White | \n",
" Female | \n",
" 0 | \n",
" 0 | \n",
" 8 | \n",
" United-States | \n",
" <=50K | \n",
"
\n",
" \n",
" 29590 | \n",
" 40 | \n",
" Private | \n",
" 326310 | \n",
" Some-college | \n",
" 10 | \n",
" Married-civ-spouse | \n",
" Craft-repair | \n",
" Husband | \n",
" White | \n",
" Male | \n",
" 0 | \n",
" 0 | \n",
" 44 | \n",
" United-States | \n",
" <=50K | \n",
"
\n",
" \n",
" 18116 | \n",
" 37 | \n",
" Private | \n",
" 222450 | \n",
" HS-grad | \n",
" 9 | \n",
" Never-married | \n",
" Sales | \n",
" Not-in-family | \n",
" White | \n",
" Male | \n",
" 0 | \n",
" 2339 | \n",
" 40 | \n",
" El-Salvador | \n",
" <=50K | \n",
"
\n",
" \n",
" 33964 | \n",
" 62 | \n",
" Private | \n",
" 109190 | \n",
" Bachelors | \n",
" 13 | \n",
" Married-civ-spouse | \n",
" Exec-managerial | \n",
" Husband | \n",
" White | \n",
" Male | \n",
" 15024 | \n",
" 0 | \n",
" 40 | \n",
" United-States | \n",
" >50K | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" age workclass fnlwgt education education-num \\\n",
"6118 51 Private 39264 Some-college 10 \n",
"23204 58 Private 51662 10th 6 \n",
"29590 40 Private 326310 Some-college 10 \n",
"18116 37 Private 222450 HS-grad 9 \n",
"33964 62 Private 109190 Bachelors 13 \n",
"\n",
" marital-status occupation relationship race sex \\\n",
"6118 Married-civ-spouse Exec-managerial Wife White Female \n",
"23204 Married-civ-spouse Other-service Wife White Female \n",
"29590 Married-civ-spouse Craft-repair Husband White Male \n",
"18116 Never-married Sales Not-in-family White Male \n",
"33964 Married-civ-spouse Exec-managerial Husband White Male \n",
"\n",
" capital-gain capital-loss hours-per-week native-country class \n",
"6118 0 0 40 United-States >50K \n",
"23204 0 0 8 United-States <=50K \n",
"29590 0 0 44 United-States <=50K \n",
"18116 0 2339 40 El-Salvador <=50K \n",
"33964 15024 0 40 United-States >50K "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/train.csv')\n",
"subsample_size = 500 # subsample subset of data for faster demo, try setting this to much larger values\n",
"train_data = train_data.sample(n=subsample_size, random_state=0)\n",
"train_data.head()"
]
},
{
"cell_type": "markdown",
"id": "352deaed",
"metadata": {},
"source": [
"이번에는 교육 기간, 학력, 개인 소득을 동시에 예측해 보겠습니다."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3d3c62bd",
"metadata": {},
"outputs": [],
"source": [
"labels = ['education-num','education','class'] # which columns to predict based on the others\n",
"problem_types = ['regression','multiclass','binary'] # type of each prediction problem\n",
"save_path = 'ag-02-multilabel'\n",
"time_limit = 60"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3f4cfa0c",
"metadata": {},
"outputs": [],
"source": [
"!rm -rf $save_path"
]
},
{
"cell_type": "markdown",
"id": "57c7f78b",
"metadata": {},
"source": [
"사용자 정의 클래스로 훈련을 수행합니다. 만약 다중 레이블의 상관 관계를 고려해야 한다면, `consider_labels_correlation=True`로 설정하고 개별적으로 예측하려면 `consider_labels_correlation=False`로 설정하세요."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "92701f08",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Beginning AutoGluon training ... Time limit = 60s\n",
"AutoGluon will save models to \"ag-02-multilabel/Predictor_education-num/\"\n",
"AutoGluon Version: 0.5.2\n",
"Python Version: 3.8.12\n",
"Operating System: Linux\n",
"Train Data Rows: 500\n",
"Train Data Columns: 12\n",
"Label Column: education-num\n",
"Preprocessing data ...\n",
"Using Feature Generators to preprocess the data ...\n",
"Fitting AutoMLPipelineFeatureGenerator...\n",
"\tAvailable Memory: 14967.6 MB\n",
"\tTrain Data (Original) Memory Usage: 0.26 MB (0.0% of available memory)\n",
"\tInferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.\n",
"\tStage 1 Generators:\n",
"\t\tFitting AsTypeFeatureGenerator...\n",
"\t\t\tNote: Converting 1 features to boolean dtype as they only contain 2 unique values.\n",
"\tStage 2 Generators:\n",
"\t\tFitting FillNaFeatureGenerator...\n",
"\tStage 3 Generators:\n",
"\t\tFitting IdentityFeatureGenerator...\n",
"\t\tFitting CategoryFeatureGenerator...\n",
"\t\t\tFitting CategoryMemoryMinimizeFeatureGenerator...\n",
"\tStage 4 Generators:\n",
"\t\tFitting DropUniqueFeatureGenerator...\n",
"\tTypes of features in original data (raw dtype, special dtypes):\n",
"\t\t('int', []) : 5 | ['age', 'fnlwgt', 'capital-gain', 'capital-loss', 'hours-per-week']\n",
"\t\t('object', []) : 7 | ['workclass', 'marital-status', 'occupation', 'relationship', 'race', ...]\n",
"\tTypes of features in processed data (raw dtype, special dtypes):\n",
"\t\t('category', []) : 6 | ['workclass', 'marital-status', 'occupation', 'relationship', 'race', ...]\n",
"\t\t('int', []) : 5 | ['age', 'fnlwgt', 'capital-gain', 'capital-loss', 'hours-per-week']\n",
"\t\t('int', ['bool']) : 1 | ['sex']\n",
"\t0.1s = Fit runtime\n",
"\t12 features in original data used to generate 12 features in processed data.\n",
"\tTrain Data (Processed) Memory Usage: 0.03 MB (0.0% of available memory)\n",
"Data preprocessing and feature engineering runtime = 0.08s ...\n",
"AutoGluon will gauge predictive performance using evaluation metric: 'root_mean_squared_error'\n",
"\tThis metric's sign has been flipped to adhere to being higher_is_better. The metric score can be multiplied by -1 to get the metric value.\n",
"\tTo change this, specify the eval_metric parameter of Predictor()\n",
"Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 400, Val Rows: 100\n",
"Fitting 11 L1 models ...\n",
"Fitting model: KNeighborsUnif ... Training model for up to 59.92s of the 59.91s of remaining time.\n",
"\t-2.703\t = Validation score (-root_mean_squared_error)\n",
"\t0.01s\t = Training runtime\n",
"\t0.01s\t = Validation runtime\n",
"Fitting model: KNeighborsDist ... Training model for up to 59.9s of the 59.9s of remaining time.\n",
"\t-2.7447\t = Validation score (-root_mean_squared_error)\n",
"\t0.0s\t = Training runtime\n",
"\t0.0s\t = Validation runtime\n",
"Fitting model: LightGBMXT ... Training model for up to 59.89s of the 59.89s of remaining time.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting TabularPredictor for label: education-num ...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\t-2.294\t = Validation score (-root_mean_squared_error)\n",
"\t0.73s\t = Training runtime\n",
"\t0.01s\t = Validation runtime\n",
"Fitting model: LightGBM ... Training model for up to 59.15s of the 59.15s of remaining time.\n",
"\t-2.3176\t = Validation score (-root_mean_squared_error)\n",
"\t0.2s\t = Training runtime\n",
"\t0.01s\t = Validation runtime\n",
"Fitting model: RandomForestMSE ... Training model for up to 58.94s of the 58.94s of remaining time.\n",
"\t-2.2587\t = Validation score (-root_mean_squared_error)\n",
"\t0.45s\t = Training runtime\n",
"\t0.05s\t = Validation runtime\n",
"Fitting model: CatBoost ... Training model for up to 58.43s of the 58.42s of remaining time.\n",
"\t-2.1682\t = Validation score (-root_mean_squared_error)\n",
"\t1.39s\t = Training runtime\n",
"\t0.01s\t = Validation runtime\n",
"Fitting model: ExtraTreesMSE ... Training model for up to 57.03s of the 57.02s of remaining time.\n",
"\t-2.2983\t = Validation score (-root_mean_squared_error)\n",
"\t0.39s\t = Training runtime\n",
"\t0.05s\t = Validation runtime\n",
"Fitting model: NeuralNetFastAI ... Training model for up to 56.57s of the 56.57s of remaining time.\n",
"\t-2.4422\t = Validation score (-root_mean_squared_error)\n",
"\t2.22s\t = Training runtime\n",
"\t0.01s\t = Validation runtime\n",
"Fitting model: XGBoost ... Training model for up to 54.33s of the 54.32s of remaining time.\n",
"\t-2.1456\t = Validation score (-root_mean_squared_error)\n",
"\t0.29s\t = Training runtime\n",
"\t0.01s\t = Validation runtime\n",
"Fitting model: NeuralNetTorch ... Training model for up to 54.03s of the 54.02s of remaining time.\n",
"\t-2.3312\t = Validation score (-root_mean_squared_error)\n",
"\t1.73s\t = Training runtime\n",
"\t0.01s\t = Validation runtime\n",
"Fitting model: LightGBMLarge ... Training model for up to 52.28s of the 52.27s of remaining time.\n",
"\t-2.3514\t = Validation score (-root_mean_squared_error)\n",
"\t0.38s\t = Training runtime\n",
"\t0.01s\t = Validation runtime\n",
"Fitting model: WeightedEnsemble_L2 ... Training model for up to 59.92s of the 51.87s of remaining time.\n",
"\t-2.1049\t = Validation score (-root_mean_squared_error)\n",
"\t0.25s\t = Training runtime\n",
"\t0.0s\t = Validation runtime\n",
"AutoGluon training complete, total runtime = 8.39s ... Best model: \"WeightedEnsemble_L2\"\n",
"TabularPredictor saved. To load, use: predictor = TabularPredictor.load(\"ag-02-multilabel/Predictor_education-num/\")\n",
"Beginning AutoGluon training ... Time limit = 60s\n",
"AutoGluon will save models to \"ag-02-multilabel/Predictor_education/\"\n",
"AutoGluon Version: 0.5.2\n",
"Python Version: 3.8.12\n",
"Operating System: Linux\n",
"Train Data Rows: 500\n",
"Train Data Columns: 13\n",
"Label Column: education\n",
"Preprocessing data ...\n",
"Warning: Some classes in the training set have fewer than 10 examples. AutoGluon will only keep 11 out of 15 classes for training and will not try to predict the rare classes. To keep more classes, increase the number of datapoints from these rare classes in the training data or reduce label_count_threshold.\n",
"Fraction of data from classes with at least 10 examples that will be kept for training models: 0.976\n",
"Train Data Class Count: 11\n",
"Using Feature Generators to preprocess the data ...\n",
"Fitting AutoMLPipelineFeatureGenerator...\n",
"\tAvailable Memory: 14527.17 MB\n",
"\tTrain Data (Original) Memory Usage: 0.25 MB (0.0% of available memory)\n",
"\tInferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.\n",
"\tStage 1 Generators:\n",
"\t\tFitting AsTypeFeatureGenerator...\n",
"\t\t\tNote: Converting 1 features to boolean dtype as they only contain 2 unique values.\n",
"\tStage 2 Generators:\n",
"\t\tFitting FillNaFeatureGenerator...\n",
"\tStage 3 Generators:\n",
"\t\tFitting IdentityFeatureGenerator...\n",
"\t\tFitting CategoryFeatureGenerator...\n",
"\t\t\tFitting CategoryMemoryMinimizeFeatureGenerator...\n",
"\tStage 4 Generators:\n",
"\t\tFitting DropUniqueFeatureGenerator...\n",
"\tTypes of features in original data (raw dtype, special dtypes):\n",
"\t\t('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]\n",
"\t\t('object', []) : 7 | ['workclass', 'marital-status', 'occupation', 'relationship', 'race', ...]\n",
"\tTypes of features in processed data (raw dtype, special dtypes):\n",
"\t\t('category', []) : 6 | ['workclass', 'marital-status', 'occupation', 'relationship', 'race', ...]\n",
"\t\t('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]\n",
"\t\t('int', ['bool']) : 1 | ['sex']\n",
"\t0.1s = Fit runtime\n",
"\t13 features in original data used to generate 13 features in processed data.\n",
"\tTrain Data (Processed) Memory Usage: 0.03 MB (0.0% of available memory)\n",
"Data preprocessing and feature engineering runtime = 0.08s ...\n",
"AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'\n",
"\tTo change this, specify the eval_metric parameter of Predictor()\n",
"Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 390, Val Rows: 98\n",
"Fitting 13 L1 models ...\n",
"Fitting model: KNeighborsUnif ... Training model for up to 59.92s of the 59.92s of remaining time.\n",
"\t0.2653\t = Validation score (accuracy)\n",
"\t0.0s\t = Training runtime\n",
"\t0.0s\t = Validation runtime\n",
"Fitting model: KNeighborsDist ... Training model for up to 59.91s of the 59.9s of remaining time.\n",
"\t0.2347\t = Validation score (accuracy)\n",
"\t0.0s\t = Training runtime\n",
"\t0.0s\t = Validation runtime\n",
"Fitting model: NeuralNetFastAI ... Training model for up to 59.9s of the 59.89s of remaining time.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting TabularPredictor for label: education ...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\t0.8163\t = Validation score (accuracy)\n",
"\t1.13s\t = Training runtime\n",
"\t0.01s\t = Validation runtime\n",
"Fitting model: LightGBMXT ... Training model for up to 58.75s of the 58.75s of remaining time.\n",
"\t0.9694\t = Validation score (accuracy)\n",
"\t0.87s\t = Training runtime\n",
"\t0.01s\t = Validation runtime\n",
"Fitting model: LightGBM ... Training model for up to 57.8s of the 57.8s of remaining time.\n",
"\t1.0\t = Validation score (accuracy)\n",
"\t0.47s\t = Training runtime\n",
"\t0.01s\t = Validation runtime\n",
"Fitting model: RandomForestGini ... Training model for up to 57.32s of the 57.32s of remaining time.\n",
"\t0.9082\t = Validation score (accuracy)\n",
"\t0.82s\t = Training runtime\n",
"\t0.07s\t = Validation runtime\n",
"Fitting model: RandomForestEntr ... Training model for up to 56.39s of the 56.39s of remaining time.\n",
"\t0.8776\t = Validation score (accuracy)\n",
"\t0.73s\t = Training runtime\n",
"\t0.07s\t = Validation runtime\n",
"Fitting model: CatBoost ... Training model for up to 55.56s of the 55.56s of remaining time.\n",
"\t1.0\t = Validation score (accuracy)\n",
"\t17.74s\t = Training runtime\n",
"\t0.01s\t = Validation runtime\n",
"Fitting model: ExtraTreesGini ... Training model for up to 37.81s of the 37.81s of remaining time.\n",
"\t0.9592\t = Validation score (accuracy)\n",
"\t0.68s\t = Training runtime\n",
"\t0.07s\t = Validation runtime\n",
"Fitting model: ExtraTreesEntr ... Training model for up to 37.04s of the 37.04s of remaining time.\n",
"\t0.949\t = Validation score (accuracy)\n",
"\t0.59s\t = Training runtime\n",
"\t0.07s\t = Validation runtime\n",
"Fitting model: XGBoost ... Training model for up to 36.36s of the 36.36s of remaining time.\n",
"\t1.0\t = Validation score (accuracy)\n",
"\t0.49s\t = Training runtime\n",
"\t0.01s\t = Validation runtime\n",
"Fitting model: NeuralNetTorch ... Training model for up to 35.86s of the 35.86s of remaining time.\n",
"\t0.9592\t = Validation score (accuracy)\n",
"\t3.16s\t = Training runtime\n",
"\t0.01s\t = Validation runtime\n",
"Fitting model: LightGBMLarge ... Training model for up to 32.69s of the 32.68s of remaining time.\n",
"\t1.0\t = Validation score (accuracy)\n",
"\t1.92s\t = Training runtime\n",
"\t0.01s\t = Validation runtime\n",
"Fitting model: WeightedEnsemble_L2 ... Training model for up to 59.92s of the 30.68s of remaining time.\n",
"\t1.0\t = Validation score (accuracy)\n",
"\t0.28s\t = Training runtime\n",
"\t0.0s\t = Validation runtime\n",
"AutoGluon training complete, total runtime = 29.62s ... Best model: \"WeightedEnsemble_L2\"\n",
"TabularPredictor saved. To load, use: predictor = TabularPredictor.load(\"ag-02-multilabel/Predictor_education/\")\n",
"Beginning AutoGluon training ... Time limit = 60s\n",
"AutoGluon will save models to \"ag-02-multilabel/Predictor_class/\"\n",
"AutoGluon Version: 0.5.2\n",
"Python Version: 3.8.12\n",
"Operating System: Linux\n",
"Train Data Rows: 500\n",
"Train Data Columns: 14\n",
"Label Column: class\n",
"Preprocessing data ...\n",
"Selected class <--> label mapping: class 1 = >50K, class 0 = <=50K\n",
"\tNote: For your binary classification, AutoGluon arbitrarily selected which label-value represents positive ( >50K) vs negative ( <=50K) class.\n",
"\tTo explicitly set the positive_class, either rename classes to 1 and 0, or specify positive_class in Predictor init.\n",
"Using Feature Generators to preprocess the data ...\n",
"Fitting AutoMLPipelineFeatureGenerator...\n",
"\tAvailable Memory: 13846.86 MB\n",
"\tTrain Data (Original) Memory Usage: 0.29 MB (0.0% of available memory)\n",
"\tInferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.\n",
"\tStage 1 Generators:\n",
"\t\tFitting AsTypeFeatureGenerator...\n",
"\t\t\tNote: Converting 1 features to boolean dtype as they only contain 2 unique values.\n",
"\tStage 2 Generators:\n",
"\t\tFitting FillNaFeatureGenerator...\n",
"\tStage 3 Generators:\n",
"\t\tFitting IdentityFeatureGenerator...\n",
"\t\tFitting CategoryFeatureGenerator...\n",
"\t\t\tFitting CategoryMemoryMinimizeFeatureGenerator...\n",
"\tStage 4 Generators:\n",
"\t\tFitting DropUniqueFeatureGenerator...\n",
"\tTypes of features in original data (raw dtype, special dtypes):\n",
"\t\t('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]\n",
"\t\t('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]\n",
"\tTypes of features in processed data (raw dtype, special dtypes):\n",
"\t\t('category', []) : 7 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]\n",
"\t\t('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]\n",
"\t\t('int', ['bool']) : 1 | ['sex']\n",
"\t0.1s = Fit runtime\n",
"\t14 features in original data used to generate 14 features in processed data.\n",
"\tTrain Data (Processed) Memory Usage: 0.03 MB (0.0% of available memory)\n",
"Data preprocessing and feature engineering runtime = 0.09s ...\n",
"AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'\n",
"\tTo change this, specify the eval_metric parameter of Predictor()\n",
"Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 400, Val Rows: 100\n",
"Fitting 13 L1 models ...\n",
"Fitting model: KNeighborsUnif ... Training model for up to 59.91s of the 59.91s of remaining time.\n",
"\t0.73\t = Validation score (accuracy)\n",
"\t0.0s\t = Training runtime\n",
"\t0.0s\t = Validation runtime\n",
"Fitting model: KNeighborsDist ... Training model for up to 59.9s of the 59.89s of remaining time.\n",
"\t0.65\t = Validation score (accuracy)\n",
"\t0.01s\t = Training runtime\n",
"\t0.01s\t = Validation runtime\n",
"Fitting model: LightGBMXT ... Training model for up to 59.88s of the 59.88s of remaining time.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting TabularPredictor for label: class ...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\t0.83\t = Validation score (accuracy)\n",
"\t0.21s\t = Training runtime\n",
"\t0.02s\t = Validation runtime\n",
"Fitting model: LightGBM ... Training model for up to 59.64s of the 59.64s of remaining time.\n",
"\t0.85\t = Validation score (accuracy)\n",
"\t0.25s\t = Training runtime\n",
"\t0.02s\t = Validation runtime\n",
"Fitting model: RandomForestGini ... Training model for up to 59.35s of the 59.35s of remaining time.\n",
"\t0.84\t = Validation score (accuracy)\n",
"\t0.51s\t = Training runtime\n",
"\t0.07s\t = Validation runtime\n",
"Fitting model: RandomForestEntr ... Training model for up to 58.75s of the 58.75s of remaining time.\n",
"\t0.83\t = Validation score (accuracy)\n",
"\t0.57s\t = Training runtime\n",
"\t0.07s\t = Validation runtime\n",
"Fitting model: CatBoost ... Training model for up to 58.09s of the 58.09s of remaining time.\n",
"\t0.85\t = Validation score (accuracy)\n",
"\t1.25s\t = Training runtime\n",
"\t0.01s\t = Validation runtime\n",
"Fitting model: ExtraTreesGini ... Training model for up to 56.83s of the 56.83s of remaining time.\n",
"\t0.82\t = Validation score (accuracy)\n",
"\t0.62s\t = Training runtime\n",
"\t0.07s\t = Validation runtime\n",
"Fitting model: ExtraTreesEntr ... Training model for up to 56.12s of the 56.12s of remaining time.\n",
"\t0.81\t = Validation score (accuracy)\n",
"\t0.61s\t = Training runtime\n",
"\t0.07s\t = Validation runtime\n",
"Fitting model: NeuralNetFastAI ... Training model for up to 55.43s of the 55.43s of remaining time.\n",
"\t0.82\t = Validation score (accuracy)\n",
"\t0.68s\t = Training runtime\n",
"\t0.02s\t = Validation runtime\n",
"Fitting model: XGBoost ... Training model for up to 54.72s of the 54.72s of remaining time.\n",
"\t0.87\t = Validation score (accuracy)\n",
"\t0.33s\t = Training runtime\n",
"\t0.01s\t = Validation runtime\n",
"Fitting model: NeuralNetTorch ... Training model for up to 54.37s of the 54.37s of remaining time.\n",
"\t0.85\t = Validation score (accuracy)\n",
"\t2.28s\t = Training runtime\n",
"\t0.02s\t = Validation runtime\n",
"Fitting model: LightGBMLarge ... Training model for up to 52.07s of the 52.07s of remaining time.\n",
"\t0.83\t = Validation score (accuracy)\n",
"\t0.75s\t = Training runtime\n",
"\t0.04s\t = Validation runtime\n",
"Fitting model: WeightedEnsemble_L2 ... Training model for up to 59.91s of the 51.24s of remaining time.\n",
"\t0.87\t = Validation score (accuracy)\n",
"\t0.59s\t = Training runtime\n",
"\t0.0s\t = Validation runtime\n",
"AutoGluon training complete, total runtime = 9.37s ... Best model: \"WeightedEnsemble_L2\"\n",
"TabularPredictor saved. To load, use: predictor = TabularPredictor.load(\"ag-02-multilabel/Predictor_class/\")\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"MultilabelPredictor saved to disk. Load with: MultilabelPredictor.load('ag-02-multilabel/')\n"
]
}
],
"source": [
"multi_predictor = MultilabelPredictor(labels=labels, problem_types=problem_types, path=save_path)\n",
"multi_predictor.fit(train_data, time_limit=time_limit)"
]
},
{
"cell_type": "markdown",
"id": "d933db7a",
"metadata": {},
"source": [
"## 3. Inference and Evaluation"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "aef32e97",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loaded data from: https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv | Columns = 15 / 15 | Rows = 9769 -> 9769\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" age | \n",
" workclass | \n",
" fnlwgt | \n",
" marital-status | \n",
" occupation | \n",
" relationship | \n",
" race | \n",
" sex | \n",
" capital-gain | \n",
" capital-loss | \n",
" hours-per-week | \n",
" native-country | \n",
"
\n",
" \n",
" \n",
" \n",
" 5454 | \n",
" 41 | \n",
" Self-emp-not-inc | \n",
" 408498 | \n",
" Married-civ-spouse | \n",
" Exec-managerial | \n",
" Husband | \n",
" White | \n",
" Male | \n",
" 0 | \n",
" 0 | \n",
" 50 | \n",
" United-States | \n",
"
\n",
" \n",
" 6111 | \n",
" 39 | \n",
" Private | \n",
" 746786 | \n",
" Married-civ-spouse | \n",
" Prof-specialty | \n",
" Husband | \n",
" White | \n",
" Male | \n",
" 0 | \n",
" 0 | \n",
" 55 | \n",
" United-States | \n",
"
\n",
" \n",
" 5282 | \n",
" 50 | \n",
" Private | \n",
" 62593 | \n",
" Married-civ-spouse | \n",
" Farming-fishing | \n",
" Husband | \n",
" Asian-Pac-Islander | \n",
" Male | \n",
" 0 | \n",
" 0 | \n",
" 40 | \n",
" United-States | \n",
"
\n",
" \n",
" 3046 | \n",
" 31 | \n",
" Private | \n",
" 248178 | \n",
" Married-civ-spouse | \n",
" Other-service | \n",
" Husband | \n",
" Black | \n",
" Male | \n",
" 0 | \n",
" 0 | \n",
" 35 | \n",
" United-States | \n",
"
\n",
" \n",
" 2162 | \n",
" 43 | \n",
" State-gov | \n",
" 52849 | \n",
" Married-civ-spouse | \n",
" Prof-specialty | \n",
" Husband | \n",
" White | \n",
" Male | \n",
" 0 | \n",
" 0 | \n",
" 40 | \n",
" United-States | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" age workclass fnlwgt marital-status occupation \\\n",
"5454 41 Self-emp-not-inc 408498 Married-civ-spouse Exec-managerial \n",
"6111 39 Private 746786 Married-civ-spouse Prof-specialty \n",
"5282 50 Private 62593 Married-civ-spouse Farming-fishing \n",
"3046 31 Private 248178 Married-civ-spouse Other-service \n",
"2162 43 State-gov 52849 Married-civ-spouse Prof-specialty \n",
"\n",
" relationship race sex capital-gain capital-loss \\\n",
"5454 Husband White Male 0 0 \n",
"6111 Husband White Male 0 0 \n",
"5282 Husband Asian-Pac-Islander Male 0 0 \n",
"3046 Husband Black Male 0 0 \n",
"2162 Husband White Male 0 0 \n",
"\n",
" hours-per-week native-country \n",
"5454 50 United-States \n",
"6111 55 United-States \n",
"5282 40 United-States \n",
"3046 35 United-States \n",
"2162 40 United-States "
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv')\n",
"test_data = test_data.sample(n=subsample_size, random_state=0)\n",
"test_data_nolab = test_data.drop(columns=labels) # unnecessary, just to demonstrate we're not cheating here\n",
"test_data_nolab.head()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "9ce86965",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predicting with TabularPredictor for label: education-num ...\n",
"Predicting with TabularPredictor for label: education ...\n",
"Predicting with TabularPredictor for label: class ...\n",
"Predictions: \n",
" education-num education class\n",
"5454 10.308321 Some-college >50K\n",
"6111 12.796276 HS-grad >50K\n",
"5282 9.429871 HS-grad >50K\n",
"3046 9.370234 HS-grad <=50K\n",
"2162 12.537287 HS-grad >50K\n",
"... ... ... ...\n",
"6965 9.409938 HS-grad >50K\n",
"4762 8.725597 11th <=50K\n",
"234 10.437899 Some-college <=50K\n",
"6291 10.456116 Some-college >50K\n",
"9575 10.091375 Some-college >50K\n",
"\n",
"[500 rows x 3 columns]\n"
]
}
],
"source": [
"multi_predictor = MultilabelPredictor.load(save_path) # unnecessary, just demonstrates how to load previously-trained multilabel predictor from file\n",
"\n",
"predictions = multi_predictor.predict(test_data_nolab)\n",
"print(\"Predictions: \\n\", predictions)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "22750d4f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluating TabularPredictor for label: education-num ...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Evaluation: root_mean_squared_error on test data: -2.1815121451689987\n",
"\tNote: Scores are always higher_is_better. This metric score can be multiplied by -1 to get the metric value.\n",
"Evaluations on test data:\n",
"{\n",
" \"root_mean_squared_error\": -2.1815121451689987,\n",
" \"mean_squared_error\": -4.758995239519847,\n",
" \"mean_absolute_error\": -1.6262979860305786,\n",
" \"r2\": 0.3846415427345674,\n",
" \"pearsonr\": 0.6331502292829242,\n",
" \"median_absolute_error\": -1.1948366165161133\n",
"}\n",
"Evaluation: accuracy on test data: 0.226\n",
"Evaluations on test data:\n",
"{\n",
" \"accuracy\": 0.226,\n",
" \"balanced_accuracy\": 0.0896378687902662,\n",
" \"mcc\": 0.043530002398949885\n",
"}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluating TabularPredictor for label: education ...\n",
"Evaluating TabularPredictor for label: class ...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Evaluation: accuracy on test data: 0.812\n",
"Evaluations on test data:\n",
"{\n",
" \"accuracy\": 0.812,\n",
" \"balanced_accuracy\": 0.7039219373576547,\n",
" \"mcc\": 0.466508799319164,\n",
" \"roc_auc\": 0.8491088405524562,\n",
" \"f1\": 0.5688073394495413,\n",
" \"precision\": 0.6966292134831461,\n",
" \"recall\": 0.4806201550387597\n",
"}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'education-num': {'root_mean_squared_error': -2.1815121451689987, 'mean_squared_error': -4.758995239519847, 'mean_absolute_error': -1.6262979860305786, 'r2': 0.3846415427345674, 'pearsonr': 0.6331502292829242, 'median_absolute_error': -1.1948366165161133}, 'education': {'accuracy': 0.226, 'balanced_accuracy': 0.0896378687902662, 'mcc': 0.043530002398949885}, 'class': {'accuracy': 0.812, 'balanced_accuracy': 0.7039219373576547, 'mcc': 0.466508799319164, 'roc_auc': 0.8491088405524562, 'f1': 0.5688073394495413, 'precision': 0.6966292134831461, 'recall': 0.4806201550387597}}\n",
"Evaluated using metrics: {'education-num': root_mean_squared_error, 'education': accuracy, 'class': accuracy}\n"
]
}
],
"source": [
"evaluations = multi_predictor.evaluate(test_data)\n",
"print(evaluations)\n",
"print(\"Evaluated using metrics:\", multi_predictor.eval_metrics)"
]
},
{
"cell_type": "markdown",
"id": "7d578a8a",
"metadata": {},
"source": [
"다음과 같이 레이블 중 하나에 대해 TabularPredictor 단일 레이블을 예측할 수도 있습니다."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a5e506d3",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" model | \n",
" score_val | \n",
" pred_time_val | \n",
" fit_time | \n",
" pred_time_val_marginal | \n",
" fit_time_marginal | \n",
" stack_level | \n",
" can_infer | \n",
" fit_order | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" XGBoost | \n",
" 0.87 | \n",
" 0.007190 | \n",
" 0.334163 | \n",
" 0.007190 | \n",
" 0.334163 | \n",
" 1 | \n",
" True | \n",
" 11 | \n",
"
\n",
" \n",
" 1 | \n",
" WeightedEnsemble_L2 | \n",
" 0.87 | \n",
" 0.008057 | \n",
" 0.920020 | \n",
" 0.000867 | \n",
" 0.585857 | \n",
" 2 | \n",
" True | \n",
" 14 | \n",
"
\n",
" \n",
" 2 | \n",
" CatBoost | \n",
" 0.85 | \n",
" 0.007188 | \n",
" 1.249770 | \n",
" 0.007188 | \n",
" 1.249770 | \n",
" 1 | \n",
" True | \n",
" 7 | \n",
"
\n",
" \n",
" 3 | \n",
" NeuralNetTorch | \n",
" 0.85 | \n",
" 0.020120 | \n",
" 2.275299 | \n",
" 0.020120 | \n",
" 2.275299 | \n",
" 1 | \n",
" True | \n",
" 12 | \n",
"
\n",
" \n",
" 4 | \n",
" LightGBM | \n",
" 0.85 | \n",
" 0.022127 | \n",
" 0.248642 | \n",
" 0.022127 | \n",
" 0.248642 | \n",
" 1 | \n",
" True | \n",
" 4 | \n",
"
\n",
" \n",
" 5 | \n",
" RandomForestGini | \n",
" 0.84 | \n",
" 0.066845 | \n",
" 0.514668 | \n",
" 0.066845 | \n",
" 0.514668 | \n",
" 1 | \n",
" True | \n",
" 5 | \n",
"
\n",
" \n",
" 6 | \n",
" LightGBMXT | \n",
" 0.83 | \n",
" 0.021293 | \n",
" 0.209066 | \n",
" 0.021293 | \n",
" 0.209066 | \n",
" 1 | \n",
" True | \n",
" 3 | \n",
"
\n",
" \n",
" 7 | \n",
" LightGBMLarge | \n",
" 0.83 | \n",
" 0.041951 | \n",
" 0.751162 | \n",
" 0.041951 | \n",
" 0.751162 | \n",
" 1 | \n",
" True | \n",
" 13 | \n",
"
\n",
" \n",
" 8 | \n",
" RandomForestEntr | \n",
" 0.83 | \n",
" 0.074112 | \n",
" 0.570750 | \n",
" 0.074112 | \n",
" 0.570750 | \n",
" 1 | \n",
" True | \n",
" 6 | \n",
"
\n",
" \n",
" 9 | \n",
" NeuralNetFastAI | \n",
" 0.82 | \n",
" 0.015175 | \n",
" 0.678817 | \n",
" 0.015175 | \n",
" 0.678817 | \n",
" 1 | \n",
" True | \n",
" 10 | \n",
"
\n",
" \n",
" 10 | \n",
" ExtraTreesGini | \n",
" 0.82 | \n",
" 0.071165 | \n",
" 0.616249 | \n",
" 0.071165 | \n",
" 0.616249 | \n",
" 1 | \n",
" True | \n",
" 8 | \n",
"
\n",
" \n",
" 11 | \n",
" ExtraTreesEntr | \n",
" 0.81 | \n",
" 0.072873 | \n",
" 0.605156 | \n",
" 0.072873 | \n",
" 0.605156 | \n",
" 1 | \n",
" True | \n",
" 9 | \n",
"
\n",
" \n",
" 12 | \n",
" KNeighborsUnif | \n",
" 0.73 | \n",
" 0.004318 | \n",
" 0.003909 | \n",
" 0.004318 | \n",
" 0.003909 | \n",
" 1 | \n",
" True | \n",
" 1 | \n",
"
\n",
" \n",
" 13 | \n",
" KNeighborsDist | \n",
" 0.65 | \n",
" 0.005875 | \n",
" 0.005210 | \n",
" 0.005875 | \n",
" 0.005210 | \n",
" 1 | \n",
" True | \n",
" 2 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" model score_val pred_time_val fit_time \\\n",
"0 XGBoost 0.87 0.007190 0.334163 \n",
"1 WeightedEnsemble_L2 0.87 0.008057 0.920020 \n",
"2 CatBoost 0.85 0.007188 1.249770 \n",
"3 NeuralNetTorch 0.85 0.020120 2.275299 \n",
"4 LightGBM 0.85 0.022127 0.248642 \n",
"5 RandomForestGini 0.84 0.066845 0.514668 \n",
"6 LightGBMXT 0.83 0.021293 0.209066 \n",
"7 LightGBMLarge 0.83 0.041951 0.751162 \n",
"8 RandomForestEntr 0.83 0.074112 0.570750 \n",
"9 NeuralNetFastAI 0.82 0.015175 0.678817 \n",
"10 ExtraTreesGini 0.82 0.071165 0.616249 \n",
"11 ExtraTreesEntr 0.81 0.072873 0.605156 \n",
"12 KNeighborsUnif 0.73 0.004318 0.003909 \n",
"13 KNeighborsDist 0.65 0.005875 0.005210 \n",
"\n",
" pred_time_val_marginal fit_time_marginal stack_level can_infer \\\n",
"0 0.007190 0.334163 1 True \n",
"1 0.000867 0.585857 2 True \n",
"2 0.007188 1.249770 1 True \n",
"3 0.020120 2.275299 1 True \n",
"4 0.022127 0.248642 1 True \n",
"5 0.066845 0.514668 1 True \n",
"6 0.021293 0.209066 1 True \n",
"7 0.041951 0.751162 1 True \n",
"8 0.074112 0.570750 1 True \n",
"9 0.015175 0.678817 1 True \n",
"10 0.071165 0.616249 1 True \n",
"11 0.072873 0.605156 1 True \n",
"12 0.004318 0.003909 1 True \n",
"13 0.005875 0.005210 1 True \n",
"\n",
" fit_order \n",
"0 11 \n",
"1 14 \n",
"2 7 \n",
"3 12 \n",
"4 4 \n",
"5 5 \n",
"6 3 \n",
"7 13 \n",
"8 6 \n",
"9 10 \n",
"10 8 \n",
"11 9 \n",
"12 1 \n",
"13 2 "
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predictor_class = multi_predictor.get_predictor('class')\n",
"predictor_class.leaderboard(silent=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "conda_pytorch_p38",
"language": "python",
"name": "conda_pytorch_p38"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}