{ "cells": [ { "cell_type": "markdown", "id": "b6667a31", "metadata": {}, "source": [ "# Multimodal Training/Prediction for Image + Text + Tabular\n", "\n", "이미지가 포함된 멀티모달 데이터도 한 줄의 코드로 쉽게 훈련 가능합니다. " ] }, { "cell_type": "code", "execution_count": 1, "id": "55e17642", "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "np.random.seed(123)" ] }, { "cell_type": "markdown", "id": "bc8ebca0", "metadata": {}, "source": [ "
\n", "\n", "## 1. Data preparation and Training\n", "\n", "### Load Data\n", "본 핸즈온은 빠른 실습을 위해 PetFinder 데이셋의 일부만 사용합니다. 입양 프로필 정보를 기반으로 애완 동물의 입양 속도를 예측하는 문제로, 문제를 단순화하여 0(느림)/1(빠름) 이진 분류 문제로 변환하였습니다.\n", "\n", "클래스 개수 또한, 실생활에서 많이 접하는 불균형 문제가 아닌 300/300(1:1 비율)입니다." ] }, { "cell_type": "code", "execution_count": 2, "id": "c7e6ca44", "metadata": {}, "outputs": [], "source": [ "save_path = 'ag-03-multimodal-img-text-tabular'\n", "download_dir = './ag_automm_tutorial'\n", "!rm -rf $save_path $download_dir" ] }, { "cell_type": "code", "execution_count": 3, "id": "e046ee50", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0.00/18.8M [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
TypeNameAgeBreed1Breed2GenderColor1Color2Color3MaturitySize...QuantityFeeStateRescuerIDVideoAmtDescriptionPetIDPhotoAmtAdoptionSpeedImages
02Yumi Hamasaki429226521572...1041326bcc4e1b9557a8b3aaf545ea8e6e869910I rescued Yumi Hamasaki at a food stall far aw...7d7a39d713.00images/7d7a39d71-1.jpg
12Nene/ Kimie12285025672...1041326f0450bf0efe0fa3ff9321d0b827b12370Has adopted by a friend with new pet name Kimie0e107c82f3.00images/0e107c82f-1.jpg
22Mattie12266021702...10414019b52af6d48a4521fd01d4028eb5879a30I rescued Mattie with a broken leg. After surg...1a8fd67075.00images/1a8fd6707-1.jpg
31NaN118930721202...104140188da1210e021a5cf43480b074778f3bc0She born on 30 September . I really hope the a...bca8b44ae3.00images/bca8b44ae-1.jpg
42Coco627628522472...110041326227d7b1bcfaffb5f9882bf57b5ee8fab0Calico Tame and easy going Diet RC Kitten Supp...2def679521.00images/2def67952-1.jpg
\n", "

5 rows × 25 columns

\n", "" ], "text/plain": [ " Type Name Age Breed1 Breed2 Gender Color1 Color2 Color3 \\\n", "0 2 Yumi Hamasaki 4 292 265 2 1 5 7 \n", "1 2 Nene/ Kimie 12 285 0 2 5 6 7 \n", "2 2 Mattie 12 266 0 2 1 7 0 \n", "3 1 NaN 1 189 307 2 1 2 0 \n", "4 2 Coco 6 276 285 2 2 4 7 \n", "\n", " MaturitySize ... Quantity Fee State RescuerID \\\n", "0 2 ... 1 0 41326 bcc4e1b9557a8b3aaf545ea8e6e86991 \n", "1 2 ... 1 0 41326 f0450bf0efe0fa3ff9321d0b827b1237 \n", "2 2 ... 1 0 41401 9b52af6d48a4521fd01d4028eb5879a3 \n", "3 2 ... 1 0 41401 88da1210e021a5cf43480b074778f3bc \n", "4 2 ... 1 100 41326 227d7b1bcfaffb5f9882bf57b5ee8fab \n", "\n", " VideoAmt Description PetID \\\n", "0 0 I rescued Yumi Hamasaki at a food stall far aw... 7d7a39d71 \n", "1 0 Has adopted by a friend with new pet name Kimie 0e107c82f \n", "2 0 I rescued Mattie with a broken leg. After surg... 1a8fd6707 \n", "3 0 She born on 30 September . I really hope the a... bca8b44ae \n", "4 0 Calico Tame and easy going Diet RC Kitten Supp... 2def67952 \n", "\n", " PhotoAmt AdoptionSpeed Images \n", "0 3.0 0 images/7d7a39d71-1.jpg \n", "1 3.0 0 images/0e107c82f-1.jpg \n", "2 5.0 0 images/1a8fd6707-1.jpg \n", "3 3.0 0 images/bca8b44ae-1.jpg \n", "4 1.0 0 images/2def67952-1.jpg \n", "\n", "[5 rows x 25 columns]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_data.head()" ] }, { "cell_type": "code", "execution_count": 7, "id": "c38ab36c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Type 2\n", "Name Yumi Hamasaki\n", "Age 4\n", "Breed1 292\n", "Breed2 265\n", "Gender 2\n", "Color1 1\n", "Color2 5\n", "Color3 7\n", "MaturitySize 2\n", "FurLength 2\n", "Vaccinated 1\n", "Dewormed 3\n", "Sterilized 2\n", "Health 1\n", "Quantity 1\n", "Fee 0\n", "State 41326\n", "RescuerID bcc4e1b9557a8b3aaf545ea8e6e86991\n", "VideoAmt 0\n", "Description I rescued Yumi Hamasaki at a food stall far aw...\n", "PetID 7d7a39d71\n", "PhotoAmt 3.0\n", "AdoptionSpeed 0\n", "Images images/7d7a39d71-1.jpg\n", "Name: 0, dtype: object" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "example_row = train_data.iloc[0]\n", "example_row" ] }, { "cell_type": "markdown", "id": "4d1e92f0", "metadata": {}, "source": [ "### Modify image path to absolute image path" ] }, { "cell_type": "code", "execution_count": 8, "id": "e6ecb49a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 images/7d7a39d71-1.jpg\n", "1 images/0e107c82f-1.jpg\n", "2 images/1a8fd6707-1.jpg\n", "3 images/bca8b44ae-1.jpg\n", "4 images/2def67952-1.jpg\n", " ... \n", "595 images/aeffc6dfd-1.jpg\n", "596 images/a4f6acdb6-1.jpg\n", "597 images/38063ee14-1.jpg\n", "598 images/d3faf6a0e-1.jpg\n", "599 images/8b0fa13e1-1.jpg\n", "Name: Images, Length: 600, dtype: object" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_data['Images']" ] }, { "cell_type": "code", "execution_count": 9, "id": "37e985e8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'/home/ec2-user/SageMaker/autogluon-on-aws/4.multimodal/ag_automm_tutorial/petfinder_for_tutorial/images/7d7a39d71-1.jpg'" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "image_col = 'Images'\n", "train_data[image_col] = train_data[image_col].apply(lambda ele: ele.split(';')[0]) # Use the first image for a quick tutorial\n", "test_data[image_col] = test_data[image_col].apply(lambda ele: ele.split(';')[0])\n", "\n", "def path_expander(path, base_folder):\n", " path_l = path.split(';')\n", " return ';'.join([os.path.abspath(os.path.join(base_folder, path)) for path in path_l])\n", "\n", "train_data[image_col] = train_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))\n", "test_data[image_col] = test_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))\n", "\n", "train_data[image_col].iloc[0]" ] }, { "cell_type": "code", "execution_count": 10, "id": "aafbdfe6", "metadata": {}, "outputs": [ { "data": { "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAGQASwDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDyXSNCa+uG8uQAL1Zveu7t9OS3tYoUDEIMbu9PsbK2sbdYoV2g8knkmry5wRjge1Ypdzoq1Ob3YKyKjQ7TkjrxULwPxhwBnuKvl8uwAzj2oKZx09c02rGKszOaA4AFM8gjHHNaTIQcYBFRlcHJAFOztqF9S6oxNA3ZgR+n/wBarXpxUAIC2rY/5aKPzBFXtoI6c1pBtIp2ICPbtTccc45qwYxio2iqxXKk4zGarKSjdjir08ZEZ44FUeA3A5zQrjsh0zhm4RVHoDVfpIvPGakkOCKiH+s6U3d7jSLfljHQYpjRKc5Apwb60pIHWlYdrlZ4IyD8ophtUOOMfSrBcdaTeCRUyQ1ZEH2cgja7j6GkKSjI80kY71YDjODx604ctWbiWrrqcxrGhTXLQvBs+QNnceSSax30HUN2PLX8Gr0F4xgdPyqPyhngD2qZQuaQrSirJnnbaTfIoL28hPtzUEtjcJw8UgwO4r0owg54prW6nqo/Ks+Vm31mfVHmTIQOh5Pcc0zaSMjvXphsomzujU5/2RUT6RbNnNvHz/sipsy1iEt0ebkAjHbsaCMtwOtd7JoFkzZ+zqOO1V5PDNmzZCMpPoaLspVo9jiRw3PWkA5wB2rr28KwknZJIv0xUD+FGOdk/Ud1ouCrw6s5c8Hp+VJt5OPz9K6J/C91nKyqc9cjpUEnhy+ByFU+uDRzFe0g+pi49O3XFJtyx61qSaJfqADbtn2qu2n3SH5oJPrtp8yBOL6lZYzLIEUHLHaB9a77SrcRmRgOFIjU+yjH+NcvoljI+qozRELEDIcr6V21smyEcgE8n6mpkzGorysTjr2p2B6frSZ4BxS/8BoSb0uZyRKkiggDnI7CnEDA/pQmAQBgDHWnY67umeK3SRwXBUHI3delL5ZBOR1HWkU5b2PfFSqeOg9sVSsIhMQDZGeagdGAJA6VfVCeM8GkaMt2BoshXRA0mLGJ/wC7JGen+0K0w3bNYV1FfeQ8UVuzJkYxjPXNSfb7uJ/31nIExncFojKKKdjZJPU00nrWTHrKNLtdJE4zllNB16AOVMFw6/3o4iwrWMosTVi9dOVt5GHUCsKR3YcSHf7Vdm1qxmjaNWkDt8oVomGT6dKpTOjSMccE9x0qo2LiiF5ZkALMV+oqa0k81s5OBTGEezJxjpk061kVnKrjg1TsVdWNEoSOv0qN1arC9Onaq1xcRwDMrqg9zimo3eiuZSnZXbIyT3BpN3vVZtZsQxBnU/TmhdTsZGwJkz78U3h6n8rJjiKfcshjkdalQ8ioo2ilG5GBHsc1YjTOOaylG25qppkp5HTNIF4xTwpx0oUelZuxcWhuzp/hRtOfWpMd6dt9qhl3I41G/LdO9Th7ZvlHHvVK/mFvbnnBPSs20u8NhyDmp5ktLCsbU1vgBl5WqxTnkVPbzgkD+E9ae6YfoMUnFFJlTaNwAzk+1Ls9qs7PXGKNg9OKmyC6KpSlMWe1WdlL5YosguU/K4GaZ5AJHFXvK/yKTyhSshXK8cC7WOO2DTfJXFXlj+U0wpjtRyoLlURHP3j60vlv61a28ijb9aaimxOViDgkZbrQ0xjGW+7njjrT1j2xg7ct607AZO2c5wR0rVJnPa6FRwxB2kfWpEwR+P5UibiegPtVxLKQgYUDvkmnYdtCELycmpUAPA4JFTLYMDl5OfSpktY1xyaaFykKRlyV4xTjAdgGcjuDVlUUHIFPCD0p3QcrM9oFdTgA0iwhCQF4x0xWltUCjaDSuHIcprAuZJVhgichcMxVM8+lZb2mpluIbkcc/JxXoKoW+lK0ZUZyM1oqluhomo2R529nqn3ds/HXMZ5qGeyvXASWCcJ3KqVP5jpXpNLjNP2z7BzJnnGpQzabZRXFtc37FhnymYsPx4z3rlJftdxKXuDKxf8AvA8V7lgdwKQxRsfmRPfIq44qdOPLBW8zndCMpczPDDZqiq3mbjj5sL92mfZyQSuduM5bvXt8mnWcufMtYG9d0YNVW8P6VIMyadbhvZah4vEN3UvkX7ClbY8eimu7UCSKR8DHzL0zXU6N4hjmUpdAq69wuc12c/hXSLjKm12DOSI3KgkdOBVe38E6TaXKXFv56OpyP3mQPzrb60qkbVUr9zJ0HBp02UrTVLC7mMME+6QAkqUYcevIq8NueoFaX9mRg5DHPrig6XAR8wJPqK47o6NTO2r6inYx2xVo6UqnKNx6Gs6fw0WuHnjubmJnO4iOYhQfYUrIpXMTxFMsbxq3Q+lUIEBUMrbl7Gn6wzwXfkykzxqcEt1FQ2qwhR5cpCnsT0qGle5pbQ07aUqVXdz6Vs20olTkcrXNbjG4DEZ7H1rQ07UI/P8AILfORwPWqsU7dDdCj0o2j0rPl1b7MSJLK7xnqke4H8qiXxJpxJ8x5Y/XzImX+lFtTK5q7RTto9Kox6zpsv3L+3P/AAMD+dWkuIZB8k0bZ9GBo5WIftHoKMDPSlBzzkUUuWwXExjoO1JsHendKXtRYBnlLS7BUlGafLfoS3oZZduiAFQPyq7aWssuHc7Qcc9zT7GzyBLLznotanAHSr3BIZFAkQAUfietSdfpSd6kRgvVaBsbsOO9OC05pc8AYFMDZ7UO4rhtpcHp1oBFLkGkO4YoxxRuApRgmgTYgBFLgmnDFApXsNiYFLj0pfejrVIloApJAp4iOKYCadvJ6mjQauNIxkUnfpSnqaXtgUD1EAGKDS45o/GgYhJFJ7d6dTO9IHck2HHSoZGEaMW4wKlDEDg1la/dfZdMlmJ5xRoFmcFe3SXWoXCnn5jUS2xU5jOP61jx3ivfNlgGY8Z4ya07e+5IbgjgqazfmaxuWGbgA8fqKx768ltb2KUHBVgQc81stNDOhwPnFc9rBZYwTgru/KiwXfY9Z0C5i1DTo5SoJI5rUNnbt1iU1xXgW82W3lHpnArvlII6cVSbMmrlB9F06XIe0ibPXKCqr+E9Fc5+xop9V4/lW3jmlximFznj4P08HMUt1Ee2yZqQeGJI0ZYdXu1543YbH510iBSRuHFSGKMjIOKLOwXOV/sDVU+5qyOP+mkA/oaDputovD2Uv4MtdNtoxU6jOZ8jV0J3WEb+6Tf40u6/HB0ybP8AvCumAHbGaTbTTZMlocnpOoGVQh64raXeRnArzq0uCjKyk8jkZxXU2twZIVnimZJE6gnIcf3TmrXYucbao3grdcUp3DtzT1dWG4HI9RUNzceRbvIBuYDhc4yfSnYzF3H0pDKFBLHH1rPL3M5U+c6HjKpwKkNq8mSxLtjGWNRfsHKXVkU9CD+NPDCs17QIG+XGeuOKI2liwBIWUdm5/Wjm7j5WaWRmngAc1QS9RZI4rhdpkO1XHTPYfjWiIgOhNVuTsGKD1oMZxwaFVuu4UrMdxwPFLxSBWo+f+7VCDjNHaj6jFJkdKQxQKUikGMdaDmgaExig0ckUYoATNIKcYzjOKaOOaVg1HZHpXJeLLoSQG1JAU9a6W5uFhhLEivNvEdy92JSuc9RiixSVzmbtPsjqCA8fTOeRUoVmRZVbJxyfX60yGF5g6MCSCGGe9dBZaerwBgpKng1k4muq3ZlxzkANnketV74GdTGR8o5x6VvnRG6L0/pT20jDHK9RjpTUWS2mO8IK8anOcbuDXo1tMSgBNcpoth5Y4XFdRBHtAFXFENsvBven1CqnrT881XKybkg4NLUYb0pwJpWCwtKKaSfSgc0mFmOxwaWkzzimluaEgZ5BNG9lclGB2Z4rRsb3yyAeUPUGtHU9Mju4WdImLjocY/nXNIz28phk4YHih6rQ1g+h29jqcUOQscjR4A2xrnB/Orj3i3ULRmKeMHo3AP1rltPuUWVBIoyOhPat8TQxIOqhuc4JH50OVtCJQ1HWd8BKI7iFoWLlUzyH9/bpW1HKhHX3rC2TMRJFKjr1G5f5EVJ57Rgb1I45I6UlJInkNeQCTgYqs8LdqhiuMruDAg96nW4z160NqQWaIJLVWKlwTtOQO2ack81rjZmRAMeWx5/A1K8wbtUI5PGaV7PQqya1NOC6juYBJHnB4IPUHuDUoIrN0uCaNZpJYhH5kmVUdcdMn3NaI4Faoydh4NGcc5pooyM5PSqEOzQQD2qqmoQvKUTLqBnzFHy59M+tTrNG4+VhmldF2DYPUikKsOhzT/WgmlZDTIskdRShxketS8GmlQe35UMA8447fWmk5pDF1wajbcvWk0xqxla+jtaYjbDVyT2pmVo2Hz45rsb1TMhXtWYLTDhiOfWhRuF0c7b6Ykc6HacdD9PWtmysxbtJFjKseDV82ykjjoasJEKFGwc1yOK2BI4qf7FGwHAqZVC04N6Ch2DUkt4EhXAFTqwBNQpnvUiMAeRQIuJginlQRUMcgqwrAkVRAnk5GRR5ZA61MOlHanYakV8EcE0c1I696rtMiDLcYrNqxVyXimEjPUVnXOqqgIXrWU+rOWzmhILSNHyQyLtTfkeoFc5r+hs8BuI49jjsTn+VGm641sAlx931rRuPEWkzqLWWbiX5WPTaPXNUuV7BJSTOHhnIfa3Dqa6rRdQSRVhkf5j096xte01ElF3aMrxnupyKzrS9aGRXU4IPrUtLqaXTiejNbMI8RbVPUccUx4+mDUNnqSz2kZJ+ZuB7mrMFs0EOCSWJLNn1JzUNIi5RnhlSH/RdqSBt20jhvUGrVsj3MAlhw699pzg9xTpByTVOI/2dqC3METP58ipKqnjnjdj1pRSejG290acdhcM/zhUQj1yfyq9BZx24Lcs3q1Tg8UMc1vGEVsZtsb1oZlQEsQFHUntRXP3Eo1K4O9CIoXZUw+Vk7EmqclFXJSbNGbVolk8uKOSR8ZBC/Kf+BVTZ7iaZZJZSCM/InC/j601ECBVQBVXgD0FWEjNYym2aRikN24Hy8e1PVGxxVlIx6VKUVV7VKix3K8YkU8McCpRI2MtzTHlCHtVd7gnPPGOtVdILNl8S+1OEin+IZrIa6BiMgYsg/ujNNW5ZzkK/JxypFLmDlNk+uahkPaqMd0TKYxuBHqOKtrluT1rSLuRJWIXXPUVWZADVxxwagYVQiIKKepxTHoAz2oYx4PIqRfl5pirxmo3lOdo71OiHcseZzxQJAelVA2TTw+BSY7F2FySRmrSSEHms+JsAHvVgPzTRLZe8/AyTSm4VVyTVEtkVVuLjCEdPxp3C1yzc6qEBA7Vz97qzMfvYzWfqF8UcqDj3zXPzapum2Fdx71lKbR0Rp2Nxr7JySee/pUTuzNnbn3zWdFIpGV/I96d56jhZdo9M9KmL1B33HpIrvsYbW9DTbmz84YIwegNdBfaVFJcJMg5BycVW1Dy4ym1fyrflvqtLEqe0ZIpaVpsttHdSXDD7IU3HnOCO+KyLsRuu+HDKScH1x7Vr/wBpCzUvkFOhX2rQ0eDS9UXzLayiQHOX39foKIvuZThbbYqeD79RcmORRvHAJ9K7mRcjIrmbTw4bbVfOQkKDXUsuEA9KWqFfUzphjPAxWXfxtcy2tqmfOklVhg42gck1rzsin5iB71TmgDusisUlTlXXqKyTszR6o6IHijuap6beG6tgJBiZOJAeOfX6Gr3FbproZNNDG4U965rT0xA5VdsbSMUTHCjPSunxk4rIvNPkt7tZbYKtsQxlQdm7ED+dTUV0ONrjAOetTx5FVlkVxlWBHtUm4kHFY3satXLQm2riobi6CxlmYKoGSSeAKhJIHWssy/2jd3FqCyRRAByCMsT/AE/nVXbE7Fn7RPc3XlxRlYQMmVv4sjjb61aVQDnGTjGaRRtAUKAB0xUg5HIqW7jG9OgwKbnPIIPuKpyk3F3L5xK20J2hDxvbqSfUVlXniCNZvItWGRxtHpVWYrnQxyAyhetW3lEY5IxWfooaaLzpOpqv4j1BbeAxoRk1pBWRD1LzX0TfdcH6GneapXPrXmUmoNbyeaZ2j5zyeK6fSNaS+t8BhvHXHeqdxWV9zpAQxpT04qtbTbl561KXoHysVn9KiZqcwyoqrLIIwalsEiXeE4z1p6uHbAqgZCTuqa3buDSLaNIHaKRp8Z5quzk96haT5WqrklkXeQeeKrXMu5CSeKpzzeWi4H1rKvL2aVSqZwO4qWNK5nas5MjbfzrKijOdxbJ96lOoASlJSMHjmq12UD5yVB5BB/lWUl1OmLdrGpGQwxx7VXmCeadzFT6YrLF5NbsMvuHY+1TNqZ3dqz5uUp05W0Z6kH71HJbxS9QOaRSDjac/jTgfSruc7MHU/DwuY2EbYB7VF4W0e40q5cs52Z6Gtt7qBH2PPGGJHylhmnCdRceSQyv1G5SA30Pem3IzsbIuhjkdO9KblCvWsxjvjKHIzxlTgioI4HijKJcyHuDIA2KrnYuUuyuHPt70iZMeWxnPGPSoIkbaRM6yZ9FxxSzXEdtFulZVQcZJxU7srUWN5LfV7WRCWEgMbR56jrn6iul/nWBpllNNqBu51/chR5Kt6/3vaug7dK2gtCJO7BepNK3NIoxS9askrT2FrckNLCCwGAwOCPyqhNpr27Brd3kjwd6yOS3titntWfqV+tqnkxMhu5AfLQn9T6UpRTWo02tjFl1COJczB4snGJBtqvp8iC4eRBuM+GdlBIB+uOmKtR2ivIJpUDzkHczc9eo+lWUVR8g2g44UelYPsjVJjl6+tSqMk81FGAzMq53KcHIxU4Ug0kmDMXxNObfTXkUqCB1PNeWaVctLrauzcFsMmMZHrXr+pWf2y3aNkBz6157HoX2XXxCqgZAbr0rTlZm07Ho9mwitFweMVyXiWQvcqexOK6y0jbyVU4xisXV9PWctE3ynOQ3pWlxpaM8p1NJ59SEXzbMgDjgV1fhOwlt7sxsSdoqre293ZXBjltQ/92RejV1XhiwdIDPL99+3pTbuKKaN+OPatO6mntkDFQSShB1qWVqxXfjrxVeVVxljxVaa8ANV3vA4AzSepSix8jkk46VLE7AYApiFWXGRzU2G24XrQkBKshxgmmSSbQajIZeoGajlLY9BQK2oyRt4AxWLqMrxArEOa02lwCBxWfcqWVmHU81my0jgtQW4ExkY8k1NBLN5PlyfMnUZPSte7szM/wAwGAazZ18pdo6jpWcm2tjeFiNyFXDNnHWozMM/xflUXzEZLAn86lEUpAOwGkoo6XOyR7uNKsvNWVbdUcd04z9cdaBo9gkxmW3AcnOdxx+WauE/L70A5rrseO2VY9MsopTLHaQrIf4ggzU81tHcRGKVcr7HB/OpDQCfTpVWEZMukSof9GucKf4ZRu/I1G1jeiOUiOMsp+Qb/vD+lbXOaOg5FTyRYczMaOwvpo1LiOBj97ncV+nrVy20eBPKa5/0iZP43Hf1x0q+MGnqOaFBIOZscB6Uo4pBS4xTGOAxR6CmbuuKN3cmn1C4/rXMzSo+v3sfz70VPvLgAY7etdGGrI1hWjuLe4DIsZbZICPvZ+7zUzV4ji0mMQZptpIkgkI271cqwAxg0BuakVueR1rDY2LCDIxmpMADGc1V83YpIBOOw71XmuLp5I0RAkbDLsWww9gKrQkvSEdQKpy2ltPOsjxKXUcNipcl+/NSRRkkZFNNsTVhsREbBcYFPuraKdASB9ae0O9uOMU8rtXBNbJENo566hR5PJ8vIHcir8MawwhQMYp8gAckLzVeSRlJ3flSQ1qLNMFGTWTdXXJANTXEjMMA1QlCquWNSzogktyrM/cms2bU4YpNrSAY61R1zWBDlImBPSuSmnkU7nJLHkc9Kmw5y5dz1XS54JkDiQMM+vStwLHjKcivGNJ1G4hn/dSlCOfrXomi6yLuAb22yLwwzTuk7GTd9ToJUJHSqEuUPTNXo5g64J5qCePccg/WiQKxnn5h0xSKgzhhxU7QNnINRupXqajYt6lG6tok5AFcnrD7Gwq4/Cu3LRkbWFcf4hsiXEkZPXoamSiFPSWpgwkeaOR75ro7WCKS3Vs9fesCC2eZgUPI9q6yy0aeS1VmHJqWotlykn1PWsUYIFIGFO3V0nGAHrSkYGKXj1o6imJjduKTmnd6FGSTRcLDec5qRWoxigqDTuFh2R36Uo5yaYQc9ad260DDgg0mMmjFIRzQIdgYpksMdxE0cyBkbgg0/oKU+lNDsYkulSWKu1ozyxk58pmyV9cE9fpUKSl1B+ZD/dYYIreOMdaZJDHMpEiKwIwQaiUE2OMmjIDMBjP50oDH0NTvoVm4KgyqpOcCQ4pbTSEs1IWeRx2DEYFR7MvnQRR5IyPxq0BtAGeKaIlj6ZpJnCqD+dWo2IbuSGQKpzVSSRpHwDwKr3d3stmOcEdDWXDesDndQ3YqMG1c3vMUDGBnvVK+TuMVVW8B6tzmkuLwHvkUrlKDuVZPu571j3++RSink8VskeapK9KrNaFnDAcUmaJnB6rolyEaRQW4zgVkSWP2q3R4fvDgj0r1Sa22pkLnHauQ1HS0+0tLbkxOxydp4NLSK1JkpTZz0dj9lRWYYbufWprK9nF7/o5OO+KtHTJ5WAnnJHpV+xsIbbG1QT6+tROor3Q4QaWp0NhqJdFEhw/etdJlPBwciuTfKy7l/CtGC5bAzUqo1oy5QubzcDgCqkxz04piXJIGDSM2TyKq90JRK8iEng81Surfz49rKCRWmF3U5bUs2OoqHe2hS0MLTdGxdjCfL9K9Fs7SKO2RcDis2ysBGQ3Q1pIGjXbuLd8042RlUaZohR1p20VCrGpAxrpOe6HbVzmjbx1pN/NG6izAawI705RgUh5Ip3AoGOB9eaUHvTe1NBb0pASetANMyQOlG72qhEmRRTN3anZxQMcfaimhhS5FAhTj2ppFL+NJnJoATjIGcZ6UNhV96cAAailPXvVCIt+W/pTJ18yIikTr3/GpGPy81Nh31OX1FnEbxsT7ViLdMuBycV1d9AsoORzXIXdu0VyQpIzWbTudtFole9wCVNPguluW2LJyKx5LWdmwGI555rV0vTFhO7JLnvmovqaTUbaG7bZVQCKuBUwCBUcEYjGW6VPKFEZK4Naq5zSsxhMJGD+NYmo2NuSWHGe4q60oLYzUMg3deRUyaYJNdTn5LDH8VMMJUjAzjvW19kO48nHpR9lUHnrWLjc05jPSEOOlTJBtAyKtiMKBipRFnqKFBBcrKmBxUyRFvrVlIParUEA3DIp8jJuVobUuenWtW30/gGrltaR46VpRwKo45rRQM5VEZ39nzhT5bIT2zxUh0+7J+UxY+prTAAp4IxT5EZ87sZoU5/Cn4PtQCMU4c1ZIxgQO1M3HqVp79MetRk84obCyHqwzTgQxqIVJGPl+tSmxjj0pyjjFRtnNPHHFWSOPWgY9KgmlZCAPrUQuXHXBqeZAkXMAnmjiqq3Weq1YBB59qd7j1F2g0FOmDS0ZNUhDCDmngcYppPzU7PFIAB684qCXPWpGbBqCWQHjOapAhFbnkYpszEDGaRMk9eKjuG696Q0itIc1kXtp5h3AcitRzULkdOtS9TaOhg/Zznkc1NDuRhjIrQeNCelIIE/Gs+XUttsdHO/GQCOlTq4ZSSMVDsVV696Gk2tgdx+dUIR4k4P6io2j2/j3pGcofVD0prMy8jkelKwaikdxSeWSaTPcdKXzMGk0MUQ+1SpFjFMV85qQE0rDsSqnNXraHPOKpRZYitW0QgZHNOKuKTsi7AuMcVaFRJ90VKDx71oczdxRT8e9N60fjQkLZFAAYp2AM0wcYp2cd6Bsawy/XGKTZznNKD1JoHTNFgDYQe1PAwuKTPehzhTTQCDls1J0qNMAdKfQLoVbhszH2FQZyauPCHYtkgmq7QHccMKjlbHdEY61dUkYA9KqrC4bJIwKs+hqoohyJVJJNOqNRgg9TjFPPFUUMbO7ik3HHWms2TSE8VLASR+OaiL5GMVFPNjio4XLv15qkx2Lw+SLOc1QlYljV1ztTFUJOtNlRIHbmomPFSOcVEeag0RERSBvnpzccU0AjrSuWKWwlMl42t6UMeKJf9Wv0oCwgwylD16rUanHB6UoOAG9KSbht3Y80gWomSrYAp2AwyKjJyoPcU5TzkfjSZQ8LjkVLGSetNXpViOPLdKB6Fi3Qkg1s26jAxxVC2jIwa1IQABmmkY1GS8EZxzSjNJt6kcilHaqMESA07JpgNLkVSE3oUAKPalBBoJFSMaeBgDPtTweKYOWp1AC5xSHB7UhxjGKaPvYFMB4p2eKT6UGgBTSEZHSj1pepxQIbt+U0u0cd6XPFAwKLgA4pWPFIaRulNAMAzyRTZQAp5p4xtqCfO0980mBiXd0yz7RUtg7mTL0z+y5DcNMR16ZPSrltEIyfahFFmRuOtVX5JqV2FQO3aqbKSG4FRt6Y4pSw6Umc1mUmRMPamnpTyCajcgHPpSNENYfSkl6AY7UgbNOfDGi5SIBzxigjcnTpTiACBSr0IpNgRRrzjFSIvODSqMcVMF5pDFjTpxVqJKjjUZzVpRgVRLZZg44qfkHiqsTDOKsg8ipdjOSLdu2VwamqpAcSY9atH6VUXdGTHHI6c0u4+lIp9OlSc1QmjLzgdaqtIS5OT1qaR8RE+1VRilMcWSLM65w1OWeTqcH8KgxzTlBLAepqLlWRfBJx700g9qcOlG6tUyBoVs9adlsZpQQeRTWPOKb0AMn0o8zmmgkHrQzH1qbjY8N70pYYpF5HNLtFPQQZ796GwVNG0E0beOtUhAv3aAAZOmaMgDqKEOAWpXAc205qnLgZPFTsTtLVVlOQT2oHHcrseeahcjPvUhxVeVsGmzREUkmCKcrg9DxVK4k5PPNNju1B25HArNtFmkTxxUTLn3pkUobHNTe9IaZHswRTMENU+KYfelctMhIy2Kco5p/FOUDNJBcYFyfpUqgdaTb/OnD0qmBKmMZqVX5xVSeTy48jrS2khkbJqbiZswIpwSBV9ETHQflVKA8CrqGrsjGQ9Y1HIAzTmGcUY96UjIxQkiECcHFOJ5phHAI7U7GeapAzIYb12mmfZzjg1KmcU8H9KHZ7iWiKhgcHjmnxIwfJGMVZ6mkbpU8iQ7i5GR0oxzTDweKBuOMVbESjpUZPzGms7A9OKOozSaAX8qB1waNpxSc/jUMehIBgUoxQG4xSDiqSEOHtSnNICKG6de1UK5Rc5kI9TVo42hR+JrPuN5Q+WfmzxV6MqsY3HLY5xUwGwkkH3c1XkOeKSTmQkdBTXYdOtVcIogc4qjO+FNW5mzkVn3B4I9qTZvFGXdzlUJJ5rH+1Orfe781ZvpOduePrWYxyMVi5PmNEjobS/BXJYVuLIBGpyOe9cFFMyOB2BFdBNqIMEYDf/XpqfcHDsbbzqo6ioGnzjnrWNHdNPk7u9XYs4HNCYWsaEbZBqVDVaP7oqdTgZpiJc4o3BR2qJpBGu4njtWRLfSSTsFHANDaRSL15P8AOFB96s2J5BrMRDIdzc1qWihGGehpLcmUjagb5avxNkVQiwB9atxOMCtDBtlsc0oNMU+hqQDigQLkk9KCpz1FJ/HUn4UyZbGOh4FSg1CvGKk9KNx6jxj0pDjFHfFL1FAhm3jinL096TGDSAkcU0wFPrQpo9eKUEelFxCk80YB60Gm5qQHbVzimZzwKUKx+lIBg8U0wDnFKScGkOTnNBJ7CqJbKc8MhjfYRntWMdfW3bypwY5AcEHpXRt05H5VhavpMd6wdcLIO/rScbbFRmupPa38d4f3bhgD2qw5AJwKoaNp72SPvCZJ4IFXJhg54pJs0j5EMjVQueVNWnbmqN1JhT60nI1SOY1E/wCk4qoD2FPvH825Zh60wA+1c7fc10Gk7W6478VOZCxVQaqSBieMZrSs7c7lJApXdzS6SuaNrb7EBPWtGJMCoYRhRnk1aXpmtbWMWSoOQKkPXvimCmTziNDk8029BlS+uAQV5x6VThILg5ODTJJN7nnJpYBgisnLUdtDWgGVFXocNxVW3GUFXI12tWyZzyZo25bAGc1ajNU4Dk88GrqDDU2SmWFbmpUb5u/NRY9KeDSBokbqKXP1oPKjNPHStFclmQOuakGDVcOw7U9XIPTioUkBYxmg9KRDu5pzDNVcQwGlAzzTfpSjNF7DDZ15o2kUo4yaXJBHFNMTY0qacAAKTdz0pcjpSuAucZpmQTTieKZnk00IPWjvijoTzQOTVJkWuI1Qsu8gY5qwVwM5+tRKwzu7YpXC12K6qq4ziqEw54q7kHJNUbnPO0cd6hmyKEzYJwazLpwQeavzZJPHSse8kaNWYpkgcVDZ0QRgTjbO+P71RlsHFKGMjFu7EmmuCD/Lis0yrokhgM3zkcA8D1rXtQQfm4qG0j2W6g9+eKsgEEEUrJFWuX4zxU6uBmqMchPUHFTxvk9KpSQcpaVuMnoKrSyCQnuKkmcrATjmqUbZJ6Cm2FkSGJSOgqSGFA/AApUGTVuKPkUWuTKSRLCmAP51dVfeoY4+fxq2q4HFaJIxk0SxnmrqHOPWqScVdiOVHNVbQzRaVcgYpdh9aIzxipOvalZA2KvAwTSZ96XjGDTcD2ppik9DFZiDT1NMJGeaUMO1QO5ZQ4FPJJxUanpRnP5VpaxN7inIp28UzPXvTH4Qn2oYbii5XcRg8U4XUZ71Q3HnHrS4wM1lzMqxeFxGeM1Iu081m4GKvRZ8pcVUW2JomwMdaYMUmTg+1IvStBCgetKg5JpNwB5oD4PAzQQlqSsvFZs8zJeCIA4YZz6VdaTA6Vm3z7F85eSvP4UmyluWWI28VTkJydo96bFdCZOCCMUjOd2SBioepokRSsCvoxPasbUpEjhbd1PTNX5rhYyXY8Dr7Vj3kg1CcJGMqpySKltG0dDGSEHnHWo5oiARz6Zre+yBV4FU7i36kCodylILM77dVLZIGDjtVwoCAMnIrKtCYLgr2atdMkAnp6UlqXEIwc+1TY2sMDrTVwo7c0xbmJXYM43L2zQiuYmvJAIggOCarIOtMd2mffjHpVmHDYJxmnuSyeEZxV+Jfeq0ScZzVxF4q0mZSsWUA9anxhetQRjK+lTAHbjNUmZySRInrVuE4FZ4JUVcibpTIL0Z/OpxVWP2qypNAxxHFRnrUo6c0hXmqRLMMxmlCHOeKkKgjrxTcY6H86SQroeMjNKuRmkU84604E46UxXDJHUVFO37s1KWqGVdy4BwaGhpoq55xzT92Rg0hjYN900EEHoaxsVcQ8t3rQXhQKoICZQCO9Xu3SriiZMcTSZAFHbvQ3TnOM1oTe6Ez/OncY60mKU4pSeg47iMMj0qvJGGGDVhulRHHU85qG7GiSMCW2bT5nliyyN1WqM1/OzHy4zn0PNdUY1fORmoHtIVyQgz9KVi4zS3RzH2K6u2DSEgHr2q7BYJbJhfxrY8pcVG6DHShJD52zFncKxUDpVKVweCp/CtG6QeY2BVN4+aiTZcWjNkjRn7gmlYXOwCOX5R7VbaJT2qWCNdpGO9Qn7xsmrFEG5IK9O3FEFqFcMck1pFFGOPzpmBuOKqwNoaqcYpUO1uOnWnDg051G3Ip6LYhluB6vLisi3bHGetacJyOtO5DSLaHBqx0GRVRWIcelWwRitEzOQY4xT4m2tTN3FAb5utBmaMbc1ZVuKzVn2LnBqaK755FDaKs2aCmnnrVVLlDjOR9am8+P1o5kSzLY5HSmnvUpWo2XrTTJsC8VIDxUYBHrmnDOBzQxBnPWmnlvalAPNI3TpRcdhygdTxQVHFCg44pSMU0JpjQo3U49QM8UDHc0E5PWgnUXJ60Z+tAUkU0kg0DY4AHPrSsh4IpvfvUvRRzQxpEJDYzj9ahLGrL1ERgVLRaY1GHP0prkYNP4zTHwKExkbAY4NROBjr1qRgMd6jbpQwRm3SfOT61SdMcZrUuITJytUnhcdjWUlZmqZTKmnwjk59KeY3z9006JCGJKnpUrc1ixpA3VHs5OancVGf/wBdNsq4zbkc1IFJSmeueakTIJHana5LZCAQc+9aEBzxVRvapLeT5hUppE2NNQeCKuIMrVOMgjNWojwK1iZtDyMD6U0HmnsCe9R4NUZtDzkqRmhD700H+VJGfmIP0rOTKRbHSpOaiQ81Lk+1IbsB4A9aYx9KeRn8KYw5zWiM2hpbmnA8daYP1p+c1V7EiqDSuOaUdcUHg0igHXGKU0i8tTiRjpQBGeppBinEAnrSYGDVGbQDOaM8UY7ijkjIFDHYAcYpxfjmmng+lI3I9qYcwpbI4qNjjrSt7VG54pMuMhwIpj88UZ4NMJ9qVh3BhxULE4pxY+tRk0WGhhPFRPg05j6UxuPpUjRFjnkUMopw5Jo/CkzaJXdCagcEZ471cf61C4+tTYtMrjr0qRWwcYxRtANOA56UrWCyAjPOMUirtYYHFTquV6Cm4wTxTJsXICCMYq6vCjFULdhkEVoJyoNNEND1Jpjk56GnjPfpTzjv0pk2K3Q5pfLJfcozmnlcH0p8fB4oaAkSJ+OBU3lt/dp0JyORz9anB4oUbktoqLIGHBBpjFiDUEGfN6DpVg8DoKINtGciMA55NOXjnPNJnk0oyQOKsnUkBNDGmNkDrzTSWqWzREq9TTm5OKgEhzjFSLIpHfNNO4nccV56U2nFuDzTN2eR0qrGdncUdCaUE/1puf5Uq9fajUp3FpOKB1oYHI5piEJFRseKccjPFMLZFSxpDevQ01hTuOOaY3U4NIdiI9eQKYw/Cnt1pjHHegaRGQKibI6elSljn2prHpmk0WiHBpcHHWjNHpmkaoQjj3FRMKkyTxTHBNS1cuzIj17U4DIHSj34/wAKenQik0UCcDB5p7J+VAFOI+XHemQxIcYxWlDgrWapIPSr8DmnFElkDHSl7YpR0pcDFGpIzr1GaQYHpTyMNntSNinqJosRdetWfwFVIiKtZ4/+vTIZnQDkn2qZhwBTIf8AVk96fuHNKK0M2tRvanfSk6Gkz29qoVhHPT1pD1oJ+YUhOOtSyxCcHNO7U3PPJp2feiwNEn8FCjig9MZoXOauxFtRSoxRwTxSsemTSBfSmyxpyPpUfnpjrUkikIcHtVHqO3WpcmhqJYM6EdaZ5qdC3PuahKYHWo9oOajnGolvhj1prDg88VHbjrzUpGKtO5TRC33s9qayn8KlZQWz2pHXKGjQUUVXGDxmmnJAp7deaacZqS+Uh+6ScHNKTn1oPAPNMVskDdTdiloOzzSNg04g880wqxGKmxaYzvSgDHekwQcUoHy0rDuSqecfrTmqNeDxUhztptIljR1NWoH5xVJfvEHirMPPTijQhmmjZFPBHbpUMfTGc1Jgjmmidxdo6elDU0k5zTzzzT6AxF4wQatq52iqWTmp1fAHNCTInex//9k=\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "example_row = train_data.iloc[0]\n", "example_image = example_row[image_col]\n", "\n", "from IPython.display import Image, display\n", "pil_img = Image(filename=example_image)\n", "display(pil_img)" ] }, { "cell_type": "markdown", "id": "a18c22ae", "metadata": {}, "source": [ "
\n", "\n", "## 2. Training" ] }, { "cell_type": "code", "execution_count": null, "id": "804432cc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Extension horovod.torch has not been built: /home/ec2-user/anaconda3/envs/mxnet_p37/lib/python3.7/site-packages/horovod/torch/mpi_lib/_mpi_lib.cpython-37m-x86_64-linux-gnu.so not found\n", "If this is not expected, reinstall Horovod with HOROVOD_WITH_PYTORCH=1 to debug the build error.\n", "Warning! MPI libs are missing, but python applications are still avaiable.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Global seed set to 123\n" ] } ], "source": [ "from autogluon.multimodal import MultiModalPredictor\n", "predictor = MultiModalPredictor(label=label_col, path=save_path)\n", "predictor.fit(\n", " train_data=train_data,\n", " time_limit=120, # seconds\n", ")" ] }, { "cell_type": "code", "execution_count": 15, "id": "fda231ea", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "61ae2a7bc4504ba389d238fdc1ca3565", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "{'roc_auc': 0.9408}" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scores = predictor.evaluate(test_data, metrics=[\"roc_auc\"])\n", "scores" ] }, { "cell_type": "code", "execution_count": null, "id": "b10893e8", "metadata": {}, "outputs": [], "source": [ "# predictor.save('my_saved_dir')\n", "# loaded_predictor = MultiModalPredictor.load('my_saved_dir')\n", "# scores2 = loaded_predictor.evaluate(test_data, metrics=[\"roc_auc\"])\n", "# scores2" ] } ], "metadata": { "kernelspec": { "display_name": "conda_pytorch_p38", "language": "python", "name": "conda_pytorch_p38" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }