{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Fine-tuning Passage Embeddings with GenQ\n", "\n", "State-of-the-art semantic search applications like [extractive question answering](https://docs.pinecone.io/docs/extractive-question-answering) and [Retrieval Augmented Generation](https://arxiv.org/abs/2005.11401) (RAG) have gained significant attention recently, allowing users to query a large document corpus using natural language questions. One of the most important components in these systems is the **embedding model**, which maps input text into a vector, providing a numeric representation of the semantic meaning of the text. An effective model is tuned to place passages with similar meaning in similar locations in the embedding space. In a process known as **retrieval**, this space is efficiently interrogated to find the most relevant documents given an input query.\n", "\n", "In practice, general-purpose language models do not perform well on this retrieval task, particularly when dealing with documents related to a specialized domain, like medicine or law. The typical solution is to \"fine-tune\" a generic model to the particular task and field of interest by feeding it thousands of labeled training examples. For a semantic search engine, the most valuable dataset would be human-generated (query, passage) pairs.\n", "\n", "For example, this excerpt from the [Sagemaker Developer Guide](https://docs.aws.amazon.com/sagemaker/latest/dg/train-model.html) contains an appropriate answer for the following query:\n", "\n", "```\n", "Query: What instance type should I use for an NLP model?\n", "\n", "Relevant Passage: Choose the right instance type and infrastructure management tools for your use case. You can start from a small instance and scale up depending on your workload. For training a model on a tabular dataset, start with the smallest CPU instance of the C4 or C5 instance families. For training a large model for computer vision or natural language processing, start with the smallest GPU instance of the P2, P3, G4dn or G5 instance families.\n", "```\n", "\n", "However, it would be extremely expensive and time consuming to have humans generate questions for thousands of documents. In this series of notebooks, we will demonstrate an approach called [**GenQ**](https://www.sbert.net/examples/unsupervised_learning/query_generation/README.html), in which we leverage a separate model to automatically generate so-called \"synthetic\" queries for our dataset. For more details about GenQ, see the [NIPS publication](https://arxiv.org/abs/2104.08663).\n", "\n", "This demonstration is broken into 3 notebooks which are designed to be completed in order.\n", "\n", "1. **01-Data-Preparation** (this notebook): initialization, dataset preparation, generating synthetic queries\n", "2. 02-Finetune-and-Deploy-Model: finetuning the embedding model and deploying it to a Sagemaker endpoint\n", "3. 03-Applications-and-Evaluation: comparing the finetuned model to a baseline and using it for RAG and extractive QA\n", "\n", "Throughout the demo, we will leverage frameworks like [LangChain](https://python.langchain.com/en/latest/) and [Hugging Face](https://huggingface.co/) to simplify and operationalize the workflow. We will also use [Amazon SageMaker](https://aws.amazon.com/sagemaker/) for fully-managed and scalable inference, training, and deployment.\n", "\n", "## Part 1: Data Preparation\n", "\n", "In this notebook, we will download a document dataset, break it into individual passages, and use a GenQ model to generate synthetic queries for each excerpt." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1.0 Install dependencies\n", "\n", "**Make sure to use the `conda_pytorch_p39` SageMaker kernel, which comes with many dependencies pre-installed.** We have added a few additional packages to the `requirements.txt` file." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T16:10:42.893295Z", "start_time": "2023-05-31T16:09:32.371972Z" } }, "outputs": [], "source": [ "!pip install -r requirements.txt -q" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1.1. Load and preprocess dataset\n", "\n", "We will use PubMed abstracts as an example of a domain-specific corpus that may not produce great retrieval results with generic embedding models. The following cell uses the Hugging Face [`load_dataset`](https://huggingface.co/docs/datasets/loading#hugging-face-hub) utility to download the abstracts and cache on your instance.\n", "\n", "Cohan, A., Dernoncourt, F., Kim, D. S., Bui, T., Kim, S., Chang, W., & Goharian, N. (2018, June). A Discourse-Aware Attention Model for Abstractive Summarization of Long Documents. Proceedings of the 2018 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 2 (Short Papers), 615–621. doi:10.18653/v1/N18-2097\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "from IPython.core.display import display, HTML\n", "import os\n", "\n", "os.environ['TOKENIZERS_PARALLELISM'] = 'false'\n", "\n", "# load abstracts of pubmed articles\n", "dataset = load_dataset('ccdv/pubmed-summarization', 'document', split='train[:2500]')\n", "articles = [{'id': str(i).zfill(6), 'abstract': d['abstract']} for i, d in enumerate(dataset)]\n", "\n", "print(f\"Number of articles: {len(articles)}\")\n", "\n", "# print an example article\n", "display(HTML(articles[0]['abstract']))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In semantic search applications, it is common practice to split the input documents into smaller passages. The size of the excerpts requires you to strike a balance between accuracy and specificity - for example, larger passages are more likely to contain a full answer to your query, but they take longer for users (and downstream models) to parse. The most appropriate `chunk_size` will depend on your dataset and processing pipeline. LangChain supports a number of different [Text Splitters](https://python.langchain.com/en/latest/modules/indexes/text_splitters.html) - in this case we will use the [`NLTKTextSplitter`](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/nltk.html), which uses a tokenizer to split input texts." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# preprocessing - create passages by splitting articles\n", "from langchain.text_splitter import NLTKTextSplitter\n", "import nltk\n", "\n", "nltk.download('punkt')\n", "separator = '\\n\\n'\n", "text_splitter = NLTKTextSplitter(separator=separator, chunk_size=1000)\n", "\n", "passages, metadatas = [], []\n", "\n", "for a in articles:\n", " split_texts = text_splitter.split_text(a['abstract'])\n", " for p in split_texts:\n", " new_passages = p.split(separator)\n", " passages += new_passages\n", " metadatas += [{'source': a['id']} for _ in new_passages]\n", "\n", "print(f'Number of passages: {len(passages)}\\n\\n')\n", "\n", "print('Example Passages:\\n')\n", "for p in passages[:10]:\n", " print(p + '\\n')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1.2. Prepare GenQ model and use it to generate potential queries for each passage\n", "\n", "We will use [`BeIR/query-gen-msmarco-t5-large-v1`](https://huggingface.co/BeIR/query-gen-msmarco-t5-large-v1), which is a T5-based model used for synthetic query generation tasks. The Sagemaker [Hugging Face](https://sagemaker.readthedocs.io/en/stable/frameworks/huggingface/index.html) integration will greatly simplify this deployment. To define the [`HuggingFaceModel`](https://sagemaker.readthedocs.io/en/stable/frameworks/huggingface/sagemaker.huggingface.html#hugging-face-model), we need to supply an appropriate image URI, environment variables for the model and task, and an appropriate execution role." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# prepare (query, passage) pairs with GenQ\n", "from sagemaker.huggingface import HuggingFaceModel\n", "import sagemaker\n", "\n", "role = sagemaker.get_execution_role()\n", "hub = {\n", " 'HF_MODEL_ID':'BeIR/query-gen-msmarco-t5-large-v1',\n", " 'HF_TASK':'text2text-generation'\n", "}\n", "\n", "transformers_version = '4.26.0'\n", "pytorch_version = '1.13.1'\n", "py_version = 'py39'\n", "use_gpu = True\n", "\n", "# https://github.com/aws/deep-learning-containers/blob/master/available_images.md#huggingface-inference-containers\n", "image_uri = f\"763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-inference:{pytorch_version}-transformers{transformers_version}-{'gpu' if use_gpu else 'cpu'}-{py_version}{'-cu117' if use_gpu else ''}-ubuntu20.04\"\n", "\n", "# create Hugging Face Model Class\n", "huggingface_model = HuggingFaceModel(\n", " image_uri=image_uri,\n", " env=hub,\n", " role=role\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that the GenQ model has been defined, we can use it to generate (query, passage) pairs for each passage. Because we only need the model to be temporarily deployed for this specific task, it is an excellent candidate for a [SageMaker Transform Job](https://sagemaker.readthedocs.io/en/stable/api/inference/transformer.html). After generating text for our input dataset, the resources will automatically be shut down to limit costs.\n", "\n", "**Setup SageMaker Batch Transform Job to Generate Queries Asynchronously**\n", "\n", "At this time, the [Hugging Face Inference DLC only supports the `.jsonl` format](https://huggingface.co/docs/sagemaker/inference#run-batch-transform-with-transformers-and-sagemaker) for batch transforms. We will pass in additional parameters to ensure that 3 concise queries are generated for each passage. The `input.jsonl` file will be copied to S3." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import json\n", "\n", "with open('input.jsonl', 'w') as f:\n", " for p in passages:\n", " f.write(json.dumps({\n", " 'inputs': p,\n", " 'parameters': {\n", " 'max_length': 32,\n", " 'num_return_sequences': 3,\n", " 'top_p': 0.95,\n", " 'do_sample': True\n", " }}) + '\\n')\n", "\n", "sess = sagemaker.Session()\n", "bucket = sess.default_bucket()\n", "\n", "prefix = 'pubmed-finetuning'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!aws s3 cp input.jsonl s3://{bucket}/{prefix}/data/batch_input/input.jsonl" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [SageMaker Developer Guide](https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform-data-processing.html) for more details about the configuration of the transform job. In this case, we use a `join_source` on the input to ensure that we keep the (query, passage) pairs together." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# transform job takes about 3 hours with p3.2xlarge\n", "\n", "batch_job = huggingface_model.transformer(\n", " instance_count=1,\n", " instance_type='ml.p3.2xlarge',\n", " strategy='SingleRecord',\n", " accept='application/json',\n", " assemble_with='Line',\n", " output_path=f's3://{bucket}/{prefix}/data/batch_output'\n", ")\n", "\n", "# starts batch transform job and uses S3 data as input\n", "batch_job.transform(\n", " data=f's3://{bucket}/{prefix}/data/batch_input/input.jsonl',\n", " content_type='application/json',\n", " split_type='Line',\n", " join_source='Input'\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**(Optional) Deploy query generation model to SageMaker endpoint**\n", "\n", "As an alternative to the batch transform job, you can also deploy the model to a SageMaker endpoint using the SDK. This is useful for rapid testing - see an example of the generated queries for this passage." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Passage:\n", "background : the present study was carried out to assess the effects of community nutrition intervention based on advocacy approach on malnutrition status among school - aged children in shiraz , iran.materials and methods : this case - control nutritional intervention has been done between 2008 and 2009 on 2897 primary and secondary school boys and girls ( 7 - 13 years old ) based on advocacy approach in shiraz , iran .\n", "\n", "Generated Queries:\n", "\n", "what is nutritional intervention based on advocacy\n", "why community nutrition approach\n", "what is nutrition advocacy intervention\n" ] } ], "source": [ "# deploy model to SageMaker Inference\n", "predictor = huggingface_model.deploy(\n", " initial_instance_count=1, # number of instances\n", " instance_type='ml.g4dn.xlarge', # ec2 instance type\n", ")\n", "\n", "# synchronous inference\n", "psg = passages[0]\n", "\n", "queries = predictor.predict({\n", " 'inputs': psg,\n", " 'parameters': {\n", " 'max_length': 32,\n", " 'num_return_sequences': 3,\n", " 'top_p': 0.95,\n", " 'do_sample': True\n", " }\n", "})\n", "\n", "print(f'\\nPassage:\\n{psg}\\n\\nGenerated Queries:\\n')\n", "for q in queries:\n", " print(q['generated_text'])\n", "\n", "# make sure to delete the endpoint when you're done to limit cost\n", "# predictor.delete_endpoint()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1.4. Perform some light transforms on the output to turn into individual (query, passage) pairs\n", "\n", "Once the queries have been generated, we need to do some light transformations to parse the batch output into a format suitable for the finetuning job. The training pairs will be batched into `.tsv` files containing 1024 samples each. In the next notebook, we will use this dataset to finetune a generic embedding model." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:botocore.credentials:Found credentials from IAM Role: BaseNotebookInstanceEc2InstanceRole\n" ] } ], "source": [ "# transform outputs into format suitable for training job\n", "local_path = 'batch_output'\n", "\n", "sagemaker.s3.S3Downloader.download(f'{batch_job.output_path}/input.jsonl.out', local_path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from tqdm.auto import tqdm\n", "\n", "training_data_path = 'pubmed_training_pairs'\n", "\n", "if os.path.isdir(training_data_path) is False:\n", " os.mkdir(training_data_path)\n", "\n", "pairs = []\n", "file_count = 0\n", "\n", "with open(f'{local_path}/input.jsonl.out', 'r') as f:\n", " for r in tqdm(f.readlines()):\n", " obj = json.loads(r)\n", " psg = obj['inputs'].replace('\\t', ' ').replace('\\n', ' ')\n", "\n", " for q in obj['SageMakerOutput']:\n", " pairs.append(q['generated_text'].replace('\\t', ' ') + '\\t' + psg)\n", "\n", " if len(pairs) > 1024:\n", " with open(f'{training_data_path}/pairs_{file_count:04}.tsv', 'w', encoding='utf-8') as fp:\n", " fp.write('\\n'.join(pairs))\n", "\n", " file_count += 1\n", " pairs = []\n", "\n", "if pairs is not None:\n", " # save the final, smaller than 1024 batch\n", " with open(f'{training_data_path}/pairs_{file_count:04}.tsv', 'w', encoding='utf-8') as fp:\n", " fp.write('\\n'.join(pairs))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!aws s3 cp pubmed_training_pairs s3://{bucket}/{prefix}/data/training --recursive" ] } ], "metadata": { "kernelspec": { "name": "python3", "language": "python", "display_name": "Python 3 (ipykernel)" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.15" } }, "nbformat": 4, "nbformat_minor": 1 }