# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from __future__ import absolute_import

import base64
import contextlib
import json
import os
import shlex
import shutil
import subprocess
import sys
import tempfile
import time

import boto3


def wait_for_s3_object(
    s3_bucket,
    key,
    local_dir,
    local_prefix="",
    aws_account=None,
    aws_region=None,
    timeout=1200,
    limit=20,
    fetch_only=None,
    training_job_name=None,
):
    """
    Keep polling s3 object until it is generated.
    Pulling down latest data to local directory with short key

    Arguments:
        s3_bucket (string): s3 bucket name
        key (string): key for s3 object
        local_dir (string): local directory path to save s3 object
        local_prefix (string): local prefix path append to the local directory
        aws_account (string): aws account of the s3 bucket
        aws_region (string): aws region where the repo is located
        timeout (int): how long to wait for the object to appear before giving up
        limit (int): maximum number of files to download
        fetch_only (lambda): a function to decide if this object should be fetched or not
        training_job_name (string): training job name to query job status

    Returns:
        A list of all downloaded files, as local filenames
    """
    session = boto3.Session()
    aws_account = aws_account or session.client("sts").get_caller_identity()["Account"]
    aws_region = aws_region or session.region_name

    s3 = session.resource("s3")
    sagemaker = session.client("sagemaker")
    bucket = s3.Bucket(s3_bucket)
    objects = []

    print("Waiting for s3://%s/%s..." % (s3_bucket, key), end="", flush=True)
    start_time = time.time()
    cnt = 0
    while len(objects) == 0:
        objects = list(bucket.objects.filter(Prefix=key))
        if fetch_only:
            objects = list(filter(fetch_only, objects))
        if objects:
            continue
        print(".", end="", flush=True)
        time.sleep(5)
        cnt += 1
        if cnt % 80 == 0:
            print("")
        if time.time() > start_time + timeout:
            raise FileNotFoundError(
                "S3 object s3://%s/%s never appeared after %d seconds" % (s3_bucket, key, timeout)
            )
        if training_job_name:
            training_job_status = sagemaker.describe_training_job(
                TrainingJobName=training_job_name
            )["TrainingJobStatus"]
            if training_job_status == "Failed":
                raise RuntimeError(
                    "Training job {} failed while waiting for S3 object s3://{}/{}".format(
                        training_job_name, s3_bucket, key
                    )
                )

    print("\n", end="", flush=True)

    if len(objects) > limit:
        print("Only downloading %d of %d files" % (limit, len(objects)))
        objects = objects[-limit:]

    fetched_files = []
    for obj in objects:
        print("Downloading %s" % obj.key)
        local_path = os.path.join(local_dir, local_prefix, obj.key.split("/")[-1])
        obj.Object().download_file(local_path)
        fetched_files.append(local_path)

    return fetched_files


def get_execution_role(role_name="sagemaker", aws_account=None, aws_region=None):
    """
    Create sagemaker execution role to perform sagemaker task

    Args:
        role_name (string): name of the role to be created
        aws_account (string): aws account of the ECR repo
        aws_region (string): aws region where the repo is located
    """
    session = boto3.Session()
    aws_account = aws_account or session.client("sts").get_caller_identity()["Account"]
    aws_region = aws_region or session.region_name

    assume_role_policy_document = json.dumps(
        {
            "Version": "2012-10-17",
            "Statement": [
                {
                    "Effect": "Allow",
                    "Principal": {
                        "Service": ["sagemaker.amazonaws.com", "robomaker.amazonaws.com"]
                    },
                    "Action": "sts:AssumeRole",
                }
            ],
        }
    )

    client = session.client("iam")
    try:
        client.get_role(RoleName=role_name)
    except client.exceptions.NoSuchEntityException:
        client.create_role(
            RoleName=role_name, AssumeRolePolicyDocument=str(assume_role_policy_document)
        )

        print("Created new sagemaker execution role: %s" % role_name)

    client.attach_role_policy(
        PolicyArn="arn:aws:iam::aws:policy/AmazonSageMakerFullAccess", RoleName=role_name
    )

    return client.get_role(RoleName=role_name)["Role"]["Arn"]


def wait_for_training_job_to_complete(job_name):
    sagemaker_client = boto3.client("sagemaker")
    sagemaker_client.get_waiter("training_job_completed_or_stopped").wait(TrainingJobName=job_name)