# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with # the License. A copy of the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES # OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and # limitations under the License. import functools import logging from concurrent.futures import Future from datetime import datetime from typing import Callable, Optional, Protocol, TypedDict from common.utils import check_command_output, time_is_up, validate_absolute_path logger = logging.getLogger(__name__) # timestamp used by clustermgtd and computemgtd should be in default ISO format # YYYY-MM-DDTHH:MM:SS.ffffff+HH:MM[:SS[.ffffff]] TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S.%f%z" DEFAULT_COMMAND_TIMEOUT = 30 ComputeInstanceDescriptor = TypedDict( "ComputeInstanceDescriptor", { "Name": str, "InstanceId": str, }, ) class TaskController(Protocol): class TaskShutdownError(RuntimeError): """Exception raised if shutdown has been requested.""" pass def queue_task(self, task: Callable[[], None]) -> Optional[Future]: """Queue a task and returns a Future for the task or None if the task could not be queued.""" def is_shutdown(self) -> bool: """Is shutdown has been requested.""" def raise_if_shutdown(self) -> None: """Raise an error if a shutdown has been requested.""" def wait_unless_shutdown(self, seconds_to_wait: float) -> None: """Wait for seconds_to_wait or will raise an error if a shutdown has been requested.""" def shutdown(self, wait: bool, cancel_futures: bool) -> None: """Request that all tasks be shutdown.""" def log_exception( logger, action_desc, log_level=logging.ERROR, catch_exception=Exception, raise_on_error=True, exception_to_raise=None, ): def _log_exception(function): @functools.wraps(function) def wrapper(*args, **kwargs): try: return function(*args, **kwargs) except catch_exception as e: logger.log(log_level, "Failed when %s with exception %s, message: %s", action_desc, type(e).__name__, e) if raise_on_error: if exception_to_raise: # preserve exception message if exception to raise is same of actual exception raise e if isinstance(e, exception_to_raise) else exception_to_raise else: raise return wrapper return _log_exception def print_with_count(resource_list): """Print resource list with the len of the list.""" if isinstance(resource_list, str): return resource_list resource_list = [str(elem) for elem in resource_list] return f"(x{len(resource_list)}) {str(resource_list)}" def get_clustermgtd_heartbeat(clustermgtd_heartbeat_file_path): """Get clustermgtd's last heartbeat.""" # Use subprocess based method to read shared file to prevent hanging when NFS is down # Do not copy to local. Different users need to access the file, but file should be writable by root only # Only use last line of output to avoid taking unexpected output in stdout # Validation to sanitize the input argument and make it safe to use the function affected by B604 validate_absolute_path(clustermgtd_heartbeat_file_path) heartbeat = ( check_command_output( f"cat {clustermgtd_heartbeat_file_path}", timeout=DEFAULT_COMMAND_TIMEOUT, shell=True, # nosec B604 ) .splitlines()[-1] .strip() ) # Note: heartbeat must be written with datetime.strftime to convert localized datetime into str # datetime.strptime will not work with str(datetime) # Example timestamp written to heartbeat file: 2020-07-30 19:34:02.613338+00:00 return datetime.strptime(heartbeat, TIMESTAMP_FORMAT) def expired_clustermgtd_heartbeat(last_heartbeat, current_time, clustermgtd_timeout): """Test if clustermgtd heartbeat is expired.""" if time_is_up(last_heartbeat, current_time, clustermgtd_timeout): logger.error( "Clustermgtd has been offline since %s. Current time is %s. Timeout of %s seconds has expired!", last_heartbeat, current_time, clustermgtd_timeout, ) return True return False def is_clustermgtd_heartbeat_valid(current_time, clustermgtd_timeout, clustermgtd_heartbeat_file_path): try: last_heartbeat = get_clustermgtd_heartbeat(clustermgtd_heartbeat_file_path) logger.info("Latest heartbeat from clustermgtd: %s", last_heartbeat) return not expired_clustermgtd_heartbeat(last_heartbeat, current_time, clustermgtd_timeout) except Exception as e: logger.error("Unable to retrieve clustermgtd heartbeat with exception: %s", e) return False