# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import json

from sagemaker.workflow.conditions import (
    ConditionEquals,
    ConditionGreaterThan,
    ConditionGreaterThanOrEqualTo,
    ConditionIn,
    ConditionLessThan,
    ConditionLessThanOrEqualTo,
    ConditionNot,
    ConditionOr,
)
from sagemaker.workflow.execution_variables import ExecutionVariables
from sagemaker.workflow.parameters import ParameterInteger, ParameterString
from sagemaker.workflow.condition_step import ConditionStep
from sagemaker.workflow.pipeline import Pipeline, PipelineGraph
from sagemaker.workflow.properties import Properties
from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered


def test_condition_step():
    param = ParameterInteger(name="MyInt")
    cond = ConditionEquals(left=param, right=1)
    step1 = CustomStep(name="MyStep1")
    step2 = CustomStep(name="MyStep2")
    cond_step = ConditionStep(
        name="MyConditionStep",
        depends_on=["TestStep"],
        conditions=[cond],
        if_steps=[step1],
        else_steps=[step2],
    )
    cond_step.add_depends_on(["SecondTestStep"])
    assert cond_step.to_request() == {
        "Name": "MyConditionStep",
        "Type": "Condition",
        "DependsOn": ["TestStep", "SecondTestStep"],
        "Arguments": {
            "Conditions": [
                {
                    "Type": "Equals",
                    "LeftValue": param,
                    "RightValue": 1,
                },
            ],
            "IfSteps": [
                {
                    "Name": "MyStep1",
                    "Type": "Training",
                    "Arguments": {},
                },
            ],
            "ElseSteps": [
                {
                    "Name": "MyStep2",
                    "Type": "Training",
                    "Arguments": {},
                }
            ],
        },
    }
    assert cond_step.properties.Outcome.expr == {"Get": "Steps.MyConditionStep.Outcome"}


def test_pipeline_condition_step_interpolated(sagemaker_session):
    param1 = ParameterInteger(name="MyInt1")
    param2 = ParameterInteger(name="MyInt2")
    param3 = ParameterString(name="MyStr")
    var = ExecutionVariables.START_DATETIME
    prop = Properties("foo")

    cond_eq = ConditionEquals(left=param1, right=param2)
    cond_gt = ConditionGreaterThan(left=var, right="2020-12-01")
    cond_gte = ConditionGreaterThanOrEqualTo(left=var, right=param3)
    cond_lt = ConditionLessThan(left=var, right="2020-12-01")
    cond_lte = ConditionLessThanOrEqualTo(left=var, right=param3)
    cond_in = ConditionIn(value=param3, in_values=["abc", "def"])
    cond_in_mixed = ConditionIn(value=param3, in_values=["abc", prop, var])
    cond_not_eq = ConditionNot(expression=cond_eq)
    cond_not_in = ConditionNot(expression=cond_in)
    cond_or = ConditionOr(conditions=[cond_gt, cond_in])

    step1 = CustomStep(name="MyStep1")
    step2 = CustomStep(name="MyStep2")
    cond_step = ConditionStep(
        name="MyConditionStep",
        conditions=[
            cond_eq,
            cond_gt,
            cond_gte,
            cond_lt,
            cond_lte,
            cond_in,
            cond_in_mixed,
            cond_not_eq,
            cond_not_in,
            cond_or,
        ],
        if_steps=[step1],
        else_steps=[step2],
    )

    pipeline = Pipeline(
        name="MyPipeline",
        parameters=[param1, param2, param3],
        steps=[cond_step],
        sagemaker_session=sagemaker_session,
    )
    assert json.loads(pipeline.definition()) == {
        "Version": "2020-12-01",
        "Metadata": {},
        "Parameters": [
            {"Name": "MyInt1", "Type": "Integer"},
            {"Name": "MyInt2", "Type": "Integer"},
            {"Name": "MyStr", "Type": "String"},
        ],
        "PipelineExperimentConfig": {
            "ExperimentName": {"Get": "Execution.PipelineName"},
            "TrialName": {"Get": "Execution.PipelineExecutionId"},
        },
        "Steps": [
            {
                "Name": "MyConditionStep",
                "Type": "Condition",
                "Arguments": {
                    "Conditions": [
                        {
                            "Type": "Equals",
                            "LeftValue": {"Get": "Parameters.MyInt1"},
                            "RightValue": {"Get": "Parameters.MyInt2"},
                        },
                        {
                            "Type": "GreaterThan",
                            "LeftValue": {"Get": "Execution.StartDateTime"},
                            "RightValue": "2020-12-01",
                        },
                        {
                            "Type": "GreaterThanOrEqualTo",
                            "LeftValue": {"Get": "Execution.StartDateTime"},
                            "RightValue": {"Get": "Parameters.MyStr"},
                        },
                        {
                            "Type": "LessThan",
                            "LeftValue": {"Get": "Execution.StartDateTime"},
                            "RightValue": "2020-12-01",
                        },
                        {
                            "Type": "LessThanOrEqualTo",
                            "LeftValue": {"Get": "Execution.StartDateTime"},
                            "RightValue": {"Get": "Parameters.MyStr"},
                        },
                        {
                            "Type": "In",
                            "QueryValue": {"Get": "Parameters.MyStr"},
                            "Values": ["abc", "def"],
                        },
                        {
                            "Type": "In",
                            "QueryValue": {"Get": "Parameters.MyStr"},
                            "Values": [
                                "abc",
                                {"Get": "Steps.foo"},
                                {"Get": "Execution.StartDateTime"},
                            ],
                        },
                        {
                            "Type": "Not",
                            "Expression": {
                                "Type": "Equals",
                                "LeftValue": {"Get": "Parameters.MyInt1"},
                                "RightValue": {"Get": "Parameters.MyInt2"},
                            },
                        },
                        {
                            "Type": "Not",
                            "Expression": {
                                "Type": "In",
                                "QueryValue": {"Get": "Parameters.MyStr"},
                                "Values": ["abc", "def"],
                            },
                        },
                        {
                            "Type": "Or",
                            "Conditions": [
                                {
                                    "Type": "GreaterThan",
                                    "LeftValue": {"Get": "Execution.StartDateTime"},
                                    "RightValue": "2020-12-01",
                                },
                                {
                                    "Type": "In",
                                    "QueryValue": {"Get": "Parameters.MyStr"},
                                    "Values": ["abc", "def"],
                                },
                            ],
                        },
                    ],
                    "IfSteps": [{"Name": "MyStep1", "Type": "Training", "Arguments": {}}],
                    "ElseSteps": [{"Name": "MyStep2", "Type": "Training", "Arguments": {}}],
                },
            }
        ],
    }


def test_pipeline(sagemaker_session):
    param = ParameterInteger(name="MyInt", default_value=2)
    cond = ConditionEquals(left=param, right=1)
    custom_step1 = CustomStep("IfStep")
    custom_step2 = CustomStep("ElseStep")
    step_cond = ConditionStep(
        name="CondStep",
        conditions=[cond],
        if_steps=[custom_step1],
        else_steps=[custom_step2],
    )
    pipeline = Pipeline(
        name="MyPipeline",
        steps=[step_cond],
        sagemaker_session=sagemaker_session,
        parameters=[param],
    )
    adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list
    assert ordered(adjacency_list) == ordered(
        {"CondStep": ["IfStep", "ElseStep"], "IfStep": [], "ElseStep": []}
    )