{ "cells": [ { "cell_type": "markdown", "id": "a7a5062b-19e2-4106-837a-a64af46bce89", "metadata": {}, "source": [ "# Lab 1: Korean-English Translation model Training\n", "\n", "## Introduction\n", "---\n", "\n", "본 모듈에서는 허깅페이스 트랜스포머(Hugging Face transformers) 라이브러리를 사용하여 한영 번역 모델을 훈련합니다. 번역은 시퀀스-투-시퀀스(sequence-to-sequence) 태스크의 가장 대표적인 형태로, 어텐션 메커니즘과 트랜스포머 기반 언어 모델의 기반이 되었던 다운스트림 태스크입니다.\n", "\n", "\n", "### References\n", "\n", "- Hugging Face Tutorial: https://huggingface.co/docs/transformers/training\n", "- Translation fine-tuning: https://huggingface.co/docs/transformers/tasks/translation\n", "- KDE4 dataset: https://huggingface.co/datasets/kde4\n", "- 관련 논문: http://www.lrec-conf.org/proceedings/lrec2012/pdf/463_Paper.pdf" ] }, { "cell_type": "code", "execution_count": 1, "id": "0f74e730-f022-4310-8dd0-5a06b278c1bc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com\n", "Requirement already satisfied: sacrebleu in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (2.2.0)\n", "Requirement already satisfied: portalocker in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from sacrebleu) (2.5.1)\n", "Requirement already satisfied: lxml in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from sacrebleu) (4.8.0)\n", "Requirement already satisfied: tabulate>=0.8.9 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from sacrebleu) (0.8.9)\n", "Requirement already satisfied: colorama in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from sacrebleu) (0.4.3)\n", "Requirement already satisfied: numpy>=1.17 in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from sacrebleu) (1.21.2)\n", "Requirement already satisfied: regex in /home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages (from sacrebleu) (2021.11.10)\n", "\u001b[33mWARNING: You are using pip version 22.0.4; however, version 22.2 is available.\n", "You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_p38/bin/python -m pip install --upgrade pip' command.\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install sacrebleu" ] }, { "cell_type": "markdown", "id": "a8f970b9-8a51-4c51-a9de-6329c5cc8b76", "metadata": {}, "source": [ "\n", "## 1. Setup Environments\n", "---\n", "\n", "### Import modules" ] }, { "cell_type": "code", "execution_count": 2, "id": "2177e1af-87fd-4d5d-8621-22c42b762ad7", "metadata": {}, "outputs": [], "source": [ "import os\n", "import sys\n", "import json\n", "import logging\n", "import argparse\n", "import torch\n", "from torch import nn\n", "import numpy as np\n", "import pandas as pd\n", "from tqdm import tqdm\n", "from IPython.display import display, HTML\n", "\n", "from transformers import (\n", " AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq,\n", " Trainer, TrainingArguments, set_seed\n", ")\n", "from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer\n", "\n", "from transformers.trainer_utils import get_last_checkpoint\n", "from datasets import load_dataset, load_metric, ClassLabel, Sequence\n", "\n", "logging.basicConfig(\n", " level=logging.INFO, \n", " format='[{%(filename)s:%(lineno)d} %(levelname)s - %(message)s',\n", " handlers=[\n", " logging.StreamHandler(sys.stdout)\n", " ]\n", ")\n", "logger = logging.getLogger(__name__)" ] }, { "cell_type": "markdown", "id": "05163c75-04f3-4c05-9b17-a2ffa95d6354", "metadata": {}, "source": [ "### Argument parser" ] }, { "cell_type": "code", "execution_count": 3, "id": "238afbb6-6c5c-43e0-a2d7-2960a61f8289", "metadata": {}, "outputs": [], "source": [ "def parser_args(train_notebook=False):\n", " parser = argparse.ArgumentParser()\n", "\n", " # Default Setting\n", " parser.add_argument(\"--epochs\", type=int, default=1)\n", " parser.add_argument(\"--seed\", type=int, default=42)\n", " parser.add_argument(\"--train_batch_size\", type=int, default=32)\n", " parser.add_argument(\"--eval_batch_size\", type=int, default=32)\n", " parser.add_argument(\"--max_length\", type=int, default=128)\n", " parser.add_argument(\"--stride\", type=int, default=32)\n", " parser.add_argument(\"--warmup_steps\", type=int, default=100)\n", " parser.add_argument(\"--logging_steps\", type=int, default=100)\n", " parser.add_argument(\"--learning_rate\", type=str, default=3e-5)\n", " parser.add_argument(\"--disable_tqdm\", type=bool, default=False)\n", " parser.add_argument(\"--fp16\", type=bool, default=True)\n", " parser.add_argument(\"--debug\", type=bool, default=False) \n", " parser.add_argument(\"--tokenizer_id\", type=str, default='Helsinki-NLP/opus-mt-ko-en')\n", " parser.add_argument(\"--model_id\", type=str, default='Helsinki-NLP/opus-mt-ko-en')\n", " \n", " # SageMaker Container environment\n", " parser.add_argument(\"--output_data_dir\", type=str, default=os.environ[\"SM_OUTPUT_DATA_DIR\"])\n", " parser.add_argument(\"--model_dir\", type=str, default=os.environ[\"SM_MODEL_DIR\"])\n", " parser.add_argument(\"--n_gpus\", type=str, default=os.environ[\"SM_NUM_GPUS\"])\n", " parser.add_argument(\"--train_dir\", type=str, default=os.environ[\"SM_CHANNEL_TRAIN\"])\n", " parser.add_argument(\"--valid_dir\", type=str, default=os.environ[\"SM_CHANNEL_VALID\"])\n", " parser.add_argument('--chkpt_dir', type=str, default='/opt/ml/checkpoints') \n", "\n", " if train_notebook:\n", " args = parser.parse_args([])\n", " else:\n", " args = parser.parse_args()\n", " return args" ] }, { "cell_type": "code", "execution_count": 4, "id": "c27daaf3-2498-4a45-a339-b64aa3087ae1", "metadata": {}, "outputs": [], "source": [ "train_dir = 'seq2seq_translate_train'\n", "valid_dir = 'seq2seq_translate_valid'\n", "!rm -rf {train_dir} {valid_dir}\n", "os.makedirs(train_dir, exist_ok=True)\n", "os.makedirs(valid_dir, exist_ok=True) " ] }, { "cell_type": "markdown", "id": "5d0f072f-b098-4635-ad94-b531b207d58a", "metadata": {}, "source": [ "### Load Arguments\n", "\n", "주피터 노트북에서 곧바로 실행할 수 있도록 설정값들을 로드합니다. 물론 노트북 환경이 아닌 커맨드라인에서도 `cd scripts & python3 train.py` 커맨드로 훈련 스크립트를 실행할 수 있습니다." ] }, { "cell_type": "code", "execution_count": 5, "id": "b831876c-5335-4133-bfe2-dfb2ade48b27", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[{204499775.py:21} INFO - ***** Arguments *****\n", "[{204499775.py:22} INFO - epochs=1\n", "seed=42\n", "train_batch_size=32\n", "eval_batch_size=32\n", "max_length=128\n", "stride=32\n", "warmup_steps=100\n", "logging_steps=100\n", "learning_rate=3e-05\n", "disable_tqdm=False\n", "fp16=True\n", "debug=False\n", "tokenizer_id=Helsinki-NLP/opus-mt-ko-en\n", "model_id=Helsinki-NLP/opus-mt-ko-en\n", "output_data_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/translation/data\n", "model_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/translation/model\n", "n_gpus=4\n", "train_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/translation/seq2seq_translate_train\n", "valid_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/translation/seq2seq_translate_valid\n", "chkpt_dir=chkpt\n", "\n" ] } ], "source": [ "chkpt_dir = 'chkpt'\n", "model_dir = 'model'\n", "output_data_dir = 'data'\n", "num_gpus = torch.cuda.device_count()\n", "\n", "!rm -rf {chkpt_dir} {model_dir} {output_data_dir} \n", "\n", "if os.environ.get('SM_CURRENT_HOST') is None:\n", " is_sm_container = False\n", "\n", " #src_dir = '/'.join(os.getcwd().split('/')[:-1])\n", " src_dir = os.getcwd()\n", " os.environ['SM_MODEL_DIR'] = f'{src_dir}/{model_dir}'\n", " os.environ['SM_OUTPUT_DATA_DIR'] = f'{src_dir}/{output_data_dir}'\n", " os.environ['SM_NUM_GPUS'] = str(num_gpus)\n", " os.environ['SM_CHANNEL_TRAIN'] = f'{src_dir}/{train_dir}'\n", " os.environ['SM_CHANNEL_VALID'] = f'{src_dir}/{valid_dir}'\n", "\n", "args = parser_args(train_notebook=True) \n", "args.chkpt_dir = chkpt_dir\n", "logger.info(\"***** Arguments *****\")\n", "logger.info(''.join(f'{k}={v}\\n' for k, v in vars(args).items()))\n", "\n", "os.makedirs(args.chkpt_dir, exist_ok=True) \n", "os.makedirs(args.model_dir, exist_ok=True)\n", "os.makedirs(args.output_data_dir, exist_ok=True) " ] }, { "cell_type": "markdown", "id": "109b1d73-65a7-4178-b2ce-91178f4be83e", "metadata": {}, "source": [ "
\n", "\n", "## 2. Preparation & Custructing Feature set\n", "---\n", "\n", "### Dataset\n", "\n", "본 핸즈온에서 사용할 데이터셋은 KDE4 데이터셋으로 한국어를 포함한 100여가지에 육박하는 언어를 지원하고 있습니다. 이 데이터셋을 사용하여, 대규모 Opus 데이터셋 (https://opus.nlpl.eu/) 으로 사전 훈련된 한영 번역 Marian 모델을 파인튜닝합니다.\n", "\n", "\n", "- KDE4 dataset: https://huggingface.co/datasets/kde4\n", "- 관련 논문: http://www.lrec-conf.org/proceedings/lrec2012/pdf/463_Paper.pdf\n", "\n", "```\n", "{\n", " 'id': '15',\n", " 'translation': \n", " {\n", " 'en': '& kde; provides a highly configurable desktop environment. This overview assumes that you are using the default environment.',\n", " 'ko': '& kde; 는 다양한 부분을 설정할 수 있는 데스크톱 환경입니다. 이 문서에서는 여러분이 기본적인 데스크톱 환경을 사용한다는 것을 가정합니다.'\n", " }\n", "}\n", "```` " ] }, { "cell_type": "code", "execution_count": 6, "id": "d61e29dc-1d61-4b89-b169-84cebc8da48f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[{builder.py:463} WARNING - Using custom data configuration en-ko-lang1=en,lang2=ko\n", "[{builder.py:641} WARNING - Reusing dataset kde4 (/home/ec2-user/.cache/huggingface/datasets/kde4/en-ko-lang1=en,lang2=ko/0.0.0/243129fb2398d5b0b4f7f6831ab27ad84774b7ce374cf10f60f6e1ff331648ac)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8cccc3380ad04c5895f7050f3fd02512", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1 [00:00\n", " \n", " \n", " \n", " id\n", " translation\n", " \n", " \n", " \n", " \n", " 0\n", " 12899\n", " {'en': 'Archive deleted.', 'ko': '압축 파일을 삭제했습니다.'}\n", " \n", " \n", " 1\n", " 11403\n", " {'en': 'Two Terminals, Horizontally', 'ko': '두 개의 터미널, 수평@ action'}\n", " \n", " \n", " 2\n", " 52907\n", " {'en': 'Pacific/ Fakaofo', 'ko': '태평양/ 파카오포'}\n", " \n", " \n", " 3\n", " 50684\n", " {'en': 'brown1', 'ko': 'color'}\n", " \n", " \n", " 4\n", " 8330\n", " {'en': 'Replace selection', 'ko': '선택부분 바꾸기'}\n", " \n", " \n", " 5\n", " 40103\n", " {'en': 'Central Region', 'ko': 'Central RegionRegion/ state in Russia'}\n", " \n", " \n", " 6\n", " 52832\n", " {'en': 'Europe/ Kaliningrad', 'ko': '유럽/ 칼리닌그라드'}\n", " \n", " \n", " 7\n", " 35057\n", " {'en': 'Border/ Coast', 'ko': 'ukraine. kgm'}\n", " \n", " \n", " 8\n", " 8446\n", " {'en': 'Public Domain', 'ko': 'Public Domain'}\n", " \n", " \n", " 9\n", " 46288\n", " {'en': 'Sami (Northern, Sweden)', 'ko': '북부 사미어 (스웨덴)'}\n", " \n", " \n", "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_random_elements(raw_datasets[\"train\"])" ] }, { "cell_type": "code", "execution_count": 9, "id": "22f321a0-2217-4d61-b5e2-854db7fcc15e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[{arrow_dataset.py:3615} WARNING - Loading cached split indices for dataset at /home/ec2-user/.cache/huggingface/datasets/kde4/en-ko-lang1=en,lang2=ko/0.0.0/243129fb2398d5b0b4f7f6831ab27ad84774b7ce374cf10f60f6e1ff331648ac/cache-7c7521c035973fea.arrow and /home/ec2-user/.cache/huggingface/datasets/kde4/en-ko-lang1=en,lang2=ko/0.0.0/243129fb2398d5b0b4f7f6831ab27ad84774b7ce374cf10f60f6e1ff331648ac/cache-d6acd5fd405706c8.arrow\n" ] } ], "source": [ "split_datasets = raw_datasets[\"train\"].train_test_split(train_size=0.9, seed=42)\n", "split_datasets[\"validation\"] = split_datasets.pop(\"test\")" ] }, { "cell_type": "markdown", "id": "fc27a2dd-5f3a-42d8-adb4-92a97a2778da", "metadata": {}, "source": [ "### Tokenization\n", "데이터셋을 토큰화합니다. 원문과 타겟 번역문 모두 토큰화가 필요하며, 타겟 번역문은 context manager 내에서 `as_target_tokenizer()`로 래핑해야 합니다.\n", "\n", "토큰화에 대한 자세한 내용은 https://huggingface.co/docs/datasets/process#processing-data-with-map 를 참조하세요." ] }, { "cell_type": "code", "execution_count": 10, "id": "cb129797-de35-4254-adb0-06b234c9f3b6", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/transformers/models/marian/tokenization_marian.py:198: UserWarning: Recommended: pip install sacremoses.\n", " warnings.warn(\"Recommended: pip install sacremoses.\")\n" ] } ], "source": [ "tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id, return_tensors=\"pt\")" ] }, { "cell_type": "markdown", "id": "63e95aac-793b-460b-8fcd-b7f6cc507647", "metadata": {}, "source": [ "#### Tokenize Sample Data" ] }, { "cell_type": "code", "execution_count": 11, "id": "2eca14f9-6ae7-4434-b060-eed6825f74c4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "({'input_ids': [20993, 343, 1068, 2266, 0], 'attention_mask': [1, 1, 1, 1, 1]},\n", " {'input_ids': [1097, 38774, 15193, 46, 0], 'attention_mask': [1, 1, 1, 1, 1]})" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ko_sentence = split_datasets[\"train\"][10][\"translation\"][\"ko\"]\n", "en_sentence = split_datasets[\"train\"][10][\"translation\"][\"en\"]\n", "\n", "inputs = tokenizer(ko_sentence)\n", "\n", "# If you forget to tokenize the target within the context manager, the target is tokenized by the input tokenizer. \n", "with tokenizer.as_target_tokenizer():\n", " targets = tokenizer(en_sentence)\n", " \n", "inputs, targets " ] }, { "cell_type": "code", "execution_count": 12, "id": "ad2cc4f8-ab1b-45e1-a6f9-2a65bf76df87", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[{arrow_dataset.py:2597} WARNING - Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/kde4/en-ko-lang1=en,lang2=ko/0.0.0/243129fb2398d5b0b4f7f6831ab27ad84774b7ce374cf10f60f6e1ff331648ac/cache-2149d75340913593.arrow\n", "[{arrow_dataset.py:2597} WARNING - Loading cached processed dataset at /home/ec2-user/.cache/huggingface/datasets/kde4/en-ko-lang1=en,lang2=ko/0.0.0/243129fb2398d5b0b4f7f6831ab27ad84774b7ce374cf10f60f6e1ff331648ac/cache-02f755fd9fd8c249.arrow\n" ] } ], "source": [ "max_input_length = 128\n", "max_target_length = 128\n", "\n", "# tokenizer helper function\n", "def preprocess_function(examples, source_lang=\"ko\", target_lang=\"en\"):\n", " inputs = [ex[source_lang] for ex in examples[\"translation\"]]\n", " targets = [ex[target_lang] for ex in examples[\"translation\"]]\n", " model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)\n", "\n", " # Setup the tokenizer for targets;\n", " # If you forget to tokenize the target within the context manager, the target is tokenized by the input tokenizer. \n", " with tokenizer.as_target_tokenizer():\n", " labels = tokenizer(targets, max_length=max_target_length, truncation=True)\n", "\n", " model_inputs[\"labels\"] = labels[\"input_ids\"]\n", " return model_inputs\n", "\n", "tokenized_datasets = split_datasets.map(\n", " preprocess_function,\n", " batched=True,\n", " remove_columns=split_datasets[\"train\"].column_names,\n", ")" ] }, { "cell_type": "code", "execution_count": 13, "id": "84de3db7-2afe-438a-b816-03822fda0b7a", "metadata": {}, "outputs": [], "source": [ "# train_dir = 'datasets/train'\n", "# valid_dir = 'datasets/valid'\n", "# !rm -rf {train_dir} {valid_dir}\n", "\n", "# os.makedirs(train_dir, exist_ok=True)\n", "# os.makedirs(valid_dir, exist_ok=True) \n", "\n", "train_dataset = tokenized_datasets['train']\n", "valid_dataset = tokenized_datasets['validation']\n", "\n", "if not os.listdir(args.train_dir):\n", " train_dataset.save_to_disk(args.train_dir)\n", "if not os.listdir(args.valid_dir):\n", " valid_dataset.save_to_disk(args.valid_dir)\n", "\n", "# from datasets import load_from_disk\n", "# train_dataset = load_from_disk(args.train_dir)\n", "# valid_dataset = load_from_disk(args.valid_dir)" ] }, { "cell_type": "markdown", "id": "931727f4-79eb-4ecd-a509-59c2518a6cb2", "metadata": {}, "source": [ "
\n", "\n", "## 3. Training (Fine-tuning)\n", "---\n", "\n", "### Define Custom metric\n", "특정 시점마다(예: epoch, steps) 검증 데이터셋으로 정밀도(precision), 재현율(recall), F1 스코어, 정확도(accuracy)를 등의 지표를 계산하기 위한 커스텀 함수를 정의합니다.\n", "\n", "커스텀 함수의 첫번째 인자는 `EvalPrediction` 객체로, 예측값(`predictios`)과 정답값(`label_ids`)를 포함합니다. 자세한 내용은 아래 웹사이트를 참조하세요. \n", "https://huggingface.co/transformers/internal/trainer_utils.html#transformers.EvalPrediction\n", "\n", "### BLEU (Bilingual Evaluation Understudy) metric\n", "BLEU는 기계 번역 결과와 사람이 직접 번역한 결과가 얼마나 유사한지 비교하는 대표적인 지표이며, 생성된 문장의 토큰이 정답(레이블) 문장에 포함되는 정도를 정량화한 n-gram 기반 precision 지표입니다.\n", "\n", "BLEU 스코어는 0-100 스케일로 높을 수록 좋으며, 정답 레이블 대비 동일한 토큰이 계속 반복되거나 정답 레이블보다 짧은 문장을 출력 시, 페널티를 부과합니다. \n", "\n", "- SacreBLEU: https://github.com/mjpost/sacreBLEU" ] }, { "cell_type": "code", "execution_count": 14, "id": "0d2553cf-bc7e-4f85-b428-664eb472f1bd", "metadata": {}, "outputs": [], "source": [ "metric = load_metric(\"sacrebleu\")\n", "\n", "def compute_metrics(eval_preds):\n", " preds, labels = eval_preds\n", " if isinstance(preds, tuple):\n", " preds = preds[0]\n", "\n", " decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n", " labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n", " decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n", "\n", " decoded_preds = [pred.strip() for pred in decoded_preds]\n", " decoded_labels = [[label.strip()] for label in decoded_labels]\n", "\n", " result = metric.compute(predictions=decoded_preds, references=decoded_labels)\n", " return {\"bleu\": result[\"score\"]}" ] }, { "cell_type": "markdown", "id": "d2a576f2-b868-4f95-b2fb-706b2f38a119", "metadata": {}, "source": [ "### Pre-trained model\n", "\n", "이제 훈련에 필요한 피쳐셋이 모두 준비되었으므로, 사전 훈련된 모델을 로드하여 파인튜닝을 수행합니다." ] }, { "cell_type": "code", "execution_count": 15, "id": "a1700cd9-52b9-4caf-b3b0-fadcf7afb2c6", "metadata": {}, "outputs": [], "source": [ "model = AutoModelForSeq2SeqLM.from_pretrained(args.model_id)" ] }, { "cell_type": "markdown", "id": "a2aea3f0-ef5a-447f-af86-326fb8adb622", "metadata": {}, "source": [ "### Data Collation\n", "\n", "동적 길이의 입력 데이터를 처리하기 위해 보통 패딩(padding) 기법을 사용하며, 이 때, 허깅페이스에서 지원하는 데이터 콜레이터(Data Collator)를 사용하면 편리합니다. 번역 모델은 Seq2seq 기반으로 `DataCollatorForSeq2Seq`을 사용하면 됩니다." ] }, { "cell_type": "code", "execution_count": 16, "id": "04789557-af5f-4a8d-9193-ae93c01f6e06", "metadata": {}, "outputs": [], "source": [ "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)" ] }, { "cell_type": "markdown", "id": "220f26a8-be78-4000-b991-aee40f784627", "metadata": {}, "source": [ "#### Check sample data" ] }, { "cell_type": "code", "execution_count": 17, "id": "c4f1ceaa-7304-4ff2-ad3b-90587b6eba88", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'input_ids': [33361, 66, 0],\n", " 'attention_mask': [1, 1, 1],\n", " 'labels': [31874, 66, 0]}" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenized_datasets['train'][0]" ] }, { "cell_type": "code", "execution_count": 18, "id": "7ed7858f-7550-40d2-accf-ccc6ce8ef3c9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[33361, 66, 0, 65000, 65000, 65000, 65000, 65000, 65000, 65000,\n", " 65000, 65000, 65000, 65000, 65000, 65000, 65000, 65000, 65000, 65000],\n", " [11616, 9768, 1160, 182, 168, 802, 103, 956, 11616, 11964,\n", " 9, 49537, 1160, 132, 248, 2, 8548, 438, 2, 0]]),\n", " tensor([[31874, 66, 0, -100, -100, -100, -100, -100, -100, -100,\n", " -100, -100, -100],\n", " [21961, 17769, 2092, 2406, 13, 17769, 1911, 9159, 2, 56436,\n", " 24, 2, 0]]))" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch = data_collator([tokenized_datasets[\"train\"][i] for i in range(0, 2)])\n", "batch['input_ids'], batch['labels']" ] }, { "cell_type": "markdown", "id": "95c7e873-5da6-40f0-ad68-bf8d49aac1f0", "metadata": {}, "source": [ "### Training Preparation\n", "\n", "Seq2Seq 기반 모델 평가 시에는 `predict_with_generate=True`로 설정하는 것을 잊지 마세요.\n", "\n", "참조: https://huggingface.co/transformers/main_classes/trainer.html#transformers.Seq2SeqTrainingArguments" ] }, { "cell_type": "code", "execution_count": 19, "id": "21cc7a91-42a9-40d1-b368-44212ee75aa6", "metadata": {}, "outputs": [], "source": [ "from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer\n", "\n", "training_args = Seq2SeqTrainingArguments(\n", " output_dir=args.chkpt_dir,\n", " overwrite_output_dir=True if get_last_checkpoint(args.chkpt_dir) is not None else False, \n", " num_train_epochs=args.epochs,\n", " per_device_train_batch_size=args.train_batch_size,\n", " per_device_eval_batch_size=args.eval_batch_size,\n", " weight_decay=0.01,\n", " learning_rate=float(args.learning_rate), \n", " save_total_limit=3,\n", " predict_with_generate=True,\n", " fp16=args.fp16,\n", " disable_tqdm=args.disable_tqdm, \n", " evaluation_strategy=\"no\", \n", " save_strategy=\"epoch\", \n", ")" ] }, { "cell_type": "code", "execution_count": 20, "id": "e08420d2-f06b-4906-b28f-2be7e7f8e604", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[{arrow_dataset.py:3399} WARNING - Loading cached shuffled indices for dataset at /home/ec2-user/.cache/huggingface/datasets/kde4/en-ko-lang1=en,lang2=ko/0.0.0/243129fb2398d5b0b4f7f6831ab27ad84774b7ce374cf10f60f6e1ff331648ac/cache-ae608e7ad10939de.arrow\n", "[{arrow_dataset.py:3399} WARNING - Loading cached shuffled indices for dataset at /home/ec2-user/.cache/huggingface/datasets/kde4/en-ko-lang1=en,lang2=ko/0.0.0/243129fb2398d5b0b4f7f6831ab27ad84774b7ce374cf10f60f6e1ff331648ac/cache-8c284c0eda7cf936.arrow\n" ] } ], "source": [ "# For debug only\n", "args.debug = True\n", "if args.debug:\n", " train_dataset = train_dataset.shuffle(seed=42).select(range(3000))\n", " valid_dataset = valid_dataset.shuffle(seed=42).select(range(300))" ] }, { "cell_type": "code", "execution_count": 21, "id": "d0a9aceb-e578-4072-b29f-ea6f48f8da37", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using cuda_amp half precision backend\n" ] } ], "source": [ "trainer = Seq2SeqTrainer(\n", " model,\n", " training_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=valid_dataset,\n", " data_collator=data_collator,\n", " tokenizer=tokenizer,\n", " compute_metrics=compute_metrics,\n", ")" ] }, { "cell_type": "markdown", "id": "59b9a409-8b52-4578-8f9f-a73ced1abedd", "metadata": {}, "source": [ "### Training\n", "훈련을 수행합니다. 딥러닝 기반 자연어 처리 모델 훈련에는 GPU가 필수이며, 본격적인 훈련을 위해서는 멀티 GPU 및 분산 훈련을 권장합니다. 만약 멀티 GPU가 장착되어 있다면 Trainer에서 총 배치 크기 = 배치 크기 x GPU 개수로 지정한 다음 데이터 병렬화를 자동으로 수행합니다." ] }, { "cell_type": "code", "execution_count": 22, "id": "d5afad66-b0ca-40e8-b06c-82aa9aa07fa1", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n", "***** Running training *****\n", " Num examples = 3000\n", " Num Epochs = 1\n", " Instantaneous batch size per device = 32\n", " Total train batch size (w. parallel, distributed & accumulation) = 128\n", " Gradient Accumulation steps = 1\n", " Total optimization steps = 24\n", "/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [24/24 00:12, Epoch 1/1]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Saving model checkpoint to chkpt/checkpoint-24\n", "Configuration saved in chkpt/checkpoint-24/config.json\n", "Model weights saved in chkpt/checkpoint-24/pytorch_model.bin\n", "tokenizer config file saved in chkpt/checkpoint-24/tokenizer_config.json\n", "Special tokens file saved in chkpt/checkpoint-24/special_tokens_map.json\n", "\n", "\n", "Training completed. Do not forget to share your model on huggingface.co/models =)\n", "\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 24.8 s, sys: 6.46 s, total: 31.3 s\n", "Wall time: 19.9 s\n" ] } ], "source": [ "%%time\n", "# train model\n", "if get_last_checkpoint(args.chkpt_dir) is not None:\n", " logger.info(\"***** Continue Training *****\")\n", " last_checkpoint = get_last_checkpoint(args.chkpt_dir)\n", " trainer.train(resume_from_checkpoint=last_checkpoint)\n", "else:\n", " trainer.train()" ] }, { "cell_type": "markdown", "id": "10a955a6-fbba-4992-b0ff-d21b51889854", "metadata": {}, "source": [ "
\n", "\n", "## 4. Evaluation\n", "---\n", "\n", "평가를 수행합니다." ] }, { "cell_type": "code", "execution_count": 23, "id": "673e411f-f515-4ceb-867d-ee0db5aa8ead", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "***** Running Prediction *****\n", " Num examples = 300\n", " Batch size = 128\n", "/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn('Was asked to gather along dimension 0, but all '\n" ] }, { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [3/3 00:05]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "***** Evaluation results at /home/ec2-user/SageMaker/sm-kornlp-usecases/translation/data *****\n", "[{4005816495.py:9} INFO - test_bleu = 43.00990584009949\n", "\n", "[{4005816495.py:9} INFO - test_loss = 1.6864901781082153\n", "\n", "[{4005816495.py:9} INFO - test_runtime = 14.0713\n", "\n", "[{4005816495.py:9} INFO - test_samples_per_second = 21.32\n", "\n", "[{4005816495.py:9} INFO - test_steps_per_second = 0.213\n", "\n" ] } ], "source": [ "outputs = trainer.predict(valid_dataset)\n", "eval_results = outputs.metrics\n", "\n", "# writes eval result to file which can be accessed later in s3 ouput\n", "with open(os.path.join(args.output_data_dir, \"eval_results.txt\"), \"w\") as writer:\n", " print(f\"***** Evaluation results at {args.output_data_dir} *****\")\n", " for key, value in sorted(eval_results.items()):\n", " writer.write(f\"{key} = {value}\\n\")\n", " logger.info(f\"{key} = {value}\\n\")" ] }, { "cell_type": "markdown", "id": "f0de9e0b-af14-4c85-971a-9cbda4b429ae", "metadata": {}, "source": [ "
\n", "\n", "## 5. Prediction\n", "---\n", "\n", "여러분만의 샘플 문장을 만들어서 자유롭게 추론을 수행해 보세요." ] }, { "cell_type": "code", "execution_count": 24, "id": "03a6dbf7-3030-4da9-bf0c-45f8819726ba", "metadata": {}, "outputs": [], "source": [ "from transformers import pipeline\n", "translator = pipeline(\n", " task=\"translation\",\n", " model=model, \n", " tokenizer=tokenizer,\n", " device=0\n", ")" ] }, { "cell_type": "code", "execution_count": 25, "id": "7b196a44-cfca-4452-9795-75452eb72190", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'translation_text': \"It's easy and fast to develop an ML model through the Amazon SageMaker, a fully managed service.\"}]" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "translator(\"머신 러닝 완전 관리형 서비스인 Amazon SageMaker를 통해 쉽고 빠르게 ML모델을 개발하세요\")" ] } ], "metadata": { "kernelspec": { "display_name": "conda_pytorch_p38", "language": "python", "name": "conda_pytorch_p38" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }