# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from __future__ import annotations

import asyncio
import time
from functools import singledispatch
from logging import Logger, getLogger
from typing import Any, Dict, Union

import boto3

from braket.ahs.analog_hamiltonian_simulation import AnalogHamiltonianSimulation
from braket.annealing.problem import Problem
from braket.aws.aws_session import AwsSession
from braket.circuits import Instruction
from braket.circuits.circuit import Circuit
from braket.circuits.circuit_helpers import validate_circuit_and_shots
from braket.circuits.compiler_directives import StartVerbatimBox
from braket.circuits.gates import PulseGate
from braket.circuits.serialization import (
    IRType,
    OpenQASMSerializationProperties,
    QubitReferenceType,
)
from braket.device_schema import GateModelParameters
from braket.device_schema.dwave import (
    Dwave2000QDeviceParameters,
    DwaveAdvantageDeviceParameters,
    DwaveDeviceParameters,
)
from braket.device_schema.dwave.dwave_2000Q_device_level_parameters_v1 import (
    Dwave2000QDeviceLevelParameters,
)
from braket.device_schema.dwave.dwave_advantage_device_level_parameters_v1 import (
    DwaveAdvantageDeviceLevelParameters,
)
from braket.device_schema.ionq import IonqDeviceParameters
from braket.device_schema.oqc import OqcDeviceParameters
from braket.device_schema.rigetti import RigettiDeviceParameters
from braket.device_schema.simulators import GateModelSimulatorDeviceParameters
from braket.error_mitigation import ErrorMitigation
from braket.ir.blackbird import Program as BlackbirdProgram
from braket.ir.openqasm import Program as OpenQASMProgram
from braket.pulse.pulse_sequence import PulseSequence
from braket.schema_common import BraketSchemaBase
from braket.task_result import (
    AnalogHamiltonianSimulationTaskResult,
    AnnealingTaskResult,
    GateModelTaskResult,
    PhotonicModelTaskResult,
)
from braket.tasks import (
    AnalogHamiltonianSimulationQuantumTaskResult,
    AnnealingQuantumTaskResult,
    GateModelQuantumTaskResult,
    PhotonicModelQuantumTaskResult,
    QuantumTask,
)
from braket.tracking.tracking_context import broadcast_event
from braket.tracking.tracking_events import _TaskCompletionEvent


class AwsQuantumTask(QuantumTask):
    """Amazon Braket implementation of a quantum task. A task can be a circuit or an annealing
    problem."""

    # TODO: Add API documentation that defines these states. Make it clear this is the contract.
    NO_RESULT_TERMINAL_STATES = {"FAILED", "CANCELLED"}
    RESULTS_READY_STATES = {"COMPLETED"}
    TERMINAL_STATES = RESULTS_READY_STATES.union(NO_RESULT_TERMINAL_STATES)

    DEFAULT_RESULTS_POLL_TIMEOUT = 432000
    DEFAULT_RESULTS_POLL_INTERVAL = 1
    RESULTS_FILENAME = "results.json"

    @staticmethod
    def create(
        aws_session: AwsSession,
        device_arn: str,
        task_specification: Union[
            Circuit,
            Problem,
            OpenQASMProgram,
            BlackbirdProgram,
            PulseSequence,
            AnalogHamiltonianSimulation,
        ],
        s3_destination_folder: AwsSession.S3DestinationFolder,
        shots: int,
        device_parameters: Dict[str, Any] = None,
        disable_qubit_rewiring: bool = False,
        tags: Dict[str, str] = None,
        inputs: Dict[str, float] = None,
        *args,
        **kwargs,
    ) -> AwsQuantumTask:
        """AwsQuantumTask factory method that serializes a quantum task specification
        (either a quantum circuit or annealing problem), submits it to Amazon Braket,
        and returns back an AwsQuantumTask tracking the execution.

        Args:
            aws_session (AwsSession): AwsSession to connect to AWS with.

            device_arn (str): The ARN of the quantum device.

            task_specification (Union[Circuit, Problem, OpenQASMProgram, BlackbirdProgram,PulseSequence, AnalogHamiltonianSimulation]): # noqa
                The specification of the task to run on device.

            s3_destination_folder (AwsSession.S3DestinationFolder): NamedTuple, with bucket
                for index 0 and key for index 1, that specifies the Amazon S3 bucket and folder
                to store task results in.

            shots (int): The number of times to run the task on the device. If the device is a
                simulator, this implies the state is sampled N times, where N = `shots`.
                `shots=0` is only available on simulators and means that the simulator
                will compute the exact results based on the task specification.

            device_parameters (Dict[str, Any]): Additional parameters to send to the device.

            disable_qubit_rewiring (bool): Whether to run the circuit with the exact qubits chosen,
                without any rewiring downstream, if this is supported by the device.
                Only applies to digital, gate-based circuits (as opposed to annealing problems).
                If ``True``, no qubit rewiring is allowed; if ``False``, qubit rewiring is allowed.
                Default: False

            tags (Dict[str, str]): Tags, which are Key-Value pairs to add to this quantum task.
                An example would be:
                `{"state": "washington"}`

            inputs (Dict[str, float]): Inputs to be passed along with the
                IR. If the IR supports inputs, the inputs will be updated with this value.
                Default: {}.

        Returns:
            AwsQuantumTask: AwsQuantumTask tracking the task execution on the device.

        Note:
            The following arguments are typically defined via clients of Device.
                - `task_specification`
                - `s3_destination_folder`
                - `shots`

        See Also:
            `braket.aws.aws_quantum_simulator.AwsQuantumSimulator.run()`
            `braket.aws.aws_qpu.AwsQpu.run()`
        """
        if len(s3_destination_folder) != 2:
            raise ValueError(
                "s3_destination_folder must be of size 2 with a 'bucket' and 'key' respectively."
            )

        create_task_kwargs = _create_common_params(
            device_arn,
            s3_destination_folder,
            shots if shots is not None else AwsQuantumTask.DEFAULT_SHOTS,
        )

        if tags is not None:
            create_task_kwargs.update({"tags": tags})
        inputs = inputs or {}

        if isinstance(task_specification, Circuit):
            param_names = {param.name for param in task_specification.parameters}
            unbounded_parameters = param_names - set(inputs.keys())
            if unbounded_parameters:
                raise ValueError(
                    f"Cannot execute circuit with unbound parameters: " f"{unbounded_parameters}"
                )

        return _create_internal(
            task_specification,
            aws_session,
            create_task_kwargs,
            device_arn,
            device_parameters or {},
            disable_qubit_rewiring,
            inputs,
            *args,
            **kwargs,
        )

    def __init__(
        self,
        arn: str,
        aws_session: AwsSession = None,
        poll_timeout_seconds: float = DEFAULT_RESULTS_POLL_TIMEOUT,
        poll_interval_seconds: float = DEFAULT_RESULTS_POLL_INTERVAL,
        logger: Logger = getLogger(__name__),
    ):
        """
        Args:
            arn (str): The ARN of the task.
            aws_session (AwsSession): The `AwsSession` for connecting to AWS services.
                Default is `None`, in which case an `AwsSession` object will be created with the
                region of the task.
            poll_timeout_seconds (float): The polling timeout for `result()`. Default: 5 days.
            poll_interval_seconds (float): The polling interval for `result()`. Default: 1 second.
            logger (Logger): Logger object with which to write logs, such as task statuses
                while waiting for task to be in a terminal state. Default is `getLogger(__name__)`

        Examples:
            >>> task = AwsQuantumTask(arn='task_arn')
            >>> task.state()
            'COMPLETED'
            >>> result = task.result()
            AnnealingQuantumTaskResult(...)

            >>> task = AwsQuantumTask(arn='task_arn', poll_timeout_seconds=300)
            >>> result = task.result()
            GateModelQuantumTaskResult(...)
        """

        self._arn: str = arn
        self._aws_session: AwsSession = aws_session or AwsQuantumTask._aws_session_for_task_arn(
            task_arn=arn
        )
        self._poll_timeout_seconds = poll_timeout_seconds
        self._poll_interval_seconds = poll_interval_seconds

        self._logger = logger

        self._metadata: Dict[str, Any] = {}
        self._result: Union[
            GateModelQuantumTaskResult, AnnealingQuantumTaskResult, PhotonicModelQuantumTaskResult
        ] = None

    @staticmethod
    def _aws_session_for_task_arn(task_arn: str) -> AwsSession:
        """
        Get an AwsSession for the Task ARN. The AWS session should be in the region of the task.

        Returns:
            AwsSession: `AwsSession` object with default `boto_session` in task's region.
        """
        task_region = task_arn.split(":")[3]
        boto_session = boto3.Session(region_name=task_region)
        return AwsSession(boto_session=boto_session)

    @property
    def id(self) -> str:
        """str: The ARN of the quantum task."""
        return self._arn

    def _cancel_future(self) -> None:
        """Cancel the future if it exists. Else, create a cancelled future."""
        if hasattr(self, "_future"):
            self._future.cancel()
        else:
            self._future = asyncio.Future()
            self._future.cancel()

    def cancel(self) -> None:
        """Cancel the quantum task. This cancels the future and the task in Amazon Braket."""
        self._cancel_future()
        self._aws_session.cancel_quantum_task(self._arn)

    def metadata(self, use_cached_value: bool = False) -> Dict[str, Any]:
        """
        Get task metadata defined in Amazon Braket.

        Args:
            use_cached_value (bool): If `True`, uses the value most recently retrieved
                from the Amazon Braket `GetQuantumTask` operation, if it exists; if not,
                `GetQuantumTask` will be called to retrieve the metadata. If `False`, always calls
                `GetQuantumTask`, which also updates the cached value. Default: `False`.
        Returns:
            Dict[str, Any]: The response from the Amazon Braket `GetQuantumTask` operation.
            If `use_cached_value` is `True`, Amazon Braket is not called and the most recently
            retrieved value is used, unless `GetQuantumTask` was never called, in which case
            it wil still be called to populate the metadata for the first time.
        """
        if not use_cached_value or not self._metadata:
            self._metadata = self._aws_session.get_quantum_task(self._arn)
        return self._metadata

    def state(self, use_cached_value: bool = False) -> str:
        """
        The state of the quantum task.

        Args:
            use_cached_value (bool): If `True`, uses the value most recently retrieved
                from the Amazon Braket `GetQuantumTask` operation. If `False`, calls the
                `GetQuantumTask` operation to retrieve metadata, which also updates the cached
                value. Default = `False`.
        Returns:
            str: The value of `status` in `metadata()`. This is the value of the `status` key
            in the Amazon Braket `GetQuantumTask` operation. If `use_cached_value` is `True`,
            the value most recently returned from the `GetQuantumTask` operation is used.
        See Also:
            `metadata()`
        """
        return self._status(use_cached_value)

    def _status(self, use_cached_value: bool = False) -> str:
        metadata = self.metadata(use_cached_value)
        status = metadata.get("status")
        if not use_cached_value and status in self.NO_RESULT_TERMINAL_STATES:
            self._logger.warning(f"Task is in terminal state {status} and no result is available.")
            if status == "FAILED":
                failure_reason = metadata.get("failureReason", "unknown")
                self._logger.warning(f"Task failure reason is: {failure_reason}.")
        return status

    def _update_status_if_nonterminal(self) -> str:
        # If metadata has not been populated, the first call to _status will fetch it,
        # so the second _status call will no longer need to
        metadata_absent = not self._metadata
        cached = self._status(True)
        return cached if cached in self.TERMINAL_STATES else self._status(metadata_absent)

    def result(
        self,
    ) -> Union[
        GateModelQuantumTaskResult, AnnealingQuantumTaskResult, PhotonicModelQuantumTaskResult
    ]:
        """
        Get the quantum task result by polling Amazon Braket to see if the task is completed.
        Once the task is completed, the result is retrieved from S3 and returned as a
        `GateModelQuantumTaskResult` or `AnnealingQuantumTaskResult`

        This method is a blocking thread call and synchronously returns a result.
        Call `async_result()` if you require an asynchronous invocation.
        Consecutive calls to this method return a cached result from the preceding request.

        Returns:
            Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult, PhotonicModelQuantumTaskResult]: # noqa
            The result of the task, if the task completed successfully; returns `None` if the task
            did not complete successfully or the future timed out.
        """
        if self._result or (
            self._metadata and self._status(True) in self.NO_RESULT_TERMINAL_STATES
        ):
            return self._result
        if self._metadata and self._status(True) in self.RESULTS_READY_STATES:
            return self._download_result()
        try:
            async_result = self.async_result()
            return async_result.get_loop().run_until_complete(async_result)
        except asyncio.CancelledError:
            # Future was cancelled, return whatever is in self._result if anything
            self._logger.warning("Task future was cancelled")
            return self._result

    def _get_future(self) -> asyncio.Future:
        try:
            asyncio.get_event_loop()
        except Exception as e:
            self._logger.debug(e)
            self._logger.info("No event loop found; creating new event loop")
            asyncio.set_event_loop(asyncio.new_event_loop())
        if not hasattr(self, "_future") or (
            self._future.done()
            and not self._future.cancelled()
            and self._result is None
            # timed out and no result
            and self._update_status_if_nonterminal() not in self.NO_RESULT_TERMINAL_STATES
        ):
            self._future = asyncio.get_event_loop().run_until_complete(self._create_future())
        return self._future

    def async_result(self) -> asyncio.Task:
        """
        Get the quantum task result asynchronously. Consecutive calls to this method return
        the result cached from the most recent request.
        """
        return self._get_future()

    async def _create_future(self) -> asyncio.Task:
        """
        Wrap the `_wait_for_completion` coroutine inside a future-like object.
        Invoking this method starts the coroutine and returns back the future-like object
        that contains it. Note that this does not block on the coroutine to finish.

        Returns:
            asyncio.Task: An asyncio Task that contains the `_wait_for_completion()` coroutine.
        """
        return asyncio.create_task(self._wait_for_completion())

    async def _wait_for_completion(
        self,
    ) -> Union[
        GateModelQuantumTaskResult, AnnealingQuantumTaskResult, PhotonicModelQuantumTaskResult
    ]:
        """
        Waits for the quantum task to be completed, then returns the result from the S3 bucket.

        Returns:
            Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult]: If the task is in the
                `AwsQuantumTask.RESULTS_READY_STATES` state within the specified time limit,
                the result from the S3 bucket is loaded and returned.
                `None` is returned if a timeout occurs or task state is in
                `AwsQuantumTask.NO_RESULT_TERMINAL_STATES`.
        Note:
            Timeout and sleep intervals are defined in the constructor fields
                `poll_timeout_seconds` and `poll_interval_seconds` respectively.
        """
        self._logger.debug(f"Task {self._arn}: start polling for completion")
        start_time = time.time()

        while (time.time() - start_time) < self._poll_timeout_seconds:
            # Used cached metadata if cached status is terminal
            task_status = self._update_status_if_nonterminal()
            self._logger.debug(f"Task {self._arn}: task status {task_status}")
            if task_status in AwsQuantumTask.RESULTS_READY_STATES:
                return self._download_result()
            elif task_status in AwsQuantumTask.NO_RESULT_TERMINAL_STATES:
                self._result = None
                return None
            else:
                await asyncio.sleep(self._poll_interval_seconds)

        # Timed out
        self._logger.warning(
            f"Task {self._arn}: polling for task completion timed out after "
            + f"{time.time() - start_time} seconds. Please increase the timeout; "
            + "this can be done by creating a new AwsQuantumTask with this task's ARN "
            + "and a higher value for the `poll_timeout_seconds` parameter."
        )
        self._result = None
        return None

    def _download_result(
        self,
    ) -> Union[
        GateModelQuantumTaskResult, AnnealingQuantumTaskResult, PhotonicModelQuantumTaskResult
    ]:
        current_metadata = self.metadata(True)
        result_string = self._aws_session.retrieve_s3_object_body(
            current_metadata["outputS3Bucket"],
            current_metadata["outputS3Directory"] + f"/{AwsQuantumTask.RESULTS_FILENAME}",
        )
        self._result = _format_result(BraketSchemaBase.parse_raw_schema(result_string))
        task_event = {"arn": self.id, "status": self.state(), "execution_duration": None}
        try:
            task_event[
                "execution_duration"
            ] = self._result.additional_metadata.simulatorMetadata.executionDuration
        except AttributeError:
            pass
        broadcast_event(_TaskCompletionEvent(**task_event))
        return self._result

    def __repr__(self) -> str:
        return f"AwsQuantumTask('id/taskArn':'{self.id}')"

    def __eq__(self, other) -> bool:
        if isinstance(other, AwsQuantumTask):
            return self.id == other.id
        return False

    def __hash__(self) -> int:
        return hash(self.id)


@singledispatch
def _create_internal(
    task_specification: Union[Circuit, Problem, BlackbirdProgram],
    aws_session: AwsSession,
    create_task_kwargs: Dict[str, Any],
    device_arn: str,
    device_parameters: Union[dict, BraketSchemaBase],
    disable_qubit_rewiring: bool,
    inputs: Dict[str, float],
    *args,
    **kwargs,
) -> AwsQuantumTask:
    raise TypeError("Invalid task specification type")


@_create_internal.register
def _(
    pulse_sequence: PulseSequence,
    aws_session: AwsSession,
    create_task_kwargs: Dict[str, Any],
    device_arn: str,
    _device_parameters: Union[dict, BraketSchemaBase],  # Not currently used for OpenQasmProgram
    _disable_qubit_rewiring: bool,
    inputs: Dict[str, float],
    *args,
    **kwargs,
) -> AwsQuantumTask:
    create_task_kwargs.update({"action": OpenQASMProgram(source=pulse_sequence.to_ir()).json()})
    task_arn = aws_session.create_quantum_task(**create_task_kwargs)
    return AwsQuantumTask(task_arn, aws_session, *args, **kwargs)


@_create_internal.register
def _(
    openqasm_program: OpenQASMProgram,
    aws_session: AwsSession,
    create_task_kwargs: Dict[str, Any],
    device_arn: str,
    device_parameters: Union[dict, BraketSchemaBase],
    _disable_qubit_rewiring: bool,
    inputs: Dict[str, float],
    *args,
    **kwargs,
) -> AwsQuantumTask:
    if inputs:
        inputs_copy = openqasm_program.inputs.copy() if openqasm_program.inputs is not None else {}
        inputs_copy.update(inputs)
        openqasm_program = OpenQASMProgram(
            source=openqasm_program.source,
            inputs=inputs_copy,
        )
    create_task_kwargs.update({"action": openqasm_program.json()})
    if device_parameters:
        final_device_parameters = (
            _circuit_device_params_from_dict(
                device_parameters,
                device_arn,
                GateModelParameters(qubitCount=0),  # qubitCount unused
            )
            if type(device_parameters) is dict
            else device_parameters
        )
        create_task_kwargs.update(
            {"deviceParameters": final_device_parameters.json(exclude_none=True)}
        )

    task_arn = aws_session.create_quantum_task(**create_task_kwargs)
    return AwsQuantumTask(task_arn, aws_session, *args, **kwargs)


@_create_internal.register
def _(
    blackbird_program: BlackbirdProgram,
    aws_session: AwsSession,
    create_task_kwargs: Dict[str, any],
    device_arn: str,
    _device_parameters: Union[dict, BraketSchemaBase],
    _disable_qubit_rewiring: bool,
    inputs: Dict[str, float],
    *args,
    **kwargs,
) -> AwsQuantumTask:
    create_task_kwargs.update({"action": blackbird_program.json()})
    task_arn = aws_session.create_quantum_task(**create_task_kwargs)
    return AwsQuantumTask(task_arn, aws_session, *args, **kwargs)


@_create_internal.register
def _(
    circuit: Circuit,
    aws_session: AwsSession,
    create_task_kwargs: Dict[str, Any],
    device_arn: str,
    device_parameters: Union[dict, BraketSchemaBase],
    disable_qubit_rewiring: bool,
    inputs: Dict[str, float],
    *args,
    **kwargs,
) -> AwsQuantumTask:
    validate_circuit_and_shots(circuit, create_task_kwargs["shots"])
    # TODO: Update this to use `deviceCapabilities` from Amazon Braket's GetDevice operation
    # in order to decide what parameters to build.
    paradigm_parameters = GateModelParameters(
        qubitCount=circuit.qubit_count, disableQubitRewiring=disable_qubit_rewiring
    )
    final_device_parameters = (
        _circuit_device_params_from_dict(device_parameters or {}, device_arn, paradigm_parameters)
        if type(device_parameters) is dict
        else device_parameters
    )

    qubit_reference_type = QubitReferenceType.VIRTUAL

    if (
        disable_qubit_rewiring
        or Instruction(StartVerbatimBox()) in circuit.instructions
        or any(isinstance(instruction.operator, PulseGate) for instruction in circuit.instructions)
    ):
        qubit_reference_type = QubitReferenceType.PHYSICAL

    serialization_properties = OpenQASMSerializationProperties(
        qubit_reference_type=qubit_reference_type
    )

    openqasm_program = circuit.to_ir(
        ir_type=IRType.OPENQASM, serialization_properties=serialization_properties
    )

    if inputs:
        inputs_copy = openqasm_program.inputs.copy() if openqasm_program.inputs is not None else {}
        inputs_copy.update(inputs)
        openqasm_program = OpenQASMProgram(
            source=openqasm_program.source,
            inputs=inputs_copy,
        )

    create_task_kwargs.update(
        {
            "action": openqasm_program.json(),
            "deviceParameters": final_device_parameters.json(exclude_none=True),
        }
    )
    task_arn = aws_session.create_quantum_task(**create_task_kwargs)
    return AwsQuantumTask(task_arn, aws_session, *args, **kwargs)


@_create_internal.register
def _(
    problem: Problem,
    aws_session: AwsSession,
    create_task_kwargs: Dict[str, Any],
    device_arn: str,
    device_parameters: Union[
        dict,
        DwaveDeviceParameters,
        DwaveAdvantageDeviceParameters,
        Dwave2000QDeviceParameters,
    ],
    _,
    inputs: Dict[str, float],
    *args,
    **kwargs,
) -> AwsQuantumTask:
    device_params = _create_annealing_device_params(device_parameters, device_arn)
    create_task_kwargs.update(
        {
            "action": problem.to_ir().json(),
            "deviceParameters": device_params.json(exclude_none=True),
        }
    )

    task_arn = aws_session.create_quantum_task(**create_task_kwargs)
    return AwsQuantumTask(task_arn, aws_session, *args, **kwargs)


@_create_internal.register
def _(
    analog_hamiltonian_simulation: AnalogHamiltonianSimulation,
    aws_session: AwsSession,
    create_task_kwargs: Dict[str, Any],
    device_arn: str,
    device_parameters: dict,
    _,
    inputs: Dict[str, float],
    *args,
    **kwargs,
) -> AwsQuantumTask:
    create_task_kwargs.update({"action": analog_hamiltonian_simulation.to_ir().json()})
    task_arn = aws_session.create_quantum_task(**create_task_kwargs)
    return AwsQuantumTask(task_arn, aws_session, *args, **kwargs)


def _circuit_device_params_from_dict(
    device_parameters: dict, device_arn: str, paradigm_parameters: GateModelParameters
) -> GateModelSimulatorDeviceParameters:
    if "errorMitigation" in device_parameters:
        error_migitation = device_parameters["errorMitigation"]
        device_parameters["errorMitigation"] = (
            error_migitation.serialize()
            if isinstance(error_migitation, ErrorMitigation)
            else error_migitation
        )
    if "ionq" in device_arn:
        return IonqDeviceParameters(paradigmParameters=paradigm_parameters, **device_parameters)
    if "rigetti" in device_arn:
        return RigettiDeviceParameters(paradigmParameters=paradigm_parameters)
    if "oqc" in device_arn:
        return OqcDeviceParameters(paradigmParameters=paradigm_parameters)
    return GateModelSimulatorDeviceParameters(paradigmParameters=paradigm_parameters)


def _create_annealing_device_params(
    device_params: Dict[str, Any], device_arn: str
) -> Union[DwaveAdvantageDeviceParameters, Dwave2000QDeviceParameters]:
    """Gets Annealing Device Parameters.

    Args:
        device_params (Dict[str, Any]): Additional parameters for the device.
        device_arn (str): The ARN of the quantum device.

    Returns:
        Union[DwaveAdvantageDeviceParameters, Dwave2000QDeviceParameters]: The device parameters.

    """
    if type(device_params) is not dict:
        device_params = device_params.dict()

    # check for device level or provider level parameters
    device_level_parameters = device_params.get("deviceLevelParameters", None) or device_params.get(
        "providerLevelParameters", {}
    )

    # deleting since it may be the old version
    if "braketSchemaHeader" in device_level_parameters:
        del device_level_parameters["braketSchemaHeader"]

    if "Advantage" in device_arn:
        device_level_parameters = DwaveAdvantageDeviceLevelParameters.parse_obj(
            device_level_parameters
        )
        return DwaveAdvantageDeviceParameters(deviceLevelParameters=device_level_parameters)
    elif "2000Q" in device_arn:
        device_level_parameters = Dwave2000QDeviceLevelParameters.parse_obj(device_level_parameters)
        return Dwave2000QDeviceParameters(deviceLevelParameters=device_level_parameters)
    else:
        raise Exception(
            f"Amazon Braket could not find a device with ARN: {device_arn}. "
            "To continue, make sure that the value of the device_arn parameter "
            "corresponds to a valid QPU."
        )


def _create_common_params(
    device_arn: str, s3_destination_folder: AwsSession.S3DestinationFolder, shots: int
) -> Dict[str, Any]:
    return {
        "deviceArn": device_arn,
        "outputS3Bucket": s3_destination_folder[0],
        "outputS3KeyPrefix": s3_destination_folder[1],
        "shots": shots,
    }


@singledispatch
def _format_result(
    result: Union[GateModelTaskResult, AnnealingTaskResult, PhotonicModelTaskResult]
) -> Union[GateModelQuantumTaskResult, AnnealingQuantumTaskResult, PhotonicModelQuantumTaskResult]:
    raise TypeError("Invalid result specification type")


@_format_result.register
def _(result: GateModelTaskResult) -> GateModelQuantumTaskResult:
    GateModelQuantumTaskResult.cast_result_types(result)
    return GateModelQuantumTaskResult.from_object(result)


@_format_result.register
def _(result: AnnealingTaskResult) -> AnnealingQuantumTaskResult:
    return AnnealingQuantumTaskResult.from_object(result)


@_format_result.register
def _(result: PhotonicModelTaskResult) -> PhotonicModelQuantumTaskResult:
    return PhotonicModelQuantumTaskResult.from_object(result)


@_format_result.register
def _(
    result: AnalogHamiltonianSimulationTaskResult,
) -> AnalogHamiltonianSimulationQuantumTaskResult:
    return AnalogHamiltonianSimulationQuantumTaskResult.from_object(result)