{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Document Extraction - _continued_\n", "\n", "In this notebook, we will train an Amazon Comprehend custom entity recognizer so that we can detect and extract entities from a sample Hospital Discharge Summary. We will be using the [Amazon Textract Parser Library](https://github.com/aws-samples/amazon-textract-response-parser/tree/master/src-python) to extract the plaintext data from the document and use data science library [Pandas](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html) to prepare training data. We will also be needing the [Amazon SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/), and [AWS boto3 python sdk](https://boto3.amazonaws.com/v1/documentation/api/latest/index.html) libraries. We will perform two types of entity recognition with Amazon Comprehend.\n", "\n", "- [Default entity recognition](#step1)\n", "- [Custom entity recognition](#step2)\n", "\n", "---\n", "\n", "## Setup Notebook\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import boto3\n", "import botocore\n", "import sagemaker\n", "import time\n", "import os\n", "import json\n", "import datetime\n", "import io\n", "import uuid\n", "import pandas as pd\n", "import numpy as np\n", "from pytz import timezone\n", "from PIL import Image, ImageDraw, ImageFont\n", "import multiprocessing as mp\n", "from pathlib import Path\n", "from IPython.display import Image, display, HTML, JSON, IFrame\n", "from textractcaller.t_call import call_textract, Textract_Features\n", "from textractprettyprinter.t_pretty_print import Textract_Pretty_Print, get_string\n", "from trp import Document\n", "\n", "# Document\n", "from IPython.display import Image, display, HTML, JSON\n", "from PIL import Image as PImage, ImageDraw\n", "\n", "\n", "# variables\n", "data_bucket = sagemaker.Session().default_bucket()\n", "region = boto3.session.Session().region_name\n", "account_id = boto3.client('sts').get_caller_identity().get('Account')\n", "\n", "os.environ[\"BUCKET\"] = data_bucket\n", "os.environ[\"REGION\"] = region\n", "role = sagemaker.get_execution_role()\n", "\n", "print(f\"SageMaker role is: {role}\\nDefault SageMaker Bucket: s3://{data_bucket}\")\n", "\n", "s3=boto3.client('s3')\n", "textract = boto3.client('textract', region_name=region)\n", "comprehend=boto3.client('comprehend', region_name=region)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "# Default Entity Recognition with Amazon Comprehend \n", "\n", "Amazon Comprehend can detect a pre-defined list of default entities using it's pre-trained model. Check out the [documentation](https://docs.aws.amazon.com/comprehend/latest/dg/how-entities.html) for a full list of default entitied. In this section, we will see how we can use Amazon Comprehend's default entity recognizer to get the default entities present in the document." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# document path for the discharge summary in dataset\n", "\n", "display(Image(filename=\"./dataset/document_samples/discharge-summary.png\", width=900, height=400))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's now take a look at some of the features available within Comprehend Medical." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#upload to S3\n", "key='idp/textract/discharge_summary.png'\n", "s3.upload_file(Filename='./dataset/document_samples/discharge-summary.png', \n", " Bucket=data_bucket, \n", " Key=key)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "form_resp = textract.analyze_document(Document={'S3Object':{\"Bucket\": data_bucket, \"Name\": key}\n", " }, FeatureTypes=['TABLES','FORMS'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Print text\n", "print(\"\\nText\\n========\")\n", "text = \"\"\n", "for item in form_resp[\"Blocks\"]:\n", " if item[\"BlockType\"] == \"LINE\":\n", " print ('\\033[94m' + item[\"Text\"] + '\\033[0m')\n", " text = text + \" \" + item[\"Text\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "comprehend_med = boto3.client(service_name='comprehendmedical')\n", "# Detect medical entities\n", "\n", "cm_json_data = comprehend_med.detect_entities_v2(Text=text)\n", "\n", "print(\"\\nMedical Entities\\n========\")\n", "for entity in cm_json_data[\"Entities\"]:\n", " print(\"- {}\".format(entity[\"Text\"]))\n", " print (\" Type: {}\".format(entity[\"Type\"]))\n", " print (\" Category: {}\".format(entity[\"Category\"]))\n", " if(entity[\"Traits\"]):\n", " print(\" Traits:\")\n", " for trait in entity[\"Traits\"]:\n", " print (\" - {}\".format(trait[\"Name\"]))\n", " print(\"\\n\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cm_json_data = comprehend_med.infer_icd10_cm(Text=text)\n", "\n", "print(\"\\nMedical coding\\n========\")\n", "\n", "for entity in cm_json_data[\"Entities\"]:\n", " \n", " for icd in entity[\"ICD10CMConcepts\"]:\n", " description = icd['Description']\n", " code = icd[\"Code\"]\n", " print(f'{description}: {code}')\n", " \n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will now extract the (UTF-8) string text from the document above and use the Amazon Comprehend [DetectEntities](https://docs.aws.amazon.com/comprehend/latest/dg/API_DetectEntities.html) API to detect the default entities.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "comprehend = boto3.client('comprehend')\n", "\n", "response = comprehend.detect_entities(\n", " Text=text,\n", " LanguageCode='en')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for entity in response['Entities']:\n", " print(f'{entity[\"Type\"]} : {entity[\"Text\"]}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "The output above shows us the default entities that Amazon Comprehend was able to detect in the document's text. However, we are interested in knowing specific entity values such as the patient name (which is denoted currently by default entity PERSON), or the patient's ID (which is denoted currently by default entity OTHER). In order to be able to do that, we will need to train an Amazon Comprehend custom entity recognizer which we will do in the following section" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "# Custom Entity Recognition with Amazon Comprehend \n", "\n", "## Data preparation\n", "\n", "There are 2 different ways we can train an Amazon Comprehend custom entity recognizer. \n", "\n", "- [Annotations](https://docs.aws.amazon.com/comprehend/latest/dg/cer-annotation.html)\n", "- [Entity lists](https://docs.aws.amazon.com/comprehend/latest/dg/cer-entity-list.html)\n", "\n", "The annotations method can often lead to more refined results for image files, PDFs, or Word documents because you train a model by submitting more accurate context as annotations along with your documents. However, the annotations method can be time-consuming and work-intensive.\n", "\n", "For simplicity of this hands-on, we use the entity lists method, which you can only use for plain text documents. This method gives us a CSV file that should contain the plain text and its corresponding entity type, as shown in the preceding example. The entities in this file are going to be specific to our use case, here - patient name and patient ID.\n", "\n", "For more details on how to prepare the training data for different use cases using annotations or entity lists methods, refer to [Preparing the training data.](https://docs.aws.amazon.com/comprehend/latest/dg/prep-training-data-cer.html)\n", "\n", "\n", "Now, let's take a look at our sample document." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "display(Image(filename=\"./dataset/document_samples/discharge-summary.png\", width=900, height=400))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We would like to extract 2 entities from this document\n", "\n", "- Patient name (`PATIENT_NAME`)\n", "- Patient ID (`PATIENT_ID`)\n", "\n", "Since we are going to use and Entity List with the above two entities, we need to get the sample document's content in UTF-8 encoded plain text format. This can be done by extracting the text from the document file(s) using Amazon textract." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "response = call_textract(input_document=f's3://{data_bucket}/idp/textract/discharge_summary.png') \n", "lines = get_string(textract_json=response, output_type=[Textract_Pretty_Print.LINES])\n", "text = lines.replace(\"\\n\", \" \")\n", "text" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The custom entity recognizer needs atleast 200 document samples, and 250 entity samples for each entity. For the purposes of this hands-on we have provided the augmented manifest file that provide training data for your custom model. An augmented manifest file is a labeled dataset that is produced by Amazon SageMaker Ground Truth." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "---\n", "\n", "## Training the custom entity recognizer\n", "\n", "Let's take a look at the entity list csv file." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "entities_df = pd.read_csv('./dataset/entity_list.csv', dtype={'Text': object})\n", "entities = entities_df[\"Type\"].unique().tolist()\n", "print(f'Custom entities : {entities}')\n", "print(f'\\nTotal Custom entities: {entities_df[\"Type\"].nunique()}')\n", "print(\"\\n\\nTotal Sample per entity:\")\n", "entities_df['Type'].value_counts()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here, we have two types of entities in the entity list CSV file - `PATIENT_ID` and `PATIENT_NAME`. We also have 300 samples for each entity. With the minimum number of samples per entity satisfied, we can now train the custom entity recognizer model for Amazon Comprehend. Let's upload entity list CSV file and the raw text corpus of the training data to S3." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!aws s3 cp ./dataset/entity_list.csv s3://{data_bucket}/idp-insurance/comprehend/entity_list.csv\n", "!aws s3 cp ./dataset/entity_training_corpus.txt s3://{data_bucket}/idp-insurance/comprehend/entity_training_corpus.txt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's initialize the variables required to start the entity recognizer training job." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "entities_uri = f's3://{data_bucket}/idp-insurance/comprehend/entity_list.csv'\n", "training_data_uri = f's3://{data_bucket}/idp-insurance/comprehend/entity_training_corpus.txt'\n", "\n", "print(f'Entity List CSV File: {entities_uri}')\n", "print(f'Training Data File: {training_data_uri}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create a custom entity recognizer\n", "account_id = boto3.client('sts').get_caller_identity().get('Account')\n", "id = str(datetime.datetime.now().strftime(\"%s\"))\n", "\n", "entity_recognizer_name = 'insurance-custom-ner-ds'\n", "entity_recognizer_version = 'v1'\n", "entity_recognizer_arn = ''\n", "create_response = None\n", "EntityTypes = []\n", "for e in entities:\n", " EntityTypes.append( {'Type':e})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "try:\n", " create_response = comprehend.create_entity_recognizer(\n", " InputDataConfig={\n", " 'DataFormat': 'COMPREHEND_CSV',\n", " 'EntityTypes': EntityTypes,\n", " 'Documents': {\n", " 'S3Uri': training_data_uri\n", " },\n", " 'EntityList': {\n", " 'S3Uri': entities_uri\n", " }\n", " },\n", " DataAccessRoleArn=role,\n", " RecognizerName=entity_recognizer_name,\n", " VersionName=entity_recognizer_version,\n", " LanguageCode='en'\n", " )\n", " \n", " entity_recognizer_arn = create_response['EntityRecognizerArn']\n", " \n", " print(f\"Comprehend Custom entity recognizer created with ARN: {entity_recognizer_arn}\")\n", "except Exception as error:\n", "\n", " print(error)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Check status of the Comprehend Custom Classification Job. Alternatively, the status of the training job can also be viewed by going to the Amazon Comprehend console.\n", "Note that the training may take ~ 30 minutes. \n", "\n", "Once the training job is completed move on to next step." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "# Loop through and wait for the training to complete . Takes up to 10 mins \n", "from IPython.display import clear_output\n", "import time\n", "from datetime import datetime\n", "\n", "jobArn = create_response['EntityRecognizerArn']\n", "\n", "max_time = time.time() + 3*60*60 # 3 hours\n", "while time.time() < max_time:\n", " now = datetime.now()\n", " current_time = now.strftime(\"%H:%M:%S\")\n", " \n", " describe_custom_recognizer = comprehend.describe_entity_recognizer(\n", " EntityRecognizerArn = jobArn\n", " )\n", " status = describe_custom_recognizer[\"EntityRecognizerProperties\"][\"Status\"]\n", " clear_output(wait=True)\n", " print(f\"{current_time} : Custom document entity recognizer: {status}\")\n", " \n", " if status == \"TRAINED\" or status == \"IN_ERROR\":\n", " break\n", " time.sleep(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## Create an Amazon Comprehend custom NER real-time endpoint\n", "\n", "Once our model has been trained successfully, it can then be deployed via an endpoint. Let's look at how we can deploy the trained custom entity recognizer." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#create comprehend endpoint\n", "model_arn = entity_recognizer_arn\n", "ep_name = 'insurance-custom-ner-endpoint'\n", "\n", "try:\n", " endpoint_response = comprehend.create_endpoint(\n", " EndpointName=ep_name,\n", " ModelArn=model_arn,\n", " DesiredInferenceUnits=1, \n", " DataAccessRoleArn=role\n", " )\n", " ER_ENDPOINT_ARN=endpoint_response['EndpointArn']\n", " print(f'Endpoint created with ARN: {ER_ENDPOINT_ARN}')\n", " %store ER_ENDPOINT_ARN\n", "except Exception as error:\n", " if error.response['Error']['Code'] == 'ResourceInUseException':\n", " print(f'An endpoint with the name \"{ep_name}\" already exists.')\n", " ER_ENDPOINT_ARN = f'arn:aws:comprehend:{region}:{account_id}:entity-recognizer-endpoint/{ep_name}'\n", " print(f'The classifier endpoint ARN is: \"{ER_ENDPOINT_ARN}\"')\n", " %store ER_ENDPOINT_ARN\n", " else:\n", " print(error)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that the endpoint creation may take about ~ 20 minutes. The status of the deployment can be checked using the code below. You can also view the status of the training job from the Amazon Comprehend console." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "# Loop through and wait for the training to complete . Takes up to 20 mins \n", "from IPython.display import clear_output\n", "import time\n", "from datetime import datetime\n", "\n", "ep_arn = endpoint_response[\"EndpointArn\"]\n", "\n", "max_time = time.time() + 3*60*60 # 3 hours\n", "while time.time() < max_time:\n", " now = datetime.now()\n", " current_time = now.strftime(\"%H:%M:%S\")\n", " \n", " describe_endpoint_resp = comprehend.describe_endpoint(\n", " EndpointArn=ep_arn\n", " )\n", " status = describe_endpoint_resp[\"EndpointProperties\"][\"Status\"]\n", " clear_output(wait=True)\n", " print(f\"{current_time} : Custom entity recognizer classifier: {status}\")\n", " \n", " if status == \"IN_SERVICE\" or status == \"FAILED\":\n", " break\n", " \n", " time.sleep(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "## Detect custom entities using the Endpoint\n", "\n", "We will now detect our two custom entities `PATIENT_NAME` and `PATIENT_ID` from our sample discharge summary letter. The function `get_entities()` is a wrapper function that calls the comprehend `DetectEntities` API. To get entities from the sample text document, we call the `comprehend.detect_entities()` method within the wrapper function and configure the language code and text as input parameters. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_entities(text):\n", " try:\n", " #detect entities\n", " entities_custom = comprehend.detect_entities(LanguageCode=\"en\", Text=text, EndpointArn=ER_ENDPOINT_ARN) \n", " df_custom = pd.DataFrame(entities_custom[\"Entities\"], columns = ['Text', 'Type', 'Score'])\n", " df_custom = df_custom.drop_duplicates(subset=['Text']).reset_index()\n", " return df_custom\n", " except Exception as e:\n", " print(e)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The API response that we get are the detected entities, their types and their corresponding confidence scores" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "response = get_entities(text)\n", "response" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "# Conclusion\n", "\n", "In this notebook, we saw Amazon Comprehend pre-trained entities and an Amazon Comprehend custom entity recognizer for further document extraction. For custom entity recognition, we trained an Amazon Comprehend custom entity recognizer model to detect custom entities from documents containing dense texts. We used the `Entity lists` approach to train the custom ner model, and lastly, deployed the model with an endpoint. We then used the endpoint to detect our custom entities `Patient Name` and `Patient ID` from the text extracted by Amazon Textract, from our sample Discharge Summary document." ] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "Python 3 (Data Science)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-2:429704687514:image/datascience-1.0" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 4 }