"""Util function to get GPT or BLOOM model configs."""

import logging

from transformers import (  # pylint: disable=import-error
    AutoConfig,
    BloomConfig,
    GPT2Config,
    GPTNeoXConfig,
    T5Config,
)


def _get_gpt2_config_from_args(args):
    """Get GPT2 config."""

    return {
        "vocab_size": args.vocab_size,
        "n_positions": args.max_context_width,
        "n_embd": args.hidden_width,
        "n_layer": args.num_layers,
        "n_head": args.num_heads,
        "n_inner": None,
        "activation_function": "gelu_new",
        "resid_pdrop": args.resid_pdrop,
        "embd_pdrop": args.embd_pdrop,
        "attn_pdrop": args.attn_pdrop,
        "layer_norm_epsilon": 1e-05,
        "initializer_range": args.initializer_range,
        "summary_type": "cls_index",
        "summary_use_proj": True,
        "summary_activation": None,
        "summary_proj_to_labels": True,
        "summary_first_dropout": args.summary_first_pdrop,
        # "gradient_checkpointing": args.gradient_checkpointing > 0,
        "use_cache": False,
        "bos_token_id": 50256,
        "eos_token_id": 50256,
        "return_dict": True,
    }


def _get_gpt_neox_config_from_args(args):
    """Get GPTNeoX config."""

    return {
        "vocab_size": args.vocab_size,
        "hidden_size": args.hidden_width,
        "num_hidden_layers": args.num_layers,
        "num_attention_heads": args.num_heads,
        "hidden_act": "gelu",
        "intermediate_size": 4 * args.hidden_width,
        "rotary_pct": args.rotary_pct,
        "rotary_emb_base": args.rotary_emb_base,
        "max_position_embeddings": args.max_context_width,
        "layer_norm_epsilon": 1e-05,
        "initializer_range": args.initializer_range,
        "use_cache": False,
        "parallel_attn_output": True,
    }


def _get_bloom_config_from_args(args):
    """Get BLOOM config."""

    return {
        "vocab_size": args.vocab_size,
        "hidden_size": args.hidden_width,
        "n_layer": args.num_layers,
        "n_head": args.num_heads,
        "hidden_dropout": 0.0,
        "attention_dropout": 0.0,
        "layer_norm_epsilon": 1e-05,
        "initializer_range": args.initializer_range,
        "summary_type": "cls_index",
        "summary_use_proj": True,
        "summary_activation": None,
        "summary_proj_to_labels": True,
        "summary_first_dropout": args.summary_first_pdrop,
        # "gradient_checkpointing": args.gradient_checkpointing > 0,
        "use_cache": False,
        "bos_token_id": 50256,
        "eos_token_id": 50256,
        "return_dict": True,
    }


def _get_t5_config_from_args(args):
    """Get T5 config."""

    return {
        "vocab_size": args.vocab_size,
        "d_model": args.hidden_width,
        "d_kv": 64,
        "d_ff": args.intermediate_size,
        "num_layers": args.num_layers,
        "num_decoder_layers": args.num_layers,
        "num_heads": args.num_heads,
        "relative_attention_num_buckets": 32,
        "relative_attention_max_distance": 128,
        "dropout_rate": 0.1,
        "layer_norm_epsilon": 1e-6,
        "initializer_factor": 1.0,
        "feed_forward_proj": "gated-gelu",
        "is_encoder_decoder": True,
        "use_cache": False,
        "pad_token_id": 0,
        "eos_token_id": 1,
        "decoder_start_token_id": 0,
    }


def get_model_config_from_args(model_type, model_name, args, log=False):
    """Get model config for GPT or BLOOM: From cmd args."""
    if model_name:
        logging.info(f"Loading config from HF model {model_name}")
        return AutoConfig.from_pretrained(model_name), args

    if model_type == "gpt2":
        config_type = GPT2Config
        config_kwargs = _get_gpt2_config_from_args(args)
    elif model_type == "gpt_neox":
        config_type = GPTNeoXConfig
        config_kwargs = _get_gpt_neox_config_from_args(args)
    elif model_type == "bloom":
        config_type = BloomConfig
        config_kwargs = _get_bloom_config_from_args(args)
        if args.use_distributed_transformer > 0:
            args.use_distributed_transformer = 0
            logging.warning(
                "DistributedTransformer does not support Bloom, falling back "
                "to regular HF implementation."
            )
    elif model_type == "flan_t5":
        config_type = T5Config
        config_kwargs = _get_t5_config_from_args(args)
        if args.use_distributed_transformer > 0:
            args.use_distributed_transformer = 0
            logging.warning(
                "DistributedTransformer does not support T5, falling back "
                "to regular HF implementation."
            )

    if log:
        logging.info("Args for model %s:", model_type)
        for key, value in sorted(config_kwargs.items()):
            logging.info("  config %-20s: %s", key, value)

    return config_type(**config_kwargs), args