# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from typing import Dict, Union

from openpulse import ast
from openqasm3.ast import DurationLiteral
from openqasm3.visitor import QASMTransformer

from braket.parametric.free_parameter_expression import FreeParameterExpression


class _FreeParameterExpressionIdentifier(ast.Identifier):
    """Dummy AST node with FreeParameterExpression instance attached"""

    def __init__(self, expression: FreeParameterExpression):
        super().__init__(name=f"FreeParameterExpression({expression})")
        self._expression = expression

    @property
    def expression(self) -> FreeParameterExpression:
        return self._expression


class _FreeParameterTransformer(QASMTransformer):
    """Walk the AST and evaluate FreeParameterExpressions."""

    def __init__(self, param_values: Dict[str, float]):
        self.param_values = param_values
        super().__init__()

    def visit__FreeParameterExpressionIdentifier(
        self, identifier: ast.Identifier
    ) -> Union[_FreeParameterExpressionIdentifier, ast.FloatLiteral]:
        """Visit a FreeParameterExpressionIdentifier.
        Args:
            identifier (Identifier): The identifier.

        Returns:
            Union[_FreeParameterExpressionIdentifier, FloatLiteral]: The transformed expression.
        """
        new_value = identifier.expression.subs(self.param_values)
        if isinstance(new_value, FreeParameterExpression):
            return _FreeParameterExpressionIdentifier(new_value)
        else:
            return ast.FloatLiteral(new_value)

    def visit_DurationLiteral(self, duration_literal: DurationLiteral) -> DurationLiteral:
        """Visit Duration Literal.
            node.value, node.unit (node.unit.name, node.unit.value)
            1
        Args:
            duration_literal (DurationLiteral): The duration literal.
        Returns:
            DurationLiteral: The transformed duration literal.
        """
        duration = duration_literal.value
        if not isinstance(duration, FreeParameterExpression):
            return duration_literal
        return DurationLiteral(duration.subs(self.param_values), duration_literal.unit)