{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "94ed44ac-98e5-4200-8f1c-3adedec2866b",
   "metadata": {},
   "source": [
    "# Lab 1: Sentence-BERT (SBERT) Training\n",
    "\n",
    "### Fine-tuning Sentence-BERT (SBERT) & SBERT Embedding for applications\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d897cd6-a127-4a1c-9980-aa033c332f6a",
   "metadata": {},
   "source": [
    "\n",
    "## Introduction\n",
    "---\n",
    "\n",
    "본 모듈에서는 문장 임베딩을 산출하는 Sentence-BERT 모델을 STS 데이터셋으로 파인튜닝해 봅니다.\n",
    "SentenceTransformers 패키지를 사용하면 파인튜닝을 쉽게 수행할 수 있습니다. 다만, 현 시점에는 분산 훈련 기능 지원이 잘 되지 않으므로, 대용량 데이터셋으로 파인튜닝하는 니즈가 있다면 커스텀 훈련 코드를 직접 작성하셔야 합니다.\n",
    "\n",
    "***[Note] SageMaker Studio Lab, SageMaker Studio, SageMaker 노트북 인스턴스, 또는 여러분의 로컬 머신에서 이 데모를 실행할 수 있습니다. SageMaker Studio Lab을 사용하는 경우 GPU를 활성화하세요.***\n",
    "\n",
    "### References\n",
    "\n",
    "- Hugging Face Tutorial: https://huggingface.co/docs/transformers/training\n",
    "- Sentence-BERT paper: https://arxiv.org/abs/1908.10084\n",
    "- SentenceTransformers: https://www.sbert.net"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "577d681d-0f93-44ea-a3d1-0e0a9bb5e063",
   "metadata": {},
   "source": [
    "\n",
    "## 1. Setup Environments\n",
    "---\n",
    "\n",
    "### Import modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "91641d7f-95e5-473b-b144-823d4a19a299",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install sentence_transformers datasets faiss-gpu progressbar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c6d4b001-3508-4ca1-b73b-76b11fdcce7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import logging\n",
    "import argparse\n",
    "import torch\n",
    "import gzip\n",
    "import csv\n",
    "import math\n",
    "import urllib\n",
    "from torch import nn\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "\n",
    "from datetime import datetime\n",
    "from datasets import load_dataset\n",
    "from torch.utils.data import DataLoader\n",
    "from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses, models, util\n",
    "from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator\n",
    "from sentence_transformers.readers import InputExample\n",
    "from transformers.trainer_utils import get_last_checkpoint\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.INFO,\n",
    "    format='%(asctime)s - %(message)s',\n",
    "    datefmt='%Y-%m-%d %H:%M:%S',\n",
    "    handlers=[LoggingHandler()]\n",
    ")\n",
    "\n",
    "logger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1215e31-b560-4ea9-8fd2-eb097c9dd30b",
   "metadata": {},
   "source": [
    "### Argument parser"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c512803f-d734-4edd-a868-0f369f6b3888",
   "metadata": {},
   "outputs": [],
   "source": [
    "def parser_args(train_notebook=False):\n",
    "    parser = argparse.ArgumentParser()\n",
    "\n",
    "    # Default Setting\n",
    "    parser.add_argument(\"--epochs\", type=int, default=1)\n",
    "    parser.add_argument(\"--seed\", type=int, default=42)\n",
    "    parser.add_argument(\"--train_batch_size\", type=int, default=32)\n",
    "    parser.add_argument(\"--eval_batch_size\", type=int, default=32)\n",
    "    parser.add_argument(\"--warmup_steps\", type=int, default=100)\n",
    "    parser.add_argument(\"--logging_steps\", type=int, default=100)\n",
    "    parser.add_argument(\"--learning_rate\", type=str, default=5e-5)\n",
    "    parser.add_argument(\"--disable_tqdm\", type=bool, default=False)\n",
    "    parser.add_argument(\"--fp16\", type=bool, default=True)\n",
    "    parser.add_argument(\"--tokenizer_id\", type=str, default='sentence-transformers/xlm-r-100langs-bert-base-nli-stsb-mean-tokens')\n",
    "    parser.add_argument(\"--model_id\", type=str, default='sentence-transformers/xlm-r-100langs-bert-base-nli-stsb-mean-tokens')\n",
    "    \n",
    "    # SageMaker Container environment\n",
    "    parser.add_argument(\"--output_data_dir\", type=str, default=os.environ[\"SM_OUTPUT_DATA_DIR\"])\n",
    "    parser.add_argument(\"--model_dir\", type=str, default=os.environ[\"SM_MODEL_DIR\"])\n",
    "    parser.add_argument(\"--n_gpus\", type=str, default=os.environ[\"SM_NUM_GPUS\"])\n",
    "    parser.add_argument(\"--train_dir\", type=str, default=os.environ[\"SM_CHANNEL_TRAIN\"])\n",
    "    parser.add_argument(\"--valid_dir\", type=str, default=os.environ[\"SM_CHANNEL_VALID\"])\n",
    "    parser.add_argument(\"--test_dir\", type=str, default=os.environ[\"SM_CHANNEL_TEST\"])    \n",
    "    parser.add_argument('--chkpt_dir', type=str, default='/opt/ml/checkpoints')     \n",
    "\n",
    "    if train_notebook:\n",
    "        args = parser.parse_args([])\n",
    "    else:\n",
    "        args = parser.parse_args()\n",
    "    return args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "970aca02-6ba6-4111-943a-8ffa15e5ead9",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dir = 'train'\n",
    "valid_dir = 'valid'\n",
    "test_dir = 'test'\n",
    "!rm -rf {train_dir} {valid_dir} {test_dir} \n",
    "os.makedirs(train_dir, exist_ok=True)\n",
    "os.makedirs(valid_dir, exist_ok=True) \n",
    "os.makedirs(test_dir, exist_ok=True) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9635cf55-6781-4f0b-bed4-c4c49a76619c",
   "metadata": {},
   "source": [
    "### Load Arguments\n",
    "\n",
    "주피터 노트북에서 곧바로 실행할 수 있도록 설정값들을 로드합니다. 물론 노트북 환경이 아닌 커맨드라인에서도 `cd scripts & python3 train.py` 커맨드로 훈련 스크립트를 실행할 수 있습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1825c0c8-2106-445b-bb7e-6d255dd3bd1c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2022-07-20 08:42:26 - ***** Arguments *****\n",
      "2022-07-20 08:42:26 - epochs=1\n",
      "seed=42\n",
      "train_batch_size=32\n",
      "eval_batch_size=32\n",
      "warmup_steps=100\n",
      "logging_steps=100\n",
      "learning_rate=5e-05\n",
      "disable_tqdm=False\n",
      "fp16=True\n",
      "tokenizer_id=salti/bert-base-multilingual-cased-finetuned-squad\n",
      "model_id=salti/bert-base-multilingual-cased-finetuned-squad\n",
      "output_data_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/data\n",
      "model_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/model\n",
      "n_gpus=4\n",
      "train_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/train\n",
      "valid_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/valid\n",
      "test_dir=/home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/test\n",
      "chkpt_dir=chkpt\n",
      "\n"
     ]
    }
   ],
   "source": [
    "chkpt_dir = 'chkpt'\n",
    "model_dir = 'model'\n",
    "output_data_dir = 'data'\n",
    "num_gpus = torch.cuda.device_count()\n",
    "\n",
    "!rm -rf {chkpt_dir} {model_dir} {output_data_dir} \n",
    "\n",
    "if os.environ.get('SM_CURRENT_HOST') is None:\n",
    "    is_sm_container = False\n",
    "\n",
    "    #src_dir = '/'.join(os.getcwd().split('/')[:-1])\n",
    "    src_dir = os.getcwd()\n",
    "    os.environ['SM_MODEL_DIR'] = f'{src_dir}/{model_dir}'\n",
    "    os.environ['SM_OUTPUT_DATA_DIR'] = f'{src_dir}/{output_data_dir}'\n",
    "    os.environ['SM_NUM_GPUS'] = str(num_gpus)\n",
    "    os.environ['SM_CHANNEL_TRAIN'] = f'{src_dir}/{train_dir}'\n",
    "    os.environ['SM_CHANNEL_VALID'] = f'{src_dir}/{valid_dir}'\n",
    "    os.environ['SM_CHANNEL_TEST'] = f'{src_dir}/{test_dir}'\n",
    "    \n",
    "args = parser_args(train_notebook=True) \n",
    "args.chkpt_dir = chkpt_dir\n",
    "logger.info(\"***** Arguments *****\")\n",
    "logger.info(''.join(f'{k}={v}\\n' for k, v in vars(args).items()))\n",
    "\n",
    "os.makedirs(args.chkpt_dir, exist_ok=True) \n",
    "os.makedirs(args.model_dir, exist_ok=True)\n",
    "os.makedirs(args.output_data_dir, exist_ok=True) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5478a4d9-eaad-4652-8b1a-141d40b18d88",
   "metadata": {},
   "source": [
    "<br>\n",
    "\n",
    "## 2. Preparation\n",
    "---\n",
    "본 핸즈온에서 사용할 데이터셋은 KorSTS (https://github.com/kakaobrain/KorNLUDatasets) 와 KLUE-STS (https://github.com/KLUE-benchmark/KLUE) 입니다.\n",
    "단일 데이터셋으로 훈련해도 무방하지만, 두 데이터셋을 모두 활용하여 훈련 시, 약간의 성능 향상이 있습니다.\n",
    "\n",
    "### Training Tips\n",
    "SBERT 훈련은 일반적으로 아래 3가지 방법들을 베이스라인으로 사용합니다.\n",
    "1. NLI 데이터셋으로 훈련\n",
    "2. STS 데이터셋으로 훈련\n",
    "3. NLI 데이터셋으로 훈련 후 STS 데이터셋으로 파인튜닝\n",
    "\n",
    "한국어 데이터의 경우, STS의 훈련 데이터가 상대적으로 적음에도 불구하고 NLI 기반 모델보다 예측 성능이 우수합니다. 따라서, 2번째 방법으로 진행합니다. <br>\n",
    "다만, STS보다 조금 더 좋은 예측 성능을 원한다면 NLI 데이터로 먼저 훈련하고 STS 데이터셋으로 이어서 훈련하는 것을 권장합니다."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9aa546b9-2072-42f3-b886-18a06069edf0",
   "metadata": {},
   "source": [
    "### KLUE-STS 데이터셋 다운로드 및 피쳐셋 생성\n",
    "KLUE-STS 데이터셋을 허깅페이스 데이터셋 허브에서 다운로드 후, SBERT 훈련에 필요한 피쳐셋을 생성합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fc5b4880-9907-44f0-9e45-2e67f3938a60",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2022-07-20 08:42:30 - Read KLUE-STS train/dev dataset\n",
      "2022-07-20 08:42:30 - Reusing dataset klue (/home/ec2-user/.cache/huggingface/datasets/klue/sts/1.0.0/e0fc3bc3de3eb03be2c92d72fd04a60ecc71903f821619cb28ca0e1e29e4233e)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6fd3d65e0c0a482b9cdb08ba4409c3c8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "logger.info(\"Read KLUE-STS train/dev dataset\")\n",
    "datasets = load_dataset(\"klue\", \"sts\")\n",
    "\n",
    "train_samples = []\n",
    "dev_samples = []\n",
    "\n",
    "for phase in [\"train\", \"validation\"]:\n",
    "    examples = datasets[phase]\n",
    "\n",
    "    for example in examples:\n",
    "        score = float(example[\"labels\"][\"label\"]) / 5.0  # 0.0 ~ 1.0 스케일로 유사도 정규화\n",
    "        inp_example = InputExample(texts=[example[\"sentence1\"], example[\"sentence2\"]], label=score)\n",
    "\n",
    "        if phase == \"validation\":\n",
    "            dev_samples.append(inp_example)\n",
    "        else:\n",
    "            train_samples.append(inp_example)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "972b0261-3a0e-4612-b1ff-128a46fb37fd",
   "metadata": {},
   "source": [
    "\n",
    "### KorSTS 데이터셋 다운로드 및 피쳐셋 생성\n",
    "KorSTS 데이터셋은 허깅페이스에도 등록되어 있지만, 향후 여러분의 커스텀 데이터셋을 같이 사용하는 유즈케이스를 고려하여 GitHub의 데이터셋을 다운로드받아 사용하겠습니다. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "589cb12d-9a43-4651-912c-466dbbfcd75d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('/home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/test/sts-test.tsv',\n",
       " <http.client.HTTPMessage at 0x7fee52a7f970>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "repo = 'https://raw.githubusercontent.com/kakaobrain/KorNLUDatasets/master/KorSTS'\n",
    "urllib.request.urlretrieve(f'{repo}/sts-train.tsv', filename=f'{args.train_dir}/sts-train.tsv')\n",
    "urllib.request.urlretrieve(f'{repo}/sts-dev.tsv', filename=f'{args.valid_dir}/sts-dev.tsv')\n",
    "urllib.request.urlretrieve(f'{repo}/sts-test.tsv', filename=f'{args.test_dir}/sts-test.tsv')\n",
    "\n",
    "# !wget https://raw.githubusercontent.com/kakaobrain/KorNLUDatasets/master/KorSTS/sts-train.tsv -O {train_dir}/sts-train.tsv\n",
    "# !wget https://raw.githubusercontent.com/kakaobrain/KorNLUDatasets/master/KorSTS/sts-dev.tsv -O {valid_dir}/sts-dev.tsv\n",
    "# !wget https://raw.githubusercontent.com/kakaobrain/KorNLUDatasets/master/KorSTS/sts-test.tsv -O {test_dir}/sts-test.tsv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0ad5f3ed-a0b9-4970-9b11-e3740cd92aec",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2022-07-20 08:42:36 - Read KorSTS train dataset\n",
      "2022-07-20 08:42:36 - Read KorSTS dev dataset\n"
     ]
    }
   ],
   "source": [
    "logger.info(\"Read KorSTS train dataset\")\n",
    "\n",
    "with open(f'{args.train_dir}/sts-train.tsv', 'rt', encoding='utf8') as fIn:\n",
    "    reader = csv.DictReader(fIn, delimiter='\\t', quoting=csv.QUOTE_NONE)\n",
    "    for row in reader:\n",
    "        if row[\"sentence1\"] and row[\"sentence2\"]:          \n",
    "            score = float(row['score']) / 5.0  # Normalize score to range 0 ... 1\n",
    "            inp_example = InputExample(texts=[row['sentence1'], row['sentence2']], label=score)\n",
    "            train_samples.append(inp_example)\n",
    "            \n",
    "logging.info(\"Read KorSTS dev dataset\")            \n",
    "with open(f'{args.valid_dir}/sts-dev.tsv', 'rt', encoding='utf8') as fIn:\n",
    "    reader = csv.DictReader(fIn, delimiter='\\t', quoting=csv.QUOTE_NONE)\n",
    "    for row in reader:\n",
    "        if row[\"sentence1\"] and row[\"sentence2\"]:        \n",
    "            score = float(row['score']) / 5.0  # Normalize score to range 0 ... 1\n",
    "            inp_example = InputExample(texts=[row['sentence1'], row['sentence2']], label=score)\n",
    "            dev_samples.append(inp_example)            "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "118a5b4e-81b3-48ad-b97e-b0cb8f98b099",
   "metadata": {},
   "source": [
    "<br>\n",
    "\n",
    "## 3. Training\n",
    "---\n",
    "\n",
    "### Training Preparation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57d7fdc0-6779-44d1-b9b6-c2a7a5d43706",
   "metadata": {},
   "source": [
    "### Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5ef7a33c-a8bc-427c-a64f-c5d7dfa5a09c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2022-07-20 08:43:44 - /home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/model/training_sts_sentence-transformers-xlm-r-100langs-bert-base-nli-stsb-mean-tokens-2022-07-20_08-43-44\n"
     ]
    }
   ],
   "source": [
    "model_name = 'sentence-transformers/xlm-r-100langs-bert-base-nli-stsb-mean-tokens'\n",
    "\n",
    "train_batch_size = args.train_batch_size\n",
    "num_epochs = args.epochs\n",
    "model_save_path = f'{args.model_dir}/training_sts_'+model_name.replace(\"/\", \"-\")+'-'+datetime.now().strftime(\"%Y-%m-%d_%H-%M-%S\")\n",
    "logger.info(model_save_path)\n",
    "\n",
    "# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings\n",
    "word_embedding_model = models.Transformer(model_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b92c650-3ae4-4319-a756-8457b8f03c97",
   "metadata": {},
   "source": [
    "문장 임베딩을 계산하기 위한 Pooler를 정의합니다. BERT로 분류 태스크를 수행할 때는 첫 번째 [CLS] 토큰의 출력 벡터를 임베딩 벡터로 사용하지만, SBERT에서는 BERT의 모든 토큰들의 출력 벡터들을 사용하여 임베딩 벡터를 계산합니다. 이 때 mean pooling이나 max pooling을 사용할 수 있으며, 본 예제에서는 mean pooling을 사용합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b01d0e7c-0ac6-410a-b07d-9fdbea08162c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2022-07-20 08:49:35 - Use pytorch device: cuda\n"
     ]
    }
   ],
   "source": [
    "# Apply mean pooling to get one fixed sized sentence vector\n",
    "pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),\n",
    "                               pooling_mode_mean_tokens=True,\n",
    "                               pooling_mode_cls_token=False,\n",
    "                               pooling_mode_max_tokens=False)\n",
    "\n",
    "model = SentenceTransformer(modules=[word_embedding_model, pooling_model])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c317c627-4ee2-4eca-bd12-dc25a76dfef2",
   "metadata": {},
   "source": [
    "모델 훈련 및 검증에 필요한 클래스 인스턴스를 생성합니다. 베이스라인으로 사용되는 검증 지표는 두 문장의 임베딩 벡터의 유사도를 산출하는 코사인 유사도입니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f20d8263-04b6-498e-aa67-55dd266274b7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2022-07-20 08:49:36 - Warmup-steps: 55\n"
     ]
    }
   ],
   "source": [
    "train_dataset = SentencesDataset(train_samples, model)\n",
    "train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)\n",
    "train_loss = losses.CosineSimilarityLoss(model=model)\n",
    "\n",
    "evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='sts-dev')\n",
    "\n",
    "warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up\n",
    "logger.info(\"Warmup-steps: {}\".format(warmup_steps))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e178954-f94a-4f8c-8a05-6c396d10f9e6",
   "metadata": {},
   "source": [
    "훈련을 수행합니다. 분산 훈련을 수행하지는 않지만, 데이터 볼륨이 크지 않으므로 수 분 내에 훈련이 완료됩니다."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e990ec0-0bd7-40dd-b276-0d91a3758188",
   "metadata": {},
   "source": [
    "### Start Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b3905a80-e2c9-4295-a34b-70f12920dad5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5bc3d1b44f9c467cb1dfda0553397652",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "86c0606e96594b7b81c4ca1e21dd240c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Iteration:   0%|          | 0/545 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2022-07-20 08:50:25 - EmbeddingSimilarityEvaluator: Evaluating the model on sts-dev dataset in epoch 0 after 272 steps:\n",
      "2022-07-20 08:50:28 - Cosine-Similarity :\tPearson: 0.8458\tSpearman: 0.8462\n",
      "2022-07-20 08:50:28 - Manhattan-Distance:\tPearson: 0.8333\tSpearman: 0.8371\n",
      "2022-07-20 08:50:28 - Euclidean-Distance:\tPearson: 0.8339\tSpearman: 0.8380\n",
      "2022-07-20 08:50:28 - Dot-Product-Similarity:\tPearson: 0.8095\tSpearman: 0.8114\n",
      "2022-07-20 08:50:28 - Save model to /home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/model/training_sts_sentence-transformers-xlm-r-100langs-bert-base-nli-stsb-mean-tokens-2022-07-20_08-43-44\n",
      "2022-07-20 08:51:12 - EmbeddingSimilarityEvaluator: Evaluating the model on sts-dev dataset in epoch 0 after 544 steps:\n",
      "2022-07-20 08:51:15 - Cosine-Similarity :\tPearson: 0.8511\tSpearman: 0.8513\n",
      "2022-07-20 08:51:15 - Manhattan-Distance:\tPearson: 0.8378\tSpearman: 0.8416\n",
      "2022-07-20 08:51:15 - Euclidean-Distance:\tPearson: 0.8383\tSpearman: 0.8425\n",
      "2022-07-20 08:51:15 - Dot-Product-Similarity:\tPearson: 0.8112\tSpearman: 0.8140\n",
      "2022-07-20 08:51:15 - Save model to /home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/model/training_sts_sentence-transformers-xlm-r-100langs-bert-base-nli-stsb-mean-tokens-2022-07-20_08-43-44\n",
      "2022-07-20 08:51:24 - EmbeddingSimilarityEvaluator: Evaluating the model on sts-dev dataset after epoch 0:\n",
      "2022-07-20 08:51:27 - Cosine-Similarity :\tPearson: 0.8511\tSpearman: 0.8513\n",
      "2022-07-20 08:51:27 - Manhattan-Distance:\tPearson: 0.8378\tSpearman: 0.8416\n",
      "2022-07-20 08:51:27 - Euclidean-Distance:\tPearson: 0.8383\tSpearman: 0.8425\n",
      "2022-07-20 08:51:27 - Dot-Product-Similarity:\tPearson: 0.8112\tSpearman: 0.8140\n"
     ]
    }
   ],
   "source": [
    "# Train the model\n",
    "model.fit(\n",
    "    train_objectives=[(train_dataloader, train_loss)],\n",
    "    evaluator=evaluator,\n",
    "    epochs=num_epochs,\n",
    "    evaluation_steps=int(len(train_dataloader)*0.5),\n",
    "    warmup_steps=warmup_steps,\n",
    "    output_path=model_save_path,\n",
    "    use_amp=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e370591f-cb30-4b2e-a637-c664ef8fc918",
   "metadata": {},
   "source": [
    "<br>\n",
    "\n",
    "## 4. Evaluation\n",
    "---\n",
    "훈련이 완료되었다면, 테스트 데이터셋으로 예측 성능을 볼 수 있는 지표들을 산출합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "1b94aa7e-ff24-453c-a0c5-28ee104b89c5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2022-07-20 08:51:47 - Read KorSTS test dataset\n"
     ]
    }
   ],
   "source": [
    "test_samples = []\n",
    "logger.info(\"Read KorSTS test dataset\")            \n",
    "with open(f'{args.test_dir}/sts-test.tsv', 'rt', encoding='utf8') as fIn:\n",
    "    reader = csv.DictReader(fIn, delimiter='\\t', quoting=csv.QUOTE_NONE)\n",
    "    for row in reader:\n",
    "        if row[\"sentence1\"] and row[\"sentence2\"]:        \n",
    "            score = float(row['score']) / 5.0  # Normalize score to range 0 ... 1\n",
    "            inp_example = InputExample(texts=[row['sentence1'], row['sentence2']], label=score)\n",
    "            test_samples.append(inp_example)        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "98a96778-da38-41d7-b617-25db5b721c19",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2022-07-20 08:51:48 - Load pretrained SentenceTransformer: /home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/model/training_sts_sentence-transformers-xlm-r-100langs-bert-base-nli-stsb-mean-tokens-2022-07-20_08-43-44\n",
      "2022-07-20 08:51:51 - Use pytorch device: cuda\n",
      "2022-07-20 08:51:51 - EmbeddingSimilarityEvaluator: Evaluating the model on sts-test dataset:\n",
      "2022-07-20 08:51:53 - Cosine-Similarity :\tPearson: 0.8287\tSpearman: 0.8310\n",
      "2022-07-20 08:51:53 - Manhattan-Distance:\tPearson: 0.8242\tSpearman: 0.8283\n",
      "2022-07-20 08:51:53 - Euclidean-Distance:\tPearson: 0.8245\tSpearman: 0.8287\n",
      "2022-07-20 08:51:53 - Dot-Product-Similarity:\tPearson: 0.7619\tSpearman: 0.7608\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.8309806357819561"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "##############################################################################\n",
    "# Load the stored model and evaluate its performance on STS benchmark dataset\n",
    "##############################################################################\n",
    "\n",
    "model = SentenceTransformer(model_save_path)\n",
    "test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name='sts-test')\n",
    "test_evaluator(model, output_path=model_save_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a545fcc7-785c-47ff-8585-f148ef4b9786",
   "metadata": {},
   "source": [
    "<br>\n",
    "\n",
    "## 5. Applications\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "239fbb3b-892b-4833-8720-754cb251f456",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import time\n",
    "from operator import itemgetter \n",
    "\n",
    "def get_faiss_index(emb, data, dim=768):\n",
    "    import faiss\n",
    "    n_gpus = torch.cuda.device_count()\n",
    "\n",
    "    if n_gpus == 0:\n",
    "        # Create the Inner Product Index\n",
    "        index = faiss.IndexFlatIP(dim)\n",
    "    else:\n",
    "        flat_config = []\n",
    "        res = [faiss.StandardGpuResources() for i in range(n_gpus)]\n",
    "        for i in range(n_gpus):\n",
    "            cfg = faiss.GpuIndexFlatConfig()\n",
    "            cfg.useFloat16 = False\n",
    "            cfg.device = i\n",
    "            flat_config.append(cfg)\n",
    "\n",
    "        index = faiss.GpuIndexFlatIP(res[0], dim, flat_config[0])\n",
    "\n",
    "    index = faiss.IndexIDMap(index)\n",
    "    index.add_with_ids(emb, np.array(range(0, len(data)))) \n",
    "    return index\n",
    "\n",
    "\n",
    "def search(model, query, data, index, k=5, random_select=False, verbose=True):\n",
    "    t = time.time()\n",
    "    query_vector = model.encode(query)\n",
    "    dists, top_k_inds = index.search(query_vector, k)\n",
    "    if verbose:\n",
    "        print('total time: {}'.format(time.time() - t))\n",
    "    results = [itemgetter(*ind)(data) for ind in top_k_inds] \n",
    "    \n",
    "    if random_select:\n",
    "        return [random.choice(r) for r in results]\n",
    "    else:\n",
    "        return results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7aa1864d-6c57-499f-a7ac-fa4ba1e37672",
   "metadata": {},
   "source": [
    "### Chatbot\n",
    "\n",
    "챗봇은 크게 두 가지 형태로 개발합니다. 1) 생성 모델을 사용하여 해당 질문에 대한 창의적인 답변을 생성하거나, 2) 수많은 질문-답변 리스트들 중 질문에 부합하는 질문 후보들을 추린 다음 해당 후보에 적합한 답변을 찾는 방식이죠.\n",
    "본 핸즈온은 2)의 방법으로 간단하게 챗봇 예시를 보여드립니다. 질문 텍스트를 입력으로 받으면, 해당 질문의 임베딩을 계산하여 질문 임베딩과 모든 질문 리스트의 임베딩을 비교하여 유사도가 가장 높은 질문 후보들을 찾고, 각 후보에 매칭되는 답변을 찾습니다.\n",
    "\n",
    "코사인 유사도를 직접 계산할 수도 있지만, 페이스북에서 개발한 Faiss 라이브러리 (https://github.com/facebookresearch/faiss) 를 사용하면 훨씬 빠른 속도로 계산할 수 있습니다. Faiss는 Product 양자화 알고리즘을 GPU로 더욱 빠르게 구현한 라이브러리로, 정보 손실을 가급적 줄이면서 임베딩 벡터를 인덱싱합니다.\n",
    "\n",
    "References\n",
    "- Billion-scale similarity search with GPUs: https://arxiv.org/pdf/1702.08734.pdf\n",
    "- Product Quantizers for k-NN Tutorial Part 1: https://mccormickml.com/2017/10/13/product-quantizer-tutorial-part-1\n",
    "- Product Quantizers for k-NN Tutorial Part 1: http://mccormickml.com/2017/10/22/product-quantizer-tutorial-part-2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8cb6a25-6e3a-4606-8bcb-f9bc647bc28e",
   "metadata": {},
   "source": [
    "#### Preparing chatbot dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "acbe5f54-c885-4324-8c56-f617c13085e5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Q</th>\n",
       "      <th>A</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>12시 땡!</td>\n",
       "      <td>하루가 또 가네요.</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1지망 학교 떨어졌어</td>\n",
       "      <td>위로해 드립니다.</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3박4일 놀러가고 싶다</td>\n",
       "      <td>여행은 언제나 좋죠.</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3박4일 정도 놀러가고 싶다</td>\n",
       "      <td>여행은 언제나 좋죠.</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>PPL 심하네</td>\n",
       "      <td>눈살이 찌푸려지죠.</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                 Q            A  label\n",
       "0           12시 땡!   하루가 또 가네요.      0\n",
       "1      1지망 학교 떨어졌어    위로해 드립니다.      0\n",
       "2     3박4일 놀러가고 싶다  여행은 언제나 좋죠.      0\n",
       "3  3박4일 정도 놀러가고 싶다  여행은 언제나 좋죠.      0\n",
       "4          PPL 심하네   눈살이 찌푸려지죠.      0"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "urllib.request.urlretrieve(\"https://raw.githubusercontent.com/songys/Chatbot_data/master/ChatbotData.csv\", \n",
    "                           filename=f\"{args.train_dir}/chatbot-train.csv\")\n",
    "chatbot_df = pd.read_csv(f'{args.train_dir}/chatbot-train.csv')\n",
    "chatbot_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08b8b5b7-b8bf-467f-adc8-728a1719cff3",
   "metadata": {},
   "source": [
    "#### Embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "3770cf95-ef8e-47d5-b158-b5f7f523228d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aed7a672408642c6b7a315db01f13670",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/185 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "chatbot_q_data = chatbot_df['Q'].tolist()\n",
    "chatbot_a_data = chatbot_df['A'].tolist()\n",
    "chatbot_emb = model.encode(chatbot_q_data, normalize_embeddings=True, batch_size=64, show_progress_bar=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38d7cf2c-664d-413a-9f52-83cd415cae34",
   "metadata": {},
   "source": [
    "#### Indexing the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "179195c4-137d-4c75-b77d-f30f4b3b1155",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2022-07-20 08:52:14 - Loading faiss with AVX2 support.\n",
      "2022-07-20 08:52:14 - Could not load library with AVX2 support due to:\n",
      "ModuleNotFoundError(\"No module named 'faiss.swigfaiss_avx2'\")\n",
      "2022-07-20 08:52:14 - Loading faiss.\n",
      "2022-07-20 08:52:14 - Successfully loaded faiss.\n"
     ]
    }
   ],
   "source": [
    "chatbot_index = get_faiss_index(chatbot_emb, chatbot_q_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5d8384b-7623-4344-bbd1-dabb1a31b0c6",
   "metadata": {},
   "source": [
    "#### Inference\n",
    "샘플 질문들에 대한 추론을 수행합니다. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "84f161c7-c078-41ce-9a5b-1ba9aff13fb8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2eb8a3d695f44c5b959a59685fab41c1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total time: 0.03067159652709961\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "['좋은 시간 보내시길 바라요.']"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "query = ['커피 라떼 마시고 싶어']\n",
    "search(model, query, chatbot_a_data, chatbot_index, random_select=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "e8e46dbf-7e0b-4b62-a700-fdd554264087",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4467da19586c426bb86720a793707dd3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total time: 0.0308225154876709\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "['낮잠을 잠깐 자도 괜찮아요.', '같이 놀아요.']"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "query = ['너무 졸려', '놀고 싶어']\n",
    "search(model, query, chatbot_a_data, chatbot_index, random_select=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cce7a866-49e9-4415-af15-6ec389a9138f",
   "metadata": {},
   "source": [
    "### Semantic Search (News)\n",
    "\n",
    "\n",
    "시멘틱(의미) 검색은 검색 쿼리가 키워드를 찾는 것뿐만 아니라, 검색에 사용되는 단어의 의도와 문맥적 의미를 파악하는 것을 목표로 합니다.\n",
    "시멘틱 유사도 검색을 또한 상기 챗봇 예시와 마찬가지로, 해당 검색 쿼리를 입력하면, 검색 쿼리의 임베딩을 계산하여 모든 문서(예: 뉴스 제목/요약, 웹페이지 제목/요약) 리스트의 임베딩을 비교하여 가장 유사도가 높은 문서 후보들을 찾습니다.\n",
    "\n",
    "References\n",
    "- Billion-scale semantic similarity search with FAISS+SBERT: https://towardsdatascience.com/billion-scale-semantic-similarity-search-with-faiss-sbert-c845614962e2\n",
    "- Korean Contemporary Corpus of Written Sentences: http://nlp.kookmin.ac.kr/kcc/"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "832c06db-3e7f-4497-90f9-b70f5f1ba81c",
   "metadata": {},
   "source": [
    "#### Preparing news dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "f912ef7e-afd0-4d2d-b25e-f48ee59dfe99",
   "metadata": {},
   "outputs": [],
   "source": [
    "import progressbar\n",
    "\n",
    "class MyProgressBar():\n",
    "    def __init__(self):\n",
    "        self.pbar = None\n",
    "\n",
    "    def __call__(self, block_num, block_size, total_size):\n",
    "        if not self.pbar:\n",
    "            self.pbar=progressbar.ProgressBar(maxval=total_size)\n",
    "            self.pbar.start()\n",
    "\n",
    "        downloaded = block_num * block_size\n",
    "        if downloaded < total_size:\n",
    "            self.pbar.update(downloaded)\n",
    "        else:\n",
    "            self.pbar.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "6b5b0d49-16fa-4080-bf88-21cd64ae1cbe",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100% |########################################################################|\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('/home/ec2-user/SageMaker/sm-kornlp-usecases/sentence-bert-finetuning/train/KCCq28_Korean_sentences_EUCKR_v2.zip',\n",
       " <http.client.HTTPMessage at 0x7f53a416d3a0>)"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "url = 'http://nlp.kookmin.ac.kr/kcc/KCCq28_Korean_sentences_EUCKR_v2.zip'\n",
    "news_path = f'{args.train_dir}/KCCq28_Korean_sentences_EUCKR_v2.zip'\n",
    "urllib.request.urlretrieve(url, news_path, MyProgressBar())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "3323ca38-b3d0-4d21-bf22-f527b164e6a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import zipfile\n",
    "with zipfile.ZipFile(news_path, 'r') as zip_ref:\n",
    "    zip_ref.extractall(train_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "e07ea44c-59c0-413d-8ec4-c544287e24d7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
      "To disable this warning, you can either:\n",
      "\t- Avoid using `tokenizers` before the fork if possible\n",
      "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
     ]
    }
   ],
   "source": [
    "!rm -rf {news_path}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "5679f7cb-28a4-4d68-bfc3-ed127cc91809",
   "metadata": {},
   "outputs": [],
   "source": [
    "news_data = []\n",
    "f = open(f'{args.train_dir}/KCCq28_Korean_sentences_EUCKR_v2.txt', 'rt', encoding='cp949')\n",
    "lines = f.readlines()\n",
    "for line in lines:\n",
    "    line = line.strip()\n",
    "    news_data.append(line)\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "65705cff-abe8-47ae-a6ac-6b7a2d289a1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "news_data = news_data[:10000] # For debug purpose"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ac03cb7-b95b-416d-af12-5b81491c7ca8",
   "metadata": {},
   "source": [
    "#### Embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "489d2ba2-3c58-4bb1-8587-67ea81fd42dc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a08694d72f084b4b94724c87ddd4ac15",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/157 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "news_emb = model.encode(news_data, normalize_embeddings=True, batch_size=64, show_progress_bar=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3dfc9adf-e1e8-412e-91c0-46778dbd608f",
   "metadata": {},
   "source": [
    "#### Indexing the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "d6bdc627-114d-47d5-844a-2702c6343af9",
   "metadata": {},
   "outputs": [],
   "source": [
    "news_index = get_faiss_index(news_emb, news_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d1a914e-7ccc-49e9-adc3-174d6231eb46",
   "metadata": {},
   "source": [
    "#### Inference\n",
    "샘플 질문들에 대한 추론을 수행합니다. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "3203cbdd-a1d4-4c4e-8f08-2de03702a247",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5e949ee48ecb447aa633119f28f20897",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total time: 0.031549692153930664\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[('아이서플라이는 \"아이폰4에 사용한 \\'A4\\'와 마찬가지로 \\'다이 마크\\'로 볼 때 삼성전자가 만든 것으로 보인다\"고 했다.',\n",
       "  '아이스크림을 서비스하고 있는 시공그룹의 박기석 회장은 \"콘텐츠제공서비스 품질인증으로 아이스크림 서비스와 콘텐츠에 대해 신뢰도와 공신력을 인정받았다.',\n",
       "  '엘리엇의 자회사인 블레이크 캐피탈과 포터 캐피탈은 홍보대행사를 통해 \"삼성전자가 제시한 개략적인 주주가치 제고 방안이 향후 회사에 건설적인 첫 걸음이 될 것으로 보고 있다\"고 밝혔다.',\n",
       "  '검토 결과에 따라 결론을 내릴 것\"이라고 말했다.',\n",
       "  '조 대사는 \"대미 의존도를 줄이려고 노력하고 있는 캐나다는 눈길을 아시아·태평양 지역으로 돌리고 있다\"며 \"특히 이 지역에서는 캐나다와 FTA를 맺은 국가가 없어 더욱 상징적 의미가 있다\"고 말했다.',\n",
       "  '박 대통령은 또 \"적극적인 세일즈 외교 대통령으로 나서겠다\"는 점을 분명히 했다.',\n",
       "  '애플이 협력업체의 인력을 영입하는 이유에 대해 아이폰인 캐나다는 \"애플이 GPU 자체 개발을 서두르기 위한 수순\"이라고 설명했다.'),\n",
       " ('이회성 IPCC 의장은 \"온실가스 배출량이 늘고 있다는 사실이 지구온난화의 과학적 근거\"라고 말했다.',\n",
       "  '다음은 일문 일답.파리기후협정의 의미는 무엇인가.\"2015년 12월 프랑스 파리에서 열린 기후변화협약 당사국총회에서 채택됐고, 지난해 11월 발효된 파리기후협정은 2020년부터 선진국뿐만 아니라 개발도상국도 의무적으로 온실가스 감축에 참여하도록 했다.',\n",
       "  '미국에서 활동하고 있는 중국계 학자 유 사오카이는 이 논문에서 \"도심의 고층빌딩 꼭대기에서 아래쪽으로 미세한 물방울을 뿌린다면 중국의 극심한 초미세먼지까지 낮출 수 있다\"고 주장했다.',\n",
       "  '그렇다면 대기 중 미세먼지 농도가 심한 날엔 어떻게 해야 할까.류연기 환경부 생활환경과장은 \"대기 중 미세먼지 농도가 높은 날에 요리를 할 경우엔 우선 주방 환풍기를 사용해 환기하고, 요리 후엔 잠시 동안 창문을 열어 두는 것이 좋다\"고 말했다.',\n",
       "  '국방부 관계자는 \"폐탄약 처리 시설은 한국 환경 관련 법률에 따라 친환경적이고 한미 양국의 합의각서에 따라 엄격하게 운영될 것이다\"고 말했다.',\n",
       "  'SK하이닉스 관계자는 \"압축공기보다 미량의 질소를 지속적으로 흘려 스크러버 내부의 유해물질을 닦아내고 있는데, 질소가 직접적인 원인인지는 수사당국 등의 정밀한 조사가 필요하다\"고 설명했다.',\n",
       "  '회사 측은 \"슬러지 100%자가처리, 에너지비용, 온실가스 배출량 절감을 위한 목적\"이라고 설명했다.')]"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "query =['아이스 라떼', '미세먼지']\n",
    "search(model, query, news_data, news_index, k=7, random_select=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_pytorch_p38",
   "language": "python",
   "name": "conda_pytorch_p38"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}