{
"cells": [
{
"cell_type": "markdown",
"id": "3c130782-9478-4cb4-b6ab-852cfbc07d4c",
"metadata": {},
"source": [
"# Lab 1: Korean NLI (Natural Language Inference) Training on AWS\n",
"\n",
"## Introduction\n",
"---\n",
"\n",
"본 모듈에서는 허깅페이스 트랜스포머(Hugging Face transformers) 라이브러리를 사용하여 한국어 자연어 추론 (Korean NLI; Natural Language Inference) 쌍을 훈련합니다. 자연어 추론은 전제(premise)와 가설(hypothesis)이 포함된 두 문장 사이에서 전제가 참이라고 가정할 때, 가설의 연결이 참인지(entailment), 모순이 있는지(contradiction), 알 수 없는지(neutral)로 구별하는 다운스트림 태스크입니다.\n",
"\n",
"***[Note] SageMaker Studio Lab, SageMaker Studio, SageMaker 노트북 인스턴스, 또는 여러분의 로컬 머신에서 이 데모를 실행할 수 있습니다. SageMaker Studio Lab을 사용하는 경우 GPU를 활성화하세요.***\n",
"\n",
"### References\n",
"- Hugging Face Tutorial: https://huggingface.co/docs/transformers/tasks/question_answering\n",
"- Fine-tuning with custom datasets: https://huggingface.co/transformers/v4.11.3/custom_datasets.html#question-answering-with-squad-2-0\n",
"- KorNLI datasets: https://github.com/kakaobrain/KorNLUDatasets/tree/master/KorNLI\n",
"- KLUE: https://github.com/KLUE-benchmark/KLUE"
]
},
{
"cell_type": "markdown",
"id": "44cf4937-b073-4611-9982-f4be3b58efa2",
"metadata": {},
"source": [
"\n",
"## 1. Setup Environments\n",
"---\n",
"\n",
"### Import modules"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "fbb97a1c-4057-4e14-a1af-48a773879fd7",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"import json\n",
"import logging\n",
"import argparse\n",
"import torch\n",
"from torch import nn\n",
"import numpy as np\n",
"import pandas as pd\n",
"from tqdm import tqdm\n",
"from IPython.display import display, HTML\n",
"\n",
"from transformers import (\n",
" AutoTokenizer, AutoModelForSequenceClassification,\n",
" Trainer, TrainingArguments, set_seed\n",
")\n",
"from transformers.trainer_utils import get_last_checkpoint\n",
"from datasets import load_dataset, load_metric, ClassLabel, Sequence\n",
"\n",
"logging.basicConfig(\n",
" level=logging.INFO, \n",
" format='[{%(filename)s:%(lineno)d} %(levelname)s - %(message)s',\n",
" handlers=[\n",
" logging.StreamHandler(sys.stdout)\n",
" ]\n",
")\n",
"logger = logging.getLogger(__name__)"
]
},
{
"cell_type": "markdown",
"id": "0ba7a122-0df6-4022-a921-b07605d77973",
"metadata": {},
"source": [
"### Argument parser"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "36090222-56c2-4a51-b915-4801d12c80a7",
"metadata": {},
"outputs": [],
"source": [
"def parser_args(train_notebook=False):\n",
" parser = argparse.ArgumentParser()\n",
"\n",
" # Default Setting\n",
" parser.add_argument(\"--epochs\", type=int, default=3)\n",
" parser.add_argument(\"--seed\", type=int, default=42)\n",
" parser.add_argument(\"--train_batch_size\", type=int, default=32)\n",
" parser.add_argument(\"--eval_batch_size\", type=int, default=64)\n",
" parser.add_argument(\"--max_length\", type=int, default=384)\n",
" parser.add_argument(\"--stride\", type=int, default=64)\n",
" parser.add_argument(\"--warmup_steps\", type=int, default=100)\n",
" parser.add_argument(\"--logging_steps\", type=int, default=100)\n",
" parser.add_argument(\"--learning_rate\", type=str, default=3e-5)\n",
" parser.add_argument(\"--disable_tqdm\", type=bool, default=False)\n",
" parser.add_argument(\"--fp16\", type=bool, default=True)\n",
" parser.add_argument(\"--tokenizer_id\", type=str, default='klue/roberta-base')\n",
" parser.add_argument(\"--model_id\", type=str, default='klue/roberta-base')\n",
" \n",
" # SageMaker Container environment\n",
" parser.add_argument(\"--output_data_dir\", type=str, default=os.environ[\"SM_OUTPUT_DATA_DIR\"])\n",
" parser.add_argument(\"--model_dir\", type=str, default=os.environ[\"SM_MODEL_DIR\"])\n",
" parser.add_argument(\"--n_gpus\", type=str, default=os.environ[\"SM_NUM_GPUS\"])\n",
" parser.add_argument(\"--train_dir\", type=str, default=os.environ[\"SM_CHANNEL_TRAIN\"])\n",
" parser.add_argument(\"--valid_dir\", type=str, default=os.environ[\"SM_CHANNEL_VALID\"])\n",
" parser.add_argument('--chkpt_dir', type=str, default='/opt/ml/checkpoints') \n",
"\n",
" if train_notebook:\n",
" args = parser.parse_args([])\n",
" else:\n",
" args = parser.parse_args()\n",
" return args"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "acd57291-8742-46e9-9ff5-c7984eb4ca56",
"metadata": {},
"outputs": [],
"source": [
"train_dir = 'nli_train'\n",
"valid_dir = 'nli_valid'\n",
"!rm -rf {train_dir} {valid_dir}\n",
"os.makedirs(train_dir, exist_ok=True)\n",
"os.makedirs(valid_dir, exist_ok=True) "
]
},
{
"cell_type": "markdown",
"id": "c2b42b7b-6486-44ec-b35f-94c718cf589e",
"metadata": {},
"source": [
"### Load Arguments\n",
"\n",
"주피터 노트북에서 곧바로 실행할 수 있도록 설정값들을 로드합니다. 물론 노트북 환경이 아닌 커맨드라인에서도 `cd scripts & python3 train.py` 커맨드로 훈련 스크립트를 실행할 수 있습니다."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "5659bafa-5ff5-44b0-8901-f31ceea25bfd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[{204499775.py:21} INFO - ***** Arguments *****\n",
"[{204499775.py:22} INFO - epochs=3\n",
"seed=42\n",
"train_batch_size=32\n",
"eval_batch_size=64\n",
"max_length=384\n",
"stride=64\n",
"warmup_steps=100\n",
"logging_steps=100\n",
"learning_rate=3e-05\n",
"disable_tqdm=False\n",
"fp16=True\n",
"tokenizer_id=klue/roberta-base\n",
"model_id=klue/roberta-base\n",
"output_data_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/nli/data\n",
"model_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/nli/model\n",
"n_gpus=4\n",
"train_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/nli/nli_train\n",
"valid_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/nli/nli_valid\n",
"chkpt_dir=chkpt\n",
"\n"
]
}
],
"source": [
"chkpt_dir = 'chkpt'\n",
"model_dir = 'model'\n",
"output_data_dir = 'data'\n",
"num_gpus = torch.cuda.device_count()\n",
"\n",
"!rm -rf {chkpt_dir} {model_dir} {output_data_dir} \n",
"\n",
"if os.environ.get('SM_CURRENT_HOST') is None:\n",
" is_sm_container = False\n",
"\n",
" #src_dir = '/'.join(os.getcwd().split('/')[:-1])\n",
" src_dir = os.getcwd()\n",
" os.environ['SM_MODEL_DIR'] = f'{src_dir}/{model_dir}'\n",
" os.environ['SM_OUTPUT_DATA_DIR'] = f'{src_dir}/{output_data_dir}'\n",
" os.environ['SM_NUM_GPUS'] = str(num_gpus)\n",
" os.environ['SM_CHANNEL_TRAIN'] = f'{src_dir}/{train_dir}'\n",
" os.environ['SM_CHANNEL_VALID'] = f'{src_dir}/{valid_dir}'\n",
"\n",
"args = parser_args(train_notebook=True) \n",
"args.chkpt_dir = chkpt_dir\n",
"logger.info(\"***** Arguments *****\")\n",
"logger.info(''.join(f'{k}={v}\\n' for k, v in vars(args).items()))\n",
"\n",
"os.makedirs(args.chkpt_dir, exist_ok=True) \n",
"os.makedirs(args.model_dir, exist_ok=True)\n",
"os.makedirs(args.output_data_dir, exist_ok=True) "
]
},
{
"cell_type": "markdown",
"id": "1c4109c9-ae21-4f42-9c75-d8e13d8b6e51",
"metadata": {},
"source": [
"
\n",
"\n",
"## 2. Preparation & Custructing Feature set\n",
"---\n",
"\n",
"### Dataset\n",
"\n",
"본 핸즈온에서 사용할 데이터셋은 KLUE-NLI로 허깅페이스의 dataset 라이브러리로 곧바로 로드할 수 있습니다.\n",
"- KLUE: https://github.com/KLUE-benchmark/KLUE"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "4d2065a2-58b3-448a-9af3-e93dc92d146b",
"metadata": {},
"outputs": [],
"source": [
"from datasets import ClassLabel, Sequence\n",
"import random\n",
"import pandas as pd\n",
"from IPython.display import display, HTML\n",
"\n",
"def show_random_elements(dataset, num_examples=10):\n",
" assert num_examples <= len(\n",
" dataset\n",
" ), \"Can't pick more elements than there are in the dataset.\"\n",
" picks = []\n",
" for _ in range(num_examples):\n",
" pick = random.randint(0, len(dataset) - 1)\n",
" while pick in picks:\n",
" pick = random.randint(0, len(dataset) - 1)\n",
" picks.append(pick)\n",
"\n",
" df = pd.DataFrame(dataset[picks])\n",
" for column, typ in dataset.features.items():\n",
" if isinstance(typ, ClassLabel):\n",
" df[column] = df[column].transform(lambda i: typ.names[i])\n",
" elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):\n",
" df[column] = df[column].transform(\n",
" lambda x: [typ.feature.names[i] for i in x]\n",
" )\n",
" display(HTML(df.to_html()))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a2050410-9cf1-4e1d-8736-2af109599a92",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[{builder.py:641} WARNING - Reusing dataset klue (/home/ec2-user/.cache/huggingface/datasets/klue/nli/1.0.0/e0fc3bc3de3eb03be2c92d72fd04a60ecc71903f821619cb28ca0e1e29e4233e)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c1d075e224724a62a7c0417d80b3343d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/2 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"
\n", " | guid | \n", "source | \n", "premise | \n", "hypothesis | \n", "label | \n", "
---|---|---|---|---|---|
0 | \n", "klue-nli-v1_train_03020 | \n", "wikipedia | \n", "그는 1903년에 자신과 아내 파울리네 사이에 있었던 한 부부 싸움을 떠올렸다. | \n", "그는 1903년 다음 해에 파울리네와 부부가 되었다. | \n", "contradiction | \n", "
1 | \n", "klue-nli-v1_train_20075 | \n", "airbnb | \n", "지금껏 다녔던 숙소보다 너무 좋았어요. | \n", "지금껏 다녔던 숙소 중에 제일 별로였어요. | \n", "contradiction | \n", "
2 | \n", "klue-nli-v1_train_19388 | \n", "wikinews | \n", "조선민주주의인민공화국이 2013년 2월 12일, 제3차 핵실험을 성공하였다고 공식 발표하였다. | \n", "조선민주주의인민공화국은 핵실험에 실패한 적이 있다. | \n", "neutral | \n", "
3 | \n", "klue-nli-v1_train_03544 | \n", "wikinews | \n", "그런데, 아직도 문제를 단순히 공식만으로 풀게 하고, 지루하게 계산만을 반복시키는 그런 수학, 이거 안 통합니다. | \n", "이제 더 이상 수학에 공식외우기는 없어야 합니다. | \n", "neutral | \n", "
4 | \n", "klue-nli-v1_train_03100 | \n", "wikitree | \n", "그는 주머니에서 오만 원권 두 장을 꺼내 팬에게 건넸고, 공연장에 있던 사람들은 그의 행동에 환호했다. | \n", "그는 주머니에서 만원 권 두 장만을 꺼내 팬에게 건넸다. | \n", "contradiction | \n", "
5 | \n", "klue-nli-v1_train_03446 | \n", "policy | \n", "그러나, 앞날을 결코 낙관할 수 없습니다. | \n", "낙관할 수 없는 앞날입니다. | \n", "entailment | \n", "
6 | \n", "klue-nli-v1_train_01274 | \n", "NSMC | \n", "감독이라는것도 정말 일정 자격 시험이 필요한게 아닐까 하는 생각이 들게 한다 | \n", "일정 자격 시험을 통과해서 감독 자격을 받도록 해야 하는게 아닐까 하는 생각이 들게 한다. | \n", "entailment | \n", "
7 | \n", "klue-nli-v1_train_22075 | \n", "NSMC | \n", "타란티노가 연출했다면 더 잘 살렸을듯 연출이 각본을 따라가질못한다 | \n", "타란티노가 연출했다면 더 좋았을 듯. | \n", "entailment | \n", "
8 | \n", "klue-nli-v1_train_09824 | \n", "airbnb | \n", "빵이나 커피 잼 과일 맘껏 먹을 수 있고요 | \n", "잼은 두 종류 뿐이고요. | \n", "neutral | \n", "
9 | \n", "klue-nli-v1_train_23753 | \n", "wikitree | \n", "한편, 구는 수립된 안전관리계획을 책자로 제작해 관련 유관기관 등에 배부할 예정이다. | \n", "안전관리계획을 책자로 제작하는 곳은 유관기관이다. | \n", "contradiction | \n", "
Epoch | \n", "Training Loss | \n", "Validation Loss | \n", "Accuracy | \n", "
---|---|---|---|
1 | \n", "No log | \n", "0.504815 | \n", "0.818333 | \n", "
2 | \n", "No log | \n", "0.436733 | \n", "0.847000 | \n", "
3 | \n", "0.440700 | \n", "0.450579 | \n", "0.852667 | \n", "
"
],
"text/plain": [
"
\n",
"\n",
"## 4. Evaluation\n",
"---\n",
"\n",
"평가를 수행합니다."
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "34f4daa1-89bb-4839-ba64-14d777db0901",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The following columns in the test set don't have a corresponding argument in `RobertaForSequenceClassification.forward` and have been ignored: source, guid, premise, hypothesis. If source, guid, premise, hypothesis are not expected by `RobertaForSequenceClassification.forward`, you can safely ignore this message.\n",
"***** Running Prediction *****\n",
" Num examples = 3000\n",
" Batch size = 256\n",
"/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n"
]
},
{
"data": {
"text/html": [
"\n",
"
\n",
"\n",
"## 5. Prediction\n",
"---\n",
"\n",
"여러분만의 샘플 문장을 만들어서 자유롭게 추론을 수행해 보세요."
]
},
{
"cell_type": "code",
"execution_count": 65,
"id": "e9101731-9409-43a7-bb89-4691ae0dfc10",
"metadata": {},
"outputs": [],
"source": [
"idx = [i for i in range(len(train_dataset.features['label'].names))]\n",
"classes = train_dataset.features['label'].names\n",
"\n",
"model.config.label2id = dict(zip(classes, idx))\n",
"model.config.id2label = dict(zip(idx, classes))"
]
},
{
"cell_type": "code",
"execution_count": 66,
"id": "7ceb8e53-9931-426c-9dad-44e237c74d5c",
"metadata": {},
"outputs": [],
"source": [
"from transformers import pipeline\n",
"classifier = pipeline(\n",
" task=\"text-classification\",\n",
" model=model, \n",
" tokenizer=tokenizer,\n",
" top_k=1,\n",
" device=0\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 71,
"id": "9b197f11-85c0-449e-b9ac-5d5e53430f12",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'label': 'contradiction', 'score': 0.9946433305740356}]"
]
},
"execution_count": 71,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"example = f\"머신러닝은 쉽다. {tokenizer.sep_token} 머신러닝은 어렵다.\"\n",
"classifier(sentences)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "conda_pytorch_p38",
"language": "python",
"name": "conda_pytorch_p38"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}