import itertools import time from pathlib import Path from typing import Any, List from unittest import skipIf import boto3 import pytest from samcli.lib.observability.util import OutputOption from tests.integration.deploy.deploy_integ_base import DeployIntegBase from tests.integration.traces.traces_integ_base import TracesIntegBase, RETRY_COUNT, RETRY_SLEEP from tests.testing_utils import ( run_command, RUNNING_ON_CI, RUNNING_TEST_FOR_MASTER_ON_CI, RUN_BY_CANARY, method_to_stack_name, kill_process, start_persistent_process, read_until, ) from datetime import datetime import logging from parameterized import parameterized LOG = logging.getLogger(__name__) SKIP_TRACES_TESTS = RUNNING_ON_CI and RUNNING_TEST_FOR_MASTER_ON_CI and not RUN_BY_CANARY @skipIf(SKIP_TRACES_TESTS, "Skip traces tests in CI/CD only") @pytest.mark.xdist_group(name="sam_traces") class TestTracesCommand(TracesIntegBase): stack_resources: List[Any] = [] stack_name = "" def setUp(self): self.lambda_client = boto3.client("lambda") self.sfn_client = boto3.client("stepfunctions") self.xray_client = boto3.client("xray") @pytest.fixture(scope="class") def deploy_testing_stack(self): test_data_path = Path(__file__).resolve().parents[1].joinpath("testdata", "traces") TestTracesCommand.stack_name = method_to_stack_name("test_traces_command") cfn_client = boto3.client("cloudformation") deploy_cmd = DeployIntegBase.get_deploy_command_list( stack_name=TestTracesCommand.stack_name, template_file=test_data_path.joinpath("python-apigw-sfn", "template.yaml"), resolve_s3=True, capabilities="CAPABILITY_IAM", ) deploy_result = run_command(deploy_cmd) yield deploy_result, cfn_client cfn_client.delete_stack(StackName=TestTracesCommand.stack_name) @pytest.fixture(autouse=True, scope="class") def sync_code_base(self, deploy_testing_stack): deploy_result = deploy_testing_stack[0] cfn_client = deploy_testing_stack[1] self.assertEqual( deploy_result.process.returncode, 0, f"Deployment of the test stack is failed with {deploy_result.stderr}" ) TestTracesCommand.stack_resources = cfn_client.describe_stack_resources( StackName=TestTracesCommand.stack_name ).get("StackResources", []) def _get_physical_id(self, logical_id: str): for stack_resource in self.stack_resources: if stack_resource["LogicalResourceId"] == logical_id: return stack_resource["PhysicalResourceId"] return None @parameterized.expand([("ApiGwFunction",), ("SfnFunction",)]) def test_function_traces(self, function_name): function_id = self._get_physical_id(function_name) expected_trace_output = [function_id] LOG.info("Invoking function %s", function_name) lambda_invoke_result = self.lambda_client.invoke(FunctionName=function_id) LOG.info("Lambda invoke result %s", lambda_invoke_result) cmd_list = self.get_traces_command_list() self._check_traces(cmd_list, expected_trace_output) @parameterized.expand([("ApiGwFunction",), ("SfnFunction",)]) def test_trace_id(self, function_name): function_id = self._get_physical_id(function_name) expected_trace_output = [function_id] start_time = datetime.utcnow() LOG.info("Invoking function %s", function_name) lambda_invoke_result = self.lambda_client.invoke(FunctionName=function_id) LOG.info("Lambda invoke result %s", lambda_invoke_result) for _ in range(RETRY_COUNT): end_time = datetime.utcnow() kwargs = {"TimeRangeType": "TraceId", "StartTime": start_time, "EndTime": end_time} trace_summaries_response = self.xray_client.get_trace_summaries(**kwargs) trace_summaries = trace_summaries_response.get("TraceSummaries", []) if trace_summaries: break time.sleep(RETRY_SLEEP) if not trace_summaries: self.fail("can't find any trace summaries") trace_id = trace_summaries[0].get("Id") LOG.info("Trace id: %s", trace_id) cmd_list = self.get_traces_command_list(trace_id=trace_id) self._check_traces(cmd_list, expected_trace_output, has_service_graph=False) @parameterized.expand([("ApiGwFunction",), ("SfnFunction",)]) def test_trace_start_time(self, function_name): function_id = self._get_physical_id(function_name) expected_trace_output = [function_id] start_time = datetime.utcnow() LOG.info("Invoking function %s", function_name) lambda_invoke_result = self.lambda_client.invoke(FunctionName=function_id) LOG.info("Lambda invoke result %s", lambda_invoke_result) cmd_list = self.get_traces_command_list(start_time=str(start_time)) self._check_traces(cmd_list, expected_trace_output) @parameterized.expand([("ApiGwFunction",), ("SfnFunction",)]) def test_trace_end_time(self, function_name): function_id = self._get_physical_id(function_name) expected_trace_output = [function_id] LOG.info("Invoking function %s", function_name) lambda_invoke_result = self.lambda_client.invoke(FunctionName=function_id) LOG.info("Lambda invoke result %s", lambda_invoke_result) end_time = datetime.utcnow() cmd_list = self.get_traces_command_list(end_time=str(end_time)) self._check_traces(cmd_list, expected_trace_output) @parameterized.expand([("ApiGwFunction",), ("SfnFunction",)]) def test_traces_with_tail(self, function_name: str): function_id = self._get_physical_id(function_name) expected_trace_output = function_id LOG.info("Invoking function %s", "HelloWorldFunction") lambda_invoke_result = self.lambda_client.invoke(FunctionName=function_id) LOG.info("Lambda invoke result %s", lambda_invoke_result) cmd_list = self.get_traces_command_list(tail=True) tail_process = start_persistent_process(cmd_list) def _check_traces(output: str, _: List[str]) -> bool: return expected_trace_output in output try: read_until(tail_process, _check_traces, timeout=RETRY_COUNT * RETRY_SLEEP) finally: kill_process(tail_process) @parameterized.expand( itertools.product(["ApiGwFunction", "SfnFunction"], [None, OutputOption.text.name, OutputOption.json.name]) ) def test_traces_with_output_option(self, function_name, output): function_id = self._get_physical_id(function_name) expected_trace_output = [function_id] LOG.info("Invoking function %s", function_name) lambda_invoke_result = self.lambda_client.invoke(FunctionName=function_id) LOG.info("Lambda invoke result %s", lambda_invoke_result) cmd_list = self.get_traces_command_list(output=output) output_check = OutputOption.json if output == OutputOption.json.name else OutputOption.text self._check_traces(cmd_list, expected_trace_output, output=output_check) def _check_traces(self, cmd_list, trace_strings, output=OutputOption.text, has_service_graph=True): for _ in range(RETRY_COUNT): cmd_result = run_command(cmd_list) self.assertEqual(cmd_result.process.returncode, 0) actual_output = cmd_result.stdout.decode("utf-8") if has_service_graph and not self._check_traces_with_service_graph(trace_strings, actual_output, output): time.sleep(RETRY_SLEEP) continue if not self._check_traces_with_xray_event(trace_strings, actual_output, output): time.sleep(RETRY_SLEEP) continue return self.fail(f"No match found for one of the expected trace outputs '{trace_strings}'") def _check_traces_with_service_graph(self, trace_strings, console_output, output=OutputOption.text): if output == OutputOption.text: return self._check_service_graph_with_output_text(trace_strings, console_output) if output == OutputOption.json: return self._check_service_graph_with_output_json(trace_strings, console_output) def _check_traces_with_xray_event(self, trace_strings, console_output, output=OutputOption.text): if output == OutputOption.text: return self._check_xray_event_with_output_text(trace_strings, console_output) if output == OutputOption.json: return self._check_xray_event_with_output_json(trace_strings, console_output) def _check_xray_event_with_output_text(self, trace_strings, console_output): # It's hard to verify the entire text output, just verify if some keywords exist and verify if expected # trace strings exist in the console output as well if "XRay Event" not in console_output: return False return self._check_trace_string_exist(trace_strings, console_output) def _check_xray_event_with_output_json(self, trace_strings, console_output): # It's hard to verify the entire json output, just verify if some keywords exist and verify if expected # trace strings exist in the console output as well if 'content-type": "application/json' not in console_output: return False return self._check_trace_string_exist(trace_strings, console_output) def _check_service_graph_with_output_text(self, trace_strings, console_output): # It's hard to verify the entire text output, just verify if some keywords exist and verify if expected # trace strings exist in the console output as well if "New XRay Service Graph" not in console_output: return False if "Start time" not in console_output: return False if "End time" not in console_output: return False if "Reference Id" not in console_output: return False if "Summary_statistics" not in console_output: return False return self._check_trace_string_exist(trace_strings, console_output) def _check_service_graph_with_output_json(self, trace_strings, console_output): # It's hard to verify the entire json output, just verify if some keywords exist and verify if expected # trace strings exist in the console output if "Segments" not in console_output: return False if "trace_id" not in console_output: return False return self._check_trace_string_exist(trace_strings, console_output) @staticmethod def _check_trace_string_exist(trace_strings, console_output): for trace_string in trace_strings: if trace_string not in console_output: return False return True