---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/inference|generativeai|llm-workshop|lab4-openchatkit|deploy_openchatkit_on_sagemaker.ipynb)

---


# Deploy OpenChatKit Model with high performance on SageMaker 

In this notebook, we explore how to host a large language model on SageMaker using the latest container that packages some of the most popular open source libraries for model parallel inference like DeepSpeed and Hugging Face Accelerate. We use DJLServing as the model serving solution in this example. DJLServing is a high-performance universal model serving solution powered by the Deep Java Library (DJL) that is programming language agnostic. To learn more about DJL and DJLServing, you can refer to our recent blog post (https://aws.amazon.com/blogs/machine-learning/deploy-bloom-176b-and-opt-30b-on-amazon-sagemaker-with-large-model-inference-deep-learning-containers-and-deepspeed/).

Language models have recently exploded in both size and popularity. In 2018, BERT-large entered the scene and, with its 340M parameters and novel transformer architecture, set the standard on NLP task accuracy. Within just a few years, state-of-the-art NLP model size has grown by more than 500x with models such as OpenAI’s 175 billion parameter GPT-3 and similarly sized open source Bloom 176B raising the bar on NLP accuracy. This increase in the number of parameters is driven by the simple and empirically-demonstrated positive relationship between model size and accuracy: more is better. With easy access from models zoos such as Hugging Face and improved accuracy in NLP tasks such as classification and text generation, practitioners are increasingly reaching for these large models. However, deploying them can be a challenge because of their size.

Model parallelism can help deploy large models that would normally be too large for a single GPU. With model parallelism, we partition and distribute a model across multiple GPUs. Each GPU holds a different part of the model, resolving the memory capacity issue for the largest deep learning models with billions of parameters. This notebook uses tensor parallelism techniques which allow GPUs to work simultaneously on the same layer of a model and achieve low latency inference relative to a pipeline parallel solution.

SageMaker has rolled out DeepSpeed and Accelerate container which now provides users with the ability to leverage the managed serving capabilities and help to provide the un-differentiated heavy lifting.

In this notebook, we deploy the open source [GPT-NeoXT-Chat-Base-20B](https://huggingface.co/togethercomputer/GPT-NeoXT-Chat-Base-20B?text=As+part+of+OpenChatKit+%28codebase+available+here%29%2C+GPT-NeoXT-Chat-Base-20B+is+a+20B+parameter+language+model%2C+fine-tuned+from+EleutherAI%E2%80%99s+GPT-NeoX+with+over+40+million+instructions+on+100%25+carbon+negative+compute) (OpenChatKit) model across GPUs on a ml.g5.12xlarge instance. The open source [GPT-JT-Moderation-6b](https://huggingface.co/togethercomputer/GPT-JT-Moderation-6B) model is deployed across GPUs in the same instance

OpenChatKit provides a powerful, open-source base to create both specialized and general purpose chatbots for various applications. The kit includes an instruction-tuned 20 billion parameter language model, a 6 billion parameter moderation model, and an extensible retrieval system for including up-to-date responses from custom repositories. It was trained on the OIG-43M training dataset, which was a collaboration between Together, LAION, and Ontocord.ai. Much more than a model release, this is the beginning of an open source project. We are releasing a set of tools and processes for ongoing improvement with community contributions. You can read more information on OpenChatKit [here](https://github.com/togethercomputer/OpenChatKit)

In this example, we demonstrate how to use SageMaker large model inference container to host OpenChatKit. We used HuggingFace Accelerate's model parallel techniques with multiple GPUs on a single SageMaker machine learning instance. OpenChatKit also includes an extensible retrieval system. With the retrieval system the chatbot is able to incorporate regularly updated or custom content, such as knowledge from Wikipedia, news feeds, or sports scores in response. The additional component of OpenChatKit is a 6 billion parameter moderation model fine-tuned from GPT-JT. In chat applications, the moderation model runs in tandem with the main chat model, checking the user utterance for any inappropriate content. Based on the moderation model’s assessment, the chatbot can limit the input to moderated subjects. For more narrow tasks the moderation model can be used to detect out-of-domain questions and override when the question is not on topic Please refer to [this](https://www.together.xyz/blog/openchatkit) blog post to extend this model with retrieval system.

Invocations to SageMaker endpoints are stateless, so a model cannot automatically refer to past messages in computing new outputs. As a result, a DynamoDB table is created to store conversations based on a unique identifier generated by the endpoint. When this identifier is passed in with the invocation request, the model concatenates the new prompt with the previous conversation before performing inference.

As a result, the IAM role used for the endpoint needs permissions for the following actions:
- `dynamodb:CreateTable`
- `dynamodb:DescribeTable`
- `dynamodb:PutItem`
- `dynamodb:GetItem`


HuggingFace Accelerate is used for tensor parallelism inference while DJLServing handles inference requests and the distributed workers. For further reading on HuggingFace you can refer to https://huggingface.co/docs

## Licence agreement
 - View license information https://github.com/togethercomputer/OpenChatKit/blob/main/LICENSE before using the model.
 - This notebook is a sample notebook and not intended for production use. Please refer to the licence at https://github.com/aws/mit-0. 
 - Faiss is available from https://github.com/facebookresearch/faiss. View license information at https://github.com/facebookresearch/faiss/blob/main/LICENSE
 
 


In [None]:
!pip install boto3 huggingface_hub sagemaker-studio-image-build --upgrade --quiet

In [None]:
import sagemaker
import jinja2
from sagemaker import image_uris
import boto3
import os
import time
import json
from pathlib import Path

In [None]:
role = sagemaker.get_execution_role() # execution role for the endpoint
sess = sagemaker.session.Session() # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket() # bucket to house artifacts

model_bucket = sess.default_bucket() # bucket to house artifacts
s3_code_prefix = "hf-large-model-djl-/code_gpt_neoxt-chatbase" # folder within bucket where code artifact will go
s3_model_prefix = "hf-large-model-djl-/model_gpt_neoxt-chatbase" # folder within bucket where code artifact will go
region = sess._region_name
account_id = sess.account_id()

s3_client = boto3.client("s3")
sm_client = boto3.client("sagemaker")
smr_client = boto3.client("sagemaker-runtime")

jinja_env = jinja2.Environment()

# define a variable to contain the s3url of the location that has the model
pretrained_model_location = f"s3://{model_bucket}/{s3_model_prefix}/"
print(f"Pretrained model will be uploaded to ---- > {pretrained_model_location}")

### Download the models from Hugging Face and upload the model artifacts on Amazon S3

In [None]:
from huggingface_hub import snapshot_download
from pathlib import Path
import os

# - This will download the model into the current directory where ever the jupyter notebook is running
local_model_path = Path("./openchatkit")
local_model_path.mkdir(exist_ok=True)
model_name = "togethercomputer/GPT-NeoXT-Chat-Base-20B"
# Only download pytorch checkpoint files
allow_patterns = ["*.json", "*.pt", "*.bin", "*.txt", "*.model"]

# - Leverage the snapshot library to donload the model since the model is stored in repository using LFS
chat_model_download_path = snapshot_download(
 repo_id=model_name,
 cache_dir=local_model_path,
 allow_patterns=allow_patterns,
)

In [None]:
model_artifact = sess.upload_data(path=chat_model_download_path, key_prefix=s3_model_prefix)
print(f"Model uploaded to --- > {model_artifact}")
print(f"We will set option.s3url={model_artifact}")

In [None]:
!rm -rf openchatkit/

## Create SageMaker compatible Model artifact, upload Model to S3 and bring your own inference script.

SageMaker Large Model Inference containers can be used to host models without providing your own inference code. This is extremely useful when there is no custom pre-processing of the input data or post-processing of the model's predictions.

However, in this notebook, we demonstrate how to deploy a model with custom inference code.

SageMaker needs the model artifacts to be in a Tarball format. In this example, we provide the following files - `serving.properties` and `model.py`.

The tarball is in the following format

```
code
├──── 
│ └── serving.properties
│ └── model.py
 

```

- `serving.properties` is the configuration file that can be used to configure the model server.
- `model.py` is the script handles any requests for serving.


#### Create serving.properties 

This is a configuration file to indicate to DJL Serving which model parallelization and inference optimization libraries you would like to use. Depending on your need, you can set the appropriate configuration.

Here is a list of settings that we use in this configuration file -
- `engine`: The engine for DJL to use. In this case, it is **Python**.
- `option.entryPoint`: The entry point python file or module. This should align with the engine that is being used. 
- `option.s3url`: Set this to the URI of the Amazon S3 bucket that contains the model. 

If you want to download the model from huggingface.co, you can set `option.modelid`. The model ID of a pretrained model hosted inside a model repository on huggingface.co (https://huggingface.co/models). The container uses this model ID to download the corresponding model repository on huggingface.co. 
- `option.tensor_parallel_degree`: Set to the number of GPU devices over which HuggingFace Accelerate needs to partition the model. This parameter also controls the number of workers per model which will be started up when DJL serving runs. As an example if we have an 8 GPU machine, and we are creating 8 partitions then we will have 1 worker per model to serve the requests. 

For more details on the configuration options and an exhaustive list, you can refer the documentation - https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-large-model-configuration.html.

HuggingFace Accelerate can automatically handle the device map computation by setting the `device_map` option to a supported option, or a device map can be provided. By using the `auto` device map, HuggingFace evenly splits the model across all available GPUs by maximising the available GPU RAM

In [None]:
!mkdir openchatkit

In [None]:
%%writefile openchatkit/serving.properties
engine = Python
option.tensor_parallel_degree = 4
option.s3url = {{s3url}}

In [None]:
# we plug in the appropriate model location into our `serving.properties` file based on the region in which this notebook is running
template = jinja_env.from_string(Path("openchatkit/serving.properties").open().read())
Path("openchatkit/serving.properties").open("w").write(
 template.render(s3url=pretrained_model_location)
)
!pygmentize openchatkit/serving.properties | cat -n

The below code implements the handling logic for the main OpenChatKit GPT-NeoX model. The overall solution is implemented over 4 files to handle:
1. Receiving inference request and handling it (`model.py`)
2. Downloading and preparing the Wikipedia index (`wikipedia_prepare.py`)
3. Searching the Wikipedia Index for relevant documents (`wikipedia.py`)
4. Storing and retrieving the conversation thread in DynamoDB for passing to the model and user (`conversation.py`)


`model.py` implements a class `OpenChatKitService` which handles passing the data between the GPT-JT Moderation mode, GPT NeoX model, Faiss search, and the conversation object. This is called on when inference is performed. This will also generate a unique ID for each invocation if one is not supplied for the purpose of storing the prompts in DynamoDB.

The `ChatModel` class loads the model and generates the response. A stopping criteria is configured for the generation to only produce the bot response on inference. This handles partitioning the model across multiple GPUs.

The `ModerationModel` class will load the model and generate the classification for moderation. If it finds that the classification is `"needs intervention"`, the return value will be `True` to advise the model to censor the response to the user.

In [None]:
%%writefile openchatkit/model.py
import torch
import logging
import uuid
import wikipedia as wp
import conversation as convo

from djl_python import Input, Output
from transformers import (
 pipeline,
 AutoConfig,
 AutoModelForCausalLM,
 AutoTokenizer,
 StoppingCriteria,
 StoppingCriteriaList,
)
from accelerate import infer_auto_device_map, init_empty_weights
from typing import Optional


class StopWordsCriteria(StoppingCriteria):
 def __init__(self, tokenizer, stop_words):
 self._tokenizer = tokenizer
 self._stop_words = stop_words
 self._partial_result = ""
 self._stream_buffer = ""

 def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 first = not self._partial_result
 text = self._tokenizer.decode(input_ids[0, -1])
 self._partial_result += text
 for stop_word in self._stop_words:
 if stop_word in self._partial_result:
 return True
 return False


class ModerationModel:
 def __init__(self, properties):
 tensor_parallel = int(properties.get("tensor_parallel_degree", -1))
 model_location = "togethercomputer/GPT-JT-Moderation-6B"

 kwargs = {}

 config = AutoConfig.from_pretrained(model_location)

 with init_empty_weights():
 model_from_conf = AutoModelForCausalLM.from_config(config)

 model_from_conf.tie_weights()

 if "dtype" in properties:
 if properties["dtype"] == "float16":
 dtype_str = "float16"
 dtype = torch.float16
 if properties["dtype"] == "bfloat16":
 dtype_str = "bfloat16"
 dtype = torch.bfloat16
 else:
 dtype_str = "float16"
 dtype = torch.float16

 if "device_map" in properties:
 kwargs["device_map"] = properties["device_map"]
 elif tensor_parallel > 0:
 kwargs["device_map"] = "auto"
 world_size = torch.cuda.device_count()
 assert (
 world_size == tensor_parallel
 ), f"TP degree ({tensor_parallel}) doesn't match available GPUs ({world_size})"
 logging.info(f"Using {world_size} gpus")
 else:
 kwargs["device_map"] = infer_auto_device_map(
 model_from_conf,
 no_split_module_classes=["GPTNeoXLayer"],
 dtype=dtype_str,
 )

 if "low_cpu_mem_usage" in properties:
 kwargs["low_cpu_mem_usage"] = properties.get("low_cpu_mem_usage")

 logging.info(f"Using device map: {kwargs['device_map']}")
 logging.info(f"Loading model in {model_location} with kwargs {kwargs}")

 self.tokenizer = AutoTokenizer.from_pretrained(model_location)

 self.model = AutoModelForCausalLM.from_pretrained(
 model_location,
 torch_dtype=dtype,
 offload_folder="/tmp/offload",
 offload_state_dict=True,
 **kwargs,
 )

 logging.info(f"Moderation Model loaded Initialized")

 def clean_output(self, output):
 # Clean the input prompt out
 print(output.split("Output:"))
 _output = output.split("Output:")[-1]

 return _output

 def inference(self, prompt):
 prompt = f"""Possible labels:
 1. casual
 2. needs caution
 3. needs intervention
 4. possibly needs caution
 5. probably needs caution
 
 
 Input: I want to call in sick
 Output: casual

 Input: I hate my job and I don't want to go there.
 Output: possibly needs caution

 Input: It's so hard to deal with this. I hate this job and I want to quit.
 Output: possibly needs caution

 Input: Something tells me that you're sitting in the wrong section of the Stormfront bus.
 Output: needs caution

 Input: {prompt}
 Output:"""

 print(prompt)

 print("Inference moderation complete")

 stop_criteria = StopWordsCriteria(self.tokenizer, ["Input:"])

 inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

 outputs = self.model.generate(
 **inputs,
 max_new_tokens=512,
 stopping_criteria=StoppingCriteriaList([stop_criteria]),
 )

 output = self.tokenizer.batch_decode(outputs)[0]
 output = self.clean_output(output)
 if "needs intervention" in output:
 return True
 else:
 return False


class ChatModel:
 human_id = ""
 bot_id = ""

 def __init__(self, properties):
 tensor_parallel = int(properties.get("tensor_parallel_degree", -1))
 if "model_dir" in properties:
 model_location = properties["model_dir"]
 if "model_id" in properties:
 model_location = properties["model_id"]

 kwargs = {}

 config = AutoConfig.from_pretrained(model_location)

 with init_empty_weights():
 model_from_conf = AutoModelForCausalLM.from_config(config)

 model_from_conf.tie_weights()

 if "dtype" in properties:
 if properties["dtype"] == "float16":
 dtype_str = "float16"
 dtype = torch.float16
 if properties["dtype"] == "bfloat16":
 dtype_str = "bfloat16"
 dtype = torch.bfloat16
 else:
 dtype_str = "float16"
 dtype = torch.float16

 if "device_map" in properties:
 kwargs["device_map"] = properties["device_map"]
 elif tensor_parallel > 0:
 kwargs["device_map"] = "auto"
 world_size = torch.cuda.device_count()
 assert (
 world_size == tensor_parallel
 ), f"TP degree ({tensor_parallel}) doesn't match available GPUs ({world_size})"
 logging.info(f"Using {world_size} gpus")
 else:
 kwargs["device_map"] = infer_auto_device_map(
 model_from_conf,
 no_split_module_classes=["GPTNeoXLayer"],
 dtype=dtype_str,
 )

 if "low_cpu_mem_usage" in properties:
 kwargs["low_cpu_mem_usage"] = properties.get("low_cpu_mem_usage")

 logging.info(f"Using device map: {kwargs['device_map']}")
 logging.info(f"Loading model in {model_location} with kwargs {kwargs}")

 self.tokenizer = AutoTokenizer.from_pretrained(model_location)

 self.model = AutoModelForCausalLM.from_pretrained(
 model_location,
 torch_dtype=dtype,
 offload_folder="/tmp/offload",
 offload_state_dict=True,
 **kwargs,
 )

 logging.info(f"ChatModel loaded Initialized")

 def do_inference(self, prompt, **generate_kwargs):
 stop_criteria = StopWordsCriteria(self.tokenizer, [self.human_id])
 inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

 outputs = self.model.generate(
 **inputs,
 pad_token_id=self.tokenizer.eos_token_id,
 stopping_criteria=StoppingCriteriaList([stop_criteria]),
 **generate_kwargs,
 )

 output = self.tokenizer.batch_decode(outputs)[0]
 output = output.split(self.bot_id)[-1].strip()

 return output


class OpenChatKitService:
 def __init__(self):
 self.input_model = None
 self.model = None
 self.output_model = None
 self.initialized = False
 self.index = None
 self.conversation = None

 def initialize(self, properties):
 print("Done")
 logging.info(f"Loading models...")
 self.input_model = ModerationModel(properties)
 self.model = ChatModel(properties)
 self.output_model = ModerationModel(properties)

 logging.info("Loading Wikipedia Retrieval")
 import wikipedia_prepare

 self.index = wp.WikipediaIndex()
 self.conversation = convo.Conversation(self.model.human_id, self.model.bot_id)
 self.initialized = True

 def inference(self, inputs: Input):
 data = inputs.get_as_json()

 input_sentences = data["inputs"]
 params = data["parameters"]

 print(params)

 if self.input_model.inference(input_sentences):
 return Output().add_as_json(
 {"outputs": "Unfortunately I am unable to provide any information about this topic"}
 )

 if "session_id" in params.keys():
 session_id = params.pop("session_id")
 else:
 session_id = str(uuid.uuid4())

 if "no_retrieval" not in params.keys():
 results = self.index.search(input_sentences)
 if len(results) > 0:
 self.conversation.push_context_turn(results[0], session_id)
 else:
 params.pop("no_retrieval")

 self.conversation.push_human_turn(input_sentences, session_id)

 output = self.model.do_inference(self.conversation.get_raw_prompt(session_id), **params)

 self.conversation.push_model_response(output, session_id)

 response = self.conversation.get_last_turn(session_id).strip()

 if self.output_model.inference(response):
 return Output().add_as_json(
 {"outputs": "Unfortunately I am unable to provide any information about this topic"}
 )

 return Output().add_as_json({"outputs": response, "session_id": session_id})


_service = OpenChatKitService()


def handle(inputs: Input) -> Optional[Output]:
 if not _service.initialized:
 _service.initialize(inputs.get_properties())

 if inputs.is_empty():
 return None

 return _service.inference(inputs)

`conversation.py` is adapted from the open source OpenChatKit repository. This file is responsible for defining the object that stores the conversation turns between the human and the model. With this, the model is able to retain a session for the conversation allowing a user to refer to previous messages. 

As SageMaker endpoint invocations are stateless, this conversation needs to be stored in a location external to the endpoint instances. On startup, the instance will create a DynamoDB table if it does not exist. All updates to the conversation are then stored in DynamoDB based on the `session_id` key which is generated by the endpoint. Any invocation with a session ID will retrieve the associated conversation string and update it as required.

In [None]:
%%writefile openchatkit/conversation.py
# This file was adapted from togethercomputer/openchatkit:
# https://github.com/togethercomputer/OpenChatKit/blob/main/inference/conversation.py
#
# The original file was licensed under the Apache 2.0 License

import re
import time
import boto3
import logging

MEANINGLESS_WORDS = ["", "", "<|endoftext|>"]
PRE_PROMPT = """\
Current Date: {}
Current Time: {}

"""


def clean_response(response):
 for word in MEANINGLESS_WORDS:
 response = response.replace(word, "")
 response = response.strip("\n")
 return response


class Conversation:
 DEFAULT_KEY_NAME = "session_id"

 def __init__(self, human_id, bot_id, db_name="openchatkit_chat_logs"):
 cur_date = time.strftime("%Y-%m-%d")
 cur_time = time.strftime("%H:%M:%S %p %Z")

 self._human_id = human_id
 self._bot_id = bot_id
 prompt = PRE_PROMPT.format(cur_date, cur_time)
 self.db_name = db_name
 self.ddb_client = boto3.client("dynamodb")

 try:
 self.ddb_client.describe_table(TableName=db_name)
 except self.ddb_client.exceptions.ResourceNotFoundException:
 logging.info(f"Table {db_name} not found. Creating...")
 self.ddb_client.create_table(
 TableName=db_name,
 AttributeDefinitions=[
 {"AttributeName": self.DEFAULT_KEY_NAME, "AttributeType": "S"},
 ],
 KeySchema=[{"AttributeName": self.DEFAULT_KEY_NAME, "KeyType": "HASH"}],
 BillingMode="PAY_PER_REQUEST",
 )
 waiter = self.ddb_client.get_waiter("table_exists")
 waiter.wait(TableName=db_name, WaiterConfig={"Delay": 1})

 def push_context_turn(self, context, session_id):
 # for now, context is represented as a human turn
 prompt = self.get_raw_prompt(session_id)
 prompt += f"{self._human_id}: {context}\n"
 self.set_prompt(session_id, prompt)

 def push_human_turn(self, query, session_id):
 prompt = self.get_raw_prompt(session_id)
 prompt += f"{self._human_id}: {query}\n"
 prompt += f"{self._bot_id}:"
 self.set_prompt(session_id, prompt)

 def push_model_response(self, response, session_id):
 has_finished = self._human_id in response
 bot_turn = response.split(f"{self._human_id}:")[0]
 bot_turn = clean_response(bot_turn)
 # if it is truncated, then append "..." to the end of the response
 if not has_finished:
 bot_turn += "..."

 prompt = self.get_raw_prompt(session_id)
 prompt += f"{bot_turn}\n"
 self.set_prompt(session_id, prompt)

 def get_last_turn(self, session_id):
 human_tag = f"{self._human_id}:"
 bot_tag = f"{self._bot_id}:"
 prompt = self.get_raw_prompt(session_id)
 turns = re.split(f"({human_tag}|{bot_tag})\W?", prompt)
 # print(turns)
 return turns[-1]

 def set_prompt(self, session_id, prompt):
 self.ddb_client.put_item(
 TableName=self.db_name,
 Item={self.DEFAULT_KEY_NAME: {"S": session_id}, "content": {"S": prompt}},
 )

 def get_raw_prompt(self, session_id):
 data = self.ddb_client.get_item(
 TableName=self.db_name, Key={self.DEFAULT_KEY_NAME: {"S": session_id}}
 )

 # If no data is associated with the session id (meaning session did not exist)
 if "Item" not in data.keys():
 cur_date = time.strftime("%Y-%m-%d")
 cur_time = time.strftime("%H:%M:%S %p %Z")
 prompt = PRE_PROMPT.format(cur_date, cur_time)
 self.set_prompt(session_id, prompt)
 return prompt

 return data["Item"]["content"]["S"]

 @classmethod
 def from_raw_prompt(cls, value, session_id):
 self.set_prompt(session_id, value)

In order to search the Wikipedia documents for relevant text, the index needs to be downloaded from HuggingFace as it is not packaged elsewhere.

This file is responsible for handling the download when imported. Only a single process in the multiple that are running for inference can clone the repository. The rest will instead wait until the files are present in the local filesystem.

In [None]:
%%writefile openchatkit/wikipedia_prepare.py
# This file was adapted from togethercomputer/openchatkit:
# https://github.com/togethercomputer/OpenChatKit/blob/main/data/wikipedia-3sentence-level-retrieval-index/prepare.py
#
# The original file was licensed under the Apache 2.0 license.

import os
import subprocess
import time

DIR = os.path.dirname(os.path.abspath(__file__))
print(DIR)
print("Running lfs check")

print(DIR)
print("Prior clone")
print
if not os.path.isdir("/tmp/files/index"):
 print("Cloning to local")
 try:
 process = subprocess.run(
 f"git clone https://huggingface.co/datasets/ChristophSchuhmann/wikipedia-3sentence-level-retrieval-index /tmp/files/index",
 shell=True,
 check=True,
 )
 except:
 pass

while not os.path.isfile(os.path.join("/tmp/files/index", "wikipedia-en-sentences.parquet")):
 time.sleep(5)
 print("Waiting for clone to finish...")
print("After clone")

This code is responsible for loading and searching the Wikipedia document index. This helps to provide additional context to the chatbot which can improve performance.

In [None]:
%%writefile openchatkit/wikipedia.py
# This file was adapted from ChristophSchuhmann/wikipedia-3sentence-level-retrieval-index:
# https://huggingface.co/datasets/ChristophSchuhmann/wikipedia-3sentence-level-retrieval-index/blob/main/wikiindexquery.py
#
# The original file was licensed under the Apache 2.0 license.

import os

from transformers import AutoTokenizer, AutoModel
import faiss
import numpy as np
import pandas as pd

DIR = os.path.dirname(os.path.abspath(__file__))


def mean_pooling(token_embeddings, mask):
 token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.0)
 sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
 return sentence_embeddings


def cos_sim_2d(x, y):
 norm_x = x / np.linalg.norm(x, axis=1, keepdims=True)
 norm_y = y / np.linalg.norm(y, axis=1, keepdims=True)
 return np.matmul(norm_x, norm_y.T)


class WikipediaIndex:
 def __init__(self):
 path = os.path.join("/tmp/files", "index")
 indexpath = os.path.join(path, "knn.index")
 wiki_sentence_path = os.path.join(path, "wikipedia-en-sentences.parquet")

 self._device = "cuda"
 self._tokenizer = AutoTokenizer.from_pretrained("facebook/contriever-msmarco")
 self._contriever = AutoModel.from_pretrained("facebook/contriever-msmarco").to(self._device)

 self._df_sentences = pd.read_parquet(wiki_sentence_path, engine="fastparquet")

 self._wiki_index = faiss.read_index(indexpath, faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)

 def search(self, query, k=1, w=5, w_th=0.5):
 inputs = self._tokenizer(query, padding=True, truncation=True, return_tensors="pt").to(
 self._device
 )
 outputs = self._contriever(**inputs)
 embeddings = mean_pooling(outputs[0], inputs["attention_mask"])

 query_vector = embeddings.cpu().detach().numpy().reshape(1, -1)

 distances, indices = self._wiki_index.search(query_vector, k)

 texts = []
 for i, (dist, indice) in enumerate(zip(distances[0], indices[0])):
 text = self._df_sentences.iloc[indice]["text_snippet"]

 try:
 input_texts = [self._df_sentences.iloc[indice]["text_snippet"]]
 for j in range(1, w + 1):
 input_texts = [
 self._df_sentences.iloc[indice - j]["text_snippet"]
 ] + input_texts
 for j in range(1, w + 1):
 input_texts = input_texts + [
 self._df_sentences.iloc[indice + j]["text_snippet"]
 ]

 inputs = self._tokenizer(
 input_texts, padding=True, truncation=True, return_tensors="pt"
 ).to(self._device)

 outputs = self._contriever(**inputs)
 embeddings = (
 mean_pooling(outputs[0], inputs["attention_mask"]).detach().cpu().numpy()
 )

 for j in range(1, w + 1):
 if (
 cos_sim_2d(
 embeddings[w - j].reshape(1, -1),
 embeddings[w].reshape(1, -1),
 )
 > w_th
 ):
 text = self._df_sentences.iloc[indice - j]["text_snippet"] + text
 else:
 break

 for j in range(1, w + 1):
 if (
 cos_sim_2d(
 embeddings[w + j].reshape(1, -1),
 embeddings[w].reshape(1, -1),
 )
 > w_th
 ):
 text += self._df_sentences.iloc[indice + j]["text_snippet"]
 else:
 break

 except Exception as e:
 print(e)

 texts.append(text)

 return texts

One of the other features of OpenChatKit are the moderation capabilities. While the model itself does have some moderation built in, TogetherComputer trained a [GPT-JT-Moderation-6B](https://huggingface.co/togethercomputer/GPT-JT-Moderation-6B) model with Ontocord.ai's [OIG-moderation dataset](https://huggingface.co/datasets/ontocord/OIG-moderation). This model runs alongside the main chatbot to check both the user input and answer from the bot do not contain inappropriate results. In the scenario they do, the input model will indicate to the chat model that the input is inappropriate to override the inference result, and the output model will override the inference result.

The input moderation model returns the data in a format that is readable by the bot as if it were a regular input. The output moderation model does not include this change.

**Image URI for the DJL container is being used here**

In [None]:
inference_image_uri = image_uris.retrieve(
 framework="djl-deepspeed", region=sess.boto_session.region_name, version="0.21.0"
)

print(f"Image going to be used is ---- > {inference_image_uri}")

The index search uses Facebook's [Faiss](https://github.com/facebookresearch/faiss) library for performing the similarity search. As this is not included in the base LMI image, the container needs to be adapted to install this library. The below defines a Dockerfile which installs Faiss from source alongside other libraries needed by the bot endpoint.

In [None]:
%%writefile Dockerfile.template
FROM {{imagebase}}

ARG FAISS_URL=https://github.com/facebookresearch/faiss.git
RUN apt-get update && apt-get install -y git-lfs wget cmake pkg-config build-essential apt-utils
RUN apt search openblas && apt-get install -y libopenblas-dev swig

RUN git clone $FAISS_URL && \
 cd faiss && \
 cmake -B build . -DFAISS_OPT_LEVEL=avx2 -DCMAKE_CUDA_ARCHITECTURES="86" && \
 make -C build -j faiss && \
 make -C build -j swigfaiss && \
 make -C build -j swigfaiss_avx2 && \
 (cd build/faiss/python && python -m pip install .)

RUN pip install pandas fastparquet boto3 && \
 git lfs install --skip-repo && \
 apt-get clean all

In [None]:
# we plug in the appropriate model location into our `serving.properties` file based on the region in which this notebook is running
template = jinja_env.from_string(Path("Dockerfile.template").open().read())
Path("Dockerfile").open("w").write(template.render(imagebase=inference_image_uri))
!pygmentize Dockerfile | cat -n

This uses the [SageMaker Studio Image Build CLI](https://github.com/aws-samples/sagemaker-studio-image-build-cli) to build the Docker image defined above as SageMaker Studio does not allow for Docker to be installed for building the image. This will leverage CodeBuild to remotely build the image and push it to a private ECR repository.

This same Dockerfile can be built anywhere that allows for running Docker commands and pushing to a relevant ECR repository.

In [None]:
!sm-docker build . --repository openchatkit:djl --compute-type BUILD_GENERAL1_2XLARGE

In [None]:
chat_inference_image_uri = (
 f"{sess.account_id()}.dkr.ecr.{sess.boto_session.region_name}.amazonaws.com/openchatkit:djl"
)

**Create the Tarball and then upload to S3 location**

In [None]:
%%sh
tar czvf model.tar.gz openchatkit/
rm -rf openchatkit

In [None]:
s3_code_artifact = sess.upload_data("model.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {s3_code_artifact}")

### To create the endpoint the steps are:

1. Build an image adapted from the DJL container that installs Faiss for information retrieval
2. Create the Model using the Image container and the Model Tarball uploaded earlier
3. Create the endpoint config using the following key parameters

 a) Instance Type is ml.g5.12xlarge 
 
 b) ContainerStartupHealthCheckTimeoutInSeconds is 2400 to ensure health check starts after the model is ready 
3. Create the end point using the endpoint config created 
 

#### Create the Model
Use the image URI built from the DJL container and the s3 location to which the tarball was uploaded. The moderation models will use the DJL container.

The container downloads the model into the `/tmp` space on the instance because SageMaker maps the `/tmp` to the Amazon Elastic Block Store (Amazon EBS) volume that is mounted when we specify the endpoint creation parameter VolumeSizeInGB. It leverages `s5cmd`(https://github.com/peak/s5cmd) which offers a very fast download speed and hence extremely useful when downloading large models.

For instances like p4dn, which come pre-built with the volume instance, we can continue to leverage the `/tmp` on the container. The size of this mount is large enough to hold the model.


In [None]:
chat_inference_image_uri

In [None]:
from sagemaker.utils import name_from_base

chat_model_name = name_from_base(f"gpt-neoxt-chatbase-ds")
print(chat_model_name)

create_model_response = sm_client.create_model(
 ModelName=chat_model_name,
 ExecutionRoleArn=role,
 PrimaryContainer={
 "Image": chat_inference_image_uri,
 "ModelDataUrl": s3_code_artifact,
 },
)
chat_model_arn = create_model_response["ModelArn"]

print(f"Created Model: {chat_model_arn}")

In [None]:
chat_endpoint_config_name = f"{chat_model_name}-config"
chat_endpoint_name = f"{chat_model_name}-endpoint"

chat_endpoint_config_response = sm_client.create_endpoint_config(
 EndpointConfigName=chat_endpoint_config_name,
 ProductionVariants=[
 {
 "VariantName": "variant1",
 "ModelName": chat_model_name,
 "InstanceType": "ml.g5.12xlarge",
 "InitialInstanceCount": 1,
 "ContainerStartupHealthCheckTimeoutInSeconds": 3600,
 },
 ],
)

print(chat_endpoint_config_response)

In [None]:
chat_create_endpoint_response = sm_client.create_endpoint(
 EndpointName=f"{chat_endpoint_name}", EndpointConfigName=chat_endpoint_config_name
)

print(f"Created Endpoint: {chat_create_endpoint_response['EndpointArn']},")

### This step can take ~ 10 min or longer so please be patient

In [None]:
import time

resp = sm_client.describe_endpoint(EndpointName=chat_endpoint_name)
status = resp["EndpointStatus"]
chat_resp = sm_client.describe_endpoint(EndpointName=chat_endpoint_name)
chat_status = chat_resp["EndpointStatus"]
print("Status: " + status)

while chat_status == "Creating":
 time.sleep(60)
 chat_resp = sm_client.describe_endpoint(EndpointName=chat_endpoint_name)
 chat_status = chat_resp["EndpointStatus"]
 print(f"Status: {chat_status}...")

print(f"Arns: {chat_resp['EndpointArn']}")
print(f"Status: {chat_status}")

#### While you wait for the endpoint to be created, you can read more about:
- [Deep Learning containers for large model inference](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-large-model-dlc.html)
- [Accelerate](https://huggingface.co/docs/accelerate/index)

#### Leverage the Boto3 to invoke the endpoint. 

This is a generative model, so we pass in a Text as a prompt and Model will complete the sentence and return the results.

You can pass a batch of prompts as input to the model. This done by setting `inputs` to the list of prompts. The model then returns a result for each prompt. The text generation can be configured using appropriate parameters. These `parameters` need to be passed to the endpoint. Refer to this documentation - https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig for more details.

The below code sample illustrates the invocation of the endpoint using a prompt and also sets some parameters for inference. The function allows for a session ID to be provided for re-using previous inputs and outputs as additional context for a conversation.


In [None]:
def chat(prompt, session_id=None, **kwargs):
 if session_id:
 chat_response_model = smr_client.invoke_endpoint(
 EndpointName=chat_endpoint_name,
 Body=json.dumps(
 {
 "inputs": prompt,
 "parameters": {
 "temperature": 0.6,
 "top_k": 40,
 "max_new_tokens": 512,
 "session_id": session_id,
 "no_retrieval": True,
 },
 }
 ),
 ContentType="application/json",
 )
 else:
 chat_response_model = smr_client.invoke_endpoint(
 EndpointName=chat_endpoint_name,
 Body=json.dumps(
 {
 "inputs": prompt,
 "parameters": {
 "temperature": 0.6,
 "top_k": 40,
 "max_new_tokens": 512,
 },
 }
 ),
 ContentType="application/json",
 )

 response = chat_response_model["Body"].read().decode("utf8")
 return json.loads(response)

In [None]:
prompts = "What do data engineers do?"

In [None]:
response = chat(prompts)

response

In [None]:
chat("What frameworks do they work with?", session_id=response["session_id"])

## Clean Up

In [None]:
# # - Delete the end point
sm_client.delete_endpoint(EndpointName=chat_endpoint_name)

In [None]:
# # - In case the end point failed we still want to delete the model
sm_client.delete_endpoint_config(EndpointConfigName=chat_endpoint_config_name)
sm_client.delete_model(ModelName=chat_model_name)

In [None]:
dynamodb_client = boto3.client("dynamodb")
dynamodb_client.delete_table(TableName="openchatkit_chat_logs")

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.


![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/inference|generativeai|llm-workshop|lab4-openchatkit|deploy_openchatkit_on_sagemaker.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/inference|generativeai|llm-workshop|lab4-openchatkit|deploy_openchatkit_on_sagemaker.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/inference|generativeai|llm-workshop|lab4-openchatkit|deploy_openchatkit_on_sagemaker.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/inference|generativeai|llm-workshop|lab4-openchatkit|deploy_openchatkit_on_sagemaker.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/inference|generativeai|llm-workshop|lab4-openchatkit|deploy_openchatkit_on_sagemaker.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/inference|generativeai|llm-workshop|lab4-openchatkit|deploy_openchatkit_on_sagemaker.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/inference|generativeai|llm-workshop|lab4-openchatkit|deploy_openchatkit_on_sagemaker.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/inference|generativeai|llm-workshop|lab4-openchatkit|deploy_openchatkit_on_sagemaker.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/inference|generativeai|llm-workshop|lab4-openchatkit|deploy_openchatkit_on_sagemaker.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/inference|generativeai|llm-workshop|lab4-openchatkit|deploy_openchatkit_on_sagemaker.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/inference|generativeai|llm-workshop|lab4-openchatkit|deploy_openchatkit_on_sagemaker.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/inference|generativeai|llm-workshop|lab4-openchatkit|deploy_openchatkit_on_sagemaker.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/inference|generativeai|llm-workshop|lab4-openchatkit|deploy_openchatkit_on_sagemaker.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/inference|generativeai|llm-workshop|lab4-openchatkit|deploy_openchatkit_on_sagemaker.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/inference|generativeai|llm-workshop|lab4-openchatkit|deploy_openchatkit_on_sagemaker.ipynb)
