{
"cells": [
{
"cell_type": "markdown",
"id": "7f4b5902",
"metadata": {},
"source": [
"# Multimodal Training/Prediction for Text + Tabular \n",
"\n",
"AutoGluon은 자동으로 멀티모달(multimodal) 모델을 학습하는 AutoMM 클래스를 0.5.1버전부터 지원하기에, 이미지/텍스트/tabular가 혼용된 데이터를 쉽게 훈련할 수 있습니다.\n",
"뿐만 아니라, 유니모달(unimodal) 딥러닝 모델(이미지, 텍스트, tabular)의 훈련도 지원하기에, 머신 러닝 비전문가들도 쉽게 활용 가능합니다. \n",
"\n",
"훈련에 필요한 데이터셋은 기존 tabular 데이터셋에서 텍스트 열(예: 영화 리뷰 컨텐츠 텍스트)를 그대로 추가하면 되며, 모델 훈련 시 데이터 범주에 따라 아래와 같은 임베딩을 수행합니다. \n",
"\n",
"- 텍스트 열: 사전 훈련된 Transformer(BERT) backbone을 사용하여 데이터를 임베딩합니다. \n",
"- 범주형 열: Embedding-MLP을 통해 범주형 데이터를 임베딩합니다.\n",
"- 수치형 열: 표준 MLP를 사용하여 범주형 데이터를 임베딩합니다.\n",
"\n",
"## TextPredictor & TabularPredictor\n",
"\n",
"`MultiModalPredictor` 외에 `TextPredictor`나 `TabularPredictor`로도 멀티모달 모델 훈련이 가능합니다.\n",
"만약 트랜스포머 임베딩 대신, XGBoost/LightGBM/CatBoost 등의 Gradient Boosted Tree 결과를 앙상블 및 스태킹하여 모델링하고 싶다면, `TabularPredictor.fit (..., hyperparameters = 'multimodal')`로 훈련하세요. 주의할 점은 `hyperparameters = 'multimodal'`을 지정하지 않고 훈련 시, AutoGluon Tabular는 텍스트 데이터를 N-gram으로 자동으로 변환 후 tabular 모델로만 훈련합니다. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "538764ba",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n"
]
}
],
"source": [
"import os\n",
"import torch\n",
"num_gpus = torch.cuda.device_count()\n",
"\n",
"if num_gpus == 0:\n",
" os.environ['AUTOGLUON_TEXT_TRAIN_WITHOUT_GPU'] = '1'\n",
"\n",
"print(num_gpus) "
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4fa5305e",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import os\n",
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"np.random.seed(123)\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4e2b757a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com\n",
"Requirement already satisfied: openpyxl in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (3.0.9)\n",
"Requirement already satisfied: et-xmlfile in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from openpyxl) (1.0.1)\n",
"\u001b[33mWARNING: You are using pip version 22.0.4; however, version 22.2.2 is available.\n",
"You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p38/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\u001b[33m\n",
"\u001b[0m"
]
}
],
"source": [
"!python3 -m pip install openpyxl"
]
},
{
"cell_type": "markdown",
"id": "7e17f828",
"metadata": {},
"source": [
"
\n",
"\n",
"## 1. Data preparation and Training\n",
"\n",
"본 핸즈온은 MachineHack Salary Prediction Hackathon의 도서 가격 예측 데이터셋을 사용합니다. 제목(Title), 저자(Author), 리뷰 평점(Reviews) 등과 같은 다양한 피쳐를 고려하여 도서의 가격을 예측합니다."
]
},
{
"cell_type": "markdown",
"id": "568e3885",
"metadata": {},
"source": [
"### Data preparation"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "81fe7827",
"metadata": {},
"outputs": [],
"source": [
"save_path = 'ag-01-multimodal-text-tabular'\n",
"!rm -rf $save_path price_of_books"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "7c91ab7e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2022-08-30 00:17:32-- https://automl-mm-bench.s3.amazonaws.com/machine_hack_competitions/predict_the_price_of_books/Data.zip\n",
"Resolving automl-mm-bench.s3.amazonaws.com (automl-mm-bench.s3.amazonaws.com)... 52.217.111.100\n",
"Connecting to automl-mm-bench.s3.amazonaws.com (automl-mm-bench.s3.amazonaws.com)|52.217.111.100|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 3521673 (3.4M) [application/zip]\n",
"Saving to: ‘price_of_books/Data.zip’\n",
"\n",
"100%[======================================>] 3,521,673 --.-K/s in 0.1s \n",
"\n",
"2022-08-30 00:17:32 (26.2 MB/s) - ‘price_of_books/Data.zip’ saved [3521673/3521673]\n",
"\n",
"Archive: Data.zip\n",
" inflating: Participants_Data/Data_Test.xlsx \n",
" inflating: Participants_Data/Data_Train.xlsx \n",
" inflating: Participants_Data/Sample_Submission.xlsx \n",
"Data_Test.xlsx\tData_Train.xlsx Sample_Submission.xlsx\n"
]
}
],
"source": [
"!mkdir -p price_of_books\n",
"!wget https://automl-mm-bench.s3.amazonaws.com/machine_hack_competitions/predict_the_price_of_books/Data.zip -O price_of_books/Data.zip\n",
"!cd price_of_books && unzip -o Data.zip\n",
"!ls price_of_books/Participants_Data"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "991d67bf",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Title | \n",
" Author | \n",
" Edition | \n",
" Reviews | \n",
" Ratings | \n",
" Synopsis | \n",
" Genre | \n",
" BookCategory | \n",
" Price | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" The Prisoner's Gold (The Hunters 3) | \n",
" Chris Kuzneski | \n",
" Paperback,– 10 Mar 2016 | \n",
" 4.0 out of 5 stars | \n",
" 8 customer reviews | \n",
" THE HUNTERS return in their third brilliant no... | \n",
" Action & Adventure (Books) | \n",
" Action & Adventure | \n",
" 220.00 | \n",
"
\n",
" \n",
" 1 | \n",
" Guru Dutt: A Tragedy in Three Acts | \n",
" Arun Khopkar | \n",
" Paperback,– 7 Nov 2012 | \n",
" 3.9 out of 5 stars | \n",
" 14 customer reviews | \n",
" A layered portrait of a troubled genius for wh... | \n",
" Cinema & Broadcast (Books) | \n",
" Biographies, Diaries & True Accounts | \n",
" 202.93 | \n",
"
\n",
" \n",
" 2 | \n",
" Leviathan (Penguin Classics) | \n",
" Thomas Hobbes | \n",
" Paperback,– 25 Feb 1982 | \n",
" 4.8 out of 5 stars | \n",
" 6 customer reviews | \n",
" \"During the time men live without a common Pow... | \n",
" International Relations | \n",
" Humour | \n",
" 299.00 | \n",
"
\n",
" \n",
" 3 | \n",
" A Pocket Full of Rye (Miss Marple) | \n",
" Agatha Christie | \n",
" Paperback,– 5 Oct 2017 | \n",
" 4.1 out of 5 stars | \n",
" 13 customer reviews | \n",
" A handful of grain is found in the pocket of a... | \n",
" Contemporary Fiction (Books) | \n",
" Crime, Thriller & Mystery | \n",
" 180.00 | \n",
"
\n",
" \n",
" 4 | \n",
" LIFE 70 Years of Extraordinary Photography | \n",
" Editors of Life | \n",
" Hardcover,– 10 Oct 2006 | \n",
" 5.0 out of 5 stars | \n",
" 1 customer review | \n",
" For seven decades, \"Life\" has been thrilling t... | \n",
" Photography Textbooks | \n",
" Arts, Film & Photography | \n",
" 965.62 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Title Author \\\n",
"0 The Prisoner's Gold (The Hunters 3) Chris Kuzneski \n",
"1 Guru Dutt: A Tragedy in Three Acts Arun Khopkar \n",
"2 Leviathan (Penguin Classics) Thomas Hobbes \n",
"3 A Pocket Full of Rye (Miss Marple) Agatha Christie \n",
"4 LIFE 70 Years of Extraordinary Photography Editors of Life \n",
"\n",
" Edition Reviews Ratings \\\n",
"0 Paperback,– 10 Mar 2016 4.0 out of 5 stars 8 customer reviews \n",
"1 Paperback,– 7 Nov 2012 3.9 out of 5 stars 14 customer reviews \n",
"2 Paperback,– 25 Feb 1982 4.8 out of 5 stars 6 customer reviews \n",
"3 Paperback,– 5 Oct 2017 4.1 out of 5 stars 13 customer reviews \n",
"4 Hardcover,– 10 Oct 2006 5.0 out of 5 stars 1 customer review \n",
"\n",
" Synopsis \\\n",
"0 THE HUNTERS return in their third brilliant no... \n",
"1 A layered portrait of a troubled genius for wh... \n",
"2 \"During the time men live without a common Pow... \n",
"3 A handful of grain is found in the pocket of a... \n",
"4 For seven decades, \"Life\" has been thrilling t... \n",
"\n",
" Genre BookCategory Price \n",
"0 Action & Adventure (Books) Action & Adventure 220.00 \n",
"1 Cinema & Broadcast (Books) Biographies, Diaries & True Accounts 202.93 \n",
"2 International Relations Humour 299.00 \n",
"3 Contemporary Fiction (Books) Crime, Thriller & Mystery 180.00 \n",
"4 Photography Textbooks Arts, Film & Photography 965.62 "
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_df = pd.read_excel(os.path.join('price_of_books', 'Participants_Data', 'Data_Train.xlsx'), engine='openpyxl')\n",
"train_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "2df2e37e",
"metadata": {},
"outputs": [],
"source": [
"def preprocess(df):\n",
" df = df.copy(deep=True)\n",
" df.loc[:, 'Reviews'] = pd.to_numeric(df['Reviews'].apply(lambda ele: ele[:-len(' out of 5 stars')]))\n",
" df.loc[:, 'Ratings'] = pd.to_numeric(df['Ratings'].apply(lambda ele: ele.replace(',', '')[:-len(' customer reviews')]))\n",
" df.loc[:, 'Price'] = np.log(df['Price'] + 1)\n",
" return df"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "8407fb07",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Title | \n",
" Author | \n",
" Edition | \n",
" Reviews | \n",
" Ratings | \n",
" Synopsis | \n",
" Genre | \n",
" BookCategory | \n",
" Price | \n",
"
\n",
" \n",
" \n",
" \n",
" 949 | \n",
" Furious Hours | \n",
" Casey Cep | \n",
" Paperback,– 1 Jun 2019 | \n",
" 4.0 | \n",
" NaN | \n",
" ‘It’s been a long time since I picked up a boo... | \n",
" True Accounts (Books) | \n",
" Biographies, Diaries & True Accounts | \n",
" 5.743003 | \n",
"
\n",
" \n",
" 5504 | \n",
" REST API Design Rulebook | \n",
" Mark Masse | \n",
" Paperback,– 7 Nov 2011 | \n",
" 5.0 | \n",
" NaN | \n",
" In todays market, where rival web services com... | \n",
" Computing, Internet & Digital Media (Books) | \n",
" Computing, Internet & Digital Media | \n",
" 5.786897 | \n",
"
\n",
" \n",
" 5856 | \n",
" The Atlantropa Articles: A Novel | \n",
" Cody Franklin | \n",
" Paperback,– Import, 1 Nov 2018 | \n",
" 4.5 | \n",
" 2.0 | \n",
" #1 Amazon Best Seller! Dystopian Alternate His... | \n",
" Action & Adventure (Books) | \n",
" Romance | \n",
" 6.893656 | \n",
"
\n",
" \n",
" 4137 | \n",
" Hickory Dickory Dock (Poirot) | \n",
" Agatha Christie | \n",
" Paperback,– 5 Oct 2017 | \n",
" 4.3 | \n",
" 21.0 | \n",
" There’s more than petty theft going on in a Lo... | \n",
" Action & Adventure (Books) | \n",
" Crime, Thriller & Mystery | \n",
" 5.192957 | \n",
"
\n",
" \n",
" 3205 | \n",
" The Stanley Kubrick Archives (Bibliotheca Univ... | \n",
" Alison Castle | \n",
" Hardcover,– 21 Aug 2016 | \n",
" 4.6 | \n",
" 3.0 | \n",
" In 1968, when Stanley Kubrick was asked to com... | \n",
" Cinema & Broadcast (Books) | \n",
" Humour | \n",
" 6.889591 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Title Author \\\n",
"949 Furious Hours Casey Cep \n",
"5504 REST API Design Rulebook Mark Masse \n",
"5856 The Atlantropa Articles: A Novel Cody Franklin \n",
"4137 Hickory Dickory Dock (Poirot) Agatha Christie \n",
"3205 The Stanley Kubrick Archives (Bibliotheca Univ... Alison Castle \n",
"\n",
" Edition Reviews Ratings \\\n",
"949 Paperback,– 1 Jun 2019 4.0 NaN \n",
"5504 Paperback,– 7 Nov 2011 5.0 NaN \n",
"5856 Paperback,– Import, 1 Nov 2018 4.5 2.0 \n",
"4137 Paperback,– 5 Oct 2017 4.3 21.0 \n",
"3205 Hardcover,– 21 Aug 2016 4.6 3.0 \n",
"\n",
" Synopsis \\\n",
"949 ‘It’s been a long time since I picked up a boo... \n",
"5504 In todays market, where rival web services com... \n",
"5856 #1 Amazon Best Seller! Dystopian Alternate His... \n",
"4137 There’s more than petty theft going on in a Lo... \n",
"3205 In 1968, when Stanley Kubrick was asked to com... \n",
"\n",
" Genre \\\n",
"949 True Accounts (Books) \n",
"5504 Computing, Internet & Digital Media (Books) \n",
"5856 Action & Adventure (Books) \n",
"4137 Action & Adventure (Books) \n",
"3205 Cinema & Broadcast (Books) \n",
"\n",
" BookCategory Price \n",
"949 Biographies, Diaries & True Accounts 5.743003 \n",
"5504 Computing, Internet & Digital Media 5.786897 \n",
"5856 Romance 6.893656 \n",
"4137 Crime, Thriller & Mystery 5.192957 \n",
"3205 Humour 6.889591 "
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_subsample_size = 1500 # subsample for faster demo, you can try setting to larger values\n",
"test_subsample_size = 5\n",
"train_df = preprocess(train_df)\n",
"train_data = train_df.iloc[100:].sample(train_subsample_size, random_state=123)\n",
"test_data = train_df.iloc[:100].sample(test_subsample_size, random_state=245)\n",
"train_data.head()"
]
},
{
"cell_type": "markdown",
"id": "0bf0d437",
"metadata": {},
"source": [
"### Training\n",
"\n",
"시간을 절약하기 위해 데이터를 서브샘플링하고 2분 동안만 훈련합니다. 도서의 가격을 예측하는 Regression 모델이기에, RMSE(Root Mean Squared Error)가 디폴트 metric입니다."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "92a2f09b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Global seed set to 123\n",
"Auto select gpus: [0]\n",
"Using 16bit native Automatic Mixed Precision (AMP)\n",
"GPU available: True, used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
"\n",
" | Name | Type | Params\n",
"----------------------------------------------------------\n",
"0 | model | MultimodalFusionMLP | 109 M \n",
"1 | validation_metric | MeanSquaredError | 0 \n",
"2 | loss_func | MSELoss | 0 \n",
"----------------------------------------------------------\n",
"109 M Trainable params\n",
"0 Non-trainable params\n",
"109 M Total params\n",
"219.565 Total estimated model params size (MB)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Sanity Checking: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a57adcb676984b8ca2d8e78da14f930f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 0, global step 4: 'val_rmse' reached 1.52672 (best 1.52672), saving model to '/home/ec2-user/SageMaker/autogluon-on-aws/4.multimodal/ag-01-multimodal-text-tabular/epoch=0-step=4.ckpt' as top 3\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 0, global step 10: 'val_rmse' reached 1.09024 (best 1.09024), saving model to '/home/ec2-user/SageMaker/autogluon-on-aws/4.multimodal/ag-01-multimodal-text-tabular/epoch=0-step=10.ckpt' as top 3\n",
"Time limit reached. Elapsed time is 0:02:00. Signaling Trainer to stop.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 1, global step 14: 'val_rmse' reached 0.88354 (best 0.88354), saving model to '/home/ec2-user/SageMaker/autogluon-on-aws/4.multimodal/ag-01-multimodal-text-tabular/epoch=1-step=14.ckpt' as top 3\n"
]
}
],
"source": [
"from autogluon.multimodal import MultiModalPredictor\n",
"time_limit = 2 * 60 # set to larger value in your applications\n",
"predictor = MultiModalPredictor(label='Price', path=save_path)\n",
"predictor.fit(train_data, time_limit=time_limit)"
]
},
{
"cell_type": "markdown",
"id": "cc159166",
"metadata": {},
"source": [
"
\n",
"\n",
"## 2. Prediction"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6693fb26",
"metadata": {},
"outputs": [],
"source": [
"predictions = predictor.predict(test_data)\n",
"print('Predictions:')\n",
"print('------------')\n",
"print(np.exp(predictions) - 1)\n",
"print()\n",
"print('True Value:')\n",
"print('------------')\n",
"print(np.exp(test_data['Price']) - 1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "54ca8eb7",
"metadata": {},
"outputs": [],
"source": [
"performance = predictor.evaluate(test_data)\n",
"print(performance)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "83df1981",
"metadata": {},
"outputs": [],
"source": [
"embeddings = predictor.extract_embedding(test_data)\n",
"embeddings.shape"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "conda_pytorch_p38",
"language": "python",
"name": "conda_pytorch_p38"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}