{ "cells": [ { "cell_type": "markdown", "id": "5aa67c75-08d0-413b-93c8-fc6f8dd566a2", "metadata": {}, "source": [ "# Intelligent document processing with Gen AI, Amazon Textract and FlanT5 on SageMaker Jumpstart\n", "____\n", "\n", "\n", "
\n", " NOTE: You will need to use a Jupyter Kernel with Python 3.9 or above to use this notebook. For example, you can use the `PyTorch 1.13 Python 3.9` image. Also note that selecting an image may also require you to use a different instance type. We recommend the `GPU Based image with ml.g4dn.xlarge instance type` configuration.\n", "
\n", "\n", "\n", "In this notebook we will first walk through Amazon Textract's document extraction capabilities, and then the steps required to perform Q&A with a document first by extracting text from a document using Amazon Textract, generating chunks of text and store them into a Vector DB, and then performing Q&A with a FlanT5 model deployed in SageMaker endpoint via SageMaker Jumpstart and get precise answers from the model. Later on, we will also use the endpoint to perform text summarization. " ] }, { "cell_type": "markdown", "id": "70342dd7-132e-4f71-b2a7-6311e7dab54c", "metadata": { "tags": [] }, "source": [ "# Setup notebook \n", "\n", "In this step, we will import some necessary libraries that will be used throughout this notebook. " ] }, { "cell_type": "code", "execution_count": null, "id": "b250172e-3298-4157-bf3e-0349d1e60536", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "!pip install -U langchain \n", "!pip install pdfplumber\n", "!pip install unstructured\n", "!pip install chromadb\n", "!pip install -U sentence-transformers\n", "!pip install pydantic==1.10.11 #use 1.10.11 version due to stability\n", "#textractor libraries\n", "!python -m pip install -q amazon-textract-caller --upgrade\n", "!python -m pip install -q amazon-textract-prettyprinter --upgrade\n", "!python -m pip install -q amazon-textract-response-parser --upgrade" ] }, { "cell_type": "markdown", "id": "ec3f2a96-7efb-4eab-8e52-93583fae0afb", "metadata": {}, "source": [ "# Module 1 - Document Extraction " ] }, { "cell_type": "code", "execution_count": null, "id": "4586ad84-849d-4078-b993-2d063fb9bcce", "metadata": { "tags": [] }, "outputs": [], "source": [ "import boto3\n", "import botocore\n", "import sagemaker\n", "from sagemaker.session import Session\n", "from sagemaker.session import Session\n", "from IPython.display import Image, display, JSON\n", "from textractcaller.t_call import call_textract, Textract_Features, call_textract_expense\n", "from textractprettyprinter.t_pretty_print import convert_table_to_list\n", "from trp import Document\n", "import os\n", "import pandas as pd\n", "\n", "# variables\n", "sagemaker_session = Session()\n", "data_bucket = sagemaker.Session().default_bucket()\n", "region = boto3.session.Session().region_name\n", "aws_role = sagemaker_session.get_caller_identity_arn()\n", "\n", "# boto3 clients\n", "s3=boto3.client('s3')\n", "textract = boto3.client('textract', region_name=region)\n", "\n", "print(f\"Region is {region}, IAM Role: {aws_role}, S3 Bucket: {data_bucket}\")" ] }, { "cell_type": "markdown", "id": "1116bdef-fdf8-490b-ae57-a5dd8ded896f", "metadata": {}, "source": [ "## Upload sample data to S3 bucket\n", "\n", "\n", "The sample document is in `/samples` directory. For this workshop, we will be using a sample document." ] }, { "cell_type": "code", "execution_count": null, "id": "99c5ba3c-a1be-423c-8866-5f0a01386125", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Upload images to S3 bucket:\n", "\n", "!aws s3 cp samples s3://{data_bucket}/idp/genai --recursive --only-show-errors" ] }, { "cell_type": "markdown", "id": "2d6ddc56-cc4c-4b19-b770-451bacf4a6ac", "metadata": { "tags": [] }, "source": [ "---\n", "# Extract structured data such as tables and key-value pairs using Amazon Textract\n", "\n", "In this step we will take a brief look at how to extract table and key-value pair information from our sample healthcare policy document. \n", "\n", "### Extracting Tables\n" ] }, { "cell_type": "code", "execution_count": null, "id": "56706bf7-ae61-4c50-bd1d-41a4f552330b", "metadata": { "tags": [] }, "outputs": [], "source": [ "prefix = \"idp/genai\"\n", "file_key = \"health_plan.pdf\"\n", "resp = call_textract(input_document=f's3://{data_bucket}/{prefix}/{file_key}', features=[Textract_Features.TABLES])\n", "tdoc = Document(resp)\n", "dfs = list()" ] }, { "cell_type": "code", "execution_count": null, "id": "97a9239b-8000-4ad2-ab93-af955c1c7354", "metadata": { "tags": [] }, "outputs": [], "source": [ "for page in tdoc.pages:\n", " for table in page.tables:\n", " tab_list = convert_table_to_list(trp_table=table)\n", " print(tab_list)\n", " dfs.append(pd.DataFrame(tab_list))\n", "df1 = dfs[0]\n", "df2 = dfs[1]" ] }, { "cell_type": "code", "execution_count": null, "id": "1bfff420-d4ee-482a-91ad-fd1646e0c14c", "metadata": { "tags": [] }, "outputs": [], "source": [ "df1" ] }, { "cell_type": "code", "execution_count": null, "id": "8c253f5b-dd6b-4314-bc94-db266c675e6f", "metadata": { "tags": [] }, "outputs": [], "source": [ "df2" ] }, { "cell_type": "markdown", "id": "f2b8cebe-50a8-4574-90cf-8f578a862eab", "metadata": { "tags": [] }, "source": [ "### Extracting Forms (key-value pairs) data\n" ] }, { "cell_type": "code", "execution_count": null, "id": "61e0ac8a-bff1-4845-8bdb-cc6d0330c2f8", "metadata": { "tags": [] }, "outputs": [], "source": [ "from textractcaller.t_call import call_textract, Textract_Features\n", "from textractprettyprinter.t_pretty_print import Pretty_Print_Table_Format, Textract_Pretty_Print, get_string\n", "\n", "\n", "# Call Amazon Textract\n", "response = call_textract(input_document=f's3://{data_bucket}/{prefix}/{file_key}', features=[Textract_Features.FORMS])\n", "\n", "\n", "print(get_string(textract_json=response,\n", " table_format=Pretty_Print_Table_Format.csv,\n", " output_type=[Textract_Pretty_Print.FORMS]))" ] }, { "cell_type": "markdown", "id": "f1efd28f-8fc3-4613-a6b8-25d41c24ba7b", "metadata": {}, "source": [ "# Module 2 - Enhancing IDP with Foundation Models" ] }, { "cell_type": "markdown", "id": "a435b72c-988c-4e0c-9a12-9996889e01b4", "metadata": {}, "source": [ "## Select a pre-trained model\n", "---\n", "You can continue with the default model, or can choose a different model from the dropdown generated upon running the next cell. A complete list of SageMaker pre-trained models can also be accessed at [Sagemaker pre-trained Models](https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html#).\n" ] }, { "cell_type": "code", "execution_count": null, "id": "fbf51554-c002-4c6e-a0b3-a1490dd46ff9", "metadata": { "tags": [] }, "outputs": [], "source": [ "# \"huggingface-text2text-flan-t5-xl\",\n", "# \"huggingface-text2text-flan-t5-large\",\n", "\n", "model_id, model_version, = (\n", " \"huggingface-text2text-flan-t5-xl\",\n", " \"*\",\n", ")" ] }, { "cell_type": "markdown", "id": "3916715b-cef5-4ad9-bc2b-3b897b54889d", "metadata": {}, "source": [ "## Retrieve Artifacts & Deploy a HuggingFace FLAN-T5 Endpoint\n", "\n", "---\n", "\n", "Using SageMaker, we can perform inference on the pre-trained model, even without fine-tuning it first on a new dataset. We start by retrieving the `deploy_image_uri`, `deploy_source_uri`, and `model_uri` for the pre-trained model. To host the pre-trained model, we create an instance of [`sagemaker.model.Model`](https://sagemaker.readthedocs.io/en/stable/api/inference/model.html) and deploy it. This may take a few minutes.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b67f4200-5d12-4396-93fe-1f31f8eefdf3", "metadata": { "tags": [] }, "outputs": [], "source": [ "def get_sagemaker_session(local_download_dir) -> sagemaker.Session:\n", " \"\"\"Return the SageMaker session.\"\"\"\n", "\n", " sagemaker_client = boto3.client(\n", " service_name=\"sagemaker\", region_name=boto3.Session().region_name\n", " )\n", "\n", " session_settings = sagemaker.session_settings.SessionSettings(\n", " local_download_dir=local_download_dir\n", " )\n", "\n", " # the unit test will ensure you do not commit this change\n", " session = sagemaker.session.Session(\n", " sagemaker_client=sagemaker_client, settings=session_settings\n", " )\n", "\n", " return session" ] }, { "cell_type": "markdown", "id": "518f52a2-6c18-4f6c-a048-91b9d51cc984", "metadata": {}, "source": [ "We need to create a directory to host the downloaded model." ] }, { "cell_type": "code", "execution_count": null, "id": "eb526248-24c5-494b-af8e-f69758571fb3", "metadata": { "tags": [] }, "outputs": [], "source": [ "!mkdir -p download_dir" ] }, { "cell_type": "markdown", "id": "49ae79c1-b2eb-47a1-ba5e-fa7205941511", "metadata": {}, "source": [ "We will use the code block below to download the model artifacts and then deploy the model on to a SageMaker inference endpoint. Note that we are going to use `ml.g5.2xlarge` inference instance type to deploy the model and the script below may take about ~10 minutes to complete deployment." ] }, { "cell_type": "code", "execution_count": null, "id": "25398a0e-194b-4df1-bfac-d347eff859ef", "metadata": { "tags": [] }, "outputs": [], "source": [ "from sagemaker import image_uris, model_uris, script_uris, hyperparameters\n", "from sagemaker.model import Model\n", "from sagemaker.predictor import Predictor\n", "from sagemaker.utils import name_from_base\n", "import config\n", "\n", "\n", "endpoint_name = name_from_base(f\"{config.SOLUTION_PREFIX}-{model_id}\")\n", "\n", "inference_instance_type = \"ml.g5.2xlarge\"\n", "\n", "# Retrieve the inference docker container uri. This is the base HuggingFace container image for the default model above.\n", "deploy_image_uri = image_uris.retrieve(\n", " region=None,\n", " framework=None, # automatically inferred from model_id\n", " image_scope=\"inference\",\n", " model_id=model_id,\n", " model_version=model_version,\n", " instance_type=inference_instance_type,\n", ")\n", "\n", "# Retrieve the inference script uri. This includes all dependencies and scripts for model loading, inference handling etc.\n", "deploy_source_uri = script_uris.retrieve(\n", " model_id=model_id, model_version=model_version, script_scope=\"inference\"\n", ")\n", "\n", "\n", "# Retrieve the model uri.\n", "model_uri = model_uris.retrieve(\n", " model_id=model_id, model_version=model_version, model_scope=\"inference\"\n", ")\n", "\n", "#Create model\n", "model = Model(\n", " image_uri=deploy_image_uri,\n", " model_data=model_uri,\n", " role=aws_role,\n", " predictor_cls=Predictor,\n", " name=endpoint_name,\n", ")\n", "\n", "# deploy the Model. Note that we need to pass Predictor class when we deploy model through Model class,\n", "# for being able to run inference through the sagemaker API.\n", "model_predictor = model.deploy(\n", " initial_instance_count=1,\n", " instance_type=inference_instance_type,\n", " predictor_cls=Predictor,\n", " endpoint_name=endpoint_name,\n", " # volume_size=30,\n", ")" ] }, { "cell_type": "markdown", "id": "2e20477d-736b-455e-bb17-b50f0fabd587", "metadata": {}, "source": [ "---\n", "# Perform Common sense reasoning and QA on a document\n", "\n", "In this section, we will perform common sense reasoning and Q&A on a document. This section does the following\n", "\n", "- Generates text from documents and stores them into S3 in plaintext format\n", "- Generate embeddings from the text\n", "- Uses an in-memory vector database to store the embeddings\n", "- Perform similarity search on the in-memory vector db to find relevant pieces of text that have relavancy to the asked question (by the user)\n", "- Generate the context for the LLM using the search results\n", "- Give the model the context and the original question asked\n", "- Get the answer back from the LLM\n", "- Profit\n", "\n", "> _\"Wait but that's a lot of steps just for getting an answer back? Why?\"_\n", "\n", "We would love to explain and dive deeper into why, but here's a paper that does a better job of explain the why? and the how? - https://arxiv.org/pdf/2005.11401.pdf . In short, LLMs know too much, _sometimes a bit too much that it may get confused and wander into the proverbial forest of it's own world knowledge and go start gathering firewood, when it was actually asked to go pick some fruit_. To solve this problem, and to get accurate answers (or better no answer at all) we use this method of Retrieval-Augmented Generation (aka RAG), just to give the LLM a bit more _stuff_ to work with such that it gives us the desired output (like a fruit basket in our example, so that it knows it's only supposed to pick fruits) .\n", "\n", "As a first step, we read a file (document) using Amazon Textract and write the plaintext into S3." ] }, { "cell_type": "code", "execution_count": null, "id": "76812801-acd0-4de5-a5c3-93839b71c466", "metadata": { "tags": [] }, "outputs": [], "source": [ "from textractcaller.t_call import call_textract, Textract_Features\n", "from trp.trp2 import TDocument, TDocumentSchema\n", "from trp.t_pipeline import order_blocks_by_geo\n", "import boto3\n", "import sagemaker\n", "import pdfplumber\n", "import mimetypes\n", "import trp\n", "import json\n", "import uuid\n", "\n", "doc_path = f's3://{data_bucket}/{prefix}/{file_key}'\n", "data_bucket = sagemaker.Session().default_bucket()\n", "s3=boto3.client('s3')\n", "doc_text=list()\n", "page_num=1\n", "prefix=str(uuid.uuid4())\n", "\n", "print(f\"Bucket is {data_bucket}\")\n", "\n", "if not doc_text:\n", " # CAREFUL: this only works with Single pages of scanned PDF documents\n", " # typically we will have OCR done on the page in advance of the lang chain initiation\n", " j = call_textract(input_document=doc_path) \n", "\n", " t_doc = TDocumentSchema().load(j)\n", " ordered_doc = order_blocks_by_geo(t_doc) #sort by reading order\n", " trp_doc = trp.Document(TDocumentSchema().dump(ordered_doc))\n", "\n", " doc_content = str()\n", " # Iterate over elements in the document\n", " for page in trp_doc.pages:\n", " # Print lines and words\n", " for line in page.lines:\n", " doc_content = doc_content + \"\\n\" + line.text\n", " \n", " content_res = bytes(doc_content, 'utf-8')\n", " s3.put_object(Bucket=data_bucket,\n", " Key=f\"llm/sample/page-{page_num}.txt\",\n", " Body=content_res)\n", " print(f\"Page text written into llm/sample/page-{page_num}.txt\")\n", " page_num=page_num+1" ] }, { "cell_type": "markdown", "id": "46ef3ad0-a98c-4a32-85fc-92f05f0d1aac", "metadata": {}, "source": [ "The above piece of code calls Amazon Textract on a document and stores the document's page content into S3 in plain text format by page. The code above reads a single page JPG, but similar logic can be implemented for multi-page PDF using Asynchronous `StartDocumentTextDetection` API. For the sake of brevity, we used Textract's real time `DetectDocumentText` which only works on single page documents.\n", "\n", "Next we are going to load up the plain text files that we wrote into S3 into LangChain's `Document` interface that easily integrates into the LangChain supported Vector DB (in this case ChromaDB which is an in memory vector DB). We then split the document into chunks, this is required because we may have a large multi-page document and our LLMs ill have token limits. Then these chunks will be loaded into the Vector DB for performing similarity search in the subsequent steps. \n", "\n", "However, before we store the document in the VectorDB, we will have to generate embeddings on the text. We use `HuggingFaceEmbeddings` which is built into LangChain, for that purpose. For other models you may chose embedding models accordingly as suggested by the model provider." ] }, { "cell_type": "code", "execution_count": null, "id": "09069c6e-aedb-45dc-ac99-2dcb48ea979f", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "from langchain.document_loaders import S3DirectoryLoader\n", "from langchain.vectorstores import Chroma\n", "from langchain.text_splitter import NLTKTextSplitter\n", "from langchain.embeddings import HuggingFaceEmbeddings\n", "from langchain.schema import Document\n", "import sagemaker\n", "\n", "data_bucket = sagemaker.Session().default_bucket()\n", "prefix='llm/sample'\n", "\n", "embeddings = HuggingFaceEmbeddings()\n", "loader = S3DirectoryLoader(data_bucket, prefix=prefix)\n", "docs = loader.load()\n", "text_splitter = NLTKTextSplitter(chunk_size=550)\n", "texts = text_splitter.split_documents(docs)\n", "vectordb = Chroma.from_documents(texts, embeddings)" ] }, { "cell_type": "code", "execution_count": null, "id": "639f01ec-07d8-4e8a-8fdc-84d12a7ca1e0", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "docs" ] }, { "cell_type": "markdown", "id": "658eca5f-7e36-4372-813d-6f2c088fc248", "metadata": {}, "source": [ "## Using HuggingFace FLAN-T5 XXL SageMaker endpoint\n", "\n", "Now we have our Vector DB loaded with the chunks of the document. Now all is left is to take a question from the user, perform similarity search on the Vector DB and then give the model the context and the prompt and wait for it to answer the question. But before that let's define a custom QA chain with the same SageMaker endpoint but a slightly different prompt template since we want the model to answer question from the text rather than generate questions. We won't do a detailed prompt engineering as before but rather use a simple prompt in this case, but the previous method may also be utilized to craft a more robust QA prompt. We use LangChain's `PromptTemplate` to craft the prompt this time -" ] }, { "cell_type": "markdown", "id": "8fc2f09f-f6e2-442d-83fc-8b82467d361c", "metadata": {}, "source": [ "Let's first set the payload parameters of (output) text generation. When invoking the endpoint, our JSON payload can include any desired inference parameters that help control the length, sampling strategy, and output token sequence restrictions. \n", "\n", "You may refer to this [documentation](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig) by HuggingFace for detailed explanation on generation parameters. " ] }, { "cell_type": "code", "execution_count": null, "id": "9c094f65-2b38-4f58-9791-67a4c44204f1", "metadata": { "tags": [] }, "outputs": [], "source": [ "FLAN_T5_PARAMETERS = {\n", " \"temperature\": 0.97, # the value used to modulate the next token probabilities.\n", " \"max_length\": 100, # restrict the length of the generated text.\n", " \"num_return_sequences\": 3, # number of output sequences returned.\n", " \"top_k\": 50, # in each step of text generation, sample from only the top_k most likely words.\n", " \"top_p\": 0.95, # in each step of text generation, sample from the smallest possible set of words with cumulative probability top_p.\n", " \"do_sample\": True # whether or not to use sampling; use greedy decoding otherwise.\n", "}" ] }, { "cell_type": "markdown", "id": "80d4271f-bf40-4038-9d0c-29b04acb034f", "metadata": {}, "source": [ "
\n", " NOTE: You will need to insert an endpoint name below if you are using your own endpoint. At this point, you should already have a deployed FLAN-T5 model in your account.\n", "
" ] }, { "cell_type": "code", "execution_count": null, "id": "31f099d4-2ee6-4199-8339-dfaca37927d4", "metadata": { "tags": [] }, "outputs": [], "source": [ "from langchain import SagemakerEndpoint\n", "from langchain.llms.sagemaker_endpoint import LLMContentHandler\n", "from langchain.chains import LLMChain\n", "from langchain.prompts import PromptTemplate\n", "import json\n", "from typing import Dict\n", "\n", "class QAContentHandler(LLMContentHandler):\n", " content_type = \"application/json\"\n", " accepts = \"application/json\"\n", "\n", " def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n", " input_str = json.dumps({\"text_inputs\": prompt, **model_kwargs})\n", " return input_str.encode('utf-8')\n", " \n", " def transform_output(self, output: bytes) -> str:\n", " response_json = json.loads(output.read().decode(\"utf-8\"))\n", " return response_json[\"generated_texts\"][0]\n", "\n", "qa_content_handler = QAContentHandler()\n", "prompt_template=\"\"\"Given the following text from a document, answer the question to the best of your abilities. Answer only from the provided document,, if you do not know the answer \n", "just say you don't know. DO NOT make up an answer.\n", "\n", "Document: {document}\n", "Question: {question}\n", "Answer:\n", "\"\"\"\n", "\n", "prompt=PromptTemplate( input_variables=[\"document\", \"question\"], \n", " template=prompt_template)\n", "\n", "qa_chain = LLMChain(\n", " llm=SagemakerEndpoint(\n", " endpoint_name=endpoint_name, # replace with your endpoint name if needed\n", " region_name=region,\n", " model_kwargs=FLAN_T5_PARAMETERS,\n", " content_handler=qa_content_handler\n", " ),\n", " prompt=prompt\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "5018ff18", "metadata": { "tags": [] }, "outputs": [], "source": [ "question=\"What is the deductible?\"" ] }, { "cell_type": "markdown", "id": "acccba52-47db-4e26-95e7-ab7f6f034f9f", "metadata": {}, "source": [ "## Common sense reasoning / natural language inference\n", "\n", "Perform a similarity search on the document with `k=3` which means it will return the top-3 chunks of text that are relevant to the question asked." ] }, { "cell_type": "code", "execution_count": null, "id": "d12d790d-d490-4a57-8b38-8386421aa193", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "similar_docs = vectordb.similarity_search(question, k=3) #see also : max_marginal_relevance_search_by_vector(query, k=3)\n", "context_list = [a.page_content for a in similar_docs]\n", "metadata_list = [a.metadata.get('source') for a in similar_docs]\n", "context = \"\\n\\n\".join(context_list)\n", "context" ] }, { "cell_type": "markdown", "id": "551eb3f3", "metadata": {}, "source": [ "## Question and answering\n", "\n", "We can now use the custom QA chain with the SageMaker endpoint to provide an answer to our question, based on the content of the documents as shown below." ] }, { "cell_type": "code", "execution_count": null, "id": "f6856c8b-7ea2-4dc8-80eb-7d8aff4e8dce", "metadata": { "tags": [] }, "outputs": [], "source": [ "qa_chain.run({\n", " 'document': context,\n", " 'question': question\n", " })" ] }, { "cell_type": "markdown", "id": "314fcfbc", "metadata": {}, "source": [ "# Text summarization\n", "\n", "Text summarization involves condensing a given text or a document into a shorter version while retaining its key information. This technique is beneficial for efficient information retrieval which enables the users to quickly grasp the key points of a dicument without reading the entire content. \n", "\n", "While Amazon Textract doesn't directly perform text summarization, it provides the foundational capabilities that can be leveraged for here. Amazon Textract can accurately extract text from various types of documents as seen in the earlier modules. This extracted text serves as an input to our LLM model for performing text summarization tasks.\n", "\n" ] }, { "cell_type": "markdown", "id": "1f14e7ba-3bde-4f8e-9da2-1dea380bf08d", "metadata": {}, "source": [ "## Use LangChain to create LLM class for Text extraction and SageMaker endpoint calls\n", "\n", "---\n", "Now that we have deployed our endpoints, it is ready to use and perform Summarization on our document. We will use LangChain to perform inference and we need to first create two LLM Classes using the base LangChain LLM Class. Read more about LangChain LLM Class [here](https://python.langchain.com/en/latest/modules/models/llms.html). Specifically we will create two custom LLM classes\n", "\n", "1. An LLM class to extract text from our document using Amazon Textract\n", "2. An LLM class to be able to make calls to the SageMaker endpoint where our FlanT5 model is deployed\n", "\n", "The purpose of building these custom LLM classes is to be able to easily use these constructs with LangChain's pre-built or custom chains. Read more about LangChain chains [here](https://python.langchain.com/en/latest/modules/chains.html)" ] }, { "cell_type": "markdown", "id": "63349e31-6e7a-4bf7-8594-78c828b7410b", "metadata": { "tags": [] }, "source": [ "### Custom OCR LLM with Amazon Textract\n", "\n", "The first step is to read the document using Amazon Textract. As a first step in the chain, we need to make a Boto3 call to Amazon Textract `detect_document_text()` given a document path, the output of which will be sent to the LLM with the prompt we engineered above so that it can recommend us questions. For this purpose we first subclass LangChain's LLM class and create a custom LLM class which essentially calls Amazon Textract's real-time sync `detect_document_text()` API using the Textract textractor and then formats the output using textract-response-parser library. The input to this LLM class is the path to the document and the output is serialized text." ] }, { "cell_type": "code", "execution_count": null, "id": "aee84ad2-cd3d-4cbc-af7c-78df8bf84268", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "from langchain.llms.base import LLM\n", "from langchain.prompts import PromptTemplate\n", "from langchain.chains import LLMChain\n", "from typing import Optional, List\n", "from textractcaller.t_call import call_textract, Textract_Features\n", "from trp.trp2 import TDocumentSchema\n", "from trp.t_pipeline import order_blocks_by_geo_x_y\n", "import trp\n", "import json\n", "\n", "class OcrLLM(LLM): \n", " @property\n", " def _llm_type(self) -> str:\n", " return \"custom\"\n", " \n", " def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:\n", " # prompt is the document path\n", " if stop is not None:\n", " raise ValueError(\"stop kwargs are not permitted.\")\n", " j = call_textract(input_document=prompt)\n", " t_doc = TDocumentSchema().load(j)\n", " ordered_doc = order_blocks_by_geo_x_y(t_doc)\n", " trp_doc = trp.Document(TDocumentSchema().dump(ordered_doc))\n", " document = str()\n", " for page in trp_doc.pages:\n", " for line in page.lines:\n", " document = document + \"\\n\" + line.text\n", " return document\n", "\n", "ocrllm = OcrLLM()\n", "ocr_prompt = PromptTemplate(\n", " input_variables=[\"doc_path\"],\n", " template=\"{doc_path}\",\n", ")\n", "ocr_chain = LLMChain(llm=ocrllm, prompt=ocr_prompt)" ] }, { "cell_type": "markdown", "id": "ce112d8e", "metadata": {}, "source": [ "## Custom SageMaker endpoint LLM class\n", "Next we create a custom LangChain LLM class using LangChain's built in support for SageMaker endpoints, which makes call to the SageMaker hosted inference endpoints. Earlier, we used the FlanT5 model for Common sense reasoning and QA tasks. The following class can take the endpoint and run inference with the provided text for text summarization tasks and is re-usable in any LangChain chain." ] }, { "cell_type": "code", "execution_count": null, "id": "d6cf5dd6", "metadata": { "tags": [] }, "outputs": [], "source": [ "from langchain import SagemakerEndpoint\n", "from langchain.llms.sagemaker_endpoint import LLMContentHandler\n", "from langchain.chains import LLMChain\n", "from langchain.prompts import load_prompt, PromptTemplate\n", "import json\n", "\n", "class ContentHandler(LLMContentHandler):\n", " content_type = \"application/json\"\n", " accepts = \"application/json\"\n", "\n", " def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:\n", " input_str = json.dumps({\"text_inputs\": prompt, **model_kwargs})\n", " return input_str.encode('utf-8')\n", " \n", " def transform_output(self, output: bytes) -> str:\n", " response_json = json.loads(output.read().decode(\"utf-8\"))\n", " return response_json['generated_texts'][0]\n", "\n", "content_handler = ContentHandler()\n", "prompt_template = \"\"\"Write a short summary for this text using your own words without quoting text directly from the provided text. Make sure to only include full and complete sentences: \n", "{document}\"\"\"\n", "prompt = PromptTemplate.from_template(prompt_template)\n", "\n", "llm_chain = LLMChain(\n", " llm=SagemakerEndpoint(\n", " endpoint_name=endpoint_name, # replace with your endpoint name if needed\n", " region_name=region,\n", " model_kwargs={\"temperature\":0.97,\n", " \"max_length\": 150,\n", " \"num_return_sequences\": 3,\n", " \"top_k\": 50,\n", " \"top_p\": 0.95,\n", " \"do_sample\": True},\n", " content_handler=content_handler\n", " ),\n", " prompt=prompt\n", ")" ] }, { "cell_type": "markdown", "id": "347734d5", "metadata": {}, "source": [ "## Putting things together\n", "\n", "We now have two LangChain LLM classes ready, the first one does Amazon Textract OCR on the document and generates an output in plain text. The second LLM class calls the SageMaker endpoint which has the FlanT5 model hosted to generate the summary. Note that the first LLM, i.e. the Amazon Textract LLM class, merely needs the path of the document as part of the prompt. The second LLM class will be given the prompt for summarization and then inject the output of the first LLM i.e. plain text from Textract, into it." ] }, { "cell_type": "code", "execution_count": null, "id": "0c8f66b2", "metadata": { "tags": [] }, "outputs": [], "source": [ "doc_path=\"./samples/health_plan_pg1.png\"" ] }, { "cell_type": "code", "execution_count": null, "id": "4eb5912a", "metadata": { "tags": [] }, "outputs": [], "source": [ "from langchain.chains import SimpleSequentialChain\n", "\n", "overall_chain = SimpleSequentialChain(chains=[ocr_chain, llm_chain], verbose=False)\n", "summary = overall_chain.run(doc_path)\n", "print(summary) " ] }, { "cell_type": "markdown", "id": "bf70688e-bf6b-4fda-8b75-24bca933b624", "metadata": {}, "source": [ "---\n", "\n", "### Cleanup\n", "\n", "Don't forget to cleanup the memory by deleting the in memory Vector DB collection.\n", "\n", "You may want to delete the collection so that your SM Studio domain doesn't run out of memory\n", "`vectordb.delete_collection()`" ] }, { "cell_type": "markdown", "id": "b78a6387-c8c5-4813-9cf2-e88a436d0d72", "metadata": {}, "source": [ "### Delete the endpoint" ] }, { "cell_type": "markdown", "id": "1ab76136-be35-4004-abfa-16f332cfe8b0", "metadata": {}, "source": [ "Now that you have successfully performed a real-time inference, you do not need the endpoint any more. You can terminate the endpoint to avoid being charged." ] }, { "cell_type": "code", "execution_count": null, "id": "d2ff1718-b07c-4e4e-ba2b-229063e5ed45", "metadata": { "tags": [] }, "outputs": [], "source": [ "model.sagemaker_session.delete_endpoint(endpoint_name)\n", "model.sagemaker_session.delete_endpoint_config(endpoint_name)" ] }, { "cell_type": "markdown", "id": "8f63c96b-5532-46f0-8bc1-0532468efec9", "metadata": {}, "source": [ "### Delete the model" ] }, { "cell_type": "code", "execution_count": null, "id": "49ffdbdc-873b-4ff3-993f-ea64674a7606", "metadata": { "tags": [] }, "outputs": [], "source": [ "model.delete_model()" ] } ], "metadata": { "availableInstances": [ { "_defaultOrder": 0, "_isFastLaunch": true, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 4, "name": "ml.t3.medium", "vcpuNum": 2 }, { "_defaultOrder": 1, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 8, "name": "ml.t3.large", "vcpuNum": 2 }, { "_defaultOrder": 2, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.t3.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 3, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.t3.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 4, "_isFastLaunch": true, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 8, "name": "ml.m5.large", "vcpuNum": 2 }, { "_defaultOrder": 5, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.m5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 6, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.m5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 7, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.m5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 8, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.m5.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 9, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.m5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 10, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.m5.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 11, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 384, "name": "ml.m5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 12, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 8, "name": "ml.m5d.large", "vcpuNum": 2 }, { "_defaultOrder": 13, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.m5d.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 14, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.m5d.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 15, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.m5d.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 16, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.m5d.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 17, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.m5d.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 18, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.m5d.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 19, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 384, "name": "ml.m5d.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 20, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "hideHardwareSpecs": true, "memoryGiB": 0, "name": "ml.geospatial.interactive", "supportedImageNames": [ "sagemaker-geospatial-v1-0" ], "vcpuNum": 0 }, { "_defaultOrder": 21, "_isFastLaunch": true, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 4, "name": "ml.c5.large", "vcpuNum": 2 }, { "_defaultOrder": 22, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 8, "name": "ml.c5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 23, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.c5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 24, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.c5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 25, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 72, "name": "ml.c5.9xlarge", "vcpuNum": 36 }, { "_defaultOrder": 26, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 96, "name": "ml.c5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 27, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 144, "name": "ml.c5.18xlarge", "vcpuNum": 72 }, { "_defaultOrder": 28, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.c5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 29, "_isFastLaunch": true, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.g4dn.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 30, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.g4dn.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 31, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.g4dn.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 32, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.g4dn.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 33, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.g4dn.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 34, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.g4dn.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 35, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 61, "name": "ml.p3.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 36, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "hideHardwareSpecs": false, "memoryGiB": 244, "name": "ml.p3.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 37, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 8, "hideHardwareSpecs": false, "memoryGiB": 488, "name": "ml.p3.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 38, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 8, "hideHardwareSpecs": false, "memoryGiB": 768, "name": "ml.p3dn.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 39, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.r5.large", "vcpuNum": 2 }, { "_defaultOrder": 40, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.r5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 41, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.r5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 42, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.r5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 43, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.r5.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 44, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 384, "name": "ml.r5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 45, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 512, "name": "ml.r5.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 46, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "hideHardwareSpecs": false, "memoryGiB": 768, "name": "ml.r5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 47, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 16, "name": "ml.g5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 48, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 32, "name": "ml.g5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 49, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 64, "name": "ml.g5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 50, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 128, "name": "ml.g5.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 51, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "hideHardwareSpecs": false, "memoryGiB": 256, "name": "ml.g5.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 52, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "hideHardwareSpecs": false, "memoryGiB": 192, "name": "ml.g5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 53, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "hideHardwareSpecs": false, "memoryGiB": 384, "name": "ml.g5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 54, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 8, "hideHardwareSpecs": false, "memoryGiB": 768, "name": "ml.g5.48xlarge", "vcpuNum": 192 }, { "_defaultOrder": 55, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 8, "hideHardwareSpecs": false, "memoryGiB": 1152, "name": "ml.p4d.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 56, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 8, "hideHardwareSpecs": false, "memoryGiB": 1152, "name": "ml.p4de.24xlarge", "vcpuNum": 96 } ], "instance_type": "ml.g4dn.xlarge", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" } }, "nbformat": 4, "nbformat_minor": 5 }