import json
import boto3
import logging

logger = logging.getLogger(__name__)
logger.setLevel('INFO')
client = boto3.client('sagemaker')
sns = boto3.client('sns')

def lambda_handler(event, context):
    """ Monitor active run, stop if issuesFound, and define retraining job. """
    try:
        job_name = event["state"]["history"]["latest_job_name"]
        logger.info(f'Job name: {job_name}')
        topic_arn = event['topic_arn']
        max_num_retraining = event['max_num_retraining']
        max_monitor_transitions = event['max_monitor_transitions']
    except KeyError as e:
        raise KeyError('Bad input.' +
                       'Key error is : {} \n'.format(e) +
                       'The input received was: {}'.format(event)
                       )

    state = event.get("state")
    state['history']['num_monitor_transitions'] += 1

    try:
        job_description = client.describe_training_job(TrainingJobName=job_name)
    except Exception as e:
        logger.error(
            "Encountered error while trying to "
            "stop training job {}: {}".format(
                job_name, str(e)
            )
        )
        raise e

    job_status = job_description.get('TrainingJobStatus')
    state["job_status"] = job_status
    logger.info(f'Job status: {job_status}')

    if state['history']['num_monitor_transitions'] > max_monitor_transitions:
        # stop the machine, too many transitions
        sns.publish(TopicArn=topic_arn,
                    Message=f'Ending. Number of monitor transitions {state["history"]["num_monitor_transitions"]} > {max_monitor_transitions}')
        stop_job(job_name)
        return {
            'statusCode': 200,
            'body': {
                'state': state
            }
        }

    if job_status == 'Completed':
        state["job_status"] = job_status
        # making sure we don't have any lingering debugger jobs before we terminating the machine
        stop_processing_job(job_description['DebugRuleEvaluationStatuses'][0]['RuleEvaluationJobArn'].split('/')[-1])   
        state["next_action"] = "end"
        sns.publish(TopicArn=topic_arn, Message=f'Job completed. Final state is: {state}. Final training job name is {job_name}.')
        return {
            'statusCode': 200,
            'body': {
                'state': state
            }
        }
    elif job_status == 'Failed':
        state["job_status"] = "Failed"
        reason = job_description.get('FailureReason', '')
        if 'AlgorithmError' not in reason:
            state["next_action"] = "launch_new"
            sns.publish(TopicArn=topic_arn, Message=f'Training failed with reason: {reason}\n'
                                                    f'Job description: {job_description}\n'
                                                    f'Machine state: {state}')
        else:
            state["next_action"] = "end"
            
    else:
        rules_eval_statuses = job_description.get('DebugRuleEvaluationStatuses', None)
        if rules_eval_statuses is None or len(rules_eval_statuses) == 0:
            logger.info("Couldn't find any debug rule statuses, skipping...")
            state["rule_status"] = "NotFound"
            state["next_action"] = "monitor"
            return {
                'statusCode': 200,
                'body': {
                    'state': state
                }
            }
        rule = rules_eval_statuses[0]
        if rule['RuleEvaluationStatus'] == "IssuesFound":
            logging.info(
                'Evaluation of rule configuration {} resulted in "IssuesFound". '
                'Attempting to stop training job {}'.format(
                    rule.get("RuleConfigurationName"), job_name
                )
            )
            stop_job(job_name)
            logger.info('Planning a new launch')
            state = plan_launch_spec(state)
            logger.info(f'New training spec {json.dumps(state["run_spec"])}')
            state["rule_status"] = "ExplodingTensors"
        elif rule['RuleEvaluationStatus'] == "InProgress":
            logger.info(
                'Evaluation of rule configuration {} of job {} is in progress. '.format(
                    rule.get("RuleConfigurationName"), job_name
                )
            )
            state["rule_status"] = "InProgress"
            state["next_action"] = "monitor"
        else:
            logging.info(
                'Status of rule configuration {} of job {} is unknown. '.format(
                    rule.get("RuleConfigurationName"), job_name
                )
            )
            state["rule_status"] = "Unknown"
            state["next_action"] = "monitor"

    if state["next_action"] == "launch_new" and state['history']['num_retraining'] >= max_num_retraining:
        state["next_action"] = "end"
        stop_job(job_name)
        sns.publish(TopicArn=topic_arn, Message=f'Max number of iterations is reached. Terminating.')

    if state["next_action"] == "launch_new":
        sns.publish(TopicArn=topic_arn, Message=f'Retraining. \n'
                                                f'State: {state}')
        logger.info(f'Retraining. \n'
                    f'State: {state}')

    return {
        'statuscode': 200,
        'body': {
            'state': state
        }
    }

def plan_launch_spec(state):
    """  Read current job params, and prescribe the next training job to launch
    """

    last_run_spec = state['run_spec']
    last_warmup_rate = last_run_spec['warmup_learning_rate']
    add_batch_norm = last_run_spec['add_batch_norm']
    learning_rate = last_run_spec['learning_rate']

    if last_warmup_rate / 5 >= 1e-3:
        logger.info('Reducing warmup rate by 1/5')
        state['history']['num_warmup_adjustments'] += 1
        state['run_spec']['warmup_learning_rate'] = last_warmup_rate * 0.5
        state['next_action'] = 'launch_new'
    elif add_batch_norm == 0:
        logger.info('Adding batch normalization layer')
        state['history']['num_batch_layer_adjustments'] += 1
        state['run_spec']['add_batch_norm'] = 1           # we are only changing the model by adding batch layers
                                                          # prior to ELU. But can make more tweaks here.
        state['next_action'] = 'launch_new'
    elif learning_rate * 0.9 > 0.001:
        state['run_spec']['learning_rate'] = learning_rate * 0.9
        state['history']['num_learning_rate_adjustments'] += 1
        state['next_action'] = 'launch_new'
    else:
        state['next_action'] = 'end'
    return state


def stop_job(job_name):
    """ stop given job"""
    try:
        client.stop_training_job(
            TrainingJobName=job_name
        )
    except client.exceptions.ClientError as e:
        logger.error(
            "Error while attempting to stop job with debugging issue. Job may have finished already. " + str(e)
        )
                        
def stop_processing_job(processing_job_name):
    """ stop given job"""
    try:
        client.stop_processing_job(
            ProcessingJobName=processing_job_name
        )
    except client.exceptions.ClientError as e:
        logger.error(
            "Error while attempting to stop the processing debugger job. Job may have finished already. " + str(e)
        )