import os
import sys
import time
import json
import logging
import argparse
from glob import glob

import monai
import torch
import numpy as np
import nibabel as nib
from monai.config import print_config
from monai.utils import set_determinism
from monai.data import partition_dataset, Dataset, CacheDataset, PersistentDataset, SmartCacheDataset, DataLoader 
from monai.transforms import (
    Activations,
    AsChannelFirstd,
    AsDiscrete,
    CenterSpatialCropd,
    Compose,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    ToTensord,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))
print_config()

class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    Convert labels to multi channels based on brats classes:
    label 1 is the peritumoral edema
    label 2 is the GD-enhancing tumor
    label 3 is the necrotic and non-enhancing tumor core
    The possible classes are TC (Tumor core), WT (Whole tumor)
    and ET (Enhancing tumor).

    """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            # merge label 2 and label 3 to construct TC
            result.append(np.logical_or(d[key] == 2, d[key] == 3))
            # merge labels 1, 2 and 3 to construct WT
            result.append(
                np.logical_or(
                    np.logical_or(d[key] == 2, d[key] == 3), d[key] == 1
                )
            )
            # label 2 is ET
            result.append(d[key] == 2)
            d[key] = np.stack(result, axis=0).astype(np.float32)
        return d

def get_data_loaders(args):
    """
    This function loads input/output file paths, builds a MONAI DataLoader
    from a PersistentDataset (child of torch dataset) for them.
    It returns a DataLoader for train and validation splits 
    """    
    images = sorted(glob(os.path.join(args.train, 'imagesTr', 'BRATS*.nii.gz')))
    segs = sorted(glob(os.path.join(args.train, 'labelsTr', 'BRATS*.nii.gz')))
    data_list = [{"image": img, "label": seg} for img, seg in zip(images, segs)]
    logger.info('Total file pairs: %d' % (len(data_list)))
    
    # split the data_list into train and val
    train_files, val_files = partition_dataset(
        data_list,
        ratios=[0.99, 0.01],
        shuffle=True,
        seed=args.seed
    )
    logger.info('# of train and val: %d/%d' % (len(train_files), len(val_files)))

    # define transforms for image and segmentation
    train_transforms = Compose(
        [
            # load 4 Nifti images and stack them together
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys="image"),
            ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
            Spacingd(
                keys=["image", "label"],
                pixdim=(1.5, 1.5, 2.0),
                mode=("bilinear", "nearest"),
            ),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            RandSpatialCropd(
                keys=["image", "label"], roi_size=[128, 128, 64], random_size=False
            ),
            NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
            RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
            RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
            RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
            ToTensord(keys=["image", "label"]),
        ]
    )
    
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AsChannelFirstd(keys="image"),
            ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
            Spacingd(
                keys=["image", "label"],
                pixdim=(1.5, 1.5, 2.0),
                mode=("bilinear", "nearest"),
            ),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            CenterSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64]),
            NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
            ToTensord(keys=["image", "label"]),
        ]
    )
                
    # create training/validation data loaders
    logger.info('Defining Train Dataset')
    persistent_cache = os.path.join(args.train, args.cache_dir)

    if args.torch_dataset_type == 'CacheDataset':
        train_ds = CacheDataset(
            data=train_files,
            transform=train_transforms,
            cache_rate=1.0,
            num_workers=args.num_workers
        )
    elif args.torch_dataset_type == 'PersistentDataset':
        train_ds = PersistentDataset(
            data=train_files,
            transform=train_transforms, 
            cache_dir=persistent_cache
        )
    elif args.torch_dataset_type == 'SmartCacheDataset':
        train_ds = SmartCacheDataset(
            data=train_files,
            transform=train_transforms,
            cache_rate = args.cache_rate,
            replace_rate = args.replace_rate,
            num_init_workers=args.num_workers,
            num_replace_workers = args.num_workers
        )
    else:
        train_ds = Dataset(data=train_files, transform=train_transforms)    
    
    logger.info('Defining Validation Dataset')
    #val_ds = Dataset(data=val_files, transform=val_transforms)
    val_ds = PersistentDataset(
        data=val_files,
        transform=val_transforms,
        cache_dir=persistent_cache
    )
    
    
    logger.info('Defining DataLoader')
    train_loader = DataLoader(
        train_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers
    )

    return train_loader, val_loader

    
def get_model(args):
    logger.info('Defining network')    
    device = torch.device("cuda") # WARNING: do not assign local rank here, i.e. cuda:0
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=4, 
        out_channels=3, 
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

    loss_function = monai.losses.DiceLoss(to_onehot_y=False, sigmoid=True, squared_pred=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), args.lr)
    
    return device, model, loss_function, optimizer
    
def train(args, train_loader, val_loader, device, model, loss_function, optimizer):
    # Training epochs
    val_interval = 5
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = []
    metric_values = []
    metric_values_tc = []
    metric_values_wt = []
    metric_values_et = []
    for epoch in range(args.epochs):
        logger.info("-" * 10)
        logger.info(f"epoch {epoch + 1}/{args.epochs}")
        model.train()
        epoch_loss = 0
        step = 0
        tic = time.time()

        #train_sampler.set_epoch(epoch)
        #epoch_len = len(train_ds) // (train_loader.batch_size * args.world_size)
        
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            #logger.info(f"train_loss: {loss.item():.4f}")
            #logger.info(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")

        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        logger.info(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        logger.info("secs_time_per_epoch: {}".format(time.time() - tic))
        
        # validation loop
        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                dice_metric = monai.metrics.DiceMetric(include_background=True, reduction="mean")
                post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
                metric_sum = metric_sum_tc = metric_sum_wt = metric_sum_et = 0.0
                metric_count = (
                    metric_count_tc
                ) = metric_count_wt = metric_count_et = 0
                for val_data in val_loader:
                    val_inputs, val_labels = (
                        val_data["image"].to(device),
                        val_data["label"].to(device),
                    )
                    val_outputs = model(val_inputs)
                    val_outputs = post_trans(val_outputs)
                    # compute overall mean dice
                    value, not_nans = dice_metric(y_pred=val_outputs, y=val_labels)
                    not_nans = not_nans.item()
                    metric_count += not_nans
                    metric_sum += value.item() * not_nans
                    # compute mean dice for TC
                    value_tc, not_nans = dice_metric(y_pred=val_outputs[:, 0:1], y=val_labels[:, 0:1])
                    not_nans = not_nans.item()
                    metric_count_tc += not_nans
                    metric_sum_tc += value_tc.item() * not_nans
                    # compute mean dice for WT
                    value_wt, not_nans = dice_metric(y_pred=val_outputs[:, 1:2], y=val_labels[:, 1:2])
                    not_nans = not_nans.item()
                    metric_count_wt += not_nans
                    metric_sum_wt += value_wt.item() * not_nans
                    # compute mean dice for ET
                    value_et, not_nans = dice_metric(y_pred=val_outputs[:, 2:3], y=val_labels[:, 2:3])
                    not_nans = not_nans.item()
                    metric_count_et += not_nans
                    metric_sum_et += value_et.item() * not_nans

                metric = metric_sum / metric_count
                metric_values.append(metric)
                metric_tc = metric_sum_tc / metric_count_tc
                metric_values_tc.append(metric_tc)
                metric_wt = metric_sum_wt / metric_count_wt
                metric_values_wt.append(metric_wt)
                metric_et = metric_sum_et / metric_count_et
                metric_values_et.append(metric_et)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), os.path.join(args.model_dir, "best_metric_model.pth"))
                    logger.info("saved new best metric model")
                logger.info(
                    f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                    f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
                    f"\nbest mean dice: {best_metric:.4f}"
                    f" at epoch: {best_metric_epoch}"
                )

    logger.info(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--torch_dataset_type', type=str, default='Dataset',
                        help='torch dataset for training dataloader (default: Dataset)')
    parser.add_argument('--batch_size', type=int, default=8,
                        help='input batch size for training (default: 2)')
    parser.add_argument('--epochs', type=int, default=5,
                        help='number of epochs to train (default: 5)')
    parser.add_argument('--replace_rate', type=float, default=0.2,
                        help='replace rate for SmartCacheDataset (default: 0.2)')
    parser.add_argument('--cache_rate', type=float, default=0.5,
                        help='Cache rate for SmartCacheDataset (default: 0.5)')
    parser.add_argument('--lr', type=float, default=1e-3,
                        help='learning rate (default: 1e-3)')
    parser.add_argument('--num_workers', type=int, default=4,
                        help='num cpu workers (default: 4)')
    parser.add_argument('--seed', type=int, default=42,
                        help='random seed (default: 42)')
    parser.add_argument('--cache_dir', type=str, default='monai_persistent_cache')
    parser.add_argument('--model_dir', type=str, default=os.environ['SM_MODEL_DIR'])
    parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
    #parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST'])

    args = parser.parse_args()
    args.batch_size = max(args.batch_size, 1)
    logger.info('Arguments: %s' % args)
    if not torch.cuda.is_available():
        raise Exception("Must run this example on CUDA-capable devices.") 
        
    set_determinism(seed=args.seed)
    train_loader, val_loader = get_data_loaders(args=args)
    device, model, loss_function, optimizer = get_model(args)
    train(args, train_loader, val_loader, device, model, loss_function, optimizer)
    
if __name__ == "__main__":
    main()