{ "cells": [ { "cell_type": "markdown", "id": "03050374-58d4-4388-a334-d06a395bfd40", "metadata": {}, "source": [ "# Finetuning Foundation Models - HuggingFace Text2Text- FLAN" ] }, { "cell_type": "markdown", "id": "362b27d1-cd77-46d7-88f0-68748156703d", "metadata": {}, "source": [ "In this demo notebook, we use the SageMaker Python SDK to **fine-tune a Text2Text model**. Such a model takes prompting text as input and generates text as output. The prompt can include a task description in natural language. Accordingly, the model can be used for a variety of NLP tasks (e.g., text summarization, question answering, etc.).\n", "\n", "We will fine-tune a pre-trained **FLAN T5 model** from [Hugging Face](https://huggingface.co/docs/transformers/model_doc/flan-t5). While pre-trained FLAN T5 models can be used \"as is\" for many tasks, fine-tuning can improve model performance on a particular task or language domain. As an example, we will fine-tune the model for a task that was not used for pre-training. After fine-tuning we will deploy two inference endpoints, one with a pre-trained and one with a fine-tuned model. We will then run the same inference query against both endpoints and compare results." ] }, { "cell_type": "markdown", "id": "d1706e56-3d74-4f2c-b5cd-695acca57d5c", "metadata": {}, "source": [ "#### In this notebook:\n", "1. [Setting up](#1.-Setting-up)\n", "1. [Fine-tuning a model](#2.-Fine-tuning-a-model)\n", "1. [Deploying inference endpoints](#3.-Deploying-inference-endpoints)\n", "1. [Running inference queries](#4.-Running-inference-queries)\n", "1. [Cleaning up resources](#5.-Cleaning-up-resources)" ] }, { "cell_type": "markdown", "id": "1c30630a-685b-4be1-9c32-e50252a77b87", "metadata": {}, "source": [ "### 1. Setting up" ] }, { "cell_type": "markdown", "id": "5b97ab4c-f05c-4696-8040-f7b0a17ba21e", "metadata": {}, "source": [ "We begin by installing and upgrading necessary packages. Restart the kernel after executing the cell below." ] }, { "cell_type": "code", "execution_count": 3, "id": "ef9bf59d-f7fb-47e1-b2cf-7be6a819be96", "metadata": { "tags": [] }, "outputs": [], "source": [ "#!pip install nest-asyncio==1.5.5 --quiet\n", "#!pip install ipywidgets==8.0.4 --quiet\n", "#!pip install sagemaker==2.148.0 --quiet" ] }, { "cell_type": "markdown", "id": "714a91ac-431d-4adc-8285-4e1dd73721d1", "metadata": { "tags": [] }, "source": [ "We will use the following variables throughout the notebook. In particular, we select FLAN T5 model size and select training and inference instance types. We also obtain execution role associated with the current notebook instance." ] }, { "cell_type": "code", "execution_count": 4, "id": "c83140fe-f5d2-49ee-a433-70109cf23bd5", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1maws_region:\u001b[0m us-east-1\n", "\u001b[1maws_role:\u001b[0m arn:aws:iam::509957658284:role/service-role/AmazonSageMaker-ExecutionRole-20211126T131684\n", "\u001b[1moutput_bucket:\u001b[0m sagemaker-us-east-1-509957658284\n" ] } ], "source": [ "import boto3\n", "import sagemaker\n", "\n", "# Get current region, role, and default bucket\n", "aws_region = boto3.Session().region_name\n", "aws_role = sagemaker.session.Session().get_caller_identity_arn()\n", "output_bucket = sagemaker.Session().default_bucket()\n", "\n", "# This will be useful for printing\n", "newline, bold, unbold = \"\\n\", \"\\033[1m\", \"\\033[0m\"\n", "\n", "print(f\"{bold}aws_region:{unbold} {aws_region}\")\n", "print(f\"{bold}aws_role:{unbold} {aws_role}\")\n", "print(f\"{bold}output_bucket:{unbold} {output_bucket}\")" ] }, { "cell_type": "markdown", "id": "c5f41438-e53c-4d04-bea7-a2a1f9408128", "metadata": { "tags": [] }, "source": [ "## Select Flan model" ] }, { "cell_type": "code", "execution_count": 5, "id": "23ec3986-9905-4c40-b81e-ca54f78311cf", "metadata": { "tags": [] }, "outputs": [], "source": [ "import IPython\n", "from ipywidgets import Dropdown\n", "from sagemaker.jumpstart.filters import And\n", "from sagemaker.jumpstart.notebook_utils import list_jumpstart_models\n", "\n", "# Default model choice\n", "model_id = \"huggingface-text2text-flan-t5-small\"\n", "model_version = \"*\"" ] }, { "cell_type": "code", "execution_count": 6, "id": "905ddba9-de01-4059-af82-3c178ecbed60", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1mmodel_id:\u001b[0m huggingface-text2text-flan-t5-small\n", "\u001b[1mtraining_instance_type:\u001b[0m ml.p3.2xlarge\n", "\u001b[1minference_instance_type:\u001b[0m ml.g5.2xlarge\n" ] } ], "source": [ "from sagemaker.instance_types import retrieve_default\n", "\n", "# Instance types for training and inference\n", "training_instance_type = retrieve_default(\n", " model_id=model_id, model_version=model_version, scope=\"training\"\n", ")\n", "\n", "training_instance_type = \"ml.p3.2xlarge\" \n", "inference_instance_type = \"ml.g5.2xlarge\"\n", "print(f\"{bold}model_id:{unbold} {model_id}\")\n", "print(f\"{bold}training_instance_type:{unbold} {training_instance_type}\")\n", "print(f\"{bold}inference_instance_type:{unbold} {inference_instance_type}\")" ] }, { "cell_type": "markdown", "id": "237c7754-ffc3-40a0-9f85-bda3b25161d0", "metadata": {}, "source": [ "### 2. Fine-tuning a model" ] }, { "cell_type": "markdown", "id": "d898c437-0b3a-47b2-9b25-827d824f83ec", "metadata": {}, "source": [ "FLAN T5 models were pre-trained on a variety of tasks. In this demo, we fine-tune a model for a new task. In this task, given a piece of text, the model is asked to generate questions that are relevant to the text, but cannot be answered based on provided information. Examples are given in the inference section of this notebook." ] }, { "cell_type": "markdown", "id": "0c377423-c7d5-4087-a175-717227d47936", "metadata": { "tags": [] }, "source": [ "#### 2.1. Preparing training data\n", "We will use a subset of SQuAD2.0 for supervised fine-tuning. This dataset contains questions posed by human annotators on a set of Wikipedia articles. In addition to questions with answers, SQuAD2.0 contains about 50k unanswerable questions. Such questions are plausible, but cannot be directly answered from the articles' content. We only use unanswerable questions for our task.\n", "\n", "*Citation: @article{rajpurkar2018know, title={Know what you don't know: Unanswerable questions for SQuAD},\n", "author={Rajpurkar, Pranav and Jia, Robin and Liang, Percy}, journal={arXiv preprint arXiv:1806.03822}, year={2018} }*\n", "\n", "License: [Creative Commons Attribution-ShareAlike License (CC BY-SA 4.0)](https://creativecommons.org/licenses/by-sa/4.0/legalcode)\n", "#original_data_location = f\"s3://sagemaker-sample-files/datasets/text/squad2.0/{original_data_file}\"" ] }, { "cell_type": "code", "execution_count": 7, "id": "96bb8fb1-84e1-43d3-981b-769856e1c204", "metadata": { "tags": [] }, "outputs": [], "source": [ "from sagemaker.s3 import S3Downloader\n", "\n", "original_data_file = \"train-v2.0.json\"" ] }, { "cell_type": "markdown", "id": "7d9c3efd-63ae-4d01-9b13-cd9c48c1af92", "metadata": {}, "source": [ "The Text2Text generation model can be fine-tuned on any text data provided that the data is in the expected format. The data must include a training and an optional validation parts. The best model is selected according to the validation loss, calculated at the end of each epoch. If a validation set is not given, an (adjustable) percentage of the training data is automatically split and used for validation.\n", "\n", "The training data must be formatted in JSON lines (`.jsonl`) format, where each line is a dictionary representing a single data sample. All training data must be in a single folder, however it can be saved in multiple jsonl files. The `.jsonl` file extension is mandatory. The training folder can also contain a `template.json` file describing the input and output formats.\n", "\n", "If no template file is given, the following default template will be used:\n", "```json\n", "{\n", " \"prompt\": \"{prompt}\",\n", " \"completion\": \"{completion}\"\n", "}\n", "```\n", "In this case, the data in the JSON lines entries must include `prompt` and `completion` fields.\n", "\n", "In this demo, we are going to use a custom template (see below)." ] }, { "cell_type": "code", "execution_count": 8, "id": "3002bc68-2779-4fc0-a68b-cbdd0b3663bd", "metadata": { "tags": [] }, "outputs": [], "source": [ "import json\n", "\n", "local_data_file = \"task-data.jsonl\" # any name with .jsonl extension\n", "\n", "with open(original_data_file) as f:\n", " data = json.load(f)\n", "\n", "with open(local_data_file, \"w\") as f:\n", " for article in data[\"data\"]:\n", " for paragraph in article[\"paragraphs\"]:\n", " # iterate over questions for a given paragraph\n", " for qas in paragraph[\"qas\"]:\n", " if qas[\"is_impossible\"]:\n", " # the question is relevant, but cannot be answered\n", " example = {\"context\": paragraph[\"context\"], \"question\": qas[\"question\"]}\n", " json.dump(example, f)\n", " f.write(\"\\n\")\n", "\n", "template = {\n", " \"prompt\": \"Ask a question which is related to the following text, but cannot be answered based on the text. Text: {context}\",\n", " \"completion\": \"{question}\",\n", "}\n", "with open(\"template.json\", \"w\") as f:\n", " json.dump(template, f)" ] }, { "cell_type": "markdown", "id": "67f18513-50ab-49a0-9ed1-cf436a6b65dc", "metadata": {}, "source": [ "### Upload to S3" ] }, { "cell_type": "code", "execution_count": 9, "id": "0fc435de-a9c3-403e-a24a-f5ef032755cb", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1mtraining data:\u001b[0m s3://sagemaker-us-east-1-509957658284/train_data\n" ] } ], "source": [ "from sagemaker.s3 import S3Uploader\n", "\n", "train_data_location = f\"s3://{output_bucket}/train_data\"\n", "S3Uploader.upload(local_data_file, train_data_location)\n", "S3Uploader.upload(\"template.json\", train_data_location)\n", "print(f\"{bold}training data:{unbold} {train_data_location}\")" ] }, { "cell_type": "markdown", "id": "36c448df-769f-41de-ba3b-e9ed0dccb041", "metadata": {}, "source": [ "#### 2.2. Start training\n", "\n", "We are now ready to launch a training job." ] }, { "cell_type": "code", "execution_count": 10, "id": "2e35da74-e166-41e8-b7d3-3804911b522e", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1mimage uri:\u001b[0m 763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:1.10.2-transformers4.17.0-gpu-py38-cu113-ubuntu20.04\n", "\u001b[1mmodel uri:\u001b[0m s3://jumpstart-cache-prod-us-east-1/huggingface-training/train-huggingface-text2text-flan-t5-small.tar.gz\n", "\u001b[1mscript uri:\u001b[0m s3://jumpstart-cache-prod-us-east-1/source-directory-tarballs/huggingface/transfer_learning/text2text/prepack/v1.0.3/sourcedir.tar.gz\n", "\u001b[1moutput location:\u001b[0m s3://sagemaker-us-east-1-509957658284/demo-fine-tune-flan-t5/\n" ] } ], "source": [ "from sagemaker import image_uris, model_uris, script_uris\n", "\n", "# Training instance will use this image\n", "train_image_uri = image_uris.retrieve(\n", " region=aws_region,\n", " framework=None, # automatically inferred from model_id\n", " model_id=model_id,\n", " model_version=model_version,\n", " image_scope=\"training\",\n", " instance_type=training_instance_type,\n", ")\n", "\n", "# Pre-trained model\n", "train_model_uri = model_uris.retrieve(\n", " model_id=model_id, model_version=model_version, model_scope=\"training\"\n", ")\n", "\n", "# Script to execute on the training instance\n", "train_script_uri = script_uris.retrieve(\n", " model_id=model_id, model_version=model_version, script_scope=\"training\"\n", ")\n", "\n", "output_location = f\"s3://{output_bucket}/demo-fine-tune-flan-t5/\"\n", "\n", "print(f\"{bold}image uri:{unbold} {train_image_uri}\")\n", "print(f\"{bold}model uri:{unbold} {train_model_uri}\")\n", "print(f\"{bold}script uri:{unbold} {train_script_uri}\")\n", "print(f\"{bold}output location:{unbold} {output_location}\")" ] }, { "cell_type": "code", "execution_count": 11, "id": "1ed0df7c-0414-49d9-953d-1d48675236c4", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'epochs': '2', 'max_steps': '-1', 'seed': '42', 'batch_size': '64', 'learning_rate': '0.0001', 'lr_scheduler_type': 'constant_with_warmup', 'warmup_ratio': '0.0', 'warmup_steps': '0', 'validation_split_ratio': '0.05', 'train_data_split_seed': '0', 'max_train_samples': '-1', 'max_eval_samples': '-1', 'max_input_length': '-1', 'max_output_length': '128', 'pad_to_max_length': 'True', 'gradient_accumulation_steps': '1', 'weight_decay': '0.0', 'adam_beta1': '0.9', 'adam_beta2': '0.999', 'adam_epsilon': '1e-08', 'max_grad_norm': '1.0', 'load_best_model_at_end': 'True', 'early_stopping_patience': '3', 'early_stopping_threshold': '0.0', 'label_smoothing_factor': '0', 'logging_strategy': 'steps', 'logging_first_step': 'False', 'logging_steps': '500', 'logging_nan_inf_filter': 'True', 'save_strategy': 'epoch', 'save_steps': '500', 'save_total_limit': '2', 'dataloader_drop_last': 'False', 'dataloader_num_workers': '0', 'evalaution_strategy': 'epoch', 'eval_steps': '500', 'eval_accumulation_steps': 'None', 'gradient_checkpointing': 'True', 'auto_find_batch_size': 'False', 'preprocessing_num_workers': 'None'}\n" ] } ], "source": [ "from sagemaker import hyperparameters\n", "\n", "# Retrieve the default hyper-parameters for fine-tuning the model\n", "hyperparameters = hyperparameters.retrieve_default(model_id=model_id, model_version=model_version)\n", "\n", "# We will override some default hyperparameters with custom values\n", "hyperparameters[\"epochs\"] = \"2\"\n", "print(hyperparameters)" ] }, { "cell_type": "markdown", "id": "6fdf3377-4cd7-40df-876b-e0fa2f57dc8f", "metadata": {}, "source": [ "We are now ready to start the training job. This can take a while to complete, from 20 minutes to several hours, depending on the model size, amount of data, and so on (e.g., it can take a few hours for the xl model, 40k examples and 3 epochs)." ] }, { "cell_type": "code", "execution_count": 12, "id": "068d6c5a-00a1-48d0-a2cd-86bcdd15cb0b", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1mjob name:\u001b[0m js-demo-flan-t5-small-2-2023-05-29-03-24-42-284\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:sagemaker:Creating training-job with name: js-demo-flan-t5-small-2-2023-05-29-03-24-42-284\n" ] } ], "source": [ "from sagemaker.estimator import Estimator\n", "from sagemaker.utils import name_from_base\n", "\n", "model_name = \"-\".join(model_id.split(\"-\")[2:]) # get the most informative part of ID\n", "training_job_name = name_from_base(f\"js-demo-{model_name}-{hyperparameters['epochs']}\")\n", "print(f\"{bold}job name:{unbold} {training_job_name}\")\n", "\n", "training_metric_definitions = [\n", " {\"Name\": \"val_loss\", \"Regex\": \"'eval_loss': ([0-9\\\\.]+)\"},\n", " {\"Name\": \"train_loss\", \"Regex\": \"'loss': ([0-9\\\\.]+)\"},\n", " {\"Name\": \"epoch\", \"Regex\": \"'epoch': ([0-9\\\\.]+)\"},\n", "]\n", "\n", "# Create SageMaker Estimator instance\n", "sm_estimator = Estimator(\n", " role=aws_role,\n", " image_uri=train_image_uri,\n", " model_uri=train_model_uri,\n", " source_dir=train_script_uri,\n", " entry_point=\"transfer_learning.py\",\n", " instance_count=1,\n", " instance_type=training_instance_type,\n", " volume_size=300,\n", " max_run=360000,\n", " hyperparameters=hyperparameters,\n", " output_path=output_location,\n", " metric_definitions=training_metric_definitions,\n", ")\n", "\n", "# Launch a SageMaker training job over data located in the given S3 path\n", "# Training jobs can take hours, it is recommended to set wait=False,\n", "# and monitor job status through SageMaker console\n", "sm_estimator.fit({\"training\": train_data_location}, job_name=training_job_name, wait=False)" ] }, { "cell_type": "markdown", "id": "4524dead-769d-4b5d-88b2-e9587dab474e", "metadata": {}, "source": [ "Performance metrics such as training and validation loss can be accessed through CloudWatch during training. We can also fetch the most recent snapshot of metrics as follows." ] }, { "cell_type": "code", "execution_count": 14, "id": "2683497e-dc98-4fe3-ac15-d09cd58817e6", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:sagemaker.analytics:Warning: No metrics called train_loss found\n" ] }, { "data": { "text/html": [ "
\n", " | timestamp | \n", "metric_name | \n", "value | \n", "
---|---|---|---|
0 | \n", "0.0 | \n", "val_loss | \n", "2.475903 | \n", "
1 | \n", "0.0 | \n", "epoch | \n", "1.666667 | \n", "