# ######################################################################################################################
#  Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.                                                  #
#                                                                                                                      #
#  Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance      #
#  with the License. You may obtain a copy of the License at                                                           #
#                                                                                                                      #
#   http://www.apache.org/licenses/LICENSE-2.0                                                                         #
#                                                                                                                      #
#  Unless required by applicable law or agreed to in writing, software distributed under the License is distributed    #
#  on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for   #
#  the specific language governing permissions and limitations under the License.                                      #
# ######################################################################################################################

import os

import pytest
from moto import mock_sts

from aws_solutions.core import (
    get_aws_region,
    get_service_client,
    get_aws_partition,
    get_aws_account,
    get_service_resource,
)


@pytest.fixture(autouse=True, scope="module")
def valid_solution_env():
    os.environ["AWS_REGION"] = "us-east-1"
    os.environ["SOLUTION_ID"] = "SO0100"
    os.environ["SOLUTION_VERSION"] = "1.0.0"
    yield
    del os.environ["AWS_REGION"]
    del os.environ["SOLUTION_ID"]
    del os.environ["SOLUTION_VERSION"]


def test_get_aws_region_valid():
    assert get_aws_region() == "us-east-1"


def test_get_service_client():
    cli = get_service_client("ec2")
    assert cli.meta.service_model.service_name == "ec2"


def test_get_service_resource():
    ec2 = get_service_resource("ec2")
    assert ec2.meta.service_name == "ec2"


@pytest.mark.parametrize(
    "region,partition",
    [
        ("us-east-1", "aws"),
        ("us-gov-west-1", "aws-us-gov"),
        ("us-gov-west-2", "aws-us-gov"),
        ("cn-north-1", "aws-cn"),
        ("cn-northwest-1", "aws-cn"),
    ],
)
def test_get_aws_partition(region, partition, mocker):
    mocker.patch("aws_solutions.core.helpers.get_aws_region", return_value=region)
    assert get_aws_partition() == partition


@mock_sts
def test_get_aws_account_id(mocker):
    assert get_aws_account() == "1" * 12