import logging
from unittest.case import skipIf

import pytest
from parameterized import parameterized

from integration.config.service_names import (
    ARM,
    CODE_DEPLOY,
    EVENT_INVOKE_CONFIG,
    HTTP_API,
    KMS,
    LAMBDA_URL,
    XRAY,
)
from integration.helpers.base_test import BaseTest
from integration.helpers.resource import current_region_does_not_support

LOG = logging.getLogger(__name__)


class TestBasicFunction(BaseTest):
    """
    Basic AWS::Lambda::Function tests
    """

    @parameterized.expand(
        [
            "single/basic_function",
            "single/basic_function_no_envvar",
            "single/basic_function_openapi",
        ]
    )
    def test_basic_function(self, file_name):
        """
        Creates a basic lambda function
        """
        self.create_and_verify_stack(file_name)

        self.set_template_resource_property("MyLambdaFunction", "Timeout", 10)
        self.update_stack()

        self.assertEqual(self.get_resource_status_by_logical_id("MyLambdaFunction"), "UPDATE_COMPLETE")

    def test_basic_function_with_role_path(self):
        self.create_and_verify_stack("single/function_with_role_path")

        lambda_client = self.client_provider.lambda_client
        function_name = self.get_physical_id_by_type("AWS::Lambda::Function")
        role_name = self.get_physical_id_by_type("AWS::IAM::Role")
        response = lambda_client.get_function(FunctionName=function_name)

        role_arn = response.get("Configuration", {}).get("Role")
        self.assertIsNotNone(role_arn)
        self.assertIn("/foo/bar/", role_arn)

        iam_client = self.client_provider.iam_client
        response = iam_client.get_role(RoleName=role_name)

        self.assertEqual(response["Role"]["Path"], "/foo/bar/")

    @parameterized.expand(
        [
            "single/function_with_http_api_events",
            "single/function_alias_with_http_api_events",
        ]
    )
    @pytest.mark.flaky(reruns=5)
    @skipIf(current_region_does_not_support([HTTP_API]), "HTTP API is not supported in this testing region")
    def test_function_with_http_api_events(self, file_name):
        self.create_and_verify_stack(file_name)

        endpoint = self.get_api_v2_endpoint("MyHttpApi")

        self._verify_get_request(endpoint, self.FUNCTION_OUTPUT)

    @parameterized.expand(
        [
            ("single/basic_function", ["x86_64"]),
            ("single/basic_function_no_envvar", ["x86_64"]),
            ("single/basic_function_openapi", ["x86_64"]),
            ("single/basic_function_with_arm_architecture", ["arm64"]),
            ("single/basic_function_with_x86_architecture", ["x86_64"]),
        ]
    )
    @skipIf(current_region_does_not_support([ARM]), "ARM is not supported in this testing region")
    def test_basic_function_with_architecture(self, file_name, architecture):
        """
        Creates a basic lambda function
        """
        self.create_and_verify_stack(file_name)

        lambda_client = self.client_provider.lambda_client
        function_name = self.get_physical_id_by_type("AWS::Lambda::Function")
        function_architecture = lambda_client.get_function_configuration(FunctionName=function_name)["Architectures"]

        self.assertEqual(function_architecture, architecture)

    @parameterized.expand(
        [
            ("single/basic_function_with_function_url_config", None),
            ("single/basic_function_with_function_url_with_autopuplishalias", "live"),
        ]
    )
    @skipIf(current_region_does_not_support([LAMBDA_URL]), "Lambda Url is not supported in this testing region")
    def test_basic_function_with_url_config(self, file_name, qualifier):
        """
        Creates a basic lambda function with Function Url enabled
        """
        self.create_and_verify_stack(file_name)

        lambda_client = self.client_provider.lambda_client

        function_name = self.get_physical_id_by_type("AWS::Lambda::Function")
        function_url_config = (
            lambda_client.get_function_url_config(FunctionName=function_name, Qualifier=qualifier)
            if qualifier
            else lambda_client.get_function_url_config(FunctionName=function_name)
        )
        cors_config = {
            "AllowOrigins": ["https://foo.com"],
            "AllowMethods": ["POST"],
            "AllowCredentials": True,
            "AllowHeaders": ["x-custom-header"],
            "ExposeHeaders": ["x-amzn-header"],
            "MaxAge": 10,
        }

        self.assertEqual(function_url_config["AuthType"], "NONE")
        self.assertEqual(function_url_config["Cors"], cors_config)
        self._assert_invoke(lambda_client, function_name, qualifier, 200)

    @skipIf(current_region_does_not_support([CODE_DEPLOY]), "CodeDeploy is not supported in this testing region")
    def test_function_with_deployment_preference_alarms_intrinsic_if(self):
        self.create_and_verify_stack("single/function_with_deployment_preference_alarms_intrinsic_if")

    @parameterized.expand(
        [
            ("single/basic_function_with_sns_dlq", "sns:Publish"),
            ("single/basic_function_with_sqs_dlq", "sqs:SendMessage"),
        ]
    )
    def test_basic_function_with_dlq(self, file_name, action):
        """
        Creates a basic lambda function with dead letter queue policy
        """
        dlq_policy_name = "DeadLetterQueuePolicy"
        self.create_and_verify_stack(file_name)

        lambda_function_name = self.get_physical_id_by_type("AWS::Lambda::Function")
        function_configuration = self.client_provider.lambda_client.get_function_configuration(
            FunctionName=lambda_function_name
        )
        dlq_arn = function_configuration["DeadLetterConfig"]["TargetArn"]
        self.assertIsNotNone(dlq_arn, "DLQ Arn should be set")

        role_name = self.get_physical_id_by_type("AWS::IAM::Role")
        role_policy_result = self.client_provider.iam_client.get_role_policy(
            RoleName=role_name, PolicyName=dlq_policy_name
        )
        statements = role_policy_result["PolicyDocument"]["Statement"]

        self.assertEqual(len(statements), 1, "Only one statement must be in policy")
        self.assertEqual(statements[0]["Action"], action)
        self.assertEqual(statements[0]["Resource"], dlq_arn)
        self.assertEqual(statements[0]["Effect"], "Allow")

    @skipIf(current_region_does_not_support([KMS]), "KMS is not supported in this testing region")
    def test_basic_function_with_kms_key_arn(self):
        """
        Creates a basic lambda function with KMS key arn
        """
        self.create_and_verify_stack("single/basic_function_with_kmskeyarn")

        lambda_function_name = self.get_physical_id_by_type("AWS::Lambda::Function")
        function_configuration = self.client_provider.lambda_client.get_function_configuration(
            FunctionName=lambda_function_name
        )
        kms_key_arn = function_configuration["KMSKeyArn"]

        self.assertIsNotNone(kms_key_arn, "Expecting KmsKeyArn to be set.")

    def test_basic_function_with_tags(self):
        """
        Creates a basic lambda function with tags
        """
        self.create_and_verify_stack("single/basic_function_with_tags")
        lambda_function_name = self.get_physical_id_by_type("AWS::Lambda::Function")
        get_function_result = self.client_provider.lambda_client.get_function(FunctionName=lambda_function_name)
        tags = get_function_result["Tags"]

        self.assertIsNotNone(tags, "Expecting tags on function.")
        self.assertTrue("lambda:createdBy" in tags, "Expected 'lambda:CreatedBy' tag key, but not found.")
        self.assertEqual("SAM", tags["lambda:createdBy"], "Expected 'SAM' tag value, but not found.")
        self.assertTrue("TagKey1" in tags)
        self.assertEqual(tags["TagKey1"], "TagValue1")
        self.assertTrue("TagKey2" in tags)
        self.assertEqual(tags["TagKey2"], "")

    @skipIf(
        current_region_does_not_support([EVENT_INVOKE_CONFIG]),
        "EventInvokeConfig is not supported in this testing region",
    )
    def test_basic_function_event_destinations(self):
        """
        Creates a basic lambda function with event destinations
        """
        self.create_and_verify_stack("single/basic_function_event_destinations")

        test_function_1 = self.get_physical_id_by_logical_id("MyTestFunction")
        test_function_2 = self.get_physical_id_by_logical_id("MyTestFunction2")

        function_invoke_config_result = self.client_provider.lambda_client.get_function_event_invoke_config(
            FunctionName=test_function_1, Qualifier="$LATEST"
        )
        self.assertIsNotNone(
            function_invoke_config_result["DestinationConfig"], "Expecting destination config to be set."
        )
        self.assertEqual(
            int(function_invoke_config_result["MaximumEventAgeInSeconds"]),
            70,
            "MaximumEventAgeInSeconds value is not set or incorrect.",
        )
        self.assertEqual(
            int(function_invoke_config_result["MaximumRetryAttempts"]),
            1,
            "MaximumRetryAttempts value is not set or incorrect.",
        )

        function_invoke_config_result = self.client_provider.lambda_client.get_function_event_invoke_config(
            FunctionName=test_function_2, Qualifier="live"
        )
        self.assertIsNotNone(
            function_invoke_config_result["DestinationConfig"], "Expecting destination config to be set."
        )
        self.assertEqual(
            int(function_invoke_config_result["MaximumEventAgeInSeconds"]),
            80,
            "MaximumEventAgeInSeconds value is not set or incorrect.",
        )
        self.assertEqual(
            int(function_invoke_config_result["MaximumRetryAttempts"]),
            2,
            "MaximumRetryAttempts value is not set or incorrect.",
        )

    @skipIf(current_region_does_not_support([XRAY]), "XRay is not supported in this testing region")
    def test_basic_function_with_tracing(self):
        """
        Creates a basic lambda function with tracing
        """
        self.create_and_verify_stack("single/basic_function_with_tracing", self.get_default_test_template_parameters())

        active_tracing_function_id = self.get_physical_id_by_logical_id("ActiveTracingFunction")
        pass_through_tracing_function_id = self.get_physical_id_by_logical_id("PassThroughTracingFunction")

        function_configuration_result = self.client_provider.lambda_client.get_function_configuration(
            FunctionName=active_tracing_function_id
        )
        self.assertIsNotNone(function_configuration_result["TracingConfig"], "Expecting tracing config to be set.")
        self.assertEqual(
            function_configuration_result["TracingConfig"]["Mode"],
            "Active",
            "Expecting tracing config mode to be set to Active.",
        )

        function_configuration_result = self.client_provider.lambda_client.get_function_configuration(
            FunctionName=pass_through_tracing_function_id
        )
        self.assertIsNotNone(function_configuration_result["TracingConfig"], "Expecting tracing config to be set.")
        self.assertEqual(
            function_configuration_result["TracingConfig"]["Mode"],
            "PassThrough",
            "Expecting tracing config mode to be set to PassThrough.",
        )

    def _assert_invoke(self, lambda_client, function_name, qualifier=None, expected_status_code=200):
        """
        Assert if a Lambda invocation returns the expected status code

        Parameters
        ----------
        lambda_client : boto3.BaseClient
            boto3 Lambda client
        function_name : string
            Function name
        qualifier : string
            Specify a version or alias to invoke a published version of the function
        expected_status_code : int
            Expected status code from the invocation
        """
        request_params = {
            "FunctionName": function_name,
            "Payload": "{}",
        }
        if qualifier:
            request_params["Qualifier"] = qualifier

        response = lambda_client.invoke(**request_params)
        self.assertEqual(response.get("StatusCode"), expected_status_code)

    def _verify_get_request(self, url, expected_text):
        response = self.verify_get_request_response(url, 200)
        self.assertEqual(response.text, expected_text)