{ "cells": [ { "cell_type": "markdown", "id": "7f4b5902", "metadata": {}, "source": [ "# Multimodal Training/Prediction for Text + Tabular \n", "\n", "AutoGluon은 자동으로 멀티모달(multimodal) 모델을 학습하는 AutoMM 클래스를 0.5.1버전부터 지원하기에, 이미지/텍스트/tabular가 혼용된 데이터를 쉽게 훈련할 수 있습니다.\n", "뿐만 아니라, 유니모달(unimodal) 딥러닝 모델(이미지, 텍스트, tabular)의 훈련도 지원하기에, 머신 러닝 비전문가들도 쉽게 활용 가능합니다. \n", "\n", "훈련에 필요한 데이터셋은 기존 tabular 데이터셋에서 텍스트 열(예: 영화 리뷰 컨텐츠 텍스트)를 그대로 추가하면 되며, 모델 훈련 시 데이터 범주에 따라 아래와 같은 임베딩을 수행합니다. \n", "\n", "- 텍스트 열: 사전 훈련된 Transformer(BERT) backbone을 사용하여 데이터를 임베딩합니다. \n", "- 범주형 열: Embedding-MLP을 통해 범주형 데이터를 임베딩합니다.\n", "- 수치형 열: 표준 MLP를 사용하여 범주형 데이터를 임베딩합니다.\n", "\n", "## TextPredictor & TabularPredictor\n", "\n", "`MultiModalPredictor` 외에 `TextPredictor`나 `TabularPredictor`로도 멀티모달 모델 훈련이 가능합니다.\n", "만약 트랜스포머 임베딩 대신, XGBoost/LightGBM/CatBoost 등의 Gradient Boosted Tree 결과를 앙상블 및 스태킹하여 모델링하고 싶다면, `TabularPredictor.fit (..., hyperparameters = 'multimodal')`로 훈련하세요. 주의할 점은 `hyperparameters = 'multimodal'`을 지정하지 않고 훈련 시, AutoGluon Tabular는 텍스트 데이터를 N-gram으로 자동으로 변환 후 tabular 모델로만 훈련합니다. " ] }, { "cell_type": "code", "execution_count": 1, "id": "538764ba", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1\n" ] } ], "source": [ "import os\n", "import torch\n", "num_gpus = torch.cuda.device_count()\n", "\n", "if num_gpus == 0:\n", " os.environ['AUTOGLUON_TEXT_TRAIN_WITHOUT_GPU'] = '1'\n", "\n", "print(num_gpus) " ] }, { "cell_type": "code", "execution_count": 2, "id": "4fa5305e", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import os\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "np.random.seed(123)\n", "\n", "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 3, "id": "4e2b757a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com\n", "Requirement already satisfied: openpyxl in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (3.0.9)\n", "Requirement already satisfied: et-xmlfile in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from openpyxl) (1.0.1)\n", "\u001b[33mWARNING: You are using pip version 22.0.4; however, version 22.2.2 is available.\n", "You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p38/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!python3 -m pip install openpyxl" ] }, { "cell_type": "markdown", "id": "7e17f828", "metadata": {}, "source": [ "
\n", "\n", "## 1. Data preparation and Training\n", "\n", "본 핸즈온은 MachineHack Salary Prediction Hackathon의 도서 가격 예측 데이터셋을 사용합니다. 제목(Title), 저자(Author), 리뷰 평점(Reviews) 등과 같은 다양한 피쳐를 고려하여 도서의 가격을 예측합니다." ] }, { "cell_type": "markdown", "id": "568e3885", "metadata": {}, "source": [ "### Data preparation" ] }, { "cell_type": "code", "execution_count": 4, "id": "81fe7827", "metadata": {}, "outputs": [], "source": [ "save_path = 'ag-01-multimodal-text-tabular'\n", "!rm -rf $save_path price_of_books" ] }, { "cell_type": "code", "execution_count": 5, "id": "7c91ab7e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2022-08-30 00:17:32-- https://automl-mm-bench.s3.amazonaws.com/machine_hack_competitions/predict_the_price_of_books/Data.zip\n", "Resolving automl-mm-bench.s3.amazonaws.com (automl-mm-bench.s3.amazonaws.com)... 52.217.111.100\n", "Connecting to automl-mm-bench.s3.amazonaws.com (automl-mm-bench.s3.amazonaws.com)|52.217.111.100|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 3521673 (3.4M) [application/zip]\n", "Saving to: ‘price_of_books/Data.zip’\n", "\n", "100%[======================================>] 3,521,673 --.-K/s in 0.1s \n", "\n", "2022-08-30 00:17:32 (26.2 MB/s) - ‘price_of_books/Data.zip’ saved [3521673/3521673]\n", "\n", "Archive: Data.zip\n", " inflating: Participants_Data/Data_Test.xlsx \n", " inflating: Participants_Data/Data_Train.xlsx \n", " inflating: Participants_Data/Sample_Submission.xlsx \n", "Data_Test.xlsx\tData_Train.xlsx Sample_Submission.xlsx\n" ] } ], "source": [ "!mkdir -p price_of_books\n", "!wget https://automl-mm-bench.s3.amazonaws.com/machine_hack_competitions/predict_the_price_of_books/Data.zip -O price_of_books/Data.zip\n", "!cd price_of_books && unzip -o Data.zip\n", "!ls price_of_books/Participants_Data" ] }, { "cell_type": "code", "execution_count": 6, "id": "991d67bf", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
TitleAuthorEditionReviewsRatingsSynopsisGenreBookCategoryPrice
0The Prisoner's Gold (The Hunters 3)Chris KuzneskiPaperback,– 10 Mar 20164.0 out of 5 stars8 customer reviewsTHE HUNTERS return in their third brilliant no...Action & Adventure (Books)Action & Adventure220.00
1Guru Dutt: A Tragedy in Three ActsArun KhopkarPaperback,– 7 Nov 20123.9 out of 5 stars14 customer reviewsA layered portrait of a troubled genius for wh...Cinema & Broadcast (Books)Biographies, Diaries & True Accounts202.93
2Leviathan (Penguin Classics)Thomas HobbesPaperback,– 25 Feb 19824.8 out of 5 stars6 customer reviews\"During the time men live without a common Pow...International RelationsHumour299.00
3A Pocket Full of Rye (Miss Marple)Agatha ChristiePaperback,– 5 Oct 20174.1 out of 5 stars13 customer reviewsA handful of grain is found in the pocket of a...Contemporary Fiction (Books)Crime, Thriller & Mystery180.00
4LIFE 70 Years of Extraordinary PhotographyEditors of LifeHardcover,– 10 Oct 20065.0 out of 5 stars1 customer reviewFor seven decades, \"Life\" has been thrilling t...Photography TextbooksArts, Film & Photography965.62
\n", "
" ], "text/plain": [ " Title Author \\\n", "0 The Prisoner's Gold (The Hunters 3) Chris Kuzneski \n", "1 Guru Dutt: A Tragedy in Three Acts Arun Khopkar \n", "2 Leviathan (Penguin Classics) Thomas Hobbes \n", "3 A Pocket Full of Rye (Miss Marple) Agatha Christie \n", "4 LIFE 70 Years of Extraordinary Photography Editors of Life \n", "\n", " Edition Reviews Ratings \\\n", "0 Paperback,– 10 Mar 2016 4.0 out of 5 stars 8 customer reviews \n", "1 Paperback,– 7 Nov 2012 3.9 out of 5 stars 14 customer reviews \n", "2 Paperback,– 25 Feb 1982 4.8 out of 5 stars 6 customer reviews \n", "3 Paperback,– 5 Oct 2017 4.1 out of 5 stars 13 customer reviews \n", "4 Hardcover,– 10 Oct 2006 5.0 out of 5 stars 1 customer review \n", "\n", " Synopsis \\\n", "0 THE HUNTERS return in their third brilliant no... \n", "1 A layered portrait of a troubled genius for wh... \n", "2 \"During the time men live without a common Pow... \n", "3 A handful of grain is found in the pocket of a... \n", "4 For seven decades, \"Life\" has been thrilling t... \n", "\n", " Genre BookCategory Price \n", "0 Action & Adventure (Books) Action & Adventure 220.00 \n", "1 Cinema & Broadcast (Books) Biographies, Diaries & True Accounts 202.93 \n", "2 International Relations Humour 299.00 \n", "3 Contemporary Fiction (Books) Crime, Thriller & Mystery 180.00 \n", "4 Photography Textbooks Arts, Film & Photography 965.62 " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_df = pd.read_excel(os.path.join('price_of_books', 'Participants_Data', 'Data_Train.xlsx'), engine='openpyxl')\n", "train_df.head()" ] }, { "cell_type": "code", "execution_count": 7, "id": "2df2e37e", "metadata": {}, "outputs": [], "source": [ "def preprocess(df):\n", " df = df.copy(deep=True)\n", " df.loc[:, 'Reviews'] = pd.to_numeric(df['Reviews'].apply(lambda ele: ele[:-len(' out of 5 stars')]))\n", " df.loc[:, 'Ratings'] = pd.to_numeric(df['Ratings'].apply(lambda ele: ele.replace(',', '')[:-len(' customer reviews')]))\n", " df.loc[:, 'Price'] = np.log(df['Price'] + 1)\n", " return df" ] }, { "cell_type": "code", "execution_count": 8, "id": "8407fb07", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
TitleAuthorEditionReviewsRatingsSynopsisGenreBookCategoryPrice
949Furious HoursCasey CepPaperback,– 1 Jun 20194.0NaN‘It’s been a long time since I picked up a boo...True Accounts (Books)Biographies, Diaries & True Accounts5.743003
5504REST API Design RulebookMark MassePaperback,– 7 Nov 20115.0NaNIn todays market, where rival web services com...Computing, Internet & Digital Media (Books)Computing, Internet & Digital Media5.786897
5856The Atlantropa Articles: A NovelCody FranklinPaperback,– Import, 1 Nov 20184.52.0#1 Amazon Best Seller! Dystopian Alternate His...Action & Adventure (Books)Romance6.893656
4137Hickory Dickory Dock (Poirot)Agatha ChristiePaperback,– 5 Oct 20174.321.0There’s more than petty theft going on in a Lo...Action & Adventure (Books)Crime, Thriller & Mystery5.192957
3205The Stanley Kubrick Archives (Bibliotheca Univ...Alison CastleHardcover,– 21 Aug 20164.63.0In 1968, when Stanley Kubrick was asked to com...Cinema & Broadcast (Books)Humour6.889591
\n", "
" ], "text/plain": [ " Title Author \\\n", "949 Furious Hours Casey Cep \n", "5504 REST API Design Rulebook Mark Masse \n", "5856 The Atlantropa Articles: A Novel Cody Franklin \n", "4137 Hickory Dickory Dock (Poirot) Agatha Christie \n", "3205 The Stanley Kubrick Archives (Bibliotheca Univ... Alison Castle \n", "\n", " Edition Reviews Ratings \\\n", "949 Paperback,– 1 Jun 2019 4.0 NaN \n", "5504 Paperback,– 7 Nov 2011 5.0 NaN \n", "5856 Paperback,– Import, 1 Nov 2018 4.5 2.0 \n", "4137 Paperback,– 5 Oct 2017 4.3 21.0 \n", "3205 Hardcover,– 21 Aug 2016 4.6 3.0 \n", "\n", " Synopsis \\\n", "949 ‘It’s been a long time since I picked up a boo... \n", "5504 In todays market, where rival web services com... \n", "5856 #1 Amazon Best Seller! Dystopian Alternate His... \n", "4137 There’s more than petty theft going on in a Lo... \n", "3205 In 1968, when Stanley Kubrick was asked to com... \n", "\n", " Genre \\\n", "949 True Accounts (Books) \n", "5504 Computing, Internet & Digital Media (Books) \n", "5856 Action & Adventure (Books) \n", "4137 Action & Adventure (Books) \n", "3205 Cinema & Broadcast (Books) \n", "\n", " BookCategory Price \n", "949 Biographies, Diaries & True Accounts 5.743003 \n", "5504 Computing, Internet & Digital Media 5.786897 \n", "5856 Romance 6.893656 \n", "4137 Crime, Thriller & Mystery 5.192957 \n", "3205 Humour 6.889591 " ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_subsample_size = 1500 # subsample for faster demo, you can try setting to larger values\n", "test_subsample_size = 5\n", "train_df = preprocess(train_df)\n", "train_data = train_df.iloc[100:].sample(train_subsample_size, random_state=123)\n", "test_data = train_df.iloc[:100].sample(test_subsample_size, random_state=245)\n", "train_data.head()" ] }, { "cell_type": "markdown", "id": "0bf0d437", "metadata": {}, "source": [ "### Training\n", "\n", "시간을 절약하기 위해 데이터를 서브샘플링하고 2분 동안만 훈련합니다. 도서의 가격을 예측하는 Regression 모델이기에, RMSE(Root Mean Squared Error)가 디폴트 metric입니다." ] }, { "cell_type": "code", "execution_count": null, "id": "92a2f09b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Global seed set to 123\n", "Auto select gpus: [0]\n", "Using 16bit native Automatic Mixed Precision (AMP)\n", "GPU available: True, used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params\n", "----------------------------------------------------------\n", "0 | model | MultimodalFusionMLP | 109 M \n", "1 | validation_metric | MeanSquaredError | 0 \n", "2 | loss_func | MSELoss | 0 \n", "----------------------------------------------------------\n", "109 M Trainable params\n", "0 Non-trainable params\n", "109 M Total params\n", "219.565 Total estimated model params size (MB)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Sanity Checking: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a57adcb676984b8ca2d8e78da14f930f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Training: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 0, global step 4: 'val_rmse' reached 1.52672 (best 1.52672), saving model to '/home/ec2-user/SageMaker/autogluon-on-aws/4.multimodal/ag-01-multimodal-text-tabular/epoch=0-step=4.ckpt' as top 3\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 0, global step 10: 'val_rmse' reached 1.09024 (best 1.09024), saving model to '/home/ec2-user/SageMaker/autogluon-on-aws/4.multimodal/ag-01-multimodal-text-tabular/epoch=0-step=10.ckpt' as top 3\n", "Time limit reached. Elapsed time is 0:02:00. Signaling Trainer to stop.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 1, global step 14: 'val_rmse' reached 0.88354 (best 0.88354), saving model to '/home/ec2-user/SageMaker/autogluon-on-aws/4.multimodal/ag-01-multimodal-text-tabular/epoch=1-step=14.ckpt' as top 3\n" ] } ], "source": [ "from autogluon.multimodal import MultiModalPredictor\n", "time_limit = 2 * 60 # set to larger value in your applications\n", "predictor = MultiModalPredictor(label='Price', path=save_path)\n", "predictor.fit(train_data, time_limit=time_limit)" ] }, { "cell_type": "markdown", "id": "cc159166", "metadata": {}, "source": [ "
\n", "\n", "## 2. Prediction" ] }, { "cell_type": "code", "execution_count": null, "id": "6693fb26", "metadata": {}, "outputs": [], "source": [ "predictions = predictor.predict(test_data)\n", "print('Predictions:')\n", "print('------------')\n", "print(np.exp(predictions) - 1)\n", "print()\n", "print('True Value:')\n", "print('------------')\n", "print(np.exp(test_data['Price']) - 1)" ] }, { "cell_type": "code", "execution_count": null, "id": "54ca8eb7", "metadata": {}, "outputs": [], "source": [ "performance = predictor.evaluate(test_data)\n", "print(performance)" ] }, { "cell_type": "code", "execution_count": null, "id": "83df1981", "metadata": {}, "outputs": [], "source": [ "embeddings = predictor.extract_embedding(test_data)\n", "embeddings.shape" ] } ], "metadata": { "kernelspec": { "display_name": "conda_pytorch_p38", "language": "python", "name": "conda_pytorch_p38" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }