#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

r"""
Source: `pytorch imagenet example <https://github.com/pytorch/examples/blob/master/imagenet/main.py>`_ # noqa B950

Modified and simplified to make the original pytorch example compatible with
torchelastic.distributed.launch.

Changes:

1. Removed ``rank``, ``gpu``, ``multiprocessing-distributed``, ``dist_url`` options.
   These are obsolete parameters when using ``torchelastic.distributed.launch``.

2. Removed ``seed``, ``evaluate``, ``pretrained`` options for simplicity.

3. Removed ``resume``, ``start-epoch`` options.
   Loads the most recent checkpoint by default.

4. ``batch-size`` is now per GPU (worker) batch size rather than for all GPUs.

5. Defaults ``workers`` (num data loader workers) to ``0``.

Usage

::

 >>> python -m torchelastic.distributed.launch
        --nnodes=$NUM_NODES
        --nproc_per_node=$WORKERS_PER_NODE
        --rdzv_id=$JOB_ID
        --rdzv_backend=etcd
        --rdzv_endpoint=$ETCD_HOST:$ETCD_PORT
        main.py
        --arch resnet18
        --epochs 20
        --batch-size 32
        <DATA_DIR>
"""

import argparse, io, os, shutil, time, logging, operator

from contextlib import contextmanager
from datetime import timedelta
from typing import List, Tuple
from pathlib import Path

from tqdm import tqdm
import wandb

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from torch.optim import SGD, AdamW
from torch.optim.lr_scheduler import LinearLR

os.environ["WANDB_START_METHOD"] = "thread"
os.environ['WANDB_API_KEY'] = ''
wandb.login()

from datasets import load_dataset, Features, ClassLabel, Value, load_metric
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    LEDForConditionalGeneration
)

from torch.distributed.elastic.utils.data import ElasticDistributedSampler

from pathlib import Path
import pandas as pd

from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader

from collections import OrderedDict
from sklearn.metrics import recall_score, accuracy_score, f1_score, precision_score

logging.getLogger().setLevel(logging.INFO)

# TODO: Refactor load/save with Huggingface from_pretrained/save_pretrained api.
# Curently SGD and ADAMW produce problems with params (like momentum)

def run(args):

    local_rank = int(os.environ["LOCAL_RANK"])

    if local_rank == 0:
        wandb.init(config=args, project=args.wandb_project)
        args = wandb.config
        do_log = True
    else:
        
        do_log = False
    
    device_id = local_rank

    torch.cuda.set_device(device_id)
    logging.info(f"Set cuda device = {device_id}")

    dist.init_process_group(backend=args.dist_backend, init_method="env://", timeout=timedelta(seconds=120))

    model, criterion, optimizer = initialize_huggingface_model(
        args.arch, args.lr, args.momentum, args.weight_decay, args.optimizer, device_id
    )

    train_loader, val_loader = initialize_custom_data_loader(args.data, args.batch_size, args.workers)

    # resume from checkpoint if one exists;
    state = load_checkpoint(args.checkpoint_file, device_id, args.arch, model, optimizer)

    model_saver = SaveBestModel(args.checkpoint_file, do_log=do_log)

    # for most transformer based models Linear LR decays is best (1/10 total decay)
    scheduler = LinearLR(optimizer, start_factor=1., end_factor=0.1, total_iters=args.epochs)

    # start_epoch = state.epoch + 1

    print_freq = args.print_freq

    for epoch in range(args.epochs):
        state.epoch = epoch
        train_loader.batch_sampler.sampler.set_epoch(epoch)

        logging.info(f"training epoch {epoch}")
        train_loss_epoch = train(train_loader, model, criterion, optimizer, epoch, device_id, print_freq, do_log)
        scheduler.step()

        logging.info(f"validating epoch {epoch}")
        val_loss_epoch = validate(val_loader, model, criterion, device_id, print_freq, do_log)

        if device_id == 0:
            model_saver.save(state, val_loss_epoch)

        if device_id == 0:
            save_checkpoint(state, args.checkpoint_file)

    logging.info(f"Running predictions")
    if device_id == 0:
        # del model
        # torch.cuda.empty_cache()
        run_predictions(args.checkpoint_file, args.lr, args.optimizer, do_log)

    wandb.finish()


def main():
    parser = argparse.ArgumentParser(description="PyTorch Elastic HuggingFace Training")

    # Required paramaters
    parser.add_argument(
        "--data",
        metavar="DIR",
        default="/shared-efs/wandb-finbert",
        help="path to dataset",
    )
    parser.add_argument(
        "--wandb_project",
        default="aws_eks_demo",
        help="The wandb project name",
    )
    parser.add_argument(
        "--sweep_id",
        default=None,
        help="The Sweep id created by wandb",
    )

    # Other params
    parser.add_argument("--arch", default="HuggingFace")
    parser.add_argument("--workers", default=0, type=int, help="number of data loading workers")
    parser.add_argument("--epochs", default=1, type=int, help="number of total epochs to run")
    parser.add_argument("--batch-size", default=32, type=int, help="mini-batch size per worker (GPU)")
    parser.add_argument("--lr", default=1e-4, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, help="momentum")
    parser.add_argument("--weight-decay", default=1e-4, help="weight decay (default: 1e-4)")
    parser.add_argument("--print-freq", default=1, type=int, help="print frequency (default: 10)")
    parser.add_argument(
        "--dist-backend",
        default="nccl",
        choices=["nccl", "gloo"],
        help="distributed backend",
    )
    parser.add_argument(
        "--checkpoint-file",
        default="/shared-efs/checkpoint.pth.tar",
        help="checkpoint file path, to load and save to",
    )
    
    
    
    parser.add_argument("--optimizer", default="AdamW", help="optimizer type")

    args = parser.parse_args()

    wandb.require("service")
    wandb.setup()

    if args.sweep_id is not None:
        wandb.agent(args.sweep_id, lambda: run(args), project=args.wandb_project, count = 1)
    else:
        run(args=args)


class Dataset(torch.utils.data.Dataset):
    #'Characterizes a dataset for PyTorch'
    def __init__(self, data_dir, file_name):

        #'Initialization'
        self.data_dir = data_dir
        self.df = pd.read_csv(Path(data_dir) / file_name)

    def __len__(self):
        # 'Denotes the total number of samples'
        return len(self.df)

    def __getitem__(self, index):
        #'Generates one sample of data'
        # Select sample
        df = self.df
        one_line = df["Text"][index]
        label = df["labels"][index]

        return (one_line, label)


def collate_tokenize(data, tokenizer):
    text_batch = [element[0] for element in data]
    labels = [element[1] for element in data]
    tokenized_inputs = tokenizer(text_batch, padding="max_length", truncation=True, return_tensors="pt")

    tokenized_inputs["labels"] = torch.tensor(labels)
    tokenized_inputs["attention_mask"] = tokenized_inputs["attention_mask"]

    return tokenized_inputs


class MyCollator(object):
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        # do something with batch and self.params
        tokenized_inputs = collate_tokenize(batch, self.tokenizer)

        return tokenized_inputs


class State:
    """
    Container for objects that we want to checkpoint. Represents the
    current "state" of the worker. This object is mutable.
    """

    def __init__(self, arch, model, optimizer):
        self.epoch = -1
        self.best_acc1 = 10
        self.arch = arch
        self.model = model
        self.optimizer = optimizer

    def capture_snapshot(self):
        """
        Essentially a ``serialize()`` function, returns the state as an
        object compatible with ``torch.save()``. The following should work
        ::
        snapshot = state_0.capture_snapshot()
        state_1.apply_snapshot(snapshot)
        assert state_0 == state_1
        """
        return {
            "epoch": self.epoch,
            "best_acc1": self.best_acc1,
            "arch": self.arch,
            "state_dict": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
        }

    def apply_snapshot(self, obj, device_id):
        """
        The complimentary function of ``capture_snapshot()``. Applies the
        snapshot object that was returned by ``capture_snapshot()``.
        This function mutates this state object.
        """

        self.epoch = obj["epoch"]
        self.best_acc1 = obj["best_acc1"]
        self.state_dict = obj["state_dict"]
        self.model.load_state_dict(obj["state_dict"])
        self.optimizer.load_state_dict(obj["optimizer"])

    def save(self, f):
        torch.save(self.capture_snapshot(), f)

    def load(self, f, device_id):
        # Map model to be loaded to specified single gpu.
        snapshot = torch.load(f, map_location=f"cuda:{device_id}")
        self.apply_snapshot(snapshot, device_id)


def initialize_huggingface_model(
    arch: str,
    lr: float,
    momentum: float,
    weight_decay: float,
    optimizer_type,
    device_id: int,
):
    logging.info(f"=> creating model: {arch}")

    ## Initializing the model
    model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)

    # For multiprocessing distributed, DistributedDataParallel constructor
    # should always set the single device scope, otherwise,
    # DistributedDataParallel will use all available devices.

    model.cuda(device_id)

    cudnn.benchmark = True

    model = DistributedDataParallel(model, device_ids=[device_id])

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(device_id)

    # initialize optimizer
    if optimizer_type == "AdamW":
        optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    if optimizer_type == "SGD":
        optimizer = SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

    return model, criterion, optimizer


def initialize_custom_data_loader(data_dir, batch_size, num_data_workers) -> Tuple[DataLoader, DataLoader]:

    # Generators
    train_dataset = Dataset(data_dir, file_name="train.csv")
    logging.info("Train dataset done")

    train_sampler = ElasticDistributedSampler(train_dataset)
    logging.info("Train sampler done")

    model_name = "bert-base-cased"
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    my_collator = MyCollator(tokenizer)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_data_workers,
        pin_memory=True,
        collate_fn=my_collator,
        sampler=train_sampler,
    )

    logging.info("Train loader done")

    test_dataset = Dataset(data_dir, file_name="test.csv")

    logging.info("Test dataset done")

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        num_workers=num_data_workers,
        pin_memory=True,
        collate_fn=my_collator,
    )

    logging.info("Test loader done")

    return train_loader, test_loader


def load_checkpoint(
    checkpoint_file: str,
    device_id: int,
    arch: str,
    model: DistributedDataParallel,
    optimizer,  # SGD
) -> State:
    """
    Loads a local checkpoint (if any). Otherwise, checks to see if any of
    the neighbors have a non-zero state. If so, restore the state
    from the rank that has the most up-to-date checkpoint.
    .. note:: when your job has access to a globally visible persistent storage
              (e.g. nfs mount, S3) you can simply have all workers load
              from the most recent checkpoint from such storage. Since this
              example is expected to run on vanilla hosts (with no shared
              storage) the checkpoints are written to local disk, hence
              we have the extra logic to broadcast the checkpoint from a
              surviving node.
    """

    state = State(arch, model, optimizer)
    
    print('***** Checkpoint File = '+checkpoint_file)

    if os.path.isfile(checkpoint_file):
        logging.info(f"=> loading checkpoint file: {checkpoint_file}")
        state.load(checkpoint_file, device_id)
        logging.info(f"=> loaded checkpoint file: {checkpoint_file}")

    logging.info(f"=> done restoring from previous checkpoint")
    return state


@contextmanager
def tmp_process_group(backend):
    cpu_pg = dist.new_group(backend=backend)
    try:
        yield cpu_pg
    finally:
        dist.destroy_process_group(cpu_pg)


class SaveBestModel:
    "A simple model saver Callback"

    def __init__(self, filename, min_metric=True, do_log=True):
        self.filename = filename
        self.min_metric = min_metric
        self.do_log = do_log
        self.checkpoint_dir = os.path.dirname(filename)
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        self.best = 100 if min_metric else -1

    def save(self, state, metric_value):
        torch.save(state.capture_snapshot(), self.filename)
        op = operator.lt if self.min_metric else operator.gt
        if op(metric_value, self.best):
            logging.info(f"=> best model found at epoch {state.epoch}")
            self._save()

    def _save(self):
        best_model = os.path.join(self.checkpoint_dir, "model_best.pth.tar")
        shutil.copyfile(self.filename, best_model)
        if self.do_log:
             self.log_model(best_model)

    def log_model(self, path, metadata={}, description="trained model"):
        "Log model file"
        if wandb.run is None:
            raise ValueError("You must call wandb.init() before log_model()")
        path = Path(path)
        if not path.is_file():
            raise f"path must be a valid file: {path}"
        name = f"run-{wandb.run.id}-model"
        artifact_model = wandb.Artifact(name=name, type="model", metadata=metadata, description=description)
        with artifact_model.new_file(name, mode="wb") as fa:
            fa.write(path.read_bytes())
        wandb.run.log_artifact(artifact_model)


def save_checkpoint(state: State, filename: str):
    checkpoint_dir = os.path.dirname(filename)
    os.makedirs(checkpoint_dir, exist_ok=True)

    # save to tmp, then commit by moving the file in case the job
    # gets interrupted while writing the checkpoint
    #tmp_filename = filename + ".tmp"
    torch.save(state.capture_snapshot(), filename)
    #os.rename(tmp_filename, filename)
    print(f"=> saved checkpoint for epoch {state.epoch} at {filename}")
    # if is_best:
    #     best = os.path.join(checkpoint_dir, "model_best.pth.tar")
    #     print(f"=> best model found at epoch {state.epoch} saving to {best}")
    #     shutil.copyfile(filename, best)


def train(
    train_loader: DataLoader,
    model: DistributedDataParallel,
    criterion,  # nn.CrossEntropyLoss
    optimizer,  # AdamW,
    epoch: int,
    device_id: int,
    print_freq: int,
    do_log: bool,
):
    losses = AverageMeter("Loss", ":.4e")

    model.train()

    for batch in tqdm(train_loader, total=len(train_loader)):

        optimizer.zero_grad()
        input_ids = batch["input_ids"].cuda(device_id, non_blocking=True)
        attention_mask = batch["attention_mask"].cuda(device_id, non_blocking=True)
        labels = batch["labels"].cuda(device_id, non_blocking=True)

        # forward pass
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)

        # hf models return loss as 1st argument
        loss = outputs.loss
        if do_log:
            wandb.log({"train_loss": loss.item()})

        # # measure accuracy and record loss
        losses.update(loss.item(), input_ids.size(0))

        loss.backward()
        optimizer.step()

    return losses.avg


def validate(
    val_loader: DataLoader,
    model: DistributedDataParallel,
    criterion,  # nn.CrossEntropyLoss
    device_id: int,
    print_freq: int,
    do_log: bool,
):
    losses = AverageMeter("Loss", ":.4e")

    metrics = [Metric("val_recall", recall_score), 
               Metric("val_f1",f1_score),
               Metric("val_accuracy", accuracy_score),
               Metric("val_precision", precision_score)]

    # switch to evaluate mode
    model.eval()

    with torch.inference_mode():
        for batch in tqdm(val_loader, total=len(val_loader)):

            input_ids = batch["input_ids"].cuda(device_id, non_blocking=True)
            attention_mask = batch["attention_mask"].cuda(device_id, non_blocking=True)
            labels = batch["labels"].cuda(device_id, non_blocking=True)

            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)

            # compute output
            loss = outputs[0]

            # # measure accuracy and record loss
            losses.update(loss.item(), input_ids.size(0))

            #compute metrics
            pred_labels = outputs.logits.argmax(axis=1).cpu()
            true_labels = labels.cpu()
            
            for m in metrics:
                m.update(pred_labels, true_labels)

        if do_log:
            wandb.log({"val_loss": losses.avg})
            wandb.log({m.name:m.avg for m in metrics})

    return losses.avg


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name: str, fmt: str = ":f"):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self) -> None:
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1) -> None:
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
        return fmtstr.format(**self.__dict__)

class Metric(AverageMeter):

    def __init__(self, name, func):
        super().__init__(name)
        self.func = func
    
    def update(self, y_pred, y_true):
        val = self.func(y_pred=y_pred, y_true=y_true)
        super().update(val)


# def accuracy(output, target, topk=(1,)):
#     """
#     Computes the accuracy over the k top predictions for the specified values of k
#     """
#     with torch.no_grad():
#         maxk = max(topk)
#         batch_size = target.size(0)

#         _, pred = output.topk(maxk, 1, True, True)
#         pred = pred.t()
#         correct = pred.eq(target.view(1, -1).expand_as(pred))

#         res = []
#         for k in topk:
#             correct_k = correct[:k].reshape(1, -1).view(-1).float().sum(0, keepdim=True)
#             res.append(correct_k.mul_(100.0 / batch_size))
#         return res


def run_predictions(checkpoint_filename, lr, optimizer, do_log):
    print("**********************\nRunning predictions\n**********************")
    
    checkpoint_dir = os.path.dirname(checkpoint_filename)
    logging.info(checkpoint_dir)
    
    # Run inference on CPU to avoid Cuda out of memory issue
    # device = torch.device("cuda")
    device = torch.device("cpu")

    model_name = "bert-base-cased"
    test_file = "/shared-efs/wandb-finbert/test.csv"

    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
    
    checkpoint = torch.load(checkpoint_dir + "/checkpoint.tar")
    state_dict = checkpoint["state_dict"]
    # create new OrderedDict that does not contain `module.`

    logging.info("Doing some state dict magic")
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        # name = k[7:] # remove `module.`
        name = k.replace("module.", "")  # removing ‘moldule.’ from key
        new_state_dict[name] = v
    # load params
    model.load_state_dict(new_state_dict)

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    test_df = pd.read_csv(test_file)
    
    print('****Len test_df = '+str(test_df.shape[0]))

    logging.info("Tokening inputs")
    tokenized_test_inputs = tokenizer(
        list(test_df["Text"]),
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )

    
    tokenized_test_inputs.to(device)
    model.to(device)
    logging.info(f"Running inference on: {device}")

    model.eval()
    with torch.no_grad():
        # preds = model(**tokenized_test_inputs)
        preds = model(tokenized_test_inputs['input_ids'], tokenized_test_inputs['attention_mask'])

    pred_labels = preds.logits.argmax(axis=1).tolist()
    true_labels = list(test_df["labels"])

    recall = recall_score(y_pred=pred_labels, y_true=true_labels)
    f1 = f1_score(y_pred=pred_labels, y_true=true_labels)
    accuracy = accuracy_score(y_pred=pred_labels, y_true=true_labels)
    precision = precision_score(y_pred=pred_labels, y_true=true_labels)

    pred_dict = {
        "test_recall": recall,
        "test_f1": f1,
        "test_accuracy": accuracy,
        "test_precision": precision,
    }

    metrics_df = pd.DataFrame()
    metrics_df = metrics_df.append(pred_dict, ignore_index=True)

    run_name = checkpoint_dir.split("/")[-1]
    metrics_df["run_name"] = run_name
    metrics_df["lr"] = lr
    metrics_df["optimizer"] = optimizer
    
    print(metrics_df)

    out_file = "all_results.csv"

    logging.info(f"Logging metrics to : {out_file}")
    if os.path.exists(f"/shared-efs/wandb-finbert/{out_file}"):
        metrics_df.to_csv(f"/shared-efs/wandb-finbert/{out_file}", mode="a", index=False, header=False)
    else:
        metrics_df.to_csv(f"/shared-efs/wandb-finbert/{out_file}", index=False)
        
    if do_log:
        wandb.log(pred_dict)

    return None


if __name__ == "__main__":
    main()