{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "525c8bba-5198-46b4-aad2-b6c4f51e6371",
   "metadata": {
    "tags": []
   },
   "source": [
    "**Post-Processing Amazon Textract with Location-Aware Transformers**\n",
    "\n",
    "# Optional Extras\n",
    "\n",
    "> *This notebook works well with the `PyTorch 1.10 Python 3.8 CPU Optimized (Python 3)` kernel on SageMaker Studio - **different** from the others in the series*\n",
    "\n",
    "This notebook discusses optional extra/alternative steps separate from the typical pipeline setup flow. You won't typically need to run these steps unless specifically guided, or you're digging deeper into customization.\n",
    "\n",
    "## Common setup\n",
    "\n",
    "First, as usual, we'll set up and import required libraries. You should run these cells regardless of which optional section(s) you're using:\n",
    "\n",
    "The Hugging Face `datasets` and `transformers` installs here are used specifically for dataset preparation in the seq2seq section. If you have problems with these libraries and aren't tackling this section, you may be able to omit them. If you regularly need to install several custom libraries in Studio notebooks, refer to the [documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-byoi-create.html) and [samples](https://github.com/aws-samples/sagemaker-studio-custom-image-samples) on building **Custom kernel images** for SageMaker."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01620969-ea29-46c7-ba84-9006ca9e73d1",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "!pip install amazon-textract-response-parser \\\n",
    "    \"datasets>=2.4,<3\" \\\n",
    "    \"ipywidgets>=7,<8\" \\\n",
    "    sagemaker-studio-image-build \\\n",
    "    \"sagemaker>=2.87,<3\" \\\n",
    "    \"transformers>=4.25,<4.26\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53fa81a6-6466-4a78-b87b-4bcec14ca983",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "# Python Built-Ins:\n",
    "import json\n",
    "from logging import getLogger\n",
    "import os\n",
    "import time\n",
    "\n",
    "# External Dependencies:\n",
    "import boto3  # General-purpose AWS SDK for Python\n",
    "import numpy as np  # Matrix/math utilities\n",
    "import pandas as pd  # Data table / dataframe utilities\n",
    "import sagemaker  # High-level Python SDK for Amazon SageMaker\n",
    "\n",
    "# Local Dependencies:\n",
    "import util\n",
    "\n",
    "logger = getLogger()\n",
    "\n",
    "# Configuration:\n",
    "bucket_name = sagemaker.Session().default_bucket()\n",
    "bucket_prefix = \"textract-transformers/\"\n",
    "print(f\"Working in bucket s3://{bucket_name}/{bucket_prefix}\")\n",
    "config = util.project.init(\"ocr-transformers-demo\")\n",
    "print(config)\n",
    "\n",
    "# AWS service clients:\n",
    "smclient = boto3.client(\"sagemaker\")\n",
    "ssm = boto3.client(\"ssm\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2939bcb1-bdce-48f2-83d7-c0e14cd6adea",
   "metadata": {},
   "source": [
    "## Contents\n",
    "\n",
    "The sections of this notebook are independent:\n",
    "\n",
    "- **[Manual thumbnail generator setup](#Manual-thumbnail-generator-setup)**: Customise online page thumbnail generation endpoint\n",
    "- **[Optimise costs with endpoint auto-scaling](#Optimise-costs-with-endpoint-auto-scaling)**: Configure your SageMaker endpoint(s) to auto-scale based on incoming request volume\n",
    "- **[Experimenting with alternative OCR engines](#Experimenting-with-alternative-OCR-engines)**: Substitute Amazon Textract with open-source OCR tools, for use with unsupported languages\n",
    "- **[Exploring sequence-to-sequence models](#Exploring-sequence-to-sequence-models)**: Use generative models to automatically re-format detected fields and fix common OCR error patterns"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2de05db-4a6a-44c0-afc6-203ecdbd999d",
   "metadata": {
    "tags": []
   },
   "source": [
    "---\n",
    "\n",
    "## Manual thumbnail generator setup\n",
    "\n",
    "> This section walks through manually building and configuring the endpoint to generate resized page thumbnail images in real time.\n",
    ">\n",
    "> You may find it useful if you want to customise the container image or script used by this process, or if you deployed your pipeline without thumbnailing support but want to experiment with image-based models from notebooks.\n",
    ">\n",
    "> ⚠️ **Note:** Deploying and registering a thumbnailing endpoint from the notebook will still not turn on thumbnail generation in a pipeline deployed without support for it. Instead, refer to your CDK app parameters to ensure the pipeline state machine gets updated to include a thumbnail generation step."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71fa91ae-c79b-4ecc-b470-d7daa79033ce",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Build and register custom container image\n",
    "\n",
    "The tools we use to read PDF files aren't installed by default in the pre-built SageMaker containers and aren't `pip install`able, so the thumbnail generator will need a custom container image. We can derive a custom image from an existing AWS DLC serving container, to minimise boilerplate code because a SageMaker-compatible serving stack will already be included.\n",
    "\n",
    "Because SageMaker Studio kernels are already containerized, you won't be able to run typical `docker build` commands you may be used to: So we'll use the [SageMaker Studio Image Build CLI](https://github.com/aws-samples/sagemaker-studio-image-build-cli) to build the image and store it in your account's [Amazon Elastic Container Registry (ECR)](https://aws.amazon.com/ecr/):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3f48a7b-cdf4-4f8a-9544-6942fbe582a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configurations:\n",
    "preproc_ecr_repo_name = \"sm-ocr-preproc\"\n",
    "preproc_ecr_image_tag = \"pytorch-1.10-inf-cpu\"\n",
    "\n",
    "preproc_framework_version = \"1.10\"\n",
    "preproc_py_version = \"py38\"\n",
    "\n",
    "base_image_uri = sagemaker.image_uris.retrieve(\n",
    "    framework=\"pytorch\",\n",
    "    region=os.environ[\"AWS_REGION\"],\n",
    "    instance_type=\"ml.c5.xlarge\",  # (Just used to check whether GPUs/accelerators are used)\n",
    "    py_version=preproc_py_version,\n",
    "    image_scope=\"inference\",  # Inference base because we'll also deploy as an endpoint later\n",
    "    version=preproc_framework_version,\n",
    ")\n",
    "\n",
    "# Combine together into the final URI (not needed for the build, but used later in the notebook):\n",
    "account_id = sagemaker.Session().account_id()\n",
    "region = os.environ[\"AWS_REGION\"]\n",
    "preproc_ecr_image_uri = \"{}.dkr.ecr.{}.amazonaws.com/{}:{}\".format(\n",
    "    account_id, region, preproc_ecr_repo_name, preproc_ecr_image_tag\n",
    ")\n",
    "print(f\"Will build to {preproc_ecr_image_uri}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b2d84a6-1b5b-4ebc-b057-9ffec1cb5fd1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "# (No need to re-run this cell if your image is already in ECR)\n",
    "\n",
    "# Actually build & push the container image:\n",
    "!cd custom-containers/preproc && sm-docker build . \\\n",
    "    --repository {ecr_repo_name}:{ecr_image_tag} \\\n",
    "    --role {config.sm_image_build_role} \\\n",
    "    --build-arg BASE_IMAGE={base_image_uri}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fc24da6-d3b8-4b13-ab94-fd7085836678",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check from notebook whether the image was successfully created:\n",
    "ecr = boto3.client(\"ecr\")\n",
    "imgs_desc = ecr.describe_images(\n",
    "    registryId=account_id,\n",
    "    repositoryName=preproc_ecr_repo_name,\n",
    "    imageIds=[{\"imageTag\": preproc_ecr_image_tag}],\n",
    ")\n",
    "assert len(imgs_desc[\"imageDetails\"]) > 0, \"Couldn't find ECR image {} after build\".format(\n",
    "    preproc_ecr_image_uri\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca987aff-e81d-46e6-a723-4f7e980481cc",
   "metadata": {},
   "source": [
    "### Deploy and test the thumbnailer endpoint\n",
    "\n",
    "Because the custom image is based on the standard SageMaker PyTorch inference container, our [preproc/preproc.py](preproc/preproc.py) script can [work with the existing serving stack](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#id3) by exposing custom `model_fn`, `input_fn`, `predict_fn`, and/or `output_fn` functions.\n",
    "\n",
    "We'll bundle the scripts into a `.tar.gz` file in the format the PyTorch container expects: With inference code in a `code/` subfolder.\n",
    "\n",
    "Normally this process (and the setting of the `SAGEMAKER_PROGRAM` and `SAGEMAKER_SUBMIT_DIRECTORY` environment variables) is handled automatically by the `PyTorchModel` - which allows \"re-packing\" the tarball from a training job to create a new tarball with new `source_dir` and `entry_point` scripts. In this case though, we don't need such a two-step process because there's no training artifact to start from and no actual \"model\" in this tarball - PyTorch or otherwise. Our script just defines code to extract and resize page images, and a dummy `model_fn` so the endpoint won't crash from failing to find a model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f7cecb5-9b7d-41d8-ad5b-c54c7ebc2c9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compress the archive locally and list the compressed contents:\n",
    "preproc_model_path = util.deployment.tar_as_inference_code(\"preproc\", \"data/preproc-model.tar.gz\")\n",
    "print(f\"(Re)-created {preproc_model_path}\")\n",
    "!tar -ztvf {preproc_model_path}\n",
    "print()\n",
    "\n",
    "# Upload to S3:\n",
    "preproc_model_key = \"\".join((\n",
    "    bucket_prefix,\n",
    "    \"preproc-model/\",\n",
    "    util.uid.append_timestamp(\"model\"),  # (Maintain history in S3)\n",
    "    \".tar.gz\"\n",
    "))\n",
    "preproc_model_s3uri = f\"s3://{bucket_name}/{preproc_model_key}\"\n",
    "!aws s3 cp {preproc_model_path} {preproc_model_s3uri}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd15056b-2516-4e82-8464-3b9eee33421d",
   "metadata": {},
   "source": [
    "Once a `model.tar.gz` is available on S3, we're ready to create and deploy a SageMaker \"Model\" and Endpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "443e9614-a7d0-404f-977b-12c848f7a54b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.pytorch import PyTorchModel\n",
    "\n",
    "if config.thumbnails_callback_topic_arn.startswith(\"arn:\"):\n",
    "    async_notification_config = {\n",
    "        \"SuccessTopic\": config.thumbnails_callback_topic_arn,\n",
    "        \"ErrorTopic\": config.thumbnails_callback_topic_arn,\n",
    "    }\n",
    "else:\n",
    "    logger.warning(\"Pipeline stack deployed without thumbnailing callback topic\")\n",
    "    async_notification_config = {}\n",
    "\n",
    "\n",
    "class PatchedPyTorchModel(PyTorchModel):\n",
    "    \"\"\"Modified PyTorchModel to allow manually setting SM Script Mode environment vars\n",
    "\n",
    "    See: https://github.com/aws/sagemaker-python-sdk/issues/3361\n",
    "    \"\"\"\n",
    "\n",
    "    def prepare_container_def(self, *args, **kwargs):\n",
    "        # Call the parent function:\n",
    "        result = super().prepare_container_def(*args, **kwargs)\n",
    "        # ...But allow our manual env vars configuration to override the internals:\n",
    "        manual_env = dict(self.env)\n",
    "        result[\"Environment\"].update(manual_env)\n",
    "        return result\n",
    "\n",
    "\n",
    "preproc_model = PatchedPyTorchModel(\n",
    "    name=util.uid.append_timestamp(\"ocr-thumbnail\"),\n",
    "    model_data=preproc_model_s3uri,\n",
    "    entry_point=None,  # Set manually via tarball and SAGEMAKER_PROGRAM\n",
    "    framework_version=\"1.10\",\n",
    "    py_version=\"py38\",\n",
    "    image_uri=preproc_ecr_image_uri,\n",
    "    role=sagemaker.get_execution_role(),\n",
    "    env={\n",
    "        \"PYTHONUNBUFFERED\": \"1\",\n",
    "        \"SAGEMAKER_PROGRAM\": \"preproc.py\",\n",
    "        # TorchServe configurations for large payloads & slow inference:\n",
    "        \"TS_MAX_REQUEST_SIZE\": str(100*1024*1024),  # 100MiB instead of default ~6.2MiB\n",
    "        \"TS_MAX_RESPONSE_SIZE\": str(100*1024*1024),  # 100MiB instead of default ~6.2MiB\n",
    "        \"TS_DEFAULT_RESPONSE_TIMEOUT\": str(60*15),  # 15mins instead of the default (60s maybe?)\n",
    "    },\n",
    ")\n",
    "\n",
    "preproc_predictor = preproc_model.deploy(\n",
    "    initial_instance_count=1,\n",
    "    instance_type=\"ml.m5.xlarge\",\n",
    "    async_inference_config=sagemaker.async_inference.AsyncInferenceConfig(\n",
    "        output_path=f\"s3://{config.model_results_bucket}/preproc\",\n",
    "        max_concurrent_invocations_per_instance=2,\n",
    "        notification_config=async_notification_config,\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f573dd1-bae7-430f-b5ca-39d7c372c061",
   "metadata": {},
   "source": [
    "This endpoint accepts images or documents and outputs resized page thumbnail images.\n",
    "\n",
    "For multi-page documents the main output format is `application/x-npz`, which produces a [compressed numpy archive](https://numpy.org/doc/stable/reference/generated/numpy.savez_compressed.html#numpy.savez_compressed) in which `images` is an **array of images** each represented by **PNG bytes**. These formats require customizing the client (predictor) *serializer* and *deserializer* from the default for PyTorch. Since `Predictor` de/serializers set the `Content-Type` and `Accept` headers, we'll also need to re-configure the serializer whenever switching between input document types (for example PDF vs PNG).\n",
    "\n",
    "To support potentially large documents, the preprocessor is deployed to an **asynchronous** endpoint which enables larger request and response payload sizes.\n",
    "\n",
    "So how would it look to test the endpoint from Python? Let's see an example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce9bfda1-ffce-423c-8bee-d6cf761ece45",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "# Choose an input (document or image):\n",
    "input_file = \"data/raw/121 Financial Credit Union/Visa Credit Card Agreement.pdf\"\n",
    "#input_file = \"data/imgs-clean/121 Financial Credit Union/Visa Credit Card Agreement-0001-1.png\"\n",
    "\n",
    "# Ensure de/serializers are correctly set up:\n",
    "preproc_predictor.serializer = util.deployment.FileSerializer.from_filename(input_file)\n",
    "preproc_predictor.deserializer = util.deployment.CompressedNumpyDeserializer()\n",
    "# Duplication because of https://github.com/aws/sagemaker-python-sdk/issues/3100\n",
    "preproc_predictor.predictor.serializer = preproc_predictor.serializer\n",
    "preproc_predictor.predictor.deserializer = preproc_predictor.deserializer\n",
    "\n",
    "# Run prediction:\n",
    "print(\"Calling endpoint...\")\n",
    "resp = preproc_predictor.predict(input_file)\n",
    "print(f\"Got response of type {type(resp)}\")\n",
    "\n",
    "# Render result:\n",
    "util.viz.draw_thumbnails_response(resp)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c2f0719-7314-4097-82f8-98169036e86c",
   "metadata": {},
   "source": [
    "### Connect thumbnailer to the deployed processing pipeline\n",
    "\n",
    "Once your thumbnailer endpoint is deployed and working, you can connect it into your document processing pipeline via SSM parameter configuration - just like the main enrichment model. This will only have an effect if your pipeline was already deployed with thumbnailing enabled, so the cell below will first check whether that seems to be the case."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e2ef9eb-e39e-41fe-9309-78edae0375c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "if config.thumbnails_callback_topic_arn == \"undefined\":\n",
    "    raise ValueError(\n",
    "        \"This pipeline CDK stack was deployed with thumbnailing disabled (by setting parameter \"\n",
    "        \"use_thumbnails=False). Either redeploy the CDK stack with updated settings to enable \"\n",
    "        \"thumbnailing, or continue without (and consider deleting the thumbnailing endpoint you \"\n",
    "        \"created, to save unnecessary cost).\"\n",
    "    )\n",
    "\n",
    "thumbnail_endpoint_name = preproc_predictor.endpoint_name\n",
    "print(f\"Configuring pipeline with thumbnailer: {thumbnail_endpoint_name}\")\n",
    "\n",
    "ssm.put_parameter(\n",
    "    Name=config.thumbnail_endpoint_name_param,\n",
    "    Overwrite=True,\n",
    "    Value=thumbnail_endpoint_name,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7972cc5-8025-46df-9661-143a47d8c8da",
   "metadata": {},
   "source": [
    "### Clean up experimental models\n",
    "\n",
    "Clean up any endpoints you created that are no longer required, to free up resources and avoid unnecessary ongoing costs. The below code demonstrates how to delete an endpoint, and its associated configuration & model records. you may also like to clean up the `preproc-model/` S3 folder to remove any old draft versions.\n",
    "\n",
    "> ⚠️ **Note:** If you delete the active endpoint/model your deployed pipeline is configured to use for thumbnailing, your pipeline will fail to process new documents."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77093526-6286-401e-97d4-2a408cbbd15b",
   "metadata": {
    "tags": []
   },
   "source": [
    "---\n",
    "\n",
    "*[Back to contents](#Contents)*\n",
    "\n",
    "## Optimise costs with endpoint auto-scaling\n",
    "\n",
    "> This section demonstrates how you can enable and customise auto-scaling on your SageMaker endpoints to optimise resource use and cost.\n",
    ">\n",
    "> **Note:** For endpoints automatically deployed by the pipeline stack (such as the thumbnail generator), there are options available to configure this directly in CDK - which you may prefer.\n",
    "\n",
    "SageMaker Async Inference endpoints support [auto-scaling down to zero instances](https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference-autoscale.html) when not in use, which can provide significant cost-savings for use cases where document processing is occasional and the pipeline is often idle.\n",
    "\n",
    "⏰ **However:** You should be aware that enabling scale-to-zero can introduce cold-start delays of **several minutes** if requests arrive when all instances backing your endpoint have been shut down."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6365f067-0180-4b1a-b10d-fb1c6c04a482",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Setting up auto-scaling\n",
    "\n",
    "You can configure auto-scaling for your endpoint(s) by first registering them with the [application auto-scaling service](https://docs.aws.amazon.com/autoscaling/application/userguide/what-is-application-auto-scaling.html) and then applying a scaling policy as shown in the following cells.\n",
    "\n",
    "First, configure which SageMaker endpoint you want to auto-scale by name. SageMaker endpoints may be backed by multiple [variants](https://docs.aws.amazon.com/sagemaker/latest/dg/model-ab-testing.html) which can scale independently, but this sample only typically uses the default \"AllTraffic\" variant."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9ed66fc-8f53-43b5-9cdb-b5014620c621",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For example, maybe you want to configure whichever enrichment model is currently in pipeline:\n",
    "endpoint_name = ssm.get_parameter(\n",
    "    Name=config.sagemaker_endpoint_name_param,\n",
    ")[\"Parameter\"][\"Value\"]\n",
    "\n",
    "# Default variant name unless you know otherwise:\n",
    "variant_name = \"AllTraffic\"\n",
    "\n",
    "print(f\"Configuring endpoint name:\\n  {endpoint_name}\")\n",
    "print(f\"Configuring variant name:\\n  {variant_name}\")\n",
    "\n",
    "resource_id = f\"endpoint/{endpoint_name}/variant/{variant_name}\"\n",
    "print(f\"\\nAuto-scaling resource ID:\\n  {resource_id}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fd8a102-b4ba-43d8-9d74-9be06f7b00bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "endpoint_name = \"ocr-thumbnail-2022-10-14-03-37-58-529\"\n",
    "variant_name = \"AllTraffic\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d58a5203-3761-4b53-b492-cecd44cff541",
   "metadata": {},
   "source": [
    "From your endpoint and variant name, register a scalable target to configure overall limits:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4d0d2a7-4d9b-42f9-9f0c-db45d552a3c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "appscaling = boto3.client(\"application-autoscaling\")\n",
    "\n",
    "# Define and register your endpoint variant\n",
    "appscaling.register_scalable_target(\n",
    "    ServiceNamespace=\"sagemaker\",\n",
    "    ResourceId=resource_id,\n",
    "    ScalableDimension=\"sagemaker:variant:DesiredInstanceCount\",\n",
    "    MinCapacity=0,  # (MinCapacity 0 not supported with real-time endpoints)\n",
    "    MaxCapacity=5,\n",
    ")\n",
    "print(f\"Endpoint registered with auto-scaling service: {endpoint_name}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0dd28f7a-af12-453d-8eec-8787ad47a796",
   "metadata": {},
   "source": [
    "We can also list any scaling policies that may already be active on this resource:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edf1ee55-0ee7-4c87-bcb5-5e3526ae5b72",
   "metadata": {},
   "outputs": [],
   "source": [
    "appscaling.describe_scaling_policies(ResourceId=resource_id, ServiceNamespace=\"sagemaker\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8dfe4ceb-c4ab-4621-b02a-cb52e4cc869a",
   "metadata": {},
   "source": [
    "As discussed in the [SageMaker Asynchronous Inference Developer Guide](https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference-autoscale.html), the typical recommended scaling policy for asynchronous endpoints is to track a target on the number of queued requests per active instance - `ApproximateBacklogSizePerInstance`.\n",
    "\n",
    "However, ⚠️ setting this target value `>=1.0` can yield **un-bounded latency** for single requests arriving when the endpoint has scaled off to 0 instances - because scale-out will not be triggered until a big enough queue has formed.\n",
    "\n",
    "You can **combine multiple policies** to set up backlog target tracking but also ensure at least one instance gets started when any requests are in queue, using the alternative `HasBacklogWithoutCapacity` metric:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "052a1c96-f81e-4001-9ff6-81f355c06356",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Main backlog-per-instance target tracking policy:\n",
    "scaling_policy_resp = appscaling.put_scaling_policy(\n",
    "    PolicyName=\"BacklogTargetTracking\",\n",
    "    ServiceNamespace=\"sagemaker\",\n",
    "    ResourceId=resource_id,\n",
    "    ScalableDimension=\"sagemaker:variant:DesiredInstanceCount\",\n",
    "    PolicyType=\"TargetTrackingScaling\",\n",
    "    TargetTrackingScalingPolicyConfiguration={\n",
    "        \"TargetValue\": 4.0,\n",
    "        \"CustomizedMetricSpecification\": {\n",
    "            \"MetricName\": \"ApproximateBacklogSizePerInstance\",\n",
    "            \"Namespace\": \"AWS/SageMaker\",\n",
    "            \"Dimensions\": [\n",
    "                {\"Name\": \"EndpointName\", \"Value\": endpoint_name},\n",
    "            ],\n",
    "            \"Statistic\": \"Average\",\n",
    "        },\n",
    "        \"ScaleInCooldown\": 5 * 60,  # (seconds)\n",
    "        \"ScaleOutCooldown\": 4 * 60,  # (seconds)\n",
    "    },\n",
    ")\n",
    "print(f\"Created/updated scaling policy ARN:\\n{scaling_policy_resp['PolicyARN']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06f8d34a-0129-4309-a2a6-1ce581714ea6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extra policy to ensure one-off requests get processed promptly:\n",
    "scaling_policy_resp = appscaling.put_scaling_policy(\n",
    "    PolicyName=\"BootstrapSingleRequests\",\n",
    "    ServiceNamespace=\"sagemaker\",\n",
    "    ResourceId=resource_id,\n",
    "    ScalableDimension=\"sagemaker:variant:DesiredInstanceCount\",\n",
    "    PolicyType=\"StepScaling\",\n",
    "    StepScalingPolicyConfiguration={\n",
    "        \"AdjustmentType\": \"ChangeInCapacity\",\n",
    "        \"StepAdjustments\": [{\"MetricIntervalLowerBound\": 1.0, \"ScalingAdjustment\": +1}],\n",
    "        \"Cooldown\": 150,  # (Seconds)\n",
    "        \"MetricAggregationType\": \"Average\",\n",
    "    },\n",
    ")\n",
    "print(f\"Created/updated scaling policy ARN:\\n{scaling_policy_resp['PolicyARN']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09177701-224c-4e49-bfcd-b14310a95d9d",
   "metadata": {},
   "source": [
    "Your endpoint should now be set up to auto-scale. Refer to the [Endpoints section of the SageMaker Console](https://console.aws.amazon.com/sagemaker/home?#/endpoints) on the detail page for your target endpoint to check.\n",
    "\n",
    "### Disabling auto-scaling\n",
    "\n",
    "If you'd like to de-register an endpoint from auto-scaling, you can delete attached policies and de-register the target as shown below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbd557f2-e70c-410d-b9c6-715744b95dbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "policies = appscaling.describe_scaling_policies(\n",
    "    ResourceId=resource_id,\n",
    "    ServiceNamespace=\"sagemaker\",\n",
    ")[\"ScalingPolicies\"]\n",
    "\n",
    "print(f\"Deleting scaling policies for {resource_id}:\")\n",
    "time.sleep(3)\n",
    "\n",
    "for policy in policies:\n",
    "    appscaling.delete_scaling_policy(\n",
    "        PolicyName=policy[\"PolicyName\"],\n",
    "        ServiceNamespace=policy[\"ServiceNamespace\"],\n",
    "        ResourceId=policy[\"ResourceId\"],\n",
    "        ScalableDimension=policy[\"ScalableDimension\"],\n",
    "    )\n",
    "    print(f\" - {policy['PolicyName']}\")\n",
    "print(\"\\nDone\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84c4922b-c218-4569-94f0-57b10b3c94bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"De-registering from auto-scaling:\\n  {resource_id}\")\n",
    "time.sleep(3)\n",
    "\n",
    "appscaling.deregister_scalable_target(\n",
    "    ServiceNamespace=\"sagemaker\",\n",
    "    ResourceId=resource_id,\n",
    "    ScalableDimension=\"sagemaker:variant:DesiredInstanceCount\",\n",
    ")\n",
    "print(\"Done\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d3c5912-dfe8-409f-8543-90b53d12c688",
   "metadata": {
    "tags": []
   },
   "source": [
    "---\n",
    "\n",
    "*[Back to contents](#Contents)*\n",
    "\n",
    "## Experimenting with alternative OCR engines\n",
    "\n",
    "> This section demonstrates how to process a batch of documents using alternative, open-source-based OCR engines on Amazon SageMaker - in case you have a use case requiring languages not yet supported by Amazon Textract.\n",
    "\n",
    "As detailed further in the [Customization Guide](../CUSTOMIZATION_GUIDE.md) - You can use alternative, open-source-based OCR engines with this solution if needed, by packaging them to produce Amazon Textract-compatible result formats and integrating them with the pipeline, for which we use Amazon SageMaker Asynchronous Inference for consistency with other steps."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0ad4044-fca1-4579-8b48-ba9d1322f367",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Deploy the alternative engine(s)\n",
    "\n",
    "First, (re)-deploy your solution with the `BUILD_SM_OCRS` variable set, to create container image(s) and SageMaker model(s) for your chosen OCR engine(s).\n",
    "\n",
    "Because resource tags are automatically added to these deployed models, you'll be able to look them up using the same name - by the code below. For example, `ocr_engine_name=tesseract` in the notebook assumes `BUILD_SM_OCRS=...,tesseract,...` at CDK deploy time:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1385894-0c2c-4691-9ee9-cf137e81f5a5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "ocr_engine_name = \"tesseract\"\n",
    "ocr_model_desc = util.ocr.describe_sagemaker_ocr_model(ocr_engine_name)\n",
    "\n",
    "print(f\"Found OCR engine {ocr_engine_name}:\\n  {ocr_model_desc['ModelName']}\")\n",
    "\n",
    "ocr_image_uri = ocr_model_desc[\"PrimaryContainer\"][\"Image\"]\n",
    "print(f\"\\nImage: {ocr_image_uri}\")\n",
    "ocr_environment = ocr_model_desc[\"PrimaryContainer\"][\"Environment\"]\n",
    "print(f\"Environment variables:\\n{ocr_environment}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16185606-2a2b-416f-9764-bc022dba5bdb",
   "metadata": {},
   "source": [
    "### Extract documents in batch\n",
    "\n",
    "Just like with batch page image generation in notebook 1, we'll use a SageMaker Processing Job to run the work on a scalable cluster of instances. The input document locations are specified the same way as for page image generation, so the code below takes the whole corpus (S3 prefix) for simplicity.\n",
    "\n",
    "> ⏰ If you'd like to select **just a subset of documents**, you can instead set `ocr_inputs` using the same manifest-based \"OPTION 2\" approach shown to set `preproc_inputs` in the *Extract clean input images* section of [Notebook 1](1.%20Data%20Preparation.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91b05109-8b53-4371-a216-4e93e69d9dd6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sagemaker.processing import FrameworkProcessor, ProcessingInput, ProcessingOutput\n",
    "from util.preproc import DummyFramework\n",
    "\n",
    "# S3 input & output locations:\n",
    "raw_s3uri = f\"s3://{bucket_name}/{bucket_prefix}data/raw\"\n",
    "textract_s3uri = f\"s3://{bucket_name}/{bucket_prefix}data/textracted\"\n",
    "\n",
    "# (Assuming whole corpus - see NB1 image pre-processing for manifest-based example)\n",
    "ocr_inputs = [\n",
    "    ProcessingInput(\n",
    "        destination=\"/opt/ml/processing/input/raw\",  # Expected input location, per our script\n",
    "        input_name=\"raw\",\n",
    "        s3_data_distribution_type=\"ShardedByS3Key\",  # Distribute between instances, if multiple\n",
    "        source=raw_s3uri,  # S3 prefix for full raw document collection\n",
    "    ),\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c27c169-dd52-45c7-a656-2146ecbb12c7",
   "metadata": {},
   "source": [
    "After defining the input and output locations, and with our pre-prepared container image identified, we're ready to run the job.\n",
    "\n",
    "> ⏰ In our tests, the provided Tesseract OCR integration took around 35 minutes on 5x `ml.c5.4xlarge` instances, to process the full ~2,500 document credit cards corpus for English and Thai."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "464154ad-14d5-4dfd-98cf-65070dc596a7",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "processor = FrameworkProcessor(\n",
    "    estimator_cls=DummyFramework,\n",
    "    image_uri=ocr_image_uri,  # As created above\n",
    "    framework_version=None,\n",
    "    base_job_name=\"ocr-custom\",\n",
    "    role=sagemaker.get_execution_role(),\n",
    "    instance_count=5,\n",
    "    instance_type=\"ml.c5.4xlarge\",\n",
    "    volume_size_in_gb=16,\n",
    "    max_runtime_in_seconds=60*60,\n",
    "    env={\n",
    "        \"OMP_THREAD_LIMIT\": \"1\",  # Optimize Tesseract parallelism for batch\n",
    "        \"PYTHONUNBUFFERED\": \"1\",  # For debugging\n",
    "        **ocr_environment,\n",
    "        # Override defaults from the model env vars like this:\n",
    "        \"OCR_DEFAULT_LANGUAGES\": \"eng,tha\",\n",
    "    },\n",
    ")\n",
    "\n",
    "processor.run(\n",
    "    code=\"ocr.py\",  # OCR script\n",
    "    source_dir=\"preproc\",\n",
    "    inputs=ocr_inputs[:],  # Either whole corpus or sample, as above\n",
    "    outputs=[\n",
    "        ProcessingOutput(\n",
    "            destination=textract_s3uri,\n",
    "            output_name=\"ocr\",\n",
    "            s3_upload_mode=\"Continuous\",\n",
    "            source=\"/opt/ml/processing/output/ocr\",  # Output folder, per our script\n",
    "        ),\n",
    "    ],\n",
    "    #logs=False,\n",
    "    #wait=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9381ff91-84c3-4d7d-9c73-775012fcf31d",
   "metadata": {},
   "source": [
    "Once the job is complete, you can crawl the results on Amazon S3 to build up an equivalent manifest ready for the next stage of data preparation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84b4e4e8-c3bc-422c-ae14-4d22062c71c7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Given that raw docs live under some S3 prefix:\n",
    "raw_s3uri_prefix = raw_s3uri\n",
    "\n",
    "# ...And Amazon Textract results live under another:\n",
    "textract_s3uri = f\"s3://{bucket_name}/{bucket_prefix}data/textracted\"\n",
    "\n",
    "\n",
    "# ...And you can define a mapping from one to the other:\n",
    "def doc_uri_to_textract_uri(doc_uri: str) -> str:\n",
    "    if not doc_uri.startswith(raw_s3uri_prefix):\n",
    "        raise ValueError(\n",
    "            \"Document S3 URI '%s' did not start with expected prefix: '%s'\"\n",
    "            % (doc_uri, raw_s3uri_prefix)\n",
    "        )\n",
    "    # Replace raw prefix with Textract prefix, and add \"/consolidated.json\" to filename:\n",
    "    return textract_s3uri + doc_uri[len(raw_s3uri_prefix):] + \"/consolidated.json\"\n",
    "\n",
    "# Then build up the combined manifest, checking existence for each result:\n",
    "out_filename = \"data/textracted-all-smocr.manifest.jsonl\"\n",
    "print(f\"Building manifest: {out_filename} ...\")\n",
    "with open(\"data/raw-sample.manifest.jsonl\") as fin:\n",
    "    with open(out_filename, \"w\") as fout:\n",
    "        for doc in (json.loads(line) for line in fin):\n",
    "            textract_uri = doc_uri_to_textract_uri(doc[\"raw-ref\"])\n",
    "            if not util.s3.s3_object_exists(textract_uri):\n",
    "                raise ValueError(\n",
    "                    \"Mapped OCR result URI does not exist in S3.\\nFor: %s\\nGot: %s\"\n",
    "                    % (doc[\"raw-ref\"], textract_uri)\n",
    "                )\n",
    "            doc[\"textract-ref\"] = textract_uri\n",
    "            fout.write(json.dumps(doc) + \"\\n\")\n",
    "print(\"Done!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "048ef094-9f94-466d-bb93-729c4ac82348",
   "metadata": {},
   "source": [
    "### Integrate with the document pipeline\n",
    "\n",
    "The above steps demonstrate how to process documents in batch with alternative, open-source OCR engines, to produce datasets ready for experimenting with multi-lingual model architectures like LayoutXLM. To actually deploy the alternative OCR into your document pipeline, use the `DEPLOY_SM_OCR` and `USE_SM_OCR` variables at CDK deployment. You'll likely want to update `OCR_DEFAULT_LANGUAGES` in [/pipeline/ocr/sagemaker_ocr.py](../pipeline/ocr/sagemaker_ocr.py) to align with your use case's language needs."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88994b13-bd18-4793-a4b9-ac3978d2fdf5",
   "metadata": {
    "tags": []
   },
   "source": [
    "---\n",
    "\n",
    "*[Back to contents](#Contents)*\n",
    "\n",
    "## Exploring sequence-to-sequence models\n",
    "\n",
    "> This section demonstrates training a (non-layout-aware) model that edits extracted text fields to normalize data types or fix common OCR error patterns.\n",
    "\n",
    "Since the main flow of this solution focusses on \"extractive\" entity recognition models, you might reasonably wonder whether the same layout-aware ideas could be extended to \"generative\" models capable of actually editing the OCR'd text: For example to reformat fields or fix errors. The answer to this is **\"probably yes, but...\"**:\n",
    "\n",
    "1. Care needs to be taken with large generative models to address bias and privacy concerns: For example will it be possible to extract sensitive or PII data the model was trained on, when it's deployed? Will it be biased to predicting certain patterns that aren't representative of your documents, or are representative on average but leave some user groups with consistently poorer service?\n",
    "2. Published, pre-trained, layout-aware document models have most often provided a decoder-only stack to date: so finding pre-trained initial weights for a generative output module may be challenging. Due to their large size, training these modules from scratch could be resource-intensive.\n",
    "\n",
    "Here we show a more basic approach to start realizing some of the same benefits: Pairing the layout-aware NER model **alongside text-only seq2seq models** to normalize and standardize extracted fields."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1cf54db0-354b-4e3c-acba-c81c15f6c1ca",
   "metadata": {},
   "source": [
    "### Collect datasets\n",
    "\n",
    "In this example we'll demonstrate **normalizing dates** to a consistent format. Text-to-text models can tackle this in a flexible, example-driven and statistics-oriented way. Although maximum achievable accuracy might sometimes be higher with rule-based approaches, we'll show how the ML-based approach can yield good performance quickly without needing to build lots of rules and parsing expressions.\n",
    "\n",
    "This task can be tackled via **synthetic dataset generation**: randomly generating dates and input prompts, according to expected statistical distribution of your target data.\n",
    "\n",
    "Run the cell below to generate a training and evaluation dataset. As shown in the preview, the data will include a wide range of source date formats but **also** support multiple different *target* formats - controllable via the first part of the prompt:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67265da0-8ec5-46ce-837a-7883d435a7b6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from src.code.data.seq2seq.date_normalization import generate_seq2seq_date_norm_dataset\n",
    "\n",
    "rng = np.random.default_rng(42)\n",
    "train_dataset = generate_seq2seq_date_norm_dataset(n=1000, rng=rng)\n",
    "eval_dataset = generate_seq2seq_date_norm_dataset(n=200, rng=rng)\n",
    "\n",
    "train_dataset.save_to_disk(\"data/seq2seq-train\")\n",
    "eval_dataset.save_to_disk(\"data/seq2seq-validation\")\n",
    "\n",
    "print(\"Dataset sample (top 15 records):\")\n",
    "pd.DataFrame(train_dataset[0:15])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38389ddb-80b6-45fe-94d9-2f569becbe9c",
   "metadata": {},
   "source": [
    "As usual with SageMaker, once the datasets are prepared we'll stage them to Amazon S3 ready to use in model training:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1a21963-a264-49f7-8fae-17f5b5a502c1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_s3uri = f\"s3://{bucket_name}/{bucket_prefix}seq2seq/date-norm/train\"\n",
    "validation_s3uri = f\"s3://{bucket_name}/{bucket_prefix}seq2seq/date-norm/validation\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03e5ad33-d8d5-4e85-9f69-f922dcebf0ac",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!aws s3 sync --delete data/seq2seq-train {train_s3uri}\n",
    "!aws s3 sync --delete data/seq2seq-validation {validation_s3uri}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de96d940-87fb-4b76-b3dc-f4c30fd93473",
   "metadata": {},
   "source": [
    "### 🧪 (Experimental) Training with annotated documents\n",
    "\n",
    "If you annotated your documents using the **custom** SageMaker Ground Truth task UI in Notebook 1 (with OCR transcript reviews), instead of the default (bounding-box-only) UI, you should also be able to directly train the seq2seq model on your manually-annotated data.\n",
    "\n",
    "To do this, set your `train`, `textract` and `validation` channels as shown in Notebook 2 instead of the synthetic/augmented dataset used below. The script will build seq2seq examples from your annotated entity types, raw OCR text, and corrected OCR texts - something like:\n",
    "\n",
    "```json\n",
    "{\n",
    "    \"src_texts\": \"Normalize Card Name: mycool Credit Card.\",\n",
    "    \"tgt_texts\": \"MyCool Credit Card\"\n",
    "}\n",
    "```\n",
    "\n",
    "In the *Integrate with processing pipeline* section below, you'd then configure your normalization prompts to be of the format `Normalize {YourFieldLabel}: ` for each field where you wanted to turn the normalizing model on, instead of the `Convert dates...` prompt we use.\n",
    "\n",
    "You'll probably find it easiest to run through this example with the generated date-normalization dataset first to understand the flow, before trying to use your SMGT annotations instead."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f120616-0ae6-49c0-9821-893e5501cf9d",
   "metadata": {},
   "source": [
    "### Look up custom container images\n",
    "\n",
    "The training and inference jobs in this section will use the same customized container images created in the main notebook series for model training and deployment (see [Notebook 2 Model Training](2.%20Model%20Training.ipynb)): so you need to have built those first.\n",
    "\n",
    "The code below will check the container images are already prepared and staged in your account's Amazon Elastic Container Registry (ECR)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb38b61f-4af4-4190-87a7-ce07a8ab2273",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Configurations:\n",
    "train_repo_name = \"sm-ocr-training\"\n",
    "train_repo_tag = \"hf-4.26-pt-gpu\"  # TODO: Check this matches your ECR repo name and tagging\n",
    "inf_repo_name = \"sm-ocr-inference\"\n",
    "inf_repo_tag = train_repo_tag\n",
    "\n",
    "account_id = sagemaker.Session().account_id()\n",
    "region = os.environ[\"AWS_REGION\"]\n",
    "\n",
    "# Combine together into the final URIs:\n",
    "train_image_uri = f\"{account_id}.dkr.ecr.{region}.amazonaws.com/{train_repo_name}:{train_repo_tag}\"\n",
    "print(f\"Target training image: {train_image_uri}\")\n",
    "inf_image_uri = f\"{account_id}.dkr.ecr.{region}.amazonaws.com/{inf_repo_name}:{inf_repo_tag}\"\n",
    "print(f\"Target inference image: {inf_image_uri}\")\n",
    "\n",
    "# Check from notebook whether the images were successfully created:\n",
    "ecr = boto3.client(\"ecr\")\n",
    "for repo, tag, uri in (\n",
    "    (train_repo_name, train_repo_tag, train_image_uri),\n",
    "    (inf_repo_name, inf_repo_tag, inf_image_uri)\n",
    "):\n",
    "    imgs_desc = ecr.describe_images(\n",
    "        registryId=account_id,\n",
    "        repositoryName=repo,\n",
    "        imageIds=[{\"imageTag\": tag}],\n",
    "    )\n",
    "    assert len(imgs_desc[\"imageDetails\"]) > 0, f\"Couldn't find ECR image {uri} after build\"\n",
    "    print(f\"Found {uri}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d991db9b-0daf-4517-a3da-3e6b24c30a7d",
   "metadata": {},
   "source": [
    "### Train a model\n",
    "\n",
    "With data prepared, model training is very similar to the setup from the main notebooks. Some key differences include:\n",
    "\n",
    "- Setting `task_name: seq2seq` to indicate we're training a sequence-to-sequence model instead of the usual layout-aware `ner`.\n",
    "- Choosing a text-only pre-trained base model compatible with text generation, in this case `t5-base`.\n",
    "- Since the data is synthetic, we can easily generate quite a large dataset in comparison to the amount of training we want to run: So logging, evaluation, and model saving will be controlled in terms of number of training steps rather than number of epochs (passes through the whole dataset)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "264ec056-d96f-440f-8f00-e4df0763980d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sagemaker.huggingface.estimator import HuggingFace as HuggingFaceEstimator\n",
    "\n",
    "hyperparameters = {\n",
    "    \"model_name_or_path\": \"google/byt5-base\",\n",
    "    \"task_name\": \"seq2seq\",\n",
    "    \"logging_steps\": 100,\n",
    "    \"evaluation_strategy\": \"steps\",\n",
    "    \"eval_steps\": 250,  # (=Twice per epoch, at 1000 data points & batch size 2)\n",
    "    # Only need to set do_eval when validation channel is not provided and want to generate:\n",
    "    \"do_eval\": \"1\",\n",
    "    \"save_strategy\": \"steps\",\n",
    "    \"save_steps\": 250,\n",
    "    \"learning_rate\": 1e-4,\n",
    "    \"per_device_train_batch_size\": 2,\n",
    "    \"per_device_eval_batch_size\": 4,\n",
    "    \"seed\": 1337,\n",
    "\n",
    "    \"num_train_epochs\": 5.01,  # Make sure the epoch==5.0 evaluation gets taken\n",
    "    \"early_stopping_patience\": 4,\n",
    "    \"metric_for_best_model\": \"eval_acc\",\n",
    "    # \"greater_is_better\": \"false\",\n",
    "    # Avoid filling up disk with too many saved model checkpoints:\n",
    "    \"save_total_limit\": 10,\n",
    "}\n",
    "\n",
    "metric_definitions = [\n",
    "    {\"Name\": \"epoch\", \"Regex\": util.training.get_hf_metric_regex(\"epoch\")},\n",
    "    {\"Name\": \"learning_rate\", \"Regex\": util.training.get_hf_metric_regex(\"learning_rate\")},\n",
    "    {\"Name\": \"train:loss\", \"Regex\": util.training.get_hf_metric_regex(\"loss\")},\n",
    "    {\n",
    "        \"Name\": \"validation:n_examples\",\n",
    "        \"Regex\": util.training.get_hf_metric_regex(\"eval_n_examples\"),\n",
    "    },\n",
    "    {\"Name\": \"validation:loss_avg\", \"Regex\": util.training.get_hf_metric_regex(\"eval_loss\")},\n",
    "    {\"Name\": \"validation:acc\", \"Regex\": util.training.get_hf_metric_regex(\"eval_acc\")},\n",
    "]\n",
    "\n",
    "estimator = HuggingFaceEstimator(\n",
    "    role=sagemaker.get_execution_role(),\n",
    "    entry_point=\"train.py\",\n",
    "    source_dir=\"src\",\n",
    "    py_version=None,\n",
    "    pytorch_version=None,\n",
    "    transformers_version=None,\n",
    "    image_uri=train_image_uri,  # Use the customized training container image\n",
    "\n",
    "    base_job_name=\"byt5-datenorm\",\n",
    "    output_path=f\"s3://{bucket_name}/{bucket_prefix}trainjobs\",\n",
    "\n",
    "    instance_type=\"ml.p3.2xlarge\",  # t5-base fits on ml.g4dn.xlarge GPU, but not byt5-base\n",
    "    instance_count=1,\n",
    "    volume_size=80,\n",
    "\n",
    "    debugger_hook_config=False,\n",
    "\n",
    "    hyperparameters=hyperparameters,\n",
    "    metric_definitions=metric_definitions,\n",
    "    environment={\n",
    "        # Required for our custom dataset loading code (which depends on tokenizer):\n",
    "        \"TOKENIZERS_PARALLELISM\": \"false\",\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f19603d2-351b-42ad-8d2e-4e709ca85c46",
   "metadata": {},
   "source": [
    "There is no `textract` input data channel for this job, as both the `training` and `validation` datasets simply provide plain text.\n",
    "\n",
    "Run the cell below to kick off the job and view logs.\n",
    "\n",
    "> ⏰ In our tests, the training took about 30 minutes to complete on an `ml.g4dn.xlarge` instance in default configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38eafacf-0f50-4a5f-9bc6-cdda0e315bb7",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "inputs = {\n",
    "    \"train\": train_s3uri,\n",
    "    \"validation\": validation_s3uri,\n",
    "}\n",
    "\n",
    "estimator.fit(inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e12f9b2d-3ad3-4845-a127-b10131a03902",
   "metadata": {},
   "source": [
    "Once the training is complete, you have a model ready to normalize detected dates to specific target formats.\n",
    "\n",
    "As discussed in the main solution notebooks, you can also 'attach' the notebook to previously-completed training jobs as shown below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20e00d35-f89c-48d2-ae2f-4ea253fab41b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "#estimator = HuggingFaceEstimator.attach(\"t5-datenorm-2023-01-09-12-19-12-377\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1cbf1dda-4099-4dfc-89f7-dcc4314b89eb",
   "metadata": {},
   "source": [
    "### Deploy for inference\n",
    "\n",
    "Model deployment is similar to the entity recognition and other models shown in this solution. Note that for this endpoint we'll set up a [real-time inference endpoint](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html) (not specifying an `async_inference_config` as with some other examples), and use a separate [inference_seq2seq.py](src/inference_seq2seq.py) entrypoint because the handling logic is quite different from standard `inference.py` models that consume Amazon Textract JSON."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70b63c81-a95d-4cf2-a99d-ab215ef6afbd",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sagemaker.huggingface import HuggingFaceModel\n",
    "\n",
    "# Look up the model artifact location from the training job:\n",
    "training_job_desc = estimator.latest_training_job.describe()\n",
    "model_s3uri = training_job_desc[\"ModelArtifacts\"][\"S3ModelArtifacts\"]\n",
    "model_name = training_job_desc[\"TrainingJobName\"]\n",
    "\n",
    "# Make sure we don't accidentally re-use same model:\n",
    "try:\n",
    "    smclient.delete_model(ModelName=model_name)\n",
    "    print(f\"Deleted existing model {model_name}\")\n",
    "except smclient.exceptions.ClientError as e:\n",
    "    if not (\n",
    "        e.response[\"Error\"][\"Code\"] in (404, \"404\")\n",
    "        or e.response[\"Error\"].get(\"Message\", \"\").startswith(\"Could not find model\")\n",
    "    ):\n",
    "        raise e\n",
    "\n",
    "model = HuggingFaceModel(\n",
    "    name=model_name,\n",
    "    model_data=model_s3uri,\n",
    "    role=sagemaker.get_execution_role(),\n",
    "    source_dir=\"src/\",\n",
    "    entry_point=\"inference_seq2seq.py\",\n",
    "    py_version=None,\n",
    "    pytorch_version=None,\n",
    "    transformers_version=None,\n",
    "    image_uri=inf_image_uri,\n",
    "    env={\n",
    "        \"PYTHONUNBUFFERED\": \"1\",  # TODO: Disable once debugging is done\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93e06e94-6e4c-4d04-a03e-91c55789e432",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Delete previous endpoint, if already in use:\n",
    "try:\n",
    "    predictor.delete_endpoint(delete_endpoint_config=True)\n",
    "    print(\"Deleting previous endpoint...\")\n",
    "    time.sleep(8)\n",
    "except (NameError, smclient.exceptions.ResourceNotFound):\n",
    "    pass  # No existing endpoint to delete\n",
    "except smclient.exceptions.ClientError as e:\n",
    "    if \"Could not find\" not in e.response[\"Error\"].get(\"Message\", \"\"):\n",
    "        raise e\n",
    "\n",
    "print(\"Deploying model...\")\n",
    "predictor = model.deploy(\n",
    "    endpoint_name=training_job_desc[\"TrainingJobName\"],\n",
    "    initial_instance_count=1,\n",
    "    instance_type=\"ml.m5.large\",\n",
    "    serializer=sagemaker.serializers.JSONSerializer(),\n",
    "    deserializer=sagemaker.deserializers.JSONDeserializer(),\n",
    ")\n",
    "print(\"\\nDone!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4242c33f-0539-44eb-994a-b55d9451a684",
   "metadata": {},
   "source": [
    "### Validate the endpoint\n",
    "\n",
    "Once the model is deployed, we can run (some or all of) the evaluation dataset through it to validate performance - as shown below.\n",
    "\n",
    "> ⏰ In our tests, it took about a minute to run the full evaluation dataset through the model. For a faster turnaround, you could process just the first N samples of the dataset by instead running e.g. `eval_results = eval_dataset.select(range(N)).map(...`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "961278f4-1d87-42f6-9675-d7c03642ab91",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import datasets\n",
    "\n",
    "eval_dataset = datasets.load_from_disk(\"data/seq2seq-validation\")\n",
    "\n",
    "\n",
    "def predict_batch(batch):\n",
    "    \"\"\"Run a dataset batch through the SageMaker endpoint and check per-example correctness\"\"\"\n",
    "    input_texts = batch[\"src_texts\"]\n",
    "    result = predictor.predict({\"inputs\": input_texts})\n",
    "    result[\"correct\"] = [\n",
    "        gen == batch[\"tgt_texts\"][ix] for ix, gen in enumerate(result[\"generated_text\"])\n",
    "    ]\n",
    "    return {**batch, **result}\n",
    "\n",
    "\n",
    "eval_results = eval_dataset.map(\n",
    "    predict_batch,\n",
    "    desc=\"Running inference\",\n",
    "    batched=True,\n",
    "    batch_size=16,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c5c1fef-1ba1-49d8-903a-8b601a5f28c3",
   "metadata": {},
   "source": [
    "Below we measure overall \"accuracy\" on this evaluation set and print out some examples, to demonstrate performance:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b26141de-2304-4ee9-a4a0-e3db89fa41ab",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Calculate overall accuracy:\n",
    "n_correct = sum(eval_results[\"correct\"])\n",
    "n_total = len(eval_results)\n",
    "print(\n",
    "    \"{} of {} samples correct.\\n  Overall accuracy: {:.2%}\".format(\n",
    "        n_correct, n_total, n_correct / n_total\n",
    "    )\n",
    ")\n",
    "\n",
    "# Present some examples from the dataset:\n",
    "pd.DataFrame(eval_results)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23c9c260-9633-43c5-b9f5-e01ca23431bb",
   "metadata": {},
   "source": [
    "As shown above, this text-to-text model can take in a raw detected date mention (e.g. `Sunday Dec 31st 2000`) with a prompt prefix (e.g. `Convert dates to YYYY-MM-DD: `) and attempt to output the desired normalized format (e.g. `2000-12-31`).\n",
    "\n",
    "Note that the \"overall accuracy\" metric reported above should match with the `eval_acc` metric emitted by the training job, since the same validation dataset is used."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "868d1622-e1f2-4c51-8f56-cce0db913676",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Integrate with processing pipeline\n",
    "\n",
    "So how can such a field normalizing model be integrated with the overall document processing pipeline?\n",
    "\n",
    "In fact, the **post-processing Lambda function** invoked after our entity recognition model to extract and consolidate entities, is able to call out to additional \"normalizing\" models where required.\n",
    "\n",
    "These are configured through the same **entity/field type configuration** we originally set up for the pipeline in [Notebook 1 (Data Preparation)](1.%20Data%20Preparation.ipynb).\n",
    "\n",
    "First, we can load up the current pipeline entity configuration:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "147b502b-d807-44ce-8818-0345eb059fe4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(\"Loading current pipeline field configuration...\")\n",
    "# Load JSON text from AWS SSM Parameter Store:\n",
    "# (If this fails, you could also try reading from data/field-config.json)\n",
    "fields_json = ssm.get_parameter(Name=config.entity_config_param)[\"Parameter\"][\"Value\"]\n",
    "# Parse the JSON into Python config classes:\n",
    "fields = [\n",
    "    util.postproc.config.FieldConfiguration.from_dict(cfg)\n",
    "    for cfg in json.loads(fields_json)\n",
    "]\n",
    "print(\"Done\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d53907a-2610-40b2-a28a-c51687228bc6",
   "metadata": {},
   "source": [
    "Next, find any entity type that looks like a date (any with 'date' in the name), and configure the normalizer for those fields:\n",
    "\n",
    "> ⚠️ **Note:** Check the way you prompt your normalization model matches how it was trained, for good results!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "014d5107-40c2-4a5d-afd3-3be4a9e2a6d3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for f in fields:\n",
    "    if \"date\" in f.name.lower():\n",
    "        print(f\"Found date field: {f.name}\")\n",
    "        f.normalizer_endpoint = predictor.endpoint_name\n",
    "        print(f\"  - Setting normalizer_endpoint = '{f.normalizer_endpoint}'\")\n",
    "        f.normalizer_prompt = \"Convert dates to YYYY-MM-DD: \"\n",
    "        print(f\"  - Setting normalizer_prompt = '{f.normalizer_prompt}'\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bbca20a2-0d9e-46e5-bc35-e3934dd28997",
   "metadata": {},
   "source": [
    "When you're happy with the updated field configuration, you can run the below to update the pipeline parameter:\n",
    "\n",
    "You may also like to check these updates in the [AWS Systems Manager Parameter Store console](https://console.aws.amazon.com/systems-manager/parameters/?&tab=Table)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b9b7dff-8070-4d7a-a59a-fcc82807ab66",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(\"Saving new field configuration locally...\")\n",
    "with open(\"data/field-config.json\", \"w\") as f:\n",
    "    f.write(json.dumps(\n",
    "        [cfg.to_dict() for cfg in fields],\n",
    "        indent=2,\n",
    "    ))\n",
    "\n",
    "print(\"Uploading new field configuration to pipeline...\")\n",
    "pipeline_entity_config = json.dumps([f.to_dict(omit=[\"annotation_guidance\"]) for f in fields], indent=2)\n",
    "ssm.put_parameter(\n",
    "    Name=config.entity_config_param,\n",
    "    Overwrite=True,\n",
    "    Value=pipeline_entity_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91826510-9e89-479a-a5f8-054b0c4074d3",
   "metadata": {},
   "source": [
    "After updating your pipeline's field configuration SSM parameter to set `normalizer_endpoint` and `normalizer_prompt` on your target entity types, your pipeline's Post-processing Lambda function should automatically start calling your SageMaker model endpoints to normalize mentions on the relevant fields. For example with the Credit Card Agreements sample data, you should see that `Agreement Effective Date` results start to show in `YYYY-MM-DD` format instead of the document's source format, when reviewing results in Amazon A2I or the Step Functions console.\n",
    "\n",
    "> ⚠️ **Note:** There may be a few minutes' delay before normalization starts to take effect, if your post-processing Lambda is configured to cache the SSM configuration. Check your AWS Lambda logs for error messages, in case normalization model calls are failing.\n",
    "\n",
    "This example of normalizing individual extracted date fields is just one option in a spectrum of ways you could combine generative and extractive models for document understanding. For example, you could:\n",
    "\n",
    "- Train additional normalization types, for example for other data types or to fix common OCR error patterns\n",
    "- Include more context from around the original mention, to help the model perform better (such as interpreting whether a raw date is likely to be DD/MM or MM/DD given other information)\n",
    "- Explore linking generative and layout-aware aspects into one end-to-end trainable model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a9b4159-ad1a-4ecf-886c-7efe183b37c4",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "*[Back to contents](#Contents)*"
   ]
  }
 ],
 "metadata": {
  "instance_type": "ml.t3.medium",
  "kernelspec": {
   "display_name": "Python 3 (PyTorch 1.10 Python 3.8 CPU Optimized)",
   "language": "python",
   "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/pytorch-1.10-cpu-py38"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}