# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import logging

import sagemaker_containers.beta.framework as framework

logging.basicConfig(format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
                    level=logging.INFO)

logging.getLogger('boto3').setLevel(logging.INFO)
logging.getLogger('s3transfer').setLevel(logging.INFO)
logging.getLogger('botocore').setLevel(logging.WARN)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def train(env, hyperparameters):
    """Runs Chainer training on a user supplied module in either a local or distributed
    SageMaker environment.

    The user supplied module and its dependencies are downloaded from S3.
    Training is invoked by calling a "train" function in the user supplied module.

    If the environment contains multiple hosts, then a distributed learning
    task is started with mpirun.

    The following is a list of other hyperparameters that can be used to change training behavior.
    None of these hyperparameters are required:

    * `sagemaker_use_mpi`: force use of MPI so that ChainerMN scripts can be run on a single host.
    * `sagemaker_process_slots_per_host`: the number of process slots per host.
    * `sagemaker_num_processes`: the total number of processes to run.
    * `sagemaker_additional_mpi_options`: a string of options to pass to mpirun.
    """
    framework.modules.download_and_install(env.module_dir)

    use_mpi = bool(hyperparameters.get('sagemaker_use_mpi', len(env.hosts) > 1))
    opts = {}

    if use_mpi:
        runner_type = framework.runner.MPIRunnerType
        opts = {
            framework.params.MPI_PROCESSES_PER_HOST: hyperparameters.get('sagemaker_process_slots_per_host'),
            framework.params.MPI_NUM_PROCESSES: hyperparameters.get('sagemaker_num_processes'),
        }
    else:
        runner_type = framework.runner.ProcessRunnerType

    framework.entry_point.run(env.module_dir,
                              env.user_entry_point,
                              env.to_cmd_args(),
                              env.to_env_vars(),
                              runner=runner_type,
                              extra_opts=opts)


def main():
    hyperparameters = framework.env.read_hyperparameters()
    env = framework.training_env(hyperparameters=hyperparameters)

    logger.setLevel(env.log_level)
    train(env, hyperparameters)


# This branch hit by mpi_script.sh (see docker base directory)
if __name__ == '__main__':
    main()