#!/usr/bin/env python3 import boto3 import time import sys import signal import urllib.request import json # Fail quickly try: ssm_client = boto3.client('ssm') ddb_table = boto3.resource('dynamodb').Table('DDBTest1') cw_client = boto3.client('cloudwatch') ec2_client = boto3.client('ec2') sfn_client = boto3.client('stepfunctions') except Exception as e: print("Could not connect to AWS, did you set credentials?") print(e) sys.exit(1) # Global variables for tracking checkpoint_saved_percentage = 0 imsv2_token=None interrupt_latch=False # Functions def signal_handler(sig,frame): print(signal.Signals(sig)) print("Graceful exit - reporting final metrics - checkpointed %f" % checkpoint_saved_percentage) # # Do not suicide instance - we assume the termination notice is guaranteed # terminate_self_instances(ec2_client) sys.exit(0) def get_ddb_configs(table,instance_id): res = table.get_item( Key={ "RunId": instance_id }, AttributesToGet=[ "JobId", "JobDuration", "CheckpointDuration", "HeartbeatToken", "Percentage", ] ) return res.get("Item",{}) def get_ssm_parameter(client,name,default_setting=5): try: response = client.get_parameter( Name=name, WithDecryption=True ) # print(response) value = float(response.get("Parameter",{}).get("Value",str(default_setting))) # print("Value retrieved: %s=%f" % (name,value)) return value except: print("Couldn't read parameter %s, using default CheckPoint duration" % name) return default_setting def put_cloudwatch_percentages(client,saved_percentage,unsaved_percentage): client.put_metric_data( MetricData=[ { 'MetricName': "unsaved", 'Unit': 'Percent', 'Value': unsaved_percentage, 'StorageResolution': 1 }, { 'MetricName': "checkpointed", 'Unit': 'Percent', 'Value': saved_percentage, 'StorageResolution': 1 }, ], Namespace='fisworkshop' ) def put_ddb_saved_percentage(table,job_id,instance_id,percentage): res = table.update_item( Key={ "RunId": instance_id }, UpdateExpression="set Percentage=:p", ExpressionAttributeValues={ ':p': str(percentage) } ) res = table.update_item( Key={ "RunId": job_id }, UpdateExpression="set Percentage=:p", ExpressionAttributeValues={ ':p': str(percentage) } ) def get_imsv2_token(): global imsv2_token try: req = urllib.request.Request( 'http://169.254.169.254/latest/api/token', headers = { "X-aws-ec2-metadata-token-ttl-seconds": 21600 }, method="PUT") imsv2_token = urllib.request.urlopen(req, timeout=1).read().decode() except Exception as e: print("Failed to get token to check spot interruption") print(e) imsv2_token = None def get_instance_id(): global imsv2_token if not imsv2_token: get_imsv2_token() try: req = urllib.request.Request( 'http://169.254.169.254/latest/meta-data/instance-id', headers = { "X-aws-ec2-metadata-token": imsv2_token }, method="GET") instance_id = urllib.request.urlopen(req, timeout=1).read().decode() except: # No instance_id ... are we running in AWS instance_id = "i-manualtest" return instance_id def terminate_self_instance(client): instance_id = get_instance_id() client.terminate_instances( InstanceIds = [ instance_id ] ) print("Successfully sent instance termination request for %s" % instance_id) def check_interrupt_notice(): global imsv2_token global interrupt_latch if interrupt_latch: print("Already latched interrupt") return True if not imsv2_token: get_imsv2_token() try: req = urllib.request.Request( 'http://169.254.169.254/latest/meta-data/spot/instance-action', headers = { "X-aws-ec2-metadata-token": imsv2_token }, method="GET") res = urllib.request.urlopen(req, timeout=1).read().decode() print("Received interrupt notification: %s" % str(res)) return True except Exception as e: print("No interrupt") print(e) return False def send_task_hearbeat(client,heartbeat_token): response = client.send_task_heartbeat( taskToken=heartbeat_token ) def send_task_failure(client,heartbeat_token,error_code=1,error_text="General Error"): client.send_task_failure( taskToken=heartbeat_token, error=error_code, cause=error_text ) def send_task_success(client,heartbeat_token,percentage): client.send_task_success( taskToken=heartbeat_token, output=json.dumps({ "Percentage": str(percentage), "JobFinished": (percentage >= 100) }) ) # The code catchable_sigs = set(signal.Signals) - {signal.SIGKILL, signal.SIGSTOP, signal.SIGCHLD} for sig in catchable_sigs: print("handle %s" % sig) signal.signal(sig, signal_handler) instance_id = get_instance_id() # Get configs for ii in range(20): configs = get_ddb_configs(ddb_table,instance_id) if "HeartbeatToken" in configs: break print("No configs in DynamoDB, waiting") print(configs) time.sleep(5) if ii >= 19: print("Something went wrong. Terminate run") terminate_self_instance(ec2_client) sys.exit(1) # Duration until job completion in minutes (should be 2 < x < 15) job_duration_minutes = float(configs["JobDuration"]) # Time between checkpoints checkpoint_interval_minutes = float(configs["CheckpointDuration"]) job_id = configs["JobId"] heartbeat_token = configs["HeartbeatToken"] start_percentage = int(configs["Percentage"]) checkpoint_saved_percentage = int(configs["Percentage"]) sleep_duration_seconds = 60.0 * job_duration_minutes / 100.0 checkpoint_counter_seconds = 0.0 print("Starting job (duration %f min / checkpoint %f min)" % ( job_duration_minutes, checkpoint_interval_minutes )) put_cloudwatch_percentages(cw_client,start_percentage,start_percentage) put_ddb_saved_percentage(ddb_table,job_id,instance_id,start_percentage) for ii in range(start_percentage,100): time.sleep(sleep_duration_seconds) # checkpoint on time or interrupt notice checkpoint_counter_seconds += sleep_duration_seconds checkpoint_flag=((checkpoint_counter_seconds/60.0) > checkpoint_interval_minutes) print("%f%% complete - checkpoint=%s" % (ii+1,checkpoint_flag)) if checkpoint_flag: print("resetting flag") checkpoint_counter_seconds = 0.0 checkpoint_saved_percentage = ii+1 # record progress data that can be lost put_cloudwatch_percentages(cw_client,checkpoint_saved_percentage,ii+1) put_ddb_saved_percentage(ddb_table,job_id,instance_id,checkpoint_saved_percentage) send_task_hearbeat(sfn_client,heartbeat_token) # End on interrupt if check_interrupt_notice(): checkpoint_saved_percentage = ii+1 put_cloudwatch_percentages(cw_client,checkpoint_saved_percentage,ii+1) put_ddb_saved_percentage(ddb_table,job_id,instance_id,checkpoint_saved_percentage) send_task_success(sfn_client,heartbeat_token,checkpoint_saved_percentage) # Write final data put_cloudwatch_percentages(cw_client,100,100) put_ddb_saved_percentage(ddb_table,job_id,instance_id,100) send_task_success(sfn_client,heartbeat_token,100) # At completion suicide instance terminate_self_instance(ec2_client)