{ "cells": [ { "cell_type": "markdown", "id": "9adbbd88", "metadata": {}, "source": [ "# Fine tuning BERT for information retrieval using Amazon Sagemaker " ] }, { "cell_type": "markdown", "id": "b60614e2", "metadata": {}, "source": [ "## Runtime\n", "This notebook takes approximately 30 minutes to run.\n", "\n" ] }, { "cell_type": "markdown", "id": "9933f5b8", "metadata": {}, "source": [ "## Contents\n", " Background\n", "1. Development environment and permissions\n", " - Installation\n", " - Permissions\n", "2. Training\n", " - Downloading data \n", " - Preparing the data\n", " - Bi-Encoder Transformer Neural Network\n", "3. Inference\n", " - Offline scoring\n", " - Realtime endpoint\n", "4. OpenSearch\n", " - OpenSearch Client\n", " - Index and mapping\n", " - Ingestion of documents\n", "5. Simulated Semantic Search Application\n", " - Search Widget\n", " - Pipeline\n", " \n", "### Terminology: sentence, document, passage : All of these terms mean the same, the response for a query " ] }, { "attachments": { "huggingfact-SBERT.jpeg": { "image/jpeg": "" } }, "cell_type": "markdown", "id": "263d8baf", "metadata": {}, "source": [ "## Background \n", "\n", "The Transformer deep learning architecture has proven very successful, and has spawned several state of the art model families. One among them is Bidirectional Encoder Representations from Transformers (BERT): 340 million parameters [1]\n", "\n", "With transformers, the “pretrain then fine-tune” recipe has emerged as the standard approach of applying BERT to specific downstream tasks such as classification, sequence labeling, information retrieval and ranking. Typically, we start with a “base” pretrained transformer model such as the BERTBase and BERTLarge checkpoints directly downloadable from **SBERT** or the Hugging Face Transformers library. This model is then fine-tuned on task-specific labeled data drawn from the same distribution as the target task.\n", "\n", "\n", "\n", "Information retrieval (search) systems use lexical search algoritms such as BM-25, TF-IDF to find answers matching to a query. When we are able to use pre-trained language models like BERT for search systems, we can achieve higher search relevance as the pre-trained models will help in finding **semantic matches** rather then just **term match** for a query. At the same time, one should consider fine-tuning the original BERT model before using it for specific downstream task like information retrieval which helps in curriculum learning. \n", "\n", "The SBERT framework which is based on PyTorch and Transformers, offers a large collection of pre-trained models tuned for various tasks. We will be focussing on fine tuning the BERT model on data retrieval (search) usecase.\n", "\n", "In this notebook, we are attempting to fine-tune the BERT model for information retrieval usecase based on the original research paper Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks [2].\n", "\n", "**References**\n", "\n", "- [1] “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding“, Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova.\n", "- [2] Reimers, N., & Gurevych, I. (2019). Sentence-bert: Sentence embeddings using siamese bert-networks. arXiv preprint arXiv:1908.10084." ] }, { "cell_type": "markdown", "id": "c09b2282", "metadata": {}, "source": [ "## 1. Development environment and permissions\n", "\n", "Lets start with setting up the development environment and permissions, First we make sure that the kernel is set to \"conda_amazonei_pytorch_latest_p36\". Once the kernel is ready, we start with installing and importing all the required libraries." ] }, { "cell_type": "markdown", "id": "9af7c7cb", "metadata": {}, "source": [ "### Install and import dependencies" ] }, { "cell_type": "code", "execution_count": null, "id": "5b257eb2", "metadata": {}, "outputs": [], "source": [ "import subprocess\n", "import sys\n", "\n", "def install(package):\n", " subprocess.check_call([sys.executable, \"-q\", \"-m\", \"pip\", \"install\", package])\n", " \n", "install('sentence_transformers')\n", "install('opensearch-py')\n", "install('requests_aws4auth')\n", "\n", "import json\n", "import requests\n", "import boto3\n", "from torch.utils.data import DataLoader\n", "from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample\n", "import logging\n", "from datetime import datetime\n", "import gzip\n", "import os\n", "import tarfile\n", "from collections import defaultdict\n", "from torch.utils.data import IterableDataset\n", "import tqdm\n", "from torch.utils.data import Dataset\n", "import random\n", "import pickle\n", "import argparse\n", "import sagemaker\n", "from sagemaker.pytorch import PyTorch\n", "from sagemaker import get_execution_role" ] }, { "cell_type": "markdown", "id": "9434966e", "metadata": {}, "source": [ "### Setup the Sagemaker session, region and IAM role \n", "\n", "This notebook is already configured with an execution role which gives sagemaker, the permissions on behalf of us to access other services like S3, Sagemaker model training, sagemaker endpoints etc.\n", "\n", "We have created a S3 bucket for this notebook to store all the model artifacts. In the following code, we save the execution role arn and s3 bucket name as variables to be used later. " ] }, { "cell_type": "code", "execution_count": null, "id": "f4785d94", "metadata": {}, "outputs": [], "source": [ "role = get_execution_role()\n", "account = role.split('::')[1].split(':')[0]\n", "bucket = \"sagemaker-nlp-\"+account\n", "boto3_session = boto3.session.Session()\n", "my_region = boto3_session.region_name\n", "output_path = \"s3://\"+bucket+\"/nlp-dualencoder\"\n", "output_path" ] }, { "cell_type": "markdown", "id": "c52a530a", "metadata": {}, "source": [ "## 2. Training\n", "For model training, we are using Sagemaker Pytorch framework and provide a custom training script (nlp_loader_test.py). This script does the following steps, " ] }, { "cell_type": "markdown", "id": "5f110b64", "metadata": {}, "source": [ "### Downloading the data\n", "\n", "We are using MS MARCO dataset (https://microsoft.github.io/msmarco/Datasets). This is a large dataset to train models for information retrieval. It consists of about 500k real search queries from Bing search engine with the relevant text passages in descending order of relevance that answers the query.\n", "\n", "The dataset has 2 attributes,\n", "\n", "
Attribute | \n", "Type | \n", "Description | \n", "
---|---|---|
Query | \n", "Text | \n", "The question asked in the search engine | \n", "
Passage(s) | \n", "Array of texts | \n", "The responses that the user voted as relevant to the query they asked, these responses are ordered in such a way that the highly relevant response comes first and the responses with poor/no relevance comes last | \n", "