{ "cells": [ { "cell_type": "markdown", "id": "a4ff380b", "metadata": {}, "source": [ "# Multimodal Training/Prediction using TabularPredictor\n", "\n", "본 핸즈온에서는 `TabularPredictor`로 multimodal 데이터를 훈련합니다.\n", "\n", "## TextPredictor & TabularPredictor\n", "\n", "`MultiModalPredictor` 외에 `TextPredictor`나 `TabularPredictor`로도 멀티모달 모델 훈련이 가능합니다.\n", "만약 트랜스포머 임베딩 대신, XGBoost/LightGBM/CatBoost 등의 Gradient Boosted Tree 결과를 앙상블 및 스태킹하여 모델링하고 싶다면, `TabularPredictor.fit (..., hyperparameters = 'multimodal')`로 훈련하세요. 주의할 점은 `hyperparameters = 'multimodal'`을 지정하지 않고 훈련 시, AutoGluon Tabular는 텍스트 데이터를 N-gram으로 자동으로 변환 후 tabular 모델로만 훈련합니다. " ] }, { "cell_type": "code", "execution_count": 1, "id": "9da3bcb6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1\n" ] } ], "source": [ "import os\n", "import torch\n", "import mxnet as mx\n", "num_gpus = torch.cuda.device_count()\n", "\n", "if num_gpus == 0:\n", " os.environ['AUTOGLUON_TEXT_TRAIN_WITHOUT_GPU'] = '1'\n", "\n", "print(num_gpus) " ] }, { "cell_type": "code", "execution_count": 2, "id": "1eb4f4f4", "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import pprint\n", "import random\n", "from autogluon.tabular import TabularPredictor\n", "import mxnet as mx\n", "\n", "np.random.seed(123)\n", "random.seed(123)\n", "mx.random.seed(123)" ] }, { "cell_type": "markdown", "id": "4a280a6d", "metadata": {}, "source": [ "## 1. Data preparation and Training\n", "\n", "본 핸즈온은 MachineHack Salary Prediction Hackathon의 제품 리뷰 데이터셋을 사용합니다. 리뷰 점수는 0,1,2,3점의 4개 클래스로 구성되어 있는 다중 클래스 문제입니다." ] }, { "cell_type": "code", "execution_count": 3, "id": "ad2036bb", "metadata": {}, "outputs": [], "source": [ "save_path = 'ag-02-multimodal-tabularpredictor'\n", "!rm -rf $save_path product_sentiment" ] }, { "cell_type": "code", "execution_count": 4, "id": "eb956663", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2022-08-30 00:17:50-- https://autogluon-text-data.s3.amazonaws.com/multimodal_text/machine_hack_product_sentiment/train.csv\n", "Resolving autogluon-text-data.s3.amazonaws.com (autogluon-text-data.s3.amazonaws.com)... 52.216.24.52\n", "Connecting to autogluon-text-data.s3.amazonaws.com (autogluon-text-data.s3.amazonaws.com)|52.216.24.52|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 689486 (673K) [text/csv]\n", "Saving to: ‘product_sentiment/train.csv’\n", "\n", "100%[======================================>] 689,486 --.-K/s in 0.006s \n", "\n", "2022-08-30 00:17:50 (111 MB/s) - ‘product_sentiment/train.csv’ saved [689486/689486]\n", "\n", "--2022-08-30 00:17:51-- https://autogluon-text-data.s3.amazonaws.com/multimodal_text/machine_hack_product_sentiment/dev.csv\n", "Resolving autogluon-text-data.s3.amazonaws.com (autogluon-text-data.s3.amazonaws.com)... 52.216.24.52\n", "Connecting to autogluon-text-data.s3.amazonaws.com (autogluon-text-data.s3.amazonaws.com)|52.216.24.52|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 75517 (74K) [text/csv]\n", "Saving to: ‘product_sentiment/dev.csv’\n", "\n", "100%[======================================>] 75,517 --.-K/s in 0.002s \n", "\n", "2022-08-30 00:17:51 (45.8 MB/s) - ‘product_sentiment/dev.csv’ saved [75517/75517]\n", "\n", "--2022-08-30 00:17:51-- https://autogluon-text-data.s3.amazonaws.com/multimodal_text/machine_hack_product_sentiment/test.csv\n", "Resolving autogluon-text-data.s3.amazonaws.com (autogluon-text-data.s3.amazonaws.com)... 52.216.24.52\n", "Connecting to autogluon-text-data.s3.amazonaws.com (autogluon-text-data.s3.amazonaws.com)|52.216.24.52|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 312194 (305K) [text/csv]\n", "Saving to: ‘product_sentiment/test.csv’\n", "\n", "100%[======================================>] 312,194 --.-K/s in 0.008s \n", "\n", "2022-08-30 00:17:51 (35.7 MB/s) - ‘product_sentiment/test.csv’ saved [312194/312194]\n", "\n" ] } ], "source": [ "!mkdir -p product_sentiment\n", "!wget https://autogluon-text-data.s3.amazonaws.com/multimodal_text/machine_hack_product_sentiment/train.csv -O product_sentiment/train.csv\n", "!wget https://autogluon-text-data.s3.amazonaws.com/multimodal_text/machine_hack_product_sentiment/dev.csv -O product_sentiment/dev.csv\n", "!wget https://autogluon-text-data.s3.amazonaws.com/multimodal_text/machine_hack_product_sentiment/test.csv -O product_sentiment/test.csv" ] }, { "cell_type": "code", "execution_count": 5, "id": "9da3ebf1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of training samples: 1000\n", "Number of dev samples: 637\n", "Number of test samples: 2728\n" ] } ], "source": [ "subsample_size = 1000 # for quick demo, try setting to larger values\n", "feature_columns = ['Product_Description', 'Product_Type']\n", "label = 'Sentiment'\n", "\n", "train_df = pd.read_csv('product_sentiment/train.csv', index_col=0).sample(subsample_size, random_state=123)\n", "dev_df = pd.read_csv('product_sentiment/dev.csv', index_col=0)\n", "test_df = pd.read_csv('product_sentiment/test.csv', index_col=0)\n", "\n", "train_df = train_df[feature_columns + [label]]\n", "dev_df = dev_df[feature_columns + [label]]\n", "test_df = test_df[feature_columns]\n", "print('Number of training samples:', len(train_df))\n", "print('Number of dev samples:', len(dev_df))\n", "print('Number of test samples:', len(test_df))" ] }, { "cell_type": "code", "execution_count": 6, "id": "65d36516", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Product_DescriptionProduct_TypeSentiment
4532they took away the lego pit but replaced it wi...01
1831#Apple to Open Pop-Up Shop at #SXSW [REPORT]: ...92
3536RT @mention False Alarm: Google Circles Not Co...51
5157Will Google reveal a new social network called...92
4643Niceness RT @mention Less than 2 hours until w...63
\n", "
" ], "text/plain": [ " Product_Description Product_Type \\\n", "4532 they took away the lego pit but replaced it wi... 0 \n", "1831 #Apple to Open Pop-Up Shop at #SXSW [REPORT]: ... 9 \n", "3536 RT @mention False Alarm: Google Circles Not Co... 5 \n", "5157 Will Google reveal a new social network called... 9 \n", "4643 Niceness RT @mention Less than 2 hours until w... 6 \n", "\n", " Sentiment \n", "4532 1 \n", "1831 2 \n", "3536 1 \n", "5157 2 \n", "4643 3 " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_df.head(5)" ] }, { "cell_type": "code", "execution_count": 7, "id": "c866791e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Beginning AutoGluon training ... Time limit = 60s\n", "AutoGluon will save models to \"ag-02-multimodal-tabularpredictor/\"\n", "AutoGluon Version: 0.5.2\n", "Python Version: 3.8.12\n", "Operating System: Linux\n", "Train Data Rows: 1000\n", "Train Data Columns: 2\n", "Label Column: Sentiment\n", "Preprocessing data ...\n", "AutoGluon infers your prediction problem is: 'multiclass' (because dtype of label-column == int, but few unique label-values observed).\n", "\t4 unique label values: [1, 2, 3, 0]\n", "\tIf 'multiclass' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression'])\n", "Train Data Class Count: 4\n", "Using Feature Generators to preprocess the data ...\n", "Fitting AutoMLPipelineFeatureGenerator...\n", "\tAvailable Memory: 10593.33 MB\n", "\tTrain Data (Original) Memory Usage: 0.17 MB (0.0% of available memory)\n", "\tInferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.\n", "\tStage 1 Generators:\n", "\t\tFitting AsTypeFeatureGenerator...\n", "\tStage 2 Generators:\n", "\t\tFitting FillNaFeatureGenerator...\n", "\tStage 3 Generators:\n", "\t\tFitting IdentityFeatureGenerator...\n", "\t\tFitting IdentityFeatureGenerator...\n", "\t\t\tFitting RenameFeatureGenerator...\n", "\t\tFitting CategoryFeatureGenerator...\n", "\t\t\tFitting CategoryMemoryMinimizeFeatureGenerator...\n", "\t\tFitting TextSpecialFeatureGenerator...\n", "\t\t\tFitting BinnedFeatureGenerator...\n", "\t\t\tFitting DropDuplicatesFeatureGenerator...\n", "\t\tFitting TextNgramFeatureGenerator...\n", "\t\t\tFitting CountVectorizer for text features: ['Product_Description']\n", "\t\t\tCountVectorizer fit with vocabulary size = 108\n", "\tStage 4 Generators:\n", "\t\tFitting DropUniqueFeatureGenerator...\n", "\tTypes of features in original data (raw dtype, special dtypes):\n", "\t\t('int', []) : 1 | ['Product_Type']\n", "\t\t('object', ['text']) : 1 | ['Product_Description']\n", "\tTypes of features in processed data (raw dtype, special dtypes):\n", "\t\t('int', []) : 1 | ['Product_Type']\n", "\t\t('int', ['binned', 'text_special']) : 28 | ['Product_Description.char_count', 'Product_Description.word_count', 'Product_Description.capital_ratio', 'Product_Description.lower_ratio', 'Product_Description.digit_ratio', ...]\n", "\t\t('int', ['text_ngram']) : 109 | ['__nlp__.about', '__nlp__.all', '__nlp__.amp', '__nlp__.an', '__nlp__.an ipad', ...]\n", "\t\t('object', ['text']) : 1 | ['Product_Description_raw_text']\n", "\t0.4s = Fit runtime\n", "\t2 features in original data used to generate 139 features in processed data.\n", "\tTrain Data (Processed) Memory Usage: 0.43 MB (0.0% of available memory)\n", "Data preprocessing and feature engineering runtime = 0.42s ...\n", "AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'\n", "\tTo change this, specify the eval_metric parameter of Predictor()\n", "Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 800, Val Rows: 200\n", "Fitting 9 L1 models ...\n", "Fitting model: LightGBM ... Training model for up to 59.58s of the 59.58s of remaining time.\n", "\t0.855\t = Validation score (accuracy)\n", "\t1.9s\t = Training runtime\n", "\t0.03s\t = Validation runtime\n", "Fitting model: LightGBMXT ... Training model for up to 57.63s of the 57.63s of remaining time.\n", "\t0.845\t = Validation score (accuracy)\n", "\t1.09s\t = Training runtime\n", "\t0.02s\t = Validation runtime\n", "Fitting model: CatBoost ... Training model for up to 56.5s of the 56.5s of remaining time.\n", "\t0.86\t = Validation score (accuracy)\n", "\t1.38s\t = Training runtime\n", "\t0.01s\t = Validation runtime\n", "Fitting model: XGBoost ... Training model for up to 55.11s of the 55.11s of remaining time.\n", "\t0.85\t = Validation score (accuracy)\n", "\t1.01s\t = Training runtime\n", "\t0.03s\t = Validation runtime\n", "Fitting model: NeuralNetTorch ... Training model for up to 54.07s of the 54.07s of remaining time.\n", "\t0.86\t = Validation score (accuracy)\n", "\t1.76s\t = Training runtime\n", "\t0.02s\t = Validation runtime\n", "Fitting model: VowpalWabbit ... Training model for up to 52.28s of the 52.28s of remaining time.\n", "\tWarning: Exception caused VowpalWabbit to fail during training (ImportError)... Skipping this model.\n", "\t\t`import vowpalwabbit` failed.\n", "A quick tip is to install via `pip install vowpalwabbit==8.10.1\n", "Fitting model: LightGBMLarge ... Training model for up to 52.26s of the 52.26s of remaining time.\n", "\t0.85\t = Validation score (accuracy)\n", "\t3.31s\t = Training runtime\n", "\t0.04s\t = Validation runtime\n", "Fitting model: TextPredictor ... Training model for up to 48.63s of the 48.63s of remaining time.\n", "Global seed set to 0\n", "Auto select gpus: [0]\n", "Using 16bit native Automatic Mixed Precision (AMP)\n", "GPU available: True, used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params\n", "-------------------------------------------------------------------\n", "0 | model | HFAutoModelForTextPrediction | 13.5 M\n", "1 | validation_metric | Accuracy | 0 \n", "2 | loss_func | CrossEntropyLoss | 0 \n", "-------------------------------------------------------------------\n", "13.5 M Trainable params\n", "0 Non-trainable params\n", "13.5 M Total params\n", "26.968 Total estimated model params size (MB)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Sanity Checking: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1ab41138df7c4a8bb93e561132c45b5d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Training: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 0, global step 3: 'val_accuracy' reached 0.58500 (best 0.58500), saving model to '/home/ec2-user/SageMaker/autogluon-on-aws/4.multimodal/ag-02-multimodal-tabularpredictor/models/TextPredictor/epoch=0-step=3.ckpt' as top 3\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 0, global step 7: 'val_accuracy' reached 0.58500 (best 0.58500), saving model to '/home/ec2-user/SageMaker/autogluon-on-aws/4.multimodal/ag-02-multimodal-tabularpredictor/models/TextPredictor/epoch=0-step=7.ckpt' as top 3\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 1, global step 10: 'val_accuracy' reached 0.61500 (best 0.61500), saving model to '/home/ec2-user/SageMaker/autogluon-on-aws/4.multimodal/ag-02-multimodal-tabularpredictor/models/TextPredictor/epoch=1-step=10.ckpt' as top 3\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 1, global step 14: 'val_accuracy' was not in top 3\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 2, global step 17: 'val_accuracy' reached 0.65500 (best 0.65500), saving model to '/home/ec2-user/SageMaker/autogluon-on-aws/4.multimodal/ag-02-multimodal-tabularpredictor/models/TextPredictor/epoch=2-step=17.ckpt' as top 3\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 2, global step 21: 'val_accuracy' reached 0.81000 (best 0.81000), saving model to '/home/ec2-user/SageMaker/autogluon-on-aws/4.multimodal/ag-02-multimodal-tabularpredictor/models/TextPredictor/epoch=2-step=21.ckpt' as top 3\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 3, global step 24: 'val_accuracy' reached 0.76500 (best 0.81000), saving model to '/home/ec2-user/SageMaker/autogluon-on-aws/4.multimodal/ag-02-multimodal-tabularpredictor/models/TextPredictor/epoch=3-step=24.ckpt' as top 3\n", "Time limit reached. Elapsed time is 0:00:48. Signaling Trainer to stop.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 3, global step 26: 'val_accuracy' reached 0.82500 (best 0.82500), saving model to '/home/ec2-user/SageMaker/autogluon-on-aws/4.multimodal/ag-02-multimodal-tabularpredictor/models/TextPredictor/epoch=3-step=26.ckpt' as top 3\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f859eba3626a4aecbc5551ecf2d181a0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "58b211076ef344fb8d45dc0cd17ca6f6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "65cdb397e3d44c5ab1558707cb272de9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f396907d2c8b4cf486c01fde712ad699", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\t0.85\t = Validation score (accuracy)\n", "\t67.68s\t = Training runtime\n", "\t0.6s\t = Validation runtime\n", "Fitting model: WeightedEnsemble_L2 ... Training model for up to 59.58s of the -19.95s of remaining time.\n", "\t0.87\t = Validation score (accuracy)\n", "\t0.22s\t = Training runtime\n", "\t0.0s\t = Validation runtime\n", "AutoGluon training complete, total runtime = 80.3s ... Best model: \"WeightedEnsemble_L2\"\n", "TabularPredictor saved. To load, use: predictor = TabularPredictor.load(\"ag-02-multimodal-tabularpredictor/\")\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from autogluon.tabular import TabularPredictor\n", "predictor = TabularPredictor(label='Sentiment', path=save_path)\n", "predictor.fit(train_df, hyperparameters='multimodal', time_limit=60)" ] }, { "cell_type": "markdown", "id": "316116f8", "metadata": {}, "source": [ "
\n", "\n", "## 2. Evaluation and Prediction" ] }, { "cell_type": "code", "execution_count": 8, "id": "da1b283d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Load pretrained checkpoint: ag-02-multimodal-tabularpredictor/models/TextPredictor/text_nn/model.ckpt\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "768ed30f668b44ba8e7fdc1b7efacb27", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " model score_test score_val pred_time_test pred_time_val fit_time pred_time_test_marginal pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order\n", "0 NeuralNetTorch 0.886970 0.860 0.089318 0.024708 1.755834 0.089318 0.024708 1.755834 1 True 5\n", "1 WeightedEnsemble_L2 0.886970 0.870 0.239991 0.060011 3.696689 0.005654 0.000496 0.223408 2 True 8\n", "2 CatBoost 0.885400 0.860 0.054497 0.008354 1.378052 0.054497 0.008354 1.378052 1 True 3\n", "3 TextPredictor 0.882261 0.850 1.998655 0.596547 67.676712 1.998655 0.596547 67.676712 1 True 7\n", "4 LightGBMLarge 0.879121 0.850 0.363917 0.036365 3.314543 0.363917 0.036365 3.314543 1 True 6\n", "5 LightGBM 0.877551 0.855 0.097271 0.026469 1.899688 0.097271 0.026469 1.899688 1 True 1\n", "6 XGBoost 0.875981 0.850 0.117441 0.030043 1.007675 0.117441 0.030043 1.007675 1 True 4\n", "7 LightGBMXT 0.866562 0.845 0.062399 0.021118 1.087554 0.062399 0.021118 1.087554 1 True 2\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
modelscore_testscore_valpred_time_testpred_time_valfit_timepred_time_test_marginalpred_time_val_marginalfit_time_marginalstack_levelcan_inferfit_order
0NeuralNetTorch0.8869700.8600.0893180.0247081.7558340.0893180.0247081.7558341True5
1WeightedEnsemble_L20.8869700.8700.2399910.0600113.6966890.0056540.0004960.2234082True8
2CatBoost0.8854000.8600.0544970.0083541.3780520.0544970.0083541.3780521True3
3TextPredictor0.8822610.8501.9986550.59654767.6767121.9986550.59654767.6767121True7
4LightGBMLarge0.8791210.8500.3639170.0363653.3145430.3639170.0363653.3145431True6
5LightGBM0.8775510.8550.0972710.0264691.8996880.0972710.0264691.8996881True1
6XGBoost0.8759810.8500.1174410.0300431.0076750.1174410.0300431.0076751True4
7LightGBMXT0.8665620.8450.0623990.0211181.0875540.0623990.0211181.0875541True2
\n", "
" ], "text/plain": [ " model score_test score_val pred_time_test pred_time_val \\\n", "0 NeuralNetTorch 0.886970 0.860 0.089318 0.024708 \n", "1 WeightedEnsemble_L2 0.886970 0.870 0.239991 0.060011 \n", "2 CatBoost 0.885400 0.860 0.054497 0.008354 \n", "3 TextPredictor 0.882261 0.850 1.998655 0.596547 \n", "4 LightGBMLarge 0.879121 0.850 0.363917 0.036365 \n", "5 LightGBM 0.877551 0.855 0.097271 0.026469 \n", "6 XGBoost 0.875981 0.850 0.117441 0.030043 \n", "7 LightGBMXT 0.866562 0.845 0.062399 0.021118 \n", "\n", " fit_time pred_time_test_marginal pred_time_val_marginal \\\n", "0 1.755834 0.089318 0.024708 \n", "1 3.696689 0.005654 0.000496 \n", "2 1.378052 0.054497 0.008354 \n", "3 67.676712 1.998655 0.596547 \n", "4 3.314543 0.363917 0.036365 \n", "5 1.899688 0.097271 0.026469 \n", "6 1.007675 0.117441 0.030043 \n", "7 1.087554 0.062399 0.021118 \n", "\n", " fit_time_marginal stack_level can_infer fit_order \n", "0 1.755834 1 True 5 \n", "1 0.223408 2 True 8 \n", "2 1.378052 1 True 3 \n", "3 67.676712 1 True 7 \n", "4 3.314543 1 True 6 \n", "5 1.899688 1 True 1 \n", "6 1.007675 1 True 4 \n", "7 1.087554 1 True 2 " ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.leaderboard(dev_df)" ] }, { "cell_type": "code", "execution_count": 9, "id": "5e3581a2", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Evaluation: accuracy on test data: 0.8869701726844584\n", "Evaluations on test data:\n", "{\n", " \"accuracy\": 0.8869701726844584,\n", " \"balanced_accuracy\": 0.48100029507697384,\n", " \"mcc\": 0.7830290724042951\n", "}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'accuracy': 0.8869701726844584, 'balanced_accuracy': 0.48100029507697384, 'mcc': 0.7830290724042951}\n" ] } ], "source": [ "score = predictor.evaluate(dev_df)\n", "print(score)" ] } ], "metadata": { "kernelspec": { "display_name": "conda_pytorch_p38", "language": "python", "name": "conda_pytorch_p38" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }