{ "cells": [ { "cell_type": "markdown", "id": "94ed44ac-98e5-4200-8f1c-3adedec2866b", "metadata": {}, "source": [ "# Lab 1: Sentence-BERT (SBERT) Training\n", "\n", "### Fine-tuning Sentence-BERT (SBERT) & SBERT Embedding for applications\n", "---" ] }, { "cell_type": "markdown", "id": "2d897cd6-a127-4a1c-9980-aa033c332f6a", "metadata": {}, "source": [ "\n", "## Introduction\n", "---\n", "\n", "본 모듈ì—서는 문장 ìž„ë² ë”©ì„ ì‚°ì¶œí•˜ëŠ” Sentence-BERT 모ë¸ì„ STS ë°ì´í„°ì…‹ìœ¼ë¡œ 파ì¸íŠœë‹í•´ 봅니다.\n", "SentenceTransformers 패키지를 사용하면 파ì¸íŠœë‹ì„ 쉽게 ìˆ˜í–‰í• ìˆ˜ 있습니다. 다만, 현 ì‹œì ì—는 분산 í›ˆë ¨ 기능 지ì›ì´ 잘 ë˜ì§€ 않으므로, 대용량 ë°ì´í„°ì…‹ìœ¼ë¡œ 파ì¸íŠœë‹í•˜ëŠ” 니즈가 있다면 커스텀 í›ˆë ¨ 코드를 ì§ì ‘ 작성하셔야 합니다.\n", "\n", "***[Note] SageMaker Studio Lab, SageMaker Studio, SageMaker ë…¸íŠ¸ë¶ ì¸ìŠ¤í„´ìŠ¤, ë˜ëŠ” ì—¬ëŸ¬ë¶„ì˜ ë¡œì»¬ ë¨¸ì‹ ì—ì„œ ì´ ë°ëª¨ë¥¼ ì‹¤í–‰í• ìˆ˜ 있습니다. SageMaker Studio Labì„ ì‚¬ìš©í•˜ëŠ” 경우 GPU를 활성화하세요.***\n", "\n", "### References\n", "\n", "- Hugging Face Tutorial: https://huggingface.co/docs/transformers/training\n", "- Sentence-BERT paper: https://arxiv.org/abs/1908.10084\n", "- SentenceTransformers: https://www.sbert.net" ] }, { "cell_type": "markdown", "id": "577d681d-0f93-44ea-a3d1-0e0a9bb5e063", "metadata": {}, "source": [ "\n", "## 1. Setup Environments\n", "---\n", "\n", "### Import modules" ] }, { "cell_type": "code", "execution_count": 1, "id": "91641d7f-95e5-473b-b144-823d4a19a299", "metadata": {}, "outputs": [], "source": [ "# !pip install sentence_transformers datasets faiss-gpu progressbar" ] }, { "cell_type": "code", "execution_count": 2, "id": "c6d4b001-3508-4ca1-b73b-76b11fdcce7a", "metadata": {}, "outputs": [], "source": [ "import os\n", "import sys\n", "import json\n", "import logging\n", "import argparse\n", "import torch\n", "import gzip\n", "import csv\n", "import math\n", "import urllib\n", "from torch import nn\n", "import numpy as np\n", "import pandas as pd\n", "from tqdm import tqdm\n", "\n", "from datetime import datetime\n", "from datasets import load_dataset\n", "from torch.utils.data import DataLoader\n", "from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses, models, util\n", "from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator\n", "from sentence_transformers.readers import InputExample\n", "from transformers.trainer_utils import get_last_checkpoint\n", "\n", "logging.basicConfig(\n", " level=logging.INFO,\n", " format='%(asctime)s - %(message)s',\n", " datefmt='%Y-%m-%d %H:%M:%S',\n", " handlers=[LoggingHandler()]\n", ")\n", "\n", "logger = logging.getLogger(__name__)" ] }, { "cell_type": "markdown", "id": "a1215e31-b560-4ea9-8fd2-eb097c9dd30b", "metadata": {}, "source": [ "### Argument parser" ] }, { "cell_type": "code", "execution_count": 3, "id": "c512803f-d734-4edd-a868-0f369f6b3888", "metadata": {}, "outputs": [], "source": [ "def parser_args(train_notebook=False):\n", " parser = argparse.ArgumentParser()\n", "\n", " # Default Setting\n", " parser.add_argument(\"--epochs\", type=int, default=1)\n", " parser.add_argument(\"--seed\", type=int, default=42)\n", " parser.add_argument(\"--train_batch_size\", type=int, default=32)\n", " parser.add_argument(\"--eval_batch_size\", type=int, default=32)\n", " parser.add_argument(\"--warmup_steps\", type=int, default=100)\n", " parser.add_argument(\"--logging_steps\", type=int, default=100)\n", " parser.add_argument(\"--learning_rate\", type=str, default=5e-5)\n", " parser.add_argument(\"--disable_tqdm\", type=bool, default=False)\n", " parser.add_argument(\"--fp16\", type=bool, default=True)\n", " parser.add_argument(\"--tokenizer_id\", type=str, default='sentence-transformers/xlm-r-100langs-bert-base-nli-stsb-mean-tokens')\n", " parser.add_argument(\"--model_id\", type=str, default='sentence-transformers/xlm-r-100langs-bert-base-nli-stsb-mean-tokens')\n", " \n", " # SageMaker Container environment\n", " parser.add_argument(\"--output_data_dir\", type=str, default=os.environ[\"SM_OUTPUT_DATA_DIR\"])\n", " parser.add_argument(\"--model_dir\", type=str, default=os.environ[\"SM_MODEL_DIR\"])\n", " parser.add_argument(\"--n_gpus\", type=str, default=os.environ[\"SM_NUM_GPUS\"])\n", " parser.add_argument(\"--train_dir\", type=str, default=os.environ[\"SM_CHANNEL_TRAIN\"])\n", " parser.add_argument(\"--valid_dir\", type=str, default=os.environ[\"SM_CHANNEL_VALID\"])\n", " parser.add_argument(\"--test_dir\", type=str, default=os.environ[\"SM_CHANNEL_TEST\"]) \n", " parser.add_argument('--chkpt_dir', type=str, default='/opt/ml/checkpoints') \n", "\n", " if train_notebook:\n", " args = parser.parse_args([])\n", " else:\n", " args = parser.parse_args()\n", " return args" ] }, { "cell_type": "code", "execution_count": 4, "id": "970aca02-6ba6-4111-943a-8ffa15e5ead9", "metadata": {}, "outputs": [], "source": [ "train_dir = 'train'\n", "valid_dir = 'valid'\n", "test_dir = 'test'\n", "!rm -rf {train_dir} {valid_dir} {test_dir} \n", "os.makedirs(train_dir, exist_ok=True)\n", "os.makedirs(valid_dir, exist_ok=True) \n", "os.makedirs(test_dir, exist_ok=True) " ] }, { "cell_type": "markdown", "id": "9635cf55-6781-4f0b-bed4-c4c49a76619c", "metadata": {}, "source": [ "### Load Arguments\n", "\n", "주피터 노트ë¶ì—ì„œ 곧바로 ì‹¤í–‰í• ìˆ˜ 있ë„ë¡ ì„¤ì •ê°’ë“¤ì„ ë¡œë“œí•©ë‹ˆë‹¤. ë¬¼ë¡ ë…¸íŠ¸ë¶ í™˜ê²½ì´ ì•„ë‹Œ 커맨드ë¼ì¸ì—ì„œë„ `cd scripts & python3 train.py` 커맨드로 í›ˆë ¨ 스í¬ë¦½íŠ¸ë¥¼ ì‹¤í–‰í• ìˆ˜ 있습니다." ] }, { "cell_type": "code", "execution_count": 5, "id": "1825c0c8-2106-445b-bb7e-6d255dd3bd1c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2022-07-20 08:42:26 - ***** Arguments *****\n", "2022-07-20 08:42:26 - epochs=1\n", "seed=42\n", "train_batch_size=32\n", "eval_batch_size=32\n", "warmup_steps=100\n", "logging_steps=100\n", "learning_rate=5e-05\n", "disable_tqdm=False\n", "fp16=True\n", "tokenizer_id=salti/bert-base-multilingual-cased-finetuned-squad\n", "model_id=salti/bert-base-multilingual-cased-finetuned-squad\n", "output_data_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/data\n", "model_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/model\n", "n_gpus=4\n", "train_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/train\n", "valid_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/valid\n", "test_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/test\n", "chkpt_dir=chkpt\n", "\n" ] } ], "source": [ "chkpt_dir = 'chkpt'\n", "model_dir = 'model'\n", "output_data_dir = 'data'\n", "num_gpus = torch.cuda.device_count()\n", "\n", "!rm -rf {chkpt_dir} {model_dir} {output_data_dir} \n", "\n", "if os.environ.get('SM_CURRENT_HOST') is None:\n", " is_sm_container = False\n", "\n", " #src_dir = '/'.join(os.getcwd().split('/')[:-1])\n", " src_dir = os.getcwd()\n", " os.environ['SM_MODEL_DIR'] = f'{src_dir}/{model_dir}'\n", " os.environ['SM_OUTPUT_DATA_DIR'] = f'{src_dir}/{output_data_dir}'\n", " os.environ['SM_NUM_GPUS'] = str(num_gpus)\n", " os.environ['SM_CHANNEL_TRAIN'] = f'{src_dir}/{train_dir}'\n", " os.environ['SM_CHANNEL_VALID'] = f'{src_dir}/{valid_dir}'\n", " os.environ['SM_CHANNEL_TEST'] = f'{src_dir}/{test_dir}'\n", " \n", "args = parser_args(train_notebook=True) \n", "args.chkpt_dir = chkpt_dir\n", "logger.info(\"***** Arguments *****\")\n", "logger.info(''.join(f'{k}={v}\\n' for k, v in vars(args).items()))\n", "\n", "os.makedirs(args.chkpt_dir, exist_ok=True) \n", "os.makedirs(args.model_dir, exist_ok=True)\n", "os.makedirs(args.output_data_dir, exist_ok=True) " ] }, { "cell_type": "markdown", "id": "5478a4d9-eaad-4652-8b1a-141d40b18d88", "metadata": {}, "source": [ "<br>\n", "\n", "## 2. Preparation\n", "---\n", "본 핸즈온ì—ì„œ ì‚¬ìš©í• ë°ì´í„°ì…‹ì€ KorSTS (https://github.com/kakaobrain/KorNLUDatasets) 와 KLUE-STS (https://github.com/KLUE-benchmark/KLUE) 입니다.\n", "ë‹¨ì¼ ë°ì´í„°ì…‹ìœ¼ë¡œ í›ˆë ¨í•´ë„ ë¬´ë°©í•˜ì§€ë§Œ, ë‘ ë°ì´í„°ì…‹ì„ ëª¨ë‘ í™œìš©í•˜ì—¬ í›ˆë ¨ ì‹œ, ì•½ê°„ì˜ ì„±ëŠ¥ í–¥ìƒì´ 있습니다.\n", "\n", "### Training Tips\n", "SBERT í›ˆë ¨ì€ ì¼ë°˜ì 으로 아래 3가지 ë°©ë²•ë“¤ì„ ë² ì´ìŠ¤ë¼ì¸ìœ¼ë¡œ 사용합니다.\n", "1. NLI ë°ì´í„°ì…‹ìœ¼ë¡œ í›ˆë ¨\n", "2. STS ë°ì´í„°ì…‹ìœ¼ë¡œ í›ˆë ¨\n", "3. NLI ë°ì´í„°ì…‹ìœ¼ë¡œ í›ˆë ¨ 후 STS ë°ì´í„°ì…‹ìœ¼ë¡œ 파ì¸íŠœë‹\n", "\n", "í•œêµì–´ ë°ì´í„°ì˜ 경우, STSì˜ í›ˆë ¨ ë°ì´í„°ê°€ ìƒëŒ€ì 으로 ì ìŒì—ë„ ë¶ˆêµ¬í•˜ê³ NLI 기반 모ë¸ë³´ë‹¤ 예측 ì„±ëŠ¥ì´ ìš°ìˆ˜í•©ë‹ˆë‹¤. ë”°ë¼ì„œ, 2번째 방법으로 진행합니다. <br>\n", "다만, STS보다 조금 ë” ì¢‹ì€ ì˜ˆì¸¡ ì„±ëŠ¥ì„ ì›í•œë‹¤ë©´ NLI ë°ì´í„°ë¡œ ë¨¼ì € í›ˆë ¨í•˜ê³ STS ë°ì´í„°ì…‹ìœ¼ë¡œ ì´ì–´ì„œ í›ˆë ¨í•˜ëŠ” ê²ƒì„ ê¶Œìž¥í•©ë‹ˆë‹¤." ] }, { "cell_type": "markdown", "id": "9aa546b9-2072-42f3-b886-18a06069edf0", "metadata": {}, "source": [ "### KLUE-STS ë°ì´í„°ì…‹ 다운로드 ë° í”¼ì³ì…‹ ìƒì„±\n", "KLUE-STS ë°ì´í„°ì…‹ì„ 허깅페ì´ìŠ¤ ë°ì´í„°ì…‹ 허브ì—ì„œ 다운로드 후, SBERT í›ˆë ¨ì— í•„ìš”í•œ 피ì³ì…‹ì„ ìƒì„±í•©ë‹ˆë‹¤." ] }, { "cell_type": "code", "execution_count": 6, "id": "fc5b4880-9907-44f0-9e45-2e67f3938a60", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2022-07-20 08:42:30 - Read KLUE-STS train/dev dataset\n", "2022-07-20 08:42:30 - Reusing dataset klue (/home/ec2-user/.cache/huggingface/datasets/klue/sts/1.0.0/e0fc3bc3de3eb03be2c92d72fd04a60ecc71903f821619cb28ca0e1e29e4233e)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6fd3d65e0c0a482b9cdb08ba4409c3c8", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/2 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "logger.info(\"Read KLUE-STS train/dev dataset\")\n", "datasets = load_dataset(\"klue\", \"sts\")\n", "\n", "train_samples = []\n", "dev_samples = []\n", "\n", "for phase in [\"train\", \"validation\"]:\n", " examples = datasets[phase]\n", "\n", " for example in examples:\n", " score = float(example[\"labels\"][\"label\"]) / 5.0 # 0.0 ~ 1.0 스케ì¼ë¡œ ìœ ì‚¬ë„ ì •ê·œí™”\n", " inp_example = InputExample(texts=[example[\"sentence1\"], example[\"sentence2\"]], label=score)\n", "\n", " if phase == \"validation\":\n", " dev_samples.append(inp_example)\n", " else:\n", " train_samples.append(inp_example)" ] }, { "cell_type": "markdown", "id": "972b0261-3a0e-4612-b1ff-128a46fb37fd", "metadata": {}, "source": [ "\n", "### KorSTS ë°ì´í„°ì…‹ 다운로드 ë° í”¼ì³ì…‹ ìƒì„±\n", "KorSTS ë°ì´í„°ì…‹ì€ 허깅페ì´ìŠ¤ì—ë„ ë“±ë¡ë˜ì–´ 있지만, 향후 ì—¬ëŸ¬ë¶„ì˜ ì»¤ìŠ¤í…€ ë°ì´í„°ì…‹ì„ ê°™ì´ ì‚¬ìš©í•˜ëŠ” ìœ ì¦ˆì¼€ì´ìŠ¤ë¥¼ ê³ ë ¤í•˜ì—¬ GitHubì˜ ë°ì´í„°ì…‹ì„ 다운로드받아 ì‚¬ìš©í•˜ê² ìŠµë‹ˆë‹¤. " ] }, { "cell_type": "code", "execution_count": 7, "id": "589cb12d-9a43-4651-912c-466dbbfcd75d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('/home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/test/sts-test.tsv',\n", " <http.client.HTTPMessage at 0x7fee52a7f970>)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "repo = 'https://raw.githubusercontent.com/kakaobrain/KorNLUDatasets/master/KorSTS'\n", "urllib.request.urlretrieve(f'{repo}/sts-train.tsv', filename=f'{args.train_dir}/sts-train.tsv')\n", "urllib.request.urlretrieve(f'{repo}/sts-dev.tsv', filename=f'{args.valid_dir}/sts-dev.tsv')\n", "urllib.request.urlretrieve(f'{repo}/sts-test.tsv', filename=f'{args.test_dir}/sts-test.tsv')\n", "\n", "# !wget https://raw.githubusercontent.com/kakaobrain/KorNLUDatasets/master/KorSTS/sts-train.tsv -O {train_dir}/sts-train.tsv\n", "# !wget https://raw.githubusercontent.com/kakaobrain/KorNLUDatasets/master/KorSTS/sts-dev.tsv -O {valid_dir}/sts-dev.tsv\n", "# !wget https://raw.githubusercontent.com/kakaobrain/KorNLUDatasets/master/KorSTS/sts-test.tsv -O {test_dir}/sts-test.tsv" ] }, { "cell_type": "code", "execution_count": 8, "id": "0ad5f3ed-a0b9-4970-9b11-e3740cd92aec", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2022-07-20 08:42:36 - Read KorSTS train dataset\n", "2022-07-20 08:42:36 - Read KorSTS dev dataset\n" ] } ], "source": [ "logger.info(\"Read KorSTS train dataset\")\n", "\n", "with open(f'{args.train_dir}/sts-train.tsv', 'rt', encoding='utf8') as fIn:\n", " reader = csv.DictReader(fIn, delimiter='\\t', quoting=csv.QUOTE_NONE)\n", " for row in reader:\n", " if row[\"sentence1\"] and row[\"sentence2\"]: \n", " score = float(row['score']) / 5.0 # Normalize score to range 0 ... 1\n", " inp_example = InputExample(texts=[row['sentence1'], row['sentence2']], label=score)\n", " train_samples.append(inp_example)\n", " \n", "logging.info(\"Read KorSTS dev dataset\") \n", "with open(f'{args.valid_dir}/sts-dev.tsv', 'rt', encoding='utf8') as fIn:\n", " reader = csv.DictReader(fIn, delimiter='\\t', quoting=csv.QUOTE_NONE)\n", " for row in reader:\n", " if row[\"sentence1\"] and row[\"sentence2\"]: \n", " score = float(row['score']) / 5.0 # Normalize score to range 0 ... 1\n", " inp_example = InputExample(texts=[row['sentence1'], row['sentence2']], label=score)\n", " dev_samples.append(inp_example) " ] }, { "cell_type": "markdown", "id": "118a5b4e-81b3-48ad-b97e-b0cb8f98b099", "metadata": {}, "source": [ "<br>\n", "\n", "## 3. Training\n", "---\n", "\n", "### Training Preparation" ] }, { "cell_type": "markdown", "id": "57d7fdc0-6779-44d1-b9b6-c2a7a5d43706", "metadata": {}, "source": [ "### Model" ] }, { "cell_type": "code", "execution_count": 9, "id": "5ef7a33c-a8bc-427c-a64f-c5d7dfa5a09c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2022-07-20 08:43:44 - /home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/model/training_sts_sentence-transformers-xlm-r-100langs-bert-base-nli-stsb-mean-tokens-2022-07-20_08-43-44\n" ] } ], "source": [ "model_name = 'sentence-transformers/xlm-r-100langs-bert-base-nli-stsb-mean-tokens'\n", "\n", "train_batch_size = args.train_batch_size\n", "num_epochs = args.epochs\n", "model_save_path = f'{args.model_dir}/training_sts_'+model_name.replace(\"/\", \"-\")+'-'+datetime.now().strftime(\"%Y-%m-%d_%H-%M-%S\")\n", "logger.info(model_save_path)\n", "\n", "# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings\n", "word_embedding_model = models.Transformer(model_name)" ] }, { "cell_type": "markdown", "id": "4b92c650-3ae4-4319-a756-8457b8f03c97", "metadata": {}, "source": [ "문장 ìž„ë² ë”©ì„ ê³„ì‚°í•˜ê¸° 위한 Pooler를 ì •ì˜í•©ë‹ˆë‹¤. BERTë¡œ 분류 태스í¬ë¥¼ ìˆ˜í–‰í• ë•ŒëŠ” 첫 번째 [CLS] í† í°ì˜ ì¶œë ¥ 벡터를 ìž„ë² ë”© 벡터로 사용하지만, SBERTì—서는 BERTì˜ ëª¨ë“ í† í°ë“¤ì˜ ì¶œë ¥ ë²¡í„°ë“¤ì„ ì‚¬ìš©í•˜ì—¬ ìž„ë² ë”© 벡터를 계산합니다. ì´ ë•Œ mean poolingì´ë‚˜ max poolingì„ ì‚¬ìš©í• ìˆ˜ 있으며, 본 ì˜ˆì œì—서는 mean poolingì„ ì‚¬ìš©í•©ë‹ˆë‹¤." ] }, { "cell_type": "code", "execution_count": 11, "id": "b01d0e7c-0ac6-410a-b07d-9fdbea08162c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2022-07-20 08:49:35 - Use pytorch device: cuda\n" ] } ], "source": [ "# Apply mean pooling to get one fixed sized sentence vector\n", "pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),\n", " pooling_mode_mean_tokens=True,\n", " pooling_mode_cls_token=False,\n", " pooling_mode_max_tokens=False)\n", "\n", "model = SentenceTransformer(modules=[word_embedding_model, pooling_model])" ] }, { "cell_type": "markdown", "id": "c317c627-4ee2-4eca-bd12-dc25a76dfef2", "metadata": {}, "source": [ "ëª¨ë¸ í›ˆë ¨ ë° ê²€ì¦ì— 필요한 í´ëž˜ìŠ¤ ì¸ìŠ¤í„´ìŠ¤ë¥¼ ìƒì„±í•©ë‹ˆë‹¤. ë² ì´ìŠ¤ë¼ì¸ìœ¼ë¡œ 사용ë˜ëŠ” ê²€ì¦ ì§€í‘œëŠ” ë‘ ë¬¸ìž¥ì˜ ìž„ë² ë”© ë²¡í„°ì˜ ìœ ì‚¬ë„를 산출하는 ì½”ì‚¬ì¸ ìœ ì‚¬ë„입니다." ] }, { "cell_type": "code", "execution_count": 12, "id": "f20d8263-04b6-498e-aa67-55dd266274b7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2022-07-20 08:49:36 - Warmup-steps: 55\n" ] } ], "source": [ "train_dataset = SentencesDataset(train_samples, model)\n", "train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)\n", "train_loss = losses.CosineSimilarityLoss(model=model)\n", "\n", "evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='sts-dev')\n", "\n", "warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up\n", "logger.info(\"Warmup-steps: {}\".format(warmup_steps))" ] }, { "cell_type": "markdown", "id": "5e178954-f94a-4f8c-8a05-6c396d10f9e6", "metadata": {}, "source": [ "í›ˆë ¨ì„ ìˆ˜í–‰í•©ë‹ˆë‹¤. 분산 í›ˆë ¨ì„ ìˆ˜í–‰í•˜ì§€ëŠ” 않지만, ë°ì´í„° ë³¼ë¥¨ì´ í¬ì§€ 않으므로 수 분 ë‚´ì— í›ˆë ¨ì´ ì™„ë£Œë©ë‹ˆë‹¤." ] }, { "cell_type": "markdown", "id": "3e990ec0-0bd7-40dd-b276-0d91a3758188", "metadata": {}, "source": [ "### Start Training" ] }, { "cell_type": "code", "execution_count": 13, "id": "b3905a80-e2c9-4295-a34b-70f12920dad5", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5bc3d1b44f9c467cb1dfda0553397652", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Epoch: 0%| | 0/1 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "86c0606e96594b7b81c4ca1e21dd240c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Iteration: 0%| | 0/545 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "2022-07-20 08:50:25 - EmbeddingSimilarityEvaluator: Evaluating the model on sts-dev dataset in epoch 0 after 272 steps:\n", "2022-07-20 08:50:28 - Cosine-Similarity :\tPearson: 0.8458\tSpearman: 0.8462\n", "2022-07-20 08:50:28 - Manhattan-Distance:\tPearson: 0.8333\tSpearman: 0.8371\n", "2022-07-20 08:50:28 - Euclidean-Distance:\tPearson: 0.8339\tSpearman: 0.8380\n", "2022-07-20 08:50:28 - Dot-Product-Similarity:\tPearson: 0.8095\tSpearman: 0.8114\n", "2022-07-20 08:50:28 - Save model to /home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/model/training_sts_sentence-transformers-xlm-r-100langs-bert-base-nli-stsb-mean-tokens-2022-07-20_08-43-44\n", "2022-07-20 08:51:12 - EmbeddingSimilarityEvaluator: Evaluating the model on sts-dev dataset in epoch 0 after 544 steps:\n", "2022-07-20 08:51:15 - Cosine-Similarity :\tPearson: 0.8511\tSpearman: 0.8513\n", "2022-07-20 08:51:15 - Manhattan-Distance:\tPearson: 0.8378\tSpearman: 0.8416\n", "2022-07-20 08:51:15 - Euclidean-Distance:\tPearson: 0.8383\tSpearman: 0.8425\n", "2022-07-20 08:51:15 - Dot-Product-Similarity:\tPearson: 0.8112\tSpearman: 0.8140\n", "2022-07-20 08:51:15 - Save model to /home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/model/training_sts_sentence-transformers-xlm-r-100langs-bert-base-nli-stsb-mean-tokens-2022-07-20_08-43-44\n", "2022-07-20 08:51:24 - EmbeddingSimilarityEvaluator: Evaluating the model on sts-dev dataset after epoch 0:\n", "2022-07-20 08:51:27 - Cosine-Similarity :\tPearson: 0.8511\tSpearman: 0.8513\n", "2022-07-20 08:51:27 - Manhattan-Distance:\tPearson: 0.8378\tSpearman: 0.8416\n", "2022-07-20 08:51:27 - Euclidean-Distance:\tPearson: 0.8383\tSpearman: 0.8425\n", "2022-07-20 08:51:27 - Dot-Product-Similarity:\tPearson: 0.8112\tSpearman: 0.8140\n" ] } ], "source": [ "# Train the model\n", "model.fit(\n", " train_objectives=[(train_dataloader, train_loss)],\n", " evaluator=evaluator,\n", " epochs=num_epochs,\n", " evaluation_steps=int(len(train_dataloader)*0.5),\n", " warmup_steps=warmup_steps,\n", " output_path=model_save_path,\n", " use_amp=True\n", ")" ] }, { "cell_type": "markdown", "id": "e370591f-cb30-4b2e-a637-c664ef8fc918", "metadata": {}, "source": [ "<br>\n", "\n", "## 4. Evaluation\n", "---\n", "í›ˆë ¨ì´ ì™„ë£Œë˜ì—ˆë‹¤ë©´, 테스트 ë°ì´í„°ì…‹ìœ¼ë¡œ 예측 ì„±ëŠ¥ì„ ë³¼ 수 있는 ì§€í‘œë“¤ì„ ì‚°ì¶œí•©ë‹ˆë‹¤." ] }, { "cell_type": "code", "execution_count": 14, "id": "1b94aa7e-ff24-453c-a0c5-28ee104b89c5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2022-07-20 08:51:47 - Read KorSTS test dataset\n" ] } ], "source": [ "test_samples = []\n", "logger.info(\"Read KorSTS test dataset\") \n", "with open(f'{args.test_dir}/sts-test.tsv', 'rt', encoding='utf8') as fIn:\n", " reader = csv.DictReader(fIn, delimiter='\\t', quoting=csv.QUOTE_NONE)\n", " for row in reader:\n", " if row[\"sentence1\"] and row[\"sentence2\"]: \n", " score = float(row['score']) / 5.0 # Normalize score to range 0 ... 1\n", " inp_example = InputExample(texts=[row['sentence1'], row['sentence2']], label=score)\n", " test_samples.append(inp_example) " ] }, { "cell_type": "code", "execution_count": 15, "id": "98a96778-da38-41d7-b617-25db5b721c19", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2022-07-20 08:51:48 - Load pretrained SentenceTransformer: /home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/model/training_sts_sentence-transformers-xlm-r-100langs-bert-base-nli-stsb-mean-tokens-2022-07-20_08-43-44\n", "2022-07-20 08:51:51 - Use pytorch device: cuda\n", "2022-07-20 08:51:51 - EmbeddingSimilarityEvaluator: Evaluating the model on sts-test dataset:\n", "2022-07-20 08:51:53 - Cosine-Similarity :\tPearson: 0.8287\tSpearman: 0.8310\n", "2022-07-20 08:51:53 - Manhattan-Distance:\tPearson: 0.8242\tSpearman: 0.8283\n", "2022-07-20 08:51:53 - Euclidean-Distance:\tPearson: 0.8245\tSpearman: 0.8287\n", "2022-07-20 08:51:53 - Dot-Product-Similarity:\tPearson: 0.7619\tSpearman: 0.7608\n" ] }, { "data": { "text/plain": [ "0.8309806357819561" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "##############################################################################\n", "# Load the stored model and evaluate its performance on STS benchmark dataset\n", "##############################################################################\n", "\n", "model = SentenceTransformer(model_save_path)\n", "test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name='sts-test')\n", "test_evaluator(model, output_path=model_save_path)" ] }, { "cell_type": "markdown", "id": "a545fcc7-785c-47ff-8585-f148ef4b9786", "metadata": {}, "source": [ "<br>\n", "\n", "## 5. Applications\n", "---" ] }, { "cell_type": "code", "execution_count": 16, "id": "239fbb3b-892b-4833-8720-754cb251f456", "metadata": {}, "outputs": [], "source": [ "import random\n", "import time\n", "from operator import itemgetter \n", "\n", "def get_faiss_index(emb, data, dim=768):\n", " import faiss\n", " n_gpus = torch.cuda.device_count()\n", "\n", " if n_gpus == 0:\n", " # Create the Inner Product Index\n", " index = faiss.IndexFlatIP(dim)\n", " else:\n", " flat_config = []\n", " res = [faiss.StandardGpuResources() for i in range(n_gpus)]\n", " for i in range(n_gpus):\n", " cfg = faiss.GpuIndexFlatConfig()\n", " cfg.useFloat16 = False\n", " cfg.device = i\n", " flat_config.append(cfg)\n", "\n", " index = faiss.GpuIndexFlatIP(res[0], dim, flat_config[0])\n", "\n", " index = faiss.IndexIDMap(index)\n", " index.add_with_ids(emb, np.array(range(0, len(data)))) \n", " return index\n", "\n", "\n", "def search(model, query, data, index, k=5, random_select=False, verbose=True):\n", " t = time.time()\n", " query_vector = model.encode(query)\n", " dists, top_k_inds = index.search(query_vector, k)\n", " if verbose:\n", " print('total time: {}'.format(time.time() - t))\n", " results = [itemgetter(*ind)(data) for ind in top_k_inds] \n", " \n", " if random_select:\n", " return [random.choice(r) for r in results]\n", " else:\n", " return results" ] }, { "cell_type": "markdown", "id": "7aa1864d-6c57-499f-a7ac-fa4ba1e37672", "metadata": {}, "source": [ "### Chatbot\n", "\n", "ì±—ë´‡ì€ í¬ê²Œ ë‘ ê°€ì§€ 형태로 개발합니다. 1) ìƒì„± 모ë¸ì„ 사용하여 해당 ì§ˆë¬¸ì— ëŒ€í•œ ì°½ì˜ì ì¸ ë‹µë³€ì„ ìƒì„±í•˜ê±°ë‚˜, 2) ìˆ˜ë§Žì€ ì§ˆë¬¸-답변 리스트들 중 ì§ˆë¬¸ì— ë¶€í•©í•˜ëŠ” 질문 í›„ë³´ë“¤ì„ ì¶”ë¦° ë‹¤ìŒ í•´ë‹¹ í›„ë³´ì— ì í•©í•œ ë‹µë³€ì„ ì°¾ëŠ” ë°©ì‹ì´ì£ .\n", "본 í•¸ì¦ˆì˜¨ì€ 2)ì˜ ë°©ë²•ìœ¼ë¡œ 간단하게 ì±—ë´‡ 예시를 보여드립니다. 질문 í…스트를 ìž…ë ¥ìœ¼ë¡œ 받으면, 해당 ì§ˆë¬¸ì˜ ìž„ë² ë”©ì„ ê³„ì‚°í•˜ì—¬ 질문 ìž„ë² ë”©ê³¼ ëª¨ë“ ì§ˆë¬¸ ë¦¬ìŠ¤íŠ¸ì˜ ìž„ë² ë”©ì„ ë¹„êµí•˜ì—¬ ìœ ì‚¬ë„ê°€ 가장 ë†’ì€ ì§ˆë¬¸ í›„ë³´ë“¤ì„ ì°¾ê³ , ê° í›„ë³´ì— ë§¤ì¹ë˜ëŠ” ë‹µë³€ì„ ì°¾ìŠµë‹ˆë‹¤.\n", "\n", "ì½”ì‚¬ì¸ ìœ ì‚¬ë„를 ì§ì ‘ ê³„ì‚°í• ìˆ˜ë„ ìžˆì§€ë§Œ, 페ì´ìŠ¤ë¶ì—ì„œ 개발한 Faiss ë¼ì´ë¸ŒëŸ¬ë¦¬ (https://github.com/facebookresearch/faiss) 를 사용하면 훨씬 ë¹ ë¥¸ ì†ë„ë¡œ ê³„ì‚°í• ìˆ˜ 있습니다. Faiss는 Product ì–‘ìží™” ì•Œê³ ë¦¬ì¦˜ì„ GPUë¡œ ë”ìš± ë¹ ë¥´ê²Œ 구현한 ë¼ì´ë¸ŒëŸ¬ë¦¬ë¡œ, ì •ë³´ ì†ì‹¤ì„ 가급ì 줄ì´ë©´ì„œ ìž„ë² ë”© 벡터를 ì¸ë±ì‹±í•©ë‹ˆë‹¤.\n", "\n", "References\n", "- Billion-scale similarity search with GPUs: https://arxiv.org/pdf/1702.08734.pdf\n", "- Product Quantizers for k-NN Tutorial Part 1: https://mccormickml.com/2017/10/13/product-quantizer-tutorial-part-1\n", "- Product Quantizers for k-NN Tutorial Part 1: http://mccormickml.com/2017/10/22/product-quantizer-tutorial-part-2" ] }, { "cell_type": "markdown", "id": "f8cb6a25-6e3a-4606-8bcb-f9bc647bc28e", "metadata": {}, "source": [ "#### Preparing chatbot dataset" ] }, { "cell_type": "code", "execution_count": 17, "id": "acbe5f54-c885-4324-8c56-f617c13085e5", "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Q</th>\n", " <th>A</th>\n", " <th>label</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>12ì‹œ ë•¡!</td>\n", " <td>하루가 ë˜ ê°€ë„¤ìš”.</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>1ì§€ë§ í•™êµ ë–¨ì–´ì¡Œì–´</td>\n", " <td>위로해 드립니다.</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>3ë°•4ì¼ ë†€ëŸ¬ê°€ê³ ì‹¶ë‹¤</td>\n", " <td>ì—¬í–‰ì€ ì–¸ì œë‚˜ ì¢‹ì£ .</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>3ë°•4ì¼ ì •ë„ ë†€ëŸ¬ê°€ê³ ì‹¶ë‹¤</td>\n", " <td>ì—¬í–‰ì€ ì–¸ì œë‚˜ ì¢‹ì£ .</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>PPL 심하네</td>\n", " <td>ëˆˆì‚´ì´ ì°Œí‘¸ë ¤ì§€ì£ .</td>\n", " <td>0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Q A label\n", "0 12ì‹œ ë•¡! 하루가 ë˜ ê°€ë„¤ìš”. 0\n", "1 1ì§€ë§ í•™êµ ë–¨ì–´ì¡Œì–´ 위로해 드립니다. 0\n", "2 3ë°•4ì¼ ë†€ëŸ¬ê°€ê³ ì‹¶ë‹¤ ì—¬í–‰ì€ ì–¸ì œë‚˜ ì¢‹ì£ . 0\n", "3 3ë°•4ì¼ ì •ë„ ë†€ëŸ¬ê°€ê³ ì‹¶ë‹¤ ì—¬í–‰ì€ ì–¸ì œë‚˜ ì¢‹ì£ . 0\n", "4 PPL 심하네 ëˆˆì‚´ì´ ì°Œí‘¸ë ¤ì§€ì£ . 0" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "urllib.request.urlretrieve(\"https://raw.githubusercontent.com/songys/Chatbot_data/master/ChatbotData.csv\", \n", " filename=f\"{args.train_dir}/chatbot-train.csv\")\n", "chatbot_df = pd.read_csv(f'{args.train_dir}/chatbot-train.csv')\n", "chatbot_df.head()" ] }, { "cell_type": "markdown", "id": "08b8b5b7-b8bf-467f-adc8-728a1719cff3", "metadata": {}, "source": [ "#### Embedding" ] }, { "cell_type": "code", "execution_count": 18, "id": "3770cf95-ef8e-47d5-b158-b5f7f523228d", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "aed7a672408642c6b7a315db01f13670", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/185 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "chatbot_q_data = chatbot_df['Q'].tolist()\n", "chatbot_a_data = chatbot_df['A'].tolist()\n", "chatbot_emb = model.encode(chatbot_q_data, normalize_embeddings=True, batch_size=64, show_progress_bar=True)" ] }, { "cell_type": "markdown", "id": "38d7cf2c-664d-413a-9f52-83cd415cae34", "metadata": {}, "source": [ "#### Indexing the dataset" ] }, { "cell_type": "code", "execution_count": 19, "id": "179195c4-137d-4c75-b77d-f30f4b3b1155", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2022-07-20 08:52:14 - Loading faiss with AVX2 support.\n", "2022-07-20 08:52:14 - Could not load library with AVX2 support due to:\n", "ModuleNotFoundError(\"No module named 'faiss.swigfaiss_avx2'\")\n", "2022-07-20 08:52:14 - Loading faiss.\n", "2022-07-20 08:52:14 - Successfully loaded faiss.\n" ] } ], "source": [ "chatbot_index = get_faiss_index(chatbot_emb, chatbot_q_data)" ] }, { "cell_type": "markdown", "id": "e5d8384b-7623-4344-bbd1-dabb1a31b0c6", "metadata": {}, "source": [ "#### Inference\n", "샘플 ì§ˆë¬¸ë“¤ì— ëŒ€í•œ ì¶”ë¡ ì„ ìˆ˜í–‰í•©ë‹ˆë‹¤. " ] }, { "cell_type": "code", "execution_count": 36, "id": "84f161c7-c078-41ce-9a5b-1ba9aff13fb8", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2eb8a3d695f44c5b959a59685fab41c1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/1 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "total time: 0.03067159652709961\n" ] }, { "data": { "text/plain": [ "['ì¢‹ì€ ì‹œê°„ 보내시길 ë°”ë¼ìš”.']" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "query = ['커피 ë¼ë–¼ ë§ˆì‹œê³ ì‹¶ì–´']\n", "search(model, query, chatbot_a_data, chatbot_index, random_select=True)" ] }, { "cell_type": "code", "execution_count": 37, "id": "e8e46dbf-7e0b-4b62-a700-fdd554264087", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4467da19586c426bb86720a793707dd3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/1 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "total time: 0.0308225154876709\n" ] }, { "data": { "text/plain": [ "['ë‚®ìž ì„ ìž ê¹ ìžë„ 괜찮아요.', 'ê°™ì´ ë†€ì•„ìš”.']" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "query = ['너무 ì¡¸ë ¤', 'ë†€ê³ ì‹¶ì–´']\n", "search(model, query, chatbot_a_data, chatbot_index, random_select=True)" ] }, { "cell_type": "markdown", "id": "cce7a866-49e9-4415-af15-6ec389a9138f", "metadata": {}, "source": [ "### Semantic Search (News)\n", "\n", "\n", "시멘틱(ì˜ë¯¸) ê²€ìƒ‰ì€ ê²€ìƒ‰ 쿼리가 키워드를 찾는 것ë¿ë§Œ 아니ë¼, ê²€ìƒ‰ì— ì‚¬ìš©ë˜ëŠ” ë‹¨ì–´ì˜ ì˜ë„와 문맥ì ì˜ë¯¸ë¥¼ 파악하는 ê²ƒì„ ëª©í‘œë¡œ 합니다.\n", "시멘틱 ìœ ì‚¬ë„ ê²€ìƒ‰ì„ ë˜í•œ ìƒê¸° ì±—ë´‡ 예시와 마찬가지로, 해당 검색 쿼리를 ìž…ë ¥í•˜ë©´, 검색 ì¿¼ë¦¬ì˜ ìž„ë² ë”©ì„ ê³„ì‚°í•˜ì—¬ ëª¨ë“ ë¬¸ì„œ(예: 뉴스 ì œëª©/요약, 웹페ì´ì§€ ì œëª©/요약) ë¦¬ìŠ¤íŠ¸ì˜ ìž„ë² ë”©ì„ ë¹„êµí•˜ì—¬ 가장 ìœ ì‚¬ë„ê°€ ë†’ì€ ë¬¸ì„œ í›„ë³´ë“¤ì„ ì°¾ìŠµë‹ˆë‹¤.\n", "\n", "References\n", "- Billion-scale semantic similarity search with FAISS+SBERT: https://towardsdatascience.com/billion-scale-semantic-similarity-search-with-faiss-sbert-c845614962e2\n", "- Korean Contemporary Corpus of Written Sentences: http://nlp.kookmin.ac.kr/kcc/" ] }, { "cell_type": "markdown", "id": "832c06db-3e7f-4497-90f9-b70f5f1ba81c", "metadata": {}, "source": [ "#### Preparing news dataset" ] }, { "cell_type": "code", "execution_count": 38, "id": "f912ef7e-afd0-4d2d-b25e-f48ee59dfe99", "metadata": {}, "outputs": [], "source": [ "import progressbar\n", "\n", "class MyProgressBar():\n", " def __init__(self):\n", " self.pbar = None\n", "\n", " def __call__(self, block_num, block_size, total_size):\n", " if not self.pbar:\n", " self.pbar=progressbar.ProgressBar(maxval=total_size)\n", " self.pbar.start()\n", "\n", " downloaded = block_num * block_size\n", " if downloaded < total_size:\n", " self.pbar.update(downloaded)\n", " else:\n", " self.pbar.finish()" ] }, { "cell_type": "code", "execution_count": 39, "id": "6b5b0d49-16fa-4080-bf88-21cd64ae1cbe", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100% |########################################################################|\n" ] }, { "data": { "text/plain": [ "('/home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/train/KCCq28_Korean_sentences_EUCKR_v2.zip',\n", " <http.client.HTTPMessage at 0x7f53a416d3a0>)" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "url = 'http://nlp.kookmin.ac.kr/kcc/KCCq28_Korean_sentences_EUCKR_v2.zip'\n", "news_path = f'{args.train_dir}/KCCq28_Korean_sentences_EUCKR_v2.zip'\n", "urllib.request.urlretrieve(url, news_path, MyProgressBar())" ] }, { "cell_type": "code", "execution_count": 40, "id": "3323ca38-b3d0-4d21-bf22-f527b164e6a0", "metadata": {}, "outputs": [], "source": [ "import zipfile\n", "with zipfile.ZipFile(news_path, 'r') as zip_ref:\n", " zip_ref.extractall(train_dir)" ] }, { "cell_type": "code", "execution_count": 41, "id": "e07ea44c-59c0-413d-8ec4-c544287e24d7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" ] } ], "source": [ "!rm -rf {news_path}" ] }, { "cell_type": "code", "execution_count": 42, "id": "5679f7cb-28a4-4d68-bfc3-ed127cc91809", "metadata": {}, "outputs": [], "source": [ "news_data = []\n", "f = open(f'{args.train_dir}/KCCq28_Korean_sentences_EUCKR_v2.txt', 'rt', encoding='cp949')\n", "lines = f.readlines()\n", "for line in lines:\n", " line = line.strip()\n", " news_data.append(line)\n", "f.close()" ] }, { "cell_type": "code", "execution_count": 43, "id": "65705cff-abe8-47ae-a6ac-6b7a2d289a1e", "metadata": {}, "outputs": [], "source": [ "news_data = news_data[:10000] # For debug purpose" ] }, { "cell_type": "markdown", "id": "3ac03cb7-b95b-416d-af12-5b81491c7ca8", "metadata": {}, "source": [ "#### Embedding" ] }, { "cell_type": "code", "execution_count": 44, "id": "489d2ba2-3c58-4bb1-8587-67ea81fd42dc", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a08694d72f084b4b94724c87ddd4ac15", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/157 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "news_emb = model.encode(news_data, normalize_embeddings=True, batch_size=64, show_progress_bar=True)" ] }, { "cell_type": "markdown", "id": "3dfc9adf-e1e8-412e-91c0-46778dbd608f", "metadata": {}, "source": [ "#### Indexing the dataset" ] }, { "cell_type": "code", "execution_count": 45, "id": "d6bdc627-114d-47d5-844a-2702c6343af9", "metadata": {}, "outputs": [], "source": [ "news_index = get_faiss_index(news_emb, news_data)" ] }, { "cell_type": "markdown", "id": "0d1a914e-7ccc-49e9-adc3-174d6231eb46", "metadata": {}, "source": [ "#### Inference\n", "샘플 ì§ˆë¬¸ë“¤ì— ëŒ€í•œ ì¶”ë¡ ì„ ìˆ˜í–‰í•©ë‹ˆë‹¤. " ] }, { "cell_type": "code", "execution_count": 46, "id": "3203cbdd-a1d4-4c4e-8f08-2de03702a247", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5e949ee48ecb447aa633119f28f20897", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/1 [00:00<?, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "total time: 0.031549692153930664\n" ] }, { "data": { "text/plain": [ "[('ì•„ì´ì„œí”Œë¼ì´ëŠ” \"ì•„ì´í°4ì— ì‚¬ìš©í•œ \\'A4\\'와 마찬가지로 \\'ë‹¤ì´ ë§ˆí¬\\'ë¡œ ë³¼ ë•Œ ì‚¼ì„±ì „ìžê°€ ë§Œë“ ê²ƒìœ¼ë¡œ ë³´ì¸ë‹¤\"ê³ í–ˆë‹¤.',\n", " 'ì•„ì´ìŠ¤í¬ë¦¼ì„ ì„œë¹„ìŠ¤í•˜ê³ ìžˆëŠ” ì‹œê³µê·¸ë£¹ì˜ ë°•ê¸°ì„ íšŒìž¥ì€ \"콘í…ì¸ ì œê³µì„œë¹„ìŠ¤ 품질ì¸ì¦ìœ¼ë¡œ ì•„ì´ìŠ¤í¬ë¦¼ 서비스와 콘í…ì¸ ì— ëŒ€í•´ ì‹ ë¢°ë„와 ê³µì‹ ë ¥ì„ ì¸ì •ë°›ì•˜ë‹¤.',\n", " 'ì—˜ë¦¬ì—‡ì˜ ìžíšŒì‚¬ì¸ ë¸”ë ˆì´í¬ ìºí”¼íƒˆê³¼ í¬í„° ìºí”¼íƒˆì€ í™ë³´ëŒ€í–‰ì‚¬ë¥¼ 통해 \"ì‚¼ì„±ì „ìžê°€ ì œì‹œí•œ 개략ì ì¸ ì£¼ì£¼ê°€ì¹˜ ì œê³ ë°©ì•ˆì´ í–¥í›„ íšŒì‚¬ì— ê±´ì„¤ì ì¸ ì²« 걸ìŒì´ ë 것으로 ë³´ê³ ìžˆë‹¤\"ê³ ë°í˜”다.',\n", " 'ê²€í† ê²°ê³¼ì— ë”°ë¼ ê²°ë¡ ì„ ë‚´ë¦´ 것\"ì´ë¼ê³ ë§í–ˆë‹¤.',\n", " 'ì¡° 대사는 \"대미 ì˜ì¡´ë„를 줄ì´ë ¤ê³ ë…¸ë ¥í•˜ê³ ìžˆëŠ” ìºë‚˜ë‹¤ëŠ” ëˆˆê¸¸ì„ ì•„ì‹œì•„Â·íƒœí‰ì–‘ 지ì—으로 ëŒë¦¬ê³ 있다\"ë©° \"특히 ì´ ì§€ì—ì—서는 ìºë‚˜ë‹¤ì™€ FTA를 ë§ºì€ êµê°€ê°€ 없어 ë”ìš± ìƒì§•ì ì˜ë¯¸ê°€ 있다\"ê³ ë§í–ˆë‹¤.',\n", " 'ë°• ëŒ€í†µë ¹ì€ ë˜ \"ì ê·¹ì ì¸ ì„¸ì¼ì¦ˆ ì™¸êµ ëŒ€í†µë ¹ìœ¼ë¡œ ë‚˜ì„œê² ë‹¤\"는 ì ì„ ë¶„ëª…ížˆ 했다.',\n", " 'ì• í”Œì´ í˜‘ë ¥ì—…ì²´ì˜ ì¸ë ¥ì„ ì˜ìž…하는 ì´ìœ ì— ëŒ€í•´ ì•„ì´í°ì¸ ìºë‚˜ë‹¤ëŠ” \"ì• í”Œì´ GPU ìžì²´ ê°œë°œì„ ì„œë‘르기 위한 수순\"ì´ë¼ê³ 설명했다.'),\n", " ('ì´íšŒì„± IPCC ì˜ìž¥ì€ \"온실가스 ë°°ì¶œëŸ‰ì´ ëŠ˜ê³ ìžˆë‹¤ëŠ” ì‚¬ì‹¤ì´ ì§€êµ¬ì˜¨ë‚œí™”ì˜ ê³¼í•™ì 근거\"ë¼ê³ ë§í–ˆë‹¤.',\n", " '다ìŒì€ ì¼ë¬¸ ì¼ë‹µ.íŒŒë¦¬ê¸°í›„í˜‘ì •ì˜ ì˜ë¯¸ëŠ” 무엇ì¸ê°€.\"2015ë…„ 12ì›” 프랑스 파리ì—ì„œ 열린 기후변화협약 당사êµì´íšŒì—ì„œ 채íƒëê³ , 지난해 11ì›” ë°œíš¨ëœ íŒŒë¦¬ê¸°í›„í˜‘ì •ì€ 2020년부터 ì„ ì§„êµë¿ë§Œ ì•„ë‹ˆë¼ ê°œë°œë„ìƒêµë„ ì˜ë¬´ì 으로 온실가스 ê°ì¶•ì— 참여하ë„ë¡ í–ˆë‹¤.',\n", " '미êµì—ì„œ 활ë™í•˜ê³ 있는 중êµê³„ í•™ìž ìœ ì‚¬ì˜¤ì¹´ì´ëŠ” ì´ ë…¼ë¬¸ì—ì„œ \"ë„ì‹¬ì˜ ê³ ì¸µë¹Œë”© ê¼ëŒ€ê¸°ì—ì„œ 아래쪽으로 미세한 ë¬¼ë°©ìš¸ì„ ë¿Œë¦°ë‹¤ë©´ 중êµì˜ 극심한 초미세먼지까지 낮출 수 있다\"ê³ ì£¼ìž¥í–ˆë‹¤.',\n", " 'ê·¸ë ‡ë‹¤ë©´ 대기 중 미세먼지 ë†ë„ê°€ 심한 ë‚ ì—” 어떻게 해야 í• ê¹Œ.류연기 환경부 ìƒí™œí™˜ê²½ê³¼ìž¥ì€ \"대기 중 미세먼지 ë†ë„ê°€ ë†’ì€ ë‚ ì— ìš”ë¦¬ë¥¼ í• ê²½ìš°ì—” ìš°ì„ ì£¼ë°© 환í’기를 사용해 í™˜ê¸°í•˜ê³ , 요리 후엔 ìž ì‹œ ë™ì•ˆ ì°½ë¬¸ì„ ì—´ì–´ ë‘는 ê²ƒì´ ì¢‹ë‹¤\"ê³ ë§í–ˆë‹¤.',\n", " 'êµë°©ë¶€ 관계ìžëŠ” \"í탄약 처리 ì‹œì„¤ì€ í•œêµ í™˜ê²½ ê´€ë ¨ ë²•ë¥ ì— ë”°ë¼ ì¹œí™˜ê²½ì ì´ê³ 한미 ì–‘êµì˜ í•©ì˜ê°ì„œì— ë”°ë¼ ì—„ê²©í•˜ê²Œ ìš´ì˜ë 것ì´ë‹¤\"ê³ ë§í–ˆë‹¤.',\n", " 'SK하ì´ë‹‰ìŠ¤ 관계ìžëŠ” \"압축공기보다 ë¯¸ëŸ‰ì˜ ì§ˆì†Œë¥¼ 지ì†ì 으로 í˜ë ¤ 스í¬ëŸ¬ë²„ ë‚´ë¶€ì˜ ìœ í•´ë¬¼ì§ˆì„ ë‹¦ì•„ë‚´ê³ ìžˆëŠ”ë°, 질소가 ì§ì ‘ì ì¸ ì›ì¸ì¸ì§€ëŠ” ìˆ˜ì‚¬ë‹¹êµ ë“±ì˜ ì •ë°€í•œ 조사가 필요하다\"ê³ ì„¤ëª…í–ˆë‹¤.',\n", " '회사 ì¸¡ì€ \"슬러지 100%ìžê°€ì²˜ë¦¬, ì—너지비용, 온실가스 배출량 ì ˆê°ì„ 위한 목ì \"ì´ë¼ê³ 설명했다.')]" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "query =['ì•„ì´ìŠ¤ ë¼ë–¼', '미세먼지']\n", "search(model, query, news_data, news_index, k=7, random_select=False)" ] } ], "metadata": { "kernelspec": { "display_name": "conda_pytorch_p38", "language": "python", "name": "conda_pytorch_p38" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }