{ "cells": [ { "cell_type": "markdown", "id": "94ed44ac-98e5-4200-8f1c-3adedec2866b", "metadata": {}, "source": [ "# Lab 1: Sentence-BERT (SBERT) Training\n", "\n", "### Fine-tuning Sentence-BERT (SBERT) & SBERT Embedding for applications\n", "---" ] }, { "cell_type": "markdown", "id": "2d897cd6-a127-4a1c-9980-aa033c332f6a", "metadata": {}, "source": [ "\n", "## Introduction\n", "---\n", "\n", "본 모듈에서는 문장 임베딩을 산출하는 Sentence-BERT 모델을 STS 데이터셋으로 파인튜닝해 봅니다.\n", "SentenceTransformers 패키지를 사용하면 파인튜닝을 쉽게 수행할 수 있습니다. 다만, 현 시점에는 분산 훈련 기능 지원이 잘 되지 않으므로, 대용량 데이터셋으로 파인튜닝하는 니즈가 있다면 커스텀 훈련 코드를 직접 작성하셔야 합니다.\n", "\n", "***[Note] SageMaker Studio Lab, SageMaker Studio, SageMaker 노트북 인스턴스, 또는 여러분의 로컬 머신에서 이 데모를 실행할 수 있습니다. SageMaker Studio Lab을 사용하는 경우 GPU를 활성화하세요.***\n", "\n", "### References\n", "\n", "- Hugging Face Tutorial: https://huggingface.co/docs/transformers/training\n", "- Sentence-BERT paper: https://arxiv.org/abs/1908.10084\n", "- SentenceTransformers: https://www.sbert.net" ] }, { "cell_type": "markdown", "id": "577d681d-0f93-44ea-a3d1-0e0a9bb5e063", "metadata": {}, "source": [ "\n", "## 1. Setup Environments\n", "---\n", "\n", "### Import modules" ] }, { "cell_type": "code", "execution_count": 1, "id": "91641d7f-95e5-473b-b144-823d4a19a299", "metadata": {}, "outputs": [], "source": [ "# !pip install sentence_transformers datasets faiss-gpu progressbar" ] }, { "cell_type": "code", "execution_count": 2, "id": "c6d4b001-3508-4ca1-b73b-76b11fdcce7a", "metadata": {}, "outputs": [], "source": [ "import os\n", "import sys\n", "import json\n", "import logging\n", "import argparse\n", "import torch\n", "import gzip\n", "import csv\n", "import math\n", "import urllib\n", "from torch import nn\n", "import numpy as np\n", "import pandas as pd\n", "from tqdm import tqdm\n", "\n", "from datetime import datetime\n", "from datasets import load_dataset\n", "from torch.utils.data import DataLoader\n", "from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses, models, util\n", "from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator\n", "from sentence_transformers.readers import InputExample\n", "from transformers.trainer_utils import get_last_checkpoint\n", "\n", "logging.basicConfig(\n", " level=logging.INFO,\n", " format='%(asctime)s - %(message)s',\n", " datefmt='%Y-%m-%d %H:%M:%S',\n", " handlers=[LoggingHandler()]\n", ")\n", "\n", "logger = logging.getLogger(__name__)" ] }, { "cell_type": "markdown", "id": "a1215e31-b560-4ea9-8fd2-eb097c9dd30b", "metadata": {}, "source": [ "### Argument parser" ] }, { "cell_type": "code", "execution_count": 3, "id": "c512803f-d734-4edd-a868-0f369f6b3888", "metadata": {}, "outputs": [], "source": [ "def parser_args(train_notebook=False):\n", " parser = argparse.ArgumentParser()\n", "\n", " # Default Setting\n", " parser.add_argument(\"--epochs\", type=int, default=1)\n", " parser.add_argument(\"--seed\", type=int, default=42)\n", " parser.add_argument(\"--train_batch_size\", type=int, default=32)\n", " parser.add_argument(\"--eval_batch_size\", type=int, default=32)\n", " parser.add_argument(\"--warmup_steps\", type=int, default=100)\n", " parser.add_argument(\"--logging_steps\", type=int, default=100)\n", " parser.add_argument(\"--learning_rate\", type=str, default=5e-5)\n", " parser.add_argument(\"--disable_tqdm\", type=bool, default=False)\n", " parser.add_argument(\"--fp16\", type=bool, default=True)\n", " parser.add_argument(\"--tokenizer_id\", type=str, default='sentence-transformers/xlm-r-100langs-bert-base-nli-stsb-mean-tokens')\n", " parser.add_argument(\"--model_id\", type=str, default='sentence-transformers/xlm-r-100langs-bert-base-nli-stsb-mean-tokens')\n", " \n", " # SageMaker Container environment\n", " parser.add_argument(\"--output_data_dir\", type=str, default=os.environ[\"SM_OUTPUT_DATA_DIR\"])\n", " parser.add_argument(\"--model_dir\", type=str, default=os.environ[\"SM_MODEL_DIR\"])\n", " parser.add_argument(\"--n_gpus\", type=str, default=os.environ[\"SM_NUM_GPUS\"])\n", " parser.add_argument(\"--train_dir\", type=str, default=os.environ[\"SM_CHANNEL_TRAIN\"])\n", " parser.add_argument(\"--valid_dir\", type=str, default=os.environ[\"SM_CHANNEL_VALID\"])\n", " parser.add_argument(\"--test_dir\", type=str, default=os.environ[\"SM_CHANNEL_TEST\"]) \n", " parser.add_argument('--chkpt_dir', type=str, default='/opt/ml/checkpoints') \n", "\n", " if train_notebook:\n", " args = parser.parse_args([])\n", " else:\n", " args = parser.parse_args()\n", " return args" ] }, { "cell_type": "code", "execution_count": 4, "id": "970aca02-6ba6-4111-943a-8ffa15e5ead9", "metadata": {}, "outputs": [], "source": [ "train_dir = 'train'\n", "valid_dir = 'valid'\n", "test_dir = 'test'\n", "!rm -rf {train_dir} {valid_dir} {test_dir} \n", "os.makedirs(train_dir, exist_ok=True)\n", "os.makedirs(valid_dir, exist_ok=True) \n", "os.makedirs(test_dir, exist_ok=True) " ] }, { "cell_type": "markdown", "id": "9635cf55-6781-4f0b-bed4-c4c49a76619c", "metadata": {}, "source": [ "### Load Arguments\n", "\n", "주피터 노트북에서 곧바로 실행할 수 있도록 설정값들을 로드합니다. 물론 노트북 환경이 아닌 커맨드라인에서도 `cd scripts & python3 train.py` 커맨드로 훈련 스크립트를 실행할 수 있습니다." ] }, { "cell_type": "code", "execution_count": 5, "id": "1825c0c8-2106-445b-bb7e-6d255dd3bd1c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2022-07-20 08:42:26 - ***** Arguments *****\n", "2022-07-20 08:42:26 - epochs=1\n", "seed=42\n", "train_batch_size=32\n", "eval_batch_size=32\n", "warmup_steps=100\n", "logging_steps=100\n", "learning_rate=5e-05\n", "disable_tqdm=False\n", "fp16=True\n", "tokenizer_id=salti/bert-base-multilingual-cased-finetuned-squad\n", "model_id=salti/bert-base-multilingual-cased-finetuned-squad\n", "output_data_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/data\n", "model_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/model\n", "n_gpus=4\n", "train_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/train\n", "valid_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/valid\n", "test_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/test\n", "chkpt_dir=chkpt\n", "\n" ] } ], "source": [ "chkpt_dir = 'chkpt'\n", "model_dir = 'model'\n", "output_data_dir = 'data'\n", "num_gpus = torch.cuda.device_count()\n", "\n", "!rm -rf {chkpt_dir} {model_dir} {output_data_dir} \n", "\n", "if os.environ.get('SM_CURRENT_HOST') is None:\n", " is_sm_container = False\n", "\n", " #src_dir = '/'.join(os.getcwd().split('/')[:-1])\n", " src_dir = os.getcwd()\n", " os.environ['SM_MODEL_DIR'] = f'{src_dir}/{model_dir}'\n", " os.environ['SM_OUTPUT_DATA_DIR'] = f'{src_dir}/{output_data_dir}'\n", " os.environ['SM_NUM_GPUS'] = str(num_gpus)\n", " os.environ['SM_CHANNEL_TRAIN'] = f'{src_dir}/{train_dir}'\n", " os.environ['SM_CHANNEL_VALID'] = f'{src_dir}/{valid_dir}'\n", " os.environ['SM_CHANNEL_TEST'] = f'{src_dir}/{test_dir}'\n", " \n", "args = parser_args(train_notebook=True) \n", "args.chkpt_dir = chkpt_dir\n", "logger.info(\"***** Arguments *****\")\n", "logger.info(''.join(f'{k}={v}\\n' for k, v in vars(args).items()))\n", "\n", "os.makedirs(args.chkpt_dir, exist_ok=True) \n", "os.makedirs(args.model_dir, exist_ok=True)\n", "os.makedirs(args.output_data_dir, exist_ok=True) " ] }, { "cell_type": "markdown", "id": "5478a4d9-eaad-4652-8b1a-141d40b18d88", "metadata": {}, "source": [ "
\n", "\n", "## 2. Preparation\n", "---\n", "본 핸즈온에서 사용할 데이터셋은 KorSTS (https://github.com/kakaobrain/KorNLUDatasets) 와 KLUE-STS (https://github.com/KLUE-benchmark/KLUE) 입니다.\n", "단일 데이터셋으로 훈련해도 무방하지만, 두 데이터셋을 모두 활용하여 훈련 시, 약간의 성능 향상이 있습니다.\n", "\n", "### Training Tips\n", "SBERT 훈련은 일반적으로 아래 3가지 방법들을 베이스라인으로 사용합니다.\n", "1. NLI 데이터셋으로 훈련\n", "2. STS 데이터셋으로 훈련\n", "3. NLI 데이터셋으로 훈련 후 STS 데이터셋으로 파인튜닝\n", "\n", "한국어 데이터의 경우, STS의 훈련 데이터가 상대적으로 적음에도 불구하고 NLI 기반 모델보다 예측 성능이 우수합니다. 따라서, 2번째 방법으로 진행합니다.
\n", "다만, STS보다 조금 더 좋은 예측 성능을 원한다면 NLI 데이터로 먼저 훈련하고 STS 데이터셋으로 이어서 훈련하는 것을 권장합니다." ] }, { "cell_type": "markdown", "id": "9aa546b9-2072-42f3-b886-18a06069edf0", "metadata": {}, "source": [ "### KLUE-STS 데이터셋 다운로드 및 피쳐셋 생성\n", "KLUE-STS 데이터셋을 허깅페이스 데이터셋 허브에서 다운로드 후, SBERT 훈련에 필요한 피쳐셋을 생성합니다." ] }, { "cell_type": "code", "execution_count": 6, "id": "fc5b4880-9907-44f0-9e45-2e67f3938a60", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2022-07-20 08:42:30 - Read KLUE-STS train/dev dataset\n", "2022-07-20 08:42:30 - Reusing dataset klue (/home/ec2-user/.cache/huggingface/datasets/klue/sts/1.0.0/e0fc3bc3de3eb03be2c92d72fd04a60ecc71903f821619cb28ca0e1e29e4233e)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6fd3d65e0c0a482b9cdb08ba4409c3c8", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/2 [00:00)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "repo = 'https://raw.githubusercontent.com/kakaobrain/KorNLUDatasets/master/KorSTS'\n", "urllib.request.urlretrieve(f'{repo}/sts-train.tsv', filename=f'{args.train_dir}/sts-train.tsv')\n", "urllib.request.urlretrieve(f'{repo}/sts-dev.tsv', filename=f'{args.valid_dir}/sts-dev.tsv')\n", "urllib.request.urlretrieve(f'{repo}/sts-test.tsv', filename=f'{args.test_dir}/sts-test.tsv')\n", "\n", "# !wget https://raw.githubusercontent.com/kakaobrain/KorNLUDatasets/master/KorSTS/sts-train.tsv -O {train_dir}/sts-train.tsv\n", "# !wget https://raw.githubusercontent.com/kakaobrain/KorNLUDatasets/master/KorSTS/sts-dev.tsv -O {valid_dir}/sts-dev.tsv\n", "# !wget https://raw.githubusercontent.com/kakaobrain/KorNLUDatasets/master/KorSTS/sts-test.tsv -O {test_dir}/sts-test.tsv" ] }, { "cell_type": "code", "execution_count": 8, "id": "0ad5f3ed-a0b9-4970-9b11-e3740cd92aec", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2022-07-20 08:42:36 - Read KorSTS train dataset\n", "2022-07-20 08:42:36 - Read KorSTS dev dataset\n" ] } ], "source": [ "logger.info(\"Read KorSTS train dataset\")\n", "\n", "with open(f'{args.train_dir}/sts-train.tsv', 'rt', encoding='utf8') as fIn:\n", " reader = csv.DictReader(fIn, delimiter='\\t', quoting=csv.QUOTE_NONE)\n", " for row in reader:\n", " if row[\"sentence1\"] and row[\"sentence2\"]: \n", " score = float(row['score']) / 5.0 # Normalize score to range 0 ... 1\n", " inp_example = InputExample(texts=[row['sentence1'], row['sentence2']], label=score)\n", " train_samples.append(inp_example)\n", " \n", "logging.info(\"Read KorSTS dev dataset\") \n", "with open(f'{args.valid_dir}/sts-dev.tsv', 'rt', encoding='utf8') as fIn:\n", " reader = csv.DictReader(fIn, delimiter='\\t', quoting=csv.QUOTE_NONE)\n", " for row in reader:\n", " if row[\"sentence1\"] and row[\"sentence2\"]: \n", " score = float(row['score']) / 5.0 # Normalize score to range 0 ... 1\n", " inp_example = InputExample(texts=[row['sentence1'], row['sentence2']], label=score)\n", " dev_samples.append(inp_example) " ] }, { "cell_type": "markdown", "id": "118a5b4e-81b3-48ad-b97e-b0cb8f98b099", "metadata": {}, "source": [ "
\n", "\n", "## 3. Training\n", "---\n", "\n", "### Training Preparation" ] }, { "cell_type": "markdown", "id": "57d7fdc0-6779-44d1-b9b6-c2a7a5d43706", "metadata": {}, "source": [ "### Model" ] }, { "cell_type": "code", "execution_count": 9, "id": "5ef7a33c-a8bc-427c-a64f-c5d7dfa5a09c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2022-07-20 08:43:44 - /home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/model/training_sts_sentence-transformers-xlm-r-100langs-bert-base-nli-stsb-mean-tokens-2022-07-20_08-43-44\n" ] } ], "source": [ "model_name = 'sentence-transformers/xlm-r-100langs-bert-base-nli-stsb-mean-tokens'\n", "\n", "train_batch_size = args.train_batch_size\n", "num_epochs = args.epochs\n", "model_save_path = f'{args.model_dir}/training_sts_'+model_name.replace(\"/\", \"-\")+'-'+datetime.now().strftime(\"%Y-%m-%d_%H-%M-%S\")\n", "logger.info(model_save_path)\n", "\n", "# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings\n", "word_embedding_model = models.Transformer(model_name)" ] }, { "cell_type": "markdown", "id": "4b92c650-3ae4-4319-a756-8457b8f03c97", "metadata": {}, "source": [ "문장 임베딩을 계산하기 위한 Pooler를 정의합니다. BERT로 분류 태스크를 수행할 때는 첫 번째 [CLS] 토큰의 출력 벡터를 임베딩 벡터로 사용하지만, SBERT에서는 BERT의 모든 토큰들의 출력 벡터들을 사용하여 임베딩 벡터를 계산합니다. 이 때 mean pooling이나 max pooling을 사용할 수 있으며, 본 예제에서는 mean pooling을 사용합니다." ] }, { "cell_type": "code", "execution_count": 11, "id": "b01d0e7c-0ac6-410a-b07d-9fdbea08162c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2022-07-20 08:49:35 - Use pytorch device: cuda\n" ] } ], "source": [ "# Apply mean pooling to get one fixed sized sentence vector\n", "pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),\n", " pooling_mode_mean_tokens=True,\n", " pooling_mode_cls_token=False,\n", " pooling_mode_max_tokens=False)\n", "\n", "model = SentenceTransformer(modules=[word_embedding_model, pooling_model])" ] }, { "cell_type": "markdown", "id": "c317c627-4ee2-4eca-bd12-dc25a76dfef2", "metadata": {}, "source": [ "모델 훈련 및 검증에 필요한 클래스 인스턴스를 생성합니다. 베이스라인으로 사용되는 검증 지표는 두 문장의 임베딩 벡터의 유사도를 산출하는 코사인 유사도입니다." ] }, { "cell_type": "code", "execution_count": 12, "id": "f20d8263-04b6-498e-aa67-55dd266274b7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2022-07-20 08:49:36 - Warmup-steps: 55\n" ] } ], "source": [ "train_dataset = SentencesDataset(train_samples, model)\n", "train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)\n", "train_loss = losses.CosineSimilarityLoss(model=model)\n", "\n", "evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='sts-dev')\n", "\n", "warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up\n", "logger.info(\"Warmup-steps: {}\".format(warmup_steps))" ] }, { "cell_type": "markdown", "id": "5e178954-f94a-4f8c-8a05-6c396d10f9e6", "metadata": {}, "source": [ "훈련을 수행합니다. 분산 훈련을 수행하지는 않지만, 데이터 볼륨이 크지 않으므로 수 분 내에 훈련이 완료됩니다." ] }, { "cell_type": "markdown", "id": "3e990ec0-0bd7-40dd-b276-0d91a3758188", "metadata": {}, "source": [ "### Start Training" ] }, { "cell_type": "code", "execution_count": 13, "id": "b3905a80-e2c9-4295-a34b-70f12920dad5", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5bc3d1b44f9c467cb1dfda0553397652", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Epoch: 0%| | 0/1 [00:00\n", "\n", "## 4. Evaluation\n", "---\n", "훈련이 완료되었다면, 테스트 데이터셋으로 예측 성능을 볼 수 있는 지표들을 산출합니다." ] }, { "cell_type": "code", "execution_count": 14, "id": "1b94aa7e-ff24-453c-a0c5-28ee104b89c5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2022-07-20 08:51:47 - Read KorSTS test dataset\n" ] } ], "source": [ "test_samples = []\n", "logger.info(\"Read KorSTS test dataset\") \n", "with open(f'{args.test_dir}/sts-test.tsv', 'rt', encoding='utf8') as fIn:\n", " reader = csv.DictReader(fIn, delimiter='\\t', quoting=csv.QUOTE_NONE)\n", " for row in reader:\n", " if row[\"sentence1\"] and row[\"sentence2\"]: \n", " score = float(row['score']) / 5.0 # Normalize score to range 0 ... 1\n", " inp_example = InputExample(texts=[row['sentence1'], row['sentence2']], label=score)\n", " test_samples.append(inp_example) " ] }, { "cell_type": "code", "execution_count": 15, "id": "98a96778-da38-41d7-b617-25db5b721c19", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2022-07-20 08:51:48 - Load pretrained SentenceTransformer: /home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/model/training_sts_sentence-transformers-xlm-r-100langs-bert-base-nli-stsb-mean-tokens-2022-07-20_08-43-44\n", "2022-07-20 08:51:51 - Use pytorch device: cuda\n", "2022-07-20 08:51:51 - EmbeddingSimilarityEvaluator: Evaluating the model on sts-test dataset:\n", "2022-07-20 08:51:53 - Cosine-Similarity :\tPearson: 0.8287\tSpearman: 0.8310\n", "2022-07-20 08:51:53 - Manhattan-Distance:\tPearson: 0.8242\tSpearman: 0.8283\n", "2022-07-20 08:51:53 - Euclidean-Distance:\tPearson: 0.8245\tSpearman: 0.8287\n", "2022-07-20 08:51:53 - Dot-Product-Similarity:\tPearson: 0.7619\tSpearman: 0.7608\n" ] }, { "data": { "text/plain": [ "0.8309806357819561" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "##############################################################################\n", "# Load the stored model and evaluate its performance on STS benchmark dataset\n", "##############################################################################\n", "\n", "model = SentenceTransformer(model_save_path)\n", "test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name='sts-test')\n", "test_evaluator(model, output_path=model_save_path)" ] }, { "cell_type": "markdown", "id": "a545fcc7-785c-47ff-8585-f148ef4b9786", "metadata": {}, "source": [ "
\n", "\n", "## 5. Applications\n", "---" ] }, { "cell_type": "code", "execution_count": 16, "id": "239fbb3b-892b-4833-8720-754cb251f456", "metadata": {}, "outputs": [], "source": [ "import random\n", "import time\n", "from operator import itemgetter \n", "\n", "def get_faiss_index(emb, data, dim=768):\n", " import faiss\n", " n_gpus = torch.cuda.device_count()\n", "\n", " if n_gpus == 0:\n", " # Create the Inner Product Index\n", " index = faiss.IndexFlatIP(dim)\n", " else:\n", " flat_config = []\n", " res = [faiss.StandardGpuResources() for i in range(n_gpus)]\n", " for i in range(n_gpus):\n", " cfg = faiss.GpuIndexFlatConfig()\n", " cfg.useFloat16 = False\n", " cfg.device = i\n", " flat_config.append(cfg)\n", "\n", " index = faiss.GpuIndexFlatIP(res[0], dim, flat_config[0])\n", "\n", " index = faiss.IndexIDMap(index)\n", " index.add_with_ids(emb, np.array(range(0, len(data)))) \n", " return index\n", "\n", "\n", "def search(model, query, data, index, k=5, random_select=False, verbose=True):\n", " t = time.time()\n", " query_vector = model.encode(query)\n", " dists, top_k_inds = index.search(query_vector, k)\n", " if verbose:\n", " print('total time: {}'.format(time.time() - t))\n", " results = [itemgetter(*ind)(data) for ind in top_k_inds] \n", " \n", " if random_select:\n", " return [random.choice(r) for r in results]\n", " else:\n", " return results" ] }, { "cell_type": "markdown", "id": "7aa1864d-6c57-499f-a7ac-fa4ba1e37672", "metadata": {}, "source": [ "### Chatbot\n", "\n", "챗봇은 크게 두 가지 형태로 개발합니다. 1) 생성 모델을 사용하여 해당 질문에 대한 창의적인 답변을 생성하거나, 2) 수많은 질문-답변 리스트들 중 질문에 부합하는 질문 후보들을 추린 다음 해당 후보에 적합한 답변을 찾는 방식이죠.\n", "본 핸즈온은 2)의 방법으로 간단하게 챗봇 예시를 보여드립니다. 질문 텍스트를 입력으로 받으면, 해당 질문의 임베딩을 계산하여 질문 임베딩과 모든 질문 리스트의 임베딩을 비교하여 유사도가 가장 높은 질문 후보들을 찾고, 각 후보에 매칭되는 답변을 찾습니다.\n", "\n", "코사인 유사도를 직접 계산할 수도 있지만, 페이스북에서 개발한 Faiss 라이브러리 (https://github.com/facebookresearch/faiss) 를 사용하면 훨씬 빠른 속도로 계산할 수 있습니다. Faiss는 Product 양자화 알고리즘을 GPU로 더욱 빠르게 구현한 라이브러리로, 정보 손실을 가급적 줄이면서 임베딩 벡터를 인덱싱합니다.\n", "\n", "References\n", "- Billion-scale similarity search with GPUs: https://arxiv.org/pdf/1702.08734.pdf\n", "- Product Quantizers for k-NN Tutorial Part 1: https://mccormickml.com/2017/10/13/product-quantizer-tutorial-part-1\n", "- Product Quantizers for k-NN Tutorial Part 1: http://mccormickml.com/2017/10/22/product-quantizer-tutorial-part-2" ] }, { "cell_type": "markdown", "id": "f8cb6a25-6e3a-4606-8bcb-f9bc647bc28e", "metadata": {}, "source": [ "#### Preparing chatbot dataset" ] }, { "cell_type": "code", "execution_count": 17, "id": "acbe5f54-c885-4324-8c56-f617c13085e5", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
QAlabel
012시 땡!하루가 또 가네요.0
11지망 학교 떨어졌어위로해 드립니다.0
23박4일 놀러가고 싶다여행은 언제나 좋죠.0
33박4일 정도 놀러가고 싶다여행은 언제나 좋죠.0
4PPL 심하네눈살이 찌푸려지죠.0
\n", "
" ], "text/plain": [ " Q A label\n", "0 12시 땡! 하루가 또 가네요. 0\n", "1 1지망 학교 떨어졌어 위로해 드립니다. 0\n", "2 3박4일 놀러가고 싶다 여행은 언제나 좋죠. 0\n", "3 3박4일 정도 놀러가고 싶다 여행은 언제나 좋죠. 0\n", "4 PPL 심하네 눈살이 찌푸려지죠. 0" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "urllib.request.urlretrieve(\"https://raw.githubusercontent.com/songys/Chatbot_data/master/ChatbotData.csv\", \n", " filename=f\"{args.train_dir}/chatbot-train.csv\")\n", "chatbot_df = pd.read_csv(f'{args.train_dir}/chatbot-train.csv')\n", "chatbot_df.head()" ] }, { "cell_type": "markdown", "id": "08b8b5b7-b8bf-467f-adc8-728a1719cff3", "metadata": {}, "source": [ "#### Embedding" ] }, { "cell_type": "code", "execution_count": 18, "id": "3770cf95-ef8e-47d5-b158-b5f7f523228d", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "aed7a672408642c6b7a315db01f13670", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/185 [00:00)" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "url = 'http://nlp.kookmin.ac.kr/kcc/KCCq28_Korean_sentences_EUCKR_v2.zip'\n", "news_path = f'{args.train_dir}/KCCq28_Korean_sentences_EUCKR_v2.zip'\n", "urllib.request.urlretrieve(url, news_path, MyProgressBar())" ] }, { "cell_type": "code", "execution_count": 40, "id": "3323ca38-b3d0-4d21-bf22-f527b164e6a0", "metadata": {}, "outputs": [], "source": [ "import zipfile\n", "with zipfile.ZipFile(news_path, 'r') as zip_ref:\n", " zip_ref.extractall(train_dir)" ] }, { "cell_type": "code", "execution_count": 41, "id": "e07ea44c-59c0-413d-8ec4-c544287e24d7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" ] } ], "source": [ "!rm -rf {news_path}" ] }, { "cell_type": "code", "execution_count": 42, "id": "5679f7cb-28a4-4d68-bfc3-ed127cc91809", "metadata": {}, "outputs": [], "source": [ "news_data = []\n", "f = open(f'{args.train_dir}/KCCq28_Korean_sentences_EUCKR_v2.txt', 'rt', encoding='cp949')\n", "lines = f.readlines()\n", "for line in lines:\n", " line = line.strip()\n", " news_data.append(line)\n", "f.close()" ] }, { "cell_type": "code", "execution_count": 43, "id": "65705cff-abe8-47ae-a6ac-6b7a2d289a1e", "metadata": {}, "outputs": [], "source": [ "news_data = news_data[:10000] # For debug purpose" ] }, { "cell_type": "markdown", "id": "3ac03cb7-b95b-416d-af12-5b81491c7ca8", "metadata": {}, "source": [ "#### Embedding" ] }, { "cell_type": "code", "execution_count": 44, "id": "489d2ba2-3c58-4bb1-8587-67ea81fd42dc", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a08694d72f084b4b94724c87ddd4ac15", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/157 [00:00