{ "cells": [ { "cell_type": "markdown", "id": "b6667a31", "metadata": {}, "source": [ "# Multimodal Training/Prediction for Image + Text + Tabular\n", "\n", "이미지가 포함된 멀티모달 데이터도 한 줄의 코드로 쉽게 훈련 가능합니다. " ] }, { "cell_type": "code", "execution_count": 1, "id": "55e17642", "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "np.random.seed(123)" ] }, { "cell_type": "markdown", "id": "bc8ebca0", "metadata": {}, "source": [ "
\n", "\n", "## 1. Data preparation and Training\n", "\n", "### Load Data\n", "본 핸즈온은 빠른 실습을 위해 PetFinder 데이셋의 일부만 사용합니다. 입양 프로필 정보를 기반으로 애완 동물의 입양 속도를 예측하는 문제로, 문제를 단순화하여 0(느림)/1(빠름) 이진 분류 문제로 변환하였습니다.\n", "\n", "클래스 개수 또한, 실생활에서 많이 접하는 불균형 문제가 아닌 300/300(1:1 비율)입니다." ] }, { "cell_type": "code", "execution_count": 2, "id": "c7e6ca44", "metadata": {}, "outputs": [], "source": [ "save_path = 'ag-03-multimodal-img-text-tabular'\n", "download_dir = './ag_automm_tutorial'\n", "!rm -rf $save_path $download_dir" ] }, { "cell_type": "code", "execution_count": 3, "id": "e046ee50", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0.00/18.8M [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
TypeNameAgeBreed1Breed2GenderColor1Color2Color3MaturitySize...QuantityFeeStateRescuerIDVideoAmtDescriptionPetIDPhotoAmtAdoptionSpeedImages
02Yumi Hamasaki429226521572...1041326bcc4e1b9557a8b3aaf545ea8e6e869910I rescued Yumi Hamasaki at a food stall far aw...7d7a39d713.00images/7d7a39d71-1.jpg
12Nene/ Kimie12285025672...1041326f0450bf0efe0fa3ff9321d0b827b12370Has adopted by a friend with new pet name Kimie0e107c82f3.00images/0e107c82f-1.jpg
22Mattie12266021702...10414019b52af6d48a4521fd01d4028eb5879a30I rescued Mattie with a broken leg. After surg...1a8fd67075.00images/1a8fd6707-1.jpg
31NaN118930721202...104140188da1210e021a5cf43480b074778f3bc0She born on 30 September . I really hope the a...bca8b44ae3.00images/bca8b44ae-1.jpg
42Coco627628522472...110041326227d7b1bcfaffb5f9882bf57b5ee8fab0Calico Tame and easy going Diet RC Kitten Supp...2def679521.00images/2def67952-1.jpg
\n", "

5 rows × 25 columns

\n", "" ], "text/plain": [ " Type Name Age Breed1 Breed2 Gender Color1 Color2 Color3 \\\n", "0 2 Yumi Hamasaki 4 292 265 2 1 5 7 \n", "1 2 Nene/ Kimie 12 285 0 2 5 6 7 \n", "2 2 Mattie 12 266 0 2 1 7 0 \n", "3 1 NaN 1 189 307 2 1 2 0 \n", "4 2 Coco 6 276 285 2 2 4 7 \n", "\n", " MaturitySize ... Quantity Fee State RescuerID \\\n", "0 2 ... 1 0 41326 bcc4e1b9557a8b3aaf545ea8e6e86991 \n", "1 2 ... 1 0 41326 f0450bf0efe0fa3ff9321d0b827b1237 \n", "2 2 ... 1 0 41401 9b52af6d48a4521fd01d4028eb5879a3 \n", "3 2 ... 1 0 41401 88da1210e021a5cf43480b074778f3bc \n", "4 2 ... 1 100 41326 227d7b1bcfaffb5f9882bf57b5ee8fab \n", "\n", " VideoAmt Description PetID \\\n", "0 0 I rescued Yumi Hamasaki at a food stall far aw... 7d7a39d71 \n", "1 0 Has adopted by a friend with new pet name Kimie 0e107c82f \n", "2 0 I rescued Mattie with a broken leg. After surg... 1a8fd6707 \n", "3 0 She born on 30 September . I really hope the a... bca8b44ae \n", "4 0 Calico Tame and easy going Diet RC Kitten Supp... 2def67952 \n", "\n", " PhotoAmt AdoptionSpeed Images \n", "0 3.0 0 images/7d7a39d71-1.jpg \n", "1 3.0 0 images/0e107c82f-1.jpg \n", "2 5.0 0 images/1a8fd6707-1.jpg \n", "3 3.0 0 images/bca8b44ae-1.jpg \n", "4 1.0 0 images/2def67952-1.jpg \n", "\n", "[5 rows x 25 columns]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_data.head()" ] }, { "cell_type": "code", "execution_count": 7, "id": "c38ab36c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Type 2\n", "Name Yumi Hamasaki\n", "Age 4\n", "Breed1 292\n", "Breed2 265\n", "Gender 2\n", "Color1 1\n", "Color2 5\n", "Color3 7\n", "MaturitySize 2\n", "FurLength 2\n", "Vaccinated 1\n", "Dewormed 3\n", "Sterilized 2\n", "Health 1\n", "Quantity 1\n", "Fee 0\n", "State 41326\n", "RescuerID bcc4e1b9557a8b3aaf545ea8e6e86991\n", "VideoAmt 0\n", "Description I rescued Yumi Hamasaki at a food stall far aw...\n", "PetID 7d7a39d71\n", "PhotoAmt 3.0\n", "AdoptionSpeed 0\n", "Images images/7d7a39d71-1.jpg\n", "Name: 0, dtype: object" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "example_row = train_data.iloc[0]\n", "example_row" ] }, { "cell_type": "markdown", "id": "4d1e92f0", "metadata": {}, "source": [ "### Modify image path to absolute image path" ] }, { "cell_type": "code", "execution_count": 8, "id": "e6ecb49a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 images/7d7a39d71-1.jpg\n", "1 images/0e107c82f-1.jpg\n", "2 images/1a8fd6707-1.jpg\n", "3 images/bca8b44ae-1.jpg\n", "4 images/2def67952-1.jpg\n", " ... \n", "595 images/aeffc6dfd-1.jpg\n", "596 images/a4f6acdb6-1.jpg\n", "597 images/38063ee14-1.jpg\n", "598 images/d3faf6a0e-1.jpg\n", "599 images/8b0fa13e1-1.jpg\n", "Name: Images, Length: 600, dtype: object" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_data['Images']" ] }, { "cell_type": "code", "execution_count": 9, "id": "37e985e8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'/home/ec2-user/SageMaker/autogluon-on-aws/4.multimodal/ag_automm_tutorial/petfinder_for_tutorial/images/7d7a39d71-1.jpg'" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "image_col = 'Images'\n", "train_data[image_col] = train_data[image_col].apply(lambda ele: ele.split(';')[0]) # Use the first image for a quick tutorial\n", "test_data[image_col] = test_data[image_col].apply(lambda ele: ele.split(';')[0])\n", "\n", "def path_expander(path, base_folder):\n", " path_l = path.split(';')\n", " return ';'.join([os.path.abspath(os.path.join(base_folder, path)) for path in path_l])\n", "\n", "train_data[image_col] = train_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))\n", "test_data[image_col] = test_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))\n", "\n", "train_data[image_col].iloc[0]" ] }, { "cell_type": "code", "execution_count": 10, "id": "aafbdfe6", "metadata": {}, "outputs": [ { "data": { "image/jpeg": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "example_row = train_data.iloc[0]\n", "example_image = example_row[image_col]\n", "\n", "from IPython.display import Image, display\n", "pil_img = Image(filename=example_image)\n", "display(pil_img)" ] }, { "cell_type": "markdown", "id": "a18c22ae", "metadata": {}, "source": [ "
\n", "\n", "## 2. Training" ] }, { "cell_type": "code", "execution_count": null, "id": "804432cc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Extension horovod.torch has not been built: /home/ec2-user/anaconda3/envs/mxnet_p37/lib/python3.7/site-packages/horovod/torch/mpi_lib/_mpi_lib.cpython-37m-x86_64-linux-gnu.so not found\n", "If this is not expected, reinstall Horovod with HOROVOD_WITH_PYTORCH=1 to debug the build error.\n", "Warning! MPI libs are missing, but python applications are still avaiable.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Global seed set to 123\n" ] } ], "source": [ "from autogluon.multimodal import MultiModalPredictor\n", "predictor = MultiModalPredictor(label=label_col, path=save_path)\n", "predictor.fit(\n", " train_data=train_data,\n", " time_limit=120, # seconds\n", ")" ] }, { "cell_type": "code", "execution_count": 15, "id": "fda231ea", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "61ae2a7bc4504ba389d238fdc1ca3565", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "{'roc_auc': 0.9408}" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scores = predictor.evaluate(test_data, metrics=[\"roc_auc\"])\n", "scores" ] }, { "cell_type": "code", "execution_count": null, "id": "b10893e8", "metadata": {}, "outputs": [], "source": [ "# predictor.save('my_saved_dir')\n", "# loaded_predictor = MultiModalPredictor.load('my_saved_dir')\n", "# scores2 = loaded_predictor.evaluate(test_data, metrics=[\"roc_auc\"])\n", "# scores2" ] } ], "metadata": { "kernelspec": { "display_name": "conda_pytorch_p38", "language": "python", "name": "conda_pytorch_p38" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }