{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Amazon Comprehend Custom Entity Recognizer\n",
"\n",
"This notebook will serve as a template for the overall process of taking a text dataset and integrating it into [Amazon Comprehend Custom Entity Recognizer](https://docs.aws.amazon.com/comprehend/latest/dg/custom-entity-recognition.html) and perform natural language processing (NLP) to detect custom entities in your text.\n",
"\n",
"## Overview\n",
"\n",
"1. [Introduction to Amazon Comprehend Custom NER](#Introduction)\n",
"1. [Obtaining Your Data](#data)\n",
"1. [Pre-processing data](#preprocess)\n",
"1. [Training a custom recognizer](#train)\n",
"1. [Real time inference](#inference)\n",
"1. [Cleanup](#cleanup)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Introduction to Amazon Comprehend Custom Entity Recognition \n",
"\n",
"Amazon Comprehend recognizes and detects nine entity types out of the box from your data, such as person, date, place etc. Custom entity recognition extends the capability of Amazon Comprehend by helping you identify your specific new entity types that are not of from the preset generic entity types. In this case, this notebook trains Amazon Comprehend to detect three additional entity types - Robot Ethics, Positronic Brain and Kinematics.\n",
"\n",
"Building a custom entity recognizer helps to identify key words and phrases that are relevant to your business needs, and Amazon Comprehend helps you in reducing the complexity by providing automatic annotation and model training to create a custom entity model. For more information, see [Comprehend Custom Entity Recognition](https://docs.aws.amazon.com/comprehend/latest/dg/custom-entity-recognition.html)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Obtaining Your Data \n",
"\n",
"To train a custom entity recognizer, Amazon Comprehend needs training data in one of two formats -\n",
"1. **Entity Lists (plain text only)**\n",
"You specify a list of documents that contain your entities, and in addition, specify a list of specific entities to search for in the documents. This is preferred when you have a finite list of entities to work with (for example, the EasyTron model names).\n",
"2. **Annotations**\n",
"This is more comprehensive, and provides the location of your entities in a large number of documents using the entity locations (offsets). Through this, Comprehnd can train on both the entity and its context. \n",
"\n",
"For our use case, to generate custom annotations, we make use of [Amazon SageMaker Ground Truth](https://aws.amazon.com/sagemaker/groundtruth/). We use Ground Truth with a private workforce to annotate the entities in hundreds of documents, and generate annotation files using the results. To learn more about how to use Ground Truth to annotate data, see [Named Entity Recognition](https://docs.aws.amazon.com/sagemaker/latest/dg/sms-named-entity-recg.html).\n",
"\n",
"For the lab, we have already labeled the data and the annotation files are provided. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Pre-processing data "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Firstly, we import necessary libraries and initialize clients\n",
"import re\n",
"import time\n",
"import json\n",
"import uuid\n",
"import boto3\n",
"import random\n",
"import secrets\n",
"import datetime\n",
"import sagemaker\n",
"import pandas as pd\n",
"from sagemaker import get_execution_role\n",
"\n",
"\n",
"s3 = boto3.client('s3')\n",
"comprehend = boto3.client('comprehend')\n",
"\n",
"# provide the name of your S3 bucket here. This was already created in your account for this workshop\n",
"bucket = '' \n",
"\n",
"region = boto3.session.Session().region_name\n",
"\n",
"# Amazon S3 (S3) client\n",
"s3 = boto3.client('s3', region)\n",
"s3_resource = boto3.resource('s3')\n",
"try:\n",
" s3.head_bucket(Bucket=bucket)\n",
"except:\n",
" print(\"The S3 bucket name {} you entered seems to be incorrect, please try again\".format(bucket))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# This is the execution role that will be used to call Amazon Transcribe and Amazon Translate\n",
"role = get_execution_role()\n",
"display(role)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### We already provided you a training dataset and an annotations file in the repository, let's have a look at them now"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pd.read_csv('train.csv',header=None).head(10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pd.read_csv('annotations.csv').head(10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# let's upload our train and annotation files to S3\n",
"s3.upload_file('train.csv', bucket, 'comprehend/train/train.csv')\n",
"s3.upload_file('annotations.csv', bucket, 'comprehend/train/annotations.csv')\n",
"s3_train_channel = \"s3://\" + bucket + \"/comprehend/train/train.csv\"\n",
"s3_annot_channel = \"s3://\" + bucket + \"/comprehend/train/annotations.csv\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create Comprehend Custom Entity Recognizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"custom_entity_request = {\n",
" \"DataFormat\": \"COMPREHEND_CSV\",\n",
" \"Documents\": { \n",
" \"S3Uri\": s3_train_channel,\n",
" \"InputFormat\": \"ONE_DOC_PER_LINE\"\n",
" },\n",
" \"Annotations\": { \n",
" \"S3Uri\": s3_annot_channel\n",
" },\n",
" \"EntityTypes\": [\n",
" {\n",
" \"Type\": \"MOVEMENT\"\n",
" },\n",
" {\n",
" \"Type\": \"BRAIN\"\n",
" },\n",
" {\n",
" \"Type\": \"ETHICS\"\n",
" }\n",
" ]\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# create unique ID for recognizer\n",
"uid = str(uuid.uuid4())\n",
"\n",
"response = comprehend.create_entity_recognizer(\n",
" RecognizerName=f\"aim317-ner-{uid}\", \n",
" DataAccessRoleArn=role,\n",
" InputDataConfig=custom_entity_request,\n",
" LanguageCode=\"en\",\n",
" VersionName= 'v001'\n",
")\n",
"\n",
"print(response['EntityRecognizerArn'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Check training status in Amazon Comprehend console\n",
"\n",
"[Go to Amazon Comprehend Console](https://console.aws.amazon.com/comprehend/v2/home?region=us-east-1#entity-recognition)\n",
"\n",
"This will take approximately 20 minutes. **Execute the Entity Recongizer Metrics step below only after** the entity recognizer model has been created and is ready for use. Otherwise you will get an error message. If this is the case no worries, just try it again after the entity recognizer has finished training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"describe_response = comprehend.describe_entity_recognizer(\n",
" EntityRecognizerArn=response['EntityRecognizerArn']\n",
")\n",
"\n",
"print(describe_response['EntityRecognizerProperties']['Status'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Entity Recognizer Metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Print recognizer metrics\n",
"print(\"Entity recognizer metrics:\")\n",
"for ent in describe_response[\"EntityRecognizerProperties\"][\"RecognizerMetadata\"][\"EntityTypes\"]:\n",
" print(ent['Type'])\n",
" metrics = ent['EvaluationMetrics']\n",
" for k, v in metrics.items():\n",
" metrics[k] = round(v, 2)\n",
" print(metrics)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"describe_response['EntityRecognizerProperties']['EntityRecognizerArn']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create endpoint"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that the model is trained, we'll deploy the model to an Amazon Comprehend endpoint for synchronous, real-time inference. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# NOTE - We are using real-time endpoints and chunked text for demo purposes in this workshop. For your actual use case\n",
" # if you don't need real-time insights from Comprehend, we suggest using Comprehend start_entities_detection_job or batch_detect_entities to send the full corpus for entity detection\n",
" # If your need is real-time inference, please use the Comprehend real-time endpoint as we show in this notebook.\n",
" # We have used 4 Inference Units (IU) in this workshop, each IU has a throughput of 100 characters per second.\n",
"endpoint_response = comprehend.create_endpoint(\n",
" EndpointName=f\"aim317-ner-endpoint\",\n",
" ModelArn=describe_response['EntityRecognizerProperties']['EntityRecognizerArn'],\n",
" DesiredInferenceUnits=4, # you are charged based on Inference Units, for this workshop lets create 4 IUs\n",
" DataAccessRoleArn=role\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(endpoint_response['EndpointArn'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Check endpoint status in Amazon Comprehend console\n",
"\n",
"[Go to Amazon Comprehend Console](https://console.aws.amazon.com/comprehend/v2/home?region=us-east-1#endpoints)\n",
"\n",
"This will take approximately 10 minutes. Go to the **Run Inference** step below after the endpoint has been created and is ready for use. Running the cells prior to the endpoint being ready will result in error. You can re-execute the cell after the endpoint becomes available."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Input files ready for entity recognition\n",
"!aws s3 ls s3://{bucket}/comprehend/input/"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Prepare to page through our transcripts in S3\n",
"\n",
"# Define the S3 handles\n",
"s3 = boto3.client('s3')\n",
"s3_resource = boto3.resource('s3')\n",
"\n",
"\n",
"# Specify an S3 output prefix\n",
"t_prefix = 'quicksight/data/entity'\n",
"\n",
"\n",
"# Lets define the bucket name that contains the transcripts first\n",
"# So far we used a session bucket we created for training and testing the classifier\n",
"paginator = s3.get_paginator('list_objects_v2')\n",
"pages = paginator.paginate(Bucket=bucket, Prefix='comprehend/input')\n",
"job_name_list = []\n",
"\n",
"# We will use a temp DataFrame to extract the entity type that is most prominent in the transcript\n",
"tempcols = ['Type', 'Score']\n",
"df_temp = pd.DataFrame(columns=tempcols)\n",
"\n",
"\n",
"# We will define a DataFrame to store the results of the classifier\n",
"cols = ['transcript_name', 'entity_type']\n",
"df_ent = pd.DataFrame(columns=cols)\n",
"\n",
"# Now lets page through the transcripts\n",
"for page in pages:\n",
" for obj in page['Contents']:\n",
" entity = ''\n",
" # get the transcript file name\n",
" transcript_file_name = obj['Key'].split('/')[2]\n",
" # now lets get the transcript file contents\n",
" temp = s3_resource.Object(bucket, obj['Key'])\n",
" transcript_content = temp.get()['Body'].read().decode('utf-8')\n",
" # Send a chunk of the transcript for entity recognition\n",
" # NOTE - We are using real-time endpoints and chunked text for demo purposes in this workshop. For your actual use case\n",
" # if you don't need real-time insights from Comprehend, we suggest using Comprehend start_entities_detection_job or batch_detect_entities to send the full corpus for entity detection\n",
" # If your need is real-time inference, please use the Comprehend real-time endpoint as we show in this notebook.\n",
" # We have used 4 Inference Units (IU) in this workshop, each IU has a throughput of 100 characters per second.\n",
" transcript_truncated = transcript_content[400:1800]\n",
" # Call Comprehend to get the entity types the transcript belongs to\n",
" response = comprehend.detect_entities(Text=transcript_truncated, LanguageCode='en', EndpointArn=endpoint_response['EndpointArn'])\n",
" # Extract prominent entity\n",
" df_temp = pd.DataFrame(columns=tempcols)\n",
" for ent in response['Entities']:\n",
" df_temp.loc[len(df_temp.index)] = [ent['Type'],ent['Score']]\n",
" if len(df_temp) > 0:\n",
" entity = df_temp.iloc[df_temp.Score.argmax(), 0:2]['Type']\n",
" else:\n",
" entity = 'No entities'\n",
" \n",
" # Update the results DataFrame with the detected entities\n",
" df_ent.loc[len(df_ent.index)] = [transcript_file_name.strip('en-').strip('.txt'),entity] \n",
"\n",
" # Create a CSV file with cta label from this DataFrame\n",
"df_ent.to_csv('s3://' + bucket + '/' + t_prefix + '/' + 'entities.csv', index=False)\n",
"df_ent"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### We are done here. You can return to the workshop instructions for next steps"
]
}
],
"metadata": {
"instance_type": "ml.t3.medium",
"interpreter": {
"hash": "9ddb102edfbd95000dbbd260d8bbcf82701cc06b4dcf114fa04ba84aab75adcb"
},
"kernelspec": {
"display_name": "conda_python3",
"language": "python",
"name": "conda_python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.13"
}
},
"nbformat": 4,
"nbformat_minor": 4
}