# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0
import argparse
from pathlib import Path
import os

import tensorflow as tf
import smdistributed.dataparallel.tensorflow as sdp

from datasets.camvid.loader import DatasetLoader
from enet.model import ENetParams, ENetModel
from trainer.logging import SMMetricsLogger
from trainer.data_parallel import TrainerParams, Trainer

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


def setup_smdistributed():
    sdp.init()
    gpus = tf.config.experimental.list_physical_devices('GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    if gpus:
        tf.config.experimental.set_visible_devices(
            gpus[sdp.local_rank()], 'GPU')


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--dropout-rate1', type=float, default=0.01)
    parser.add_argument('--dropout-rate2', type=float, default=0.1)
    parser.add_argument('--batch-size', type=int, default=8)
    parser.add_argument('--learning-rate', type=float, default=0.001)
    parser.add_argument('--epochs', type=int, default=10)

    parser.add_argument('--data-path', type=str, default='/opt/ml/input/data')
    parser.add_argument('--checkpoint-path', type=str,
                        default='/opt/ml/checkpoints')
    parser.add_argument('--model-path', type=str, default='/opt/ml/model')

    args, _ = parser.parse_known_args()
    return args


def build_trainer(args) -> Trainer:
    loader = DatasetLoader(Path(args.data_path).expanduser(),
                           num_shards=sdp.size(), shard_id=sdp.rank())

    enet_params = ENetParams(
        input_dim=(loader.img_height, loader.img_width, 3),
        num_object_classes=loader.num_classes,
        dropout_rate1=args.dropout_rate1,
        dropout_rate2=args.dropout_rate2,
    )
    model = ENetModel(enet_params)

    trainer_params = TrainerParams(
        num_classes=loader.num_classes,
        batch_size=args.batch_size,
        learning_rate=args.learning_rate,
        epochs=args.epochs,
        save_checkpoint=True,
        checkpoint_path=Path(args.checkpoint_path).expanduser(),
        load_best_checkpoint_after_fit=True,
    )
    metrics_logger = SMMetricsLogger()
    trainer = Trainer(trainer_params, loader, model,
                      extra_callbacks=[metrics_logger])

    return trainer


def train(args, trainer: Trainer):
    trainer.fit()
    trainer.save_model(Path(args.model_path).expanduser())


if __name__ == '__main__':
    setup_smdistributed()
    args = parse_args()
    trainer = build_trainer(args)
    train(args, trainer)