{ "cells": [ { "cell_type": "markdown", "id": "3586d98e", "metadata": {}, "source": [ "# Text Generation (BLOOM-1b7)" ] }, { "cell_type": "markdown", "id": "519fbc71", "metadata": {}, "source": [ "---\n", "Welcome to Amazon [SageMaker JumpStart](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html)! You can use JumpStart to solve many Machine Learning tasks through one-click in SageMaker Studio, or through [SageMaker JumpStart API](https://sagemaker.readthedocs.io/en/stable/overview.html#use-prebuilt-models-with-sagemaker-jumpstart). \n", "\n", "In this demo notebook, we demonstrate how to use the JumpStart API for Text Generation. Text generation is the task of generating text which appears indistinguishable from the human-written text. It is also sometimes known as \"natural language generation\". Here, we show how to use state-of-the-art pre-trained Bloom models for Text Generation.\n", "\n", "---" ] }, { "cell_type": "markdown", "id": "ce5ec8c8", "metadata": {}, "source": [ "1. [Set Up](#1.-Set-Up)\n", "2. [Select a model](#2.-Select-a-model)\n", "3. [Retrieve JumpStart Artifacts & Deploy an Endpoint](#3.-Retrieve-JumpStart-Artifacts-&-Deploy-an-Endpoint)\n", "4. [Query endpoint and parse response](#4.-Query-endpoint-and-parse-response)\n", "5. [Clean up the endpoint](#5.-Clean-up-the-endpoint)" ] }, { "cell_type": "markdown", "id": "8146adc3", "metadata": {}, "source": [ "Note: This notebook was tested on ml.t3.medium instance in Amazon SageMaker Studio with Python 3 (Data Science) kernel and in Amazon SageMaker Notebook instance with conda_python3 kernel.\n", "\n", "Please see steps in [Onboard to Amazon SageMaker Domain Using Quick setup](https://docs.aws.amazon.com/sagemaker/latest/dg/onboard-quick-start.html) to make use of SageMaker Studio." ] }, { "cell_type": "markdown", "id": "61c8ad03", "metadata": {}, "source": [ "### 1. Set Up" ] }, { "cell_type": "markdown", "id": "10eac9d9", "metadata": {}, "source": [ "---\n", "Before executing the notebook, there are some initial steps required for set up. This notebook requires latest version of sagemaker.\n", "\n", "---" ] }, { "cell_type": "code", "execution_count": null, "id": "305266f7", "metadata": {}, "outputs": [], "source": [ "!pip install sagemaker --upgrade --quiet" ] }, { "cell_type": "markdown", "id": "b382c38e", "metadata": {}, "source": [ "#### Permissions and environment variables\n", "\n", "---\n", "To host on Amazon SageMaker, we need to set up and authenticate the use of AWS services. Here, we use the execution role associated with the current notebook as the AWS account role with SageMaker access. \n", "\n", "---" ] }, { "cell_type": "code", "execution_count": null, "id": "d2251212", "metadata": {}, "outputs": [], "source": [ "import sagemaker, boto3, json\n", "from sagemaker import get_execution_role\n", "\n", "aws_role = get_execution_role()\n", "aws_region = boto3.Session().region_name\n", "sess = sagemaker.Session()" ] }, { "cell_type": "markdown", "id": "bad654eb", "metadata": {}, "source": [ "### 2. Select a model\n", "\n", "***\n", "Here, we download jumpstart model_manifest file from the jumpstart s3 bucket, filter-out all the Text Generation models and select a model for inference. \n", "***" ] }, { "cell_type": "markdown", "id": "cf7ea6ef", "metadata": {}, "source": [ "#### Chose a model for Inference" ] }, { "cell_type": "code", "execution_count": null, "id": "032af855", "metadata": {}, "outputs": [], "source": [ "# model_version=\"*\" fetches the latest version of the model\n", "model_id, model_version = \"huggingface-textgeneration-bloom-1b7\", \"*\"" ] }, { "cell_type": "markdown", "id": "3d699480", "metadata": {}, "source": [ "### 3. Retrieve JumpStart Artifacts & Deploy an Endpoint\n", "\n", "***\n", "\n", "Using JumpStart, we can perform inference on the pre-trained model, even without fine-tuning it first on a new dataset. We start by retrieving the `deploy_image_uri`, `deploy_source_uri`, and `model_uri` for the pre-trained model. To host the pre-trained model, we create an instance of [`sagemaker.model.Model`](https://sagemaker.readthedocs.io/en/stable/api/inference/model.html) and deploy it. This may take some time. Kindly do not stop the cell or restart the kernel.\n", "\n", "***" ] }, { "cell_type": "code", "execution_count": null, "id": "37ef324f", "metadata": {}, "outputs": [], "source": [ "from sagemaker import image_uris, model_uris, script_uris, hyperparameters\n", "from sagemaker.model import Model\n", "from sagemaker.predictor import Predictor\n", "from sagemaker.utils import name_from_base\n", "\n", "\n", "endpoint_name = name_from_base(f\"jumpstart-console-infer-{model_id}\")\n", "\n", "inference_instance_type = \"ml.p3.2xlarge\"\n", "\n", "# Retrieve the inference docker container uri. This is the base HuggingFace container image for the default model above.\n", "deploy_image_uri = image_uris.retrieve(\n", " region=None,\n", " framework=None, # automatically inferred from model_id\n", " image_scope=\"inference\",\n", " model_id=model_id,\n", " model_version=model_version,\n", " instance_type=inference_instance_type,\n", ")\n", "\n", "# Retrieve the inference script uri. This includes all dependencies and scripts for model loading, inference handling etc.\n", "deploy_source_uri = script_uris.retrieve(\n", " model_id=model_id, model_version=model_version, script_scope=\"inference\"\n", ")\n", "\n", "\n", "# Retrieve the model uri. This includes the pre-trained nvidia-ssd model and parameters.\n", "model_uri = model_uris.retrieve(\n", " model_id=model_id, model_version=model_version, model_scope=\"inference\"\n", ")\n", "\n", "\n", "# Create the SageMaker model instance\n", "model = Model(\n", " image_uri=deploy_image_uri,\n", " source_dir=deploy_source_uri,\n", " model_data=model_uri,\n", " entry_point=\"inference.py\", # entry point file in source_dir and present in deploy_source_uri\n", " role=aws_role,\n", " predictor_cls=Predictor,\n", " name=endpoint_name,\n", ")\n", "\n", "# deploy the Model. Note that we need to pass Predictor class when we deploy model through Model class,\n", "# for being able to run inference through the sagemaker API.\n", "model_predictor = model.deploy(\n", " initial_instance_count=1,\n", " instance_type=inference_instance_type,\n", " predictor_cls=Predictor,\n", " endpoint_name=endpoint_name,\n", ")" ] }, { "cell_type": "markdown", "id": "d29b4ff3", "metadata": {}, "source": [ "### 4. Query endpoint and parse response\n", "\n", "---\n", "Input to the endpoint is any string of text dumped in json and encoded in `utf-8` format. Output of the endpoint is a `json` with generated text.\n", "\n", "---" ] }, { "cell_type": "code", "execution_count": null, "id": "4e9c4f20", "metadata": {}, "outputs": [], "source": [ "def query(model_predictor, text):\n", " \"\"\"Query the model predictor.\"\"\"\n", "\n", " encoded_text = json.dumps(text).encode(\"utf-8\")\n", "\n", " query_response = model_predictor.predict(\n", " encoded_text,\n", " {\n", " \"ContentType\": \"application/x-text\",\n", " \"Accept\": \"application/json\",\n", " },\n", " )\n", " return query_response\n", "\n", "\n", "def parse_response(query_response):\n", " \"\"\"Parse response and return the generated text.\"\"\"\n", "\n", " model_predictions = json.loads(query_response)\n", " generated_text = model_predictions[\"generated_text\"]\n", " return generated_text" ] }, { "cell_type": "markdown", "id": "ae1846ba", "metadata": {}, "source": [ "---\n", "Below, we put in some example input text. You can put in any text and the model predicts next words in the sequence. Longer sequences of text can be generated by calling the model repeatedly.\n", "\n", "---" ] }, { "cell_type": "code", "execution_count": null, "id": "7ff56136", "metadata": {}, "outputs": [], "source": [ "newline, bold, unbold = \"\\n\", \"\\033[1m\", \"\\033[0m\"\n", "\n", "text1 = \"As far as I am concerned, I will\"\n", "text2 = \"The movie is\"\n", "\n", "for text in [text1, text2]:\n", " query_response = query(model_predictor, text)\n", " generated_text = parse_response(query_response)\n", " print(f\"Input text: {text}{newline}\" f\"Generated text: {bold}{generated_text}{unbold}{newline}\")" ] }, { "cell_type": "markdown", "id": "c8ee7e6c", "metadata": {}, "source": [ "### 5. Advanced features\n", "\n", "***\n", "This model also supports many advanced parameters while performing inference. They include:\n", "\n", "* **max_length:** Model generates text until the output length (which includes the input context length) reaches `max_length`. If specified, it must be a positive integer.\n", "* **num_return_sequences:** Number of output sequences returned. If specified, it must be a positive integer.\n", "* **num_beams:** Number of beams used in the greedy search. If specified, it must be integer greater than or equal to `num_return_sequences`.\n", "* **no_repeat_ngram_size:** Model ensures that a sequence of words of `no_repeat_ngram_size` is not repeated in the output sequence. If specified, it must be a positive integer greater than 1.\n", "* **temperature:** Controls the randomness in the output. Higher temperature results in output sequence with low-probability words and lower temperature results in output sequence with high-probability words. If `temperature` -> 0, it results in greedy decoding. If specified, it must be a positive float.\n", "* **early_stopping:** If True, text generation is finished when all beam hypotheses reach the end of stence token. If specified, it must be boolean.\n", "* **do_sample:** If True, sample the next word as per the likelyhood. If specified, it must be boolean.\n", "* **top_k:** In each step of text generation, sample from only the `top_k` most likely words. If specified, it must be a positive integer.\n", "* **top_p:** In each step of text generation, sample from the smallest possible set of words with cumulative probability `top_p`. If specified, it must be a float between 0 and 1.\n", "* **seed:** Fix the randomized state for reproducibility. If specified, it must be an integer.\n", "\n", "We may specify any subset of the parameters mentioned above while invoking an endpoint. Next, we show an example of how to invoke endpoint with these arguments\n", "\n", "***" ] }, { "cell_type": "code", "execution_count": null, "id": "1168643e", "metadata": {}, "outputs": [], "source": [ "import json\n", "\n", "payload = {\n", " \"text_inputs\": \"My name is Lewis and I like to\",\n", " \"max_length\": 50,\n", " \"num_return_sequences\": 3,\n", " \"top_k\": 50,\n", " \"top_p\": 0.95,\n", " \"do_sample\": True,\n", "}\n", "\n", "\n", "def query_endpoint_with_json_payload(model_predictor, payload):\n", " \"\"\"Query the model predictor with json payload.\"\"\"\n", "\n", " encoded_payload = json.dumps(payload).encode(\"utf-8\")\n", "\n", " query_response = model_predictor.predict(\n", " encoded_payload,\n", " {\n", " \"ContentType\": \"application/json\",\n", " \"Accept\": \"application/json\",\n", " },\n", " )\n", " return query_response\n", "\n", "\n", "def parse_response_multiple_texts(query_response):\n", " \"\"\"Parse response and return the generated texts.\"\"\"\n", "\n", " model_predictions = json.loads(query_response)\n", " generated_texts = model_predictions[\"generated_texts\"]\n", " return generated_texts\n", "\n", "\n", "query_response = query_endpoint_with_json_payload(model_predictor, payload)\n", "generated_texts = parse_response_multiple_texts(query_response)\n", "print(f\"Input text: {text}{newline}\" f\"Generated text: {bold}{generated_texts}{unbold}{newline}\")" ] }, { "cell_type": "markdown", "id": "619635cb", "metadata": {}, "source": [ "### 6. Clean up the endpoint" ] }, { "cell_type": "code", "execution_count": null, "id": "7dbab4f6", "metadata": {}, "outputs": [], "source": [ "# Delete the SageMaker endpoint\n", "model_predictor.delete_model()\n", "model_predictor.delete_endpoint()" ] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "conda_python3", "language": "python", "name": "conda_python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" }, "pycharm": { "stem_cell": { "cell_type": "raw", "metadata": { "collapsed": false }, "source": [] } } }, "nbformat": 4, "nbformat_minor": 5 }