from datetime import datetime
import os
import json

import boto3

sfn_client = boto3.client("stepfunctions")

TRAINING_STATE_MACHINE = os.environ.get("TRAINING_STATE_MACHINE")
INFERENCE_STATE_MACHINE = os.environ.get("INFERENCE_STATE_MACHINE")
PERFORMANCE_STATE_MACHINE = os.environ.get("PERFORMANCE_STATE_MACHINE")


def state_machine_start_execution(state_machines_arn, name, sf_input):
    sfn_client.start_execution(
        stateMachineArn=state_machines_arn,
        name=name,
        input=json.dumps(sf_input),
    )


def lambda_handler(event, context):

    print(event)

    if event.get("Records"):
        s3_objects = [
            {
                "bucket": record["s3"]["bucket"]["name"],
                "key": record["s3"]["object"]["key"],
            }
            for record in event["Records"]
            if record.get("s3")
        ]

        if not s3_objects:
            print("No S3 object find in the event!")
            return

        bucket = s3_objects[0]["bucket"]
        key = s3_objects[0]["key"]

        tok = key.split("/")

        if len(tok) != 4:
            print(f"S3 object prefix {key} is not matching the required format!")
            return

        use_case = tok[0]
        data_type = tok[1]
        state = tok[2]

        if data_type != "raw":
            print("Incorrect data type. Only raw data could trigger the state machine!")
            return

        if use_case in ["training", "inference"]:

            current_time = datetime.now().strftime("%Y-%m-%d-%H-%M")

            job_name = f"{state}-{current_time}"

            state_machine_input = {
                "Comment": "Automatically triggered by S3 event",
                "bucket": bucket,
                "use_case": use_case,
                "state": state,
                "job_name": job_name,
            }

            state_machine_start_execution(
                TRAINING_STATE_MACHINE
                if use_case == "training"
                else INFERENCE_STATE_MACHINE,
                job_name,
                state_machine_input,
            )

            print("State machine successfully triggered")

    elif event.get("source") == "aws.states":

        event_input = json.loads(event.get("detail").get("input"))
        event_output = json.loads(event.get("detail").get("output"))

        state = event_input.get("state")

        if event_output.get("use_case") == "inference":
            print("Only training workflow receives event from state machine!")
            return

        state_machine_type = (
            event.get("detail")
            .get("stateMachineArn")
            .split(":")[-1]
            .split("-")[-1]
            .lower()
        )

        current_time = datetime.now().strftime("%Y-%m-%d-%H-%M")

        job_name = f"{state}-{current_time}"

        state_machine_input = event_input

        state_machine_input[
            "Comment"
        ] = "Automatically triggered by State Machine event"

        if state_machine_type == "training":
            state_machine_input["model_name"] = event_output["model"]["model_name"]

        state_machine_start_execution(
            INFERENCE_STATE_MACHINE
            if state_machine_type == "training"
            else PERFORMANCE_STATE_MACHINE,
            job_name,
            state_machine_input,
        )

        print("State machine successfully triggered")

    else:
        return