# coding=utf-8
# Original Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# Modifications Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
import smdistributed.modelparallel.torch as smp
from smdistributed.modelparallel.torch.state_mod import state as smp_state


class tofp16(nn.Module):
    """
    Utility module that implements::

        def forward(self, input):
            return input.half()
    """

    def __init__(self):
        super(tofp16, self).__init__()

    def forward(self, input):
        return input.half()


def BN_convert_float(module):
    """
    Utility function for network_to_half().

    Retained for legacy purposes.
    """
    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
        module.float()
    for child in module.children():
        BN_convert_float(child)
    return module


def network_to_half(network):
    """
    Convert model to half precision in a batchnorm-safe way.

    Retained for legacy purposes. It is recommended to use FP16Model.
    """
    return nn.Sequential(tofp16(), BN_convert_float(network.half()))


def convert_module(module, dtype):
    """
    Converts a module's immediate parameters and buffers to dtype.
    """
    for param in module.parameters(recurse=False):
        if param is not None:
            if param.data.dtype.is_floating_point:
                param.data = param.data.to(dtype=dtype)
            if param._grad is not None and param._grad.data.dtype.is_floating_point:
                param._grad.data = param._grad.data.to(dtype=dtype)

    for buf in module.buffers(recurse=False):
        if buf is not None and buf.data.dtype.is_floating_point:
            buf.data = buf.data.to(dtype=dtype)


def convert_network(network, dtype):
    """
    Converts a network's parameters and buffers to dtype.
    """
    for module in network.modules():
        if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:
            continue
        convert_module(module, dtype)
    return network


class FP16Model(nn.Module):
    """
    Convert model to half precision in a batchnorm-safe way.
    """

    def __init__(self, network):
        super(FP16Model, self).__init__()
        self.network = convert_network(network, dtype=torch.half)

    def forward(self, *inputs):
        inputs = tuple(t.half() for t in inputs)
        return self.network(*inputs)


def backwards_debug_hook(grad):
    raise RuntimeError("master_params recieved a gradient in the backward pass!")


def prep_param_lists(model, flat_master=False):
    """
    Creates a list of FP32 master parameters for a given model, as in
    `Training Neural Networks with Mixed Precision:  Real Examples`_.

    Args:
        model (torch.nn.Module): Existing Pytorch model
        flat_master (bool, optional, default=False):  Flatten the master parameters into a single tensor, as a performance optimization.
    Returns:
        A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`.  ``master_params`` is a list of FP32 master gradients.  If ``flat_master=True``, ``master_params`` will be a list with one element.

    Example::

        model_params, master_params = prep_param_lists(model)

    .. warning::
        Currently, if ``flat_master=True``, all the model's parameters must be the same type.  If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`.

    .. _`Training Neural Networks with Mixed Precision:  Real Examples`:
        http://on-demand.gputechconf.com/gtc/2018/video/S81012/
    """
    model_params = [param for param in model.parameters() if param.requires_grad]

    if flat_master:
        # Give the user some more useful error messages
        try:
            # flatten_dense_tensors returns a contiguous flat array.
            # http://pytorch.org/docs/master/_modules/torch/_utils.html
            master_params = _flatten_dense_tensors([param.data for param in model_params]).float()
        except BaseException:
            print(
                "Error in prep_param_lists:  model may contain a mixture of parameters "
                "of different types.  Use flat_master=False, or use F16_Optimizer."
            )
            raise
        master_params = torch.nn.Parameter(master_params)
        master_params.requires_grad = True
        # master_params.register_hook(backwards_debug_hook)
        if master_params.grad is None:
            master_params.grad = master_params.new(*master_params.size())
        return model_params, [master_params]
    else:
        master_params = [param.clone().float().detach() for param in model_params]
        for param in master_params:
            param.requires_grad = True
        return model_params, master_params


def model_grads_to_master_grads(
    model_params, master_params, flat_master=False, loss_scale=1.0, params_have_main_grad=False
):
    """
    Copy model gradients to master gradients.

    Args:
        model_params:  List of model parameters created by :func:`prep_param_lists`.
        master_params:  List of FP32 master parameters created by :func:`prep_param_lists`.  If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`.
    """
    if flat_master:
        # The flattening may incur one more deep copy than is necessary.
        master_params[0].grad.data.copy_(
            _flatten_dense_tensors([p.grad.data for p in model_params])
        )
    else:
        for model, master in zip(model_params, master_params):
            if model.device.type == "cpu":
                continue
            if model.grad is not None:
                if master.grad is None:
                    if params_have_main_grad:
                        # If gradient_as_bucket_view is False, this will be a copy
                        master.grad = model.grad.float()
                    else:
                        master.grad = Variable(master.data.new(*master.data.size()))
            else:
                master.grad = None
        model_grads = [p.grad for p in model_params if p.grad is not None]
        master_grads = [p.grad for p in master_params if p.grad is not None]
        if len(model_grads) == 0 or len(master_grads) == 0:
            return
        _overflow_buf = torch.cuda.IntTensor([0])
        multi_tensor_applier(
            amp_C.multi_tensor_scale, _overflow_buf, [model_grads, master_grads], 1.0 / loss_scale
        )


def master_params_to_model_params(model_params, master_params, flat_master=False):
    """
    Copy master parameters to model parameters.

    Args:
        model_params:  List of model parameters created by :func:`prep_param_lists`.
        master_params:  List of FP32 master parameters created by :func:`prep_param_lists`.  If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`.
    """
    if flat_master:
        for model, master in zip(
            model_params, _unflatten_dense_tensors(master_params[0].data, model_params)
        ):
            model.data.copy_(master)
    else:
        for model, master in zip(model_params, master_params):
            if model.device.type == "cpu":
                continue
            model.data.copy_(master.data)


def model_params_to_master_params(model_params, master_params, flat_master=False):
    """
    Copy model params to master params
    """
    if flat_master:
        raise ValueError("Not supported")
    else:
        for model, master in zip(model_params, master_params):
            if model.device.type == "cpu":
                continue
            master.data.copy_(model.data)


# Backward compatibility fixes


def to_python_float(t):
    if hasattr(t, "item"):
        return t.item()
    else:
        return t[0]


TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])


def get_tp_merged_fp32_from_fp16_param_groups(optimizer, cpu_fp32_from_fp16_groups):
    def _merge_param_group_tp_group(group_idx, param_group):
        result_fp32_from_fp16_param_group = []
        param_name_group = {}
        for i, param in enumerate(param_group):
            # for each param, obtain param_name from param using two dicts above for tp_rank 0
            param_index = param_id_to_index_tp_group[rank_0][
                fp32_from_fp16_paramid_groups_tp_group[rank_0][group_idx][i]
            ]
            param_name = param_index_to_name_tp_group[rank_0][param_index]
            # obtain distribution axis for the param and check if its distributed
            # axis = master_distribution_axis_tp_rank_0[fp32_from_fp16_paramid_groups_tp_group[rank_0][group_idx][i]]
            axis = master_distribution_axis_tp_rank_0.get(
                fp32_from_fp16_paramid_groups_tp_group[rank_0][group_idx][i], None
            )
            if axis is not None:
                tensors = []
                for r in range(smp.tp_size()):
                    # if distributed, for each rank, obtain param id from index using above two dicts
                    param_index_r = param_name_to_index_tp_group[r][param_name]
                    param_id_r = param_index_to_id_tp_group[r][param_index_r]

                    # search param id in fp32_from_fp16_groups_param_ids and find the index.
                    group_param_idx = fp32_from_fp16_paramid_groups_tp_group[r][group_idx].index(
                        param_id_r
                    )
                    # use the param corresponding to the index from fp32_from_fp16_groups for concatenation along axis
                    tensors.append(
                        fp32_from_fp16_param_groups_tp_group[r][group_idx][group_param_idx]
                    )
                result_fp32_from_fp16_param_group.append(torch.cat(tensors, axis))
            else:
                # if not distributed set tp_rank 0 param as the param
                result_fp32_from_fp16_param_group.append(param)
            param_name_group[param_name] = i
        return result_fp32_from_fp16_param_group, param_name_group

    # get param_index_to_name all and param_name_to_index_all
    param_index_to_name_tp_group = smp_state.param_index_to_name_tp_group
    param_name_to_index_tp_group = smp_state.param_name_to_index_tp_group
    # get mapping of param_id_to_index_all and param_index_to_id_all
    param_id_to_index = optimizer._param_id_to_index()
    param_id_to_index_tp_group = smp.allgather(param_id_to_index, smp.TP_GROUP)
    param_index_to_id_tp_group = _get_param_index_to_id(param_id_to_index_tp_group)
    # allgather all param ids and all params for fp32_from_fp16_groups
    fp32_from_fp16_paramid_groups = optimizer.fp32_from_fp16_paramid_groups
    fp32_from_fp16_paramid_groups_tp_group = smp.allgather(
        fp32_from_fp16_paramid_groups, smp.TP_GROUP
    )
    fp32_from_fp16_param_groups_tp_group = smp.allgather(cpu_fp32_from_fp16_groups, smp.TP_GROUP)
    # broadcast distribution axis from tp_rank 0 to all tp_ranks
    master_distribution_axis_tp_rank_0 = None
    if smp.tp_rank() == 0:
        master_distribution_axis_tp_rank_0 = optimizer.master_distribution_axis
        smp.broadcast(master_distribution_axis_tp_rank_0, smp.TP_GROUP)
    else:
        master_distribution_axis_tp_rank_0 = smp.recv_from(0, smp.RankType.TP_RANK)

    result_fp32_from_fp16_param_groups = []
    param_name_groups = []
    rank_0 = 0
    # iterate through all the params for tp_group_fp32_from_fp16_groups[rank_0]
    for group_idx, param_group in enumerate(fp32_from_fp16_param_groups_tp_group[rank_0]):
        result_fp32_from_fp16_param_group, param_name_group = _merge_param_group_tp_group(
            group_idx, param_group
        )
        result_fp32_from_fp16_param_groups.append(result_fp32_from_fp16_param_group)
        param_name_groups.append(param_name_group)
    return result_fp32_from_fp16_param_groups, param_name_groups


def get_pp_merged_fp32_from_fp16_param_groups(
    optimizer, fp32_from_fp16_groups, param_name_groups=None
):
    pp_group_fp32_from_fp16_groups = smp.allgather(fp32_from_fp16_groups, smp.PP_GROUP)
    if param_name_groups is not None:
        index_to_param_name_groups = []
        # obtain index_to_param_name mapping across tp_group
        for param_name_group in param_name_groups:
            index_to_param_name = {}
            for param_name, index in param_name_group.items():
                index_to_param_name[index] = param_name
            index_to_param_name_groups.append(index_to_param_name)
        # allgather the index_to_param_name_groups across the pp_group
        pp_index_to_param_name_groups = smp.allgather(index_to_param_name_groups, smp.PP_GROUP)
    else:
        raise ValueError("Merging is not supported when param_name_groups is None")

    pp_merged_fp32_from_fp16_groups = []
    result_param_groups = []

    # iterate through all the groups for rank 0
    for group_idx in range(len(pp_group_fp32_from_fp16_groups[0])):
        merged = []
        start_idx = 0
        result_param_group = {}
        # for each group iterate through all ranks and merge the param groups across pp_ranks
        for rank, group in enumerate(pp_group_fp32_from_fp16_groups):
            cur_g = group[group_idx]
            start_idx += len(merged)
            for i, _ in enumerate(cur_g):
                param_name = pp_index_to_param_name_groups[rank][group_idx][i]
                if param_name in result_param_group:
                    raise ValueError(
                        "same param_name present in the param_groups of different pipeline parallel partitions"
                    )
                result_param_group[param_name] = i + start_idx
            merged.extend(cur_g)
        pp_merged_fp32_from_fp16_groups.append(merged)
        result_param_groups.append(result_param_group)
    return pp_merged_fp32_from_fp16_groups, result_param_groups


def _get_param_index_to_id(param_id_to_index_tp_group):
    param_index_to_id_tp_group = []
    for param_id_to_index_map in param_id_to_index_tp_group:
        param_index_to_id_map = {}
        for param_id, param_index in param_id_to_index_map.items():
            param_index_to_id_map[param_index] = param_id
        param_index_to_id_tp_group.append(param_index_to_id_map)
    return param_index_to_id_tp_group


def register_optimizer_hooks(model):
    def param_name_to_index(self):
        param_id_to_index = self._param_id_to_index()
        name_to_index = {}
        if self.redefined_params:
            param_gen = model.virtual_named_parameters()
        else:
            param_gen = model.named_parameters()
        for name, param in param_gen:
            fp16_param_id = id(param)
            if fp16_param_id in self.fp32paramid_from_fp16paramid:
                param_id = self.fp32paramid_from_fp16paramid[fp16_param_id]
            else:
                param_id = fp16_param_id
            if param_id in param_id_to_index:
                name_to_index[name] = param_id_to_index[param_id]
        return name_to_index

    def _param_index_to_param_local(self):
        param_id_to_index = self._param_id_to_index()
        param_index_to_param = {}

        if not model:
            return param_index_to_param

        if self.redefined_params:
            param_gen = model.virtual_named_parameters()
        else:
            param_gen = model.named_parameters()
        for name, param in param_gen:
            fp16_param_id = id(param)
            if fp16_param_id in self.fp32paramid_from_fp16paramid:
                param_id = self.fp32paramid_from_fp16paramid[fp16_param_id]
            else:
                param_id = fp16_param_id
            if param_id in param_id_to_index:
                param_index_to_param[param_id_to_index[param_id]] = param

        return param_index_to_param

    def hook_fn(model, optimizer):
        from functools import partial

        optimizer.param_name_to_index = partial(param_name_to_index, optimizer)
        optimizer._param_index_to_param_local = partial(_param_index_to_param_local, optimizer)

    model.register_post_partition_hook(hook_fn)