# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0
"""Annotation consolidation Lambda for BBoxes+transcriptions in SageMaker Ground Truth
"""
# Python Built-Ins:
import json
import logging
from typing import List, Optional

# External Dependencies:
import boto3  # AWS SDK for Python

# Set up logger before local imports:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Local Dependencies:
from data_model import SMGTWorkerAnnotation  # Custom task data model (edit if needed!)
from smgt import (  # Generic SageMaker Ground Truth parsers/utilities
    ConsolidationRequest,
    ObjectAnnotationResult,
    PostConsolidationDatum,
)


s3 = boto3.client("s3")


def consolidate_object_annotations(
    object_data: ObjectAnnotationResult,
    label_attribute_name: str,
    label_categories: Optional[List[str]] = None,
) -> PostConsolidationDatum:
    """Consolidate the (potentially multiple) raw worker annotations for a dataset object

    TODO: Actual consolidation/reconciliation of multiple labels is not yet supported!

    This function just takes the "first" (not necessarily clock-first) worker's result and outputs
    a warning if others were found.

    Parameters
    ----------
    object_data :
        Object describing the raw annotations and metadata for a particular task in the SMGT job
    label_attribute_name :
        Target attribute on the output object to store consolidated label results (note this may
        not be the *only* attribute set/updated on the output object, hence provided as a param
        rather than abstracted away).
    label_categories :
        Label categories specified when creating the labelling job. If provided, this is used to
        translate from class names to numeric class_id similarly to SMGT's built-in bounding box
        task result.
    """
    warn_msgs: List[str] = []
    worker_anns: List[SMGTWorkerAnnotation] = []
    for worker_ann in object_data.annotations:
        ann_raw = worker_ann.fetch_data()
        worker_anns.append(SMGTWorkerAnnotation.parse(ann_raw, class_list=label_categories))

    if len(worker_anns) > 1:
        warn_msg = (
            "Reconciliation of multiple worker annotations is not currently implemented for this "
            "post-processor. Outputting annotation from worker %s and ignoring labels from %s"
            % (
                object_data.annotations[0].worker_id,
                [a.worker_id for a in object_data.annotations[1:]],
            )
        )
        logger.warning(warn_msg)
        warn_msgs.append(warn_msg)

    consolidated_label = worker_anns[0].to_jsonable()
    if len(warn_msgs):
        consolidated_label["consolidationWarnings"] = warn_msgs

    return PostConsolidationDatum(
        dataset_object_id=object_data.dataset_object_id,
        consolidated_content={
            label_attribute_name: consolidated_label,
            # Note: In our tests it's not possible to add a f"{label_attribute_name}-meta" field
            # here - it gets replaced by whatever post-processing happens, instead of merged.
        },
    )


def handler(event: dict, context) -> List[dict]:
    """Main Lambda handler for consolidation of SMGT worker annotations

    This function receives a batched request to consolidate (multiple?) workers' annotations for
    multiple objects, and outputs the consolidated results per object. For more docs see:

    https://docs.aws.amazon.com/sagemaker/latest/dg/sms-custom-templates-step3-lambda-requirements.html
    """
    logger.info("Received event: %s", json.dumps(event))
    req = ConsolidationRequest.parse(event)
    if req.label_categories and len(req.label_categories) > 0:
        label_cats = req.label_categories
    else:
        logger.warning(
            "Label categories list (see CreateLabelingJob.LabelCategoryConfigS3Uri) was not "
            "provided when creating this job. Post-consolidation outputs will be incompatible with "
            "built-in Bounding Box task, because we're unable to map class names to numeric IDs."
        )
        label_cats = None

    # Loop through the objects in this batch, consolidating annotations for each:
    return [
        consolidate_object_annotations(
            object_data,
            label_attribute_name=req.label_attribute_name,
            label_categories=label_cats,
        ).to_jsonable()
        for object_data in req.fetch_object_annotations()
    ]