{ "cells": [ { "cell_type": "markdown", "id": "db9dc9d8", "metadata": {}, "source": [ "# Building a custom entity recognizer for PDF documents using Amazon Comprehend" ] }, { "attachments": {}, "cell_type": "markdown", "id": "97553a8a", "metadata": {}, "source": [ "In many industries, it is critical to extract custom entities from documents in a timely manner. This can be challenging. Insurance claims, for example, often contain dozens of important attributes (e.g. dates, names, locations, reports) sprinkled across lengthy and dense documents. Manually scanning and extracting such information is, therefore, error-prone and time-consuming. Rule-based software can help, but ultimately is too rigid to adapt to the many varying document types and layouts. \n", "\n", "To help automate and speed up this process, Amazon Comprehend can be used to detect custom entities quickly and accurately by leveraging machine learning. This approach is flexible and accurate, because the system can adapt to new documents by leveraging what it has learned in the past. Until recently, however, this capability could only be applied to plain text documents which meant that positional information was lost when converting the documents from their native format. To address this, it was recently announced that Comprehend can now extract custom entities in native PDF and word format. In this blog post, we will walk through a concrete example from the insurance industry of how you can build a custom recognizer using PDF annotations.\n", "\n", "Specifically, we will:\n", " * Set up IAM permissions to do custom training on Sagemaker \n", " * Explore the format of PDF annotations\n", " * Use the PDF annotations to train a custom model using the Python API\n", " * Obtain evaluation metrics from the trained model\n", " * Perform inference on an unseen document " ] }, { "cell_type": "markdown", "id": "c24fd6cc", "metadata": {}, "source": [ "## Getting started" ] }, { "cell_type": "markdown", "id": "7a9f63bc", "metadata": {}, "source": [ "Before training a model, we will need to give SageMaker permission to pass data from Amazon Simple Storage Service (Amazon S3) to AWS Comprehend to run entity recognition jobs." ] }, { "cell_type": "markdown", "id": "f2be5e44", "metadata": {}, "source": [ "We will:\n", "\n", " A. Create a new IAM role with an attached policy to read from/write to s3 \n", " B. Update the role's trust agreement with Comprehend service\n", " C. Attach a policy to Sagemaker that will pass along the role to Comprehend when asked\n" ] }, { "cell_type": "code", "execution_count": null, "id": "74ba7ebe", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Enter your AWS ACCOUNT ID\n", "AWS_ACCOUNT_ID = 288947426911" ] }, { "cell_type": "markdown", "id": "4cc93f84", "metadata": {}, "source": [ "**A. Create a new IAM role with an attached policy to read from/write to s3**\n", "\n", "1. Create new IAM role, choose EC2 use-case\n", "2. Create a new policy (opens new tab), with following JSON code (execute cell and copy result), and call it \"policy-rk-read-write\"\n", "3. Go back to creating new IAM role (previous tab), refresh policy policy list, and attach \"policy-rk-read-write\"\n", "4. Name role \"comprehends3access\"\n", "5. Create a private s3 bucket s3://my-custom-comprehend-output-{AWS_ACCOUNT_ID}, to capture output from your detection job" ] }, { "cell_type": "code", "execution_count": null, "id": "f08002c3", "metadata": { "tags": [] }, "outputs": [], "source": [ "import json\n", "\n", "j = {\n", " \"Statement\": [\n", " {\n", " \"Action\": [\n", " \"s3:ListBucket\"\n", " ],\n", " \"Effect\": \"Allow\",\n", " \"Resource\": [\n", " \"arn:aws:s3:::ee-assets-prod-us-east-1\",\n", " f\"arn:aws:s3:::my-custom-comprehend-output-{AWS_ACCOUNT_ID}/*\"\n", " ]\n", " },\n", " {\n", " \"Action\": [\n", " \"s3:PutObject\",\n", " \"s3:DeleteObject\"\n", " ],\n", " \"Effect\": \"Allow\",\n", " \"Resource\": [\n", " f\"arn:aws:s3:::my-custom-comprehend-output-{AWS_ACCOUNT_ID}/*\"\n", " ]\n", " },\n", " {\n", " \"Action\": [\n", " \"s3:GetObject\"\n", " ],\n", " \"Effect\": \"Allow\",\n", " \"Resource\": [\n", " \"arn:aws:s3:::ee-assets-prod-us-east-1/*\"\n", " ]\n", " }\n", " ],\n", " \"Version\": \"2012-10-17\"\n", "}\n", "\n", "print(json.dumps(j, indent=4, sort_keys=True))\n" ] }, { "cell_type": "markdown", "id": "006729e4", "metadata": {}, "source": [ "**B. Update the role's trust agreement with Comprehend service**\n", "\n", "1. Navigate to \"comprehends3access\" role\n", "2. Edit trust relationship\n", "3. Replace EC2 with Comprehend for the Service field: \"comprehend.amazonaws.com\"" ] }, { "cell_type": "markdown", "id": "0fc23aed", "metadata": {}, "source": [ "**C. Attach a policy to Sagemaker that will pass along the role to Comprehend when asked**\n", "\n", "1. Create new policy, using JSON below (execute and copy)\n", "2. Name the policy \"IAMPassPolicyComprehend\"\n", "3. Navigate to the AmazonSageMaker-ExecutionRole-XXXX role, add permissions, and attach this policy\n", "4. Also add \"ComprehendFullAccess\" and \"AmazonTextractFullAccess\" policies to the this role" ] }, { "cell_type": "code", "execution_count": null, "id": "80ab9ae6", "metadata": { "tags": [] }, "outputs": [], "source": [ "j = {\n", " \"Version\": \"2012-10-17\",\n", " \"Statement\": [\n", " {\n", " \"Sid\": \"ComprehendAsyncPass\",\n", " \"Effect\": \"Allow\",\n", " \"Action\": \"iam:PassRole\",\n", " \"Resource\": \"arn:aws:iam::{}:role/comprehends3access\".format(AWS_ACCOUNT_ID)\n", " \n", " }\n", " ]\n", "}\n", "\n", "print(json.dumps(j, indent=4, sort_keys=True))" ] }, { "cell_type": "code", "execution_count": null, "id": "ffa1f8c3", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Imports\n", "%load_ext autoreload\n", "%autoreload 2\n", "from pprint import pprint\n", "import os\n", "import sys\n", "import json\n", "import boto3\n", "import time\n", "import uuid\n", "import pandas as pd\n", "module_path = os.path.join(os.path.abspath(os.path.join('.')), 'helperPackage')\n", "if module_path not in sys.path:\n", " sys.path.append(module_path)\n", "from pdfhelper.PDFHelper import PDFHelper\n", "from IPython.display import IFrame\n", "#!pip install --upgrade pymupdf\n", "\n", "def split_s3_uri(uri):\n", " \"\"\"return (bucket, key) tuple from s3 uri like 's3://bucket/prefix/file.txt' \"\"\"\n", " return uri.replace('s3://','').split('/',1)\n", "\n", "def s3_object_from_uri(uri):\n", " \"\"\"Initialize a boto3 s3 Object instance from a URI\"\"\"\n", " s3 = boto3.resource('s3')\n", " return s3.Object(*split_s3_uri(uri))\n", "\n", "def s3_contents_from_uri(uri, decode=True):\n", " \"\"\"Read contents from S3 object into memory\"\"\"\n", " data = s3_object_from_uri(uri).get()['Body'].read()\n", " return data.decode() if decode else data\n", "\n", "def s3_download_file(s3_uri, localpath):\n", " \"\"\"Download file from s3\"\"\"\n", " s3 = boto3.client('s3')\n", " bucket, key = split_s3_uri(s3_uri)\n", " s3.download_file(bucket, key, localpath)\n", " \n", "def split_s3_uri(s3_uri):\n", " \"\"\"Split url into bucket and key\"\"\"\n", " bucket, key = s3_uri.replace(\"s3://\", \"\").split(\"/\", 1)\n", " return bucket, key\n", "\n", "comprehend = boto3.client('comprehend')" ] }, { "cell_type": "code", "execution_count": null, "id": "e199588d", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Set up paths\n", "ASSETS_S3_PREFIX = 's3://ee-assets-prod-us-east-1/modules/b2d6c897c659445583c2edb826183e8e/v1/'\n", "\n", "# Information about the training data and how the SageMaker Ground Truth job output looks in S3\n", "TRAINING_DOCS_S3_URI_PREFIX = os.path.join(ASSETS_S3_PREFIX, 'documents/')\n", "ANNOTATIONS_S3_URI_PREFIX = os.path.join(ASSETS_S3_PREFIX, 'annotations/annotations/consolidated-annotation/consolidation-response/iteration-1/annotations/')\n", "MANIFEST_S3_URI = os.path.join(ASSETS_S3_PREFIX, 'annotations/manifests/output/output.manifest')\n", "LABEL_ATTRIBUTE_NAME = 'claim-full-job-labeling-job-20211019T163532'\n", "\n", "# local directory containing data reference in this notebook\n", "LOCAL_ARTIFACTS_DIR = 'ComprehendCustom-Artifacts'\n", "# local path to store results in\n", "LOCAL_OUTPUT_DIR = 'tmp/ComprehendCustom'\n", "# set up tmp dir under the working directory\n", "!mkdir -p {LOCAL_OUTPUT_DIR}\n", "\n", "# Download some files for later\n", "# Raw PDFs\n", "!mkdir -p {LOCAL_ARTIFACTS_DIR}/ex_pdfs\n", "fnames = ['INSR_ACORD-Property-Loss-Notice-12.05.16_1_pii_00000.pdf', 'INSR_ACORD-Property-Loss-Notice-12.05.16_1_pii_00017.pdf', 'INSR_pm_hipaa_1_pii_00048.pdf']\n", "for fname in fnames:\n", " s3_download_file(os.path.join(TRAINING_DOCS_S3_URI_PREFIX, fname), f'{LOCAL_ARTIFACTS_DIR}/ex_pdfs/{fname}')" ] }, { "cell_type": "markdown", "id": "6dace7cd", "metadata": {}, "source": [ "Now that we have set things up, we can train our model to detect the following five entities that we chose because of their relevance to insurance claims: DateOfForm, DateOfLoss, NameOfInsured, LocationOfLoss, and InsuredMailingAddress" ] }, { "cell_type": "markdown", "id": "a725be32", "metadata": {}, "source": [ "## Explore the format of PDF annotations" ] }, { "cell_type": "markdown", "id": "c0858e4a", "metadata": {}, "source": [ "To create annoations for PDF documents, you can use [Amazon SageMaker GroundTruth](https://aws.amazon.com/sagemaker/groundtruth/) - a fully managed data labeling service that makes it easy to build highly accurate training datasets for machine learning.\n", "\n", "For this tutorial, we have already annotated the PDFs, in their native form (i.e. without converting to plain text) using SageMaker GroundTruth. (To set up your own annotation job, refer to the resources in the **Summary/Resources** section of this notebook)\n", "\n", "The Ground Truth job generates three paths we will need for training our Comprehend custom model.\n", "1. Sources: Path to the input PDFs\n", "2. Annotations: Path to the annotation jsons containing the labeled entity information\n", "3. Manifest: Points to the location of the annotations and source PDFs. You will use this manifest file to create an Amazon Comprehend custom entity recognition training job and train your custom model. Manifests are saved in s3://comprehend-semi-structured-documents-us-east-1--/output/your labeling job/manifests/output/" ] }, { "cell_type": "code", "execution_count": null, "id": "499d8a3c", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Let's have a look at a sample annotation\n", "\n", "# We will find the line of the manifest corresponding to a particular input document\n", "document_s3_uri = os.path.join(ASSETS_S3_PREFIX, 'documents','INSR_ACORD-Property-Loss-Notice-12.05.16_1_pii_00000.pdf')\n", "\n", "manifest_data = [json.loads(obj) for obj in s3_contents_from_uri(MANIFEST_S3_URI).splitlines()]\n", "\n", "manifest_line = [r for r in manifest_data if r['source-ref']==document_s3_uri][0]\n", "# manifest_line\n", "\n", "# Let's download the annotation file and look at a sample annotation\n", "\n", "annotations_uri = manifest_line[LABEL_ATTRIBUTE_NAME]['annotation-ref']\n", "\n", "annotations = json.loads(s3_contents_from_uri(annotations_uri))\n", "annotations['Entities'][0]\n", "\n", "## Uncomment the following line to see more of the annotated entities:\n", "# annotations['Entities']" ] }, { "cell_type": "markdown", "id": "755524c0", "metadata": {}, "source": [ "As you can see above, the custom GroundTruth job generates a PDF annotation that captures block-level information about the entity. Such block-level information provides the precise positional coordinates of the entity (with the child blocks representing each word within the entity block). This is distinct from a standard GroundTruth job in which the data in the PDF is flattened to textual format and only offset information - but not precise coordinate information - is captured during annotation. The rich positional information we obtain with this custom annotation paradigm will allow us to train a more accurate model. \n", "\n", "The manifest that's generated from this type of job is called an Augmented Manifest, as opposed to a CSV that's used for standard annotations. For more information, see: https://docs.aws.amazon.com/comprehend/latest/dg/cer-annotation.html\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "0eda2e3b", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Visualize the annotated pdf inline\n", "\n", "original_file = f'{LOCAL_ARTIFACTS_DIR}/ex_pdfs/INSR_ACORD-Property-Loss-Notice-12.05.16_1_pii_00000.pdf'\n", "annotated_file = f'{LOCAL_OUTPUT_DIR}/annotated.pdf'\n", "\n", "# using a custom module (PDFHelper) to add annotations the file before displaying\n", "PDFHelper.add_annotations_to_file(annotations, original_file, annotated_file)\n", "IFrame(annotated_file, width=600, height=800)\n", "\n", "# Note: you may need to zoom in to read the label names" ] }, { "cell_type": "code", "execution_count": null, "id": "a58977a6", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Lets look at another annotated sample\n", "\n", "# changed the document s3 uri\n", "document_s3_uri = os.path.join(ASSETS_S3_PREFIX, 'documents','INSR_pm_hipaa_1_pii_00048.pdf')\n", "\n", "# get the annotations data\n", "manifest_data = [json.loads(obj) for obj in s3_contents_from_uri(os.path.join(ASSETS_S3_PREFIX, 'annotations/manifests/output/output.manifest')).splitlines()]\n", "manifest_line = [r for r in manifest_data if r['source-ref']==document_s3_uri][0]\n", "annotations_uri = manifest_line[LABEL_ATTRIBUTE_NAME]['annotation-ref']\n", "annotations = json.loads(s3_contents_from_uri(annotations_uri))\n", "\n", "original_file = f'{LOCAL_ARTIFACTS_DIR}/ex_pdfs/INSR_pm_hipaa_1_pii_00048.pdf'\n", "annotated_file =f'{LOCAL_OUTPUT_DIR}/INSR_pm_hipaa_1_pii_00048_annotated.pdf'\n", "\n", "# using a custom module (PDFHelper) to add annotations the file before displaying\n", "PDFHelper.add_annotations_to_file(annotations, original_file, annotated_file)\n", "IFrame(annotated_file, width=600, height=800)" ] }, { "cell_type": "markdown", "id": "1427e6f2", "metadata": {}, "source": [ "# Use the PDF annotations to train a custom model using the Python API" ] }, { "cell_type": "markdown", "id": "39e70ba3", "metadata": {}, "source": [ "An augmented manifest file must be formatted in JSON Lines format. In JSON Lines format, each line in the file is a complete JSON object followed by a newline separator." ] }, { "cell_type": "code", "execution_count": null, "id": "79169514", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Let's have a look at an entry within this augmented manifest file.\n", "manifest_line" ] }, { "cell_type": "markdown", "id": "d8693078", "metadata": {}, "source": [ "A few things to note:\n", "\n", "* There are 5 labeling types associated with this job: DateOfForm, DateOfLoss, NameOfInsured, LocationOfLoss, and InsuredMailingAddress\n", "* The manifest file makes reference to both the source PDF location and the annotation location\n", "* Metadata about the annotation job (e.g. creation date) is captured.\n", "* Use-textract-only is set to False, meaning the annotation tool will decide whether to use PDFPlumber (for a native PDF) or Amazon Textract (for a scanned PDF). If it were set to true, textract would be used in either case (more costly but potentially more accurate). " ] }, { "cell_type": "code", "execution_count": null, "id": "3956bfbf", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Now we can train the recognizer\n", "\n", "comprehend = boto3.client('comprehend')\n", "response = comprehend.create_entity_recognizer(\n", " RecognizerName=\"recognizer-example-{}\".format(str(uuid.uuid4())),\n", " LanguageCode=\"en\",\n", " DataAccessRoleArn=f'arn:aws:iam::{AWS_ACCOUNT_ID}:role/comprehends3access',\n", " InputDataConfig={\n", " \"DataFormat\": \"AUGMENTED_MANIFEST\",\n", " \"EntityTypes\": [\n", " {\n", " \"Type\": \"DateOfForm\"\n", " },\n", " {\n", " \"Type\": \"DateOfLoss\"\n", " },\n", " {\n", " \"Type\": \"NameOfInsured\"\n", " },\n", " {\n", " \"Type\": \"LocationOfLoss\"\n", " },\n", " {\n", " \"Type\": \"InsuredMailingAddress\"\n", " }\n", " ],\n", " \"AugmentedManifests\": [\n", " {\n", " 'S3Uri': MANIFEST_S3_URI,\n", " 'AnnotationDataS3Uri': ANNOTATIONS_S3_URI_PREFIX,\n", " 'SourceDocumentsS3Uri': TRAINING_DOCS_S3_URI_PREFIX,\n", " 'AttributeNames': [LABEL_ATTRIBUTE_NAME],\n", " 'DocumentType': 'SEMI_STRUCTURED_DOCUMENT',\n", " }\n", " ],\n", " }\n", ")\n", "recognizer_arn = response[\"EntityRecognizerArn\"]\n", "recognizer_arn" ] }, { "cell_type": "markdown", "id": "7698f2f4", "metadata": {}, "source": [ "Here, we are creating a recognizer to recognize all five types of entities. Of course, we could have used a subset of these entities if we preferred. You can use up to 25 entities. \n", "\n", "The details of each parameter are given below (source: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/comprehend.html#Comprehend.Client.create_entity_recognizer)\n", "\n", "**DataFormat (string) --**\n", "\n", "The format of your training data:\n", "\n", "COMPREHEND_CSV : A CSV file that supplements your training documents. The CSV file contains information about the custom entities that your trained model will detect. The required format of the file depends on whether you are providing annotations or an entity list. If you use this value, you must provide your CSV file by using either the \n", "Annotations or EntityList parameters. You must provide your training documents by using the Documents parameter.\n", "\n", "AUGMENTED_MANIFEST : A labeled dataset that is produced by Amazon SageMaker Ground Truth. This file is in JSON lines format. Each line is a complete JSON object that contains a training document and its labels. Each label annotates a named entity in the training document. If you use this value, you must provide the AugmentedManifests parameter in your request.\n", "If you don't specify a value, Amazon Comprehend uses COMPREHEND_CSV as the default.\n", "\n", "**EntityTypes (list) -- [REQUIRED]**\n", "\n", "The entity types in the labeled training data that Amazon Comprehend uses to train the custom entity recognizer. Any entity types that you don't specify are ignored.\n", "\n", "A maximum of 25 entity types can be used at one time to train an entity recognizer.\n", "\n", "**S3Uri (string) -- [REQUIRED]**\n", "\n", "The Amazon S3 location of the augmented manifest file.\n", "\n", "**AnnotationDataS3Uri (string) --**\n", "\n", "The S3 prefix to the annotation files that are referred in the augmented manifest file.\n", "\n", "**SourceDocumentsS3Uri (string) --**\n", "\n", "The S3 prefix to the source files (PDFs) that are referred to in the augmented manifest file.\n", "\n", "**AttributeNames (list) -- [REQUIRED]**\n", "\n", "The JSON attribute that contains the annotations for your training documents. The number of attribute names that you specify depends on whether your augmented manifest file is the output of a single labeling job or a chained labeling job.\n", "\n", "If your file is the output of a single labeling job, specify the LabelAttributeName key that was used when the job was created in Ground Truth.\n", "\n", "If your file is the output of a chained labeling job, specify the LabelAttributeName key for one or more jobs in the chain. Each LabelAttributeName key provides the annotations from an individual job.\n", "\n", "**DocumentType (string) --**\n", "\n", "The type of augmented manifest. PlainTextDocument or SemiStructuredDocument. If you don't specify, the default is PlainTextDocument.\n", "\n", "PLAIN_TEXT_DOCUMENT A document type that represents any unicode text that is encoded in UTF-8.\n", "\n", "SEMI_STRUCTURED_DOCUMENT A document type with positional and structural context, like a PDF. For training with Amazon Comprehend, only PDFs are supported. For inference, Amazon Comprehend support PDFs, DOCX and TXT." ] }, { "cell_type": "code", "execution_count": null, "id": "ca3397ed", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Depending on the size of the training set, training time can vary. \n", "# For this dataset, training will take ~1 hour. \n", "# Let's check on the status of the submitted training job\n", "\n", "# All recognizers\n", "recognizers = comprehend.list_entity_recognizers()\n", "# View the last submitted job\n", "recognizers['EntityRecognizerPropertiesList'][-1]" ] }, { "cell_type": "code", "execution_count": null, "id": "e75142ea", "metadata": { "tags": [] }, "outputs": [], "source": [ "# To monitor the status of the training job, you can use the describe_entity_recognizer API.\n", "# Check status of custom model training periodically until complete.\n", "\n", "recognizer_arn = recognizers['EntityRecognizerPropertiesList'][-1]['EntityRecognizerArn']\n", "\n", "while True:\n", " response = comprehend.describe_entity_recognizer(\n", " EntityRecognizerArn=recognizer_arn\n", " )\n", "\n", " status = response[\"EntityRecognizerProperties\"][\"Status\"]\n", " if \"IN_ERROR\" == status:\n", " print('TRAINING ERROR')\n", " break\n", " if \"TRAINED\" == status:\n", " print('TRAINING COMPLETE')\n", " break\n", " print(status)\n", " time.sleep(60)" ] }, { "cell_type": "markdown", "id": "3157f80f", "metadata": {}, "source": [ "## Obtain evaluation metrics from the trained model" ] }, { "cell_type": "markdown", "id": "eefa40ea", "metadata": {}, "source": [ "Comprehend provides model performance metrics for a trained model, which indicates how well the trained model is expected to make predictions using similar inputs. We can obtain both global precision/recall metrics as well as per-entity metrics. An accurate model will have both high precision and recall. High precision means the model is usually correct when it indicates a particular label, while high recall means that the model found most of the labels. F1 is a composite metric (harmonic mean) of these measures, and is thus high when both components are high. For a detailed description of the metrics, see: https://docs.aws.amazon.com/comprehend/latest/dg/cer-metrics.html![image.png](attachment:image.png)" ] }, { "cell_type": "code", "execution_count": null, "id": "435fcb10", "metadata": { "tags": [] }, "outputs": [], "source": [ "trained_recognizer_arn = recognizers['EntityRecognizerPropertiesList'][-1]['EntityRecognizerArn']\n", "trained_recognizer = comprehend.describe_entity_recognizer(EntityRecognizerArn=trained_recognizer_arn)\n", "\n", "# Global evaluation metrics\n", "trained_recognizer['EntityRecognizerProperties']['RecognizerMetadata']['EvaluationMetrics']" ] }, { "cell_type": "code", "execution_count": null, "id": "6387774b", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Per entity metrics\n", "entity_metrics = trained_recognizer['EntityRecognizerProperties']['RecognizerMetadata']['EntityTypes']\n", "for entity in entity_metrics:\n", " print(entity['Type'])\n", " print(entity['EvaluationMetrics'])\n", " print()" ] }, { "cell_type": "markdown", "id": "850295c3", "metadata": {}, "source": [ "## Deploy endpoint using trained model" ] }, { "cell_type": "code", "execution_count": null, "id": "32d49236", "metadata": { "tags": [] }, "outputs": [], "source": [ "## Let's deploy the endpoint using boto3 create endpoint api\n", "\n", "endpoint_name = \"ner-insurance-ep\"\n", "\n", "endpoint_response = comprehend.create_endpoint(\n", " EndpointName= endpoint_name,\n", " ModelArn= recognizer_arn,\n", " DesiredInferenceUnits=1,\n", " ClientRequestToken='string',\n", " Tags=[\n", " {\n", " 'Key': 'name',\n", " 'Value': 'insurance_endpoint'\n", " },\n", " ],\n", " DataAccessRoleArn= f'arn:aws:iam::{AWS_ACCOUNT_ID}:role/comprehends3access'\n", ")\n", "print(endpoint_response)" ] }, { "cell_type": "code", "execution_count": null, "id": "66651492", "metadata": { "tags": [] }, "outputs": [], "source": [ "print(json.dumps(endpoint_response, indent=4, sort_keys=True))\n", "endpoint_response[\"EndpointArn\"]" ] }, { "cell_type": "code", "execution_count": null, "id": "7a2bfe8a", "metadata": { "tags": [] }, "outputs": [], "source": [ "## Check the status of the endpoint deployment and your endpoint should be in service for inference\n", "\n", "EndpointArn= endpoint_response[\"EndpointArn\"]\n", "\n", "ep_response = comprehend.describe_endpoint(\n", " EndpointArn= EndpointArn\n", ")\n", "\n", "ep_status = (ep_response[\"EndpointProperties\"][\"Status\"])\n", "\n", "while ep_status != \"IN_SERVICE\":\n", " ep_response = comprehend.describe_endpoint(EndpointArn= EndpointArn)\n", " ep_status = (ep_response[\"EndpointProperties\"][\"Status\"])\n", " print (ep_status)\n", " time.sleep(60)\n", "\n", "print(\"your endpoint is in \"+ep_status)\n" ] }, { "cell_type": "markdown", "id": "c38b2e1e", "metadata": {}, "source": [ "## Perform inference on an unseen document using real-time endpoint" ] }, { "cell_type": "markdown", "id": "dbd3c36b", "metadata": {}, "source": [ "Let's run real-time inference with our deployed inference endpoint on a document that was not part of the training procedure. This synchrnous API can be used for standard or custom NER. If it is being used for custom NER (as it is here) we must pass the ARN of the trained model." ] }, { "cell_type": "code", "execution_count": null, "id": "fedcaa01", "metadata": { "tags": [] }, "outputs": [], "source": [ "import base64\n", "import boto3\n", "import json\n", "from botocore.exceptions import ClientError\n", "comprehend = boto3.client('comprehend')\n", "\n", "# Replace this with any document name in the /sample-docs/ directory\n", "document = \"INSR_pm_hipaa_1_pii_00048.pdf\"\n", "\n", "#with open(f\"/home/ec2-user/SageMaker/ComprehendCustom-Artifacts/ex_pdfs/{document}\", mode='rb') as file:\n", "with open(f\"./ComprehendCustom-Artifacts/ex_pdfs/{document}\", mode='rb') as file:\n", " document_bytes = file.read()\n", "\n", "try:\n", " response = comprehend.detect_entities(Bytes = document_bytes, \n", " DocumentReaderConfig={\n", " \"DocumentReadAction\": \"TEXTRACT_DETECT_DOCUMENT_TEXT\",\n", " \"DocumentReadMode\": \"SERVICE_DEFAULT\"\n", " },\n", " EndpointArn=EndpointArn)\n", " print(json.dumps(response, indent=4, sort_keys=True))\n", "\n", "except ClientError as e:\n", " print(e)\n", " print(\"Error\", e.response['Reason'], e.response['Detail']['Reason'])\n", "\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "9bd034ec", "metadata": { "tags": [] }, "outputs": [], "source": [ "## lets print the entities detected, Types and score from the response\n", "\n", "extracted_chars = 0\n", "for page_detail in response[\"DocumentMetadata\"][\"ExtractedCharacters\"]:\n", " extracted_chars = extracted_chars + page_detail[\"Count\"]\n", " \n", "print (\"Number of pages in this document : \" + str(page_detail[\"Page\"]) + \" and characters extracted count is : \" + str(extracted_chars))\n", "\n", "\n", "for results in response[\"Entities\"]:\n", " print(\"Entity Types : \" + str(results[\"Type\"]) +\" Entity : \"+str(results[\"Text\"])+ \" Score : \"+ str(round(results[\"Score\"]*100, 2)))\n" ] }, { "cell_type": "markdown", "id": "c0e3f86d", "metadata": {}, "source": [ "## Perform inference on an unseen document using asynchrnous API" ] }, { "cell_type": "markdown", "id": "10d79a10", "metadata": {}, "source": [ "Let's run inference with our trained model on a document that was not part of the training procedure. This asynchrnous API can be used for standard or custom NER. If it is being used for custom NER (as it is here) we must pass the ARN of the trained model." ] }, { "cell_type": "code", "execution_count": null, "id": "dec3a06c", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Start entities detection job\n", "\n", "response = comprehend.start_entities_detection_job(\n", " EntityRecognizerArn=trained_recognizer_arn,\n", " JobName=\"Detection-Job-{}\".format(str(uuid.uuid4())),\n", " LanguageCode=\"en\",\n", " DataAccessRoleArn=f'arn:aws:iam::{AWS_ACCOUNT_ID}:role/comprehends3access',\n", " InputDataConfig={\n", " \"InputFormat\": \"ONE_DOC_PER_FILE\",\n", " \"S3Uri\": os.path.join(ASSETS_S3_PREFIX, 'holdout/')\n", " },\n", " OutputDataConfig={\n", " \"S3Uri\": f's3://my-custom-comprehend-output-{AWS_ACCOUNT_ID}/'\n", " }\n", ")\n", "response" ] }, { "cell_type": "markdown", "id": "f4673a64", "metadata": {}, "source": [ "**S3Uri (string) -- [REQUIRED]**\n", " The Amazon S3 URI for the input data. The URI must be in same region as the API endpoint that you are calling. The URI can point to a single input file or it can provide the prefix for a collection of data files.\n", "\n", " For example, if you use the URI S3://bucketName/prefix , if the prefix is a single file, Amazon Comprehend uses that file as input. If more than one file begins with the prefix, Amazon Comprehend uses all of them as input.\n", "\n", "**InputFormat (string) --**\n", " Specifies how the text in an input file should be processed:\n", "\n", " ONE_DOC_PER_FILE - Each file is considered a separate document. Use this option when you are processing large documents, such as newspaper articles or scientific papers.\n", " ONE_DOC_PER_LINE - Each line in a file is considered a separate document. Use this option when you are processing many short documents, such as text messages." ] }, { "cell_type": "markdown", "id": "08895725", "metadata": {}, "source": [ "**Waiting for detection job completion**\n", "\n", "This code snippet could be used to periodically check the status of your detection job and wait until it finishes.\n", "\n", "We've already done a detection with a previously trained model so we will take a look at those results instead of using this code to wait for the job to finish, which would take a few minutes.\n", "\n", "```\n", "while True:\n", " job = comprehend.describe_entities_detection_job(\n", " JobId=response['JobId']\n", " )\n", " \n", " status = job[\"EntitiesDetectionJobProperties\"][\"JobStatus\"]\n", " if \"IN_ERROR\" == status:\n", " print('DETECTION ERROR')\n", " break\n", " if \"COMPLETED\" == status:\n", " print('DETECTION COMPLETE')\n", " break\n", " print(status)\n", " time.sleep(60)\n", "```" ] }, { "cell_type": "code", "execution_count": null, "id": "9a7835e5", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Get the output from the detection job\n", "\n", "# download pre-generated inference output for INSR_ACORD-Property-Loss-Notice-12.05.16_1_pii_00017\n", "!mkdir -p {LOCAL_ARTIFACTS_DIR}/inference_output\n", "INFERENCE_RESULTS_S3_URI = os.path.join(ASSETS_S3_PREFIX, 'detection/output/output.tar.gz')\n", "# If you want to instead use the result from the detection job after it completes, uncomment this line: \n", "# INFERENCE_RESULTS_S3_URI= comprehend.describe_entities_detection_job(JobId=response['JobId'])['EntitiesDetectionJobProperties']['OutputDataConfig']['S3Uri']\n", "\n", "# Pre-computed inference output (will have option to use this, or remake inference output)\n", "s3_download_file(INFERENCE_RESULTS_S3_URI, f'{LOCAL_ARTIFACTS_DIR}/inference_output/output.tar.gz')\n", "\n", "# Detection job output is at {LOCAL_ARTIFACTS_DIR}/inference_output/output.tar.gz\n", "!mkdir -p {LOCAL_OUTPUT_DIR}/inference_output/\n", "!tar -xvzf {LOCAL_ARTIFACTS_DIR}/inference_output/output.tar.gz -C {LOCAL_OUTPUT_DIR}/inference_output/\n", "\n", "INFERENCE_RESULTS_PATH = os.path.join(LOCAL_OUTPUT_DIR, 'inference_output/INSR_ACORD-Property-Loss-Notice-12.05.16_1_pii_00017_NativePDF.out')" ] }, { "cell_type": "code", "execution_count": null, "id": "c448675a", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Let's look at the inference on an example document\n", "INFERENCE_RESULTS_PATH = os.path.join(LOCAL_OUTPUT_DIR, 'inference_output/INSR_ACORD-Property-Loss-Notice-12.05.16_1_pii_00017_NativePDF.out')\n", "\n", "with open(INFERENCE_RESULTS_PATH) as f:\n", " detection_output = json.load(f)\n", "\n", "entities_list = detection_output['Entities']\n", "\n", "pd.DataFrame(entities_list)" ] }, { "cell_type": "code", "execution_count": null, "id": "46adf839", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Note that the model prediction output format closely resembles the annotation output format shown above.\n", "entities_list[0]" ] }, { "cell_type": "code", "execution_count": null, "id": "8fc6fb0e", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Let's visualize the the labels predicted by our model for this new example pdf\n", "\n", "holdout_pdf = f'{LOCAL_ARTIFACTS_DIR}/ex_pdfs/INSR_ACORD-Property-Loss-Notice-12.05.16_1_pii_00017.pdf'\n", "annotated_file =f'{LOCAL_OUTPUT_DIR}/detection_annotated.pdf'\n", "PDFHelper.add_annotations_to_file(detection_output, holdout_pdf, annotated_file)\n", "IFrame(annotated_file, width=600, height=800)" ] }, { "cell_type": "markdown", "id": "a1a399d0", "metadata": {}, "source": [ "## Summary\n", "\n", "In addition to the standard set of entities recognized by Amazon Comprehend's standard entity detection capabilities, Comprehend enables you to train and use your own custom models for detecting user-defined entities specific to your business use case directly on PDF documents. In this notebook you used a dataset of PDFs annotated with SageMaker Ground Truth to train a entity detection model in Comprehend.\n", "\n", "### Resources\n", "\n", "- Here are additional resources to help you dive deeper:\n", "\n", " - Setting up your own custom annotation job: https://aws.amazon.com/blogs/machine-learning/custom-document-annotation-for-extracting-named-entities-in-documents-using-amazon-comprehend/\n", "\n", " - Training a custom NER model using the Comprehend console: https://aws.amazon.com/blogs/machine-learning/extract-custom-entities-from-documents-in-their-native-format-with-amazon-comprehend/\n", "\n", " - API reference: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/comprehend.\n" ] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "Python 3 (Data Science)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/datascience-1.0" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 5 }