import json
import os
import unittest.mock as mock

import boto3
import braket.aws
import braket.tracking
import matplotlib.pyplot as plt

plt.savefig = mock.Mock()


class Mocker:
    mock_level = "ALL"

    def __init__(self):
        self._wrapper = (
            Boto3SessionAllWrapper() if Mocker.mock_level == "ALL" else AwsSessionMinWrapper()
        )
        braket.tracking.Tracker = mock.Mock()
        tracker = braket.tracking.Tracker().start()
        tracker.qpu_tasks_cost.return_value = 0
        tracker.simulator_tasks_cost.return_value = 0

    def set_get_device_result(self, result):
        self._wrapper.boto_client.get_device.return_value = result

    def set_create_quantum_task_result(self, result):
        self._wrapper.boto_client.create_quantum_task.return_value = result

    def set_get_quantum_task_result(self, result):
        self._wrapper.boto_client.get_quantum_task.return_value = result

    def set_cancel_quantum_task_result(self, result):
        self._wrapper.boto_client.cancel_quantum_task.return_value = result

    def set_task_result_return(self, result):
        self._wrapper.task_result_mock.return_value = result

    def set_task_result_side_effect(self, side_effect):
        self._wrapper.task_result_mock.side_effect = side_effect

    def set_search_result(self, result):
        self._wrapper.boto_client.get_paginator.return_value.paginate.return_value = result

    def set_log_streams_result(self, result):
        self._wrapper.boto_client.describe_log_streams.return_value = result

    def set_start_query_result(self, result):
        self._wrapper.boto_client.start_query.return_value = result

    def set_get_query_results_result(self, result):
        self._wrapper.boto_client.get_query_results.return_value = result

    def set_list_objects_v2_result(self, result):
        self._wrapper.boto_client.list_objects_v2.return_value = result

    @property
    def region_name(self):
        return self._wrapper.region_name


def read_file(name, file_path=None):
    if file_path:
        json_path = os.path.join(os.path.dirname(file_path), name)
    else:
        json_path = os.path.join(os.path.dirname(__file__), "default_data", name)
    with open(json_path, "r") as file:
        return file.read()


def mock_default_device_calls(mocker):
    mocker.set_get_device_result(
        {"deviceType": "QPU", "deviceCapabilities": read_file("default_capabilities.json")}
    )
    mocker.set_create_quantum_task_result(
        {
            "quantumTaskArn": "arn:aws:braket:us-west-2:000000:quantum-task/TestARN",
        }
    )
    mocker.set_get_quantum_task_result(
        {
            "quantumTaskArn": "arn:aws:braket:us-west-2:000000:quantum-task/TestARN",
            "status": "COMPLETED",
            "outputS3Bucket": "Test Bucket",
            "outputS3Directory": "Test Directory",
            "shots": 10,
            "deviceArn": "Test Device Arn",
            "ResponseMetadata": {"HTTPHeaders": {"date": ""}},
        }
    )
    mocker.set_task_result_return(read_file("default_results.json"))


def set_level(mock_level):
    Mocker.mock_level = mock_level


class SessionWrapper:
    def __init__(self):
        self.boto_client = mock.Mock()
        self.task_result_mock = mock.Mock()
        self.resource_mock = mock.Mock()

        return_mock = mock.Mock()
        return_mock.read.return_value.decode = self.task_result_mock
        self.resource_mock.Object.return_value.get.return_value = {"Body": return_mock}
        self.boto_client.get_caller_identity.return_value = {"Account": "TestAccount"}
        self.boto_client.meta.region_name = "us-west-2"
        self.boto_client.get_authorization_token.return_value = {
            "authorizationData": [{"authorizationToken": "TestToken"}]
        }


class Boto3SessionAllWrapper(SessionWrapper):
    def __init__(self):
        super().__init__()
        boto3.Session = self

    def __call__(self, *args, **kwargs):
        return self

    def client(self, *args, **kwargs):
        return self.boto_client

    def resource(self, *args, **kwargs):
        return self.resource_mock

    def profile_name(self, *args, **kwargs):
        return mock.Mock()

    def get_credentials(self, *args, **kwargs):
        return mock.Mock()

    @property
    def region_name(self):
        return "us-west-2"


class AwsSessionMinWrapper(SessionWrapper):
    def __init__(self):
        super().__init__()

        AwsSessionFacade._wrapper = self
        AwsSessionFacade.real_get_device = braket.aws.aws_session.AwsSession.get_device
        braket.aws.aws_session.AwsSession.get_device = AwsSessionFacade.get_device
        AwsSessionFacade.real_create_quantum_task = (
            braket.aws.aws_session.AwsSession.create_quantum_task
        )
        braket.aws.aws_session.AwsSession.create_quantum_task = AwsSessionFacade.create_quantum_task
        AwsSessionFacade.real_get_quantum_task = braket.aws.aws_session.AwsSession.get_quantum_task
        braket.aws.aws_session.AwsSession.get_quantum_task = AwsSessionFacade.get_quantum_task
        AwsSessionFacade.real_cancel_quantum_task = (
            braket.aws.aws_session.AwsSession.cancel_quantum_task
        )
        braket.aws.aws_session.AwsSession.cancel_quantum_task = AwsSessionFacade.cancel_quantum_task
        AwsSessionFacade.real_retrieve_s3_object_body = (
            braket.aws.aws_session.AwsSession.retrieve_s3_object_body
        )
        braket.aws.aws_session.AwsSession.retrieve_s3_object_body = (
            AwsSessionFacade.retrieve_s3_object_body
        )
        braket.aws.aws_session.AwsSession.copy_s3_directory = AwsSessionFacade.copy_s3_directory
        AwsSessionMinWrapper.parse_device_config()

    @staticmethod
    def parse_device_config():
        mock_device_config_str = os.getenv("MOCK_DEVICE_CONFIG")
        AwsSessionFacade.mock_device_config = (
            json.loads(mock_device_config_str) if mock_device_config_str else {}
        )
        unsupported_device_config_str = os.getenv("UNSUPPORTED_DEVICE_CONFIG")
        AwsSessionFacade.unsupported_device_config = (
            set(json.loads(unsupported_device_config_str)) if unsupported_device_config_str else {}
        )

    @property
    def region_name(self):
        return boto3.session.Session().region_name


class AwsSessionFacade(braket.aws.AwsSession):
    created_task_arns = set()
    created_task_locations = set()

    def get_device(self, arn):
        device_name = arn.split("/")[-1]
        if device_name in AwsSessionFacade.unsupported_device_config:
            return AwsSessionFacade._wrapper.boto_client.get_device(arn)
        return AwsSessionFacade.real_get_device(self, arn)

    def create_quantum_task(self, **boto3_kwargs):
        if boto3_kwargs and boto3_kwargs["deviceArn"]:
            device_arn = boto3_kwargs["deviceArn"]
            device_name = device_arn.split("/")[-1]
            if device_name in AwsSessionFacade.unsupported_device_config:
                return AwsSessionFacade._wrapper.boto_client.create_quantum_task(boto3_kwargs)[
                    "quantumTaskArn"
                ]
            if device_name in AwsSessionFacade.mock_device_config:
                device_sub = AwsSessionFacade.mock_device_config[device_name]
                if device_sub == "MOCK":
                    return AwsSessionFacade._wrapper.boto_client.create_quantum_task(boto3_kwargs)[
                        "quantumTaskArn"
                    ]
                else:
                    boto3_kwargs["deviceArn"] = device_sub
            task_arn = AwsSessionFacade.real_create_quantum_task(self, **boto3_kwargs)
            AwsSessionFacade.created_task_arns.add(task_arn)
            return task_arn
        return AwsSessionFacade._wrapper.boto_client.create_quantum_task(boto3_kwargs)[
            "quantumTaskArn"
        ]

    def get_quantum_task(self, arn):
        if arn in AwsSessionFacade.created_task_arns:
            task_data = AwsSessionFacade.real_get_quantum_task(self, arn)
            AwsSessionFacade.created_task_locations.add(task_data["outputS3Directory"])
            return task_data
        return AwsSessionFacade._wrapper.boto_client.get_quantum_task(arn)

    def cancel_quantum_task(self, arn):
        if arn in AwsSessionFacade.created_task_arns:
            return AwsSessionFacade.real_cancel_quantum_task(self, arn)
        return AwsSessionFacade._wrapper.boto_client.cancel_quantum_task(arn)

    def copy_s3_directory(self, source_s3_path, destination_s3_path):
        return

    def retrieve_s3_object_body(self, s3_bucket, s3_object_key):
        location = s3_object_key[: s3_object_key.rindex("/")]
        if location in AwsSessionFacade.created_task_locations:
            return AwsSessionFacade.real_retrieve_s3_object_body(self, s3_bucket, s3_object_key)
        if AwsSessionFacade._wrapper.task_result_mock.side_effect is not None:
            return next(AwsSessionFacade._wrapper.task_result_mock.side_effect)
        return AwsSessionFacade._wrapper.task_result_mock.return_value