{ "cells": [ { "cell_type": "markdown", "id": "486225e8-12a5-499f-a620-ea41ccef1815", "metadata": {}, "source": [ "# Retrieval Augmented Generation and Chatbot Application\n", "\n", "LangChain is a framework for developing applications powered by language models. The key aspects of this framework allow us to augement the Large Models and enable us to perform tasks which meet our goals and enable our use-cases. At a high level Langchain has \n", "\n", "Data: Connect a language model to other sources of data\n", "Agent: Allow a language model to interact with its environment\n", "\n", "LangChain can be used in two major ways:\n", "\n", "
  • Indivisual Components: LangChain provides modular abstractions for the components neccessary to work with language models. LangChain also has collections of implementations for all these abstractions. The components are designed to be easy to use, regardless of whether you are using the rest of the LangChain framework or not.\n", "\n", "
  • Use-Case Specific Chains: Chains can be thought of as assembling these components in particular ways in order to best accomplish a particular use case. These are intended to be a higher level interface through which people can easily get started with a specific use case. These chains are also designed to be customizable." ] }, { "cell_type": "markdown", "id": "4789df3c-aff9-4957-a32d-80086b1f7ddb", "metadata": {}, "source": [ "## Topics covered:\n", "\n", "In this notebook we will be covering the below topics:\n", "\n", "- **LLM** Examine running an LLM in bare form to check for output\n", "- **Vector DB** Examine various vector databases like FAISS or CHROMA and leverage to produce better results using RAG\n", "- **Prompt template** Examine use of PROMPT Template\n", "- **Question Answering** Retrieval Augmented Generation (RAG)\n", "- **Chatbot** Build a Interactive Chatbot with Memory " ] }, { "cell_type": "markdown", "id": "6f1176e9-9a60-4713-b72f-9e54d2a259b8", "metadata": {}, "source": [ "## Key points for consideration\n", "\n", "1. Long Document that exceed the token limit? Ability to Chain , Mapo_reduce, Refine, Map-Rerank\n", "2. Cost of per token -- minimize the tokens and send in only relevant tokens to Model\n", "3. Which model to use --\n", " - Cohere, AI21, Huggingface Hub, Manifest, Goose AI, Writer, Banana, Modal, StochasticAI, Cerebrium, Petals, Forefront AI, Anthropic, DeepInfra, and self-hosted Models.\n", " - Example LLM cohere = Cohere(model='command-xlarge')\n", " - Example LLM flan = HuggingFaceHub(repo_id=\"google/flan-t5-xl\")\n", "4. Input Data Sources PDF, WebPages, CSV , S3, EFS\n", "5. Orchestration with External Tasks\n", " - External Tasks - Agent SerpApi, SEARCH Engines\n", " - Math Calculator\n", "6. Conversation Management and History" ] }, { "cell_type": "markdown", "id": "7de785d0-3b27-4699-87be-a34484c429fa", "metadata": {}, "source": [ "### Key components of LangChain\n", "\n", "Let us examine the key components of Langchain. At the heart and the center is the Large Model.\n", "\n", "There are several main modules that LangChain provides support for. For each module we provide some examples to get started, how-to guides, reference docs, and conceptual guides. These modules are, in increasing order of complexity:\n", "\n", "**Models**: The various model types and model integrations LangChain supports.\n", "\n", "\n", "\n", " \n", "**Prompts**: This includes prompt management, prompt optimization, and prompt serialization.\n", " \n", "\n", " \n", "**Memory**: Memory is the concept of persisting state between calls of a chain/agent. LangChain provides a standard interface for memory, a collection of memory implementations, and examples of chains/agents that use memory.\n", "\n", " \n", "**Indexes**: Language models are often more powerful when combined with your own text data - this module covers best practices for doing exactly that.\n", " \n", "\n", "\n", "**Chains**: Chains go beyond just a single LLM call, and are sequences of calls (whether to an LLM or a different utility). LangChain provides a standard interface for chains, lots of integrations with other tools, and end-to-end chains for common applications.\n", "\n", "\n", "\n", "**Agents**: Agents involve an LLM making decisions about which Actions to take, taking that Action, seeing an Observation, and repeating that until done. LangChain provides a standard interface for agents, a selection of agents to choose from, and examples of end to end agents.\n", "\n", "\n", " \n", "**Callbacks**: It can be difficult to track all that occurs inside a chain or agent. Callbacks help add a level of observability and introspection.\n", " \n", " " ] }, { "cell_type": "markdown", "id": "402714bf-14b6-4481-8e33-fc3d0b8a81f4", "metadata": {}, "source": [ "### Chat Bot key elements\n", "\n", "The first process in a chat bot is to generate embeddings. Typically you will have an ingestion process which will run through your embedding model and generate the embeddings which will be stored in a sort of a vector store. In this example we are using a GPT-J embeddings model for this\n", "\n", "\n", "\n", "Second process is the user request orchestration , interaction, invoking and returing the results\n", "\n", "\n", "\n", "For processes which need deeper analysis, conversation history we will need to summarize every interaction to keep it succinct and for that we can follow this flow below which uses PineCone as an example\n", "\n", "For the various Tools which are available \n", "\n", "" ] }, { "cell_type": "markdown", "id": "0aed6880-101b-457a-9e99-25cc421ee8c5", "metadata": {}, "source": [ "# Pre-Requisites\n", "\n", "There are a few pre-reqs to be completed when running this notebook. The key one being setting up the LLM to be used.\n", "
  • Either have a FLAN-T5 model deployed in SageMaker using Lab5 at at Deploy FlanT5-XXL from https://github.com/aws/amazon-sagemaker-examples/tree/main/inference/generativeai/llm-workshop\n", "
  • Have Anthropic Model Key . You can choose to do both or either or . However certains cells might not work if you have just 1 and so you can choose to ignore those errors as part of the run\n", "\n" ] }, { "cell_type": "markdown", "id": "af78b2be-9a3f-446a-b30e-597493664257", "metadata": {}, "source": [ "### LLM model deploy in SageMaker\n", "\n", "Make sure that you have ran the Notebook `1_deploy-flan-t5-xl.ipynb`.\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "18101583-5b29-4f1e-8d41-8777472f4d69", "metadata": { "tags": [] }, "outputs": [], "source": [ "%store -r endpoint_name" ] }, { "cell_type": "code", "execution_count": 8, "id": "375416f1-819f-42e8-be69-b25e7e2f880b", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "huggingface-text2text-flan-t5-xl-1686836752\n", "huggingface-text2text-flan-t5-xl-1686836752\n" ] } ], "source": [ "import os\n", "os.environ[\"FLAN_XL_ENDPOINT\"]=endpoint_name\n", "print(os.environ[\"FLAN_XL_ENDPOINT\"])" ] }, { "cell_type": "code", "execution_count": 9, "id": "4e6ec361-a696-4897-81f4-e83e3d724e8e", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Hit:1 http://deb.debian.org/debian buster InRelease\n", "Hit:2 http://deb.debian.org/debian buster-updates InRelease\n", "Get:3 http://security.debian.org/debian-security buster/updates InRelease [34.8 kB]\n", "Get:4 http://security.debian.org/debian-security buster/updates/main amd64 Packages [515 kB]\n", "Fetched 550 kB in 0s (1541 kB/s) \u001b[0m\u001b[33m\u001b[33m\n", "Reading package lists... Done\n", "Building dependency tree \n", "Reading state information... Done\n", "74 packages can be upgraded. Run 'apt list --upgradable' to see them.\n" ] } ], "source": [ "!apt update" ] }, { "cell_type": "code", "execution_count": 10, "id": "12130b4c-6242-4b81-866c-a1b8de3c577c", "metadata": { "scrolled": true, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Reading package lists... Done\n", "Building dependency tree \n", "Reading state information... Done\n", "wkhtmltopdf is already the newest version (0.12.5-1+deb10u1).\n", "0 upgraded, 0 newly installed, 0 to remove and 74 not upgraded.\n" ] } ], "source": [ "!apt install wkhtmltopdf -y" ] }, { "cell_type": "markdown", "id": "aa11828a-243d-4808-9c92-e8caf4cebd37", "metadata": {}, "source": [ "### Install certain libraries which are needed for this run. \n", "\n", "These are provided in the requirements.txt or you can run these cells to fine control which libraries you need" ] }, { "cell_type": "code", "execution_count": 11, "id": "c900e8c1-9f2d-4b14-9904-481a0ac5442e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pip in /opt/conda/lib/python3.8/site-packages (23.1.2)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install --upgrade pip" ] }, { "cell_type": "code", "execution_count": 12, "id": "ee2be60b-480a-4524-8a1d-3529ebcb812d", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install langchain==0.0.161 --quiet" ] }, { "cell_type": "code", "execution_count": 13, "id": "3c8923e9-69f8-4561-8df3-8eca59c965fc", "metadata": { "tags": [] }, "outputs": [], "source": [ "# !pip install chromadb==0.3.21 --quiet" ] }, { "cell_type": "code", "execution_count": 14, "id": "e828474e-c07a-4dba-badd-068ac2ff6d6d", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install langchain==0.0.161 boto3 html2text jinja2 --quiet" ] }, { "cell_type": "code", "execution_count": 15, "id": "50365364-45aa-4f80-b78e-89771afebc66", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install faiss-cpu==1.7.4 --quiet" ] }, { "cell_type": "code", "execution_count": 16, "id": "597ebd16-11c0-446b-8b4c-547b1e956c30", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install pypdf==3.8.1 --quiet" ] }, { "cell_type": "code", "execution_count": 17, "id": "f4f1b765-26a6-4bde-9491-31a2c804d26b", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install transformers==4.24.0 --quiet" ] }, { "cell_type": "code", "execution_count": 18, "id": "eae906df-a917-4b04-a36e-e8f184e6860f", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install sentence_transformers==2.2.2 --quiet" ] }, { "cell_type": "code", "execution_count": 19, "id": "c34a25aa-7b04-48e1-81e2-7a8ee8e5649e", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pdfkit in /opt/conda/lib/python3.8/site-packages (1.0.0)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install pdfkit" ] }, { "cell_type": "code", "execution_count": 20, "id": "7c2064f1-3cfa-4f19-b6cd-e14c7f16ec1f", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "'2.2.2'" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import sentence_transformers \n", "sentence_transformers.__version__" ] }, { "cell_type": "code", "execution_count": 21, "id": "c698a788-66e1-4409-9c90-0994cc072eb4", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "all libraries installed\n" ] } ], "source": [ "print(\"all libraries installed\")" ] }, { "cell_type": "markdown", "id": "59f3a19c-78d7-4e60-9ba0-cec3c856ad1a", "metadata": {}, "source": [ "### Import statements for our chain and indexers. We are not using any explicit agent here" ] }, { "cell_type": "code", "execution_count": 22, "id": "d23a816c-b828-4d6d-9bc1-ecd0c29ddc2d", "metadata": { "tags": [] }, "outputs": [], "source": [ "#from aws_langchain.kendra_index_retriever import KendraIndexRetriever\n", "from langchain.chains import ConversationalRetrievalChain\n", "from langchain import SagemakerEndpoint\n", "from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n", "from langchain.prompts import PromptTemplate\n", "import sys\n", "import json\n", "import os\n", "import time\n", "import sagemaker, boto3, json\n", "from sagemaker.session import Session\n", "from sagemaker.model import Model\n", "from sagemaker import image_uris, model_uris, script_uris, hyperparameters\n", "from sagemaker.predictor import Predictor\n", "from sagemaker.utils import name_from_base\n", "from typing import Any, Dict, List, Optional\n", "from langchain.embeddings import SagemakerEndpointEmbeddings\n", "from langchain.llms.sagemaker_endpoint import ContentHandlerBase" ] }, { "cell_type": "code", "execution_count": 23, "id": "67664670-35af-4561-a8af-1eb5967bd382", "metadata": { "tags": [] }, "outputs": [], "source": [ "import sagemaker\n", "import boto3\n", "import jinja2\n", "role = sagemaker.get_execution_role() # execution role for the endpoint" ] }, { "cell_type": "markdown", "id": "12f60eb1-c3e5-41c8-ac57-e9db8123f74f", "metadata": {}, "source": [ "### [Optional] - Deploy a GPT-J embeddings Model - so we can use that to generate the embeddings for the documents\n", "\n", "This section requires a bigger instance type `ml.g5.24xlarge` which is not available in the workshop setting. If you are running in your own account and have access to `ml.g5.24xlarge`, you can uncomment the below code to deploy the GPTJ model for using it as an embeddings model. \n", "\n", "This will be used for the RAG [document search capability](https://labelbox.com/blog/how-vector-similarity-search-works/) and needs a g5.24xlarge instance to run\n", "\n", "Other Embeddings posible are here. [LangChain Embeddings](https://python.langchain.com/en/latest/reference/modules/embeddings.html)" ] }, { "cell_type": "code", "execution_count": 24, "id": "786dd885-b173-4bed-977e-8492ada4e6ab", "metadata": { "tags": [] }, "outputs": [], "source": [ "# _MODEL_CONFIG_ = {\n", "# \"huggingface-textembedding-gpt-j-6b\": {\n", "# \"instance type\": \"ml.g5.24xlarge\",\n", "# \"env\": {\"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"},\n", "# },\n", "# }\n", "# # - Uncomment and set these values in case you have an instance of GPT-J deployed already \n", "# model_id = \"huggingface-textembedding-gpt-j-6b\"\n", "# # _MODEL_CONFIG_[model_id][\"endpoint_name\"] = '' \n", "# # print( f'24xlarge::{_MODEL_CONFIG_[model_id][\"endpoint_name\"]}')\n", "# #" ] }, { "cell_type": "code", "execution_count": 25, "id": "0c25a785-8832-4d27-a992-868cf8e06a62", "metadata": { "tags": [] }, "outputs": [], "source": [ "# newline, bold, unbold = \"\\n\", \"\\033[1m\", \"\\033[0m\"\n", "\n", "# for model_id in _MODEL_CONFIG_:\n", "# endpoint_name = name_from_base(f\"jumpstart-example-embedding-{model_id}\")\n", "# inference_instance_type = _MODEL_CONFIG_[model_id][\"instance type\"]\n", "\n", "# # Retrieve the inference container uri. This is the base HuggingFace container image for the default model above.\n", "# deploy_image_uri = image_uris.retrieve(\n", "# region=None,\n", "# framework=None, # automatically inferred from model_id\n", "# image_scope=\"inference\",\n", "# model_id=model_id,\n", "# model_version=model_version,\n", "# instance_type=inference_instance_type,\n", "# )\n", "# # Retrieve the model uri.\n", "# model_uri = model_uris.retrieve(\n", "# model_id=model_id, model_version=model_version, model_scope=\"inference\"\n", "# )\n", "# model_inference = Model(\n", "# image_uri=deploy_image_uri,\n", "# model_data=model_uri,\n", "# role=role,\n", "# predictor_cls=Predictor,\n", "# name=endpoint_name,\n", "# env=_MODEL_CONFIG_[model_id][\"env\"],\n", "# )\n", "# model_predictor_inference = model_inference.deploy(\n", "# initial_instance_count=1,\n", "# instance_type=inference_instance_type,\n", "# predictor_cls=Predictor,\n", "# endpoint_name=endpoint_name,\n", "# )\n", "# print(f\"{bold}Model {model_id} has been deployed successfully.{unbold}{newline}\")\n", "# _MODEL_CONFIG_[model_id][\"endpoint_name\"] = endpoint_name" ] }, { "cell_type": "code", "execution_count": 26, "id": "05a437da-b56d-403c-a720-60f1f407a87a", "metadata": {}, "outputs": [], "source": [ "# from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler, SagemakerEndpointEmbeddings\n", "# from langchain.embeddings.base import Embeddings\n", "# from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n", "# import numpy as np\n", "# import boto3\n", "# import os\n", "\n", "# class SagemakerEndpointEmbeddingsLMI(SagemakerEndpointEmbeddings):\n", "# def embed_documents(self, texts: List[str], chunk_size: int = 5) -> List[List[float]]:\n", "# \"\"\"Compute doc embeddings using a SageMaker Inference Endpoint.\n", "\n", "# Args:\n", "# texts: The list of texts to embed.\n", "# chunk_size: The chunk size defines how many input texts will\n", "# be grouped together as request. If None, will use the\n", "# chunk size specified by the class.\n", "\n", "# Returns:\n", "# List of embeddings, one for each text.\n", "# \"\"\"\n", "# results = []\n", "# _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size\n", "\n", "# for i in range(0, len(texts), _chunk_size):\n", "# response = self._embedding_func(texts[i : i + _chunk_size])\n", "# print()\n", "# results.extend(response)\n", "# return results\n", "\n", "\n", "# class ContentHandlerEmbdSM(EmbeddingsContentHandler): #ContentHandlerBase):\n", "# content_type = \"application/json\"\n", "# accepts = \"application/json\"\n", "\n", "# def transform_input(self, prompt: str, model_kwargs={}) -> bytes:\n", "# input_str = json.dumps({\"text_inputs\": prompt, **model_kwargs})\n", "# #input_str = json.dumps({\"inputs\": prompt, \"parameters\": model_kwargs})\n", "# return input_str.encode(\"utf-8\")\n", "\n", "# def transform_output(self, output: bytes) -> str:\n", "# response_json = json.loads(output.read().decode(\"utf-8\"))\n", "# #print(f\"EMBEDDINGS::RESPONSE:{response_json}::\")\n", "# embeddings = response_json[\"embedding\"]\n", "# print(f\"EMBEDDINGS::RESPONSE::len[0]:{len(embeddings[0])}::current shape -- > {np.array(embeddings).shape}:: shape after unsqueeze -- > {np.array([embeddings]).shape}\")\n", "# if len(embeddings) == 1: # for the query embeddings - should be 1D vector because faiss will unsqueeze it \n", "# print(f\"EMBEDDINGS::returning:NO:SQUEEZE:: RESPONSE:{np.array(embeddings).shape}::\")\n", "# return embeddings #[0]\n", "# return embeddings # embeddings expected to be of shape 2D List[List[float]] -- >array 1 row with n dimensions\n", "\n", "\n", "# assumed_role = os.getenv('LANGCHAIN_ASSUME_ROLE', None)\n", "# print(assumed_role)\n", "# boto3_kwargs = {}\n", "# session = boto3.Session()\n", "# if assumed_role:\n", "# sts = session.client(\"sts\")\n", "# response = sts.assume_role(\n", "# RoleArn=str(assumed_role), #\"arn:aws:iam::425576326687:role/SageMakerStudioDomainNoAuth-SageMakerExecutionRole-3RBLN6GPZ46O\",\n", "# RoleSessionName=\"langchain-llm-1\"\n", "# )\n", "# print(response)\n", "# boto3_kwargs = dict(\n", "# aws_access_key_id=response['Credentials']['AccessKeyId'],\n", "# aws_secret_access_key=response['Credentials']['SecretAccessKey'],\n", "# aws_session_token=response['Credentials']['SessionToken']\n", "# )\n", "\n", "# boto3_sm_client = boto3.client(\n", "# \"sagemaker-runtime\",\n", "# **boto3_kwargs\n", "# )\n", "# print(boto3_sm_client)\n", "# content_handler_embd_sm = ContentHandlerEmbdSM()\n", "# hf_embeddings = SagemakerEndpointEmbeddingsLMI(\n", "# client = boto3_sm_client,\n", "# endpoint_name=_MODEL_CONFIG_[\"huggingface-textembedding-gpt-j-6b\"][\"endpoint_name\"], #os.environ[\"FLAN_XXL_ENDPOINT\"],\n", "# region_name='us-east-1',\n", "# content_handler=content_handler_embd_sm,\n", "# )\n", "# hf_embeddings" ] }, { "cell_type": "markdown", "id": "59969792-5aeb-420d-9d2c-c522e2fc87bb", "metadata": {}, "source": [ "### Use HuggingFaceEmbeddings in the workshop setting. \n", "If you are in a workshop, please use the below code. If you are using GPTJ model for generating the embeddings, please comment the below cell. " ] }, { "cell_type": "code", "execution_count": 27, "id": "454875bf-dde7-4e7e-b61d-e0ed5e25c7cc", "metadata": { "tags": [] }, "outputs": [], "source": [ "from langchain.embeddings import HuggingFaceEmbeddings\n", "from typing import Any, Dict, List, Optional\n", "from pydantic import BaseModel, Extra, Field\n", "from langchain.embeddings.base import Embeddings\n", "import numpy as np\n", "\n", "model_name = \"sentence-transformers/all-mpnet-base-v2\"\n", "model_kwargs = {'device': 'cpu'}\n", "\n", "\n", "class CustomHFEmbeddings(HuggingFaceEmbeddings):\n", " def embed_documents(self, texts: List[str]) -> List[List[float]]:\n", " \"\"\"Compute doc embeddings using a HuggingFace transformer model.\n", "\n", " Args:\n", " texts: The list of texts to embed.\n", "\n", " Returns:\n", " List of embeddings, one for each text.\n", " \"\"\"\n", " texts = list(map(lambda x: x.replace(\"\\n\", \" \"), texts))\n", " embeddings = self.client.encode(texts, **self.encode_kwargs)\n", " #- (22, 1536)\n", " print(f\"CustomHFEmbeddings::embed_documents::shape:returned -- > {embeddings.shape}:\")\n", " \n", " return embeddings.tolist()\n", " def embed_query(self, text: str) -> List[float]:\n", " \"\"\"Compute query embeddings using a HuggingFace transformer model.\n", "\n", " Args:\n", " text: The text to embed.\n", "\n", " Returns:\n", " Embeddings for the text.\n", " \"\"\"\n", " text = text.replace(\"\\n\", \" \")\n", " embedding = self.client.encode(text, **self.encode_kwargs)\n", " print(f\"CustomHFEmbeddings::QUERY::shape:returned -- > {embedding.shape}:\")\n", " return embedding.tolist()\n", "\n", "hf_embeddings = CustomHFEmbeddings(model_name=model_name, model_kwargs=model_kwargs)" ] }, { "cell_type": "markdown", "id": "cf7b5d9d-154a-48bb-98e5-1b7019bb9b11", "metadata": { "tags": [] }, "source": [ "### Test the flanT5 model \n", "Testing Flan T5 model for answering a random question." ] }, { "cell_type": "code", "execution_count": 28, "id": "97fdec15-11fb-4a83-b59d-b6438c9bfec4", "metadata": { "tags": [] }, "outputs": [], "source": [ "MAX_LENGTH = 256\n", "NUM_RETURN_SEQUENCES = 1\n", "TOP_K = 0\n", "TOP_P = 0.7\n", "DO_SAMPLE = True " ] }, { "cell_type": "code", "execution_count": 29, "id": "c36ee7d5-f46c-4158-b9c9-8a290f32f642", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Question being asked is -- > Answer this question below, How can it help me? :\n" ] }, { "data": { "text/plain": [ "'{\"generated_texts\": [\"The iron will be used to raise the metal plate, which will be heated and the tin will melt.\"]}'" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "boto3_kwargs = {}\n", "session = boto3.Session()\n", "\n", "boto3_sm_client = boto3.client(\"sagemaker-runtime\")\n", "print(boto3_sm_client)\n", "prompt = f\"Answer this question below, How can it help me? \"\n", "print(f\"Question being asked is -- > {prompt}:\")\n", "\n", "payload = {'text_inputs': prompt, \n", " 'max_length': MAX_LENGTH, \n", " 'num_return_sequences': NUM_RETURN_SEQUENCES,\n", " 'top_k': TOP_K,\n", " 'top_p': TOP_P,\n", " 'do_sample': DO_SAMPLE}\n", "\n", "payload = json.dumps(payload).encode('utf-8')\n", "\n", "boto3_sm_client.invoke_endpoint(\n", " EndpointName=os.environ[\"FLAN_XL_ENDPOINT\"],\n", " Body=payload,\n", " ContentType=\"application/json\",\n", ")[\"Body\"].read().decode(\"utf8\")" ] }, { "cell_type": "markdown", "id": "a044aaa4-4d2b-44a8-aa7d-e690e08e4682", "metadata": {}, "source": [ "## Section 2: Use LangChain\n", "\n", "We will follow this pattern for the rest of the section\n", "\n", "
  • Exploring vector databases\n", "
  • Basics of QA exploring simple chains\n", "
  • Basics of chatbot\n", "
  • Going to prompt templates,\n", "
  • Exploring Chains\n" ] }, { "cell_type": "markdown", "id": "d711c743-3d72-4c46-bcb4-6870f1d78c5e", "metadata": {}, "source": [ "### Exploring Vector DataBases and Create the Embeddings. \n", "\n", "Leverage SageMaker GPT-J model or the same" ] }, { "cell_type": "markdown", "id": "9f6618e7-646a-4db0-b6c5-d2a1642aa1f6", "metadata": {}, "source": [ "#### Use the file based document to retrieve based on embeddings\n", "\n", "Run the below to visualize the Dataset" ] }, { "cell_type": "markdown", "id": "aff84015-923b-4e8f-b89a-4543f6755210", "metadata": {}, "source": [ "#### Pull in the data set" ] }, { "cell_type": "code", "execution_count": 30, "id": "c5b2b889-d6bd-4117-ad56-5dd5ef13d1f0", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "download: s3://jumpstart-cache-prod-us-east-2/training-datasets/Amazon_SageMaker_FAQs/Amazon_SageMaker_FAQs.csv to rag_data/Amazon_SageMaker_FAQs.csv\n" ] } ], "source": [ "original_data = \"s3://jumpstart-cache-prod-us-east-2/training-datasets/Amazon_SageMaker_FAQs/\"\n", "\n", "!mkdir -p rag_data\n", "!aws s3 cp --recursive $original_data rag_data" ] }, { "cell_type": "code", "execution_count": 31, "id": "70d18454-8d28-4662-a396-fb5f70730b80", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(154, 2)\n" ] }, { "data": { "text/html": [ "
    \n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
    QuestionAnswer
    0What is Amazon SageMaker?Amazon SageMaker is a fully managed service to...
    1In which Regions is Amazon SageMaker available...For a list of the supported Amazon SageMaker A...
    \n", "
    " ], "text/plain": [ " Question \\\n", "0 What is Amazon SageMaker? \n", "1 In which Regions is Amazon SageMaker available... \n", "\n", " Answer \n", "0 Amazon SageMaker is a fully managed service to... \n", "1 For a list of the supported Amazon SageMaker A... " ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import glob\n", "import os\n", "import pandas as pd\n", "\n", "all_files = glob.glob(os.path.join(\"rag_data/\", \"*.csv\"))\n", "\n", "df_knowledge = pd.concat(\n", " (pd.read_csv(f, header=None, names=[\"Question\", \"Answer\"]) for f in all_files),\n", " axis=0,\n", " ignore_index=True,\n", ")\n", "\n", "#- drop \n", "df_answer = df_knowledge.drop([\"Question\"], axis=1)\n", "\n", "print(df_knowledge.shape)\n", "df_knowledge.head(2)" ] }, { "cell_type": "code", "execution_count": 32, "id": "a233338b-1fda-4608-9116-617bc674819d", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#convert to pdf\n", "import pdfkit\n", "pdfkit.from_url('https://aws.amazon.com/sagemaker/faqs/', 'rag_data/Amazon_SageMaker_FAQs.pdf')" ] }, { "cell_type": "code", "execution_count": 33, "id": "4615a351-cb75-47f8-ba7c-1108934dae61", "metadata": { "tags": [] }, "outputs": [], "source": [ "from langchain.chains import RetrievalQA\n", "from langchain.document_loaders import TextLoader\n", "from langchain.indexes import VectorstoreIndexCreator\n", "from langchain.vectorstores import Chroma, AtlasDB, FAISS\n", "from langchain.text_splitter import CharacterTextSplitter\n", "from langchain import PromptTemplate\n", "from langchain.chains.question_answering import load_qa_chain\n", "from langchain.document_loaders.csv_loader import CSVLoader" ] }, { "cell_type": "code", "execution_count": 34, "id": "f79e90b0-7fa6-49d5-8fad-ffa7362c75be", "metadata": { "tags": [] }, "outputs": [], "source": [ "import time\n", "import sagemaker, boto3, json\n", "from sagemaker.session import Session\n", "from sagemaker.model import Model\n", "from sagemaker import image_uris, model_uris, script_uris, hyperparameters\n", "from sagemaker.predictor import Predictor\n", "from sagemaker.utils import name_from_base\n", "from typing import Any, Dict, List, Optional\n", "from langchain.embeddings import SagemakerEndpointEmbeddings\n", "from langchain.llms.sagemaker_endpoint import ContentHandlerBase" ] }, { "cell_type": "markdown", "id": "e57a8df3-0285-4fed-ab51-5071058225cc", "metadata": {}, "source": [ "#### Create the embeddings for document search" ] }, { "cell_type": "code", "execution_count": 35, "id": "ca9e3a2e-1c4d-4892-b4d6-06f0fdc2f68d", "metadata": { "tags": [] }, "outputs": [], "source": [ "from langchain.indexes import VectorstoreIndexCreator" ] }, { "cell_type": "markdown", "id": "2ad8c530-e385-4129-8912-1d2677c1b52c", "metadata": {}, "source": [ "#### Vector store indexer. \n", "\n", "This is what stores and matches the embeddings.This notebook showcases Chroma and FAISS and will be transient and in memory. The VectorStore Api's are available [here](https://python.langchain.com/en/harrison-docs-refactor-3-24/reference/modules/vectorstore.html)\n", "\n", "We will use our own Custom implementation of SageMaker Embeddings which needs a reference to the SageMaker endpoint to call the model which will return the embeddings. This will be used by the FAISS or Chroma to store in memory and be used when ever the User runs a query" ] }, { "cell_type": "markdown", "id": "cffaae9b-0070-460b-a095-7680dd5ca6cc", "metadata": {}, "source": [ "#### Use LangChain to leverage a SageMaker LLM \n", "\n", "Let's break down the above VectorstoreIndexCreator and see what's happening under the hood. Furthermore, we will see how to incorporate a customize prompt rather than using a default prompt with VectorstoreIndexCreator.\n", "\n", "Firstly, we generate embedings for each of document in the knowledge library with SageMaker embedding model.\n" ] }, { "cell_type": "code", "execution_count": 36, "id": "4b3fabc3-e5ee-452b-b8b8-c91195d07ef4", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "SageMaker LLM created at \u001b[1mSagemakerEndpoint\u001b[0m\n", "Params: {'endpoint_name': 'huggingface-text2text-flan-t5-xl-1686836752', 'model_kwargs': {'max_length': 200, 'num_return_sequences': 1, 'top_k': 250, 'top_p': 0.95, 'do_sample': False, 'temperature': 1}}::\n" ] } ], "source": [ "from langchain.llms.sagemaker_endpoint import SagemakerEndpoint\n", "from langchain.llms.sagemaker_endpoint import LLMContentHandler\n", "import ast\n", "\n", "parameters = {\n", " \"max_length\": 200,\n", " \"num_return_sequences\": 1,\n", " \"top_k\": 250,\n", " \"top_p\": 0.95,\n", " \"do_sample\": False,\n", " \"temperature\": 1,\n", "}\n", "MAX_CHARACTER_TRUNCATION=10000 # at 20k it produced garbage results\n", "\n", "class ContentHandlerSMLMI(LLMContentHandler):\n", " content_type = \"application/json\"\n", " accepts = \"application/json\"\n", "\n", " def transform_input(self, prompt: str, model_kwargs={}) -> bytes:\n", " #input_str = json.dumps({\"text_inputs\": prompt, **model_kwargs})\n", " print(f\"ContentHandlerSMLMI::LangChain:::LEN:input_str={len(prompt)}:: will truncate if > {MAX_CHARACTER_TRUNCATION}::\")\n", " if len(prompt) > MAX_CHARACTER_TRUNCATION:\n", " prompt=prompt[:MAX_CHARACTER_TRUNCATION]\n", " input_str = json.dumps({\"text_inputs\": prompt, **model_kwargs})\n", " #print(f\"ContentHandlerSMLMI::LangChain:::LEN:input_str={len(input_str)}::\")\n", " return input_str.encode(\"utf-8\")\n", "\n", " def transform_output(self, output: bytes) -> str:\n", " response_json_dict = json.loads(output.read().decode(\"utf-8\"))\n", " print(f\"ContentHandlerSMLMI::LangChain::output={response_json_dict}:\")\n", " return response_json_dict[list(response_json_dict.keys())[0]] [0]\n", "\n", "\n", "content_handler_sm_llm = ContentHandlerSMLMI()\n", "session = boto3.Session()\n", "boto3_sm_client = boto3.client(\n", " \"sagemaker-runtime\"\n", " # **boto3_kwargs\n", ")\n", "print(boto3_sm_client)\n", "\n", "\n", "sm_llm = SagemakerEndpoint(\n", " client = boto3_sm_client,\n", " endpoint_name=os.environ[\"FLAN_XL_ENDPOINT\"],\n", " region_name='us-east-1',\n", " model_kwargs=parameters,\n", " content_handler=content_handler_sm_llm,\n", ")\n", "\n", "print(f\"SageMaker LLM created at {sm_llm}::\")" ] }, { "cell_type": "markdown", "id": "89c314cc-d05d-425d-81ad-159eeb2d3c5b", "metadata": {}, "source": [ "#### Load the Data from our Documents Source. \n", "\n", "Then we will feed this into the VectorStore to create the embeddings using the loaders like [here](https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/directory_loader.html). First we will try with the SageMaker FAQ PDF document and also the IRS PDF files\n", "\n", "we will create 3 Loaders and 3 documents after doing a split on them. 1st loader for amazon faq, 2nd for some of the IRS PDF's, 3rd just for some ramdom example. For text it will be just a separate loader, text loader vs pdf" ] }, { "cell_type": "code", "execution_count": 37, "id": "ea20d1a0-cee5-45c8-be7f-a12f50191e12", "metadata": { "tags": [] }, "outputs": [], "source": [ "from langchain.document_loaders import TextLoader\n", "from langchain.document_loaders.csv_loader import CSVLoader\n", "\n", "from langchain.document_loaders import PyPDFLoader\n", "\n", "loader = PyPDFLoader(\"rag_data/Amazon_SageMaker_FAQs.pdf\")\n", "documents_aws = loader.load() # -- gives 2 docs\n", "documents_split = loader.load_and_split() # - gives 22 docs" ] }, { "cell_type": "markdown", "id": "c99f2a61-9c05-4a22-af32-df533918d719", "metadata": {}, "source": [ "vectorstore_faiss_aws = FAISS.from_documents(\n", " CharacterTextSplitter(chunk_size=300, chunk_overlap=0).split_documents(documents_aws), \n", " hf_embeddings, \n", " #k=1\n", " #**k_args\n", ")#### VectorStore as FAISS \n", "\n", "You can read up about [FAISS](https://arxiv.org/pdf/1702.08734.pdf) in memory vector store here. However for our example it will be the same \n", "\n", "Chroma\n", "\n", "[Chroma](https://www.trychroma.com/) is a super simple vector search database. The core-API consists of just four functions, allowing users to build an in-memory document-vector store. By default Chroma uses the Hugging Face transformers library to vectorize documents.\n", "\n", "Weaviate\n", "\n", "[Weaviate](https://github.com/weaviate/weaviate) is a very posh looking tool - not only does Weaviate offer a GraphQL API with support for vector search. It also allows users to vectorize their content using Weaviate's inbuilt modules or custom modules." ] }, { "cell_type": "code", "execution_count": 38, "id": "0e86402e-da63-4d2b-b477-c7fdc48812ca", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CustomHFEmbeddings::embed_documents::shape:returned -- > (1, 768):\n", "CPU times: user 131 ms, sys: 8.07 ms, total: 139 ms\n", "Wall time: 86.8 ms\n" ] } ], "source": [ "%%time\n", "from langchain.chains.question_answering import load_qa_chain\n", "from langchain.document_loaders import TextLoader\n", "from langchain.document_loaders.csv_loader import CSVLoader\n", "\n", "from langchain.document_loaders import PyPDFLoader\n", "from langchain.vectorstores import Chroma, AtlasDB, FAISS\n", "\n", "from langchain.document_loaders import PyPDFLoader\n", "import glob\n", "import os\n", "import pandas as pd\n", "from langchain.document_loaders import DirectoryLoader\n", "\n", "from langchain.indexes import VectorstoreIndexCreator\n", "from langchain.indexes.vectorstore import VectorStoreIndexWrapper\n", "\n", "k_args = {\"k\": 1}\n", "# - sub_docs = self.text_splitter.split_documents(docs)\n", "# - create Vectorstore\n", "vectorstore_faiss_aws = FAISS.from_documents(\n", " CharacterTextSplitter(chunk_size=300, chunk_overlap=0).split_documents(documents_aws), \n", " hf_embeddings, \n", " #k=1\n", " #**k_args\n", ")\n", "\n", "wrapper_store_faiss = VectorStoreIndexWrapper(vectorstore=vectorstore_faiss_aws)" ] }, { "cell_type": "markdown", "id": "218075ec-587e-4c1a-b99d-d882d264b377", "metadata": { "tags": [] }, "source": [ "#### First way of running the Query. High Level abstraction\n", "\n", "Leverage VectorStoreIndexCreator which wraps around the RetrievalQA and provides a high level API abstraction to generate the response. This is a wrapper around the underlying API's which we will explore below" ] }, { "cell_type": "code", "execution_count": 39, "id": "c0f9af51-fe45-494e-afe1-a26a0ce0798c", "metadata": { "tags": [] }, "outputs": [], "source": [ "#query=\"Simplified method for business use of home deduction\"\n", "query=\"What is SageMaker Spot Instances\"" ] }, { "cell_type": "code", "execution_count": 40, "id": "0825562b-295c-48dc-8e48-4bb17f686fcf", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CustomHFEmbeddings::QUERY::shape:returned -- > (768,):\n", "ContentHandlerSMLMI::LangChain:::LEN:input_str=264:: will truncate if > 10000::\n", "ContentHandlerSMLMI::LangChain::output={'generated_texts': ['a cloud computing service']}:\n" ] }, { "data": { "text/plain": [ "'a cloud computing service'" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "wrapper_store_faiss.query(question=\"What is Amazon SageMaker Managed Spot Instances?\",llm=sm_llm)" ] }, { "cell_type": "markdown", "id": "b8cc84bd-bc99-4840-80ee-9cbc424fedf7", "metadata": {}, "source": [ "##### Visualize Manually what is going on \n", "\n", "\n", "First we get the relevant documents based on the query by using the embeddings using the LLM summarize the outputs. These docs can be fed into the LLM to summarize and predict the answer. Here we can specify search type 'similiarity or Relevant' and K param" ] }, { "cell_type": "code", "execution_count": 41, "id": "09559113-a106-4638-b278-6d33174d602e", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CustomHFEmbeddings::QUERY::shape:returned -- > (768,):\n", "ContentHandlerSMLMI::LangChain:::LEN:input_str=6250:: will truncate if > 10000::\n", "ContentHandlerSMLMI::LangChain::output={'generated_texts': [' ']}:\n", "CustomHFEmbeddings::QUERY::shape:returned -- > (768,):\n", "1\n" ] } ], "source": [ "wrapper_store_faiss = VectorStoreIndexWrapper(vectorstore=vectorstore_faiss_aws)\n", "result_docs = wrapper_store_faiss.query_with_sources(\n", " question=\"What is Amazon SageMaker Managed Spot Instances?\",\n", " llm=sm_llm,\n", " chain_type=\"stuff\"\n", ")\n", "result_docs\n", "\n", "# - or you can use similiarity scores\n", "retriever = vectorstore_faiss_aws.as_retriever(search_type='similarity', search_kwargs={\"k\": 3})\n", "relevant_docs = retriever.get_relevant_documents(query) \n", "print(len(relevant_docs))" ] }, { "cell_type": "markdown", "id": "ead7f7b4-def1-4f3d-a89a-cebef5103f12", "metadata": {}, "source": [ "##### As a quick Test -- to do it manually, now invoke the LLM endpoint and feed the docs along with the query\n", "\n", "The results still will not come close to the answer we are expecting" ] }, { "cell_type": "code", "execution_count": 42, "id": "f35a4bd1-463e-4c17-b79d-9214edca61ee", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Question being asked is -- > What is SageMaker Spot Instances:\n" ] }, { "data": { "text/plain": [ "'{\"generated_texts\": [\"Amazon SageMaker FAQs\"]}'" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "prompt = f\"Summarize this {relevant_docs} \"\n", "print(f\"Question being asked is -- > {query}:\")\n", "payload = {'text_inputs': prompt, \n", " 'max_length': MAX_LENGTH, \n", " 'num_return_sequences': NUM_RETURN_SEQUENCES,\n", " 'top_k': TOP_K,\n", " 'top_p': TOP_P,\n", " 'do_sample': DO_SAMPLE}\n", "payload = json.dumps(payload).encode('utf-8')\n", "boto3_sm_client.invoke_endpoint(\n", " EndpointName=os.environ[\"FLAN_XL_ENDPOINT\"],\n", " Body=payload,\n", " ContentType=\"application/json\",\n", ")[\"Body\"].read().decode(\"utf8\")" ] }, { "cell_type": "markdown", "id": "19949609-e6f5-4537-9d1f-9f9bdea42b44", "metadata": {}, "source": [ "## Exploring Chains and Prompt templates\n", "IN this section we will look at the cvarious flavors of chains and prompt templates\n" ] }, { "cell_type": "markdown", "id": "31906bd6-a1cf-4699-8d8c-7ef9220fcd7e", "metadata": {}, "source": [ "#### Define a Chain\n", "\n", "[Chains](https://python.langchain.com/en/harrison-docs-refactor-3-24/modules/chains.html) are the key to having a conversation in a chatbot manner. Here we will test **MANUALLY** injecting the documents retrived by doing a similiarity search. The final result matches our previous results in any case\n", "\n", "**Simplest QA Chain with NO CONTEXT being passed.**" ] }, { "cell_type": "markdown", "id": "06015dc7-26ee-4354-8e02-3ada40baa267", "metadata": {}, "source": [ "#### PromptTemplate \n", "\n", "This can be enhanced by using a prompt template. More details [PROMPT Template](https://python.langchain.com/en/harrison-docs-refactor-3-24/modules/prompts/prompt_templates.html) \n", "\n", "We will start with a simple Chain and build up from there.\n", "\n" ] }, { "cell_type": "code", "execution_count": 43, "id": "bc49f66c-fb02-48fe-8b5c-5b0f79e7aa9f", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "PromptTemplate(input_variables=['context', 'question'], output_parser=None, partial_variables={}, template='\\n The following is a friendly conversation between a human and an AI. \\n The AI is talkative and provides lots of specific details from its context.\\n If the AI does not know the answer to a question, it truthfully says it \\n does not know.\\n {context}\\n Instruction: Based on the above documents, provide a detailed answer for, {question} Answer \"don\\'t know\" if not present in the document. Solution:\\n ', template_format='f-string', validate_template=True)" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# - assume a chat bot asks a question\n", "from langchain.prompts import PromptTemplate\n", "prompt_template = \"\"\"\n", " The following is a friendly conversation between a human and an AI. \n", " The AI is talkative and provides lots of specific details from its context.\n", " If the AI does not know the answer to a question, it truthfully says it \n", " does not know.\n", " {context}\n", " Instruction: Based on the above documents, provide a detailed answer for, {question} Answer \"don't know\" if not present in the document. Solution:\n", " \"\"\"\n", "PROMPT_T = PromptTemplate(template=prompt_template, input_variables=[\"context\", \"question\"])\n", "PROMPT_T" ] }, { "cell_type": "code", "execution_count": 44, "id": "06e236c9-e5dd-4085-8ae9-9567c639c715", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CustomHFEmbeddings::QUERY::shape:returned -- > (768,):\n", "1\n", "ContentHandlerSMLMI::LangChain:::LEN:input_str=458:: will truncate if > 10000::\n", "ContentHandlerSMLMI::LangChain::output={'generated_texts': [\"don't know\"]}:\n", "CPU times: user 115 ms, sys: 11.9 ms, total: 127 ms\n", "Wall time: 434 ms\n" ] }, { "data": { "text/plain": [ "{'output_text': \"don't know\"}" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "## -- Load and run the Chain based on the prompt\n", "query=\"What is Amazon Managed SageMaker Spot Instances?\"\n", "\n", "# - increasing he search to 8 relevant documents works for the GPT-J embeddings model\n", "relevant_docs = vectorstore_faiss_aws.as_retriever(search_type='similarity', search_kwargs={\"k\": 3}).get_relevant_documents(query) \n", "print(len(relevant_docs))\n", "chain = load_qa_chain(llm=sm_llm, prompt=PROMPT_T)\n", "result = chain({\"input_documents\": relevant_docs, \"question\": query}, return_only_outputs=True)\n", "result\n" ] }, { "cell_type": "markdown", "id": "4a5f1443-028a-48d4-807c-f0de7b677b4d", "metadata": {}, "source": [ "##### LLM Chain is another flavour for a simple chain. In reality you will be using a combination of few different chains as we will see in the chatbot section" ] }, { "cell_type": "code", "execution_count": 45, "id": "35e76361-8c1f-46f2-b238-08e5a08b213c", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ContentHandlerSMLMI::LangChain:::LEN:input_str=563:: will truncate if > 10000::\n", "ContentHandlerSMLMI::LangChain::output={'generated_texts': [\"don't know\"]}:\n", "What is Amazon SageMaker Managed Spot Instances?\n", "CPU times: user 1.76 ms, sys: 3.86 ms, total: 5.62 ms\n", "Wall time: 374 ms\n" ] }, { "data": { "text/plain": [ "\"don't know\"" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "from langchain.chains import LLMChain\n", "\n", "query=\"What is Amazon SageMaker Managed Spot Instances?\"\n", "chain_t = LLMChain(llm=sm_llm, prompt=PROMPT_T)\n", "## -- Invoke the Chain ( call LLM ) to generate the Response\n", "result = chain_t({\"context\": relevant_docs, \"question\": query}, return_only_outputs=True)\n", "print(query)\n", "result['text']" ] }, { "cell_type": "markdown", "id": "88ad6b08-2035-4828-a332-3784bb7d8075", "metadata": {}, "source": [ "#### With LangChain we do not need to manage this explictly and starting point is a RetrievalQA chain \n", "RetrievalQA chain which uses the load_qa_chain under the hood and here we retrieve the most relevant chunk of text and feed those into the language model. Below shows how it works. In most situations we will be using the complex chains by using the Chain module to get the results based on the query by the user. We use the RetrievalQA and pass in the Vector Store to get the same results\n", "\n", "However the results do not yet match our expectations" ] }, { "cell_type": "code", "execution_count": 46, "id": "d8b01db4-c3cb-4a53-93ee-24b581f8ec0b", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CustomHFEmbeddings::QUERY::shape:returned -- > (768,):\n", "ContentHandlerSMLMI::LangChain:::LEN:input_str=264:: will truncate if > 10000::\n", "ContentHandlerSMLMI::LangChain::output={'generated_texts': ['a cloud computing service']}:\n" ] }, { "data": { "text/plain": [ "'a cloud computing service'" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "qa = RetrievalQA.from_chain_type(\n", " llm=sm_llm, \n", " chain_type=\"stuff\", \n", " retriever=vectorstore_faiss_aws.as_retriever(search_type='similarity', search_kwargs={\"k\": 3})\n", " # - k of 8 brings 32k chars which is more than what our LLM can handle\n", ")\n", "\n", "#query=\"Simplified method for business use of home deduction\"\n", "query=\"What is Amazon SageMaker Managed Spot Instances?\"\n", "result = qa.run(query)\n", "result" ] }, { "cell_type": "markdown", "id": "791f09b8-b380-4988-bf2c-529db7f76538", "metadata": {}, "source": [ "#### Retrieval QA Chain\n", "\n", "You will see better results with `VectorRun` using the QA chain " ] }, { "cell_type": "code", "execution_count": 47, "id": "abdf8c6b-c008-4b3c-8a89-754e91ff8001", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CustomHFEmbeddings::QUERY::shape:returned -- > (768,):\n", "ContentHandlerSMLMI::LangChain:::LEN:input_str=264:: will truncate if > 10000::\n", "ContentHandlerSMLMI::LangChain::output={'generated_texts': ['a cloud computing service']}:\n" ] }, { "data": { "text/plain": [ "'a cloud computing service'" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "qa_prompt = RetrievalQA.from_chain_type(\n", " llm=sm_llm, \n", " chain_type=\"stuff\", \n", " retriever=vectorstore_faiss_aws.as_retriever(search_type='similarity', search_kwargs={\"k\": 3})\n", ")\n", "#query = \"Which instances can I use with Managed Spot Training in SageMaker?\"\n", "result = qa_prompt.run(query)\n", "result" ] }, { "cell_type": "markdown", "id": "60e5cdc0", "metadata": {}, "source": [ "## Chatbot application\n", "\n", "#### For the chatbot we need `context management, history, vector stores, and many other things`. We will start by with a ConversationalRetrievalChain\n", "\n", "This uses conversation memory and RetrievalQAChain which Allow for passing in chat history which can be used for follow up questions.Source: https://python.langchain.com/en/latest/modules/chains/index_examples/chat_vector_db.html\n", "\n", "Set verbose to True to see all the what is going on behind the scenes\n", "\n", "**We use Custom Prompt template to fine tune the output responses**" ] }, { "cell_type": "code", "execution_count": 51, "id": "42367663", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Starting chat bot\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "['Enter your query, q to quit'] what is Amazon SageMaker? \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CustomHFEmbeddings::QUERY::shape:returned -- > (768,):\n", "ContentHandlerSMLMI::LangChain:::LEN:input_str=242:: will truncate if > 10000::\n", "ContentHandlerSMLMI::LangChain::output={'generated_texts': ['a machine learning platform']}:\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "['Enter your query, q to quit', 'Question:what is Amazon SageMaker? \\nAI:Answer:a machine learning platform'] Can I use it for training models? \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "ContentHandlerSMLMI::LangChain:::LEN:input_str=487:: will truncate if > 10000::\n", "ContentHandlerSMLMI::LangChain::output={'generated_texts': ['What is Amazon SageMaker?']}:\n", "CustomHFEmbeddings::QUERY::shape:returned -- > (768,):\n", "ContentHandlerSMLMI::LangChain:::LEN:input_str=241:: will truncate if > 10000::\n", "ContentHandlerSMLMI::LangChain::output={'generated_texts': ['a machine learning platform']}:\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "['Enter your query, q to quit', 'Question:what is Amazon SageMaker? \\nAI:Answer:a machine learning platform', 'Question:Can I use it for training models? \\nAI:Answer:a machine learning platform'] quit\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Breaking\n", "Thank you , that was a nice chat !!\n" ] } ], "source": [ "from langchain import LLMChain\n", "from langchain.memory import ConversationBufferMemory\n", "from langchain.vectorstores import Chroma\n", "from langchain.text_splitter import CharacterTextSplitter\n", "from langchain.chains import ConversationalRetrievalChain\n", "from langchain.chains import LLMChain\n", "from langchain.chains.question_answering import load_qa_chain\n", "from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT\n", "\n", "\n", "def create_prompt_template():\n", " _template = \"\"\"\n", " Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you do not know, do not try to make up an answer.\n", " Chat History:\n", " {chat_history}\n", " Follow Up Input: {question}\n", " Standalone question:\n", " \"\"\"\n", " CONVO_QUESTION_PROMPT = PromptTemplate.from_template(_template)\n", " return CONVO_QUESTION_PROMPT\n", "memory_chain = ConversationBufferMemory(memory_key=\"chat_history\", input_key=\"question\", return_messages=True)\n", "chat_history=[]\n", "qa = ConversationalRetrievalChain.from_llm(\n", " llm=sm_llm, \n", " #retriever=vectorstore_faiss_aws.as_retriever(), \n", " retriever=vectorstore_faiss_aws.as_retriever(search_type='similarity', search_kwargs={\"k\": 3}),\n", " memory=memory_chain,\n", " #verbose=True,\n", " condense_question_prompt=create_prompt_template(), #CONDENSE_QUESTION_PROMPT, # use the condense prompt template\n", " #chain_type='map_reduce',\n", " max_tokens_limit=100\n", " #combine_docs_chain_kwargs=key_chain_args,\n", "\n", ")\n", "print(\"Starting chat bot\")\n", "input_str = ['Enter your query, q to quit']\n", "while True:\n", " query = input(str(input_str))\n", " if 'q' == query or 'quit' == query or 'Q' == query:\n", " print(\"Breaking\")\n", " break\n", " else:\n", " result = qa.run({'question':query, 'chat_history':chat_history} )\n", " input_str.append(f\"Question:{query}\\nAI:Answer:{result}\")\n", "\n", "print(\"Thank you , that was a nice chat !!\")" ] }, { "cell_type": "markdown", "id": "2ffdaaf4-17e6-45c8-acca-7c9905b8c556", "metadata": {}, "source": [ "#### Refine as Chain type with no similiarity searches" ] }, { "cell_type": "code", "execution_count": 50, "id": "257cf2f5-345b-4473-af92-ea29de832e01", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Starting Refine chat bot\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "['Enter your query, q to quit'] What is Amazon SageMaker? \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CustomHFEmbeddings::QUERY::shape:returned -- > (768,):\n", "ContentHandlerSMLMI::LangChain:::LEN:input_str=203:: will truncate if > 10000::\n", "ContentHandlerSMLMI::LangChain::output={'generated_texts': [' ']}:\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "['Enter your query, q to quit', 'Question:What is Amazon SageMaker? \\nAI:Answer: '] quit\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Breaking\n", "Thank you , that was a nice chat !!\n" ] } ], "source": [ "from langchain import LLMChain\n", "from langchain.memory import ConversationBufferMemory\n", "from langchain.vectorstores import Chroma\n", "from langchain.text_splitter import CharacterTextSplitter\n", "from langchain.chains import ConversationalRetrievalChain\n", "from langchain.chains import LLMChain\n", "from langchain.chains.question_answering import load_qa_chain\n", "from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT\n", "\n", "\n", "def create_prompt_template():\n", " \n", "\n", " _template = \"\"\"\n", " Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you do not know, do not try to make up an answer.\n", " Chat History:\n", " {chat_history}\n", " Follow Up Input: {question}\n", " Standalone question:\n", " \"\"\"\n", " CONVO_QUESTION_PROMPT = PromptTemplate.from_template(_template)\n", " return CONVO_QUESTION_PROMPT\n", "memory_chain = ConversationBufferMemory(memory_key=\"chat_history\", input_key=\"question\", return_messages=True)\n", "chat_history=[]\n", "qa = ConversationalRetrievalChain.from_llm(\n", " llm=sm_llm, \n", " retriever=vectorstore_faiss_aws.as_retriever(), \n", " #retriever=vectorstore_faiss_aws.as_retriever(search_type='similarity', search_kwargs={\"k\": 2}),\n", " memory=memory_chain,\n", " #verbose=True,\n", " condense_question_prompt=create_prompt_template(), #CONDENSE_QUESTION_PROMPT, create_prompt_template(), # use the condense prompt template\n", " chain_type='refine', #'map_rerank', #'refine', # s(['stuff', 'map_reduce', 'refine', 'map_rerank'])\n", " max_tokens_limit=100,\n", " get_chat_history=lambda h : h,\n", ") \n", "print(\"Starting Refine chat bot\")\n", "input_str = ['Enter your query, q to quit']\n", "while True:\n", " query = input(str(input_str))\n", " if 'q' == query or 'quit' == query or 'Q' == query:\n", " print(\"Breaking\")\n", " break\n", " else:\n", " result = qa.run({'question':query, 'chat_history':chat_history} )\n", " input_str.append(f\"Question:{query}\\nAI:Answer:{result}\")\n", "\n", "print(\"Thank you , that was a nice chat !!\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "d1403d1b-dddb-4439-ad0a-ec096c442f45", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "availableInstances": [ { "_defaultOrder": 0, "_isFastLaunch": true, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 4, "name": "ml.t3.medium", "vcpuNum": 2 }, { "_defaultOrder": 1, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 8, "name": "ml.t3.large", "vcpuNum": 2 }, { "_defaultOrder": 2, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.t3.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 3, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.t3.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 4, "_isFastLaunch": true, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 8, "name": "ml.m5.large", "vcpuNum": 2 }, { "_defaultOrder": 5, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.m5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 6, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.m5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 7, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.m5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 8, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.m5.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 9, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.m5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 10, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.m5.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 11, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 384, "name": "ml.m5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 12, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 8, "name": "ml.m5d.large", "vcpuNum": 2 }, { "_defaultOrder": 13, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.m5d.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 14, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.m5d.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 15, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.m5d.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 16, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.m5d.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 17, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.m5d.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 18, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.m5d.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 19, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 384, "name": "ml.m5d.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 20, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": true, "memoryGiB": 0, "name": "ml.geospatial.interactive", "supportedImageNames": [ "sagemaker-geospatial-v1-0" ], "vcpuNum": 0 }, { "_defaultOrder": 21, "_isFastLaunch": true, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 4, "name": "ml.c5.large", "vcpuNum": 2 }, { "_defaultOrder": 22, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 8, "name": "ml.c5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 23, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.c5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 24, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.c5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 25, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 72, "name": "ml.c5.9xlarge", "vcpuNum": 36 }, { "_defaultOrder": 26, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 96, "name": "ml.c5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 27, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 144, "name": "ml.c5.18xlarge", "vcpuNum": 72 }, { "_defaultOrder": 28, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.c5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 29, "_isFastLaunch": true, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.g4dn.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 30, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.g4dn.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 31, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.g4dn.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 32, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.g4dn.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 33, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.g4dn.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 34, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.g4dn.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 35, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 61, "name": "ml.p3.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 36, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "hideHardwareSpecs": false, "memoryGiB": 244, "name": "ml.p3.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 37, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 8, "hideHardwareSpecs": false, "memoryGiB": 488, "name": "ml.p3.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 38, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 8, "hideHardwareSpecs": false, "memoryGiB": 768, "name": "ml.p3dn.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 39, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.r5.large", "vcpuNum": 2 }, { "_defaultOrder": 40, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.r5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 41, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.r5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 42, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.r5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 43, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.r5.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 44, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 384, "name": "ml.r5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 45, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 512, "name": "ml.r5.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 46, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 768, "name": "ml.r5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 47, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.g5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 48, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.g5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 49, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.g5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 50, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.g5.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 51, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.g5.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 52, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.g5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 53, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "hideHardwareSpecs": false, "memoryGiB": 384, "name": "ml.g5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 54, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 8, "hideHardwareSpecs": false, "memoryGiB": 768, "name": "ml.g5.48xlarge", "vcpuNum": 192 }, { "_defaultOrder": 55, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 8, "hideHardwareSpecs": false, "memoryGiB": 1152, "name": "ml.p4d.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 56, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 8, "hideHardwareSpecs": false, "memoryGiB": 1152, "name": "ml.p4de.24xlarge", "vcpuNum": 96 } ], "instance_type": "ml.g4dn.xlarge", "kernelspec": { "display_name": "Python 3 (Data Science 2.0)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/sagemaker-data-science-38" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" } }, "nbformat": 4, "nbformat_minor": 5 }