{
"cells": [
{
"cell_type": "markdown",
"id": "a7a5062b-19e2-4106-837a-a64af46bce89",
"metadata": {},
"source": [
"# Lab 1: Korean-English Translation model Training\n",
"\n",
"## Introduction\n",
"---\n",
"\n",
"본 모듈에서는 허깅페이스 트랜스포머(Hugging Face transformers) 라이브러리를 사용하여 한영 번역 모델을 훈련합니다. 번역은 시퀀스-투-시퀀스(sequence-to-sequence) 태스크의 가장 대표적인 형태로, 어텐션 메커니즘과 트랜스포머 기반 언어 모델의 기반이 되었던 다운스트림 태스크입니다.\n",
"\n",
"\n",
"### References\n",
"\n",
"- Hugging Face Tutorial: https://huggingface.co/docs/transformers/training\n",
"- Translation fine-tuning: https://huggingface.co/docs/transformers/tasks/translation\n",
"- KDE4 dataset: https://huggingface.co/datasets/kde4\n",
"- 관련 논문: http://www.lrec-conf.org/proceedings/lrec2012/pdf/463_Paper.pdf"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0f74e730-f022-4310-8dd0-5a06b278c1bc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com\n",
"Requirement already satisfied: sacrebleu in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (2.2.0)\n",
"Requirement already satisfied: portalocker in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from sacrebleu) (2.5.1)\n",
"Requirement already satisfied: lxml in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from sacrebleu) (4.8.0)\n",
"Requirement already satisfied: tabulate>=0.8.9 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from sacrebleu) (0.8.9)\n",
"Requirement already satisfied: colorama in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from sacrebleu) (0.4.3)\n",
"Requirement already satisfied: numpy>=1.17 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from sacrebleu) (1.21.2)\n",
"Requirement already satisfied: regex in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from sacrebleu) (2021.11.10)\n",
"\u001b[33mWARNING: You are using pip version 22.0.4; however, version 22.2 is available.\n",
"You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p38/bin/python -m pip install --upgrade pip' command.\u001b[0m\u001b[33m\n",
"\u001b[0m"
]
}
],
"source": [
"!pip install sacrebleu"
]
},
{
"cell_type": "markdown",
"id": "a8f970b9-8a51-4c51-a9de-6329c5cc8b76",
"metadata": {},
"source": [
"\n",
"## 1. Setup Environments\n",
"---\n",
"\n",
"### Import modules"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2177e1af-87fd-4d5d-8621-22c42b762ad7",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"import json\n",
"import logging\n",
"import argparse\n",
"import torch\n",
"from torch import nn\n",
"import numpy as np\n",
"import pandas as pd\n",
"from tqdm import tqdm\n",
"from IPython.display import display, HTML\n",
"\n",
"from transformers import (\n",
" AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq,\n",
" Trainer, TrainingArguments, set_seed\n",
")\n",
"from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer\n",
"\n",
"from transformers.trainer_utils import get_last_checkpoint\n",
"from datasets import load_dataset, load_metric, ClassLabel, Sequence\n",
"\n",
"logging.basicConfig(\n",
" level=logging.INFO, \n",
" format='[{%(filename)s:%(lineno)d} %(levelname)s - %(message)s',\n",
" handlers=[\n",
" logging.StreamHandler(sys.stdout)\n",
" ]\n",
")\n",
"logger = logging.getLogger(__name__)"
]
},
{
"cell_type": "markdown",
"id": "05163c75-04f3-4c05-9b17-a2ffa95d6354",
"metadata": {},
"source": [
"### Argument parser"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "238afbb6-6c5c-43e0-a2d7-2960a61f8289",
"metadata": {},
"outputs": [],
"source": [
"def parser_args(train_notebook=False):\n",
" parser = argparse.ArgumentParser()\n",
"\n",
" # Default Setting\n",
" parser.add_argument(\"--epochs\", type=int, default=1)\n",
" parser.add_argument(\"--seed\", type=int, default=42)\n",
" parser.add_argument(\"--train_batch_size\", type=int, default=32)\n",
" parser.add_argument(\"--eval_batch_size\", type=int, default=32)\n",
" parser.add_argument(\"--max_length\", type=int, default=128)\n",
" parser.add_argument(\"--stride\", type=int, default=32)\n",
" parser.add_argument(\"--warmup_steps\", type=int, default=100)\n",
" parser.add_argument(\"--logging_steps\", type=int, default=100)\n",
" parser.add_argument(\"--learning_rate\", type=str, default=3e-5)\n",
" parser.add_argument(\"--disable_tqdm\", type=bool, default=False)\n",
" parser.add_argument(\"--fp16\", type=bool, default=True)\n",
" parser.add_argument(\"--debug\", type=bool, default=False) \n",
" parser.add_argument(\"--tokenizer_id\", type=str, default='Helsinki-NLP/opus-mt-ko-en')\n",
" parser.add_argument(\"--model_id\", type=str, default='Helsinki-NLP/opus-mt-ko-en')\n",
" \n",
" # SageMaker Container environment\n",
" parser.add_argument(\"--output_data_dir\", type=str, default=os.environ[\"SM_OUTPUT_DATA_DIR\"])\n",
" parser.add_argument(\"--model_dir\", type=str, default=os.environ[\"SM_MODEL_DIR\"])\n",
" parser.add_argument(\"--n_gpus\", type=str, default=os.environ[\"SM_NUM_GPUS\"])\n",
" parser.add_argument(\"--train_dir\", type=str, default=os.environ[\"SM_CHANNEL_TRAIN\"])\n",
" parser.add_argument(\"--valid_dir\", type=str, default=os.environ[\"SM_CHANNEL_VALID\"])\n",
" parser.add_argument('--chkpt_dir', type=str, default='/opt/ml/checkpoints') \n",
"\n",
" if train_notebook:\n",
" args = parser.parse_args([])\n",
" else:\n",
" args = parser.parse_args()\n",
" return args"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c27daaf3-2498-4a45-a339-b64aa3087ae1",
"metadata": {},
"outputs": [],
"source": [
"train_dir = 'seq2seq_translate_train'\n",
"valid_dir = 'seq2seq_translate_valid'\n",
"!rm -rf {train_dir} {valid_dir}\n",
"os.makedirs(train_dir, exist_ok=True)\n",
"os.makedirs(valid_dir, exist_ok=True) "
]
},
{
"cell_type": "markdown",
"id": "5d0f072f-b098-4635-ad94-b531b207d58a",
"metadata": {},
"source": [
"### Load Arguments\n",
"\n",
"주피터 노트북에서 곧바로 실행할 수 있도록 설정값들을 로드합니다. 물론 노트북 환경이 아닌 커맨드라인에서도 `cd scripts & python3 train.py` 커맨드로 훈련 스크립트를 실행할 수 있습니다."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "b831876c-5335-4133-bfe2-dfb2ade48b27",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[{204499775.py:21} INFO - ***** Arguments *****\n",
"[{204499775.py:22} INFO - epochs=1\n",
"seed=42\n",
"train_batch_size=32\n",
"eval_batch_size=32\n",
"max_length=128\n",
"stride=32\n",
"warmup_steps=100\n",
"logging_steps=100\n",
"learning_rate=3e-05\n",
"disable_tqdm=False\n",
"fp16=True\n",
"debug=False\n",
"tokenizer_id=Helsinki-NLP/opus-mt-ko-en\n",
"model_id=Helsinki-NLP/opus-mt-ko-en\n",
"output_data_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/translation/data\n",
"model_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/translation/model\n",
"n_gpus=4\n",
"train_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/translation/seq2seq_translate_train\n",
"valid_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/translation/seq2seq_translate_valid\n",
"chkpt_dir=chkpt\n",
"\n"
]
}
],
"source": [
"chkpt_dir = 'chkpt'\n",
"model_dir = 'model'\n",
"output_data_dir = 'data'\n",
"num_gpus = torch.cuda.device_count()\n",
"\n",
"!rm -rf {chkpt_dir} {model_dir} {output_data_dir} \n",
"\n",
"if os.environ.get('SM_CURRENT_HOST') is None:\n",
" is_sm_container = False\n",
"\n",
" #src_dir = '/'.join(os.getcwd().split('/')[:-1])\n",
" src_dir = os.getcwd()\n",
" os.environ['SM_MODEL_DIR'] = f'{src_dir}/{model_dir}'\n",
" os.environ['SM_OUTPUT_DATA_DIR'] = f'{src_dir}/{output_data_dir}'\n",
" os.environ['SM_NUM_GPUS'] = str(num_gpus)\n",
" os.environ['SM_CHANNEL_TRAIN'] = f'{src_dir}/{train_dir}'\n",
" os.environ['SM_CHANNEL_VALID'] = f'{src_dir}/{valid_dir}'\n",
"\n",
"args = parser_args(train_notebook=True) \n",
"args.chkpt_dir = chkpt_dir\n",
"logger.info(\"***** Arguments *****\")\n",
"logger.info(''.join(f'{k}={v}\\n' for k, v in vars(args).items()))\n",
"\n",
"os.makedirs(args.chkpt_dir, exist_ok=True) \n",
"os.makedirs(args.model_dir, exist_ok=True)\n",
"os.makedirs(args.output_data_dir, exist_ok=True) "
]
},
{
"cell_type": "markdown",
"id": "109b1d73-65a7-4178-b2ce-91178f4be83e",
"metadata": {},
"source": [
"
\n",
"\n",
"## 2. Preparation & Custructing Feature set\n",
"---\n",
"\n",
"### Dataset\n",
"\n",
"본 핸즈온에서 사용할 데이터셋은 KDE4 데이터셋으로 한국어를 포함한 100여가지에 육박하는 언어를 지원하고 있습니다. 이 데이터셋을 사용하여, 대규모 Opus 데이터셋 (https://opus.nlpl.eu/) 으로 사전 훈련된 한영 번역 Marian 모델을 파인튜닝합니다.\n",
"\n",
"\n",
"- KDE4 dataset: https://huggingface.co/datasets/kde4\n",
"- 관련 논문: http://www.lrec-conf.org/proceedings/lrec2012/pdf/463_Paper.pdf\n",
"\n",
"```\n",
"{\n",
" 'id': '15',\n",
" 'translation': \n",
" {\n",
" 'en': '& kde; provides a highly configurable desktop environment. This overview assumes that you are using the default environment.',\n",
" 'ko': '& kde; 는 다양한 부분을 설정할 수 있는 데스크톱 환경입니다. 이 문서에서는 여러분이 기본적인 데스크톱 환경을 사용한다는 것을 가정합니다.'\n",
" }\n",
"}\n",
"```` "
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "d61e29dc-1d61-4b89-b169-84cebc8da48f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[{builder.py:463} WARNING - Using custom data configuration en-ko-lang1=en,lang2=ko\n",
"[{builder.py:641} WARNING - Reusing dataset kde4 (/home/ec2-user/.cache/huggingface/datasets/kde4/en-ko-lang1=en,lang2=ko/0.0.0/243129fb2398d5b0b4f7f6831ab27ad84774b7ce374cf10f60f6e1ff331648ac)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8cccc3380ad04c5895f7050f3fd02512",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from datasets import load_dataset, load_metric\n",
"\n",
"raw_datasets = load_dataset(\"kde4\", lang1=\"en\", lang2=\"ko\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "41c0cfe0-a2e1-4fd3-be3a-c5e4250843cb",
"metadata": {},
"outputs": [],
"source": [
"from datasets import ClassLabel, Sequence\n",
"import random\n",
"import pandas as pd\n",
"from IPython.display import display, HTML\n",
"\n",
"def show_random_elements(dataset, num_examples=10):\n",
" assert num_examples <= len(\n",
" dataset\n",
" ), \"Can't pick more elements than there are in the dataset.\"\n",
" picks = []\n",
" for _ in range(num_examples):\n",
" pick = random.randint(0, len(dataset) - 1)\n",
" while pick in picks:\n",
" pick = random.randint(0, len(dataset) - 1)\n",
" picks.append(pick)\n",
"\n",
" df = pd.DataFrame(dataset[picks])\n",
" for column, typ in dataset.features.items():\n",
" if isinstance(typ, ClassLabel):\n",
" df[column] = df[column].transform(lambda i: typ.names[i])\n",
" elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):\n",
" df[column] = df[column].transform(\n",
" lambda x: [typ.feature.names[i] for i in x]\n",
" )\n",
" display(HTML(df.to_html()))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "917e3994-3f9d-4ca3-9d75-fb3d2d22194d",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n", " | id | \n", "translation | \n", "
---|---|---|
0 | \n", "12899 | \n", "{'en': 'Archive deleted.', 'ko': '압축 파일을 삭제했습니다.'} | \n", "
1 | \n", "11403 | \n", "{'en': 'Two Terminals, Horizontally', 'ko': '두 개의 터미널, 수평@ action'} | \n", "
2 | \n", "52907 | \n", "{'en': 'Pacific/ Fakaofo', 'ko': '태평양/ 파카오포'} | \n", "
3 | \n", "50684 | \n", "{'en': 'brown1', 'ko': 'color'} | \n", "
4 | \n", "8330 | \n", "{'en': 'Replace selection', 'ko': '선택부분 바꾸기'} | \n", "
5 | \n", "40103 | \n", "{'en': 'Central Region', 'ko': 'Central RegionRegion/ state in Russia'} | \n", "
6 | \n", "52832 | \n", "{'en': 'Europe/ Kaliningrad', 'ko': '유럽/ 칼리닌그라드'} | \n", "
7 | \n", "35057 | \n", "{'en': 'Border/ Coast', 'ko': 'ukraine. kgm'} | \n", "
8 | \n", "8446 | \n", "{'en': 'Public Domain', 'ko': 'Public Domain'} | \n", "
9 | \n", "46288 | \n", "{'en': 'Sami (Northern, Sweden)', 'ko': '북부 사미어 (스웨덴)'} | \n", "
Step | \n", "Training Loss | \n", "
---|
"
],
"text/plain": [
"
\n",
"\n",
"## 4. Evaluation\n",
"---\n",
"\n",
"평가를 수행합니다."
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "673e411f-f515-4ceb-867d-ee0db5aa8ead",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"***** Running Prediction *****\n",
" Num examples = 300\n",
" Batch size = 128\n",
"/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n"
]
},
{
"data": {
"text/html": [
"\n",
"
\n",
"\n",
"## 5. Prediction\n",
"---\n",
"\n",
"여러분만의 샘플 문장을 만들어서 자유롭게 추론을 수행해 보세요."
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "03a6dbf7-3030-4da9-bf0c-45f8819726ba",
"metadata": {},
"outputs": [],
"source": [
"from transformers import pipeline\n",
"translator = pipeline(\n",
" task=\"translation\",\n",
" model=model, \n",
" tokenizer=tokenizer,\n",
" device=0\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "7b196a44-cfca-4452-9795-75452eb72190",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'translation_text': \"It's easy and fast to develop an ML model through the Amazon SageMaker, a fully managed service.\"}]"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"translator(\"머신 러닝 완전 관리형 서비스인 Amazon SageMaker를 통해 쉽고 빠르게 ML모델을 개발하세요\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "conda_pytorch_p38",
"language": "python",
"name": "conda_pytorch_p38"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}