# Intelligent document processing with Gen AI, Amazon Textract and FlanT5 on SageMaker Jumpstart
____


<div class="alert alert-block alert-info"> 
    <b>NOTE:</b> You will need to use a Jupyter Kernel with Python 3.9 or above to use this notebook. For example, you can use the `PyTorch 1.13 Python 3.9` image. Also note that selecting an image may also require you to use a different instance type. We recommend the `GPU Based image with ml.g4dn.xlarge instance type` configuration.
</div>


In this notebook we will first walk through Amazon Textract's document extraction capabilities, and then the steps required to perform Q&A with a document first by extracting text from a document using Amazon Textract, generating chunks of text and store them into a Vector DB, and then performing Q&A with a FlanT5 model deployed in SageMaker endpoint via SageMaker Jumpstart and get precise answers from the model. Later on, we will also use the endpoint to perform text summarization. 

# Setup notebook <a id="step1"></a>

In this step, we will import some necessary libraries that will be used throughout this notebook. 

In [None]:
!pip install -U langchain 
!pip install pdfplumber
!pip install unstructured
!pip install chromadb
!pip install -U sentence-transformers
!pip install pydantic==1.10.11 #use 1.10.11 version due to stability
#textractor libraries
!python -m pip install -q amazon-textract-caller --upgrade
!python -m pip install -q amazon-textract-prettyprinter --upgrade
!python -m pip install -q amazon-textract-response-parser --upgrade

# Module 1 - Document Extraction 

In [None]:
import boto3
import botocore
import sagemaker
from sagemaker.session import Session
from sagemaker.session import Session
from IPython.display import Image, display, JSON
from textractcaller.t_call import call_textract, Textract_Features, call_textract_expense
from textractprettyprinter.t_pretty_print import convert_table_to_list
from trp import Document
import os
import pandas as pd

# variables
sagemaker_session = Session()
data_bucket = sagemaker.Session().default_bucket()
region = boto3.session.Session().region_name
aws_role = sagemaker_session.get_caller_identity_arn()

# boto3 clients
s3=boto3.client('s3')
textract = boto3.client('textract', region_name=region)

print(f"Region is {region}, IAM Role: {aws_role}, S3 Bucket: {data_bucket}")

## Upload sample data to S3 bucket


The sample document is in `/samples` directory. For this workshop, we will be using a sample document.

In [None]:
# Upload images to S3 bucket:

!aws s3 cp samples s3://{data_bucket}/idp/genai --recursive --only-show-errors

---
# Extract structured data such as tables and key-value pairs using Amazon Textract

In this step we will take a brief look at how to extract table and key-value pair information from our sample healthcare policy document.  

### Extracting Tables


In [None]:
prefix = "idp/genai"
file_key = "health_plan.pdf"
resp = call_textract(input_document=f's3://{data_bucket}/{prefix}/{file_key}', features=[Textract_Features.TABLES])
tdoc = Document(resp)
dfs = list()

In [None]:
for page in tdoc.pages:
    for table in page.tables:
        tab_list = convert_table_to_list(trp_table=table)
        print(tab_list)
        dfs.append(pd.DataFrame(tab_list))
df1 = dfs[0]
df2 = dfs[1]

In [None]:
df1

In [None]:
df2

### Extracting Forms (key-value pairs) data


In [None]:
from textractcaller.t_call import call_textract, Textract_Features
from textractprettyprinter.t_pretty_print import Pretty_Print_Table_Format, Textract_Pretty_Print, get_string


# Call Amazon Textract
response = call_textract(input_document=f's3://{data_bucket}/{prefix}/{file_key}', features=[Textract_Features.FORMS])


print(get_string(textract_json=response,
               table_format=Pretty_Print_Table_Format.csv,
               output_type=[Textract_Pretty_Print.FORMS]))

# Module 2 - Enhancing IDP with Foundation Models

## Select a pre-trained model
---
You can continue with the default model, or can choose a different model from the dropdown generated upon running the next cell. A complete list of SageMaker pre-trained models can also be accessed at [Sagemaker pre-trained Models](https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html#).


In [None]:
# "huggingface-text2text-flan-t5-xl",
# "huggingface-text2text-flan-t5-large",

model_id, model_version, = (
    "huggingface-text2text-flan-t5-xl",
    "*",
)

## Retrieve Artifacts & Deploy a HuggingFace FLAN-T5 Endpoint

---

Using SageMaker, we can perform inference on the pre-trained model, even without fine-tuning it first on a new dataset. We start by retrieving the `deploy_image_uri`, `deploy_source_uri`, and `model_uri` for the pre-trained model. To host the pre-trained model, we create an instance of [`sagemaker.model.Model`](https://sagemaker.readthedocs.io/en/stable/api/inference/model.html) and deploy it. This may take a few minutes.


In [None]:
def get_sagemaker_session(local_download_dir) -> sagemaker.Session:
    """Return the SageMaker session."""

    sagemaker_client = boto3.client(
        service_name="sagemaker", region_name=boto3.Session().region_name
    )

    session_settings = sagemaker.session_settings.SessionSettings(
        local_download_dir=local_download_dir
    )

    # the unit test will ensure you do not commit this change
    session = sagemaker.session.Session(
        sagemaker_client=sagemaker_client, settings=session_settings
    )

    return session

We need to create a directory to host the downloaded model.

In [None]:
!mkdir -p download_dir

We will use the code block below to download the model artifacts and then deploy the model on to a SageMaker inference endpoint. Note that we are going to use `ml.g5.2xlarge` inference instance type to deploy the model and the script below may take about ~10 minutes to complete deployment.

In [None]:
from sagemaker import image_uris, model_uris, script_uris, hyperparameters
from sagemaker.model import Model
from sagemaker.predictor import Predictor
from sagemaker.utils import name_from_base
import config


endpoint_name = name_from_base(f"{config.SOLUTION_PREFIX}-{model_id}")

inference_instance_type = "ml.g5.2xlarge"

# Retrieve the inference docker container uri. This is the base HuggingFace container image for the default model above.
deploy_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    image_scope="inference",
    model_id=model_id,
    model_version=model_version,
    instance_type=inference_instance_type,
)

# Retrieve the inference script uri. This includes all dependencies and scripts for model loading, inference handling etc.
deploy_source_uri = script_uris.retrieve(
    model_id=model_id, model_version=model_version, script_scope="inference"
)


# Retrieve the model uri.
model_uri = model_uris.retrieve(
    model_id=model_id, model_version=model_version, model_scope="inference"
)

#Create model
model = Model(
    image_uri=deploy_image_uri,
    model_data=model_uri,
    role=aws_role,
    predictor_cls=Predictor,
    name=endpoint_name,
)

# deploy the Model. Note that we need to pass Predictor class when we deploy model through Model class,
# for being able to run inference through the sagemaker API.
model_predictor = model.deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    predictor_cls=Predictor,
    endpoint_name=endpoint_name,
    # volume_size=30,
)

---
# Perform Common sense reasoning and QA on a document

In this section, we will perform common sense reasoning and Q&A on a document. This section does the following

- Generates text from documents and stores them into S3 in plaintext format
- Generate embeddings from the text
- Uses an in-memory vector database to store the embeddings
- Perform similarity search on the in-memory vector db to find relevant pieces of text that have relavancy to the asked question (by the user)
- Generate the context for the LLM using the search results
- Give the model the context and the original question asked
- Get the answer back from the LLM
- Profit

> _"Wait but that's a lot of steps just for getting an answer back? Why?"_

We would love to explain and dive deeper into why, but here's a paper that does a better job of explain the why? and the how? - https://arxiv.org/pdf/2005.11401.pdf . In short, LLMs know too much, _sometimes a bit too much that it may get confused and wander into the proverbial forest of it's own world knowledge and go start gathering firewood, when it was actually asked to go pick some fruit_. To solve this problem, and to get accurate answers (or better no answer at all) we use this method of Retrieval-Augmented Generation (aka RAG), just to give the LLM a bit more _stuff_ to work with such that it gives us the desired output (like a fruit basket in our example, so that it knows it's only supposed to pick fruits) .

As a first step, we read a file (document) using Amazon Textract and write the plaintext into S3.

In [None]:
from textractcaller.t_call import call_textract, Textract_Features
from trp.trp2 import TDocument, TDocumentSchema
from trp.t_pipeline import order_blocks_by_geo
import boto3
import sagemaker
import pdfplumber
import mimetypes
import trp
import json
import uuid

doc_path = f's3://{data_bucket}/{prefix}/{file_key}'
data_bucket = sagemaker.Session().default_bucket()
s3=boto3.client('s3')
doc_text=list()
page_num=1
prefix=str(uuid.uuid4())

print(f"Bucket is {data_bucket}")

if not doc_text:
    # CAREFUL: this only works with Single pages of scanned PDF documents
    # typically we will have OCR done on the page in advance of the lang chain initiation
    j = call_textract(input_document=doc_path) 

    t_doc = TDocumentSchema().load(j)
    ordered_doc = order_blocks_by_geo(t_doc) #sort by reading order
    trp_doc = trp.Document(TDocumentSchema().dump(ordered_doc))

    doc_content = str()
    # Iterate over elements in the document
    for page in trp_doc.pages:
        # Print lines and words
        for line in page.lines:
            doc_content = doc_content + "\n" + line.text
            
        content_res = bytes(doc_content, 'utf-8')
        s3.put_object(Bucket=data_bucket,
                                Key=f"llm/sample/page-{page_num}.txt",
                                Body=content_res)
        print(f"Page text written into llm/sample/page-{page_num}.txt")
        page_num=page_num+1

The above piece of code calls Amazon Textract on a document and stores the document's page content into S3 in plain text format by page. The code above reads a single page JPG, but similar logic can be implemented for multi-page PDF using Asynchronous `StartDocumentTextDetection` API. For the sake of brevity, we used Textract's real time `DetectDocumentText` which only works on single page documents.

Next we are going to load up the plain text files that we wrote into S3 into LangChain's `Document` interface that easily integrates into the LangChain supported Vector DB (in this case ChromaDB which is an in memory vector DB). We then split the document into chunks, this is required because we may have a large multi-page document and our LLMs ill have token limits. Then these chunks will be loaded into the Vector DB for performing similarity search in the subsequent steps. 

However, before we store the document in the VectorDB, we will have to generate embeddings on the text. We use `HuggingFaceEmbeddings`  which is built into LangChain, for that purpose. For other models you may chose embedding models accordingly as suggested by the model provider.

In [None]:
from langchain.document_loaders import S3DirectoryLoader
from langchain.vectorstores import Chroma
from langchain.text_splitter import NLTKTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document
import sagemaker

data_bucket = sagemaker.Session().default_bucket()
prefix='llm/sample'

embeddings = HuggingFaceEmbeddings()
loader = S3DirectoryLoader(data_bucket, prefix=prefix)
docs = loader.load()
text_splitter = NLTKTextSplitter(chunk_size=550)
texts = text_splitter.split_documents(docs)
vectordb = Chroma.from_documents(texts, embeddings)

In [None]:
docs

## Using HuggingFace FLAN-T5 XXL SageMaker endpoint

Now we have our Vector DB loaded with the chunks of the document. Now all is left is to take a question from the user, perform similarity search on the Vector DB and then give the model the context and the prompt and wait for it to answer the question. But before that let's define a custom QA chain with the same SageMaker endpoint but a slightly different prompt template since we want the model to answer question from the text rather than generate questions. We won't do a detailed prompt engineering as before but rather use a simple prompt in this case, but the previous method may also be utilized to craft a more robust QA prompt. We use LangChain's `PromptTemplate` to craft the prompt this time -

Let's first set the payload parameters of (output) text generation. When invoking the endpoint, our JSON payload can include any desired inference parameters that help control the length, sampling strategy, and output token sequence restrictions. 

You may refer to this [documentation](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig) by HuggingFace for detailed explanation on generation parameters. 

In [None]:
FLAN_T5_PARAMETERS = {
    "temperature": 0.97,           # the value used to modulate the next token probabilities.
    "max_length": 100,             # restrict the length of the generated text.
    "num_return_sequences": 3,     # number of output sequences returned.
    "top_k": 50,                   # in each step of text generation, sample from only the top_k most likely words.
    "top_p": 0.95,                 # in each step of text generation, sample from the smallest possible set of words with cumulative probability top_p.
    "do_sample": True              # whether or not to use sampling; use greedy decoding otherwise.
}

<div class="alert alert-block alert-info"> 
    <b>NOTE:</b> You will need to insert an endpoint name below if you are using your own endpoint. At this point, you should already have a deployed FLAN-T5 model in your account.
</div>

In [None]:
from langchain import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
import json
from typing import Dict

class QAContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
        return input_str.encode('utf-8')
    
    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["generated_texts"][0]

qa_content_handler = QAContentHandler()
prompt_template="""Given the following text from a document, answer the question to the best of your abilities. Answer only from the provided document,, if you do not know the answer 
just say you don't know. DO NOT make up an answer.

Document: {document}
Question: {question}
Answer:
"""

prompt=PromptTemplate( input_variables=["document", "question"], 
                                               template=prompt_template)

qa_chain = LLMChain(
    llm=SagemakerEndpoint(
        endpoint_name=endpoint_name, # replace with your endpoint name if needed
        region_name=region,
        model_kwargs=FLAN_T5_PARAMETERS,
        content_handler=qa_content_handler
    ),
    prompt=prompt
)

In [None]:
question="What is the deductible?"

## Common sense reasoning / natural language inference

Perform a similarity search on the document with `k=3` which means it will return the top-3 chunks of text that are relevant to the question asked.

In [None]:
similar_docs = vectordb.similarity_search(question, k=3) #see also : max_marginal_relevance_search_by_vector(query, k=3)
context_list = [a.page_content for a in similar_docs]
metadata_list = [a.metadata.get('source') for a in similar_docs]
context = "\n\n".join(context_list)
context

## Question and answering

We can now use the custom QA chain with the SageMaker endpoint to provide an answer to our question, based on the content of the documents as shown below.

In [None]:
qa_chain.run({
    'document': context,
    'question': question
    })

# Text summarization

Text summarization involves condensing a given text or a document into a shorter version while retaining its key information. This technique is beneficial for efficient information retrieval which enables the users to quickly grasp the key points of a dicument without reading the entire content. 

While Amazon Textract doesn't directly perform text summarization, it provides the foundational capabilities that can be leveraged for here. Amazon Textract can accurately extract text from various types of documents as seen in the earlier modules. This extracted text serves as an input to our LLM model for performing text summarization tasks.



## Use LangChain to create LLM class for Text extraction and SageMaker endpoint calls

---
Now that we have deployed our endpoints, it is ready to use and perform Summarization on our document. We will use LangChain to perform inference and we need to first create two LLM Classes using the base LangChain LLM Class. Read more about LangChain LLM Class [here](https://python.langchain.com/en/latest/modules/models/llms.html). Specifically we will create two custom LLM classes

1. An LLM class to extract text from our document using Amazon Textract
2. An LLM class to be able to make calls to the SageMaker endpoint where our FlanT5 model is deployed

The purpose of building these custom LLM classes is to be able to easily use these constructs with LangChain's pre-built or custom chains. Read more about LangChain chains [here](https://python.langchain.com/en/latest/modules/chains.html)

### Custom OCR LLM with Amazon Textract

The first step is to read the document using Amazon Textract. As a first step in the chain, we need to make a Boto3 call to Amazon Textract `detect_document_text()` given a document path, the output of which will be sent to the LLM with the prompt we engineered above so that it can recommend us questions. For this purpose we first subclass LangChain's LLM class and create a custom LLM class which essentially calls Amazon Textract's real-time sync `detect_document_text()` API using the Textract textractor and then formats the output using textract-response-parser library. The input to this LLM class is the path to the document and the output is serialized text.

In [None]:
from langchain.llms.base import LLM
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from typing import Optional, List
from textractcaller.t_call import call_textract, Textract_Features
from trp.trp2 import TDocumentSchema
from trp.t_pipeline import order_blocks_by_geo_x_y
import trp
import json

class OcrLLM(LLM):    
    @property
    def _llm_type(self) -> str:
        return "custom"
    
    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        # prompt is the document path
        if stop is not None:
            raise ValueError("stop kwargs are not permitted.")
        j = call_textract(input_document=prompt)
        t_doc = TDocumentSchema().load(j)
        ordered_doc = order_blocks_by_geo_x_y(t_doc)
        trp_doc = trp.Document(TDocumentSchema().dump(ordered_doc))
        document = str()
        for page in trp_doc.pages:
            for line in page.lines:
                document = document + "\n" + line.text
        return document

ocrllm = OcrLLM()
ocr_prompt = PromptTemplate(
    input_variables=["doc_path"],
    template="{doc_path}",
)
ocr_chain = LLMChain(llm=ocrllm, prompt=ocr_prompt)

## Custom SageMaker endpoint LLM class
Next we create a custom LangChain LLM class using LangChain's built in support for SageMaker endpoints, which makes call to the SageMaker hosted inference endpoints. Earlier, we used the FlanT5 model for Common sense reasoning and QA tasks. The following class can take the endpoint and run inference with the provided text for text summarization tasks and is re-usable in any LangChain chain.

In [None]:
from langchain import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.chains import LLMChain
from langchain.prompts import load_prompt, PromptTemplate
import json

class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
        input_str = json.dumps({"text_inputs": prompt,  **model_kwargs})
        return input_str.encode('utf-8')
    
    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json['generated_texts'][0]

content_handler = ContentHandler()
prompt_template = """Write a short summary for this text using your own words without quoting text directly from the provided text. Make sure to only include full and complete sentences: 
{document}"""
prompt = PromptTemplate.from_template(prompt_template)

llm_chain = LLMChain(
    llm=SagemakerEndpoint(
        endpoint_name=endpoint_name, # replace with your endpoint name if needed
        region_name=region,
        model_kwargs={"temperature":0.97,
                      "max_length": 150,
                      "num_return_sequences": 3,
                      "top_k": 50,
                      "top_p": 0.95,
                      "do_sample": True},
        content_handler=content_handler
    ),
    prompt=prompt
)

## Putting things together

We now have two LangChain LLM classes ready, the first one does Amazon Textract OCR on the document and generates an output in plain text. The second LLM class calls the SageMaker endpoint which has the FlanT5 model hosted to generate the summary. Note that the first LLM, i.e. the Amazon Textract LLM class, merely needs the path of the document as part of the prompt. The second LLM class will be given the prompt for summarization and then inject the output of the first LLM i.e. plain text from Textract, into it.

In [None]:
doc_path="./samples/health_plan_pg1.png"

In [None]:
from langchain.chains import SimpleSequentialChain

overall_chain = SimpleSequentialChain(chains=[ocr_chain, llm_chain], verbose=False)
summary = overall_chain.run(doc_path)
print(summary) 

---

### Cleanup

Don't forget to cleanup the memory by deleting the in memory Vector DB collection.

You may want to delete the collection so that your SM Studio domain doesn't run out of memory
`vectordb.delete_collection()`

### Delete the endpoint

Now that you have successfully performed a real-time inference, you do not need the endpoint any more. You can terminate the endpoint to avoid being charged.

In [None]:
model.sagemaker_session.delete_endpoint(endpoint_name)
model.sagemaker_session.delete_endpoint_config(endpoint_name)

### Delete the model

In [None]:
model.delete_model()