{ "cells": [ { "cell_type": "markdown", "id": "8af3794b", "metadata": {}, "source": [ "# Chat completion: Run Llama 2 models in SageMaker JumpStart" ] }, { "cell_type": "markdown", "id": "d7ea02ec", "metadata": {}, "source": [ "---\n", "\n", "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.\n", "\n", "![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)\n", "\n", "---" ] }, { "cell_type": "markdown", "id": "446b1b24", "metadata": {}, "source": [ "---\n", "In this demo notebook, we demonstrate how to use the SageMaker Python SDK to deploy a JumpStart model for Text Generation using the Llama 2 fine-tuned model optimized for dialogue use cases.\n", "\n", "To perform inference on these models, you need to pass custom_attributes='accept_eula=true' as part of header. This means you have read and accept the end-user-license-agreement (EULA) of the model. EULA can be found in model card description or from https://ai.meta.com/resources/models-and-libraries/llama-downloads/. By default, this notebook sets custom_attributes='accept_eula=false', so all inference requests will fail until you explicitly change this custom attribute.\n", "\n", "Note: Custom_attributes used to pass EULA are key/value pairs. The key and value are separated by '=' and pairs are separated by ';'. If the user passes the same key more than once, the last value is kept and passed to the script handler (i.e., in this case, used for conditional logic). For example, if 'accept_eula=false; accept_eula=true' is passed to the server, then 'accept_eula=true' is kept and passed to the script handler.\n", "\n", "---" ] }, { "cell_type": "markdown", "id": "35642ab2", "metadata": {}, "source": [ "## Setup\n", "\n", "***" ] }, { "cell_type": "code", "execution_count": null, "id": "6b55e677-3429-4668-b100-bd63d2a4c401", "metadata": { "tags": [] }, "outputs": [], "source": [ "%pip install --upgrade --quiet sagemaker" ] }, { "cell_type": "markdown", "id": "7d458cf0-02e2-4066-927b-25fa5ef2a07e", "metadata": {}, "source": [ "***\n", "You can continue with the default model or choose a different model: this notebook will run with the following model IDs :\n", "- `meta-textgeneration-llama-2-7b-f`\n", "- `meta-textgeneration-llama-2-13b-f`\n", "- `meta-textgeneration-llama-2-70b-f`\n", "***" ] }, { "cell_type": "code", "execution_count": null, "id": "014424a8-7f8f-46a7-8963-2c3d454878b8", "metadata": { "jumpStartAlterations": [ "modelIdVersion" ], "tags": [] }, "outputs": [], "source": [ "(\n", " model_id,\n", " model_version,\n", ") = (\n", " \"meta-textgeneration-llama-2-7b-f\",\n", " \"*\",\n", ")" ] }, { "cell_type": "markdown", "id": "11eef0dd", "metadata": {}, "source": [ "## Deploy model\n", "\n", "***\n", "You can now deploy the model using SageMaker JumpStart.\n", "***" ] }, { "cell_type": "code", "execution_count": null, "id": "9e52afae-868d-4736-881f-7180f393003a", "metadata": { "tags": [] }, "outputs": [], "source": [ "from sagemaker.jumpstart.model import JumpStartModel\n", "\n", "model = JumpStartModel(model_id=model_id)\n", "predictor = model.deploy()" ] }, { "cell_type": "markdown", "id": "47b4d109", "metadata": {}, "source": [ "### Changing instance type\n", "---\n", "\n", "\n", "Models are supported on the following instance types:\n", "\n", " - Llama 2 7B and 7B-F: `ml.g5.2xlarge`, `ml.g5.4xlarge`, `ml.g5.8xlarge`, `ml.g5.12xlarge`, `ml.g5.24xlarge`, `ml.g5.48xlarge`, `ml.p4d.24xlarge`\n", " - Llama 2 13B and 13B-F: `ml.g5.12xlarge`, `ml.g5.24xlarge`, `ml.g5.48xlarge`, `ml.p4d.24xlarge`\n", " - Llama 2 70B and 70B-F: `ml.g5.48xlarge`, `ml.p4d.24xlarge`\n", "\n", "By default, the JumpStartModel class selects a default instance type available in your region. If you would like to use a different instance type, you can do so by specifying instance type in the JumpStartModel class.\n", "\n", "`my_model = JumpStartModel(model_id=model_id, instance_type=\"ml.g5.12xlarge\")`\n", "\n", "---" ] }, { "cell_type": "markdown", "id": "5ef7207e-01ba-4ac2-b4a9-c8f6f0e1c498", "metadata": { "tags": [] }, "source": [ "## Invoke the endpoint\n", "\n", "***\n", "### Supported Parameters\n", "This model supports the following inference payload parameters:\n", "\n", "* **max_new_tokens:** Model generates text until the output length (excluding the input context length) reaches max_new_tokens. If specified, it must be a positive integer.\n", "* **temperature:** Controls the randomness in the output. Higher temperature results in output sequence with low-probability words and lower temperature results in output sequence with high-probability words. If `temperature` -> 0, it results in greedy decoding. If specified, it must be a positive float.\n", "* **top_p:** In each step of text generation, sample from the smallest possible set of words with cumulative probability `top_p`. If specified, it must be a float between 0 and 1.\n", "\n", "You may specify any subset of the parameters mentioned above while invoking an endpoint. \n", "\n", "***\n", "### Notes\n", "- This model only supports 'system', 'user' and 'assistant' roles, starting with 'system', then 'user' and alternating (u/a/u/a/u...).\n", "- If `max_new_tokens` is not defined, the model may generate up to the maximum total tokens allowed, which is 4K for these models. This may result in endpoint query timeout errors, so it is recommended to set `max_new_tokens` when possible. For 7B, 13B, and 70B models, we recommend to set `max_new_tokens` no greater than 1500, 1000, and 500 respectively, while keeping the total number of tokens less than 4K.\n", "- In order to support a 4k context length, this model has restricted query payloads to only utilize a batch size of 1. Payloads with larger batch sizes will receive an endpoint error prior to inference.\n", "\n", "***" ] }, { "cell_type": "code", "execution_count": null, "id": "c5adf9b4-c7e1-4090-aefe-9cae0d096968", "metadata": { "tags": [] }, "outputs": [], "source": [ "def print_dialog(payload, response):\n", " dialog = payload[\"inputs\"][0]\n", " for msg in dialog:\n", " print(f\"{msg['role'].capitalize()}: {msg['content']}\\n\")\n", " print(\n", " f\"> {response[0]['generation']['role'].capitalize()}: {response[0]['generation']['content']}\"\n", " )\n", " print(\"\\n==================================\\n\")" ] }, { "cell_type": "markdown", "id": "c2fbb9af", "metadata": {}, "source": [ "### Example 1" ] }, { "cell_type": "code", "execution_count": null, "id": "4cbde5e7-1068-41f9-999a-70ef04e1cbbb", "metadata": { "tags": [] }, "outputs": [], "source": [ "%%time\n", "\n", "payload = {\n", " \"inputs\": [\n", " [\n", " {\"role\": \"user\", \"content\": \"what is the recipe of mayonnaise?\"},\n", " ]\n", " ],\n", " \"parameters\": {\"max_new_tokens\": 512, \"top_p\": 0.9, \"temperature\": 0.6},\n", "}\n", "try:\n", " response = predictor.predict(payload, custom_attributes=\"accept_eula=false\")\n", " print_dialog(payload, response)\n", "except Exception as e:\n", " print(e)" ] }, { "cell_type": "markdown", "id": "5574e4e2", "metadata": {}, "source": [ "### Example 2" ] }, { "cell_type": "code", "execution_count": null, "id": "cda81ccf-0188-4117-8355-801ef98aaa48", "metadata": { "tags": [] }, "outputs": [], "source": [ "%%time\n", "\n", "payload = {\n", " \"inputs\": [\n", " [\n", " {\"role\": \"user\", \"content\": \"I am going to Paris, what should I see?\"},\n", " {\n", " \"role\": \"assistant\",\n", " \"content\": \"\"\"\\\n", "Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:\n", "\n", "1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city.\n", "2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa.\n", "3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.\n", "\n", "These are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world.\"\"\",\n", " },\n", " {\"role\": \"user\", \"content\": \"What is so great about #1?\"},\n", " ]\n", " ],\n", " \"parameters\": {\"max_new_tokens\": 512, \"top_p\": 0.9, \"temperature\": 0.6},\n", "}\n", "try:\n", " response = predictor.predict(payload, custom_attributes=\"accept_eula=false\")\n", " print_dialog(payload, response)\n", "except Exception as e:\n", " print(e)" ] }, { "cell_type": "markdown", "id": "0aa8d152", "metadata": {}, "source": [ "### Example 3" ] }, { "cell_type": "code", "execution_count": null, "id": "de6e8250-88c8-4b1c-a70b-ae5a4976e6ad", "metadata": { "tags": [] }, "outputs": [], "source": [ "%%time\n", "\n", "payload = {\n", " \"inputs\": [\n", " [\n", " {\"role\": \"system\", \"content\": \"Always answer with Haiku\"},\n", " {\"role\": \"user\", \"content\": \"I am going to Paris, what should I see?\"},\n", " ]\n", " ],\n", " \"parameters\": {\"max_new_tokens\": 512, \"top_p\": 0.9, \"temperature\": 0.6},\n", "}\n", "try:\n", " response = predictor.predict(payload, custom_attributes=\"accept_eula=false\")\n", " print_dialog(payload, response)\n", "except Exception as e:\n", " print(e)" ] }, { "cell_type": "markdown", "id": "076644d4", "metadata": {}, "source": [ "### Example 4" ] }, { "cell_type": "code", "execution_count": null, "id": "2da83b4a-1e61-495c-b509-38266f5c44eb", "metadata": { "tags": [] }, "outputs": [], "source": [ "%%time\n", "\n", "payload = {\n", " \"inputs\": [\n", " [\n", " {\n", " \"role\": \"system\",\n", " \"content\": \"Always answer with emojis\",\n", " },\n", " {\"role\": \"user\", \"content\": \"How to go from Beijing to NY?\"},\n", " ]\n", " ],\n", " \"parameters\": {\"max_new_tokens\": 512, \"top_p\": 0.9, \"temperature\": 0.6},\n", "}\n", "try:\n", " response = predictor.predict(payload, custom_attributes=\"accept_eula=false\")\n", " print_dialog(payload, response)\n", "except Exception as e:\n", " print(e)" ] }, { "cell_type": "markdown", "id": "5e062d29", "metadata": {}, "source": [ "## Clean up the endpoint" ] }, { "cell_type": "code", "execution_count": null, "id": "24cc5560", "metadata": {}, "outputs": [], "source": [ "# Delete the SageMaker endpoint\n", "predictor.delete_model()\n", "predictor.delete_endpoint()" ] }, { "cell_type": "markdown", "id": "008bb89c", "metadata": {}, "source": [ "## Notebook CI Test Results\n", "\n", "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n", "\n", "\n", "![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)\n", "\n", "![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)\n", "\n", "![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)\n", "\n", "![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)\n", "\n", "![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)\n", "\n", "![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)\n", "\n", "![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)\n", "\n", "![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)\n", "\n", "![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)\n", "\n", "![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)\n", "\n", "![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)\n", "\n", "![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)\n", "\n", "![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)\n", "\n", "![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)\n", "\n", "![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|llama-2-chat-completion.ipynb)\n", "\n" ] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "Python 3 (Data Science 2.0)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-west-2:236514542706:image/sagemaker-data-science-38" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" } }, "nbformat": 4, "nbformat_minor": 5 }