import os
import numpy as np
import skimage
import torch
from tqdm import tqdm
from warnings import warn
from .model_io import get_model
from .transform import process_aug_dict
from .datagen import InferenceTiler
from ..raster.image import stitch_images
from ..utils.core import get_data_paths


class Inferer(object):
    """Object for training `solaris` models using PyTorch or Keras."""

    def __init__(self, config, custom_model_dict=None):
        self.config = config
        self.batch_size = self.config['batch_size']
        self.framework = self.config['nn_framework']
        self.model_name = self.config['model_name']
        # check if the model was trained as part of the same pipeline; if so,
        # use the output from that. If not, use the pre-trained model directly.
        if self.config['train']:
            warn('Because the configuration specifies both training and '
                 'inference, solaris is switching the model weights path '
                 'to the training output path.')
            self.model_path = self.config['training']['model_dest_path']
            if custom_model_dict is not None:
                custom_model_dict['weight_path'] = self.config[
                    'training']['model_dest_path']
        else:
            self.model_path = self.config.get('model_path', None)
        self.model = get_model(self.model_name, self.framework,
                               self.model_path, pretrained=True,
                               custom_model_dict=custom_model_dict)
        self.window_step_x = self.config['inference'].get('window_step_size_x',
                                                          None)
        self.window_step_y = self.config['inference'].get('window_step_size_y',
                                                          None)
        if self.window_step_x is None:
            self.window_step_x = self.config['data_specs']['width']
        if self.window_step_y is None:
            self.window_step_y = self.config['data_specs']['height']
        self.stitching_method = self.config['inference'].get(
            'stitching_method', 'average')
        self.output_dir = self.config['inference']['output_dir']
        if not os.path.isdir(self.output_dir):
            os.makedirs(self.output_dir)

    def __call__(self, infer_df=None):
        """Run inference.

        Arguments
        ---------
        infer_df : :class:`pandas.DataFrame` or `str`
            A :class:`pandas.DataFrame` with a column, ``'image'``, specifying
            paths to images for inference. Alternatively, `infer_df` can be a
            path to a CSV file containing the same information.  Defaults to
            ``None``, in which case the file path specified in the Inferer's
            configuration dict is used.

        """

        if infer_df is None:
            infer_df = get_infer_df(self.config)

        inf_tiler = InferenceTiler(
            self.framework,
            width=self.config['data_specs']['width'],
            height=self.config['data_specs']['height'],
            x_step=self.window_step_x,
            y_step=self.window_step_y,
            augmentations=process_aug_dict(
                self.config['inference_augmentation'])
            )
        for idx, im_path in tqdm(enumerate(infer_df['image']),
                                 total=len(infer_df['image'])):
            inf_input, idx_refs, (
                src_im_height, src_im_width) = inf_tiler(im_path)

            if self.framework == 'keras':
                subarr_preds = self.model.predict(inf_input,
                                                  batch_size=self.batch_size)

            elif self.framework in ['torch', 'pytorch']:
                with torch.no_grad():
                    self.model.eval()
                if torch.cuda.is_available():
                    device = torch.device('cuda')
                    self.model = self.model.cuda()
                else:
                    device = torch.device('cpu')
                inf_input = torch.from_numpy(inf_input).float().to(device)
                # add additional input data, if applicable
                if self.config['data_specs'].get('additional_inputs',
                                                 None) is not None:
                    inf_input = [inf_input]
                    for i in self.config['data_specs']['additional_inputs']:
                        inf_input.append(
                            infer_df[i].iloc[idx].to(device))

                # Revision: allow batch process to save Mem cost.
                subarr_preds_list = []
                for batch_i in range(0, inf_input.shape[0], self.batch_size):
                    if batch_i + self.batch_size <= inf_input.shape[0]:
                        subarr_pred = self.model(inf_input[
                            batch_i:batch_i+self.batch_size, ...
                        ])
                    else:
                        subarr_pred = self.model(inf_input[
                            batch_i:, ...
                        ])
                    subarr_preds_list.append(subarr_pred.cpu().data.numpy())
                subarr_preds = np.concatenate(subarr_preds_list, axis=0)
            stitched_result = stitch_images(subarr_preds,
                                            idx_refs=idx_refs,
                                            out_width=src_im_width,
                                            out_height=src_im_height,
                                            method=self.stitching_method)
            skimage.io.imsave(os.path.join(self.output_dir,
                                           os.path.split(im_path)[1]),
                              stitched_result, check_contrast=False)


def get_infer_df(config):
    """Get the inference df based on the contents of ``config`` .

    This function uses the logic described in the documentation for the config
    file to determine where to find images to be used for inference.
    See the docs and the comments in solaris/data/config_skeleton.yml for
    details.

    Arguments
    ---------
    config : dict
        The loaded configuration dict for model training and/or inference.

    Returns
    -------
    infer_df : :class:`dict`
        :class:`dict` containing at least one column: ``'image'`` . The values
        in this column correspond to the path to filenames to perform inference
        on.
    """

    infer_df = get_data_paths(config['inference_data_csv'], infer=True)
    return infer_df