# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=fixme, invalid-name, missing-docstring, no-init, old-style-class, multiple-statements
# pylint: disable=arguments-differ, too-many-arguments, no-member
"""Visualization callback function
"""
try:
    import datetime
except ImportError:
    class Datetime_Failed_To_Import: pass
    datetime = Datetime_Failed_To_Import

try:
    import bokeh.plotting
except ImportError:
    pass

try:
    from collections import defaultdict
except ImportError:
    class Defaultdict_Failed_To_Import: pass
    defaultdict = Defaultdict_Failed_To_Import

try:
    import pandas as pd
except ImportError:
    class Pandas_Failed_To_Import: pass
    pd = Pandas_Failed_To_Import

import time
# pylint: enable=missing-docstring, no-init, old-style-class, multiple-statements


def _add_new_columns(dataframe, metrics):
    """Add new metrics as new columns to selected pandas dataframe.

    Parameters
    ----------
    dataframe : pandas.DataFrame
        Selected dataframe needs to be modified.
    metrics : metric.EvalMetric
        New metrics to be added.
    """
    #TODO(leodirac): we don't really need to do this on every update.  Optimize
    new_columns = set(metrics.keys()) - set(dataframe.columns)
    for col in new_columns:
        dataframe[col] = None


def _extend(baseData, newData):
    """Assuming a is shorter than b, copy the end of b onto a
    """
    baseData.extend(newData[len(baseData):])


class PandasLogger(object):
    """Logs statistics about training run into Pandas dataframes.
    Records three separate dataframes: train, eval, epoch.

    Parameters
    ----------
    batch_size: int
        batch_size of data
    frequent: int
        How many training mini-batches between calculations.
        Defaults to calculating every 50 batches.
        (Eval data is stored once per epoch over the entire
        eval data set.)
    """
    def __init__(self, batch_size, frequent=50):
        self.batch_size = batch_size
        self.frequent = frequent
        self._dataframes = {
            'train': pd.DataFrame(),
            'eval': pd.DataFrame(),
            'epoch': pd.DataFrame(),
        }
        self.last_time = time.time()
        self.start_time = datetime.datetime.now()
        self.last_epoch_time = datetime.datetime.now()

    @property
    def train_df(self):
        """The dataframe with training data.
        This has metrics for training minibatches, logged every
        "frequent" batches.  (frequent is a constructor param)
        """
        return self._dataframes['train']

    @property
    def eval_df(self):
        """The dataframe with evaluation data.
        This has validation scores calculated at the end of each epoch.
        """
        return self._dataframes['eval']

    @property
    def epoch_df(self):
        """The dataframe with epoch data.
        This has timing information.
        """
        return self._dataframes['epoch']

    @property
    def all_dataframes(self):
        """Return a dict of dataframes
        """
        return self._dataframes

    def elapsed(self):
        """Calcaulate the elapsed time from training starting.
        """
        return datetime.datetime.now() - self.start_time

    def append_metrics(self, metrics, df_name):
        """Append new metrics to selected dataframes.

        Parameters
        ----------
        metrics : metric.EvalMetric
            New metrics to be added.
        df_name : str
            Name of the dataframe to be modified.
        """
        dataframe = self._dataframes[df_name]
        _add_new_columns(dataframe, metrics)
        dataframe.loc[len(dataframe)] = metrics

    def train_cb(self, param):
        """Callback funtion for training.
        """
        if param.nbatch % self.frequent == 0:
            self._process_batch(param, 'train')

    def eval_cb(self, param):
        """Callback function for evaluation
        """
        self._process_batch(param, 'eval')

    def _process_batch(self, param, dataframe):
        """Update parameters for selected dataframe after a completed batch
        Parameters
        ----------
        dataframe : pandas.DataFrame
            Selected dataframe needs to be modified.
        """
        now = time.time()
        if param.eval_metric is not None:
            metrics = dict(param.eval_metric.get_name_value())
            param.eval_metric.reset()
        else:
            metrics = {}
        speed = self.frequent / (now - self.last_time)
        metrics['batches_per_sec'] = speed * self.batch_size
        metrics['records_per_sec'] = speed
        metrics['elapsed'] = self.elapsed()
        metrics['minibatch_count'] = param.nbatch
        metrics['epoch'] = param.epoch
        self.append_metrics(metrics, dataframe)
        self.last_time = now

    def epoch_cb(self):
        """Callback function after each epoch. Now it records each epoch time
        and append it to epoch dataframe.
        """
        metrics = {}
        metrics['elapsed'] = self.elapsed()
        now = datetime.datetime.now()
        metrics['epoch_time'] = now - self.last_epoch_time
        self.append_metrics(metrics, 'epoch')
        self.last_epoch_time = now

    def callback_args(self):
        """returns **kwargs parameters for model.fit()
        to enable all callbacks.  e.g.
        model.fit(X=train, eval_data=test, **pdlogger.callback_args())
        """
        return {
            'batch_end_callback': self.train_cb,
            'eval_end_callback': self.eval_cb,
            'epoch_end_callback': self.epoch_cb,
        }


class LiveBokehChart(object):
    """Callback object that renders a bokeh chart in a jupyter notebook
    that gets updated as the training run proceeds.

    Requires a PandasLogger to collect the data it will render.

    This is an abstract base-class.  Sub-classes define the specific chart.
    """
    def __init__(self, pandas_logger, metric_name, display_freq=10,
                 batch_size=None, frequent=50):
        if pandas_logger:
            self.pandas_logger = pandas_logger
        else:
            self.pandas_logger = PandasLogger(batch_size=batch_size, frequent=frequent)
        self.display_freq = display_freq
        self.last_update = time.time()
        #NOTE: would be nice to auto-detect the metric_name if there's only one.
        self.metric_name = metric_name
        bokeh.io.output_notebook()
        self.handle = self.setup_chart()

    def setup_chart(self):
        """Render a bokeh object and return a handle to it.
        """
        raise NotImplementedError("Incomplete base class: LiveBokehChart must be sub-classed")

    def update_chart_data(self):
        """Update the bokeh object with new data.
        """
        raise NotImplementedError("Incomplete base class: LiveBokehChart must be sub-classed")

    def interval_elapsed(self):
        """Check whether it is time to update plot.
        Returns
        -------
        Boolean value of whethe to update now
        """
        return time.time() - self.last_update > self.display_freq

    def _push_render(self):
        """Render the plot with bokeh.io and push to notebook.
        """
        bokeh.io.push_notebook(handle=self.handle)
        self.last_update = time.time()

    def _do_update(self):
        """Update the plot chart data and render the updates.
        """
        self.update_chart_data()
        self._push_render()

    def batch_cb(self, param):
        """Callback function after a completed batch.
        """
        if self.interval_elapsed():
            self._do_update()

    def eval_cb(self, param):
        """Callback function after an evaluation.
        """
        # After eval results, force an update.
        self._do_update()

    def callback_args(self):
        """returns **kwargs parameters for model.fit()
        to enable all callbacks.  e.g.
        model.fit(X=train, eval_data=test, **pdlogger.callback_args())
        """
        return {
            'batch_end_callback': self.batch_cb,
            'eval_end_callback': self.eval_cb,
        }


class LiveTimeSeries(LiveBokehChart):
    """Plot the elasped time during live learning.
    """
    def __init__(self, **fig_params):
        self.fig = bokeh.plotting.Figure(x_axis_type='datetime',
                                         x_axis_label='Elapsed time', **fig_params)
        super(LiveTimeSeries, self).__init__(None, None)  # TODO: clean up this class hierarchy

    def setup_chart(self):
        self.start_time = datetime.datetime.now()
        self.x_axis_val = []
        self.y_axis_val = []
        self.fig.line(self.x_axis_val, self.y_axis_val)
        return bokeh.plotting.show(self.fig, notebook_handle=True)

    def elapsed(self):
        """Calculate elasped time from starting
        """
        return datetime.datetime.now() - self.start_time

    def update_chart_data(self, value):
        self.x_axis_val.append(self.elapsed())
        self.y_axis_val.append(value)
        self._push_render()


class LiveLearningCurve(LiveBokehChart):
    """Draws a learning curve with training & validation metrics
    over time as the network trains.
    """
    def __init__(self, metric_name, display_freq=10, frequent=50):
        self.frequent = frequent
        self.start_time = datetime.datetime.now()
        self._data = {
            'train': {'elapsed': [],},
            'eval': {'elapsed': [],},
        }
        super(LiveLearningCurve, self).__init__(None, metric_name, display_freq, frequent)

    def setup_chart(self):
        self.fig = bokeh.plotting.Figure(x_axis_type='datetime',
                                         x_axis_label='Training time')
        #TODO(leodirac): There's got to be a better way to
        # get a bokeh plot to dynamically update as a pandas dataframe changes,
        # instead of copying into a list.
        # I can't figure it out though.  Ask a pyData expert.
        self.x_axis_val1 = []
        self.y_axis_val1 = []
        self.train1 = self.fig.line(self.x_axis_val1, self.y_axis_val1, line_dash='dotted',
                                    alpha=0.3, legend="train")
        self.train2 = self.fig.circle(self.x_axis_val1, self.y_axis_val1, size=1.5,
                                      line_alpha=0.3, fill_alpha=0.3, legend="train")
        self.train2.visible = False  # Turn this on later.
        self.x_axis_val2 = []
        self.y_axis_val2 = []
        self.valid1 = self.fig.line(self.x_axis_val2, self.y_axis_val2,
                                    line_color='green',
                                    line_width=2,
                                    legend="validation")
        self.valid2 = self.fig.circle(self.x_axis_val2,
                                      self.y_axis_val2,
                                      line_color='green',
                                      line_width=2, legend=None)
        self.fig.legend.location = "bottom_right"
        self.fig.yaxis.axis_label = self.metric_name
        return bokeh.plotting.show(self.fig, notebook_handle=True)

    def _do_update(self):
        self.update_chart_data()
        self._push_render()

    def batch_cb(self, param):
        if param.nbatch % self.frequent == 0:
            self._process_batch(param, 'train')
        if self.interval_elapsed():
            self._do_update()

    def eval_cb(self, param):
        # After eval results, force an update.
        self._process_batch(param, 'eval')
        self._do_update()

    def _process_batch(self, param, df_name):
        """Update selected dataframe after a completed batch
        Parameters
        ----------
        df_name : str
            Selected dataframe name needs to be modified.
        """
        if param.eval_metric is not None:
            metrics = dict(param.eval_metric.get_name_value())
            param.eval_metric.reset()
        else:
            metrics = {}
        metrics['elapsed'] = datetime.datetime.now() - self.start_time
        for key, value in metrics.items():
            if key not in self._data[df_name]:
                self._data[df_name][key] = []
            self._data[df_name][key].append(value)

    def update_chart_data(self):
        dataframe = self._data['train']
        if len(dataframe['elapsed']):
            _extend(self.x_axis_val1, dataframe['elapsed'])
            _extend(self.y_axis_val1, dataframe[self.metric_name])
        dataframe = self._data['eval']
        if len(dataframe['elapsed']):
            _extend(self.x_axis_val2, dataframe['elapsed'])
            _extend(self.y_axis_val2, dataframe[self.metric_name])
        if len(dataframe) > 10:
            self.train1.visible = False
            self.train2.visible = True


def args_wrapper(*args):
    """Generates callback arguments for model.fit()
    for a set of callback objects.
    Callback objects like PandasLogger(), LiveLearningCurve()
    get passed in.  This assembles all their callback arguments.
    """
    out = defaultdict(list)
    for callback in args:
        callback_args = callback.callback_args()
        for k, v in callback_args.items():
            out[k].append(v)
    return dict(out)