{ "cells": [ { "cell_type": "markdown", "id": "6df488b2-c156-431d-80d6-95409cec7b1c", "metadata": {}, "source": [ "## Layout-Aware Entity Detection with Amazon Textract and Amazon SageMaker\n", "\n", "# End-to-End Workshop\n", "\n", "> *This notebook works well with the `Data Science 3.0 (Python 3)` kernel on SageMaker Studio*\n", "\n", "This alternative notebook accompanies a **guided workshop** on the Amazon Textract Transformer Pipeline solution. The steps have been somewhat streamlined, and the inline commentary reduced, compared to the main numbered notebook series. If you're trying out the solution on your own, you may prefer to start with [Notebook 1: Data Preparation](1.%20Data%20Preparation.ipynb) instead." ] }, { "cell_type": "markdown", "id": "7eaf75a5-f133-4d85-b780-c42a256ef620", "metadata": {}, "source": [ "---\n", "## Environment setup \n", "\n", "### SageMaker notebook permissions\n", "\n", "▶️ In the [AWS IAM Console](https://console.aws.amazon.com/iamv2/home#/roles), check that you've attached the deployed OCR pipeline stack's **data science policy** to your SageMaker Execution Role, before continuing. You can find your deployed OCRPipeline stack in the [AWS CloudFormation Console](https://console.aws.amazon.com/cloudformation/home), and the Data Science Policy name is one of the Stack outputs." ] }, { "cell_type": "markdown", "id": "1e799284-0075-4343-8e99-94a9cdc357a4", "metadata": {}, "source": [ "### Notebook libraries and configurations\n", "\n", "This notebook will require some additional libraries that aren't available by default in the SageMaker Studio Data Science kernel. Run the cell below to install the extra dependencies:" ] }, { "cell_type": "code", "execution_count": null, "id": "e9cdc9fb-377f-44dc-bb8c-51401c4a07d0", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "# Install Python libraries:\n", "!pip install amazon-textract-response-parser \\\n", " sagemaker-studio-image-build \\\n", " \"sagemaker>=2.87,<3\"\n", "\n", "# Install NodeJS:\n", "NODE_VER = \"v16.18.0\"\n", "NODE_DISTRO = \"linux-x64\"\n", "!mkdir -p /usr/local/lib/nodejs\n", "!wget -c https://nodejs.org/dist/{NODE_VER}/node-{NODE_VER}-{NODE_DISTRO}.tar.xz -O - | tar -xJ -C /usr/local/lib/nodejs\n", "NODE_BIN_DIR = f\"/usr/local/lib/nodejs/node-{NODE_VER}-{NODE_DISTRO}/bin\"\n", "ONPATH_BIN_DIR = \"/usr/local/bin\"\n", "!ln -fs {NODE_BIN_DIR}/node {ONPATH_BIN_DIR}/node && \\\n", " ln -fs {NODE_BIN_DIR}/npm {ONPATH_BIN_DIR}/npm && \\\n", " ln -fs {NODE_BIN_DIR}/npx {ONPATH_BIN_DIR}/npx && \\\n", " echo \"NodeJS {NODE_VER} installed!\"" ] }, { "cell_type": "markdown", "id": "2f2e0741-21ec-438e-b782-e1728b301106", "metadata": {}, "source": [ "With the extra libraries installed, you're ready to load them into the kernel and initialise clients for the various AWS services we'll be calling from the notebook:" ] }, { "cell_type": "code", "execution_count": null, "id": "d433651b-ff28-4540-8c8d-ba355000f120", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "# Python Built-Ins:\n", "from datetime import datetime\n", "import json\n", "from logging import getLogger\n", "import os\n", "import random\n", "import re\n", "import shutil\n", "import time\n", "from zipfile import ZipFile\n", "\n", "# External Dependencies:\n", "import boto3 # AWS SDK for Python\n", "from IPython import display # To display rich content in notebook\n", "import pandas as pd # For tabular data analysis\n", "import sagemaker # High-level SDK for SageMaker\n", "from tqdm.notebook import tqdm # Progress bars\n", "\n", "# Local Dependencies:\n", "import util\n", "\n", "# AWS service clients:\n", "s3 = boto3.resource(\"s3\")\n", "smclient = boto3.client(\"sagemaker\")\n", "ssm = boto3.client(\"ssm\")\n", "\n", "logger = getLogger()" ] }, { "cell_type": "markdown", "id": "772da7c5-e69d-4e2c-bf81-97bd47e6dc71", "metadata": {}, "source": [ "This notebook will work with data sandboxes in Amazon S3, and connect to a deployed document processing pipeline solution. Below, we configure S3 data folders and read deployed pipeline parameter configuration from [AWS Systems Manager Parameter Store (AWS SSM)](https://docs.aws.amazon.com/systems-manager/latest/userguide/systems-manager-parameter-store.html):" ] }, { "cell_type": "code", "execution_count": null, "id": "9d2a9655-b6a7-4a75-ab2f-99f2833ce4a9", "metadata": {}, "outputs": [], "source": [ "# S3 data locations:\n", "bucket_name = sagemaker.Session().default_bucket()\n", "bucket_prefix = \"textract-transformers-wshp/\"\n", "raw_s3uri = f\"s3://{bucket_name}/{bucket_prefix}data/raw\"\n", "imgs_s3uri = f\"s3://{bucket_name}/{bucket_prefix}data/imgs-clean\"\n", "textract_s3uri = f\"s3://{bucket_name}/{bucket_prefix}data/textracted\"\n", "thumbs_s3uri = f\"s3://{bucket_name}/{bucket_prefix}data/thumbnails\"\n", "annotations_base_s3uri = f\"s3://{bucket_name}/{bucket_prefix}data/annotations\"\n", "print(f\"Working in bucket s3://{bucket_name}/{bucket_prefix}\\n\")\n", "\n", "try:\n", " config = util.project.init(\"ocr-transformers-demo\")\n", " print(config)\n", "except Exception as e:\n", " try:\n", " print(f\"Your SageMaker execution role is: {sagemaker.get_execution_role()}\")\n", " except Exception:\n", " print(\"Couldn't look up your SageMaker execution role\")\n", " raise e" ] }, { "cell_type": "markdown", "id": "71a9cb45-40ea-4f65-89d2-9cb52fc0a763", "metadata": {}, "source": [ "### SageMaker Ground Truth work team\n", "\n", "For this demo, you'll also need to manually set up a private work \"team\" in SageMaker Ground Truth, and enrol yourself to be able to use the data annotation UI.\n", "\n", "▶️ **Open** the [Amazon SageMaker Ground Truth console, *Labeling Workforces* page](https://console.aws.amazon.com/sagemaker/groundtruth?#/labeling-workforces)\n", "\n", "> ⚠️ **Check** SM Ground Truth opens in the same **AWS Region** where this notebook and your CloudFormation stack are deployed: You may find it defaults to `N. Virginia`. Use the drop-down in the top right of the screen to switch regions.\n", "\n", "▶️ **Select** the *Private* tab and click **Create private team**\n", "\n", "- Choose an appropriate **name** for your team e.g. `just-me`\n", "- (If you get the option) select to **Invite new workers via email** and enter your email address (you'll need access to this address to log in and annotate the data)\n", "- And leave the other (Cognito, SNS, etc) parameters as default.\n", "\n", "▶️ **If you didn't get the option** to add workers during team creation (typically because your account is already set up for SageMaker Ground Truth), then after the team is created you can:\n", "\n", "- Click **Invite new workers** to add your email address to the workforce, and then\n", "- Click on your **team name** to open the team details, then navigate to the *Workers tab* to add yourself to the team\n", "\n", "▶️ **Copy** the *name* of your workteam and paste it into the cell below, to store it:" ] }, { "cell_type": "code", "execution_count": null, "id": "e56f8fd8-f604-46b0-ac1e-8a8bb2970bc0", "metadata": {}, "outputs": [], "source": [ "workteam_name = \"just-me\" # TODO: Update this to match yours, if different\n", "\n", "workteam_arn = util.smgt.workteam_arn_from_name(workteam_name)" ] }, { "cell_type": "markdown", "id": "3633c52e-566b-4472-94fe-5fd9aef295ac", "metadata": {}, "source": [ "Finally:\n", "\n", "▶️ **Check your email** for an invitation and log in to the labelling portal. You'll be asked to configure a password on first login.\n", "\n", "\n", "Your completed setup should look something like this in the AWS Console:\n", "\n", "![](img/smgt-private-workforce.png \"Screenshot of SageMaker Ground Truth private workforces configuration\")" ] }, { "cell_type": "markdown", "id": "4d01f0c2-fde1-48fc-907c-439d8d288c0d", "metadata": {}, "source": [ "---\n", "\n", "## Fetch the raw document corpus\n", "\n", "In this example, we'll explore entity detection on specimen **credit card agreements** published by the United States' [Consumer Finance Protection Bureau](https://www.consumerfinance.gov/credit-cards/agreements/). This dataset includes providers across the US, and is interesting for our purposes because the documents are:\n", "\n", "- **Diverse** in formatting, as various providers present the required information in different ways\n", "- **Representative of commercial** documents - rather than, for example, academic papers which might have quite different tone and structure\n", "- **Complex** in structure, with common data points in theory (e.g. interest rates, fees, etc) - but a lot of nuances and differences between documents in practice.\n", "\n", "The sample dataset (approx. 900MB uncompressed) is published as an archive file (approx. 750MB) which we'll need to extract for the raw PDFs. Since it's a reasonable size, we can perform the extraction here in SageMaker Studio to also have local copies of the raw files to inspect." ] }, { "cell_type": "code", "execution_count": null, "id": "23bad343-ea5b-402b-b94a-ea72e811d16b", "metadata": {}, "outputs": [], "source": [ "%%time\n", "os.makedirs(\"data/raw\", exist_ok=True)\n", "\n", "# Fetch the example data:\n", "!wget -O data/CC_Agreements.zip https://files.consumerfinance.gov/a/assets/Credit_Card_Agreements_2020_Q4.zip" ] }, { "cell_type": "code", "execution_count": null, "id": "965cf6bb-cef5-4620-bf62-ae361754db86", "metadata": {}, "outputs": [], "source": [ "%%time\n", "# Extract the file:\n", "print(\"Extracting...\")\n", "shutil.rmtree(\"data/raw\")\n", "with ZipFile(\"data/CC_Agreements.zip\", \"r\") as fzip:\n", " fzip.extractall(\"data/raw\")\n", "\n", "# Clean up unneeded files and remap if the folder became nested:\n", "# (This is written specific to our sample data zip, but is unlikely to break most custom data)\n", "original_root_items = os.listdir(\"data/raw\")\n", "if \"__MACOSX\" in original_root_items:\n", " shutil.rmtree(\"data/raw/__MACOSX\")\n", "if len(original_root_items) < 4:\n", " try:\n", " folder = next(f for f in original_root_items if f.startswith(\"Credit_Card_Agreements\"))\n", " print(f\"De-nesting folder '{folder}'...\")\n", " for sub in os.listdir(f\"data/raw/{folder}\"):\n", " shutil.move(f\"data/raw/{folder}/{sub}\", f\"data/raw/{sub}\")\n", " time.sleep(0.1) # (Saw a FileNotFound error during renames one time in SMStudio)\n", " os.rmdir(f\"data/raw/{folder}\")\n", " except StopIteration:\n", " pass\n", "\n", "print(\"Done!\")" ] }, { "cell_type": "code", "execution_count": null, "id": "03c12dea-8932-401e-bb5b-db7e30194521", "metadata": {}, "outputs": [], "source": [ "# The s3 sync command can upload folders from SageMaker to S3 (or download, swapping the args).\n", "# For the example data, we extracted locally so will upload:\n", "print(f\"Uploading raw PDFs to {raw_s3uri}...\")\n", "!aws s3 sync --quiet data/raw {raw_s3uri}\n", "print(\"Done\")" ] }, { "cell_type": "markdown", "id": "dac05fc9-49ab-4999-b11f-df6f56579e19", "metadata": {}, "source": [ "To build an initial manifest/index of the data, we'd like to filter out any unsupported system files or other non-document content in the folder:" ] }, { "cell_type": "code", "execution_count": null, "id": "6454c9e9-2484-4582-8e37-e8368a401aa1", "metadata": {}, "outputs": [], "source": [ "raw_bucket_name, raw_prefix = util.s3.s3uri_to_bucket_and_key(raw_s3uri)\n", "\n", "valid_file_types = {\"jpeg\", \"jpg\", \"pdf\", \"png\", \"tif\", \"tiff\"}\n", "\n", "n_files = 0\n", "with open(\"data/raw-all.manifest.jsonl\", \"w\") as f:\n", " # sorted() guarantees output order for reproducible sampling later:\n", " for obj in sorted(\n", " s3.Bucket(raw_bucket_name).objects.filter(Prefix=raw_prefix + \"/\"),\n", " key=lambda obj: obj.key,\n", " ):\n", " # Filter out any files you know shouldn't be counted:\n", " file_ext = obj.key.rpartition(\".\")[2].lower()\n", " if \"/.\" in obj.key or file_ext not in valid_file_types:\n", " print(f\"Skipping s3://{obj.bucket_name}/{obj.key}\")\n", " continue\n", "\n", " # Save\n", " item = {\"raw-ref\": f\"s3://{obj.bucket_name}/{obj.key}\"}\n", " f.write(json.dumps(item)+\"\\n\")\n", " n_files += 1\n", "\n", "print(f\"\\nFound {n_files} valid files for OCR\")" ] }, { "cell_type": "markdown", "id": "a8ad41d0-a5b0-4541-ab8d-b6e9b393d60f", "metadata": {}, "source": [ "With the documents downloaded and catalogued, we can explore some examples to get an initial idea of the kind of content in the dataset:" ] }, { "cell_type": "code", "execution_count": null, "id": "dca3a214-4149-4c93-9c01-cc278ce5c53d", "metadata": {}, "outputs": [], "source": [ "# Read from docs manifest:\n", "with open(\"data/raw-all.manifest.jsonl\") as f:\n", " raw_doc_s3uris = [json.loads(l)[\"raw-ref\"] for l in f]\n", "\n", "# Choose a document by index number:\n", "disp_record = raw_doc_s3uris[0]\n", "filepath = disp_record.replace(raw_s3uri+\"/\", \"data/raw/\")\n", "\n", "print(f\"Displaying: {filepath}\")\n", "display.IFrame(\n", " filepath,\n", " height=\"600\",\n", " width=\"100%\",\n", ")" ] }, { "cell_type": "markdown", "id": "9835769b-c255-4226-978b-8258891eb1e8", "metadata": {}, "source": [ "---\n", "\n", "## Define the challenge\n", "\n", "So we have our sample documents - what information would we like to extract from them?\n", "\n", "As an example, we'll consider a market data aggregation use case: Collecting information like interest rates, fees, provider and product names, and some other more challenging examples like minimum payment descriptions and locally-applicable terms. The cell below defines the list of entities for the use-case, with some tips on how to annotate them that you'll also be able to see in the data labelling UI later:" ] }, { "cell_type": "code", "execution_count": null, "id": "d3046e3c-2c87-44d5-8d72-9987eb7e32ab", "metadata": {}, "outputs": [], "source": [ "from util.postproc.config import FieldConfiguration\n", "\n", "# For config API details, you can see the docs in the source file or run:\n", "# help(FieldConfiguration)\n", "\n", "fields = [\n", " # (To prevent human error, enter class_id=0 each time and update programmatically below)\n", " FieldConfiguration(0, \"Agreement Effective Date\", optional=True, select=\"first\",\n", " annotation_guidance=(\n", " \"

Avoid labeling extraneous dates which are not necessarily the effective date of \"\n", " \"the document: E.g. copyright dates/years, or other dates mentioned in text.

\"\n", " \"

Do not include unnecessary qualifiers e.g. 'from 2020/01/01'.

\"\n", " ),\n", " ),\n", " FieldConfiguration(0, \"APR - Introductory\", optional=True, select=\"confidence\",\n", " annotation_guidance=(\n", " \"

Use this class (instead of the others) for ANY case where the rate is \"\n", " \"offered for a fixed introductory period - regardless of interest rate subtype e.g. \"\n", " \"balance transfers, purchases, etc.

\"\n", " \"

Include the term of the introductory period in cases where it's directly listed \"\n", " \"(e.g. '20.00% for the first 6 months'). Try to minimize/exclude extraneous \"\n", " \"information about the offer (e.g. '20.00% for the first 6 months after account \"\n", " \"opening').

\"\n", " \"

'Prime rate + X%' mentions are acceptable and should be labeled.

\"\n", " ),\n", " ),\n", " FieldConfiguration(0, \"APR - Balance Transfers\", optional=True, select=\"confidence\",\n", " annotation_guidance=(\n", " \"

Use for interest rates which are specific to balance transfers.

\"\n", " \"

Avoid including extraneous information about the terms of balance transfers, or \"\n", " \"using for fixed-term introductory rates.

\"\n", " \"

'Prime rate + X%' mentions are acceptable and should be labeled.

\"\n", " ),\n", " ),\n", " FieldConfiguration(0, \"APR - Cash Advances\", optional=True, select=\"confidence\",\n", " annotation_guidance=(\n", " \"

Use for interest rates which are specific to cash advances.

\"\n", " \"

Avoid including extraneous information about the terms of cash advances, or using \"\n", " \"for fixed-term introductory rates.

\"\n", " \"

'Prime rate + X%' mentions are acceptable and should be labeled.

\"\n", " ),\n", " ),\n", " FieldConfiguration(0, \"APR - Purchases\", optional=True, select=\"confidence\",\n", " annotation_guidance=(\n", " \"

Use for interest rates which are specific to purchases.

\"\n", " \"

'Prime rate + X%' mentions are acceptable and should be labeled.

\"\n", " ),\n", " ),\n", " FieldConfiguration(0, \"APR - Penalty\", optional=True, select=\"confidence\",\n", " annotation_guidance=(\n", " \"

Use for penalty interest rates applied under certain conditions.

\"\n", " \"

Exclude include information about the conditions under which the penalty \"\n", " \"rate comes into effect: Only include the interest rate which will be applied.

\"\n", " \"

'Prime rate + X%' mentions are acceptable and should be labeled.

\"\n", " ),\n", " ),\n", " FieldConfiguration(0, \"APR - General\", optional=True, select=\"confidence\",\n", " annotation_guidance=(\n", " \"

Use for interest rates which are general and not specifically tied to a \"\n", " \"particular transaction type e.g. purchases / balance transfers.

\"\n", " \"

Avoid using for fixed-term introductory rates.

\"\n", " \"

'Prime rate + X%' mentions are acceptable and should be labeled.

\"\n", " ),\n", " ),\n", " FieldConfiguration(0, \"APR - Other\", optional=True, select=\"confidence\",\n", " # TODO: Remove this class\n", " annotation_guidance=(\n", " \"

Use only for interest rates which don't fall in to any other category (including \"\n", " \"general or introductory rates). You may not see any examples in the data.

\"\n", " \"

Avoid using for fixed-term introductory rates.

\"\n", " \"

'Prime rate + X%' mentions are acceptable and should be labeled.

\"\n", " ),\n", " ),\n", " FieldConfiguration(0, \"Fee - Annual\", optional=True, select=\"confidence\",\n", " annotation_guidance=(\n", " \"

Include cases where the document explicitly indicates no fee e.g. 'None'

\"\n", " \"

Avoid any introductory terms e.g. '$0 for the first 6 months' or extraneous \"\n", " \"words: Label only the standard fee.

\"\n", " \"

Label only the annual amount of the fee, in cases where other breakdowns are \"\n", " \"specified: E.g. '$120', not '$10 per month ($120 per year)'.

\"\n", " ),\n", " ),\n", " FieldConfiguration(0, \"Fee - Balance Transfer\", optional=True, select=\"confidence\",\n", " annotation_guidance=(\n", " # TODO: Review\n", " \"

Try to be concise and exclude extra terms where not necessary

\"\n", " ),\n", " ),\n", " FieldConfiguration(0, \"Fee - Late Payment\", optional=True, select=\"confidence\",\n", " annotation_guidance=(\n", " \"

Label only the fee, not the circumstances in which it is payable.

\"\n", " \"

Limits e.g. 'Up to $25' are acceptable (don't just label '$25').

\"\n", " \"

Do NOT include non-specific mentions of pass-throgh costs (e.g. 'legal \"\n", " \"costs', 'reasonable expenses', etc.) incurred in the general collections process.

\"\n", " ),\n", " ),\n", " FieldConfiguration(0, \"Fee - Returned Payment\", optional=True, select=\"confidence\",\n", " annotation_guidance=(\n", " \"

Label only the fee, not the circumstances in which it is payable.

\"\n", " \"

Limits e.g. 'Up to $25' are acceptable (don't just label '$25').

\"\n", " ),\n", " ),\n", " FieldConfiguration(0, \"Fee - Foreign Transaction\", optional=True, select=\"shortest\",\n", " annotation_guidance=(\n", " \"

Do NOT include explanations of how exchange rates are calculated or \"\n", " \"non-specific indications of margins between rates. DO include specific \"\n", " \"charges/margins with brief clarifying info where listed e.g. '3% of the US \"\n", " \"dollar amount'.

\"\n", " ),\n", " ),\n", " FieldConfiguration(0, \"Fee - Other\", ignore=True,\n", " annotation_guidance=(\n", " \"

Common examples include: Minimum interest charge, cash advance fees, and \"\n", " \"overlimit fees.

\"\n", " \"

Do NOT include fixed-term introductory rates for fees (e.g. '$0 during \"\n", " \"the first year. After the first year...') - only the standard fees

\"\n", " \"

DO include qualifying information on the amount and limits of the fee, \"\n", " \"e.g. '$5 or 5% of the amount of each transaction, whichever is the greater'.

\"\n", " \"

Do NOT include general information on the nature of the fee and \"\n", " \"circumstances under which it is applied: E.g. 'Cash advance fee' or 'If the amount \"\n", " \"of interest payable is...'

\"\n", " ),\n", " ),\n", " FieldConfiguration(0, \"Card Name\",\n", " annotation_guidance=(\n", " \"

Label instances of the brand name of specific card(s) offered by the provider \"\n", " \"under the agreement, e.g. 'Rewards Platinum Card'

\"\n", " \"

Include the ' Card' suffix where available, but also annotate instances without \"\n", " \"such as 'Rewards Platinum'

\"\n", " \"

Avoid including the Provider Name (use the separate class for this) e.g. \"\n", " \"'AnyCompany Rewards Card' unless it's been substantially modified/abbreviated for \"\n", " \"the card name (e.g. 'AnyCo Rewards Card') or the company name is different from the \"\n", " \"Credit card provider (e.g. AnyBank offering a store credit card for AnyCompany)

\"\n", " \"

Do NOT include fixed-term introductory rates for fees (e.g. '$0 during \"\n", " \"the first year. After the first year...') - only the standard fees

\"\n", " \"

Avoid labeling generic payment provider names e.g. 'VISA card' or \"\n", " \"'Mastercard', except in contexts where the provider clearly uses them as the brand \"\n", " \"name for the offered card (e.g. 'VISA Card' from 'AnyCompany VISA Card'.

\"\n", " ),\n", " ),\n", " FieldConfiguration(0, \"Provider Address\", optional=True, select=\"confidence\",\n", " annotation_guidance=(\n", " \"

Include department or 'attn:' lines where present (but not Provider Name where \"\n", " \"used at the start of an address e.g. 'AnyCompany; 100 Main Street...').

\"\n", " \"

Include zip/postcode where present.

\"\n", " \"

Avoid labeling addresses for non-provider entities, such as watchdogs, \"\n", " \"market regulators, or independent agencies.

\"\n", " ),\n", " ),\n", " FieldConfiguration(0, \"Provider Name\", select=\"longest\",\n", " annotation_guidance=(\n", " \"

Label the name of the card provider: Including abbreviated mentions.

\"\n", " ),\n", " ),\n", " FieldConfiguration(0, \"Min Payment Calculation\", ignore=True,\n", " annotation_guidance=(\n", " \"

Label clauses describing how the minimum payment is calculated.

\"\n", " \"

Exclude lead-in e.g. 'The minimum payment is calculated as...' and label directly \"\n", " \"from e.g. 'the minimum of...'.

\"\n", " \"

Do NOT include clauses from related subjects e.g. how account balance is \"\n", " \"calculated

\"\n", " ),\n", " ),\n", " FieldConfiguration(0, \"Local Terms\", ignore=True,\n", " annotation_guidance=(\n", " \"

Label full terms specific to residents of certain states/countries, or applying \"\n", " \"only in particular jurisdictions.

\"\n", " \"

Include the scope of where the terms apply e.g. 'Residents of GA and \"\n", " \"VA...'

\"\n", " \"

Include locally-applicable interest rates, instead of annotating these \"\n", " \"with the 'APR - ' classes

\"\n", " ),\n", " )\n", "]\n", "for ix, cfg in enumerate(fields):\n", " cfg.class_id = ix\n", "\n", "# Print out to a simple list:\n", "entity_classes = [f.name for f in fields]\n", "print(\"\\n\".join(entity_classes))" ] }, { "cell_type": "markdown", "id": "65f19c38-3c31-4a52-b9a8-497265f5cb30", "metadata": {}, "source": [ "---\n", "## Filter a sample corpus\n", "\n", "For a quick example model, there's no need for us to process or annotate all ~2,500 documents in the original corpus. Here, we'll select a random subset - but ensuring those present in the pre-prepared annotation data are kept:" ] }, { "cell_type": "code", "execution_count": null, "id": "0675a4ec-324a-4030-927d-2e18ec5c16c2", "metadata": {}, "outputs": [], "source": [ "# Crawl source annotated Textract URIs from the job manifests:\n", "annotated_textract_s3uris = util.ocr.list_preannotated_textract_uris(\n", " ann_jobs_folder=\"data/annotations\",\n", " exclude_job_names=[\"LICENSE\"],\n", ")\n", "\n", "# Define how to check for matches:\n", "def textract_uri_matches_doc_uri(tex_uri, doc_uri) -> bool:\n", " \"\"\"Customize this function if needed for your use case's data layout\"\"\"\n", " # With our sample, Textract URIs will look like:\n", " # some/prefix/data/textracted/subfolders/file.pdf/consolidated.json\n", " tex_s3key = tex_uri[len(\"s3://\"):].partition(\"/\")[2]\n", " # With our sample, Raw URIs will look like:\n", " # some/prefix/data/raw/subfolders/file.pdf\n", " doc_s3key = doc_uri[len(\"s3://\"):].partition(\"/\")[2]\n", "\n", " # Given the expectations above:\n", " tex_rel_filepath = tex_s3key.partition(\"data/textracted/\")[2].rpartition(\"/\")[0]\n", " doc_rel_filepath = doc_s3key.partition(\"data/raw/\")[2]\n", " return doc_rel_filepath == tex_rel_filepath\n", "\n", "# Build the list of docs for which some annotations exist (prioritising debug over speed here):\n", "annotated_doc_s3uris = set()\n", "for uri in annotated_textract_s3uris:\n", " matching_doc_s3uris = [\n", " doc_s3uri\n", " for doc_s3uri in raw_doc_s3uris\n", " if textract_uri_matches_doc_uri(uri, doc_s3uri)\n", " ]\n", " n_matches = len(matching_doc_s3uris)\n", " if n_matches == 0:\n", " raise ValueError(\n", " \"Couldn't find matching document in dataset for annotated Textract URI: %s\"\n", " % (uri,)\n", " )\n", " if n_matches > 1:\n", " logger.warning(\n", " \"Textract URI matched %s document URIs: Matching criterion may be too loose.\\n%s\\n%s\",\n", " n_matches,\n", " uri,\n", " matching_doc_s3uris,\n", " )\n", " annotated_doc_s3uris.update(matching_doc_s3uris)\n", "\n", "# This sorted list of required document S3 URIs is the main result you need to get to here:\n", "annotated_doc_s3uris = sorted(annotated_doc_s3uris)\n", "print(f\"Found {len(annotated_doc_s3uris)} docs with pre-existing annotations\")\n", "print(\"For example:\")\n", "print(\"\\n\".join(annotated_doc_s3uris[:5] + [\"...\"]))" ] }, { "cell_type": "markdown", "id": "d7f984f6-beab-420e-a9a1-503ae4e0894f", "metadata": {}, "source": [ "Both Amazon Textract and the multi-lingual entity recognition model we'll use later should be capable of processing Spanish, but you may want to exclude the small number of Spanish-language docs in the corpus if you're not able to confidently read and annotate them!" ] }, { "cell_type": "code", "execution_count": null, "id": "65d8d03f-7834-4378-a63c-7e9202db4eb8", "metadata": {}, "outputs": [], "source": [ "N_DOCS_KEPT = 120\n", "SKIP_SPANISH_DOCS = True\n", "\n", "\n", "def include_filename(name: str) -> bool:\n", " \"\"\"Filter out likely Spanish/non-English docs (if SKIP_SPANISH_DOCS enabled)\"\"\"\n", " if not name:\n", " return False\n", " if not SKIP_SPANISH_DOCS:\n", " return True\n", " name_l = name.lower()\n", " if (\n", " \"spanish\" in name_l\n", " or \"tarjeta\" in name_l\n", " or re.search(r\"espa[nñ]ol\", name_l)\n", " or re.search(r\"[\\[\\(]esp?[\\]\\)]\", name_l)\n", " or re.search(r\"cr[eé]dito\", name_l)\n", " ):\n", " return False\n", " return True\n", "\n", "\n", "if N_DOCS_KEPT < len(annotated_doc_s3uris):\n", " raise ValueError(\n", " \"Existing annotations cannot be used for model training unless the target documents are \"\n", " \"Textracted. To proceed with fewer docs than have already been annotated, you'll need to \"\n", " \"`exclude_job_names` per the 'data/annotations' folder (e.g. ['augmentation-1']) AND \"\n", " \"remember to not include them in notebook 2 (model training). Alternatively, increase \"\n", " f\"your N_DOCS_KEPT. (Got {N_DOCS_KEPT} vs {len(annotated_doc_s3uris)} prev annotations).\"\n", " )\n", "\n", "with open(\"data/raw-all.manifest.jsonl\") as f:\n", " # First apply filtering rules:\n", " sampled_docs = [\n", " doc for doc in (json.loads(line) for line in f)\n", " if include_filename(doc[\"raw-ref\"])\n", " ]\n", "\n", "# Forcibly including the pre-annotated docs *after* the shuffling ensures that the order of\n", "# sampling new docs is independent of what/how many have been pre-annotated:\n", "required_docs = [d for d in sampled_docs if d[\"raw-ref\"] in annotated_doc_s3uris]\n", "random.Random(1337).shuffle(sampled_docs)\n", "new_docs = [d for d in sampled_docs if d[\"raw-ref\"] not in annotated_doc_s3uris]\n", "sampled_docs = sorted(\n", " required_docs + new_docs[:N_DOCS_KEPT - len(required_docs)],\n", " key=lambda doc: doc[\"raw-ref\"],\n", ")\n", "\n", "# Write the selected set to file:\n", "with open(\"data/raw-sample.manifest.jsonl\", \"w\") as f:\n", " for d in sampled_docs:\n", " f.write(json.dumps(d) + \"\\n\")\n", "\n", "print(f\"Extracted random sample of {len(sampled_docs)} docs\")\n", "sampled_docs[:5] + [\"...\"]" ] }, { "cell_type": "markdown", "id": "4c7285f0-0f3f-4697-ad70-e512a65310aa", "metadata": {}, "source": [ "> ▶️ In [data/raw-sample.manifest.jsonl](data/raw-sample.manifest.jsonl) you should now have an alphabetized list of the `N_DOCS_KEPT` randomly selected documents, which should include any documents referenced in existing annotations under `data/annotations`." ] }, { "cell_type": "markdown", "id": "f7a43ae7-0a41-471b-88bb-d6255c4e123f", "metadata": { "tags": [] }, "source": [ "---\n", "## OCR the input documents\n", "\n", "> ⚠️ **Note:** Refer to the [Amazon Textract Pricing Page](https://aws.amazon.com/textract/pricing/) for up-to-date guidance before running large extraction jobs.\n", ">\n", "> At the time of writing, the projected cost (in `us-east-1`, ignoring free tier allowances) of analyzing 100 documents with 10 pages on average was approximately \\\\$67 with `TABLES` and `FORMS` enabled, or \\\\$2 without. Across the full corpus, we measured the average number of pages per document at approximately 6.7.\n", "\n", "With (a subset of) the raw documents selected, the next ingredient is to link them with Amazon Textract-compatible OCR results in a new manifest - with entries something like:\n", "\n", "```json\n", "{\"raw-ref\": \"s3://doc-example-bucket/folder/mydoc.pdf\", \"textract-ref\": \"s3://doc-example-bucket/folder/mydoc-textracted.json\"}\n", "```" ] }, { "cell_type": "markdown", "id": "d53aec4c-882b-4a89-aafe-46fe1165489c", "metadata": {}, "source": [ "We need to be mindful of the service [quotas](https://docs.aws.amazon.com/general/latest/gr/textract.html#limits_textract) when processing large batches of documents with Amazon Textract, to avoid excessive rate limiting and retries. Since an OCR pipeline solution stack is already set up for this sample, you can use just the *Amazon Textract portion of the pipeline* to process the documents in bulk.\n", "\n", "> ⏰ This process took about 6 minutes to run against the 120-document sample set in our tests.\n", "\n", "> ⚠️ **If you see errors in the output:**\n", ">\n", "> - Try re-running the cell - Rate limiting can sometimes cause intermittent failures, and the function will skip successfully processed files in repeat runs.\n", "> - Persistent errors (on custom datasets) could be due to malformed files (remove them from the manifest) or very large files (see the [/CUSTOMIZATION_GUIDE.md](../CUSTOMIZATION_GUIDE.md) for tips on re-configuring your pipeline to handle very large documents)." ] }, { "cell_type": "code", "execution_count": null, "id": "d8456c95-7e74-477b-b0e3-36271f788a2e", "metadata": { "scrolled": true }, "outputs": [], "source": [ "%%time\n", "textract_results = util.ocr.call_textract(\n", " textract_sfn_arn=config.plain_textract_sfn_arn,\n", " # Can instead use raw-all.manifest.jsonl to process whole dataset (see cost note above):\n", " input_manifest=\"data/raw-sample.manifest.jsonl\",\n", " manifest_raw_field=\"raw-ref\",\n", " manifest_out_field=\"textract-ref\",\n", " # Map subpaths of {input_base} to subpaths of {output_base}:\n", " output_base_s3uri=textract_s3uri,\n", " input_base_s3uri=raw_s3uri,\n", " # Note that turning on additional features can have significant impact on API costs:\n", " features=[\"FORMS\", \"TABLES\"],\n", " skip_existing=True,\n", ")" ] }, { "cell_type": "markdown", "id": "1a184c07-dd58-4912-8c6c-a2d14fcc88a7", "metadata": {}, "source": [ "Once the extraction is done, write (only successful items) to a manifest file:" ] }, { "cell_type": "code", "execution_count": null, "id": "38867c33-0e99-4980-b94d-a185598ffd6b", "metadata": {}, "outputs": [], "source": [ "n_success = 0\n", "n_fail = 0\n", "with open(\"data/textracted-all.manifest.jsonl\", \"w\") as fout:\n", " for ix, item in enumerate(textract_results):\n", " if isinstance(item[\"textract-ref\"], str):\n", " fout.write(json.dumps(item) + \"\\n\")\n", " n_success += 1\n", " else:\n", " if n_fail == 0:\n", " logger.error(\"First failure at index %s:\\n%s\", ix, item[\"textract-ref\"])\n", " n_fail += 1\n", "\n", "print(f\"{n_success} of {n_success + n_fail} docs processed successfully\")\n", "if n_fail > 0:\n", " raise ValueError(\n", " \"Are you sure you want to continue? Consider re-trying to process the failed docs\"\n", " )" ] }, { "cell_type": "markdown", "id": "c686fc8c-353b-487c-b664-1ebf4a87fa6f", "metadata": {}, "source": [ "> ▶️ You should now have a [data/textracted-all.manifest.jsonl](data/textracted-all.manifest.jsonl) JSON-Lines manifest file mapping source documents `raw-ref` to Amazon Textract result JSONs `textract-ref`: Both as `s3://...` URIs." ] }, { "cell_type": "markdown", "id": "fdae5616-5de2-4e62-954f-e9aca3f51335", "metadata": { "tags": [] }, "source": [ "---\n", "## Extract clean input images (batch)\n", "\n", "To annotate our documents with SageMaker Ground Truth image task UIs, we need **individual page images**, stripped of EXIF rotation metadata (because, at the time of writing, SMGT ignores this rotation for annotation consistency) and converted to compatible formats (since some formats like TIFF are not supported by most browsers).\n", "\n", "For large corpora, this process of splitting PDFs and rotating and converting images may require significant resources - but is easy to parallelize.\n", "\n", "Therefore instead of pre-processing the raw documents here in the notebook, this is a good use case for a scalable [SageMaker Processing Job](https://docs.aws.amazon.com/sagemaker/latest/dg/processing-job.html).\n", "\n", "The job uses a **custom container image**, since the PDF reading tools we use aren't installed by default in pre-built SageMaker containers and aren't `pip install`able. However, the image has already been built and deployed to [Amazon Elastic Container Registry (ECR)](https://aws.amazon.com/ecr/) by the CDK stack (see `preproc_image` in [/pipeline/\\_\\_init\\_\\_.py](../pipeline/__init__.py)). All we need to do here is look it up from the stack parameters:" ] }, { "cell_type": "code", "execution_count": null, "id": "de79e673-4cb9-430c-a5ea-ed964722c477", "metadata": {}, "outputs": [], "source": [ "from sagemaker.processing import FrameworkProcessor, ProcessingInput, ProcessingOutput\n", "\n", "ecr_image_uri = config.preproc_image_uri\n", "print(f\"Using pre-built custom container image:\\n{ecr_image_uri}\")\n", "\n", "# Output S3 locations:\n", "imgs_s3uri = f\"s3://{bucket_name}/{bucket_prefix}data/imgs-clean\"\n", "thumbs_s3uri = f\"s3://{bucket_name}/{bucket_prefix}data/thumbnails\"" ] }, { "cell_type": "markdown", "id": "ca4cafa4-97a9-4b3d-a137-e9824e5536e5", "metadata": {}, "source": [ "> **Note:** The 'Non-augmented' manifest files used below for job data loading are still JSON-based, but a different format from the JSON-**Lines** manifests we use in most other places of this sample. You can find guidance on the [S3DataSource API doc](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_S3DataSource.html) for manifests as used here, and separate information in the [Ground Truth documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/sms-input-data-input-manifest.html) on the \"augmented\" JSON-Lines manifests used elsewhere." ] }, { "cell_type": "code", "execution_count": null, "id": "b45cc00a-345d-4893-b20c-85b161899efa", "metadata": {}, "outputs": [], "source": [ "#### OPTION 2: For processing the sampled subset of raw docs only:\n", "\n", "# Load the list of docs from file and add final filters:\n", "with open(\"data/raw-sample.manifest.jsonl\") as fin:\n", " doc_relpaths = [\n", " json.loads(line)[\"raw-ref\"][len(raw_s3uri) + 1:] # Relative file paths\n", " for line in fin\n", " ]\n", "\n", "# Prepare a true JSON (*NON-JSONLINES*) manifest file for SageMaker Processing:\n", "preproc_input_manifest_path = \"data/raw-dataclean-input.manifest.json\"\n", "with open(preproc_input_manifest_path, \"w\") as fout:\n", " fout.write(json.dumps(\n", " [{\"prefix\": raw_s3uri + \"/\"}]\n", " + doc_relpaths\n", " ))\n", "\n", "# Upload the manifest to S3:\n", "preproc_input_manifest_s3uri = f\"s3://{bucket_name}/{bucket_prefix}{preproc_input_manifest_path}\"\n", "!aws s3 cp {preproc_input_manifest_path} {preproc_input_manifest_s3uri}\n", "\n", "# Set the processing job inputs to reference the manifest:\n", "preproc_inputs = [\n", " ProcessingInput(\n", " destination=\"/opt/ml/processing/input/raw\", # Expected input location, per our script\n", " input_name=\"raw\",\n", " s3_data_distribution_type=\"ShardedByS3Key\", # Distribute between instances, if multiple\n", " s3_data_type=\"ManifestFile\",\n", " source=preproc_input_manifest_s3uri, # Manifest of sample raw documents\n", " ),\n", "]\n", "print(\"Selected sample subset of documents\")" ] }, { "cell_type": "markdown", "id": "7b3d3fae-0852-4f7c-b2d1-08e7b9b2525b", "metadata": {}, "source": [ "The cell below will **run the processing job** and show logs from the job as it progresses. You can also check up on the status and history of jobs in the [Processing page of the Amazon SageMaker Console](https://console.aws.amazon.com/sagemaker/home?#/processing-jobs).\n", "\n", "> ⏰ **Note:** In our tests, it took (including job start-up overheads) about 8 minutes to process the 120-document sample with 2x `ml.c5.2xlarge` instances" ] }, { "cell_type": "code", "execution_count": null, "id": "5e5d3223-5fa9-47b4-bbe7-3594bffd108a", "metadata": { "scrolled": true }, "outputs": [], "source": [ "%%time\n", "\n", "processor = FrameworkProcessor(\n", " estimator_cls=util.preproc.DummyFramework,\n", " image_uri=ecr_image_uri,\n", " framework_version=\"\", # Not needed as image URI already provided\n", " base_job_name=\"ocr-img-dataclean\",\n", " role=sagemaker.get_execution_role(),\n", " instance_count=2,\n", " instance_type=\"ml.c5.2xlarge\",\n", " volume_size_in_gb=15,\n", ")\n", "\n", "processor.run(\n", " code=\"preproc.py\", # PDF splitting / image conversion script\n", " source_dir=\"preproc\",\n", " inputs=preproc_inputs[:], # Either whole corpus or sample, as above\n", " outputs=[\n", " ProcessingOutput(\n", " destination=imgs_s3uri,\n", " output_name=\"imgs-clean\",\n", " s3_upload_mode=\"Continuous\",\n", " source=\"/opt/ml/processing/output/imgs-clean\", # Hi-res images for labelling\n", " ),\n", " ProcessingOutput(\n", " destination=thumbs_s3uri,\n", " output_name=\"thumbnails\",\n", " s3_upload_mode=\"Continuous\",\n", " source=\"/opt/ml/processing/output/thumbnails\", # Low-res images for model inputs\n", " ),\n", " ],\n", ")" ] }, { "cell_type": "markdown", "id": "18deaa89-6cb6-45c1-b721-22754cc5af12", "metadata": {}, "source": [ "Once the images have been extracted, we'll also **optionally** download them locally to the notebook for use in visualizations later:" ] }, { "cell_type": "code", "execution_count": null, "id": "1f0dfede-76cf-46f0-8541-b51d5292f8ec", "metadata": {}, "outputs": [], "source": [ "print(f\"Downloading cleaned images from {imgs_s3uri}...\")\n", "!aws s3 sync --quiet {imgs_s3uri} data/imgs-clean\n", "print(f\"Downloading thumbnail images from {thumbs_s3uri}...\")\n", "!aws s3 sync --quiet {thumbs_s3uri} data/imgs-thumb\n", "print(\"Done\")" ] }, { "cell_type": "markdown", "id": "53cfa331-7e41-484b-8021-2fbe23cf06d3", "metadata": {}, "source": [ "You'll see that this job also generates uniformly resized \"thumbnail\" images per page when the second (optional) `thumbnails` output is specified. These aren't important for the human annotation process, but will be used later for model training." ] }, { "cell_type": "markdown", "id": "1f08ef2c-08ee-4ed7-8abe-3586d6eb7100", "metadata": {}, "source": [ "### Collate OCR and image data for annotation\n", "\n", "Now we have a filtered corpus of documents with Amazon Textract results, plus cleaned and standardized images for each page - all available on Amazon S3.\n", "\n", "To prepare for data annotation and later model training, we'll need to collate these together with a **page-level manifest** in JSON-lines format, with records something like:\n", "\n", "```json\n", "{\"source-ref\": \"s3://doc-example-bucket/img-prefix/folder/filename-0001-01.png\", \"textract-ref\": \"s3://doc-example-bucket/tex-prefix/folder/filename.pdf/consolidated.json\", \"page-num\": 1}\n", "```\n", "\n", "Key features of the format are:\n", "- The `source-ref` is the path to a full-resolution cleaned page image (**not** a thumbnail), **but** model training in the next notebook will assume the equivalent thumbnail path is identical, except for some different s3://... bucket & prefix.\n", "- The `page-num` is one-based (always >= 1), and for model training must match the image to the appropriate page number **in the linked Textract JSON file**.\n", " - For example if you have thumbnail `filename-0001-15.png` for page 15 of some long document, but for some reason your `textract-ref` JSON file contains *only* detections from page 15 of the document, you would set `\"page-num\": 1`.\n", "- Mapping through the `raw-ref` here is nice to have, but optional, as the model training won't refer to the original document.\n", "\n", "The key goal is to create a page-level catalogue that we're confident is correct, and for that reason the example function below will actually **validate that the artifacts are present on S3** in the expected locations.\n", "\n", "> ⏰ Because of these validation checks, the cell below may a minute or two to run against our 120-document sample set." ] }, { "cell_type": "code", "execution_count": null, "id": "ea8ec1ae-a892-4f55-8ac2-2f493a5a58b2", "metadata": { "scrolled": true }, "outputs": [], "source": [ "warnings = util.preproc.collate_data_manifest(\n", " # Output file:\n", " \"data/pages-all-sample.manifest.jsonl\",\n", " # Input manifest:\n", " input_manifest=\"data/textracted-all.manifest.jsonl\",\n", " # s3://... base URI used to try and map 'textract-ref's to cleaned images:\n", " textract_s3_prefix=textract_s3uri,\n", " # The s3://... base URI under which page images are stored:\n", " imgs_s3_prefix=imgs_s3uri,\n", " # Optional s3://... base URI also used to try and map 'raw-ref's to images if present:\n", " raw_s3_prefix=raw_s3uri,\n", " # Other output manifest settings:\n", " by=\"page\",\n", " no_content=\"omit\",\n", ")\n", "\n", "if len(warnings):\n", " raise ValueError(\n", " \"Manifest usable but incomplete - %s docs failed. Please see `warnings` for details\"\n", " % len(warnings)\n", " )" ] }, { "cell_type": "markdown", "id": "05bb7800-7bad-40bf-bcce-7e72296ad270", "metadata": {}, "source": [ "> ▶️ You should now have a page-level catalogue linking `source-ref`, `textract-ref`, `page-num` in [data/pages-all-sample.manifest.jsonl](data/pages-all-sample.manifest.jsonl)\n", "\n", "Let's briefly explore the catalogue we've created. Each line of the file is a JSON record identifying a particular page:" ] }, { "cell_type": "code", "execution_count": null, "id": "38ff4bfe-31d7-4f13-8eb9-42b270679238", "metadata": {}, "outputs": [], "source": [ "with open(\"data/pages-all-sample.manifest.jsonl\", \"r\") as f:\n", " for ix, line in enumerate(f):\n", " print(line, end=\"\")\n", " if ix >= 2:\n", " print(\"...\")\n", " break" ] }, { "cell_type": "markdown", "id": "52809830-8c28-474c-ba0a-fd879e3021b0", "metadata": {}, "source": [ "The credit cards corpus has a very skewed distribution of number of pages per document, with a few outliers dragging up the average significantly. In our tests on corpus-wide statistics:\n", "\n", "- The overall average was **~6.7 pages per document**\n", "- The 25th percentile was 3 pages; the 50th percentile was 6 pages; and the 75th percentile was 11 pages\n", "- The longest document was 402 pages\n", "\n", "Your results for sub-sampled sets will likely vary - but can be analyzed as below:" ] }, { "cell_type": "code", "execution_count": null, "id": "ced078ea-1268-418c-96df-7ba6017a1d42", "metadata": {}, "outputs": [], "source": [ "with open(\"data/pages-all-sample.manifest.jsonl\", \"r\") as f:\n", " manifest_df = pd.DataFrame([json.loads(line) for line in f])\n", "page_counts_by_doc = manifest_df.groupby(\"textract-ref\")[\"textract-ref\"].count()\n", "\n", "print(\"Document page count statistics\")\n", "page_counts_by_doc.describe()" ] }, { "cell_type": "markdown", "id": "61da00fe-f0ed-4e7c-b1c8-d0ed13e5bfbf", "metadata": {}, "source": [ "---\n", "## Start the data labelling job\n", "\n", "Now we have a correlated set of cleaned page images and OCR results for each page, we're ready to start annotating entities to collect model training data. Typically this is an iterative process with multiple rounds of labelling to balance experimentation speed with model accuracy. Here though, we'll show setting up a single small labelling job and combine the results with pre-existing annotations." ] }, { "cell_type": "markdown", "id": "322cad31-3f79-4e25-8d12-c77bf13cd319", "metadata": {}, "source": [ "### Sample a dataset to label\n", "\n", "Below, we:\n", "\n", "- **Shuffle** our data (in a *reproducible*/deterministic way), to ensure we annotate documents/pages from a range of providers - not just concentrating on the first provider/doc(s)\n", "- **Exclude** any examples for which the page image has **already been labeled** in the `data/annotations` output folder\n", "- **Stratify** the sample, to obtain a specific (boosted) proportion of first-page samples, since we observed the first pages of documents to often be most useful for the fields of interest in the sample credit cards use case. (Many documents use the first page for a fact-sheet/summary, followed by subsequent pages of dense legal terms).\n", "\n", "Run the cells below to select a small subset of previously-unlabelled pages and build a manifest file listing them:" ] }, { "cell_type": "code", "execution_count": null, "id": "85916f48-eaa0-4a1f-afa2-df80eb7d1641", "metadata": {}, "outputs": [], "source": [ "annotation_job_name = \"cfpb-workshop-1\" # What will this job be called?\n", "N_JOB_EXAMPLES = 15 # Select 15 new pages to annotate\n", "PCT_FIRST_PAGE = .4 # 40% of samples should be page-num 1\n", "\n", "preannotated_img_uris = [\n", " f\"{imgs_s3uri}/{path}\"\n", " for path in util.preproc.list_preannotated_img_paths(\n", " annotations_folder=\"data/annotations\",\n", " exclude_job_names=[],\n", " key_prefix=\"data/imgs-clean/\",\n", " )\n", "]\n", "\n", "job_input_manifest_file = f\"data/manifests/{annotation_job_name}.jsonl\"\n", "os.makedirs(\"data/manifests\", exist_ok=True)\n", "print(f\"'{annotation_job_name}' saving to: {job_input_manifest_file}\")\n", "\n", "with open(job_input_manifest_file, \"w\") as f:\n", " for ix, example in enumerate(\n", " util.preproc.stratified_sample_first_page_examples(\n", " input_manifest_path=\"data/pages-all-sample.manifest.jsonl\",\n", " n_examples=N_JOB_EXAMPLES, \n", " pct_first_page=PCT_FIRST_PAGE,\n", " exclude_source_ref_uris=preannotated_img_uris,\n", " )\n", " ):\n", " if ix < 3:\n", " print(example)\n", " elif ix == 3:\n", " print(\"...\")\n", " f.write(json.dumps(example) + \"\\n\")" ] }, { "cell_type": "markdown", "id": "9a0955cd-cf0c-4f9f-8e3f-ce04f494527a", "metadata": {}, "source": [ "To create the labelling job in SageMaker, this manifest file will also need to be uploaded to Amazon S3:" ] }, { "cell_type": "code", "execution_count": null, "id": "650135a2-65d7-4bf3-85f7-bee1d301793c", "metadata": {}, "outputs": [], "source": [ "input_manifest_s3uri = f\"s3://{bucket_name}/{bucket_prefix}{job_input_manifest_file}\"\n", "!aws s3 cp $job_input_manifest_file $input_manifest_s3uri" ] }, { "cell_type": "markdown", "id": "785e618b-4dc8-4aca-bad5-424289813e8b", "metadata": {}, "source": [ "### Create the labelling job\n", "\n", "With a manifest file defining which pages should be included, and your \"work team\" already set up from earlier, you're ready to create your SageMaker Ground Truth labelling job.\n", "\n", "You could also explore creating this via the AWS Console for SageMaker, but the code below will set up the job with the correct settings for you automatically:" ] }, { "cell_type": "code", "execution_count": null, "id": "809aa7b8-3f18-4ef8-803f-cf6da1de9682", "metadata": {}, "outputs": [], "source": [ "util.smgt.ensure_bucket_cors(bucket_name)\n", "\n", "print(f\"Starting labeling job {annotation_job_name}\\non data {input_manifest_s3uri}\\n\")\n", "create_labeling_job_resp = util.smgt.create_bbox_labeling_job(\n", " annotation_job_name,\n", " bucket_name=bucket_name,\n", " execution_role_arn=sagemaker.get_execution_role(),\n", " fields=fields,\n", " input_manifest_s3uri=input_manifest_s3uri,\n", " output_s3uri=annotations_base_s3uri,\n", " workteam_arn=workteam_arn,\n", " # To create a review/adjustment job from a manifest with existing labels in:\n", " # reviewing_attribute_name=\"label\",\n", " s3_inputs_prefix=f\"{bucket_prefix}data/manifests\",\n", ")\n", "print(f\"\\nLABELLING JOB STARTED:\\n{create_labeling_job_resp['LabelingJobArn']}\")\n", "print()\n", "print(input_manifest_s3uri)\n", "print(annotations_base_s3uri)\n", "print(sagemaker.get_execution_role())\n", "print(\"\\n\".join([\"\\nLabels:\", \"-------\"] + entity_classes))" ] }, { "cell_type": "markdown", "id": "e294758e-1fd5-4558-b0a5-cf37798ba0ac", "metadata": {}, "source": [ "---\n", "## Before you label - build custom containers\n", "\n", "The entity recognition model we'll train later uses **customized containers**, which install extra libraries over the standard [SageMaker Hugging Face framework containers](https://sagemaker.readthedocs.io/en/stable/frameworks/huggingface/index.html).\n", "\n", "> ⏰ Building these can take several minutes - so before you start labelling your documents in the SageMaker Ground Truth portal, **start the below cells running** to save some time.\n", ">\n", "> You don't need to wait for them to finish - just move on to the next \"Label the data\" section." ] }, { "cell_type": "code", "execution_count": null, "id": "af1d114b-4f62-4866-8440-b7de14fbc228", "metadata": {}, "outputs": [], "source": [ "# Configurations:\n", "hf_version = \"4.17\"\n", "py_version = \"py38\"\n", "pt_version = \"1.10\"\n", "train_repo_name = \"sm-ocr-training\"\n", "train_repo_tag = f\"hf-{hf_version}-pt-gpu\"\n", "inf_repo_name = \"sm-ocr-inference\"\n", "inf_repo_tag = train_repo_tag\n", "\n", "account_id = sagemaker.Session().account_id()\n", "region = os.environ[\"AWS_REGION\"]\n", "\n", "base_image_params = {\n", " \"framework\": \"huggingface\",\n", " \"region\": region,\n", " \"instance_type\": \"ml.p3.2xlarge\", # (Just used to check whether GPUs/accelerators are used)\n", " \"py_version\": py_version,\n", " \"version\": hf_version,\n", " \"base_framework_version\": f\"pytorch{pt_version}\",\n", "}\n", "\n", "train_base_uri = sagemaker.image_uris.retrieve(**base_image_params, image_scope=\"training\")\n", "inf_base_uri = sagemaker.image_uris.retrieve(**base_image_params, image_scope=\"inference\")\n", "\n", "# Combine together into the final URIs:\n", "train_image_uri = f\"{account_id}.dkr.ecr.{region}.amazonaws.com/{train_repo_name}:{train_repo_tag}\"\n", "print(f\"Target training image: {train_image_uri}\")\n", "inf_image_uri = f\"{account_id}.dkr.ecr.{region}.amazonaws.com/{inf_repo_name}:{inf_repo_tag}\"\n", "print(f\"Target inference image: {inf_image_uri}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "a2310bdf-17a5-49ba-a125-7aef8134fce4", "metadata": { "scrolled": true }, "outputs": [], "source": [ "%%time\n", "# (No need to re-run this cell if your train image is already in ECR)\n", "\n", "# Build and push the training image:\n", "!cd custom-containers/train-inf && sm-docker build . \\\n", " --compute-type BUILD_GENERAL1_LARGE \\\n", " --repository {train_repo_name}:{train_repo_tag} \\\n", " --role {config.sm_image_build_role} \\\n", " --build-arg BASE_IMAGE={train_base_uri}" ] }, { "cell_type": "markdown", "id": "0a4bb843-284e-4b9e-b322-59c2ec98afff", "metadata": {}, "source": [ "Note that although our training and inference containers use the [same Dockerfile](custom-containers/train-inf/Dockerfile), they're built from different parent images so both are needed in ECR:" ] }, { "cell_type": "code", "execution_count": null, "id": "4f6fd214-1739-4dc0-a14f-6702f5a9ecb2", "metadata": { "scrolled": true }, "outputs": [], "source": [ "%%time\n", "# (No need to re-run this cell if your inference image is already in ECR)\n", "\n", "# Build and push the inference image:\n", "!cd custom-containers/train-inf && sm-docker build . \\\n", " --compute-type BUILD_GENERAL1_LARGE \\\n", " --repository {inf_repo_name}:{inf_repo_tag} \\\n", " --role {config.sm_image_build_role} \\\n", " --build-arg BASE_IMAGE={inf_base_uri}" ] }, { "cell_type": "code", "execution_count": null, "id": "186cc52b-59ea-4d09-8c2e-893ed668eb00", "metadata": {}, "outputs": [], "source": [ "# Check from notebook whether the images were successfully created:\n", "ecr = boto3.client(\"ecr\")\n", "for repo, tag, uri in (\n", " (train_repo_name, train_repo_tag, train_image_uri),\n", " (inf_repo_name, inf_repo_tag, inf_image_uri)\n", "):\n", " imgs_desc = ecr.describe_images(\n", " registryId=account_id,\n", " repositoryName=repo,\n", " imageIds=[{\"imageTag\": tag}],\n", " )\n", " assert len(imgs_desc[\"imageDetails\"]) > 0, f\"Couldn't find ECR image {uri} after build\"\n", " print(f\"Found {uri}\")" ] }, { "cell_type": "markdown", "id": "3576ffea-d0ae-4f45-924a-c1457a5a5af4", "metadata": {}, "source": [ "---\n", "\n", "## Label the data!\n", "\n", "Shortly after the labeling job has been created, you'll see a new task for your user in the SageMaker Ground Truth **labeling portal**. If you lost the portal link from your email, you can access it from the *Private* tab of the [SageMaker Ground Truth Workforces console](https://console.aws.amazon.com/sagemaker/groundtruth?#/labeling-workforces).\n", "\n", "▶️ Click **Start working** and annotate the examples until the all are finished and you're returned to the portal homepage.\n", "\n", "▶️ **Try to be as consistent as possible** in how you annotate the classes, because inconsistent annotations can significantly degrade final model accuracy. Refer to the guidance (in this notebook and the 'Full Instructions') that we applied when annotating the example set.\n", "\n", "![](img/smgt-task-pending.png \"Screenshot of SMGT labeling portal with pending task\")" ] }, { "cell_type": "markdown", "id": "e3b2a5c2-f399-4f15-93ee-61715a141c2e", "metadata": {}, "source": [ "### Sync the results locally (and iterate?)\n", "\n", "Once you've finished annotating and the job shows as \"Complete\" in the [SMGT Console](https://console.aws.amazon.com/sagemaker/groundtruth?#/labeling-jobs) (which **might take an extra minute or two**, while your annotations are consolidated), you can download the results here to the notebook via the cell below:" ] }, { "cell_type": "code", "execution_count": null, "id": "0ceb096a-3ecf-4112-a2bd-3899d59d7f15", "metadata": {}, "outputs": [], "source": [ "!aws s3 sync --quiet $annotations_base_s3uri ./data/annotations" ] }, { "cell_type": "markdown", "id": "84178ce6-4794-49cf-8223-a00b7d4e06f1", "metadata": {}, "source": [ "You should see a subfolder created with the name of your annotation job, under which the **`manifests/output/output.manifest`** file contains the consolidated results of your labelling - again in the open JSON-Lines format.\n", "\n", "▶️ **Check** your results appear as expected, and explore the file format.\n", "\n", "> Because label outputs are in JSON-Lines, it's easy to consolidate, transform, and manipulate these results as required using open source tools!" ] }, { "cell_type": "markdown", "id": "57ee9463-88eb-4eec-98eb-0d0b6eb10a64", "metadata": {}, "source": [ "---\n", "## Consolidate annotated data\n", "\n", "To construct a model training set, we'll typically need to consolidate the results of multiple SageMaker Ground Truth labelling jobs: Perhaps because the work was split up into more manageable chunks - or maybe because additional review/adjustment jobs were run to improve label quality.\n", "\n", "Inside your `data/annotations` folder, you'll find some **pre-annotated augmentation data** provided for you already (in the `augmentation-` subfolders). These datasets are not especially large or externally useful, but will help you train an example model without too much (or even any!) manual annotation effort.\n", "\n", "▶️ **Edit** the `include_jobs` line below to control which datasets (pre-provided and your own) will be included:" ] }, { "cell_type": "code", "execution_count": null, "id": "5524dde5-560f-4893-8b1e-1b3848624d06", "metadata": {}, "outputs": [], "source": [ "include_jobs = [\n", " \"augmentation-1\",\n", " \"augmentation-2\",\n", " # TODO: Can edit the below to include your custom data, if you were able to label it:\n", " # \"cfpb-workshop-1\",\n", "]\n", "\n", "\n", "source_manifests = []\n", "for job_name in sorted(filter(\n", " lambda n: os.path.isdir(f\"data/annotations/{n}\"),\n", " os.listdir(\"data/annotations\")\n", ")):\n", " if job_name not in include_jobs:\n", " logger.warning(f\"Skipping {job_name} (not in include_jobs list)\")\n", " continue\n", " job_manifest_path = f\"data/annotations/{job_name}/manifests/output/output.manifest\"\n", " if not os.path.isfile(job_manifest_path):\n", " raise RuntimeError(f\"Could not find job output manifest {job_manifest_path}\")\n", " source_manifests.append({\"job_name\": job_name, \"manifest_path\": job_manifest_path})\n", "\n", "print(f\"Got {len(source_manifests)} annotated manifests:\")\n", "print(\"\\n\".join(map(lambda o: o[\"manifest_path\"], source_manifests)))" ] }, { "cell_type": "markdown", "id": "1d09939f-b6c3-41c3-b1ae-95bce8535e39", "metadata": {}, "source": [ "Note that to **combine multiple output manifests to a single dataset**:\n", "\n", "- The labels must be stored in the same attribute on every record (records use the labeling job name by default, which will be different between jobs).\n", "- If importing data collected from some other account (like the `augmentation-` sets), we'll need to **map the S3 URIs** to equivalent links on your own bucket." ] }, { "cell_type": "code", "execution_count": null, "id": "dec049d0-7fc2-4f28-b2c9-0ea74aa5e122", "metadata": {}, "outputs": [], "source": [ "standard_label_field = \"label\"\n", "\n", "print(\"Writing data/annotations/annotations-all.manifest.jsonl\")\n", "with open(\"data/annotations/annotations-all.manifest.jsonl\", \"w\") as fout:\n", " util.preproc.consolidate_data_manifests(\n", " source_manifests,\n", " fout,\n", " standard_label_field=standard_label_field,\n", " bucket_mappings={\"DOC-EXAMPLE-BUCKET\": bucket_name},\n", " prefix_mappings={\"EXAMPLE-PREFIX/\": bucket_prefix},\n", " )" ] }, { "cell_type": "markdown", "id": "6127578f-1f89-4aa0-8cca-3f988026482f", "metadata": {}, "source": [ "### Split training and test sets\n", "\n", "To get some insight on how well our model is generalizing to real-world data, we'll need to reserve some annotated data as a testing/validation set.\n", "\n", "Below, we randomly partition the data into training and test sets and then upload the two manifests to S3:" ] }, { "cell_type": "code", "execution_count": null, "id": "1db93c29-1f24-40dc-a792-31b983c9c9cb", "metadata": {}, "outputs": [], "source": [ "def split_manifest(f_in, f_train, f_test, train_pct=0.9, random_seed=1337):\n", " \"\"\"Split `f_in` manifest file into `f_train`, `f_test`\"\"\"\n", " logger.info(f\"Reading {f_in}\")\n", " with open(f_in, \"r\") as fin:\n", " lines = list(filter(lambda line: line, fin))\n", " logger.info(\"Shuffling records\")\n", " random.Random(random_seed).shuffle(lines)\n", " n_train = round(len(lines) * train_pct)\n", "\n", " with open(f_train, \"w\") as ftrain:\n", " logger.info(f\"Writing {n_train} records to {f_train}\")\n", " for line in lines[:n_train]:\n", " ftrain.write(line)\n", " with open(f_test, \"w\") as ftest:\n", " logger.info(f\"Writing {len(lines) - n_train} records to {f_test}\")\n", " for line in lines[n_train:]:\n", " ftest.write(line)\n", "\n", "\n", "split_manifest(\n", " \"data/annotations/annotations-all.manifest.jsonl\",\n", " \"data/annotations/annotations-train.manifest.jsonl\",\n", " \"data/annotations/annotations-test.manifest.jsonl\",\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "294ea8f9-214c-4a3e-ada9-92b295f7c9e6", "metadata": {}, "outputs": [], "source": [ "train_manifest_s3uri = f\"s3://{bucket_name}/{bucket_prefix}data/annotations/annotations-train.manifest.jsonl\"\n", "!aws s3 cp data/annotations/annotations-train.manifest.jsonl $train_manifest_s3uri\n", "\n", "test_manifest_s3uri = f\"s3://{bucket_name}/{bucket_prefix}data/annotations/annotations-test.manifest.jsonl\"\n", "!aws s3 cp data/annotations/annotations-test.manifest.jsonl $test_manifest_s3uri" ] }, { "cell_type": "markdown", "id": "b9ff1563-6a77-468e-bea5-94efb4ee62f0", "metadata": {}, "source": [ "### Visualize the data\n", "\n", "Before training the model, we'll sense-check the data by plotting a few examples.\n", "\n", "The utility function below will overlay the page image with the annotated bounding boxes, the locations of `WORD` blocks detected from the Amazon Textract results, and the resulting classification of individual Textract `WORD`s. To render these results, the Amazon Textract OCR results need to be downloaded locally to the notebook:" ] }, { "cell_type": "code", "execution_count": null, "id": "6d2fbd5d-e416-46c5-9652-066fe5041f73", "metadata": {}, "outputs": [], "source": [ "%%time\n", "\n", "!aws s3 sync --quiet $textract_s3uri ./data/textracted" ] }, { "cell_type": "code", "execution_count": null, "id": "14457108-c638-4c13-b137-0d25bb4cbc5c", "metadata": {}, "outputs": [], "source": [ "with open(\"data/annotations/annotations-test.manifest.jsonl\", \"r\") as fman:\n", " test_examples = [json.loads(line) for line in filter(lambda l: l, fman)]\n", "\n", "util.viz.draw_from_manifest_items(\n", " test_examples,\n", " standard_label_field,\n", " entity_classes,\n", " imgs_s3uri[len(\"s3://\"):].partition(\"/\")[2],\n", " textract_s3key_prefix=textract_s3uri[len(\"s3://\"):].partition(\"/\")[2],\n", " imgs_local_prefix=\"data/imgs-clean\",\n", " textract_local_prefix=\"data/textracted\",\n", ")" ] }, { "cell_type": "markdown", "id": "9bf34f40-949a-470e-9e76-ad29fb15a20e", "metadata": { "tags": [] }, "source": [ "---\n", "## Train the entity recognition model\n", "\n", "We now have all the data needed to train and validate an layout- and page-image-aware entity recognition model in a [SageMaker Training Job](https://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-training.html).\n", "\n", "In this process:\n", "\n", "- SageMaker will run the job on a dedicated, managed instance of type we choose (we'll use `ml.p*` or `ml.g*` GPU-accelerated types), allowing us to keep this notebook's resources modest and only pay for the seconds of GPU time the training job needs.\n", "- The data as specified in the manifest files will be downloaded from Amazon S3.\n", "- The bundle of scripts we provide (in `src/`) will be transparently uploaded to S3 and then run inside the specified SageMaker-provided [framework container](https://docs.aws.amazon.com/sagemaker/latest/dg/docker-containers-prebuilt.html). There's no need for us to build our own container image or implement a serving stack for inference (although fully-custom containers are [also supported](https://docs.aws.amazon.com/sagemaker/latest/dg/docker-containers.html)).\n", "- Job hyperparameters will be passed through to our `src/` scripts as CLI arguments.\n", "- SageMaker will analyze the logs from the job (i.e. `print()` or `logger` calls from our script) with the regular expressions specified in `metric_definitions`, to scrape structured timeseries metrics like loss and accuracy.\n", "- When the job finishes, the contents of the `model` folder in the container will be automatically tarballed and uploaded to a `model.tar.gz` in Amazon S3.\n", "\n", "You can also refer to [Hugging Face's own docs for training on SageMaker](https://huggingface.co/transformers/sagemaker.html) for more information and examples." ] }, { "cell_type": "code", "execution_count": null, "id": "de122f49-cf2b-4331-86a4-08ba0c6bfed7", "metadata": {}, "outputs": [], "source": [ "from sagemaker.huggingface import HuggingFace as HuggingFaceEstimator\n", "\n", "hyperparameters = {\n", " \"model_name_or_path\": \"microsoft/layoutxlm-base\",\n", "\n", " # (See src/code/config.py for more info on script parameters)\n", " \"annotation_attr\": standard_label_field,\n", " \"images_prefix\": imgs_s3uri[len(\"s3://\"):].partition(\"/\")[2],\n", " \"textract_prefix\": textract_s3uri[len(\"s3://\"):].partition(\"/\")[2],\n", " \"num_labels\": len(fields) + 1, # +1 for \"other\"\n", "\n", " \"per_device_train_batch_size\": 2,\n", " \"per_device_eval_batch_size\": 4,\n", "\n", " \"num_train_epochs\": 20,\n", " \"early_stopping_patience\": 15,\n", " \"metric_for_best_model\": \"eval_focus_else_acc_minus_one\",\n", " \"greater_is_better\": \"true\",\n", "\n", " # Early stopping implies checkpointing every evaluation (epoch), so limit the total checkpoints\n", " # kept to avoid filling up disk:\n", " \"save_total_limit\": 10,\n", "}\n", "\n", "metric_definitions = [\n", " {\"Name\": \"epoch\", \"Regex\": util.training.get_hf_metric_regex(\"epoch\")},\n", " {\"Name\": \"learning_rate\", \"Regex\": util.training.get_hf_metric_regex(\"learning_rate\")},\n", " {\"Name\": \"train:loss\", \"Regex\": util.training.get_hf_metric_regex(\"loss\")},\n", " {\n", " \"Name\": \"validation:n_examples\",\n", " \"Regex\": util.training.get_hf_metric_regex(\"eval_n_examples\"),\n", " },\n", " {\"Name\": \"validation:loss_avg\", \"Regex\": util.training.get_hf_metric_regex(\"eval_loss\")},\n", " {\"Name\": \"validation:acc\", \"Regex\": util.training.get_hf_metric_regex(\"eval_acc\")},\n", " {\n", " \"Name\": \"validation:n_focus_examples\",\n", " \"Regex\": util.training.get_hf_metric_regex(\"eval_n_focus_examples\"),\n", " },\n", " {\n", " \"Name\": \"validation:focus_acc\",\n", " \"Regex\": util.training.get_hf_metric_regex(\"eval_focus_acc\"),\n", " },\n", " {\n", " \"Name\": \"validation:target\",\n", " \"Regex\": util.training.get_hf_metric_regex(\"eval_focus_else_acc_minus_one\"),\n", " },\n", "]\n", "\n", "estimator = HuggingFaceEstimator(\n", " role=sagemaker.get_execution_role(),\n", " entry_point=\"train.py\",\n", " source_dir=\"src\",\n", " py_version=py_version,\n", " pytorch_version=pt_version,\n", " transformers_version=hf_version,\n", " image_uri=train_image_uri, # Use customized training container image\n", "\n", " base_job_name=\"ws-xlm-cfpb-hf\",\n", " output_path=f\"s3://{bucket_name}/{bucket_prefix}trainjobs\",\n", "\n", " instance_type=\"ml.p3.2xlarge\", # Could also consider ml.g4dn.xlarge\n", " instance_count=1,\n", " volume_size=80,\n", "\n", " debugger_hook_config=False,\n", "\n", " hyperparameters=hyperparameters,\n", " metric_definitions=metric_definitions,\n", " environment={\n", " # Required for our custom dataset loading code (which depends on tokenizer):\n", " \"TOKENIZERS_PARALLELISM\": \"false\",\n", " },\n", ")" ] }, { "cell_type": "markdown", "id": "5f2a23a1-2307-4e13-88c3-02059d64dbd2", "metadata": {}, "source": [ "Finally, the below cell will actually kick off the training job and stream logs from the running container.\n", "\n", "> ℹ️ You'll also be able to check the status of the job in the [Training jobs page of the SageMaker Console](https://console.aws.amazon.com/sagemaker/home?#/jobs)." ] }, { "cell_type": "code", "execution_count": null, "id": "e327065b-bb2d-47c4-8920-343cdf8dd181", "metadata": { "scrolled": true }, "outputs": [], "source": [ "inputs = {\n", " \"images\": thumbs_s3uri,\n", " \"train\": train_manifest_s3uri,\n", " \"textract\": textract_s3uri + \"/\",\n", " \"validation\": test_manifest_s3uri,\n", "}\n", "estimator.fit(inputs)" ] }, { "cell_type": "markdown", "id": "33f211cd-423d-4267-8432-9f044fcc6082", "metadata": { "tags": [] }, "source": [ "### One-click model deployment\n", "\n", "Once the training job is complete, the model can be deployed to an endpoint via `estimator.deploy()` - specifying any extra parameters needed such as environment variables and, in this case, configurations for [Asynchronous Inference](https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference.html). Async inference endpoints in SageMaker can accept larger payloads and auto-scale down to 0 instances when not in use (if configured) - making them a useful option for many document processing use cases." ] }, { "cell_type": "code", "execution_count": null, "id": "6d5a0173-a3c1-486d-b7bc-c610eeb11020", "metadata": {}, "outputs": [], "source": [ "training_job_name = estimator.latest_training_job.describe()[\"TrainingJobName\"]\n", "# Or:\n", "# training_job_name = tuner.best_training_job()\n", "\n", "predictor = estimator.deploy(\n", " # Avoid us accidentally deploying the same model twice by setting name per training job:\n", " endpoint_name=training_job_name,\n", " initial_instance_count=1,\n", " instance_type=\"ml.g4dn.xlarge\", # Or try ml.m5.2xlarge\n", " image_uri=inf_image_uri,\n", "\n", " serializer=sagemaker.serializers.JSONSerializer(),\n", " deserializer=sagemaker.deserializers.JSONDeserializer(),\n", "\n", " env={\n", " \"PYTHONUNBUFFERED\": \"1\", # TODO: Disable once debugging is done\n", " \"MMS_MAX_REQUEST_SIZE\": str(100*1024*1024), # Accept large payloads (docs)\n", " \"MMS_MAX_RESPONSE_SIZE\": str(100*1024*1024), # Allow large responses\n", " },\n", "\n", " # Deploy in Asynchronous mode, to support large req/res payloads:\n", " async_inference_config=sagemaker.async_inference.AsyncInferenceConfig(\n", " output_path=f\"s3://{config.model_results_bucket}\",\n", " max_concurrent_invocations_per_instance=2,\n", " notification_config={\n", " \"SuccessTopic\": config.model_callback_topic_arn,\n", " \"ErrorTopic\": config.model_callback_topic_arn,\n", " },\n", " ),\n", ")" ] }, { "cell_type": "markdown", "id": "d1d4661a-fe4a-4e86-b97f-2b47b25a1cc2", "metadata": {}, "source": [ "If needed (for example, if your kernel crashes or restarts), you can also attach to previously deployed endpoints. Just look up the endpoint name from the SageMaker Console:" ] }, { "cell_type": "code", "execution_count": null, "id": "bb81b22b-ff28-45f2-b391-b6f875045a20", "metadata": {}, "outputs": [], "source": [ "# endpoint_name=\"xlm-cfpb-hf-2022-05-23-14-10-19-602\"\n", "# predictor = sagemaker.predictor_async.AsyncPredictor(\n", "# sagemaker.Predictor(\n", "# endpoint_name,\n", "# serializer=sagemaker.serializers.JSONSerializer(),\n", "# deserializer=sagemaker.deserializers.JSONDeserializer(),\n", "# ),\n", "# name=endpoint_name,\n", "# )" ] }, { "cell_type": "markdown", "id": "71c90243-db89-40d9-88ec-30e337b03571", "metadata": { "tags": [] }, "source": [ "---\n", "## Extract clean input images on-demand\n", "\n", "Just as we generated page thumbnail images to originally train our model, online inference should be able to generate these input features on-demand. In this example, the same code we previously used in a batch processing job has already been automatically deployed to a SageMaker inference endpoint for you. We can look up the endpoint name from the deployed stack parameters:" ] }, { "cell_type": "code", "execution_count": null, "id": "a51024fa-68c4-4b9a-8ed9-952d4212ff05", "metadata": {}, "outputs": [], "source": [ "preproc_endpoint_name = ssm.get_parameter(\n", " Name=config.thumbnail_endpoint_name_param,\n", ")[\"Parameter\"][\"Value\"]\n", "print(f\"Pre-created thumbnailer endpoint name:\\n {preproc_endpoint_name}\")" ] }, { "cell_type": "markdown", "id": "f788bbf2-8535-4f4e-a27b-fe380e70cd2a", "metadata": {}, "source": [ "The online thumbnail-generation endpoint accepts raw input documents (i.e. PDFs, images), and returns compressed arrays of page image data. From the name of the endpoint, you can configure I/O formats and connect from the notebook as shown below:" ] }, { "cell_type": "code", "execution_count": null, "id": "646e61ea-ff3a-44ae-9035-64e37e048b9d", "metadata": {}, "outputs": [], "source": [ "try:\n", " desc = smclient.describe_endpoint(EndpointName=preproc_endpoint_name)\n", "except smclient.exceptions.ClientError as e:\n", " if e.response.get(\"Error\", {}).get(\"Message\", \"\").startswith(\"Could not find\"):\n", " desc = None # Endpoint does not exist\n", " else:\n", " raise e # Some other unknown issue\n", "\n", "if desc is None:\n", " raise ValueError(\n", " \"The configured thumbnailing endpoint does not exist in SageMaker. See the 'Optional \"\n", " \"Extras.ipynb' notebook for instructions to manually deploy the thumbnailer before \"\n", " \"continuing. Missing endpoint: %s\" % preproc_endpoint_name\n", " )\n", "\n", "preproc_predictor = sagemaker.predictor_async.AsyncPredictor(\n", " sagemaker.Predictor(\n", " preproc_endpoint_name,\n", " serializer=util.deployment.FileSerializer.from_filename(\"any.pdf\"),\n", " deserializer=util.deployment.CompressedNumpyDeserializer(),\n", " ),\n", " name=preproc_endpoint_name,\n", ")" ] }, { "cell_type": "markdown", "id": "6b199be3-eaf8-48b3-9607-3431650aec99", "metadata": {}, "source": [ "So how would it look to test the endpoint from Python? Let's see an example:" ] }, { "cell_type": "code", "execution_count": null, "id": "284883db-5305-427e-b2ba-a2f463bf8e18", "metadata": {}, "outputs": [], "source": [ "%%time\n", "\n", "# Choose an input (document or image):\n", "input_file = \"data/raw/121 Financial Credit Union/Visa Credit Card Agreement.pdf\"\n", "#input_file = \"data/imgs-clean/121 Financial Credit Union/Visa Credit Card Agreement-0001-1.png\"\n", "\n", "# Ensure de/serializers are correctly set up (since depends on input file type):\n", "preproc_predictor.serializer = util.deployment.FileSerializer.from_filename(input_file)\n", "preproc_predictor.deserializer = util.deployment.CompressedNumpyDeserializer()\n", "# Duplication because of https://github.com/aws/sagemaker-python-sdk/issues/3100\n", "preproc_predictor.predictor.serializer = preproc_predictor.serializer\n", "preproc_predictor.predictor.deserializer = preproc_predictor.deserializer\n", "\n", "# Run prediction:\n", "print(\"Calling endpoint...\")\n", "resp = preproc_predictor.predict(input_file)\n", "print(f\"Got response of type {type(resp)}\")\n", "\n", "# Render result:\n", "util.viz.draw_thumbnails_response(resp)" ] }, { "cell_type": "markdown", "id": "d4d1c312-a55d-4c64-877c-3ebf938a778f", "metadata": {}, "source": [ "---\n", "## Using the entity recognition model\n", "\n", "Once the deployment is complete and a page thumbnail generator is ready, we're ready to test out inference on some documents!" ] }, { "cell_type": "markdown", "id": "c845df49-d798-4a92-b6f3-ba3f5f834d95", "metadata": {}, "source": [ "### Making requests and rendering results\n", "\n", "At a high level, the layout+language model accepts Textract-like JSON (e.g. as returned by [AnalyzeDocument](https://docs.aws.amazon.com/textract/latest/dg/API_AnalyzeDocument.html#API_AnalyzeDocument_ResponseSyntax) or [DetectDocumentText](https://docs.aws.amazon.com/textract/latest/dg/API_DetectDocumentText.html#API_DetectDocumentText_ResponseSyntax) APIs) and classifies each `WORD` [block](https://docs.aws.amazon.com/textract/latest/dg/API_Block.html) according to the entity classes we defined earlier: Returning the same JSON with additional fields added to indicate the predictions.\n", "\n", "In addition (per the logic in [src/code/inference.py](src/code/inference.py)):\n", "\n", "- To incorporate image features (for models that support them), requests can also include an `S3Thumbnails: { Bucket, Key }` object pointing to a thumbnailer endpoint response on S3.\n", "- Instead of passing the (typically large and already-S3-resident) Amazon Textract JSON inline, an `S3Input: { Bucket, Key }` reference can be passed instead (and this is actually how the standard pipeline integration works).\n", "- Output could also be redirected by passing an `S3Output: { Bucket, Key }` field in the request, but this is ignored and not needed on async endpoint deployments.\n", "- `TargetPageNum` and `TargetPageOnly` fields can be specified to limit processing to a single page of the input document.\n", "\n", "We can use utility functions to render these predictions as we did the manual annotations previously:\n", "\n", "> ⏰ **Inference may take time in some cases:**\n", ">\n", "> - Although enabling thumbnails can increase demo inference time below by several seconds, the end-to-end pipeline generates these images in parallel with running Amazon Textract - so there's usually no significant impact in practice.\n", "> - If you enabled **auto-scale-to-zero** on your your thumbnailer and/or model endpoint, you may see a cold-start of several minutes.\n", "\n", "> ⚠️ **Check:** Because of the way the SageMaker Python SDK's [AsyncPredictor](https://sagemaker.readthedocs.io/en/stable/api/inference/predictor_async.html) emulates a synchronous `predict()` interface for async endpoints, you may find the notebook waits indefinitely instead of raising an error when something goes wrong. If an inference takes more than ~30s to complete, check the endpoint logs from your [SageMaker Console Endpoints page](https://console.aws.amazon.com/sagemaker/home?#/endpoints) to see if your request resulted in an error." ] }, { "cell_type": "code", "execution_count": null, "id": "c16db0e3-2c66-4032-8784-ccf7fdb5e255", "metadata": {}, "outputs": [], "source": [ "import ipywidgets as widgets\n", "import trp\n", "\n", "# Enabling thumbnails can significantly increase inference time here, but can improve results for\n", "# models that consume image features (like LayoutLMv2, XLM):\n", "include_thumbnails = False\n", "\n", "def predict_from_manifest_item(\n", " item,\n", " predictor,\n", " imgs_s3key_prefix=imgs_s3uri[len(\"s3://\"):].partition(\"/\")[2],\n", " raw_s3uri_prefix=raw_s3uri,\n", " textract_s3key_prefix=textract_s3uri[len(\"s3://\"):].partition(\"/\")[2],\n", " imgs_local_prefix=\"data/imgs-clean\",\n", " textract_local_prefix=\"data/textracted\",\n", " draw=True,\n", "):\n", " paths = util.viz.local_paths_from_manifest_item(\n", " item,\n", " imgs_s3key_prefix,\n", " textract_s3key_prefix=textract_s3key_prefix,\n", " imgs_local_prefix=imgs_local_prefix,\n", " textract_local_prefix=textract_local_prefix,\n", " )\n", "\n", " if include_thumbnails:\n", " doc_textract_s3key = item[\"textract-ref\"][len(\"s3://\"):].partition(\"/\")[2]\n", " doc_raw_s3uri = raw_s3uri_prefix + doc_textract_s3key[len(textract_s3key_prefix):].rpartition(\"/\")[0]\n", " print(f\"Fetching thumbnails for {doc_raw_s3uri}\")\n", " thumbs_async = preproc_predictor.predict_async(input_path=doc_raw_s3uri)\n", " thumbs_bucket, _, thumbs_key = thumbs_async.output_path[len(\"s3://\"):].partition(\"/\")\n", " # Wait for the request to complete:\n", " thumbs_async.get_result(sagemaker.async_inference.WaiterConfig())\n", " req_extras = {\"S3Thumbnails\": {\"Bucket\": thumbs_bucket, \"Key\": thumbs_key}}\n", " print(\"Got thumbnails result\")\n", " else:\n", " req_extras = {}\n", "\n", " result_json = predictor.predict({\n", " \"S3Input\": {\"S3Uri\": item[\"textract-ref\"]},\n", " \"TargetPageNum\": item[\"page-num\"],\n", " \"TargetPageOnly\": True,\n", " **req_extras,\n", " })\n", "\n", " if \"Warnings\" in result_json:\n", " for warning in result_json[\"Warnings\"]:\n", " logger.warning(warning)\n", " result_trp = trp.Document(result_json)\n", "\n", " if draw:\n", " util.viz.draw_smgt_annotated_page(\n", " paths[\"image\"],\n", " entity_classes,\n", " annotations=[],\n", " textract_result=result_trp,\n", " # Note that page_num should be item[\"page-num\"] if we requested the full set of pages\n", " # from the model above:\n", " page_num=1,\n", " )\n", " return result_trp\n", "\n", "\n", "widgets.interact(\n", " lambda ix: predict_from_manifest_item(test_examples[ix], predictor),\n", " ix=widgets.IntSlider(\n", " min=0,\n", " max=len(test_examples) - 1,\n", " step=1,\n", " value=0,\n", " description=\"Example:\",\n", " )\n", ")" ] }, { "cell_type": "markdown", "id": "ae6ddf5a-b833-4750-b953-acb8a0ecb741", "metadata": {}, "source": [ "### From token classification to entity detection\n", "\n", "You may have noticed a slight mismatch: We're talking about extracting 'fields' or 'entities' from the document, but our model just classifies individual words. Going from words to entities assumes we're able to understand which words go \"together\" and what order they should be read in.\n", "\n", "Fortunately, Amazon Textract helps us out with this too as the word blocks are already collected into `LINE`s.\n", "\n", "For many straightforward applications, we can simply loop through the lines on a page and define an \"entity detection\" as a contiguous group of the same class - as below:" ] }, { "cell_type": "code", "execution_count": null, "id": "5cea40eb-e46e-49d5-902a-f5b217f2ea31", "metadata": {}, "outputs": [], "source": [ "res = predict_from_manifest_item(\n", " test_examples[6],\n", " predictor,\n", " draw=False,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "49081201-35e1-4ac9-9fd6-04ab23f69966", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "other_cls = len(entity_classes)\n", "prev_cls = other_cls\n", "current_entity = \"\"\n", "\n", "for page in res.pages:\n", " for line in page.lines:\n", " for word in line.words:\n", " pred_cls = word._block[\"PredictedClass\"]\n", " if pred_cls != prev_cls:\n", " if prev_cls != other_cls:\n", " print(f\"----------\\n{entity_classes[prev_cls]}:\\n{current_entity}\")\n", " prev_cls = pred_cls\n", " if pred_cls != other_cls:\n", " current_entity = word.text\n", " else:\n", " current_entity = \"\"\n", " continue\n", " current_entity = \" \".join((current_entity, word.text))" ] }, { "cell_type": "markdown", "id": "1d4bd908-4a5c-437c-90e4-895c572745f0", "metadata": {}, "source": [ "Of course there may be some instances where this heuristic breaks down, but we still have access to all the position (and text) information from each `LINE` and `WORD` to write additional rules for reading order and separation if wanted." ] }, { "cell_type": "markdown", "id": "93670571-10cf-489b-8b97-cb5a20d8401f", "metadata": {}, "source": [ "---\n", "## Setting up the end-to-end pipeline\n", "\n", "### Integrating the entity detection model\n", "\n", "So far we've demonstrated running entity detection requests from here in the notebook, but how can this model be integrated into the end-to-end document processing pipeline stack?\n", "\n", "First, you'll identify the **endpoint name** of your deployed model and the **AWS Systems Manager Parameter** that configures the SageMaker endpoint parameter for the pipeline stack:" ] }, { "cell_type": "code", "execution_count": null, "id": "b5a21365-586a-4b83-8a7f-9a0b4cee0f89", "metadata": {}, "outputs": [], "source": [ "print(f\"Endpoint name:\\n {predictor.endpoint_name}\")\n", "print(f\"\\nEndpoint SSM param:\\n {config.sagemaker_endpoint_name_param}\")" ] }, { "cell_type": "markdown", "id": "46912d03-8748-485b-9d50-841f154a670e", "metadata": {}, "source": [ "Finally, we'll update this SSM parameter to point to the deployed SageMaker endpoint.\n", "\n", "The below code should do this for you automatically:\n", "\n", "> ⚠️ **Note:** The [Lambda function](../pipeline/enrichment/fn-call-sagemaker/main.py) that calls your model from the OCR pipeline caches the endpoint name for a few minutes (`CACHE_TTL_SECONDS`) to reduce unnecessary ssm:GetParameter calls - so it may take a little time for an update here to take effect if you already processed a document recently." ] }, { "cell_type": "code", "execution_count": null, "id": "a0a6d62f-47bd-42ea-a74f-fed32fcd8f85", "metadata": {}, "outputs": [], "source": [ "pipeline_endpoint_name = predictor.endpoint_name\n", "\n", "print(f\"Configuring pipeline with model: {pipeline_endpoint_name}\")\n", "\n", "ssm.put_parameter(\n", " Name=config.sagemaker_endpoint_name_param,\n", " Overwrite=True,\n", " Value=pipeline_endpoint_name,\n", ")" ] }, { "cell_type": "markdown", "id": "e32572c6-4fdc-4ade-8801-bfe8f76934c2", "metadata": {}, "source": [ "Alternatively, you could open the [AWS Systems Manager Parameter Store console](https://console.aws.amazon.com/systems-manager/parameters/?tab=Table) and click on the *name* of the parameter to open its detail page, then the **Edit** button in the top right corner as shown below:\n", "\n", "![](img/ssm-param-detail-screenshot.png \"Screenshot of SSM parameter detail page showing Edit button\")\n", "\n", "From this screen you can manually set the **Value** of the parameter and save the changes.\n", "\n", "Whether you updated the SSM parameters via code or the console, your the pre-processing and enrichment stages of your stack should now be configured to use your endpoints!" ] }, { "cell_type": "markdown", "id": "81811e20-7b0d-474d-9602-d260db900a86", "metadata": {}, "source": [ "### Updating the pipeline entity definitions\n", "\n", "As well as configuring the *enrichment* stage of the pipeline to reference the deployed version of the model, we need to configure the *post-processing* stage to match the model's **definition of entity/field types**.\n", "\n", "The entity configuration is as we saved in the previous notebook, but the `annotation_guidance` attributes are not needed:\n", "\n", "> ℹ️ **Note:** As well as the mapping from ID numbers (returned by the model) to human-readable class names, this configuration controls how the pipeline consolidates entity matches into \"fields\" of the document: E.g. choosing the \"most likely\" or \"first\" value between multiple detections, or setting up a multi-value field." ] }, { "cell_type": "code", "execution_count": null, "id": "4688cdf8-e448-45b5-ad5d-c0dbe5a46b04", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "pipeline_entity_config = json.dumps([f.to_dict(omit=[\"annotation_guidance\"]) for f in fields], indent=2)\n", "print(pipeline_entity_config)" ] }, { "cell_type": "markdown", "id": "cc1b8289-1bab-414f-aadf-c50fb3d36464", "metadata": {}, "source": [ "As above, you *could* set this value manually in the SSM console for the parameter named as `EntityConfig`.\n", "\n", "...But we can make the same update via code through the APIs:" ] }, { "cell_type": "code", "execution_count": null, "id": "82bca521-6b34-47c9-9062-8d681d9ae817", "metadata": {}, "outputs": [], "source": [ "print(f\"Setting pipeline entity configuration\")\n", "ssm.put_parameter(\n", " Name=config.entity_config_param,\n", " Overwrite=True,\n", " Value=pipeline_entity_config,\n", ")" ] }, { "cell_type": "markdown", "id": "11386c0a-ad81-4a20-93a8-3f89cdf8388c", "metadata": {}, "source": [ "### Set up online review with Amazon Augmented AI (A2I)\n", "\n", "Whereas our original batch annotation used the [built-in](https://docs.aws.amazon.com/sagemaker/latest/dg/sms-task-types.html) image bounding box / object detection task UI, a custom task template is provided for online review.\n", "\n", "Since the template is built using a web framework (VueJS), we'll need to install some extra dependencies to enable building it:" ] }, { "cell_type": "code", "execution_count": null, "id": "dc459c88-e10b-46d5-82c3-24d74cc3dd49", "metadata": {}, "outputs": [], "source": [ "!cd review && npm install" ] }, { "cell_type": "markdown", "id": "a638c64f-79a9-413f-8887-220b50da9526", "metadata": {}, "source": [ "Then, build the UI HTML template from source:" ] }, { "cell_type": "code", "execution_count": null, "id": "7a5b333b-a887-4f05-a87c-c605ca4fd8aa", "metadata": {}, "outputs": [], "source": [ "!cd review && npm run build\n", "ui_template_file = \"review/dist/index.html\"" ] }, { "cell_type": "markdown", "id": "9ece1831-28dd-43f1-8077-91b064b6e531", "metadata": {}, "source": [ "Next, upload the built file as an A2I human review task UI:" ] }, { "cell_type": "code", "execution_count": null, "id": "09d403d4-08cb-4f4f-aefc-47fbfba80412", "metadata": {}, "outputs": [], "source": [ "with open(ui_template_file, \"r\") as f:\n", " create_template_resp = smclient.create_human_task_ui(\n", " HumanTaskUiName=\"fields-validation-1\", # (Can change this name as you like)\n", " UiTemplate={\"Content\": f.read()},\n", " )\n", "\n", "task_template_arn = create_template_resp[\"HumanTaskUiArn\"]\n", "print(f\"Created A2I task template:\\n{task_template_arn}\")" ] }, { "cell_type": "markdown", "id": "933e1637-e04a-4078-9079-beb56f8e5f17", "metadata": {}, "source": [ "We already defined a \"team\" for tasks to be routed to above, for SageMaker Ground Truth, and can re-use that team for the online review flow.\n", "\n", "To finish setting up the workflow itself, we need 2 more pieces of information:\n", "\n", "- The **location in S3** where review outputs should be stored\n", "- An appropriate **execution role** which will give the A2I workflow to read input documents and write review results.\n", "\n", "These are determined by the **OCR pipeline solution stack**, because the reviews bucket is created by the pipeline with event triggers to resume the next stage when reviews are uploaded.\n", "\n", "The code below should be able to look up these parameters for you automatically:" ] }, { "cell_type": "code", "execution_count": null, "id": "ad6f7c1f-d6d5-4d84-ba68-8bf9b86b812f", "metadata": {}, "outputs": [], "source": [ "reviews_bucket_name = config.pipeline_reviews_bucket_name\n", "print(reviews_bucket_name)\n", "reviews_role_arn = config.a2i_execution_role_arn\n", "print(reviews_role_arn)" ] }, { "cell_type": "markdown", "id": "cb724957-a783-48fb-85d3-39b16937e877", "metadata": {}, "source": [ "Alternatively, you may **find** your pipeline solution stack from the [AWS CloudFormation Console](https://console.aws.amazon.com/cloudformation/home?#/stacks) and click through to the stack detail page. From the **Outputs** tab, you should see the `A2IHumanReviewBucketName` and `A2IHumanReviewExecutionRoleArn` values as shown below.\n", "\n", "(You may also note the `A2IHumanReviewFlowParamName`, which we'll use in the next section)\n", "\n", "![](img/cfn-stack-outputs-a2i.png \"CloudFormation stack outputs for OCR pipeline\")" ] }, { "cell_type": "markdown", "id": "5e05cd6d-314b-435b-9994-8ec9bef79d4a", "metadata": {}, "source": [ "Once these values are populated, you're ready to create your review workflow by running the code below.\n", "\n", "Note that you can also manage flows via the [A2I Human Review Workflows Console](https://console.aws.amazon.com/a2i/home?#/human-review-workflows/)." ] }, { "cell_type": "code", "execution_count": null, "id": "cd1f6192-99cb-4c4e-95da-656c64c3bebe", "metadata": {}, "outputs": [], "source": [ "create_flow_resp = smclient.create_flow_definition(\n", " FlowDefinitionName=\"ocr-fields-validation-1\", # (Can change this name as you like)\n", " HumanLoopConfig={\n", " \"WorkteamArn\": workteam_arn,\n", " \"HumanTaskUiArn\": task_template_arn,\n", " \"TaskTitle\": \"Review OCR Field Extractions\",\n", " \"TaskDescription\": \"Review and correct credit card agreement field extractions\",\n", " \"TaskCount\": 1, # One reviewer per item\n", " \"TaskAvailabilityLifetimeInSeconds\": 60 * 60, # Availability timeout\n", " \"TaskTimeLimitInSeconds\": 60 * 60, # Working timeout\n", " },\n", " OutputConfig={\n", " \"S3OutputPath\": f\"s3://{reviews_bucket_name}/reviews\",\n", " },\n", " RoleArn=reviews_role_arn,\n", ")\n", "\n", "print(f\"Created review workflow:\\n{create_flow_resp['FlowDefinitionArn']}\")" ] }, { "cell_type": "markdown", "id": "3ed48057-1311-4609-839a-7cb5abf0b294", "metadata": {}, "source": [ "Finally, when the human review flow is created and registered, we can configure the document pipeline to use it - similarly to our SageMaker endpoint and entity configuration:" ] }, { "cell_type": "code", "execution_count": null, "id": "5fb62b35-7fe9-4b9a-b6fe-8dd1e2bf6c33", "metadata": {}, "outputs": [], "source": [ "print(f\"Configuring pipeline with review workflow: {create_flow_resp['FlowDefinitionArn']}\")\n", "\n", "ssm = boto3.client(\"ssm\")\n", "ssm.put_parameter(\n", " Name=config.a2i_review_flow_arn_param,\n", " Overwrite=True,\n", " Value=create_flow_resp[\"FlowDefinitionArn\"],\n", ")" ] }, { "cell_type": "markdown", "id": "087ba0fc-1470-4d51-ba47-ac3ced075e91", "metadata": {}, "source": [ "Alternatively through the console, you would follow these steps:\n", "\n", "▶️ **Check** the `A2IHumanReviewFlowParamName` output of your OCR pipeline stack in [CloudFormation](https://console.aws.amazon.com/cloudformation/home?#/stacks) (as we did above)\n", "\n", "▶️ **Open** the [AWS Systems Manager Parameter Store console](https://console.aws.amazon.com/systems-manager/parameters/?tab=Table) and **find the review flow parameter in the list**.\n", "\n", "▶️ **Click** on the name of the parameter to open its detail page, and then on the **Edit** button in the top right corner. Set the value to the **workflow ARN** (see previous code cell in this notebook) and save the changes.\n", "\n", "![](img/ssm-a2i-param-detail.png \"Screenshot of SSM parameter detail page for human workflow\")" ] }, { "cell_type": "markdown", "id": "0768a157-9115-4261-aeae-6827bf7189f2", "metadata": {}, "source": [ "---\n", "## Final testing\n", "\n", "Your OCR pipeline should now be fully functional! Let's try it out:\n", "\n", "▶️ **Log in** to the labelling portal (URL available from the [SageMaker Ground Truth Workforces Console](https://console.aws.amazon.com/sagemaker/groundtruth?#/labeling-workforces) for your correct AWS Region)\n", "\n", "![](img/smgt-find-workforce-url.png \"Screenshot of SMGT console with workforce login URL\")\n", "\n", "▶️ **Upload** one of the sample documents to your pipeline's input bucket in Amazon S3, either using the code snippets below or drag and drop in the [Amazon S3 Console](https://console.aws.amazon.com/s3/)" ] }, { "cell_type": "code", "execution_count": null, "id": "a9f202de-d041-4e94-af90-d80df1fe84f8", "metadata": {}, "outputs": [], "source": [ "pdfpaths = []\n", "for currpath, dirs, files in os.walk(\"data/raw\"):\n", " if \"/.\" in currpath or \"__\" in currpath:\n", " continue\n", " pdfpaths += [\n", " os.path.join(currpath, f) for f in files\n", " if f.lower().endswith(\".pdf\")\n", " ]\n", "pdfpaths = sorted(pdfpaths)" ] }, { "cell_type": "code", "execution_count": null, "id": "23811e21-bd19-45b1-ae92-a385f1981bb9", "metadata": {}, "outputs": [], "source": [ "test_filepath = pdfpaths[14]\n", "test_s3uri = f\"s3://{config.pipeline_input_bucket_name}/{test_filepath}\"\n", "\n", "!aws s3 cp '{test_filepath}' '{test_s3uri}'" ] }, { "cell_type": "markdown", "id": "68794a23-e2ec-49df-8d8b-7b5a8c375c2a", "metadata": {}, "source": [ "▶️ **Open up** your \"Processing Pipeline\" state machine in the [AWS Step Functions Console](https://console.aws.amazon.com/states/home?#/statemachines)\n", "\n", "After a few seconds you should find that a Step Function execution is automatically triggered and (since we enabled so many fields that at least one is always missing) the example is eventually forwarded for human review in A2I.\n", "\n", "As you'll see from the `ModelResult` field in your final *Step Output*, this pipeline produces a rich but usefully-structured output - with good opportunities for onward integration into further Step Functions steps or external systems. You can find more information and sample solutions for integrating AWS Step Functions in the [Step Functions Developer Guide](https://docs.aws.amazon.com/step-functions/latest/dg/welcome.html).\n", "\n", "![](img/sfn-statemachine-success.png \"Screenshot of successful Step Function execution with output JSON\")" ] }, { "cell_type": "markdown", "id": "bede76bf-3660-4418-82e3-ec303cb2a5a5", "metadata": { "tags": [] }, "source": [ "## Conclusion\n", "\n", "In this worked example we showed how advanced, open-source language processing models specifically tailored for document understanding can be integrated with [Amazon Textract](https://aws.amazon.com/textract/): providing a trainable, ML-driven framework for tackling more niche or complex requirements where Textract's [built-in structure extraction tools](https://aws.amazon.com/textract/features/) may not fully solve the challenges out-of-the-box.\n", "\n", "The underlying principle of the model - augmenting multi-task neural text processing architectures with positional data - is highly extensible, with potential to tackle a wide range of use cases where joint understanding of the content and presentation of text can deliver better results than considering text alone.\n", "\n", "We demonstrated how an end-to-end process automation pipeline applying this technology might look: Developing and deploying the model with [Amazon SageMaker](https://aws.amazon.com/sagemaker/), building a serverless workflow with [AWS Step Functions](https://aws.amazon.com/step-functions/) and [AWS Lambda](https://aws.amazon.com/lambda/), and driving quality with human review of low-confidence documents through [Amazon Augmented AI](https://aws.amazon.com/augmented-ai/).\n", "\n", "Thanks for following along, and for more information, don't forget to check out:\n", "\n", "- The other published [Amazon Textract Examples](https://docs.aws.amazon.com/textract/latest/dg/other-examples.html) listed in the [Textract Developer Guide](https://docs.aws.amazon.com/textract/latest/dg/what-is.html)\n", "- The extensive repository of [Amazon SageMaker Examples](https://github.com/aws/amazon-sagemaker-examples) and usage documentation in the [SageMaker Python SDK User Guide](https://sagemaker.readthedocs.io/en/stable/) - as well as the [SageMaker Developer Guide](https://docs.aws.amazon.com/sagemaker/index.html)\n", "- The wide range of other open algorithms and models published by [HuggingFace Transformers](https://huggingface.co/transformers/), and their specific documentation on [using the library with SageMaker](https://huggingface.co/transformers/sagemaker.html)\n", "- The conversational AI and NLP area (and others) of Amazon's own [Amazon.Science](https://www.amazon.science/conversational-ai-natural-language-processing) blog\n", "\n", "Happy building!" ] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "Python 3 (Data Science 3.0)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/sagemaker-data-science-310-v1" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }