{ "cells": [ { "cell_type": "markdown", "id": "4c825f0a", "metadata": {}, "source": [ "# Sentiment Classification for Movie Review Dataset (English)\n", "\n", "본 핸즈온에서는 영화 리뷰에 대한 감정(0: 부정, 1: 긍정)을 요약한 Stanford Sentiment Treebank (SST) 데이터셋으로 AutoGluon 훈련을 수행합니다." ] }, { "cell_type": "code", "execution_count": 1, "id": "91702aa2", "metadata": {}, "outputs": [], "source": [ "# GPU 인스턴스를 사용하고 CloudFormation으로 실습 환경을 구성하지 않았다면, 아래 주석을 해제하고 본 코드 셀을 실행 후, 노트북을 재시작해 주세요.\n", "# !pip install mxnet-cu110==1.9.1" ] }, { "cell_type": "code", "execution_count": 1, "id": "03d55f77", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1\n" ] } ], "source": [ "import os\n", "import torch\n", "import mxnet as mx\n", "num_gpus = torch.cuda.device_count()\n", "\n", "if num_gpus == 0:\n", " os.environ['AUTOGLUON_TEXT_TRAIN_WITHOUT_GPU'] = '1'\n", "\n", "print(num_gpus) " ] }, { "cell_type": "code", "execution_count": 2, "id": "e30cdf62", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import warnings\n", "import matplotlib.pyplot as plt\n", "warnings.filterwarnings('ignore')\n", "np.random.seed(123)" ] }, { "cell_type": "markdown", "id": "51f3b0d7", "metadata": {}, "source": [ "
\n", "\n", "## 1. Data preparation and Training" ] }, { "cell_type": "code", "execution_count": 3, "id": "1ee9ba24", "metadata": {}, "outputs": [], "source": [ "save_path = 'ag-01-sentiment-classifcation-eng'\n", "!rm -rf $save_path" ] }, { "cell_type": "markdown", "id": "1753830e", "metadata": {}, "source": [ "샘플 데이터셋을 다운로드합니다. parquet와 csv를 지원하며, 원격에 저장된 데이터셋을 다운로드하거나 로컬에서 데이터셋을 직접 로드할 수 있습니다." ] }, { "cell_type": "code", "execution_count": 4, "id": "f6977aab", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sentencelabel
43787very pleasing at its best moments1
16159, american chai is enough to make you put away...0
59015too much like an infomercial for ram dass 's l...0
5108a stirring visual sequence1
67052cool visual backmasking1
\n", "
" ], "text/plain": [ " sentence label\n", "43787 very pleasing at its best moments 1\n", "16159 , american chai is enough to make you put away... 0\n", "59015 too much like an infomercial for ram dass 's l... 0\n", "5108 a stirring visual sequence 1\n", "67052 cool visual backmasking 1" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from autogluon.core.utils.loaders import load_pd\n", "train_data = load_pd.load('https://autogluon-text.s3-accelerate.amazonaws.com/glue/sst/train.parquet')\n", "test_data = load_pd.load('https://autogluon-text.s3-accelerate.amazonaws.com/glue/sst/dev.parquet')\n", "subsample_size = 1000 # subsample data for faster demo, try setting this to larger values\n", "train_data = train_data.sample(n=subsample_size, random_state=0)\n", "train_data.head(5)" ] }, { "cell_type": "markdown", "id": "717c6304", "metadata": {}, "source": [ "훈련 지정 시 좀 더 세부적인 하이퍼파라메터 설정이 필요하다면, 사전 정의된 preset을 사용하시면 편리합니다. TextPredictor는 사전 훈련된 BERT, RoBERT, ELECTRA가 내장되어 있으며, 한국어를 비롯한 다국어에 대한 훈련이 필요하면 `multi_cased_bert_base_fuse_late` preset을 사용하시면 됩니다." ] }, { "cell_type": "code", "execution_count": 6, "id": "63b9d325", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['default',\n", " 'medium_quality_faster_train',\n", " 'high_quality',\n", " 'best_quality',\n", " 'multilingual']" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from autogluon.text import TextPredictor, list_text_presets\n", "list_text_presets()" ] }, { "cell_type": "code", "execution_count": 7, "id": "cb016a9c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'default': {'model.hf_text.checkpoint_name': 'google/electra-base-discriminator'},\n", " 'medium_quality_faster_train': {'model.hf_text.checkpoint_name': 'google/electra-small-discriminator',\n", " 'optimization.learning_rate': 0.0004},\n", " 'high_quality': {'model.hf_text.checkpoint_name': 'google/electra-base-discriminator'},\n", " 'best_quality': {'model.hf_text.checkpoint_name': 'microsoft/deberta-v3-base',\n", " 'env.per_gpu_batch_size': 2},\n", " 'multilingual': {'model.hf_text.checkpoint_name': 'microsoft/mdeberta-v3-base',\n", " 'optimization.top_k': 1,\n", " 'env.precision': 'bf16',\n", " 'env.per_gpu_batch_size': 4}}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list_text_presets(verbose=True)" ] }, { "cell_type": "code", "execution_count": 8, "id": "df8b4896", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Global seed set to 123\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a0c53a0b52ad4829b5d21096c9f9999b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading: 0%| | 0.00/29.0 [00:00" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor = TextPredictor(label='label', eval_metric='acc', path=save_path)\n", "#predictor.fit(train_data, time_limit=60)\n", "\n", "predictor.fit(\n", " train_data=train_data,\n", " presets=\"medium_quality_faster_train\",\n", " time_limit=60,\n", ")" ] }, { "cell_type": "markdown", "id": "90741a98", "metadata": {}, "source": [ "
\n", "\n", "## 2. Evaluation and Prediction" ] }, { "cell_type": "markdown", "id": "e7cd47e0", "metadata": {}, "source": [ "### Evaluation\n", "`predictor.evaluation()`를 사용하여 평가를 쉽게 수행할 수 있으며, F1 score 등의 추가 metric도 지정 가능합니다." ] }, { "cell_type": "code", "execution_count": 9, "id": "b81a6285", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ae00cc9f714f4384ae20fcaf29f99b19", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'acc': 0.801605504587156, 'f1': 0.808839779005525}\n" ] } ], "source": [ "if num_gpus > 0:\n", " test_score = predictor.evaluate(test_data, metrics=['acc', 'f1'])\n", " print(test_score) " ] }, { "cell_type": "markdown", "id": "4a73f665", "metadata": {}, "source": [ "### Prediction\n", "`predictor.predict()`으로 예측을 수행할 수 있습니다." ] }, { "cell_type": "code", "execution_count": 10, "id": "5c551a7c", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "db0cb828639343b5b37471e6f2e8c963", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\"Sentence\": it's a charming and often affecting journey. \"Predicted Sentiment\": 1\n", "\"Sentence\": It's slow, very, very, very slow. \"Predicted Sentiment\": 0\n" ] } ], "source": [ "sentence1 = \"it's a charming and often affecting journey.\"\n", "sentence2 = \"It's slow, very, very, very slow.\"\n", "predictions = predictor.predict({'sentence': [sentence1, sentence2]})\n", "print('\"Sentence\":', sentence1, '\"Predicted Sentiment\":', predictions[0])\n", "print('\"Sentence\":', sentence2, '\"Predicted Sentiment\":', predictions[1])" ] }, { "cell_type": "code", "execution_count": 11, "id": "128c6366", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d858790fe6264688baa1fcbc8548665c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\"Sentence\": it's a charming and often affecting journey. \"Predicted Class-Probabilities\": 0 0.000639\n", "1 0.869077\n", "Name: 0, dtype: float32\n", "\"Sentence\": It's slow, very, very, very slow. \"Predicted Class-Probabilities\": 0 0.999361\n", "1 0.130923\n", "Name: 1, dtype: float32\n" ] } ], "source": [ "probs = predictor.predict_proba({'sentence': [sentence1, sentence2]})\n", "print('\"Sentence\":', sentence1, '\"Predicted Class-Probabilities\":', probs[0])\n", "print('\"Sentence\":', sentence2, '\"Predicted Class-Probabilities\":', probs[1])" ] }, { "cell_type": "markdown", "id": "6c4ee289", "metadata": {}, "source": [ "전체 데이터셋에 대해 추론을 수행할 수도 있습니다." ] }, { "cell_type": "code", "execution_count": 12, "id": "3881122f", "metadata": {}, "outputs": [], "source": [ "# test_predictions = predictor.predict(test_data)\n", "# test_predictions.head()" ] }, { "cell_type": "markdown", "id": "3aa4fd86", "metadata": {}, "source": [ "### Save and Load\n", "\n", "predictor는 `fit()` 함수로 모델 훈련 시에 자동으로 모델을 저장하며, `load()` 함수를 통해 재로드할 수 있습니다. 물론 `save()` 함수로 모델을 저장하는 것도 가능합니다." ] }, { "cell_type": "code", "execution_count": 13, "id": "243b4b54", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Load pretrained checkpoint: ag-01-sentiment-classifcation-eng/model.ckpt\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9ace43ac80e74d6eb3ed3cf1701c6726", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
01
00.0006390.999361
10.8690770.130923
\n", "
" ], "text/plain": [ " 0 1\n", "0 0.000639 0.999361\n", "1 0.869077 0.130923" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loaded_predictor = TextPredictor.load(save_path)\n", "loaded_predictor.predict_proba({'sentence': [sentence1, sentence2]})" ] }, { "cell_type": "markdown", "id": "d82ce2bf", "metadata": {}, "source": [ "### Extract Embeddings\n", "훈련된 predictor를 사용하여 임베딩 벡터에 매핑하는 임베딩을 추출할 수도 있습니다.\n", "\n", "아래 코드 셀에서 TSNE를 사용하여 추출된 임베딩을 시각화합니다. 두 레이블에 해당하는 두 개의 클러스터가 잘 분포해 있음을 확인할 수 있습니다." ] }, { "cell_type": "code", "execution_count": 14, "id": "1fdd1c13", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c449803eb8514d0ca28886bfa0af3aff", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[[-0.9950073 0.19104779 0.30278355 ... -0.15227251 -0.2078394\n", " -0.8929618 ]\n", " [-0.37055644 0.54944015 -0.08293767 ... -0.37117034 -1.52807\n", " -1.0482882 ]\n", " [-0.89614105 0.13071844 0.14226806 ... 0.04598733 -0.5800471\n", " -0.8197252 ]\n", " ...\n", " [-0.48994136 0.35927728 0.02496865 ... 0.6498142 -0.7644585\n", " -0.8643593 ]\n", " [-0.35284802 -0.0030367 0.31872958 ... 0.31603014 -0.7678774\n", " -0.7084229 ]\n", " [-0.64051795 0.5525908 0.25442797 ... 0.22971036 -0.82027364\n", " -0.7009061 ]]\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "if num_gpus > 0:\n", " from sklearn.manifold import TSNE\n", " embeddings = predictor.extract_embedding(test_data)\n", " print(embeddings)\n", " \n", " X_embedded = TSNE(n_components=2, random_state=123).fit_transform(embeddings)\n", " for val, color in [(0, 'red'), (1, 'blue')]:\n", " idx = (test_data['label'].to_numpy() == val).nonzero()\n", " plt.scatter(X_embedded[idx, 0], X_embedded[idx, 1], c=color, label=f'label={val}')\n", " plt.legend(loc='best') " ] }, { "cell_type": "markdown", "id": "cad5d6b5", "metadata": {}, "source": [ "
\n", "\n", "## 3. Continuous Training\n", "\n", "이전에 훈련한 모델을 로드하고 `fit()`을 호출하여, 신규 데이터나 기존 데이터로 계속 훈련할 수 있습니다." ] }, { "cell_type": "code", "execution_count": 15, "id": "e5a7dee0", "metadata": {}, "outputs": [], "source": [ "save_cont_path = 'ag-01-sentiment-classifcation-cont-eng'\n", "!rm -rf $save_cont_path" ] } ], "metadata": { "kernelspec": { "display_name": "conda_pytorch_p38", "language": "python", "name": "conda_pytorch_p38" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }