{ "cells": [ { "cell_type": "markdown", "id": "0855b561", "metadata": {}, "source": [ "# Integrate Modern Data Architectures with Generative AI and interact using prompts for querying SQL databases & APIs" ] }, { "cell_type": "markdown", "id": "26105729-b3e3-42d0-a583-8446fff89277", "metadata": {}, "source": [ "This notebook demonstrates how **large language models, such as Amazon Titan and Claude Anthropic, accessible via [Amazon BedRock](https://aws.amazon.com/bedrock/)** interact with AWS databases, data stores, and third-party data warehousing solutions like Snowflake. We showcase this interaction 1) by generating and running SQL queries, and 2) making requests to API endpoints. We achieve all of this by using the LangChain framework, which allows the language model to interact with its environment and connect with other sources of data. The LangChain framework operates based on the following principles: calling out to a language model, being data-aware, and being agentic. " ] }, { "cell_type": "markdown", "id": "d02c8cc5-5104-44aa-bbce-ad3ca7562a29", "metadata": { "tags": [] }, "source": [ "This notebook focuses on establishing connection to one data source, consolidating metadata, and returning fact-based data points in response to user queries using LLMs and LangChain. The solution can be enhanced to add multiple data sources." ] }, { "cell_type": "markdown", "id": "a310d6ea-2ee1-4979-bb5e-b65cb892c0cd", "metadata": { "tags": [] }, "source": [ "\n", "\n" ] }, { "cell_type": "markdown", "id": "e0986ea2-f794-431f-a341-b94f0118cb7d", "metadata": { "tags": [] }, "source": [ "### Pre-requisites:\n", "1. Use kernel Base Python 3.0.\n", "2. Install the required packages.\n", "3. Run the One time Setup by entering the user input parameters, copying the dataset, setup IAM role and finally run the crawler.\n", "3. Access to the LLM API. In this notebook, Anthropic Model is used. Refer [here](https://console.anthropic.com/docs/access) for detais on how to get access to Anthropic API key.\n", "\n", "**Note - This notebook was tested on kernel - conda_python3 in Region us-east-1**" ] }, { "cell_type": "markdown", "id": "9597c6f9", "metadata": {}, "source": [ "1. Attach AmazonAthenaFullAccess, AWSGlueServiceRole in IAM.\n", "2. Add the following custom policy in IAM to grant creating policy (double click cell to get json format)." ] }, { "cell_type": "markdown", "id": "4f9731fb", "metadata": {}, "source": [ "{\n", " \"Version\": \"2012-10-17\",\n", " \"Statement\": [\n", " {\n", " \"Action\": [\n", " \"iam:AttachRolePolicy\",\n", " \"iam:CreateRole\",\n", " \"iam:CreatePolicy\",\n", " \"iam:GetRole\",\n", " \"iam:PassRole\"\n", " ],\n", " \"Effect\": \"Allow\",\n", " \"Resource\": \"*\"\n", " }\n", " ]\n", "}" ] }, { "cell_type": "markdown", "id": "5d0297e0-f2dd-464b-9254-6693c45ebafc", "metadata": { "tags": [] }, "source": [ "### Solution Walkthrough:\n", "\n", "Step 1. Connection to S3 through which LLMs can talk to your data. These channels include:\n", " - S3/Athena - to connect to the SageMaker's offline feature store on claims information. \n", " \n", "Step 2. Usage of Dynamic generation of prompt templates by populating metadata of the tables using Glue Data Catalog(GDC) as context. GDC was populated by running a crawler on the databases. Refer to the information here to create and run a glue crawler. In case of api, a line item was created in GDC data extract.\n", "\n", "Step 3. Define Functions to 1/ determine the best data channel to answer the user query, 2/ Generate response to user query\n", "\n", "Step 4. Apply user query to LLM and Langchain to determine the data channel. After determining the data channel, run the Langchain SQL Database chain to convert 'text to sql' and run the query against the source data channel. \n", "\n", "Finally, display the results.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "eac07c9c-6b99-4776-b6cd-7659d11c49df", "metadata": { "tags": [] }, "outputs": [], "source": [ "!python3 -m pip install boto3-1.26.142-py3-none-any.whl --quiet" ] }, { "cell_type": "code", "execution_count": null, "id": "f75ea08f-fe57-4774-90aa-04f1a7f57bc3", "metadata": { "tags": [] }, "outputs": [], "source": [ "!python3 -m pip install botocore-1.29.142-py3-none-any.whl --quiet" ] }, { "cell_type": "code", "execution_count": null, "id": "9556eddc-8e45-4e42-9157-213316ec468a", "metadata": { "tags": [] }, "outputs": [], "source": [ "%%writefile requirements.txt\n", "sqlalchemy==1.4.47\n", "snowflake-sqlalchemy\n", "langchain==0.0.190\n", "sqlalchemy-aurora-data-api\n", "PyAthena[SQLAlchemy]==2.25.2\n", "anthropic\n", "redshift-connector==2.0.910\n", "sqlalchemy-redshift==0.8.14\n", "snowflake\n", "streamlit\n", "streamlit-chat " ] }, { "cell_type": "code", "execution_count": null, "id": "b55d516c", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "!pip install -r requirements.txt --quiet" ] }, { "cell_type": "code", "execution_count": null, "id": "91c153cd", "metadata": { "tags": [] }, "outputs": [], "source": [ "import json\n", "import boto3\n", "\n", "import sqlalchemy\n", "from sqlalchemy import create_engine\n", "# from snowflake.sqlalchemy import URL\n", "\n", "from langchain.docstore.document import Document\n", "from langchain import PromptTemplate,SagemakerEndpoint,SQLDatabase, SQLDatabaseChain, LLMChain\n", "from langchain.llms.sagemaker_endpoint import LLMContentHandler\n", "from langchain.chains.question_answering import load_qa_chain\n", "from langchain.prompts.prompt import PromptTemplate\n", "from langchain.chains import SQLDatabaseSequentialChain\n", "\n", "from langchain.chains.api.prompt import API_RESPONSE_PROMPT\n", "from langchain.chains import APIChain\n", "from langchain.prompts.prompt import PromptTemplate\n", "from langchain.chat_models import ChatAnthropic\n", "from langchain.chains.api import open_meteo_docs\n", "\n", "from typing import Dict\n", "import time" ] }, { "cell_type": "code", "execution_count": null, "id": "12108671-0c1b-4b8a-b225-88e1f09049d5", "metadata": { "tags": [] }, "outputs": [], "source": [ "import boto3" ] }, { "cell_type": "markdown", "id": "87f05601-529a-42d4-8cab-ba0b9445e695", "metadata": { "tags": [] }, "source": [ "### One Time Setup\n", "Some of the resources needed for this notebook such as the IAM policy, AWS Glue database and Glue crawler are created through a cloud formation template. The next block of code does the setup based on user inputs.\n", "\n", "**NOTE - The next two blocks of code need to be run only for the first time.**" ] }, { "cell_type": "markdown", "id": "fb560be7-cdea-4ff8-9230-0679252ecf5d", "metadata": { "tags": [] }, "source": [ "### User Input\n" ] }, { "cell_type": "code", "execution_count": null, "id": "23ab506d-e40d-48ab-af71-82ebb78445c0", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Provide user input\n", "glue_databucket_name = 'sagemaker-studio-741094476554-9zkt2s8krvb' #Create this bucket in S3\n", "glue_db_name='ihmnick-bankadditional'\n", "glue_role= 'ihmnick-AWSGlueServiceRole-glueworkshop120'\n", "glue_crawler_name=glue_db_name+'-crawler120'" ] }, { "cell_type": "markdown", "id": "a6344f3f-40da-4ffb-8b4f-b27c52e1eb02", "metadata": { "tags": [] }, "source": [ "### Create IAM Role that runs the crawler" ] }, { "cell_type": "code", "execution_count": null, "id": "2d905b91", "metadata": { "tags": [] }, "outputs": [], "source": [ "import boto3\n", "import os\n", "# Retrieve the AWS account number\n", "sts_client = boto3.client('sts')\n", "account_number = sts_client.get_caller_identity().get('Account')\n", "# Retrieve the AWS region\n", "#region = os.environ['AWS_REGION']\n", "region = boto3.session.Session().region_name\n", "print(\"AWS Account Number:\", account_number)\n", "print(\"AWS Region:\", region)\n", "trust_policy=\"\"\"{\n", " \"Version\": \"2012-10-17\",\n", " \"Statement\": [\n", " {\n", " \"Sid\": \"\",\n", " \"Effect\": \"Allow\",\n", " \"Principal\": {\n", " \"Service\": \"glue.amazonaws.com\"\n", " },\n", " \"Action\": \"sts:AssumeRole\"\n", " }\n", " ]\n", "}\"\"\"\n", "managed_policy=\"\"\"{\n", " \"Version\": \"2012-10-17\",\n", " \"Statement\": [\n", " {\n", " \"Action\": [\n", " \"glue:*\"\n", " ],\n", " \"Resource\": [\n", " \"arn:aws:glue:\"\"\"+region+\"\"\":\"\"\"+account_number+\"\"\":catalog\",\n", " \"arn:aws:glue:\"\"\"+region+\"\"\":\"\"\"+account_number+\"\"\":database/*\",\n", " \"arn:aws:glue:\"\"\"+region+\"\"\":\"\"\"+account_number+\"\"\":table/*\"\n", " ],\n", " \"Effect\": \"Allow\",\n", " \"Sid\": \"Readcrawlerresources\"\n", " },\n", " {\n", " \"Action\": [\n", " \"logs:CreateLogGroup\",\n", " \"logs:CreateLogStream\",\n", " \"logs:PutLogEvents\"\n", " ],\n", " \"Resource\": [\n", " \"arn:aws:glue:\"\"\"+region+\"\"\":\"\"\"+account_number+\"\"\":log-group:/aws-glue/crawlers*\",\n", " \"arn:aws:logs:*:*:/aws-glue/*\",\n", " \"arn:aws:logs:*:*:/customlogs/*\"\n", " ],\n", " \"Effect\": \"Allow\",\n", " \"Sid\": \"ReadlogResources\"\n", " },\n", " {\n", " \"Action\": [\n", " \"s3:PutObject\",\n", " \"s3:GetObject\",\n", " \"s3:PutBucketLogging\",\n", " \"s3:ListBucket\",\n", " \"s3:PutBucketVersioning\"\n", " ],\n", " \"Resource\": [\n", " \"arn:aws:s3:::\"\"\"+glue_databucket_name+\"\"\"\",\n", " \"arn:aws:s3:::\"\"\"+glue_databucket_name+\"\"\"/*\"\n", " ],\n", " \"Effect\": \"Allow\",\n", " \"Sid\": \"ReadS3Resources\"\n", " }\n", " ]\n", " }\"\"\"\n", "print(managed_policy, file=open('managed-policy.json', 'w'))\n", "print(trust_policy, file=open('trust-policy.json', 'w'))" ] }, { "cell_type": "code", "execution_count": null, "id": "c4c27763-e1e4-4b41-bfa4-7ff8f6500c31", "metadata": { "tags": [] }, "outputs": [], "source": [ "%%sh -s \"$glue_role\" \n", "echo $1 \n", "glue_role=\"$1\"\n", "managed_policy_name=\"managed-policy-$glue_role\"\n", "echo $managed_policy_name\n", "aws iam create-role --role-name $glue_role --assume-role-policy-document file://trust-policy.json\n", "output=$(aws iam create-policy --policy-document file://managed-policy.json --policy-name $managed_policy_name)\n", "arn=$(echo \"$output\" | grep -oP '\"Arn\": \"\\K[^\"]+')\n", "echo \"$arn\"\n", "aws iam attach-role-policy --policy-arn $arn --role-name $glue_role" ] }, { "cell_type": "code", "execution_count": null, "id": "6fe28c2a-60c4-400a-9693-395e304c5164", "metadata": { "tags": [] }, "outputs": [], "source": [ "import boto3\n", "\n", "client = boto3.client('glue')\n", "\n", "# Create database \n", "try:\n", " response = client.create_database(\n", " DatabaseInput={\n", " 'Name': glue_db_name,\n", " 'Description': 'This database is created using Python boto3',\n", " }\n", " )\n", " print(\"Successfully created database\")\n", "except:\n", " print(\"error in creating database. Check if the database already exists\")\n", "\n", "#introducing some lag for the iam role to create\n", "time.sleep(20) \n", "\n", "# Create Glue Crawler \n", "try:\n", "\n", " response = client.create_crawler(\n", " Name=glue_crawler_name,\n", " Role=glue_role,\n", " DatabaseName=glue_db_name,\n", " Targets={\n", " 'S3Targets': [\n", " {\n", " 'Path': 's3://{BUCKET_NAME}/bank-additional/'.format(BUCKET_NAME =glue_databucket_name)\n", " }\n", " ]\n", " },\n", " TablePrefix=''\n", " )\n", " \n", " print(\"Successfully created crawler\")\n", "except:\n", " print(\"error in creating crawler. However, if the crawler already exists, the crawler will run.\")\n", "\n", "# Run the Crawler\n", "try:\n", " response = client.start_crawler(Name=glue_crawler_name )\n", " print(\"Successfully started crawler. The crawler may take 2-5 mins to detect the schema.\")\n", " while True:\n", " # Get the crawler status\n", " response = client.get_crawler(Name=glue_crawler_name)\n", " # Extract the crawler state\n", " status = response['Crawler']['State']\n", " # Print the crawler status\n", " print(f\"Crawler '{glue_crawler_name}' status: {status}\")\n", " if status == 'READY': # Replace 'READY' with the desired completed state\n", " break # Exit the loop if the desired state is reached\n", "\n", " time.sleep(10) # Sleep for 10 seconds before checking the status again\n", " \n", "except:\n", " print(\"error in starting crawler. Check the logs for the error details.\")" ] }, { "cell_type": "markdown", "id": "4132ffc3-6947-49b6-b627-fae3df870b88", "metadata": { "tags": [] }, "source": [ "Before proceeding to the next step, check the status of the crawler. It should change from RUNNING to READY. " ] }, { "cell_type": "markdown", "id": "b51d1d0e-33fb-46ca-b82f-6294ea867cae", "metadata": { "tags": [] }, "source": [ "### Step 1 - Connect to databases using SQL Alchemy. \n", "\n", "Under the hood, LangChain uses SQLAlchemy to connect to SQL databases. The SQLDatabaseChain can therefore be used with any SQL dialect supported by SQLAlchemy, \n", "such as MS SQL, MySQL, MariaDB, PostgreSQL, Oracle SQL, and SQLite. Please refer to the SQLAlchemy documentation for more information about requirements for connecting to your database. \n" ] }, { "cell_type": "markdown", "id": "e5f5ce28-9b33-4061-8655-2b297d5c24a2", "metadata": { "tags": [] }, "source": [ "**Important**: The code below establishes a database connection for data sources and Large Language Models. Please note that the solution will only work if the database connection for your sources is defined in the cell below. Please refer to the Pre-requisites section. If your use case requires data from Aurora MySQL alone, then please comment out other data sources. Furthermore, please update the cluster details and variables for Aurora MySQL accordingly." ] }, { "cell_type": "code", "execution_count": null, "id": "1583cade", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "# Define connections\n", "\n", "# Collect credentials from Secrets Manager\n", "# Refer here on how to use AWS Secrets Manager - https://docs.aws.amazon.com/secretsmanager/latest/userguide/intro.html\n", "# client = boto3.client('secretsmanager')\n", "client = boto3.client('glue')\n", "region=client.meta.region_name\n", "\n", "\n", "#LLM \n", "#get the llm api key\n", "#llm variables\n", "#Refer here for access to Anthropic API Keys https://console.anthropic.com/docs/access\n", "anthropic_secret_id = \"anthropic\"#\n", "## llm get credentials from secrets manager\n", "response = client.get_secret_value(SecretId=anthropic_secret_id)\n", "secrets_credentials = json.loads(response['SecretString'])\n", "ANTHROPIC_API_KEY = secrets_credentials['ANTHROPIC_API_KEY']\n", "#define large language model here. Make sure to set api keys for the variable ANTHROPIC_API_KEY\n", "llm = ChatAnthropic(temperature=0, anthropic_api_key=ANTHROPIC_API_KEY, max_tokens_to_sample = 512)\n", "\n", "\n", "\n", "# Connect to S3 using Athena\n", "\n", "# Athena variables\n", "connathena = f\"athena.{region}.amazonaws.com\" \n", "portathena = '443' #Update, if port is different\n", "schemaathena = glue_db_name #from user defined params\n", "s3stagingathena = f's3://{glue_databucket_name}/athenaresults/'#from cfn params\n", "wkgrpathena = 'primary'#Update, if workgroup is different\n", "# tablesathena=['dataset']#[]\n", "\n", "# Create the athena connection string\n", "connection_string = f\"awsathena+rest://@{connathena}:{portathena}/{schemaathena}?s3_staging_dir={s3stagingathena}/&work_group={wkgrpathena}\"\n", "\n", "# Create the athena SQLAlchemy engine\n", "engine_athena = create_engine(connection_string, echo=False)\n", "dbathena = SQLDatabase(engine_athena)\n", "\n", "gdc = [schemaathena] " ] }, { "cell_type": "code", "execution_count": null, "id": "c5a4f470-cbc5-44c0-a348-c49792ab2229", "metadata": { "tags": [] }, "outputs": [], "source": [ "gdc" ] }, { "cell_type": "markdown", "id": "1ea21757-b08a-438b-a5a7-79d85a9a9085", "metadata": {}, "source": [ "### Step 2 - Generate Dynamic Prompt Templates\n", "Build a consolidated view of Glue Data Catalog by combining metadata stored for all the databases in pipe delimited format." ] }, { "cell_type": "code", "execution_count": null, "id": "08a3373d-9285-4fab-81b5-51e5364590b5", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "#Generate Dynamic prompts to populate the Glue Data Catalog\n", "#harvest aws crawler metadata\n", "\n", "def parse_catalog():\n", " #Connect to Glue catalog\n", " #get metadata of redshift serverless tables\n", " columns_str=''\n", " \n", " #define glue cient\n", " glue_client = boto3.client('glue')\n", " \n", " for db in gdc:\n", " response = glue_client.get_tables(DatabaseName =db)\n", " for tables in response['TableList']:\n", " #classification in the response for s3 and other databases is different. Set classification based on the response location\n", " if tables['StorageDescriptor']['Location'].startswith('s3'): classification='s3' \n", " else: classification = tables['Parameters']['classification']\n", " for columns in tables['StorageDescriptor']['Columns']:\n", " dbname,tblname,colname=tables['DatabaseName'],tables['Name'],columns['Name']\n", " columns_str=columns_str+f'\\n{classification}|{dbname}|{tblname}|{colname}' \n", " #API\n", " ## Append the metadata of the API to the unified glue data catalog\n", " columns_str=columns_str+'\\n'+('api|meteo|weather|weather')\n", " return columns_str\n", "\n", "glue_catalog = parse_catalog()\n", "\n", "#display a few lines from the catalog\n", "print('\\n'.join(glue_catalog.splitlines()[-10:]) )\n" ] }, { "cell_type": "markdown", "id": "a94e6770-42c3-402b-a60e-9c21fb99d5f6", "metadata": { "tags": [] }, "source": [ "### Step 3 - Define Functions to 1/ determine the best data channel to answer the user query, 2/ Generate response to user query" ] }, { "cell_type": "markdown", "id": "adda3714-3f32-4480-9526-91cca37489d1", "metadata": {}, "source": [ "In this code sample, we use the Anthropic Model to generate inferences. You can utilize SageMaker JumpStart models to achieve the same. \n", "Guidance on how to use the JumpStart Models is available in the notebook - mda_with_llm_langchain_smjumpstart_flant5xl" ] }, { "cell_type": "code", "execution_count": null, "id": "4efcc59b", "metadata": { "tags": [] }, "outputs": [], "source": [ "#Function 1 'Infer Channel'\n", "#define a function that infers the channel/database/table and sets the database for querying\n", "def identify_channel(query):\n", " db = {}\n", " #Prompt 1 'Infer Channel'\n", " ##set prompt template. It instructs the llm on how to evaluate and respond to the llm. It is referred to as dynamic since glue data catalog is first getting generated and appended to the prompt.\n", " prompt_template = \"\"\"\n", " From the table below, find the database (in column database) which will contain the data (in corresponding column_names) to answer the question \n", " {query} \\n\n", " \"\"\"+glue_catalog +\"\"\" \n", " Give your answer as database == \n", " Also,give your answer as database.table == \n", " \"\"\"\n", " ##define prompt 1\n", " PROMPT_channel = PromptTemplate( template=prompt_template, input_variables=[\"query\"] )\n", "\n", " # define llm chain\n", " llm_chain = LLMChain(prompt=PROMPT_channel, llm=llm)\n", " #run the query and save to generated texts\n", " generated_texts = llm_chain.run(query)\n", " print('identified channel:', generated_texts)\n", "\n", " #set the channel from where the query can be answered\n", " if 'database' in generated_texts: \n", " channel='db'\n", " db=dbathena\n", " print(\"SET database to athena\")\n", " elif 'api' in generated_texts: \n", " channel='api'\n", " print(\"SET database to weather api\") \n", " else: raise Exception(\"User question cannot be answered by any of the channels mentioned in the catalog\")\n", " print(\"Step complete. Channel is: \", channel)\n", " \n", " return channel, db\n", "\n", "#Function 2 'Run Query'\n", "#define a function that infers the channel/database/table and sets the database for querying\n", "def run_query(query):\n", "\n", " channel, db = identify_channel(query) #call the identify channel function first\n", " \n", " ##Prompt 2 'Run Query'\n", " #after determining the data channel, run the Langchain SQL Database chain to convert 'text to sql' and run the query against the source data channel. \n", " #provide rules for running the SQL queries in default template--> table info.\n", "\n", " _DEFAULT_TEMPLATE = \"\"\"Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\n", "\n", " Do not append 'Query:' to SQLQuery.\n", " \n", " Display SQLResult after the query is run in plain english that users can understand. \n", "\n", " Provide answer in simple english statement.\n", " \n", " Only use the following tables:\n", "\n", " {table_info}\n", "\n", " Question: {input}\"\"\"\n", "\n", " PROMPT_sql = PromptTemplate(\n", " input_variables=[\"input\", \"table_info\", \"dialect\"], template=_DEFAULT_TEMPLATE\n", " )\n", "\n", " \n", " if channel=='db':\n", " db_chain = SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT_sql, verbose=True, return_intermediate_steps=False)\n", " response=db_chain.run(query)\n", " elif channel=='api':\n", " chain_api = APIChain.from_llm_and_api_docs(llm, open_meteo_docs.OPEN_METEO_DOCS, verbose=True)\n", " response=chain_api.run(query)\n", " else: raise Exception(\"Unlisted channel. Check your unified catalog\")\n", " return response\n", "\n" ] }, { "cell_type": "markdown", "id": "390a92cd-e1b4-4feb-ab7a-f97030ba7f84", "metadata": {}, "source": [ "### Step 4 - Run the run_query function that in turn calls the Langchain SQL Database chain to convert 'text to sql' and runs the query against the source data channel\n", "\n", "Some samples are provided below for test runs. Uncomment the query to run." ] }, { "cell_type": "code", "execution_count": null, "id": "f82599a2", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Enter the query\n", "## Few queries to try out - \n", "# query = \"\"\"How many people are married?\"\"\" \n", "# query = \"\"\"What is the maximum age?\"\"\" \n", "# query = \"\"\"What percentage of customers in each age group enroll for CDs?\"\"\"\n", "query = \"\"\"What is the enrollment rate for each marital status?\"\"\"\n", "# query = \"\"\"What percentage of customers in each education level enroll for CDs?\"\"\"\n", "# query = \"\"\"What percentage of customers with existing loans (housing loan, personal loan) enroll for CDs?\"\"\"\n", "# query = \"\"\"What is the average enrollment rate over time for each month/quarter?\"\"\"\n", "# query = \"\"\"What is the average enrollment rate at different time periods after the last marketing contact (e.g. 0-7 days, 8-14 days, 15-30 days, 30+ days)?\"\"\"\n", "# query = \"\"\"What is the enrollment percentage for different outcomes of the previous marketing campaign (e.g. success, failure, no campaign)?\"\"\"\n", "# query = \"\"\"What percentage of customers contacted X times enroll for CDs, versus customers contacted Y times?\"\"\"\n", "\n", "#api - product - weather\n", "# query = \"\"\"What is the weather like right now in New York City in degrees Farenheit?\"\"\"\n", "\n", "#Response from Langchain\n", "response = run_query(query)\n", "print(\"----------------------------------------------------------------------\")\n", "print(f'Q: {query} \\nA: {response}')" ] }, { "cell_type": "markdown", "id": "c6fa2500-5226-4b07-a357-19239389ada7", "metadata": {}, "source": [ "# Streamlit UI" ] }, { "cell_type": "markdown", "id": "6dca2441-6c19-4764-9ca9-84f458b409fb", "metadata": {}, "source": [ "### Write the Streamlit app" ] }, { "cell_type": "code", "execution_count": null, "id": "0c499bd9-b5bc-4d52-bd37-4ededd9d9dfd", "metadata": { "tags": [] }, "outputs": [], "source": [ "import json\n", "import boto3\n", "\n", "import sqlalchemy\n", "from sqlalchemy import create_engine\n", "# from snowflake.sqlalchemy import URL\n", "\n", "from langchain.docstore.document import Document\n", "from langchain import PromptTemplate,SagemakerEndpoint,SQLDatabase, SQLDatabaseChain, LLMChain\n", "from langchain.llms.sagemaker_endpoint import LLMContentHandler\n", "from langchain.chains.question_answering import load_qa_chain\n", "from langchain.prompts.prompt import PromptTemplate\n", "from langchain.chains import SQLDatabaseSequentialChain\n", "\n", "from langchain.chains.api.prompt import API_RESPONSE_PROMPT\n", "from langchain.chains import APIChain\n", "from langchain.prompts.prompt import PromptTemplate\n", "from langchain.chat_models import ChatAnthropic\n", "from langchain.chains.api import open_meteo_docs\n", "from langchain.memory import ConversationBufferMemory\n", "\n", "from typing import Dict\n", "import time\n", "\n", "import boto3\n", "import streamlit as st\n", "from streamlit_chat import message\n", "\n", "glue_databucket_name = 'sagemaker-studio-741094476554-9zkt2s8krvb' #Create this bucket in S3\n", "glue_db_name='ihmnick-bankadditional'\n", "glue_role= 'ihmnick-AWSGlueServiceRole-glueworkshop120'\n", "glue_crawler_name=glue_db_name+'-crawler120'\n", "\n", "\n", "client = boto3.client('glue')\n", "region=client.meta.region_name\n", "\n", "\n", "# Connect to S3 using Athena\n", "connathena=f\"athena.{region}.amazonaws.com\" \n", "portathena='443' #Update, if port is different\n", "schemaathena=glue_db_name #from user defined params\n", "s3stagingathena=f's3://{glue_databucket_name}/athenaresults/'#from cfn params\n", "wkgrpathena='primary'#Update, if workgroup is different\n", "# tablesathena=['dataset']#[]\n", "\n", "# Create the athena connection string\n", "connection_string = f\"awsathena+rest://@{connathena}:{portathena}/{schemaathena}?s3_staging_dir={s3stagingathena}/&work_group={wkgrpathena}\"\n", "\n", "# Create the athena SQLAlchemy engine\n", "engine_athena = create_engine(connection_string, echo=False)\n", "dbathena = SQLDatabase(engine_athena)\n", "gdc = [schemaathena] \n", "\n", "# Setup memory\n", "memory = ConversationBufferMemory(memory_key=\"chat_history\", return_messages=True)\n", "\n", "# Generate Dynamic prompts to populate the Glue Data Catalog\n", "# harvest aws crawler metadata\n", "def parse_catalog():\n", " # Connect to Glue catalog\n", " # Get metadata of redshift serverless tables\n", " columns_str=''\n", " \n", " # Define glue cient\n", " glue_client = boto3.client('glue')\n", " \n", " for db in gdc:\n", " response = glue_client.get_tables(DatabaseName =db)\n", " for tables in response['TableList']:\n", " #classification in the response for s3 and other databases is different. Set classification based on the response location\n", " if tables['StorageDescriptor']['Location'].startswith('s3'): classification='s3' \n", " else: classification = tables['Parameters']['classification']\n", " for columns in tables['StorageDescriptor']['Columns']:\n", " dbname,tblname,colname=tables['DatabaseName'],tables['Name'],columns['Name']\n", " columns_str=columns_str+f'\\n{classification}|{dbname}|{tblname}|{colname}' \n", " # API\n", " # Append the metadata of the API to the unified glue data catalog\n", " columns_str=columns_str+'\\n'+('api|meteo|weather|weather')\n", " print('columns_str', columns_str)\n", " return columns_str\n", " \n", "glue_catalog = parse_catalog()\n", "\n", "# Function 1 'Infer Channel'\n", "# Define a function that infers the channel/database/table and sets the database for querying\n", "def identify_channel(query):\n", " db = {}\n", " \n", " # Prompt 1 'Infer Channel'\n", " # Set prompt template. It instructs the llm on how to evaluate and respond to the llm. It is referred to as dynamic since glue data catalog is first getting generated and appended to the prompt.\n", " prompt_template = \"\"\"\n", " From the table below, find the database (in column database) which will contain the data (in corresponding column_names) to answer the question \n", " {query} \\n\n", " \"\"\"+glue_catalog +\"\"\" \n", " Give your answer as database == \n", " Also,give your answer as database.table == \n", " \"\"\"\n", " \n", " # Define prompt 1\n", " PROMPT_channel = PromptTemplate( template=prompt_template, input_variables=[\"query\"] )\n", "\n", " # define LLM chain\n", " llm_chain = LLMChain(prompt=PROMPT_channel, llm=llm)\n", " \n", " # Run the query and save to generated texts\n", " generated_texts = llm_chain.run(query)\n", " print('identified channel:', generated_texts)\n", "\n", " # Set the channel from where the query can be answered\n", " if 's3' in generated_texts: \n", " channel='db'\n", " db=dbathena\n", " print(\"SET database to athena\")\n", " elif 'api' in generated_texts: \n", " channel='api'\n", " print(\"SET database to weather api\") \n", " else: \n", " raise Exception(\"User question cannot be answered by any of the channels mentioned in the catalog\")\n", " \n", " print(\"Step complete. Channel is: \", channel)\n", " \n", " return channel, db\n", "\n", "# Define a function that infers the channel/database/table and sets the database for querying\n", "def run_query(query):\n", "\n", " channel, db = identify_channel(query) #call the identify channel function first\n", "\n", " # Prompt 2 'Run Query'\n", " # After determining the data channel, run the Langchain SQL Database chain to convert 'text to sql' and run the query against the source data channel. \n", " # provide rules for running the SQL queries in default template--> table info.\n", "\n", " _DEFAULT_TEMPLATE = \"\"\"Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\n", "\n", " Do not append 'Query:' to SQLQuery.\n", " \n", " Display SQLResult after the query is run in plain english that users can understand. \n", "\n", " Provide answer in simple english statement.\n", " \n", " Only use the following tables:\n", "\n", " {table_info}\n", "\n", " Question: {input}\"\"\"\n", "\n", " PROMPT_sql = PromptTemplate(\n", " input_variables=[\"input\", \"table_info\", \"dialect\"], template=_DEFAULT_TEMPLATE\n", " )\n", "\n", " if channel=='db':\n", " db_chain = SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT_sql, verbose=True, return_intermediate_steps=False, memory=memory, use_query_checker=True)\n", " response = db_chain.run(query)\n", " elif channel == 'api':\n", " chain_api = APIChain.from_llm_and_api_docs(llm, open_meteo_docs.OPEN_METEO_DOCS, verbose=True, memory=memory)\n", " response = chain_api.run(query)\n", " else: \n", " raise Exception(\"Unlisted channel. Check your unified catalog\")\n", " \n", " return response\n", "\n", "# Storing the chat\n", "if 'generated' not in st.session_state:\n", " st.session_state['generated'] = []\n", "\n", "if 'past' not in st.session_state:\n", " st.session_state['past'] = []\n", " \n", "def action_search():\n", " st.title('Market Research Assistant')\n", " \n", " col1, col2 = st.columns(2)\n", " with col1:\n", " query = st.text_input('**Ask a question:**', '')\n", " button_search = st.button('Ask')\n", " \n", " if query or button_search:\n", " reply = run_query(query)\n", " # store the output \n", " st.session_state.past.append(query)\n", " st.session_state.generated.append(reply)\n", "\n", " if st.session_state['generated']:\n", " for i in range(len(st.session_state['generated'])-1, -1, -1):\n", " message(st.session_state[\"generated\"][i], key=str(i))\n", " message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')\n", " \n", " # col1, col2 = st.columns(2)\n", " # with col1:\n", " # query = st.text_input('**Ask a question:**', '')\n", " # button_search = st.button('Ask')\n", " # if query or button_search:\n", " # message(query, is_user=True)\n", " # response = run_query(query)\n", " # # st.write(response)\n", " # message(response)\n", " \n", "def app_sidebar():\n", " with st.sidebar:\n", " st.write('## How to use:')\n", " description = \"\"\"Assume the role of a marketing analyst working at a bank. There is an ask to perform analysis whether a customer will enroll for a certificate of deposit (CD). In order to perform the analysis, the marketing dataset contains information on customer demographics, responses to marketing events, and external factors.\n", "\n", "- What percentage of customers in each age group enroll for CDs?\n", "- What is the enrollment rate for each marital status?\n", "- What percentage of customers in each education level enroll for CDs?\n", "- What percentage of customers with existing loans (housing loan, personal loan) enroll for CDs?\n", "- What is the average enrollment rate over time for each month/quarter?\n", "- What is the average enrollment rate at different time periods after the last marketing contact (e.g. 0-7 days, 8-14 days, 15-30 days, 30+ days)?\n", "- What is the enrollment percentage for different outcomes of the previous marketing campaign (e.g. success, failure, no campaign)?\n", "- What percentage of customers contacted X times enroll for CDs, versus customers contacted Y times?\"\"\"\n", " st.write(description)\n", " st.write('---')\n", "\n", "\n", "def main():\n", " st.set_page_config(layout=\"wide\")\n", " app_sidebar()\n", " action_search()\n", "\n", "\n", "if __name__ == '__main__':\n", " main()\n", " \n" ] }, { "cell_type": "markdown", "id": "c9b940dc-7fc3-474d-88d7-d7f34590f411", "metadata": {}, "source": [ "## Start App" ] }, { "cell_type": "markdown", "id": "445df238-cc24-458f-aab5-8b61ab1c071a", "metadata": { "tags": [] }, "source": [ "### Run Streamlit\n", "To run the application:\n", "1. Select File > New > Terminal\n", "2. In the terminal, use the command: `streamlit run app_search.py --server.runOnSave true`\n", " 1. Note: ensure you have installed all required packages\n", "3. If this is successful, you will be able to interact with the app by using the web address below\n", "4. An important thing to note is that when you run the above command, you should see an output similar to below.\n", "5. The port thats displayed is the same port that MUST be used after the `proxy` folder below.\n", "`\n", "You can now view your Streamlit app in your browser.\n", "\n", " Network URL: http://###.###.###.###:8501\\\n", " External URL: http://###.###.###.###:8501\n", "\n" ] }, { "cell_type": "markdown", "id": "24b1097f-8c0a-47d9-8c5d-1f6cc9525be9", "metadata": {}, "source": [ "#### Display Link to Application" ] }, { "cell_type": "code", "execution_count": null, "id": "e62f3a73-7f2b-4942-b771-a1f39075b627", "metadata": { "tags": [] }, "outputs": [], "source": [ "import boto3\n", "import sagemaker as sm\n", "import json\n", "\n", "def get_sagemaker_session(local_download_dir) -> sm.Session:\n", " \"\"\"\n", " # Create a SageMaker Session\n", " # This function is used to create a SageMaker Session object.\n", " # The SageMaker Session object is used to create a SageMaker Endpoint,\n", " # SageMaker Model, and SageMaker Endpoint Config.\n", " \"\"\"\n", " sagemaker_client = boto3.client(service_name=\"sagemaker\", region_name=boto3.Session().region_name)\n", " session_settings = sm.session_settings.SessionSettings(local_download_dir=local_download_dir)\n", " session = sm.session.Session(sagemaker_client=sagemaker_client, settings=session_settings)\n", " return session\n", "\n", "model_path = './download_dir'\n", "if not os.path.exists(model_path):\n", " os.mkdir(model_path)\n", " \n", "role = sm.get_execution_role()\n", "sagemaker_session = get_sagemaker_session(model_path) # sm.session.Session()\n", "region = sagemaker_session._region_name\n", "\n", "# These are needed to show where the streamlit app is hosted\n", "sagemaker_metadata = json.load(open('/opt/ml/metadata/resource-metadata.json', 'r'))\n", "domain_id = sagemaker_metadata['DomainId']\n", "resource_name = sagemaker_metadata['ResourceName']\n", "\n", "region = sagemaker_session._region_name\n", "\n", "print(f'http://{domain_id}.studio.{region}.sagemaker.aws/jupyter/default/proxy/8501/')" ] }, { "cell_type": "markdown", "id": "69371bdc-537f-4e5e-a004-99d852097862", "metadata": {}, "source": [ "### Clean-up\n", "After you run the modern data architecture with Generative AI, make sure to clean up any resources that won’t be utilized. Delete the data in Amazon S3 and make sure to stop any SageMaker Studio notebook instances to not incur any further charges.\n" ] } ], "metadata": { "availableInstances": [ { "_defaultOrder": 0, "_isFastLaunch": true, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 4, "name": "ml.t3.medium", "vcpuNum": 2 }, { "_defaultOrder": 1, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 8, "name": "ml.t3.large", "vcpuNum": 2 }, { "_defaultOrder": 2, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.t3.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 3, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.t3.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 4, "_isFastLaunch": true, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 8, "name": "ml.m5.large", "vcpuNum": 2 }, { "_defaultOrder": 5, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.m5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 6, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.m5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 7, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.m5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 8, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.m5.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 9, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.m5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 10, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.m5.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 11, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 384, "name": "ml.m5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 12, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 8, "name": "ml.m5d.large", "vcpuNum": 2 }, { "_defaultOrder": 13, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.m5d.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 14, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.m5d.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 15, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.m5d.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 16, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.m5d.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 17, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.m5d.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 18, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.m5d.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 19, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 384, "name": "ml.m5d.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 20, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": true, "memoryGiB": 0, "name": "ml.geospatial.interactive", "supportedImageNames": [ "sagemaker-geospatial-v1-0" ], "vcpuNum": 0 }, { "_defaultOrder": 21, "_isFastLaunch": true, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 4, "name": "ml.c5.large", "vcpuNum": 2 }, { "_defaultOrder": 22, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 8, "name": "ml.c5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 23, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.c5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 24, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.c5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 25, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 72, "name": "ml.c5.9xlarge", "vcpuNum": 36 }, { "_defaultOrder": 26, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 96, "name": "ml.c5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 27, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 144, "name": "ml.c5.18xlarge", "vcpuNum": 72 }, { "_defaultOrder": 28, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.c5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 29, "_isFastLaunch": true, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.g4dn.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 30, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.g4dn.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 31, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.g4dn.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 32, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.g4dn.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 33, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.g4dn.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 34, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.g4dn.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 35, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 61, "name": "ml.p3.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 36, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "hideHardwareSpecs": false, "memoryGiB": 244, "name": "ml.p3.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 37, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 8, "hideHardwareSpecs": false, "memoryGiB": 488, "name": "ml.p3.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 38, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 8, "hideHardwareSpecs": false, "memoryGiB": 768, "name": "ml.p3dn.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 39, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.r5.large", "vcpuNum": 2 }, { "_defaultOrder": 40, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.r5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 41, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.r5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 42, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.r5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 43, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.r5.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 44, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 384, "name": "ml.r5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 45, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 512, "name": "ml.r5.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 46, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 768, "name": "ml.r5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 47, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.g5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 48, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.g5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 49, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.g5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 50, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.g5.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 51, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.g5.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 52, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.g5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 53, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "hideHardwareSpecs": false, "memoryGiB": 384, "name": "ml.g5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 54, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 8, "hideHardwareSpecs": false, "memoryGiB": 768, "name": "ml.g5.48xlarge", "vcpuNum": 192 } ], "instance_type": "ml.m5.large", "kernelspec": { "display_name": "Python 3 (PyTorch 1.13 Python 3.9 CPU Optimized)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/pytorch-1.13-cpu-py39" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 }