# Large Language Model Customization using Retrieval-Augmented Generation (RAG) Pattern, Amazon Kendra Enterprise Search Service and Falcon-40B-Instruct Language Model

---
This Amazon SageMaker Studio notebook demonstrates how to use [SageMaker](https://sagemaker.readthedocs.io/en/stable/) and [boto3](https://boto3.amazonaws.com/v1/documentation/api/latest/index.html) SDKs to generate text using the Retrieval-Augmented Generation (RAG) pattern. The notebook implements semantic search using [Amazon Kendra](https://aws.amazon.com/kendra/) enterprise search service. The language model used for text generation is [Falcon-40B-Instruct](https://huggingface.co/tiiuae/falcon-40b-instruct).

This notebook has the following prerequisites:
- Select an AWS region where [Amazon SageMaker JumpStart](https://aws.amazon.com/sagemaker/jumpstart) is available. 
- [Setup Amazon SageMaker Domain](https://docs.aws.amazon.com/sagemaker/latest/dg/onboard-quick-start.html).
- [Add an additional permission to the SageMaker Execution Role](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html) to call the Amazon Kendra Retrieve API.
- [Available service queta](https://docs.aws.amazon.com/general/latest/gr/sagemaker.html) for "ml.g5.12xlarge for endpoint usage".
- Select the [Amazon SageMaker Kernel](https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-kernels.html), "Python 3 (Data Science 2.0) with Python 3.8" or higher.
- Familiarity with [Retrieval Augmented Generation (RAG)](https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-customize-rag.html) pattern.
- Less than $10 per hour to spend on Amazon SageMaker JumpStart model deployments and usage of other AWS services.

---

In [None]:
# install pythn libraries
!pip install --upgrade pip --quiet
!pip install --upgrade boto3 --quiet
!pip install --upgrade sagemaker --quiet
!pip install ipywidgets --quiet

In [None]:
# important required libraries
import boto3
from sagemaker.jumpstart.model import JumpStartModel

### Step 1: Create Amazon Kendra Index (this step can take between 30 to 60 minutes)

Create Amazon Kendra index and configure it to index your dataset. Alternatively, you can use the provided AWS CloudFormation [template](https://github.com/aws-samples/amazon-kendra-langchain-extensions/blob/main/kendra_retriever_samples/kendra-docs-index.yaml) to create a new Amazon Kendra index containing AWS online documentation for Amazon Kendra, Amazon Lex, and Amazon SageMaker.
This notebook must be granted AWS IAM permission to call Amazon Kendra APIs and deployed in the same AWS region where Amazon Kendra will be created.

In [None]:
# set your kendra_index_id 
kendra_index_id = "xxxxxxx-xxxx-xxxxx-xxxx-xxxxxxxxxxxx"

In [None]:
%%time

# Define SageMaker JumpStart Model using model id, instance type, and endpoint timeout
my_model = JumpStartModel(model_id="huggingface-llm-falcon-40b-instruct-bf16",
                          instance_type="ml.g5.12xlarge",
                          env={'ENDPOINT_SERVER_TIMEOUT':'300'})

# Host the model on the instance and deploy an inference endpoint
# Because the model size is >80GB, expecy deploy() to take 15 min!
predictor = my_model.deploy()

#### Step 3: Define context search function 

In [None]:
def retrieve_context(question, top_k):
    content = ""
    documentURI = "" 
    client = boto3.client('kendra')
    response = client.retrieve(IndexId=kendra_index_id, QueryText=question, PageSize=top_k)
    
    for query_result in response["ResultItems"]:
        content = content + query_result["Content"] + "\n\n"
        documentURI = documentURI + query_result["DocumentURI"] + "\n"
 
    return content.strip(), documentURI.strip()

# test retrieve_context() function
input_question = "Is Amazon SageMaker a machine-learning service?"
content, documentURI = retrieve_context(question=input_question, top_k=1) 
print("Question:", input_question)
print("Content:", content)
print("URI:", documentURI)

### Step 4: Define LLM prompt function

In [None]:
def get_prompt(question, context):
    return f"""{context}\n\nUse the above paragraphs to answer the following question: {question}"""
    
def prompt_model(question, context):
    prompt = ""

    if context:
        prompt = get_prompt(question=question, context=context)
    else:
        prompt = question
        
    payload = {
        "inputs": prompt,
        "parameters": {
            "max_new_tokens": 200,
            "num_return_sequences": 1,
            "temperature": 1,
            "num_beams": 1,
            "do_sample": True,
            "top_k": 50,
            "top_p": 0.95,
            "stop": ["<|endoftext|>", "</s>"]
        }
    }
    
    response = predictor.predict(payload)
    
    return response[0]["generated_text"].strip()

# test prompt_model() function
question = "What is Amazon SageMaker?"
context="Amazon SageMaker is a machine-learning service."
prompt = get_prompt(question=question, context=context)
response = prompt_model(question=question, context=context)
print("Question:", question)
print("Context:", context)
print("Prompt:", prompt)
print("Response:", response)

### Step 4: Test LLM prompting (with and without RAG) using an interactive widget 

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output

def on_send_button_click(button):
    input_question = input_field.value
    documentURI = ""
    content = ""
        
    if rag_check.value:
        content, documentURI = retrieve_context(question=input_question, top_k=3) 
        response = prompt_model(question=input_question, context=content)
    else:
        response = prompt_model(question=input_question, context=None)
    
    with output:
        print("Q:", input_question)
        print("A:", response)
        print("URIs:", documentURI if documentURI else "RAG not active!")
        print("-"*40)

    input_field.value = ""

def on_input_field_submit(text):
    on_send_button_click(None)

# Create the input field and send button
input_field = widgets.Text(placeholder='Type your question here...')
rag_check = widgets.Checkbox(value=True, description='Enable RAG', indent=False)
send_button = widgets.Button(description='Send')
top_box = widgets.HBox([input_field, rag_check])
bottom_box = widgets.HBox([send_button])
v_box = widgets.VBox([top_box, bottom_box])
output = widgets.Output()

# Assign the function to the button click event and the input field submit event
send_button.on_click(on_send_button_click)
input_field.on_submit(on_input_field_submit)

# Display the chat interface
display(output, v_box)

Here are some sample questions to get you started:
- What are the instance types recommended for training in SageMaker?
- Can Amazon Kendra extract content of images from Power Point slides?
- Does Amazon SageMaker support any GPUs made by Microsoft?
- Write a summary about Amazon Kendra Experience Builder.

### SageMaker Clean up

In [None]:
# Delete the SageMaker endpoint
predictor.delete_endpoint()