{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "hMBwJT0r38tj"
},
"source": [
"# Lab 1-1: Korean NER (Named Entity Recognition) Training on AWS"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Amc52mb94Jzq"
},
"source": [
"\n",
"## Introduction\n",
"---\n",
"\n",
"본 모듈에서는 허깅페이스 트랜스포머(Hugging Face transformers) 라이브러리를 사용하여 한국어 개체명 인식기(Korean NER; Named Entity Recognition)를 훈련합니다. NER은 문장에서 개체명(엔터티)을 찾는 작업으로 각 토큰 마다 개체명에 해당하는 정답 레이블을 찾는 다운스트림 태스크입니다.\n",
"\n",
"***[Note] SageMaker Studio Lab, SageMaker Studio, SageMaker 노트북 인스턴스, 또는 여러분의 로컬 머신에서 이 데모를 실행할 수 있습니다. SageMaker Studio Lab을 사용하는 경우 GPU를 활성화하세요.***\n",
"\n",
"***[주의] 본 데이터셋은 상업적인 목적으로 사용할 수 없습니다. 본 핸즈온은 연구/참고용으로만 활용하시고, 모델 훈련은 여러분만의 데이터셋을 직접 생성하셔야 합니다.***\n",
"\n",
"### References\n",
"- Hugging Face Tutorial: https://huggingface.co/docs/transformers/training\n",
"- 네이버, 창원대가 함께하는 NLP Challenge GitHub: https://github.com/naver/nlp-challenge\n",
"- 네이버, 창원대가 함께하는 NLP Challenge 리더보드 및 라이센스: http://air.changwon.ac.kr/?page_id=10"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"## 1. Setup Environments\n",
"---\n",
"\n",
"### Import modules"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"!rm -rf chkpt data model ner_train ner_valid train_data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"import json\n",
"import logging\n",
"import argparse\n",
"import torch\n",
"from torch import nn\n",
"import numpy as np\n",
"import pandas as pd\n",
"from tqdm import tqdm\n",
"from sklearn.model_selection import train_test_split\n",
"from transformers import (\n",
" BertTokenizer, BertTokenizerFast, BertConfig, BertForTokenClassification, \n",
" Trainer, TrainingArguments, set_seed\n",
")\n",
"from transformers.trainer_utils import get_last_checkpoint\n",
"\n",
"logging.basicConfig(\n",
" level=logging.INFO, \n",
" format='[{%(filename)s:%(lineno)d} %(levelname)s - %(message)s',\n",
" handlers=[\n",
" logging.StreamHandler(sys.stdout)\n",
" ]\n",
")\n",
"logger = logging.getLogger(__name__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Argument parser"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def parser_args(train_notebook=False):\n",
" parser = argparse.ArgumentParser()\n",
"\n",
" # Default Setting\n",
" parser.add_argument(\"--epochs\", type=int, default=3)\n",
" parser.add_argument(\"--seed\", type=int, default=42)\n",
" parser.add_argument(\"--train_batch_size\", type=int, default=32)\n",
" parser.add_argument(\"--eval_batch_size\", type=int, default=64)\n",
" parser.add_argument(\"--warmup_steps\", type=int, default=100)\n",
" parser.add_argument(\"--learning_rate\", type=str, default=5e-5)\n",
" parser.add_argument(\"--disable_tqdm\", type=bool, default=False)\n",
" #parser.add_argument(\"--fp16\", type=bool, default=True)\n",
" parser.add_argument(\"--tokenizer_id\", type=str, default='bert-base-multilingual-cased')\n",
" #parser.add_argument(\"--model_id\", type=str, default='distilbert-base-multilingual-cased') \n",
" parser.add_argument(\"--model_id\", type=str, default='bert-base-multilingual-cased')\n",
" \n",
" # SageMaker Container environment\n",
" parser.add_argument(\"--output_data_dir\", type=str, default=os.environ[\"SM_OUTPUT_DATA_DIR\"])\n",
" parser.add_argument(\"--model_dir\", type=str, default=os.environ[\"SM_MODEL_DIR\"])\n",
" parser.add_argument(\"--n_gpus\", type=str, default=os.environ[\"SM_NUM_GPUS\"])\n",
" parser.add_argument(\"--train_dir\", type=str, default=os.environ[\"SM_CHANNEL_TRAIN\"])\n",
" parser.add_argument(\"--valid_dir\", type=str, default=os.environ[\"SM_CHANNEL_VALID\"])\n",
" parser.add_argument('--chkpt_dir', type=str, default='/opt/ml/checkpoints') \n",
"\n",
" if train_notebook:\n",
" args = parser.parse_args([])\n",
" else:\n",
" args = parser.parse_args()\n",
" return args"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"train_dir = 'ner_train'\n",
"valid_dir = 'ner_valid'\n",
"!rm -rf {train_dir} {valid_dir}\n",
"os.makedirs(train_dir, exist_ok=True)\n",
"os.makedirs(valid_dir, exist_ok=True) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load Arguments\n",
"\n",
"주피터 노트북에서 곧바로 실행할 수 있도록 설정값들을 로드합니다. 물론 노트북 환경이 아닌 커맨드라인에서도 `cd scripts & python3 train.py` 커맨드로 훈련 스크립트를 실행할 수 있습니다."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[{204499775.py:21} INFO - ***** Arguments *****\n",
"[{204499775.py:22} INFO - epochs=3\n",
"seed=42\n",
"train_batch_size=32\n",
"eval_batch_size=64\n",
"warmup_steps=100\n",
"learning_rate=5e-05\n",
"disable_tqdm=False\n",
"tokenizer_id=bert-base-multilingual-cased\n",
"model_id=bert-base-multilingual-cased\n",
"output_data_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/named-entity-recognition/data\n",
"model_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/named-entity-recognition/model\n",
"n_gpus=4\n",
"train_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/named-entity-recognition/ner_train\n",
"valid_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/named-entity-recognition/ner_valid\n",
"chkpt_dir=chkpt\n",
"\n"
]
}
],
"source": [
"chkpt_dir = 'chkpt'\n",
"model_dir = 'model'\n",
"output_data_dir = 'data'\n",
"num_gpus = torch.cuda.device_count()\n",
"\n",
"!rm -rf {chkpt_dir} {model_dir} {output_data_dir} \n",
"\n",
"if os.environ.get('SM_CURRENT_HOST') is None:\n",
" is_sm_container = False\n",
"\n",
" #src_dir = '/'.join(os.getcwd().split('/')[:-1])\n",
" src_dir = os.getcwd()\n",
" os.environ['SM_MODEL_DIR'] = f'{src_dir}/{model_dir}'\n",
" os.environ['SM_OUTPUT_DATA_DIR'] = f'{src_dir}/{output_data_dir}'\n",
" os.environ['SM_NUM_GPUS'] = str(num_gpus)\n",
" os.environ['SM_CHANNEL_TRAIN'] = f'{src_dir}/{train_dir}'\n",
" os.environ['SM_CHANNEL_VALID'] = f'{src_dir}/{valid_dir}'\n",
"\n",
"args = parser_args(train_notebook=True) \n",
"args.chkpt_dir = chkpt_dir\n",
"logger.info(\"***** Arguments *****\")\n",
"logger.info(''.join(f'{k}={v}\\n' for k, v in vars(args).items()))\n",
"\n",
"os.makedirs(args.chkpt_dir, exist_ok=True) \n",
"os.makedirs(args.model_dir, exist_ok=True)\n",
"os.makedirs(args.output_data_dir, exist_ok=True) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
\n",
"\n",
"## 2. Preparation\n",
"---\n",
"\n",
"### Dataset\n",
"\n",
"본 핸즈온에서 사용할 데이터셋은 블라블라입니다. \n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2Xv5IzTBeecA",
"outputId": "af8096fe-f6ef-4fab-e71b-28c9486c7d74"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2022-07-05 22:45:03-- https://github.com/naver/nlp-challenge/raw/master/missions/ner/data/train/train_data\n",
"Resolving github.com (github.com)... 140.82.112.4\n",
"Connecting to github.com (github.com)|140.82.112.4|:443... connected.\n",
"HTTP request sent, awaiting response... 302 Found\n",
"Location: https://raw.githubusercontent.com/naver/nlp-challenge/master/missions/ner/data/train/train_data [following]\n",
"--2022-07-05 22:45:03-- https://raw.githubusercontent.com/naver/nlp-challenge/master/missions/ner/data/train/train_data\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.111.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 16945023 (16M) [text/plain]\n",
"Saving to: ‘train_data’\n",
"\n",
"100%[======================================>] 16,945,023 --.-K/s in 0.07s \n",
"\n",
"2022-07-05 22:45:03 (223 MB/s) - ‘train_data’ saved [16945023/16945023]\n",
"\n"
]
}
],
"source": [
"!wget https://github.com/naver/nlp-challenge/raw/master/missions/ner/data/train/train_data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`_B` / `_I`로 끝나는 개체명은 해당 단어가 개체명의 시작과 내부에 있음을 의미하며, 빈 칸이나 단일 문자 개체명(예: `O`, `-`)은 해당 단어가 어떤 개체명에도 해당하지 않음을 의미합니다.\n",
"\n",
"인덱스의 경우 1,2,3,...으로 순차적으로 증가하다가 다시 1,2,...로 변경되는데 인덱스 1은 새로운 문장의 첫 단어를 의미합니다."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 415
},
"id": "45hQv0Bienkt",
"outputId": "e4614d58-3e74-4f22-81b4-1461e4c599ed"
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" index | \n",
" src | \n",
" tar | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1 | \n",
" 비토리오 | \n",
" PER_B | \n",
"
\n",
" \n",
" 1 | \n",
" 2 | \n",
" 양일 | \n",
" DAT_B | \n",
"
\n",
" \n",
" 2 | \n",
" 3 | \n",
" 만에 | \n",
" - | \n",
"
\n",
" \n",
" 3 | \n",
" 4 | \n",
" 영사관 | \n",
" ORG_B | \n",
"
\n",
" \n",
" 4 | \n",
" 5 | \n",
" 감호 | \n",
" CVL_B | \n",
"
\n",
" \n",
" 5 | \n",
" 6 | \n",
" 용퇴, | \n",
" - | \n",
"
\n",
" \n",
" 6 | \n",
" 7 | \n",
" 항룡 | \n",
" - | \n",
"
\n",
" \n",
" 7 | \n",
" 8 | \n",
" 압력설 | \n",
" - | \n",
"
\n",
" \n",
" 8 | \n",
" 9 | \n",
" 의심만 | \n",
" - | \n",
"
\n",
" \n",
" 9 | \n",
" 10 | \n",
" 가율 | \n",
" - | \n",
"
\n",
" \n",
" 10 | \n",
" 1 | \n",
" 이 | \n",
" - | \n",
"
\n",
" \n",
" 11 | \n",
" 2 | \n",
" 음경동맥의 | \n",
" - | \n",
"
\n",
" \n",
" 12 | \n",
" 3 | \n",
" 직경이 | \n",
" - | \n",
"
\n",
" \n",
" 13 | \n",
" 4 | \n",
" 8 | \n",
" NUM_B | \n",
"
\n",
" \n",
" 14 | \n",
" 5 | \n",
" 19mm입니다 | \n",
" NUM_B | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" index src tar\n",
"0 1 비토리오 PER_B\n",
"1 2 양일 DAT_B\n",
"2 3 만에 -\n",
"3 4 영사관 ORG_B\n",
"4 5 감호 CVL_B\n",
"5 6 용퇴, -\n",
"6 7 항룡 -\n",
"7 8 압력설 -\n",
"8 9 의심만 -\n",
"9 10 가율 -\n",
"10 1 이 -\n",
"11 2 음경동맥의 -\n",
"12 3 직경이 -\n",
"13 4 8 NUM_B\n",
"14 5 19mm입니다 NUM_B"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"raw_data = pd.read_csv(\"train_data\", names=['src', 'tar'], sep=\"\\t\")\n",
"raw_data = raw_data.reset_index()\n",
"raw_data.head(15)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eC2QfGRbhmJJ"
},
"source": [
"### Data Cleansing"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IJHg2kIUht7L"
},
"source": [
"한글, 숫자, 영어, . 외의 단어들을 모두 제거합니다."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "XnsxeMY9Z7vv"
},
"outputs": [],
"source": [
"raw_data['src'] = raw_data['src'].str.replace(\".\", \".\", regex=False)\n",
"raw_data['src'] = raw_data['src'].astype(str)\n",
"raw_data['tar'] = raw_data['tar'].astype(str)\n",
"raw_data['src'] = raw_data['src'].str.replace(r'[^ㄱ-ㅣ가-힣0-9a-zA-Z.]+', \"\", regex=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"개체명-숫자 레이블 간 상호 변환을 위한 딕셔너리를 생성합니다. "
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[{99085830.py:4} INFO - {'PER_B': 0, 'DAT_B': 1, '-': 2, 'ORG_B': 3, 'CVL_B': 4, 'NUM_B': 5, 'LOC_B': 6, 'EVT_B': 7, 'TRM_B': 8, 'TRM_I': 9, 'EVT_I': 10, 'PER_I': 11, 'CVL_I': 12, 'NUM_I': 13, 'TIM_B': 14, 'TIM_I': 15, 'ORG_I': 16, 'DAT_I': 17, 'ANM_B': 18, 'MAT_B': 19, 'MAT_I': 20, 'AFW_B': 21, 'FLD_B': 22, 'LOC_I': 23, 'AFW_I': 24, 'PLT_B': 25, 'FLD_I': 26, 'ANM_I': 27, 'PLT_I': 28}\n",
"[{99085830.py:5} INFO - {0: 'PER_B', 1: 'DAT_B', 2: '-', 3: 'ORG_B', 4: 'CVL_B', 5: 'NUM_B', 6: 'LOC_B', 7: 'EVT_B', 8: 'TRM_B', 9: 'TRM_I', 10: 'EVT_I', 11: 'PER_I', 12: 'CVL_I', 13: 'NUM_I', 14: 'TIM_B', 15: 'TIM_I', 16: 'ORG_I', 17: 'DAT_I', 18: 'ANM_B', 19: 'MAT_B', 20: 'MAT_I', 21: 'AFW_B', 22: 'FLD_B', 23: 'LOC_I', 24: 'AFW_I', 25: 'PLT_B', 26: 'FLD_I', 27: 'ANM_I', 28: 'PLT_I'}\n"
]
}
],
"source": [
"unique_tags = raw_data['tar'].unique().tolist()\n",
"tag2id = {tag: id for id, tag in enumerate(unique_tags)}\n",
"id2tag = {id: tag for tag, id in tag2id.items()}\n",
"logger.info(tag2id)\n",
"logger.info(id2tag)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QekBoou4h1HK"
},
"source": [
"모델 추론 시에 개체명을 알기 쉽게 표현하기 위해 개체명-개체 설명 변환 딕셔너리를 생성합니다. "
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def get_tag2entity(tag2id):\n",
" tag2entity = {}\n",
" \n",
" for idx, (tag, id) in enumerate(tag2id.items()):\n",
" if \"PER\" in tag:\n",
" entity = 'Person'\n",
" elif \"FLD\" in tag:\n",
" entity = \"Field\"\n",
" elif \"AFW\" in tag:\n",
" entity = \"Artifacts_works\"\n",
" elif \"ORG\" in tag:\n",
" entity = \"Organization\"\n",
" elif \"LOC\" in tag:\n",
" entity = \"Location\"\n",
" elif \"CVL\" in tag:\n",
" entity = \"Civilization\"\n",
" elif \"DAT\" in tag:\n",
" entity = \"Date\"\n",
" elif \"TIM\" in tag:\n",
" entity = \"Time\"\n",
" elif \"NUM\" in tag:\n",
" entity = \"Number\"\n",
" elif \"EVT\" in tag:\n",
" entity = \"Event\"\n",
" elif \"ANM\" in tag:\n",
" entity = \"Animal\"\n",
" elif \"PLT\" in tag:\n",
" entity = \"Plant\"\n",
" elif \"MAT\" in tag:\n",
" entity = \"Material\"\n",
" elif \"TRM\" in tag:\n",
" entity = \"Term\"\n",
" else:\n",
" entity = tag\n",
"\n",
" tag2entity[tag] = entity\n",
" return tag2entity\n",
"\n",
"tag2entity = get_tag2entity(tag2id)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MqwmtXbBihQM"
},
"source": [
"데이터를 문장들과 개체들로 분리합니다. "
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"tups = []\n",
"temp_tup = []\n",
"data = [list(x) for x in raw_data[['index', 'src', 'tar']].to_numpy()]\n",
"\n",
"for idx, (i, token, entity) in enumerate(data):\n",
" if i == 1: # first token in a sentence\n",
" if idx != 0: \n",
" tups.append(temp_tup)\n",
" temp_tup = []\n",
" temp_tup.append((token, tag2id[entity]))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"sentences = []\n",
"labels = []\n",
"\n",
"for tup in tups:\n",
" sentence = []\n",
" label = []\n",
" \n",
" sentence.append(\"[CLS]\")\n",
" label.append(tag2id['-'])\n",
" \n",
" for t, l in tup:\n",
" sentence.append(t)\n",
" label.append(l)\n",
" \n",
" sentence.append(\"[SEP]\")\n",
" label.append(tag2id['-'])\n",
" \n",
" sentences.append(sentence)\n",
" labels.append(label)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(['[CLS]',\n",
" '비토리오',\n",
" '양일',\n",
" '만에',\n",
" '영사관',\n",
" '감호',\n",
" '용퇴',\n",
" '항룡',\n",
" '압력설',\n",
" '의심만',\n",
" '가율',\n",
" '[SEP]'],\n",
" [2, 0, 1, 2, 3, 4, 2, 2, 2, 2, 2, 2])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sentences[0], labels[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
\n",
"\n",
"## 3. Construct Feature set\n",
"---\n",
"\n",
"### Tokenization and Labeling\n",
"\n",
"#### Tokenization\n",
"연어 처리 모델을 훈련하려면, 토큰화(Tokenization)를 통해 말뭉치(corpus; 자연어 처리를 위한 대량의 텍스트 데이터)를 토큰 시퀀스로 나누는 과정이 필요합니다. BERT 이전의 자연어 처리 모델은 주로 도메인 전문가들이 직접 토큰화해놓은 토크아니저(Mecab, Kkma 등)들을 사용했지만, BERT를 훈련하기 위한 토크나이저는 도메인 지식 필요 없이 말뭉치에서 자주 등장하는 서브워드(subword)를 토큰화합니다. GPT 기반 모델은 BPE(Byte-pair Encoding)라는 통계적 기법을 사용하며, BERT 및 ELECTRA 기반 모델은 BPE와 유사한 Wordpiece를 토크나이저로 사용합니다.\n",
"\n",
"#### Labeling\n",
"NER task의 경우, 정답 태그가 토큰 단위가 아니라 단어 단위로 부여되어 있기에, 이를 토큰 단위로 확장해야 합니다.\n",
"\n",
"```python\n",
"\n",
"tokenizer.tokenize('2006년 아마존')\n",
">> ['2006년', '아', '##마', '##존']\n",
"\n",
"원본 레이블링:\n",
"(2006년, 'DAT'), ('아마존', 'ORG')\n",
"\n",
"토큰 단위 레이블링 변환:\n",
"(2006년, 'DAT'), ('아', 'ORG'), ('##마', 'ORG'), ('##존', 'ORG') \n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"def tokenize_and_preserve_labels(sentence, text_labels, tokenizer, max_len=128):\n",
"\n",
" \"\"\"\n",
" Word piece tokenization makes it difficult to match word labels\n",
" back up with individual word pieces. This function tokenizes each\n",
" word one at a time so that it is easier to preserve the correct\n",
" label for each subword. It is, of course, a bit slower in processing\n",
" time, but it will help our model achieve higher accuracy.\n",
" \"\"\"\n",
"\n",
" tokenized_sentence = []\n",
" labels = []\n",
"\n",
" for word, label in zip(sentence, text_labels):\n",
" \n",
" if len(word) > max_len:\n",
" word = word[:max_len]\n",
" \n",
" # Tokenize the word and count # of subwords the word is broken into\n",
" tokenized_word = tokenizer.tokenize(word)\n",
" n_subwords = len(tokenized_word)\n",
"\n",
" # Add the tokenized word to the final tokenized word list\n",
" tokenized_sentence.extend(tokenized_word)\n",
"\n",
" # Add the same label to the new list of labels `n_subwords` times\n",
" labels.extend([label] * n_subwords)\n",
"\n",
" return tokenized_sentence, labels"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"id": "0U6gAmqhCvRF"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 3.05 s, sys: 369 ms, total: 3.41 s\n",
"Wall time: 3.97 s\n"
]
}
],
"source": [
"%%time\n",
"from itertools import repeat\n",
"from multiprocessing import cpu_count, Pool\n",
"\n",
"num_cores = 16 if cpu_count() > 16 else cpu_count() \n",
"tokenizer = BertTokenizerFast.from_pretrained(args.tokenizer_id)\n",
"\n",
"with Pool(processes=num_cores) as pool:\n",
" tokenized_texts_and_labels = pool.starmap(tokenize_and_preserve_labels, \n",
" zip(sentences, labels, repeat(tokenizer)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Converting Input Ids and Labels\n",
"\n",
"토큰과 엔티티를 토큰에 대응하는 인덱스와 엔티티에 대응하는 레이블로 변환합니다."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"ids, labels = [], []\n",
"\n",
"max_len = 128\n",
"for t, l in tokenized_texts_and_labels:\n",
" if len(l) > max_len:\n",
" t = t[:max_len]\n",
" l = l[:max_len]\n",
" \n",
" ids.append(torch.tensor(tokenizer.convert_tokens_to_ids(t)))\n",
" labels.append(torch.tensor(l))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"패딩(padding)을 수행합니다. `'[PAD]'` 토큰에 대응하는 레이블은 보통 -100입니다."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"from torch.nn.utils.rnn import pad_sequence\n",
"ids = pad_sequence(ids, batch_first=True, padding_value=0)\n",
"labels = pad_sequence(labels, batch_first=True, padding_value=-100)\n",
"attention_masks = (labels != -100).long()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Split into Training set and validation set"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"random_state = 42\n",
"test_size = 0.2\n",
"train_ids, valid_ids, train_labels, valid_labels = train_test_split(ids, labels,\n",
" random_state=random_state, test_size=test_size)\n",
"train_attention_masks, valid_attention_masks, _, _ = train_test_split(attention_masks, labels,\n",
" random_state=random_state, test_size=test_size)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Save Training/Evaluation data and metadata"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"train_dict = {'input_ids': train_ids, 'attention_mask': train_attention_masks, 'labels': train_labels}\n",
"valid_dict = {'input_ids': valid_ids, 'attention_mask': valid_attention_masks, 'labels': valid_labels}\n",
"\n",
"torch.save(train_dict, os.path.join(train_dir, 'train_features.pt'))\n",
"\n",
"with open(os.path.join(train_dir, 'tag2id.json'), 'w') as f:\n",
" json.dump(tag2id, f) \n",
"\n",
"with open(os.path.join(train_dir, 'id2tag.json'), 'w') as f:\n",
" json.dump(id2tag, f)\n",
"\n",
"with open(os.path.join(train_dir, 'tag2entity.json'), 'w') as f:\n",
" json.dump(tag2entity, f) \n",
"\n",
"torch.save(valid_dict, os.path.join(valid_dir, 'valid_features.pt'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Custom Dataset\n",
"\n",
"훈련/검증 시에 사용할 커스텀 데이터셋을 생성하기 위한 클래스를 생성합니다. BERT 기반 모델은 보통 아래의 입력값을 사용하며, 단일 문장만 사용한다면 `token_type_ids`를 생략해도 무방합니다.\n",
"- `input_ids`: 문장이 인덱스(특정 vocab에 매핑하는 숫자값)로 구성된 토큰 시퀀스로 변환된 결괏값\n",
"- `attention_mask` : 해당 토큰이 패딩 토큰인지, 아닌지를 마스킹\n",
"- `token_type_ids`: 세그먼트 (두 문장 입력 시, 첫번째 문장인지 아닌지를 마스킹)\n",
"- `labels`: 해당 토큰에 매핑되는 정답 개체 레이블 "
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"class NERDataset(torch.utils.data.Dataset):\n",
" def __init__(self, input_ids, attention_masks, labels=None, max_len=128):\n",
" self.input_ids = input_ids\n",
" self.attention_masks = attention_masks\n",
" self.labels = labels\n",
" self.max_len = max_len\n",
" \n",
" def __getitem__(self, idx):\n",
" item = {}\n",
" item['input_ids'] = self.input_ids[idx]\n",
" item['attention_mask'] = self.attention_masks[idx]\n",
" item['labels'] = self.labels[idx]\n",
" return item\n",
"\n",
" def __len__(self):\n",
" return len(self.labels)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[{1494477010.py:3} INFO - num_train samples=52416, num_valid samples=13104\n"
]
}
],
"source": [
"train_dataset = NERDataset(train_ids, train_attention_masks, train_labels)\n",
"valid_dataset = NERDataset(valid_ids, valid_attention_masks, valid_labels)\n",
"logger.info(f'num_train samples={len(train_dataset)}, num_valid samples={len(valid_dataset)}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
\n",
"\n",
"## 4. Training\n",
"---\n",
"\n",
"### Define Custom metric\n",
"특정 시점마다(예: epoch, steps) 검증 데이터셋으로 정밀도(precision), 재현율(recall), F1 스코어, 정확도(accuracy)를 계산하기 위한 커스텀 함수를 정의합니다."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"def compute_metrics(p):\n",
" logits = p.predictions\n",
" labels = p.label_ids.ravel()\n",
" preds = logits.argmax(-1).ravel()\n",
" \n",
" preds = preds[labels != -100]\n",
" labels = labels[labels != -100]\n",
"\n",
" from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report\n",
" prec, rec, f1, _ = precision_recall_fscore_support(labels, preds, average='micro')\n",
" acc = accuracy_score(labels, preds)\n",
"\n",
" metrics = {\n",
" 'precision': prec,\n",
" 'recall': rec,\n",
" 'f1': f1, \n",
" 'accuracy': acc\n",
" }\n",
" \n",
" return metrics"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training Preparation\n",
"\n",
"본 핸즈온은 허깅페이스의 트랜스포머 라이브러리에 포함된 BertForTokenClassification 모델을 사용합니다. 이 모델은 문장 레벨이 아닌 토큰 레벨로 예측을 수행합니다."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForTokenClassification: ['cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']\n",
"- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"model = BertForTokenClassification.from_pretrained(args.model_id, num_labels=len(tag2id))\n",
"model.config.id2label = id2tag\n",
"model.config.label2id = tag2id\n",
"\n",
"training_args = TrainingArguments(\n",
" output_dir=args.chkpt_dir, # output directory\n",
" overwrite_output_dir=True if get_last_checkpoint(args.chkpt_dir) is not None else False,\n",
" num_train_epochs=args.epochs, # total number of training epochs\n",
" per_device_train_batch_size=args.train_batch_size, # batch size per device during training\n",
" per_device_eval_batch_size=args.eval_batch_size, # batch size for evaluation\n",
" warmup_steps=args.warmup_steps, # number of warmup steps for learning rate scheduler\n",
" weight_decay=0.01, # strength of weight decay\n",
" logging_dir=f\"{args.output_data_dir}/logs\", # directory for storing logs\n",
" eval_steps=100,\n",
" learning_rate=float(args.learning_rate),\n",
" #load_best_model_at_end=True,\n",
" save_strategy=\"epoch\",\n",
" evaluation_strategy=\"steps\",\n",
" metric_for_best_model=\"f1\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"훈련을 수행하기 위한 `Trainer` 클래스를 인스턴스화합니다."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=train_dataset, \n",
" eval_dataset=valid_dataset,\n",
" tokenizer=tokenizer,\n",
" compute_metrics=compute_metrics\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training\n",
"훈련을 수행합니다. 딥러닝 기반 자연어 처리 모델 훈련에는 GPU가 필수이며, 본격적인 훈련을 위해서는 멀티 GPU 및 분산 훈련을 권장합니다. 만약 멀티 GPU가 장착되어 있다면 Trainer에서 총 배치 크기 = 배치 크기 x GPU 개수로 지정한 다음 데이터 병렬화를 자동으로 수행합니다."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
" warnings.warn(\n",
"***** Running training *****\n",
" Num examples = 52416\n",
" Num Epochs = 3\n",
" Instantaneous batch size per device = 32\n",
" Total train batch size (w. parallel, distributed & accumulation) = 128\n",
" Gradient Accumulation steps = 1\n",
" Total optimization steps = 1230\n",
"/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n"
]
},
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
"
\n",
" [1230/1230 14:28, Epoch 3/3]\n",
"
\n",
" \n",
" \n",
" \n",
" Step | \n",
" Training Loss | \n",
" Validation Loss | \n",
" Precision | \n",
" Recall | \n",
" F1 | \n",
" Accuracy | \n",
"
\n",
" \n",
" \n",
" \n",
" 100 | \n",
" No log | \n",
" 0.664300 | \n",
" 0.812209 | \n",
" 0.812209 | \n",
" 0.812209 | \n",
" 0.812209 | \n",
"
\n",
" \n",
" 200 | \n",
" No log | \n",
" 0.522309 | \n",
" 0.851770 | \n",
" 0.851770 | \n",
" 0.851770 | \n",
" 0.851770 | \n",
"
\n",
" \n",
" 300 | \n",
" No log | \n",
" 0.461020 | \n",
" 0.866471 | \n",
" 0.866471 | \n",
" 0.866471 | \n",
" 0.866471 | \n",
"
\n",
" \n",
" 400 | \n",
" No log | \n",
" 0.416925 | \n",
" 0.876684 | \n",
" 0.876684 | \n",
" 0.876684 | \n",
" 0.876684 | \n",
"
\n",
" \n",
" 500 | \n",
" 0.673100 | \n",
" 0.407849 | \n",
" 0.880267 | \n",
" 0.880267 | \n",
" 0.880267 | \n",
" 0.880267 | \n",
"
\n",
" \n",
" 600 | \n",
" 0.673100 | \n",
" 0.393547 | \n",
" 0.882919 | \n",
" 0.882919 | \n",
" 0.882919 | \n",
" 0.882919 | \n",
"
\n",
" \n",
" 700 | \n",
" 0.673100 | \n",
" 0.380853 | \n",
" 0.887344 | \n",
" 0.887344 | \n",
" 0.887344 | \n",
" 0.887344 | \n",
"
\n",
" \n",
" 800 | \n",
" 0.673100 | \n",
" 0.362687 | \n",
" 0.891400 | \n",
" 0.891400 | \n",
" 0.891400 | \n",
" 0.891400 | \n",
"
\n",
" \n",
" 900 | \n",
" 0.673100 | \n",
" 0.373592 | \n",
" 0.890279 | \n",
" 0.890279 | \n",
" 0.890279 | \n",
" 0.890279 | \n",
"
\n",
" \n",
" 1000 | \n",
" 0.343400 | \n",
" 0.367624 | \n",
" 0.891256 | \n",
" 0.891256 | \n",
" 0.891256 | \n",
" 0.891256 | \n",
"
\n",
" \n",
" 1100 | \n",
" 0.343400 | \n",
" 0.360553 | \n",
" 0.894566 | \n",
" 0.894566 | \n",
" 0.894566 | \n",
" 0.894566 | \n",
"
\n",
" \n",
" 1200 | \n",
" 0.343400 | \n",
" 0.358742 | \n",
" 0.894731 | \n",
" 0.894731 | \n",
" 0.894731 | \n",
" 0.894731 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"***** Running Evaluation *****\n",
" Num examples = 13104\n",
" Batch size = 256\n",
"/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
"***** Running Evaluation *****\n",
" Num examples = 13104\n",
" Batch size = 256\n",
"/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
"***** Running Evaluation *****\n",
" Num examples = 13104\n",
" Batch size = 256\n",
"/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
"***** Running Evaluation *****\n",
" Num examples = 13104\n",
" Batch size = 256\n",
"/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
"Saving model checkpoint to chkpt/checkpoint-410\n",
"Configuration saved in chkpt/checkpoint-410/config.json\n",
"Model weights saved in chkpt/checkpoint-410/pytorch_model.bin\n",
"tokenizer config file saved in chkpt/checkpoint-410/tokenizer_config.json\n",
"Special tokens file saved in chkpt/checkpoint-410/special_tokens_map.json\n",
"/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
"***** Running Evaluation *****\n",
" Num examples = 13104\n",
" Batch size = 256\n",
"/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
"***** Running Evaluation *****\n",
" Num examples = 13104\n",
" Batch size = 256\n",
"/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
"***** Running Evaluation *****\n",
" Num examples = 13104\n",
" Batch size = 256\n",
"/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
"***** Running Evaluation *****\n",
" Num examples = 13104\n",
" Batch size = 256\n",
"/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
"Saving model checkpoint to chkpt/checkpoint-820\n",
"Configuration saved in chkpt/checkpoint-820/config.json\n",
"Model weights saved in chkpt/checkpoint-820/pytorch_model.bin\n",
"tokenizer config file saved in chkpt/checkpoint-820/tokenizer_config.json\n",
"Special tokens file saved in chkpt/checkpoint-820/special_tokens_map.json\n",
"/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
"***** Running Evaluation *****\n",
" Num examples = 13104\n",
" Batch size = 256\n",
"/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
"***** Running Evaluation *****\n",
" Num examples = 13104\n",
" Batch size = 256\n",
"/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
"***** Running Evaluation *****\n",
" Num examples = 13104\n",
" Batch size = 256\n",
"/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
"***** Running Evaluation *****\n",
" Num examples = 13104\n",
" Batch size = 256\n",
"/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n",
"Saving model checkpoint to chkpt/checkpoint-1230\n",
"Configuration saved in chkpt/checkpoint-1230/config.json\n",
"Model weights saved in chkpt/checkpoint-1230/pytorch_model.bin\n",
"tokenizer config file saved in chkpt/checkpoint-1230/tokenizer_config.json\n",
"Special tokens file saved in chkpt/checkpoint-1230/special_tokens_map.json\n",
"\n",
"\n",
"Training completed. Do not forget to share your model on huggingface.co/models =)\n",
"\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 19min 35s, sys: 2min 57s, total: 22min 33s\n",
"Wall time: 14min 35s\n"
]
}
],
"source": [
"%%time\n",
"# train model\n",
"if get_last_checkpoint(args.chkpt_dir) is not None:\n",
" logger.info(\"***** Continue Training *****\")\n",
" last_checkpoint = get_last_checkpoint(args.chkpt_dir)\n",
" trainer.train(resume_from_checkpoint=last_checkpoint)\n",
"else:\n",
" trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
\n",
"\n",
"## 5. Evaluation\n",
"---\n",
"\n",
"평가를 수행합니다."
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"***** Running Prediction *****\n",
" Num examples = 13104\n",
" Batch size = 256\n",
"/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
" warnings.warn('Was asked to gather along dimension 0, but all '\n"
]
},
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
"
\n",
" [52/52 00:14]\n",
"
\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"***** Evaluation results at /home/ec2-user/SageMaker/sm-kornlp-usecases/named-entity-recognition/data *****\n",
"[{2675283658.py:10} INFO - test_accuracy = 0.8945943620394088\n",
"\n",
"[{2675283658.py:10} INFO - test_f1 = 0.8945943620394088\n",
"\n",
"[{2675283658.py:10} INFO - test_loss = 0.35873326659202576\n",
"\n",
"[{2675283658.py:10} INFO - test_precision = 0.8945943620394088\n",
"\n",
"[{2675283658.py:10} INFO - test_recall = 0.8945943620394088\n",
"\n",
"[{2675283658.py:10} INFO - test_runtime = 15.2768\n",
"\n",
"[{2675283658.py:10} INFO - test_samples_per_second = 857.77\n",
"\n",
"[{2675283658.py:10} INFO - test_steps_per_second = 3.404\n",
"\n"
]
}
],
"source": [
"outputs = trainer.predict(valid_dataset)\n",
"eval_results = outputs.metrics\n",
"\n",
"# writes eval result to file which can be accessed later in s3 ouput\n",
"with open(os.path.join(args.output_data_dir, \"eval_results.txt\"), \"w\") as writer:\n",
" print(f\"***** Evaluation results at {args.output_data_dir} *****\")\n",
" for key, value in sorted(eval_results.items()):\n",
" writer.write(f\"{key} = {value}\\n\")\n",
" logger.info(f\"{key} = {value}\\n\")"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'precision': 0.8945943620394088,\n",
" 'recall': 0.8945943620394088,\n",
" 'f1': 0.8945943620394088,\n",
" 'accuracy': 0.8945943620394088}"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"compute_metrics(outputs)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"pred_logits = outputs.predictions\n",
"true = outputs.label_ids.ravel()\n",
"pred = pred_logits.argmax(-1).ravel()\n",
"pred = pred[true != -100]\n",
"true = true[true != -100]\n",
"\n",
"true_tag = [id2tag[x] for x in true]\n",
"pred_tag = [id2tag[x] for x in pred]"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, msg_start, len(result))\n",
"/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, msg_start, len(result))\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" PER_B 0.85 0.85 0.85 22546\n",
" DAT_B 0.90 0.88 0.89 8588\n",
" - 0.94 0.96 0.95 280106\n",
" ORG_B 0.86 0.85 0.85 28424\n",
" CVL_B 0.78 0.76 0.77 31916\n",
" NUM_B 0.94 0.94 0.94 22513\n",
" LOC_B 0.79 0.75 0.77 12026\n",
" EVT_B 0.78 0.79 0.78 7349\n",
" TRM_B 0.78 0.70 0.73 12164\n",
" TRM_I 0.56 0.46 0.50 1721\n",
" EVT_I 0.73 0.79 0.76 3441\n",
" PER_I 0.75 0.72 0.74 2972\n",
" CVL_I 0.47 0.39 0.42 1724\n",
" NUM_I 0.71 0.80 0.75 2857\n",
" TIM_B 0.79 0.87 0.83 1133\n",
" TIM_I 0.89 0.94 0.92 442\n",
" ORG_I 0.65 0.63 0.64 3122\n",
" DAT_I 0.82 0.88 0.85 2080\n",
" ANM_B 0.70 0.69 0.69 2803\n",
" MAT_B 0.45 0.22 0.30 149\n",
" MAT_I 0.00 0.00 0.00 2\n",
" AFW_B 0.64 0.49 0.56 2774\n",
" FLD_B 0.52 0.54 0.53 1208\n",
" LOC_I 0.00 0.00 0.00 83\n",
" AFW_I 0.56 0.42 0.48 952\n",
" PLT_B 0.00 0.00 0.00 106\n",
" FLD_I 0.00 0.00 0.00 18\n",
" ANM_I 0.00 0.00 0.00 28\n",
" PLT_I 0.00 0.00 0.00 2\n",
"\n",
" accuracy 0.89 453249\n",
" macro avg 0.58 0.56 0.57 453249\n",
"weighted avg 0.89 0.89 0.89 453249\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, msg_start, len(result))\n"
]
}
],
"source": [
"from sklearn.metrics import precision_score, recall_score, f1_score, classification_report\n",
"print(classification_report(true_tag, pred_tag, labels=unique_tags))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "emi8JP4I4lxX"
},
"source": [
"
\n",
"\n",
"## 6. Prediction\n",
"---"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mvYrdoG8t8UV",
"outputId": "7e64b739-52c4-45c2-9bed-8e13437bf666"
},
"outputs": [],
"source": [
"def predict_fn_token(example, model):\n",
" \n",
" inputs_with_offsets = tokenizer(example, return_offsets_mapping=True, return_tensors='pt')\n",
" inputs = inputs_with_offsets.copy()\n",
" inputs.pop('offset_mapping')\n",
" tokens = inputs_with_offsets.tokens()\n",
" offsets = inputs_with_offsets[\"offset_mapping\"]\n",
" \n",
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
" inputs = inputs.to(device)\n",
" model = model.to(device)\n",
" \n",
" output = model(**inputs)\n",
" softmax_fn = nn.Softmax(dim=-1)\n",
" softmax_output = softmax_fn(output['logits'])\n",
" prob, pred = torch.max(softmax_output, dim=-1)\n",
" pred_str_lst = [model.config.id2label[id] for id in pred.squeeze().tolist()]\n",
" prob_lst, pred_lst = prob.squeeze().tolist(), pred.squeeze().tolist()\n",
" offsets_lst = offsets.squeeze().tolist()\n",
" \n",
" df = pd.DataFrame(zip(tokens, pred_str_lst, prob_lst, pred_lst, offsets_lst), \n",
" columns=['token', 'tag', 'score', 'label', 'offset'])\n",
" \n",
" return df\n",
"\n",
"\n",
"def predict_fn_word(example, model):\n",
"\n",
" from transformers import pipeline\n",
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
" device_id = -1 if device.type == \"cpu\" else 0\n",
" \n",
" nlp = pipeline(\"ner\", model=model.to(device), device=device_id, \n",
" tokenizer=tokenizer, aggregation_strategy='average')\n",
" results = nlp(example)\n",
" \n",
" entity_lst, score_lst, word_lst, start_lst, end_lst = [], [], [], [], []\n",
" tag2entity[''] = '-'\n",
"\n",
" for result in results:\n",
" entity = tag2entity[result['entity_group']]\n",
" score = result['score']\n",
" word = result['word']\n",
" start = result['start']\n",
" end = result['end']\n",
"\n",
" entity_lst.append(entity)\n",
" score_lst.append(score)\n",
" word_lst.append(word)\n",
" start_lst.append(start)\n",
" end_lst.append(end)\n",
"\n",
" df = pd.DataFrame(zip(word_lst, entity_lst, score_lst, start_lst, end_lst), \n",
" columns=['word', 'entity', 'score', 'start', 'end'])\n",
" return df"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"example = '잉글랜드 프로축구 프리미어리그 토트넘 홋스퍼가 손흥민의 A매치 100경기 이상 출전 센추리클럽 가입에 축하를 보냈다.'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prediction by token\n",
"\n",
"토큰 단위로 예측을 수행하므로 모델의 결괏값은 단어 단위가 아닙니다. 단어 단위로 예측하는 방법은 아래 섹션을 확인하세요."
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LP1sNHYuvXKS",
"outputId": "5a7c5100-b68d-4d5a-a47b-93051c977a3a"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" token | \n",
" tag | \n",
" score | \n",
" label | \n",
" offset | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" [CLS] | \n",
" - | \n",
" 0.999501 | \n",
" 2 | \n",
" [0, 0] | \n",
"
\n",
" \n",
" 1 | \n",
" 잉 | \n",
" LOC_B | \n",
" 0.951433 | \n",
" 6 | \n",
" [0, 1] | \n",
"
\n",
" \n",
" 2 | \n",
" ##글 | \n",
" LOC_B | \n",
" 0.956234 | \n",
" 6 | \n",
" [1, 2] | \n",
"
\n",
" \n",
" 3 | \n",
" ##랜드 | \n",
" LOC_B | \n",
" 0.962042 | \n",
" 6 | \n",
" [2, 4] | \n",
"
\n",
" \n",
" 4 | \n",
" 프로 | \n",
" CVL_B | \n",
" 0.731096 | \n",
" 4 | \n",
" [5, 7] | \n",
"
\n",
" \n",
" 5 | \n",
" ##축 | \n",
" CVL_B | \n",
" 0.785559 | \n",
" 4 | \n",
" [7, 8] | \n",
"
\n",
" \n",
" 6 | \n",
" ##구 | \n",
" CVL_B | \n",
" 0.755286 | \n",
" 4 | \n",
" [8, 9] | \n",
"
\n",
" \n",
" 7 | \n",
" 프 | \n",
" ORG_B | \n",
" 0.497772 | \n",
" 3 | \n",
" [10, 11] | \n",
"
\n",
" \n",
" 8 | \n",
" ##리 | \n",
" ORG_B | \n",
" 0.487815 | \n",
" 3 | \n",
" [11, 12] | \n",
"
\n",
" \n",
" 9 | \n",
" ##미 | \n",
" ORG_B | \n",
" 0.498659 | \n",
" 3 | \n",
" [12, 13] | \n",
"
\n",
" \n",
" 10 | \n",
" ##어 | \n",
" ORG_B | \n",
" 0.502374 | \n",
" 3 | \n",
" [13, 14] | \n",
"
\n",
" \n",
" 11 | \n",
" ##리그 | \n",
" ORG_B | \n",
" 0.518754 | \n",
" 3 | \n",
" [14, 16] | \n",
"
\n",
" \n",
" 12 | \n",
" 토 | \n",
" ORG_B | \n",
" 0.965497 | \n",
" 3 | \n",
" [17, 18] | \n",
"
\n",
" \n",
" 13 | \n",
" ##트 | \n",
" ORG_B | \n",
" 0.964893 | \n",
" 3 | \n",
" [18, 19] | \n",
"
\n",
" \n",
" 14 | \n",
" ##넘 | \n",
" ORG_B | \n",
" 0.967747 | \n",
" 3 | \n",
" [19, 20] | \n",
"
\n",
" \n",
" 15 | \n",
" 홋 | \n",
" ORG_I | \n",
" 0.974019 | \n",
" 16 | \n",
" [21, 22] | \n",
"
\n",
" \n",
" 16 | \n",
" ##스 | \n",
" ORG_I | \n",
" 0.974807 | \n",
" 16 | \n",
" [22, 23] | \n",
"
\n",
" \n",
" 17 | \n",
" ##퍼 | \n",
" ORG_I | \n",
" 0.975250 | \n",
" 16 | \n",
" [23, 24] | \n",
"
\n",
" \n",
" 18 | \n",
" ##가 | \n",
" ORG_I | \n",
" 0.976544 | \n",
" 16 | \n",
" [24, 25] | \n",
"
\n",
" \n",
" 19 | \n",
" 손 | \n",
" PER_B | \n",
" 0.994671 | \n",
" 0 | \n",
" [26, 27] | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" token tag score label offset\n",
"0 [CLS] - 0.999501 2 [0, 0]\n",
"1 잉 LOC_B 0.951433 6 [0, 1]\n",
"2 ##글 LOC_B 0.956234 6 [1, 2]\n",
"3 ##랜드 LOC_B 0.962042 6 [2, 4]\n",
"4 프로 CVL_B 0.731096 4 [5, 7]\n",
"5 ##축 CVL_B 0.785559 4 [7, 8]\n",
"6 ##구 CVL_B 0.755286 4 [8, 9]\n",
"7 프 ORG_B 0.497772 3 [10, 11]\n",
"8 ##리 ORG_B 0.487815 3 [11, 12]\n",
"9 ##미 ORG_B 0.498659 3 [12, 13]\n",
"10 ##어 ORG_B 0.502374 3 [13, 14]\n",
"11 ##리그 ORG_B 0.518754 3 [14, 16]\n",
"12 토 ORG_B 0.965497 3 [17, 18]\n",
"13 ##트 ORG_B 0.964893 3 [18, 19]\n",
"14 ##넘 ORG_B 0.967747 3 [19, 20]\n",
"15 홋 ORG_I 0.974019 16 [21, 22]\n",
"16 ##스 ORG_I 0.974807 16 [22, 23]\n",
"17 ##퍼 ORG_I 0.975250 16 [23, 24]\n",
"18 ##가 ORG_I 0.976544 16 [24, 25]\n",
"19 손 PER_B 0.994671 0 [26, 27]"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"token_df = predict_fn_token(example, model)\n",
"token_df.head(20)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prediction by word\n",
"\n",
"pipeline 메소드와 score를 어떻게 집계할 것인지를 지정하는 aggregation_strategy를 를 사용하여 단어 단위로 예측을 수행합니다.\n",
"FastTokenizer를 사용하면 밑바닥부터 구현할 필요 없이, 편리하게 단어 단위로 예측 결과를 얻을 수 있습니다."
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" word | \n",
" entity | \n",
" score | \n",
" start | \n",
" end | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 잉글랜드 | \n",
" Location | \n",
" 0.956570 | \n",
" 0 | \n",
" 4 | \n",
"
\n",
" \n",
" 1 | \n",
" 프로축구 | \n",
" Civilization | \n",
" 0.757313 | \n",
" 5 | \n",
" 9 | \n",
"
\n",
" \n",
" 2 | \n",
" 프리미어리그 토트넘 | \n",
" Organization | \n",
" 0.733560 | \n",
" 10 | \n",
" 20 | \n",
"
\n",
" \n",
" 3 | \n",
" 홋스퍼가 | \n",
" Organization | \n",
" 0.975155 | \n",
" 21 | \n",
" 25 | \n",
"
\n",
" \n",
" 4 | \n",
" 손흥민의 | \n",
" Person | \n",
" 0.994924 | \n",
" 26 | \n",
" 30 | \n",
"
\n",
" \n",
" 5 | \n",
" A매치 | \n",
" Event | \n",
" 0.936316 | \n",
" 31 | \n",
" 34 | \n",
"
\n",
" \n",
" 6 | \n",
" 100경기 | \n",
" Number | \n",
" 0.996811 | \n",
" 35 | \n",
" 40 | \n",
"
\n",
" \n",
" 7 | \n",
" 이상 출전 | \n",
" - | \n",
" 0.944591 | \n",
" 41 | \n",
" 46 | \n",
"
\n",
" \n",
" 8 | \n",
" 센추리클럽 | \n",
" Organization | \n",
" 0.979096 | \n",
" 47 | \n",
" 52 | \n",
"
\n",
" \n",
" 9 | \n",
" 가입에 축하를 보냈다. | \n",
" - | \n",
" 0.997992 | \n",
" 53 | \n",
" 65 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" word entity score start end\n",
"0 잉글랜드 Location 0.956570 0 4\n",
"1 프로축구 Civilization 0.757313 5 9\n",
"2 프리미어리그 토트넘 Organization 0.733560 10 20\n",
"3 홋스퍼가 Organization 0.975155 21 25\n",
"4 손흥민의 Person 0.994924 26 30\n",
"5 A매치 Event 0.936316 31 34\n",
"6 100경기 Number 0.996811 35 40\n",
"7 이상 출전 - 0.944591 41 46\n",
"8 센추리클럽 Organization 0.979096 47 52\n",
"9 가입에 축하를 보냈다. - 0.997992 53 65"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"word_df = predict_fn_word(example, model)\n",
"word_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example\n",
"여러분만의 샘플 문장을 만들어서 자유롭게 추론을 수행해 보세요."
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" word | \n",
" entity | \n",
" score | \n",
" start | \n",
" end | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 아마존 | \n",
" Organization | \n",
" 0.693333 | \n",
" 0 | \n",
" 3 | \n",
"
\n",
" \n",
" 1 | \n",
" SageMaker는 머신 | \n",
" Term | \n",
" 0.709337 | \n",
" 4 | \n",
" 17 | \n",
"
\n",
" \n",
" 2 | \n",
" 러닝 | \n",
" Term | \n",
" 0.552327 | \n",
" 18 | \n",
" 20 | \n",
"
\n",
" \n",
" 3 | \n",
" 통합 | \n",
" - | \n",
" 0.768233 | \n",
" 21 | \n",
" 23 | \n",
"
\n",
" \n",
" 4 | \n",
" 엔드투엔드 | \n",
" Term | \n",
" 0.454221 | \n",
" 24 | \n",
" 29 | \n",
"
\n",
" \n",
" 5 | \n",
" 관리형 서비스로 | \n",
" - | \n",
" 0.648227 | \n",
" 30 | \n",
" 38 | \n",
"
\n",
" \n",
" 6 | \n",
" 2017년 | \n",
" Date | \n",
" 0.990274 | \n",
" 39 | \n",
" 44 | \n",
"
\n",
" \n",
" 7 | \n",
" 런칭되었다. | \n",
" - | \n",
" 0.999204 | \n",
" 45 | \n",
" 51 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" word entity score start end\n",
"0 아마존 Organization 0.693333 0 3\n",
"1 SageMaker는 머신 Term 0.709337 4 17\n",
"2 러닝 Term 0.552327 18 20\n",
"3 통합 - 0.768233 21 23\n",
"4 엔드투엔드 Term 0.454221 24 29\n",
"5 관리형 서비스로 - 0.648227 30 38\n",
"6 2017년 Date 0.990274 39 44\n",
"7 런칭되었다. - 0.999204 45 51"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"example = '아마존 SageMaker는 머신 러닝 통합 엔드투엔드 관리형 서비스로 2017년 런칭되었다.'\n",
"word_df = predict_fn_word(example, model)\n",
"word_df"
]
}
],
"metadata": {
"accelerator": "TPU",
"colab": {
"authorship_tag": "ABX9TyPT/32fR6YbrNgmG6aLi8U7",
"include_colab_link": true,
"machine_shape": "hm",
"name": "5_(BERT_실습)한국어 개체명 인식.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "conda_pytorch_p38",
"language": "python",
"name": "conda_pytorch_p38"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"1ac7cea5aaba45af9eddeaaee02e1e5a": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"2ef88e8c35374ca69203a64d209745ea": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"32c340873ce247e88df66c73309eecdc": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"4a4ff12bb4604faf8c1cd79156713854": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"51eafe68808a4ffbac05605381c2d5a3": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"598f160635264f138769ae94a127455c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "Downloading: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_1ac7cea5aaba45af9eddeaaee02e1e5a",
"max": 1961828,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_b61fb58de9be4c84b7767bf69e32c5d9",
"value": 1961828
}
},
"5b0843766d3f4ac785c7dba85254d605": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_ca07bf481be7473ab1b22babaa76c3da",
"IPY_MODEL_ab60a81459a84ae19fa922aa4ce27e8a"
],
"layout": "IPY_MODEL_6356e0609f9f49d5996ef7f4f77fbd2d"
}
},
"6356e0609f9f49d5996ef7f4f77fbd2d": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"693887ccba30416586e2085b7e36118b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_9169c98d7ee5423ba5b6eb3c4dbbeda7",
"placeholder": "",
"style": "IPY_MODEL_cc353722b52045efb9009ef79c7d56b7",
"value": " 1.96M/1.96M [00:00<00:00, 6.43MB/s]"
}
},
"6cb6badcbbd34359be9dad2c8af93098": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"88f28f34de9e41cc948c7aebb4035589": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"8d879e2bbea04536aaa1ee5d356bb7c3": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_2ef88e8c35374ca69203a64d209745ea",
"placeholder": "",
"style": "IPY_MODEL_32c340873ce247e88df66c73309eecdc",
"value": " 29.0/29.0 [00:00<00:00, 50.2B/s]"
}
},
"9169c98d7ee5423ba5b6eb3c4dbbeda7": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"93234e1bd6444d819b130d83402d2d7b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_e6e0a2274e7f420f91dc97a143315da0",
"IPY_MODEL_8d879e2bbea04536aaa1ee5d356bb7c3"
],
"layout": "IPY_MODEL_ad295e710380441588473f810a9210d7"
}
},
"a0dbf235f20c497186d319b5b1558dd9": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"ab60a81459a84ae19fa922aa4ce27e8a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_51eafe68808a4ffbac05605381c2d5a3",
"placeholder": "",
"style": "IPY_MODEL_6cb6badcbbd34359be9dad2c8af93098",
"value": " 996k/996k [00:00<00:00, 1.67MB/s]"
}
},
"ad295e710380441588473f810a9210d7": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"b61fb58de9be4c84b7767bf69e32c5d9": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"ca07bf481be7473ab1b22babaa76c3da": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "Downloading: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_88f28f34de9e41cc948c7aebb4035589",
"max": 995526,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_4a4ff12bb4604faf8c1cd79156713854",
"value": 995526
}
},
"cc353722b52045efb9009ef79c7d56b7": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"e4dcb08aab3748b18a10d0f5daaf3554": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_598f160635264f138769ae94a127455c",
"IPY_MODEL_693887ccba30416586e2085b7e36118b"
],
"layout": "IPY_MODEL_f1ed2fc28a3e499fa784d5aa1777a77b"
}
},
"e6e0a2274e7f420f91dc97a143315da0": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "Downloading: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_ed834b7f997141479ab90216655e230a",
"max": 29,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_a0dbf235f20c497186d319b5b1558dd9",
"value": 29
}
},
"ed834b7f997141479ab90216655e230a": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"f1ed2fc28a3e499fa784d5aa1777a77b": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
}
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}