# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: MIT-0 """Utilities for working with SageMaker Ground Truth """ # Python Built-Ins: import json from logging import getLogger import os from string import Template from textwrap import dedent from typing import Iterable, Optional # External Dependencies: import boto3 # General-purpose AWS SDK for Python # Local Dependencies: from .postproc.config import FieldConfiguration botosess = boto3.Session() s3 = botosess.resource("s3") smclient = botosess.client("sagemaker") logger = getLogger("smgt") # Lambda function ARN components for pre- and post-processing with built-in task types, as per: # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_HumanTaskConfig.html # https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AnnotationConsolidationConfig.html SMGT_LAMBDA_CONFIG = { "bounding-box": { "fn-name-pre": "PRE-BoundingBox", "fn-name-post": "ACS-BoundingBox", "regions": { "us-east-1": {"account-id": "432418664414"}, "us-east-2": {"account-id": "266458841044"}, "us-west-2": {"account-id": "081040173940"}, "ca-central-1": {"account-id": "918755190332"}, "eu-west-1": {"account-id": "568282634449"}, "eu-west-2": {"account-id": "487402164563"}, "eu-central-1": {"account-id": "203001061592"}, "ap-northeast-1": {"account-id": "477331159723"}, "ap-northeast-2": {"account-id": "845288260483"}, "ap-south-1": {"account-id": "565803892007"}, "ap-southeast-1": {"account-id": "377565633583"}, "ap-southeast-2": {"account-id": "454466003867"}, }, }, "bounding-box-adjustment": { "fn-name-pre": "PRE-AdjustmentBoundingBox", "fn-name-post": "ACS-AdjustmentBoundingBox", "regions": { "us-east-1": {"account-id": "432418664414"}, "us-east-2": {"account-id": "266458841044"}, "us-west-2": {"account-id": "081040173940"}, "ca-central-1": {"account-id": "918755190332"}, "eu-west-1": {"account-id": "568282634449"}, "eu-west-2": {"account-id": "487402164563"}, "eu-central-1": {"account-id": "203001061592"}, "ap-northeast-1": {"account-id": "477331159723"}, "ap-northeast-2": {"account-id": "845288260483"}, "ap-south-1": {"account-id": "565803892007"}, "ap-southeast-1": {"account-id": "377565633583"}, "ap-southeast-2": {"account-id": "454466003867"}, }, }, } # Template for built-in bounding box task type, as per: # https://docs.aws.amazon.com/sagemaker/latest/dg/sms-bounding-box.html # ...with placeholders for initial value and instructions added. BBOX_TEMPLATE = """ ${instructions_full} ${instructions_short} """ # Liquid template to incorporate initial bbox values from an existing manifest field, as per: # https://docs.aws.amazon.com/sagemaker/latest/dg/sms-ui-template-crowd-bounding-box.html BBOX_INITIAL_VALUE_TEMPLATE = """initial-value="[ {% for box in task.input.manifestLine.${label_attribute_name}.annotations %} {% capture class_id %}{{ box.class_id }}{% endcapture %} {% assign label = task.input.manifestLine.${label_attribute_name}-metadata.class-map[class_id] %} { label: {{label | to_json}}, left: {{box.left}}, top: {{box.top}}, width: {{box.width}}, height: {{box.height}}, }, {% endfor %} ]" """ def ensure_bucket_cors( bucket_name: str, aws_account_id: str = os.environ.get("AWS_ACCOUNT_ID"), ) -> Optional[dict]: """Ensure S3 bucket has a GET * CORS rule as required by SageMaker Ground Truth Parameters ---------- bucket_name : Name of the S3 bucket to configure. You must have s3:GetBucketCors and s3:PutBucketCors permissions on this bucket. aws_account_id : AWS Account ID. If not provided, will attempt to determine from AWS_ACCOUNT_ID environment variable. Returns ------- resp : An S3 PutBucketCors response if a rule was added, else None if a rule was already present. """ bucket_cors = s3.BucketCors(bucket_name) try: existing_rules = bucket_cors.cors_rules except s3.meta.client.exceptions.ClientError as err: if err.response.get("Error", {}).get("Code") == "NoSuchCORSConfiguration": existing_rules = [] else: raise err if any( r for r in existing_rules if "*" in r["AllowedOrigins"] and "GET" in r["AllowedMethods"] ): logger.info(f"Bucket already set up with CORS permissions: %s", bucket_name) return None else: new_rules = existing_rules + [ { "ID": "SageMakerGroundTruth", "AllowedHeaders": [], "AllowedMethods": ["GET"], "AllowedOrigins": ["*"], "ExposeHeaders": [], "MaxAgeSeconds": 60, }, ] cors_resp = bucket_cors.put( CORSConfiguration={"CORSRules": new_rules}, ExpectedBucketOwner=aws_account_id, ) logger.info("Added CORS permissions to bucket: %s", bucket_name) return cors_resp def get_smgt_lambda_arn( pre: bool, task: str = "bounding-box", region: Optional[str] = None, ) -> str: """Get the pre- or post-processing Lambda ARN for a SM Ground Truth built-in task type Based on documentation from: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_HumanTaskConfig.html https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AnnotationConsolidationConfig.html Parameters ---------- pre : bool True to return the pre-processing ARN, False to return the post-processing ARN task : str Which built-in task you're looking for region : Optional[str] The AWS Region you're working in (if not set, will auto-discover from boto3). """ if task not in SMGT_LAMBDA_CONFIG: raise NotImplementedError( f"Task type '{task}' is not in known set {[k for k in SMGT_LAMBDA_CONFIG]}" ) task_cfg = SMGT_LAMBDA_CONFIG[task] if not region: region = botosess.region_name if region not in task_cfg["regions"]: raise NotImplementedError(f"Lambda details not known for task {task} in region {region}") return "arn:aws:lambda:{}:{}:function:{}".format( region, task_cfg["regions"][region]["account-id"], task_cfg["fn-name-pre" if pre else "fn-name-post"], ) def generate_label_category_config( field_configs: Iterable[FieldConfiguration], reviewing_attribute_name: Optional[str] = None, ) -> dict: """Generate content for 'label configuration file' for an SMGT annotation job on pre-built UI The label configuration file includes parameters like the list of class labels and the worker instructions, which need to be populated when using the pre-built task UIs. For more info on file structure, see: https://docs.aws.amazon.com/sagemaker/latest/dg/sms-label-cat-config-attributes.html#sms-label-cat-config-attributes-schema Parameters ---------- field_configs : Iterable[FieldConfiguration] List of field configuration objects describing the label classes (see postproc utils) reviewing_attribute_name : Optional[str] If setting up a label review/validation job, provide the name of the attribute in the manifest where previous annotations are stored. """ result = { "document-version": "2018-11-28", "labels": [{"label": field.name} for field in field_configs], "instructions": { "shortInstruction": ( "Draw bounding boxes to highlight all instances of the listed entity types. Use " "overlapping boxes of the same type to highlight contiguous, non-square regions. " "See the full instructions for more details." ), "fullInstruction": dedent( """

Use the bounding box tool to highlight all instances of the listed entity types on the page.

Overlapping boxes of the same type are consolidated to a single object by the model: So you can use this pattern to highlight non-rectangular regions (such as specific multi-line sentences in a paragraph); and should avoid overlapping same-class boxes where the two regions are semantically separate.

""" ), }, } if any(field.annotation_guidance for field in field_configs): guidance = ["

Per-Field Guidance

"] for field in filter(lambda f: f.annotation_guidance, field_configs): guidance.append(f"

{field.name}

") guidance.append(field.annotation_guidance) result["instructions"]["fullInstruction"] += "\n".join(guidance) if reviewing_attribute_name is not None: result["auditLabelAttributeName"] = reviewing_attribute_name return result def get_bbox_template( header: str, instructions_short: str = "", instructions_full: str = "", reviewing_attribute_name: Optional[str] = None, ) -> str: initial_value_attr = ( Template(BBOX_INITIAL_VALUE_TEMPLATE).substitute( {"label_attribute_name": reviewing_attribute_name} ) if reviewing_attribute_name else "" ) return Template(BBOX_TEMPLATE).substitute( { "header": header, "initial_value_attr": initial_value_attr, "instructions_short": instructions_short, "instructions_full": instructions_full, } ) def workteam_arn_from_name(name: str) -> str: """Validate that a SageMaker Ground Truth Workteam exists with members, and return its ARN""" desc = smclient.describe_workteam(WorkteamName=name) if not len(desc["Workteam"]["MemberDefinitions"]): raise ValueError(f"Workteam '{name}' has no members! Add members to use for annotation") return desc["Workteam"]["WorkteamArn"] def create_bbox_labeling_job( job_name: str, bucket_name: str, execution_role_arn: str, fields: Iterable[FieldConfiguration], input_manifest_s3uri: str, output_s3uri: str, workteam_arn: str, local_inputs_folder: str = os.path.join("data", "manifests"), reviewing_attribute_name: Optional[str] = None, s3_inputs_prefix: str = "data/manifests", task_template: Optional[str] = None, pre_lambda_arn: Optional[str] = None, post_lambda_arn: Optional[str] = None, ) -> dict: """Create a SageMaker Ground Truth labelling job with the built-in Bounding Box task UI Parameters ---------- job_name : Name of the job to create (must be unique in your AWS Account+Region) bucket_name : Name of the S3 bucket where input/output manifests and job metadata will be stored execution_role_arn : ARN of the SageMaker Execution Role (in AWS IAM) that the labelling job will run as. The role must have permission to access your selected `bucket_name`. fields : Iterable[FieldConfiguration] Field/entity types list input_manifest_s3uri : 's3://...' URI where the input JSON-Lines manifest file is (already) stored output_s3uri : 's3://...' URI where the job output should be stored (SMGT will add a job subfolder) workteam_arn : ARN of the SageMaker Ground Truth workteam who will be performing the task local_inputs_folder : Local folder where configuration files for SMGT will be stored before uploading to S3. (Default 'data/manifests') reviewing_attribute_name : Optional[str] Set the name of the manifest attribute where existing labels are stored, to trigger an adjustment job on pre-existing labels. (Default None) s3_inputs_prefix : Key prefix (with or without trailing slash) under which configuration files for SMGT will be uploaded to the S3 bucket_name. (Default 'data/manifests') task_template : Optional custom task template file (local path). If not provided, the standard SMGT Bounding Box task UI will be used. pre_lambda_arn : Override AWS Lambda ARN for Ground Truth task pre-processing. When unset, the default pre-processing Lambda for SMGT Bounding Box (adjustment) task UI will be used. Set this parameter to use your own function instead. post_lambda_arn : Override AWS Lambda ARN for Ground Truth task post-processing. When unset, the default post-processing Lambda for SMGT Bounding Box (adjustment) task UI will be used. Set this parameter to use your own function instead. Returns ------- response : dict As per boto3 sagemaker client.create_labeling_job() """ # Validate/normalize inputs: if local_inputs_folder.endswith(os.path.sep): local_inputs_folder = local_inputs_folder[:-1] if s3_inputs_prefix.startswith("/"): s3_inputs_prefix = s3_inputs_prefix[1:] if s3_inputs_prefix.endswith("/"): s3_inputs_prefix = s3_inputs_prefix[:-1] bucket = s3.Bucket(bucket_name) # Generate and upload a job metadata file (including things like the list of class names, and # any instructions to include for workers): input_category_file = os.path.join(local_inputs_folder, f"{job_name}.meta.json") input_category_s3key = "/".join((s3_inputs_prefix, f"{job_name}.meta.json")) input_category_s3uri = f"s3://{bucket_name}/{input_category_s3key}" with open(input_category_file, "w") as f: label_category_config = generate_label_category_config( fields, reviewing_attribute_name=reviewing_attribute_name, ) f.write(json.dumps(label_category_config)) bucket.upload_file(input_category_file, input_category_s3key) print(f"Uploaded Labeling Category Config {input_category_file} to:\n{input_category_s3uri}") # Generate and upload the task template: task_template_s3key = "/".join((s3_inputs_prefix, f"{job_name}.liquid.html")) task_template_s3uri = f"s3://{bucket_name}/{task_template_s3key}" if task_template is None: task_template_file = os.path.join(local_inputs_folder, f"{job_name}.liquid.html") with open(task_template_file, "w") as f: f.write( get_bbox_template( header="Highlight the entities with bounding boxes", instructions_short=( label_category_config.get("instructions", {}).get("shortInstruction", "") ), instructions_full=( label_category_config.get("instructions", {}).get("fullInstruction", "") ), reviewing_attribute_name=reviewing_attribute_name, ) ) else: task_template_file = task_template bucket.upload_file(task_template_file, task_template_s3key) print(f"Uploaded resolved task UI template {task_template_file} to:\n{task_template_s3uri}") # Create the actual labeling job: task = "bounding-box-adjustment" if reviewing_attribute_name else "bounding-box" return smclient.create_labeling_job( LabelingJobName=job_name, LabelAttributeName=job_name, InputConfig={ "DataSource": { "S3DataSource": {"ManifestS3Uri": input_manifest_s3uri}, }, # If adapting this code for use with A2I public workforce, you may need to add # additional content classifiers as described here: # https://docs.aws.amazon.com/sagemaker/latest/dg/sms-workforce-management-public.html # https://docs.aws.amazon.com/augmented-ai/2019-11-07/APIReference/API_HumanLoopDataAttributes.html # "DataAttributes": { # "ContentClassifiers": [ # "FreeOfPersonallyIdentifiableInformation", "FreeOfAdultContent" # ], # }, }, OutputConfig={ "S3OutputPath": output_s3uri, }, RoleArn=execution_role_arn, LabelCategoryConfigS3Uri=input_category_s3uri, # Required for built-in tasks only HumanTaskConfig={ "WorkteamArn": workteam_arn, "UiConfig": { "UiTemplateS3Uri": task_template_s3uri, }, "PreHumanTaskLambdaArn": ( get_smgt_lambda_arn(pre=True, task=task) if pre_lambda_arn is None else pre_lambda_arn ), "TaskTitle": "Credit Card Agreement Entities", "TaskDescription": "Highlight the entities with bounding boxes", "NumberOfHumanWorkersPerDataObject": 1, "TaskTimeLimitInSeconds": 60 * 60, "TaskAvailabilityLifetimeInSeconds": 10 * 24 * 60 * 60, "MaxConcurrentTaskCount": 250, "AnnotationConsolidationConfig": { "AnnotationConsolidationLambdaArn": ( get_smgt_lambda_arn(pre=False, task=task) if post_lambda_arn is None else post_lambda_arn ), }, }, )