# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: MIT-0 """Lambda to start an A2I human loop to review a non-confident model output Should be called as an *asynchronous* task from Step Functions (using lambda:invoke.waitForToken) By passing the Step Functions task token to the A2I task as input, we ensure it gets included in the output JSON generated by the task and therefore enable our S3-triggered callback function to retrieve the task token and signal to Step Functions that the review is complete. """ # Python Built-Ins: from datetime import datetime import json import logging import os import re import uuid # External Dependencies: import boto3 logger = logging.getLogger() logger.setLevel(logging.INFO) a2i = boto3.client("sagemaker-a2i-runtime") ssm = boto3.client("ssm") default_flow_definition_arn_param = os.environ.get("DEFAULT_FLOW_DEFINITION_ARN_PARAM") class MalformedRequest(ValueError): """Returned to SFN when input event structure is invalid""" pass def generate_human_loop_name(s3_object_key: str, max_len: int = 63) -> str: """Create a random-but-a-bit-meaningful unique name for human loop job Generated names combine timestamp, object filename, and a random element. """ filename = s3_object_key.rpartition("/")[2] filename_component = re.sub( # Condense double-hyphens: r"--", "-", re.sub( # Cut out any remaining disallowed characters: r"[^a-zA-Z0-9\-]", "", re.sub( # Turn significant punctuation to hyphens: r"[ _.,!?]", "-", filename, ), ), ) # Millis is enough, no need for microseconds: datetime_component = datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f")[:-3] # Most significant bits section of a GUID: random_component = str(uuid.uuid4()).partition("-")[0] clipped_filename_component = filename_component[ : max_len - len(datetime_component) - len(random_component) ] return f"{datetime_component}-{clipped_filename_component}-{random_component}"[:max_len] def handler(event, context): try: task_token = event["TaskToken"] model_result = event["ModelResult"] task_object = event["TaskObject"] if isinstance(task_object, dict): if "S3Uri" in task_object and task_object["S3Uri"]: task_object = task_object["S3Uri"] elif "Bucket" in task_object and "Key" in task_object: task_object = f"s3://{task_object['Bucket']}/{task_object['Key']}" else: raise MalformedRequest( "TaskObject must be an s3://... URI string OR an object with 'S3Uri' key or " f"both 'Bucket' and 'Key' keys. Got {task_object}" ) task_input = { "TaskObject": task_object, "TaskToken": task_token, # Not used within A2I, but for feed-through to callback fn "ModelResult": model_result, } if "FlowDefinitionArn" in event: flow_definition_arn = event["FlowDefinitionArn"] elif default_flow_definition_arn_param: flow_definition_arn = ssm.get_parameter(Name=default_flow_definition_arn_param,)[ "Parameter" ]["Value"] if (not flow_definition_arn) or flow_definition_arn.lower() in ("undefined", "null"): raise MalformedRequest( "Neither request FlowDefinitionArn nor expected SSM parameter are set. Got: " f"{default_flow_definition_arn_param} = '{flow_definition_arn}'" ) else: raise MalformedRequest( "FlowDefinitionArn not specified in request and DEFAULT_FLOW_DEFINITION_ARN_PARAM " "env var not set" ) except KeyError as ke: raise MalformedRequest(f"Missing field {ke}, please check your input payload") logger.info(f"Starting A2I human loop with input {task_input}") a2i_response = a2i.start_human_loop( HumanLoopName=generate_human_loop_name(task_input["TaskObject"]), FlowDefinitionArn=flow_definition_arn, HumanLoopInput={"InputContent": json.dumps(task_input)}, # If adapting this code for use with A2I public workforce, you may need to add additional # content classifiers as described here: # https://docs.aws.amazon.com/sagemaker/latest/dg/sms-workforce-management-public.html # https://docs.aws.amazon.com/augmented-ai/2019-11-07/APIReference/API_HumanLoopDataAttributes.html # DataAttributes={ # "ContentClassifiers": ["FreeOfPersonallyIdentifiableInformation"] # } ) logger.info(f"Human loop started: {a2i_response}") # Doesn't really matter what we return because Step Functions will wait for the callback with # the token! return a2i_response["HumanLoopArn"]