#
#   Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#   SPDX-License-Identifier: MIT-0
#
from typing import TypedDict, List, Any
import json
import boto3
import os


class TextractSnsMessage(TypedDict):
    Message: str
    TopicArn: str


class TextractSnsEvent(TypedDict):
    Sns: TextractSnsMessage


class TextractSnsEvents(TypedDict):
    Records: List[TextractSnsEvent]


def get_task_token(sfn, execution_id: str, sns_topic_arn: str):
    """
    Return the task token for the textract execution that just completed
    """
    execution_arn = "{}:{}".format(
        os.environ["STATE_MACHINE_ARN"].replace(":stateMachine:", ":execution:"),
        execution_id,
    )

    events = sfn.get_execution_history(executionArn=execution_arn)["events"]

    # Find the scheduled event
    for i in range(0, len(events)):
        event = events[i]

        if event["type"] == "TaskScheduled" and "parameters" in event.get(
            "taskScheduledEventDetails", {}
        ):
            parameters = json.loads(event["taskScheduledEventDetails"]["parameters"])
            payload = parameters.get("Payload", {})

            if (
                payload.get("sns_topic_arn") == sns_topic_arn
                and "task_token" in payload
            ):
                return payload["task_token"]

    # Stop the execution to avoid indefinitely waiting if there's an unexpected error
    sfn.stop_execution(
        executionArn=execution_arn,
        error="Unable to find task token",
        cause="This indicates a problem with the code in the on_complete handler, or the cdk definition for this state machine",
    )


def handler(event: TextractSnsEvents, context: Any):
    """
    Handler called when a textract job is completed
    """
    sfn = boto3.client("stepfunctions")

    for i in range(0, len(event["Records"])):
        sns_event = event["Records"][i]["Sns"]
        textract_job_string = sns_event["Message"]
        textract_job = json.loads(textract_job_string)

        task_token = get_task_token(sfn, textract_job["JobTag"], sns_event["TopicArn"])

        # Indicate to the state machine that the textract job has succeeded or failed
        if textract_job["Status"] == "SUCCEEDED":
            sfn.send_task_success(taskToken=task_token, output=textract_job_string)
        else:
            sfn.send_task_failure(
                taskToken=task_token,
                error=textract_job.get("StatusMessage", textract_job["Status"]),
            )