{ "cells": [ { "cell_type": "markdown", "id": "e960490f", "metadata": {}, "source": [ "# SageMaker JumpStart Foundation Models - Fine-tuning text generation GPT-J 6B model on domain specific dataset" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. \n", "\n", "![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|domain-adaption-finetuning-gpt-j-6b.ipynb)\n", "\n", "---" ] }, { "cell_type": "markdown", "id": "96f2327e", "metadata": {}, "source": [ "---\n", "Welcome to [Amazon SageMaker Built-in Algorithms](https://sagemaker.readthedocs.io/en/stable/algorithms/index.html)! You can use SageMaker Built-in algorithms to solve many Machine Learning tasks through [SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/overview.html). You can also use these algorithms through one-click in SageMaker Studio via [JumpStart](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html).\n", "\n", "In this demo notebook, we demonstrate how to use the SageMaker Python SDK for finetuning Foundation Models and deploying the trained model for inference. The Foundation models perform Text Generation task. It takes a text string as input and predicts next words in the sequence.\n", "\n", "* **How to run inference on [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6b) model without finetuning.**\n", "* **How to fine-tune [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6b) model on a domain specific dataset, and then run inference on the fine-tuned model. In particular, the example dataset we demonstrated is [publicly available SEC filing](https://www.sec.gov/edgar/searchedgar/companysearch) of Amazon from year 2021 to 2022. The expectation is that after fine-tuning, the model should be able to generate insightful text in financial domain.**\n", "* **We compare the inference result for GPT-J 6B before finetuning and after finetuning.**\n", "\n", "Note: This notebook was tested on ml.t3.medium instance in Amazon SageMaker Studio with Python 3 (Data Science) kernel and in Amazon SageMaker Notebook instance with conda_python3 kernel.\n", "\n", "---" ] }, { "cell_type": "markdown", "id": "8091e1f6", "metadata": {}, "source": [ "1. [Set Up](#1.-Set-Up)\n", "2. [Select Text Generation Model GTP-J 6B](#2.-Select-Text-Generation-Model-GTP-J-6B)\n", "3. [Run Inference on the Pre-trained Model without finetuning](#3.-Run-Inference-on-the-Pre-trained-Model-without-finetuning)\n", " * [Retrieve Artifacts & Deploy an Endpoint](#3.1.-Retrieve-Artifacts-&-Deploy-an-Endpoint)\n", " * [Query endpoint and parse response](#3.2.-Query-endpoint-and-parse-response)\n", " * [Clean up the endpoint](#3.3.-Clean-up-the-endpoint)\n", "4. [Finetune the pre-trained model on a custom dataset](#4.-Fine-tune-the-pre-trained-model-on-a-custom-dataset)\n", " * [Retrieve Training artifacts](#4.1.-Retrieve-Training-artifacts)\n", " * [Set Training parameters](#4.2.-Set-Training-parameters)\n", " * [Train with Automatic Model Tuning](#4.3.-Train-with-Automatic-Model-Tuning-([HPO]))\n", " * [Start Training](#4.4.-Start-Training)\n", " * [Extract Training performance metrics](#4.5.-Extract-Training-performance-metrics)\n", " * [Deploy & run Inference on the fine-tuned model](#4.6.-Deploy-&-run-Inference-on-the-fine-tuned-model)" ] }, { "cell_type": "markdown", "id": "2007b31a", "metadata": {}, "source": [ "## 1. Set Up\n", "Before executing the notebook, there are some initial steps required for setup." ] }, { "cell_type": "code", "execution_count": null, "id": "39b943ff", "metadata": {}, "outputs": [], "source": [ "!pip install ipywidgets==7.0.0 --quiet\n", "!pip install --upgrade sagemaker --quiet" ] }, { "cell_type": "markdown", "id": "9b1051f6", "metadata": {}, "source": [ "To train and host on Amazon Sagemaker, we need to setup and authenticate the use of AWS services. Here, we use the execution role associated with the current notebook instance as the AWS account role with SageMaker access. It has necessary permissions, including access to your data in S3. " ] }, { "cell_type": "code", "execution_count": null, "id": "a5a3eb07", "metadata": {}, "outputs": [], "source": [ "import sagemaker, boto3, json\n", "from sagemaker.session import Session\n", "\n", "sagemaker_session = Session()\n", "aws_role = sagemaker_session.get_caller_identity_arn()\n", "aws_region = boto3.Session().region_name\n", "sess = sagemaker.Session()" ] }, { "cell_type": "markdown", "id": "ee983c64", "metadata": {}, "source": [ "## 2. Select Text Generation Model GTP-J 6B\n" ] }, { "cell_type": "code", "execution_count": null, "id": "960ca9c8", "metadata": {}, "outputs": [], "source": [ "model_id, model_version = \"huggingface-textgeneration1-gpt-j-6b\", \"*\"" ] }, { "cell_type": "markdown", "id": "c43e49a7", "metadata": {}, "source": [ "## 3. Run Inference on the Pre-trained Model without finetuning\n", "\n", "Using SageMaker, we can directly perform inference on the pre-trained [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6b) model. GPT-J 6B is an open source 6 billion parameter model released by Eleuther AI. GPT-J 6B has been trained on a large corpus of text data ([the Pile](https://pile.eleuther.ai/) dataset) and is capable of performing various natural language processing tasks such as text generation, text classification, and text summarization. " ] }, { "cell_type": "markdown", "id": "c45f9242", "metadata": {}, "source": [ "### 3.1. Retrieve Artifacts & Deploy an Endpoint\n", "\n", "We retrieve the deploy_image_uri, deploy_source_uri, and base_model_uri for the pre-trained model. To host the pre-trained model, we create an instance of [`sagemaker.model.Model`](https://sagemaker.readthedocs.io/en/stable/api/inference/model.html) and deploy it." ] }, { "cell_type": "code", "execution_count": null, "id": "fecbe672", "metadata": {}, "outputs": [], "source": [ "from sagemaker import image_uris, model_uris, script_uris\n", "from sagemaker.model import Model\n", "from sagemaker.predictor import Predictor\n", "from sagemaker.utils import name_from_base\n", "\n", "endpoint_name = name_from_base(f\"jumpstart-example-{model_id}\")\n", "\n", "inference_instance_type = \"ml.g5.12xlarge\"\n", "\n", "# Retrieve the inference docker container uri.\n", "deploy_image_uri = image_uris.retrieve(\n", " region=None,\n", " framework=None,\n", " image_scope=\"inference\",\n", " model_id=model_id,\n", " model_version=model_version,\n", " instance_type=inference_instance_type,\n", ")\n", "\n", "# Retrieve the model uri.\n", "model_uri = model_uris.retrieve(\n", " model_id=model_id, model_version=model_version, model_scope=\"inference\"\n", ")\n", "\n", "# Create the SageMaker model instance. Note that we need to pass Predictor class when we deploy model through Model class,\n", "# for being able to run inference through the sagemaker API.\n", "model = Model(\n", " image_uri=deploy_image_uri,\n", " model_data=model_uri,\n", " role=aws_role,\n", " predictor_cls=Predictor,\n", " name=endpoint_name,\n", ")\n", "\n", "# deploy the Model. TODO\n", "base_model_predictor = model.deploy(\n", " initial_instance_count=1,\n", " instance_type=inference_instance_type,\n", " endpoint_name=endpoint_name,\n", ")" ] }, { "cell_type": "markdown", "id": "cdce1248-9fa9-4617-9dfa-d432c03ad804", "metadata": {}, "source": [ "### 3.2. Query endpoint and parse response\n", "The model takes a text string as input and predicts next words in the sequence. We use three of following input examples.\n", "\n", "1. `This Form 10-K report shows that`\n", "2. `We serve consumers through`\n", "3. `Our vision is`\n", "\n", "**The input examples are related to company's perforamnce in financial report. You will see the outputs from the model without finetuning are limited in providing insightful contents.**" ] }, { "cell_type": "code", "execution_count": null, "id": "4f12c38b-fd58-4696-824d-be15e34f5a16", "metadata": {}, "outputs": [], "source": [ "import json\n", "import boto3\n", "\n", "\n", "def query_endpoint_with_json_payload(encoded_json, endpoint_name):\n", " client = boto3.client(\"runtime.sagemaker\")\n", " response = client.invoke_endpoint(\n", " EndpointName=endpoint_name, ContentType=\"application/json\", Body=encoded_json\n", " )\n", " return response\n", "\n", "\n", "def parse_response_multiple_texts(query_response):\n", " generated_text = []\n", " model_predictions = json.loads(query_response[\"Body\"].read())\n", " return model_predictions[0]" ] }, { "cell_type": "code", "execution_count": null, "id": "1c63f536-be9b-4e52-a464-d206998e7071", "metadata": {}, "outputs": [], "source": [ "parameters = {\n", " \"max_length\": 400,\n", " \"num_return_sequences\": 1,\n", " \"top_k\": 250,\n", " \"top_p\": 0.8,\n", " \"do_sample\": True,\n", " \"temperature\": 1,\n", "}\n", "\n", "res_gpt_before_finetune = []\n", "for quota_text in [\n", " \"This Form 10-K report shows that\",\n", " \"We serve consumers through\",\n", " \"Our vision is\",\n", "]:\n", " payload = {\"text_inputs\": f\"{quota_text}:\", **parameters}\n", "\n", " query_response = query_endpoint_with_json_payload(\n", " json.dumps(payload).encode(\"utf-8\"), endpoint_name=endpoint_name\n", " )\n", " generated_texts = parse_response_multiple_texts(query_response)[0][\"generated_text\"]\n", " res_gpt_before_finetune.append(generated_texts)\n", " print(generated_texts)\n", " print(\"\\n\")" ] }, { "cell_type": "markdown", "id": "a60dbad7", "metadata": {}, "source": [ "### 3.3. Clean up the endpoint" ] }, { "cell_type": "code", "execution_count": null, "id": "45f8225e", "metadata": { "pycharm": { "is_executing": true } }, "outputs": [], "source": [ "# Delete the SageMaker endpoint and the attached resources\n", "base_model_predictor.delete_model()\n", "base_model_predictor.delete_endpoint()" ] }, { "cell_type": "markdown", "id": "70950bf9", "metadata": {}, "source": [ "## 4. Fine-tune the pre-trained model on a custom dataset" ] }, { "cell_type": "markdown", "id": "1d1f6f8c", "metadata": {}, "source": [ "Fine-tuning refers to the process of taking a pre-trained language model and retraining it for a different but related task using specific data. This approach is also known as transfer learning, which involves transferring the knowledge learned from one task to another. Large language models (LLMs) like GPT-J 6B are trained on massive amounts of unlabeled data and can be fine-tuned on domain domain datasets, making the model perform better on that specific domain. \n", "\n", "We will use financial text from SEC filings to fine tune a LLM GPT-J 6B for financial applications. \n", "\n", "\n", "\n", "- **Input**: A train and an optional validation directory. Each directory contains a CSV/JSON/TXT file.\n", " - For CSV/JSON files, the train or validation data is used from the column called 'text' or the first column if no column called 'text' is found.\n", " - The number of files under train and validation (if provided) should equal to one.\n", "- **Output**: A trained model that can be deployed for inference.\n", "Below is an example of a TXT file for fine-tuning the Text Generation model. The TXT file is SEC filings of Amazon from year 2021 to 2022.\n", "\n", "---\n", "```\n", "This report includes estimates, projections, statements relating to our\n", "business plans, objectives, and expected operating results that are \u201cforward-\n", "looking statements\u201d within the meaning of the Private Securities Litigation\n", "Reform Act of 1995, Section 27A of the Securities Act of 1933, and Section 21E\n", "of the Securities Exchange Act of 1934. Forward-looking statements may appear\n", "throughout this report, including the following sections: \u201cBusiness\u201d (Part I,\n", "Item 1 of this Form 10-K), \u201cRisk Factors\u201d (Part I, Item 1A of this Form 10-K),\n", "and \u201cManagement\u2019s Discussion and Analysis of Financial Condition and Results\n", "of Operations\u201d (Part II, Item 7 of this Form 10-K). These forward-looking\n", "statements generally are identified by the words \u201cbelieve,\u201d \u201cproject,\u201d\n", "\u201cexpect,\u201d \u201canticipate,\u201d \u201cestimate,\u201d \u201cintend,\u201d \u201cstrategy,\u201d \u201cfuture,\u201d\n", "\u201copportunity,\u201d \u201cplan,\u201d \u201cmay,\u201d \u201cshould,\u201d \u201cwill,\u201d \u201cwould,\u201d \u201cwill be,\u201d \u201cwill\n", "continue,\u201d \u201cwill likely result,\u201d and similar expressions. Forward-looking\n", "statements are based on current expectations and assumptions that are subject\n", "to risks and uncertainties that may cause actual results to differ materially.\n", "We describe risks and uncertainties that could cause actual results and events\n", "to differ materially in \u201cRisk Factors,\u201d \u201cManagement\u2019s Discussion and Analysis\n", "of Financial Condition and Results of Operations,\u201d and \u201cQuantitative and\n", "Qualitative Disclosures about Market Risk\u201d (Part II, Item 7A of this Form\n", "10-K). Readers are cautioned not to place undue reliance on forward-looking\n", "statements, which speak only as of the date they are made. We undertake no\n", "obligation to update or revise publicly any forward-looking statements,\n", "whether because of new information, future events, or otherwise.\n", "\n", "GENERAL\n", "\n", "Embracing Our Future ...\n", "```\n", "---\n", "SEC filings data of Amazon is downloaded from publicly available [EDGAR](https://www.sec.gov/edgar/searchedgar/companysearch). Instruction of accessing the data is shown [here](https://www.sec.gov/os/accessing-edgar-data)." ] }, { "cell_type": "markdown", "id": "f4552dd0", "metadata": {}, "source": [ "### 4.1. Retrieve Training artifacts\n", "Here, for the selected model, we retrieve the training docker container, the training algorithm source, the pre-trained model, and a python dictionary of the training hyper-parameters that the algorithm accepts with their default values. Note that the model_version=\"*\" fetches the latest model. Also, we do need to specify the training_instance_type to fetch train_image_uri." ] }, { "cell_type": "code", "execution_count": null, "id": "c3d9fd0c", "metadata": {}, "outputs": [], "source": [ "from sagemaker import image_uris, model_uris, script_uris, hyperparameters\n", "\n", "training_instance_type = \"ml.g5.12xlarge\"\n", "\n", "# Retrieve the docker image\n", "train_image_uri = image_uris.retrieve(\n", " region=None,\n", " framework=None,\n", " model_id=model_id,\n", " model_version=model_version,\n", " image_scope=\"training\",\n", " instance_type=training_instance_type,\n", ")\n", "# Retrieve the training script\n", "train_source_uri = script_uris.retrieve(\n", " model_id=model_id, model_version=model_version, script_scope=\"training\"\n", ")\n", "# Retrieve the pre-trained model tarball to further fine-tune\n", "train_model_uri = model_uris.retrieve(\n", " model_id=model_id, model_version=model_version, model_scope=\"training\"\n", ")" ] }, { "cell_type": "markdown", "id": "c0bfe06b", "metadata": {}, "source": [ "### 4.2. Set Training parameters\n", "Now that we are done with all the setup that is needed, we are ready to fine-tune our Text Classification model. To begin, let us create a [``sageMaker.estimator.Estimator``](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html) object. This estimator will launch the training job. \n", "\n", "There are two kinds of parameters that need to be set for training. \n", "\n", "The first one are the parameters for the training job. These include: (i) Training data path. This is S3 folder in which the input data is stored, (ii) Output path: This the s3 folder in which the training output is stored. (iii) Training instance type: This indicates the type of machine on which to run the training. Typically, we use GPU instances for these training. We defined the training instance type above to fetch the correct train_image_uri. \n", "***\n", "The second set of parameters are algorithm specific training hyper-parameters. It is also used for sepcifying the model name if we want to fine-tune on the model which is not present in the dropdown list.\n", "***" ] }, { "cell_type": "code", "execution_count": null, "id": "036bac37", "metadata": {}, "outputs": [], "source": [ "# Sample training data is available in this bucket\n", "data_bucket = f\"jumpstart-cache-prod-{aws_region}\"\n", "data_prefix = \"training-datasets/sec_data\"\n", "\n", "training_dataset_s3_path = f\"s3://{data_bucket}/{data_prefix}/train/\"\n", "validation_dataset_s3_path = f\"s3://{data_bucket}/{data_prefix}/validation/\"\n", "\n", "output_bucket = sess.default_bucket()\n", "output_prefix = \"jumpstart-example-tg-train\"\n", "\n", "s3_output_location = f\"s3://{output_bucket}/{output_prefix}/output\"" ] }, { "cell_type": "markdown", "id": "8ad02cf3", "metadata": {}, "source": [ "***\n", "For algorithm specific hyper-parameters, we start by fetching python dictionary of the training hyper-parameters that the algorithm accepts with their default values. This can then be overridden to custom values.\n", "***" ] }, { "cell_type": "code", "execution_count": null, "id": "651f68c9", "metadata": {}, "outputs": [], "source": [ "from sagemaker import hyperparameters\n", "\n", "# Retrieve the default hyper-parameters for fine-tuning the model\n", "hyperparameters = hyperparameters.retrieve_default(model_id=model_id, model_version=model_version)\n", "\n", "# [Optional] Override default hyperparameters with custom values\n", "hyperparameters[\"epoch\"] = \"3\"\n", "hyperparameters[\"per_device_train_batch_size\"] = \"4\"\n", "print(hyperparameters)" ] }, { "cell_type": "markdown", "id": "a5051d41", "metadata": {}, "source": [ "### 4.3. Train with Automatic Model Tuning ([HPO](https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning.html)) \n", "***\n", "Amazon SageMaker automatic model tuning, also known as hyperparameter tuning, finds the best version of a model by running many training jobs on your dataset using the algorithm and ranges of hyperparameters that you specify. It then chooses the hyperparameter values that result in a model that performs the best, as measured by a metric that you choose. We will use a [HyperparameterTuner](https://sagemaker.readthedocs.io/en/stable/api/training/tuner.html) object to interact with Amazon SageMaker hyperparameter tuning APIs.\n", "***" ] }, { "cell_type": "code", "execution_count": null, "id": "8c271247", "metadata": {}, "outputs": [], "source": [ "from sagemaker.tuner import ContinuousParameter\n", "\n", "# Use AMT for tuning and selecting the best model\n", "use_amt = False\n", "\n", "# Define objective metric, based on which the best model will be selected.\n", "amt_metric_definitions = {\n", " \"metrics\": [{\"Name\": \"eval:loss\", \"Regex\": \"'eval_loss': ([0-9]+\\.[0-9]+)\"}],\n", " \"type\": \"Minimize\",\n", "}\n", "\n", "# You can select from the hyperparameters supported by the model, and configure ranges of values to be searched for training the optimal model.(https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-ranges.html)\n", "hyperparameter_ranges = {\n", " \"learning_rate\": ContinuousParameter(0.00001, 0.0001, scaling_type=\"Logarithmic\")\n", "}\n", "\n", "# Increase the total number of training jobs run by AMT, for increased accuracy (and training time).\n", "max_jobs = 6\n", "# Change parallel training jobs run by AMT to reduce total training time, constrained by your account limits.\n", "# if max_jobs=max_parallel_jobs then Bayesian search turns to Random.\n", "max_parallel_jobs = 2" ] }, { "cell_type": "markdown", "id": "0d9d2622", "metadata": {}, "source": [ "### 4.4. Start Training\n", "***\n", "We start by creating the estimator object with all the required assets and then launch the training job.\n", "***" ] }, { "cell_type": "code", "execution_count": null, "id": "973d923c", "metadata": {}, "outputs": [], "source": [ "from sagemaker.estimator import Estimator\n", "from sagemaker.utils import name_from_base\n", "from sagemaker.tuner import HyperparameterTuner\n", "\n", "training_job_name = name_from_base(f\"jumpstart-example-{model_id}-transfer-learning\")\n", "\n", "metric_definitions = [\n", " {\"Name\": \"train:loss\", \"Regex\": \"'loss': ([0-9]+\\.[0-9]+)\"},\n", " {\"Name\": \"eval:loss\", \"Regex\": \"'eval_loss': ([0-9]+\\.[0-9]+)\"},\n", " {\"Name\": \"eval:runtime\", \"Regex\": \"'eval_runtime': ([0-9]+\\.[0-9]+)\"},\n", " {\"Name\": \"eval:samples_per_second\", \"Regex\": \"'eval_samples_per_second': ([0-9]+\\.[0-9]+)\"},\n", " {\"Name\": \"eval:eval_steps_per_second\", \"Regex\": \"'eval_steps_per_second': ([0-9]+\\.[0-9]+)\"},\n", "]\n", "\n", "\n", "# Create SageMaker Estimator instance\n", "tg_estimator = Estimator(\n", " role=aws_role,\n", " image_uri=train_image_uri,\n", " source_dir=train_source_uri,\n", " model_uri=train_model_uri,\n", " entry_point=\"transfer_learning.py\",\n", " instance_count=1,\n", " instance_type=training_instance_type,\n", " max_run=360000,\n", " hyperparameters=hyperparameters,\n", " output_path=s3_output_location,\n", " base_job_name=training_job_name,\n", " metric_definitions=metric_definitions,\n", ")\n", "\n", "if use_amt:\n", " hp_tuner = HyperparameterTuner(\n", " tg_estimator,\n", " amt_metric_definitions[\"metrics\"][0][\"Name\"],\n", " hyperparameter_ranges,\n", " amt_metric_definitions[\"metrics\"],\n", " max_jobs=max_jobs,\n", " max_parallel_jobs=max_parallel_jobs,\n", " objective_type=amt_metric_definitions[\"type\"],\n", " base_tuning_job_name=training_job_name,\n", " )\n", "\n", " # Launch a SageMaker Tuning job to search for the best hyperparameters\n", " hp_tuner.fit({\"train\": training_dataset_s3_path, \"validation\": validation_dataset_s3_path})\n", "else:\n", " # Launch a SageMaker Training job by passing s3 path of the training data\n", " tg_estimator.fit(\n", " {\"train\": training_dataset_s3_path, \"validation\": validation_dataset_s3_path}, logs=True\n", " )" ] }, { "cell_type": "markdown", "id": "97ed581b", "metadata": {}, "source": [ "### 4.5. Extract Training performance metrics\n", "***\n", "Performance metrics such as training loss and validation accuracy/loss can be accessed through cloudwatch while the training. We can also fetch these metrics and analyze them within the notebook\n", "***" ] }, { "cell_type": "code", "execution_count": null, "id": "ce268cd7", "metadata": {}, "outputs": [], "source": [ "from sagemaker import TrainingJobAnalytics\n", "\n", "if use_amt:\n", " training_job_name = hp_tuner.best_training_job()\n", "else:\n", " training_job_name = tg_estimator.latest_training_job.job_name\n", "\n", "df = TrainingJobAnalytics(training_job_name=training_job_name).dataframe()\n", "df.head(10)" ] }, { "cell_type": "markdown", "id": "bd2d20f9", "metadata": {}, "source": [ "## 4.6. Deploy & run Inference on the fine-tuned model\n", "***\n", "A trained model does nothing on its own. We now want to use the model to perform inference. For this example, that means predicting the class label of an input sentence. We follow the same steps as in [3. Run inference on the pre-trained model](#3.-Run-inference-on-the-pre-trained-model). We start by retrieving the artifacts for deploying an endpoint. However, instead of base_predictor, we deploy the `tg_estimator` that we fine-tuned.\n", "***" ] }, { "cell_type": "code", "execution_count": null, "id": "ce738168", "metadata": {}, "outputs": [], "source": [ "inference_instance_type = \"ml.g5.12xlarge\"\n", "\n", "# Retrieve the inference docker container uri\n", "deploy_image_uri = image_uris.retrieve(\n", " region=None,\n", " framework=None,\n", " image_scope=\"inference\",\n", " model_id=model_id,\n", " model_version=model_version,\n", " instance_type=inference_instance_type,\n", ")\n", "\n", "endpoint_name_after_finetune = name_from_base(f\"jumpstart-example-{model_id}-\")\n", "\n", "# Use the estimator from the previous step to deploy to a SageMaker endpoint\n", "finetuned_predictor = (hp_tuner if use_amt else tg_estimator).deploy(\n", " initial_instance_count=1,\n", " instance_type=inference_instance_type,\n", " image_uri=deploy_image_uri,\n", " endpoint_name=endpoint_name_after_finetune,\n", ")" ] }, { "cell_type": "markdown", "id": "ea22eef2", "metadata": {}, "source": [ "Next, we query the finetuned model using the same set of examples above, parse the response and print the predictions." ] }, { "cell_type": "code", "execution_count": null, "id": "097903dd", "metadata": {}, "outputs": [], "source": [ "import json\n", "import boto3\n", "\n", "\n", "def query_endpoint_with_json_payload(encoded_json, endpoint_name):\n", " client = boto3.client(\"runtime.sagemaker\")\n", " response = client.invoke_endpoint(\n", " EndpointName=endpoint_name, ContentType=\"application/json\", Body=encoded_json\n", " )\n", " return response\n", "\n", "\n", "def parse_response_multiple_texts(query_response):\n", " generated_text = []\n", " model_predictions = json.loads(query_response[\"Body\"].read())\n", " return model_predictions[0]" ] }, { "cell_type": "markdown", "id": "f541ccf7-d26e-4088-b52a-8fd0bf93bafc", "metadata": {}, "source": [ "The outputs from fine-tune model are generated as below. We can see that after being fine-tuned, the model can generate more insightful contents related to financial domain." ] }, { "cell_type": "code", "execution_count": null, "id": "7a6aa706-99d1-47c2-b712-c279a89eb63e", "metadata": {}, "outputs": [], "source": [ "parameters = {\n", " \"max_length\": 400,\n", " \"num_return_sequences\": 1,\n", " \"top_k\": 250,\n", " \"top_p\": 0.8,\n", " \"do_sample\": True,\n", " \"temperature\": 1,\n", "}\n", "\n", "res_gpt_finetune = []\n", "for quota_text in [\n", " \"This Form 10-K report shows that\",\n", " \"We serve consumers through\",\n", " \"Our vision is\",\n", "]:\n", " payload = {\"text_inputs\": f\"{quota_text}:\", **parameters}\n", "\n", " query_response = query_endpoint_with_json_payload(\n", " json.dumps(payload).encode(\"utf-8\"), endpoint_name_after_finetune\n", " )\n", " generated_texts = parse_response_multiple_texts(query_response)[0][\"generated_text\"]\n", " res_gpt_finetune.append(generated_texts)\n", " print(generated_texts)\n", " print(\"\\n\")" ] }, { "cell_type": "markdown", "id": "dc10989f-71f7-4f0e-ba5c-5dd7b375cf69", "metadata": {}, "source": [ "We compare the outputs between the model before fine-tuning and after fine-tuning." ] }, { "cell_type": "code", "execution_count": null, "id": "280a1d7c-4236-4dd5-aeab-912ff582d498", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "\n", "pd.DataFrame(\n", " {\n", " \"Input example\": [\n", " \"This Form 10-K report shows that\",\n", " \"We serve consumers through\",\n", " \"Our vision is\",\n", " ],\n", " \"Output before finetuning\": res_gpt_before_finetune,\n", " \"Output after finetuning\": res_gpt_finetune,\n", " }\n", ")" ] }, { "cell_type": "markdown", "id": "5a3d9168", "metadata": {}, "source": [ "---\n", "Next, we clean up the deployed endpoint.\n", "\n", "---" ] }, { "cell_type": "code", "execution_count": null, "id": "49f98c21", "metadata": {}, "outputs": [], "source": [ "# Delete the SageMaker endpoint and the attached resources\n", "finetuned_predictor.delete_model()\n", "finetuned_predictor.delete_endpoint()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Notebook CI Test Results\n", "\n", "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n", "\n", "![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|domain-adaption-finetuning-gpt-j-6b.ipynb)\n", "\n", "![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|domain-adaption-finetuning-gpt-j-6b.ipynb)\n", "\n", "![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|domain-adaption-finetuning-gpt-j-6b.ipynb)\n", "\n", "![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|domain-adaption-finetuning-gpt-j-6b.ipynb)\n", "\n", "![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|domain-adaption-finetuning-gpt-j-6b.ipynb)\n", "\n", "![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|domain-adaption-finetuning-gpt-j-6b.ipynb)\n", "\n", "![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|domain-adaption-finetuning-gpt-j-6b.ipynb)\n", "\n", "![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/introduction_to_amazon_algorithms|jumpstart-foundation-models|domain-adaption-finetuning-gpt-j-6b.ipynb)\n", "\n", "![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|domain-adaption-finetuning-gpt-j-6b.ipynb)\n", "\n", "![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|domain-adaption-finetuning-gpt-j-6b.ipynb)\n", "\n", "![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|domain-adaption-finetuning-gpt-j-6b.ipynb)\n", "\n", "![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|domain-adaption-finetuning-gpt-j-6b.ipynb)\n", "\n", "![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|domain-adaption-finetuning-gpt-j-6b.ipynb)\n", "\n", "![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|domain-adaption-finetuning-gpt-j-6b.ipynb)\n", "\n", "![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|domain-adaption-finetuning-gpt-j-6b.ipynb)\n" ] } ], "metadata": { "kernelspec": { "display_name": "conda_python3", "language": "python", "name": "conda_python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }