#  ___________________________________________________________________________
#
#  Pyomo: Python Optimization Modeling Objects
#  Copyright 2017 National Technology and Engineering Solutions of Sandia, LLC
#  Under the terms of Contract DE-NA0003525 with National Technology and 
#  Engineering Solutions of Sandia, LLC, the U.S. Government retains certain 
#  rights in this software.
#  This software is distributed under the 3-clause BSD License.
#  ___________________________________________________________________________

__all__ = ("ScenarioTreeActionManagerPyro",)

import time
import itertools
import logging
from collections import defaultdict
import base64
try:
    import cPickle as pickle
except:
    import pickle

from pyomo.common.dependencies import attempt_import
from pyomo.opt.parallel.manager import ActionStatus
from pyomo.opt.parallel.pyro import PyroAsynchronousActionManager

pyu_pyro = attempt_import('pyutilib.pyro', alt_names=['pyu_pyro'])[0]
Pyro4 = attempt_import('Pyro4')[0]

import six
from six import advance_iterator, iteritems, itervalues

logger = logging.getLogger('pyomo.pysp')

#
# a specialized asynchronous action manager for the scenariotreeserver
#

class ScenarioTreeActionManagerPyro(PyroAsynchronousActionManager):

    def __init__(self, *args, **kwds):
        super(ScenarioTreeActionManagerPyro, self).__init__(*args, **kwds)
        # the SPPyroScenarioTreeServer objects associated with
        # this manager
        self.server_pool = []
        self._server_name_to_dispatcher_name = {}
        self._dispatcher_name_to_server_names = {}
        # tells the action manager to ignore task errors
        # (it will still report them, just take no action)
        self.ignore_task_errors = False

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.close()

    def close(self):
        """Close the manager."""
        if len(self.server_pool):
            self.release_servers()
        super(ScenarioTreeActionManagerPyro, self).close()

    def acquire_servers(self, servers_requested, timeout=None):

        if self._verbose:
            print("Attempting to acquire %s scenario tree servers"
                  % (servers_requested))
            if timeout is None:
                print("Timeout has been disabled")
            else:
                print("Automatic timeout in %s seconds" % (timeout))

        assert len(self.server_pool) == 0
        assert len(self._dispatcher_name_to_client) == 0
        assert len(self._server_name_to_dispatcher_name) == 0
        assert len(self._dispatcher_name_to_server_names) == 0
        assert len(self._dispatcher_proxies) == 0
        #
        # This process consists of the following steps:
        #
        # (1) Obtain the list of dispatchers from the nameserver
        # (2) Acquire all workers currently registered on each dispatcher
        # (3) Repeat (1) and (2) until we reach the timeout (if it exists)
        #     or until we obtain the number of servers requested
        # (4) Release any servers we don't need on dispatchers
        #
        wait_start = time.time()
        dispatcher_registered_servers = defaultdict(list)
        dispatcher_servers_to_release = defaultdict(list)
        dispatcher_proxies = {}
        servers_acquired = 0
        while servers_acquired < servers_requested:

            if (timeout is not None) and \
               ((time.time()-wait_start) > timeout):
                print("Timeout reached before %s servers could be acquired. "
                      "Proceeding with %s servers."
                      % (servers_requested, servers_acquired))
                break

            try:
                dispatchers = pyu_pyro.util.get_dispatchers(
                    host=self.host,
                    port=self.port,
                    caller_name="Client")
            except pyu_pyro.util._connection_problem:
                print("Failed to obtain one or more dispatchers from nameserver")
                continue
            for (name, uri) in dispatchers:
                dispatcher = None
                server_names = None
                if name not in dispatcher_proxies:
                    # connect to the dispatcher
                    if pyu_pyro.using_pyro3:
                        dispatcher = pyu_pyro.Pyro.core.getProxyForURI(uri)
                    else:
                        dispatcher = pyu_pyro.Pyro.Proxy(uri)
                        dispatcher._pyroTimeout = 10
                    try:
                        server_names = dispatcher.acquire_available_workers()
                    except pyu_pyro.util._connection_problem:
                        if pyu_pyro.using_pyro4:
                            dispatcher._pyroRelease()
                        else:
                            dispatcher._release()
                        continue
                    dispatcher_proxies[name] = dispatcher
                    if pyu_pyro.using_pyro4:
                        dispatcher._pyroTimeout = None
                else:
                    dispatcher = dispatcher_proxies[name]
                    server_names = dispatcher.acquire_available_workers()

                # collect the list of registered PySP workers
                servers_to_release = dispatcher_servers_to_release[name]
                registered_servers = dispatcher_registered_servers[name]
                for server_name in server_names:
                    if server_name.startswith("ScenarioTreeServerPyro_"):
                        registered_servers.append(server_name)
                    else:
                        servers_to_release.append(server_name)

                if (timeout is not None) and \
                   ((time.time()-wait_start) > timeout):
                    break

            servers_acquired = sum(len(_serverlist) for _serverlist
                                   in itervalues(dispatcher_registered_servers))
            # Don't overload the nameserver while trying to
            # collect dispatchers with registered workers.
            # If you haven't found them after the first few tries,
            # it's very likely that you are not going to.
            time.sleep(0.5)

        for name, servers_to_release in iteritems(dispatcher_servers_to_release):
            dispatcher_proxies[name].release_acquired_workers(servers_to_release)
        del dispatcher_servers_to_release

        #
        # Decide which servers we will utilize and do this in such a way
        # as to balance the workload we place on each dispatcher
        #
        server_to_dispatcher_map = {}
        dispatcher_servers_utilized = defaultdict(list)
        servers_utilized = 0
        dispatcher_names = itertools.cycle(dispatcher_registered_servers.keys())
        while servers_utilized < min(servers_requested, servers_acquired):
            name = advance_iterator(dispatcher_names)
            if len(dispatcher_registered_servers[name]) > 0:
                servername = dispatcher_registered_servers[name].pop()
                server_to_dispatcher_map[servername] = name
                dispatcher_servers_utilized[name].append(servername)
                servers_utilized += 1

        # copy the keys as we are modifying this list
        dispatcher_proxies_byURI = {}
        for name in list(dispatcher_proxies.keys()):
            dispatcher = dispatcher_proxies[name]
            servers = dispatcher_servers_utilized[name]
            if len(dispatcher_registered_servers[name]) > 0:
                # release any servers we do not need
                dispatcher.release_acquired_workers(
                    dispatcher_registered_servers[name])
            if len(servers) == 0:
                # release the proxy to this dispatcher,
                # we don't need it
                if pyu_pyro.using_pyro4:
                    dispatcher._pyroRelease()
                else:
                    dispatcher._release()
                del dispatcher_proxies[name]
            else:
                # when we initialize a client directly with a dispatcher
                # proxy it does not need to know the nameserver host or port
                client = self._create_client(dispatcher=dispatcher)
                self._dispatcher_name_to_server_names[client.URI] = servers
                dispatcher_proxies_byURI[client.URI] = dispatcher
                for servername in servers:
                    self._server_name_to_dispatcher_name[servername] = client.URI
                    self.server_pool.append(servername)
        self._dispatcher_proxies = dispatcher_proxies_byURI

    def release_servers(self):

        if self._verbose:
            print("Releasing scenario tree servers")

        for name in self._dispatcher_proxies:
            dispatcher = self._dispatcher_proxies[name]
            servers = self._dispatcher_name_to_server_names[name]
            # tell dispatcher that the servers we have acquired are no
            # longer needed
            dispatcher.release_acquired_workers(servers)

        self.server_pool = []
        self._server_name_to_dispatcher_name = {}
        self._dispatcher_name_to_server_names = {}

    #
    # Abstract Methods
    #

    def _get_dispatcher_name(self, queue_name):
        return self._server_name_to_dispatcher_name[queue_name]

    def _get_task_data(self, ah, **kwds):
        # Doing this serves two purposes:
        #   (1) It avoids issues with transmitting user-defined
        #       types over the wire that the dispatcher is not
        #       aware of (and therefore unable to de-serialize)
        #   (2) It improves performance on the dispatcher
        #       because de-serialization (and
        #       re-serialization) of raw bytes should be
        #       about as trivial as you can get for any
        #       serializer that Pyro/Pyro4 happens to be
        #       configured with (pickle is the fastest,
        #       but that is not the default in Pyro4 for
        #       security reasons).
        return pickle.dumps(kwds)

    def _download_results(self):

        found_results = False
        for client in itervalues(self._dispatcher_name_to_client):
            if len(self._dispatcher_name_to_client) == 1:
                # if there is a single dispatcher then we can do
                # a more efficient blocking call
                results = client.get_results(override_type=client.CLIENTNAME,
                                             block=True,
                                             timeout=None)
            else:
                results = client.get_results(override_type=client.CLIENTNAME,
                                             block=False)
            if len(results) > 0:
                found_results = True
                for task in results:
                    self.queued_action_counter -= 1

                    # The only reason we are go through this much
                    # effort to deal with the serpent serializer
                    # is because it is the default in Pyro4.
                    if pyu_pyro.using_pyro4 and \
                       (Pyro4.config.SERIALIZER == 'serpent'):
                        if six.PY3:
                            assert type(task['result']) is dict
                            assert task['result']['encoding'] == 'base64'
                            task['result'] = base64.b64decode(task['result']['data'])
                        else:
                            assert type(task['result']) is unicode
                            task['result'] = str(task['result'])
                    # ** See note in _get_task_data about why we pickle
                    #    all communication
                    task['result'] = pickle.loads(task['result'])

                    ah = self.event_handle.get(task['id'], None)
                    if ah is None:
                        # if we are here, this is really bad news!
                        raise RuntimeError(
                            "The %s found results for task with id=%s"
                            " - but no corresponding action handle "
                            "could be located! Showing task result "
                            "below:\n%s" % (type(self).__name__,
                                            task['id'],
                                            task.get('result', None)))
                    if type(task['result']) is pyu_pyro.TaskProcessingError:
                        ah.status = ActionStatus.error
                        self.event_handle[ah.id].update(ah)
                        msg = ("ScenarioTreeServer reported a processing "
                               "error for task with id=%s. Reason: \n%s"
                               % (task['id'], task['result'].args[0]))
                        if not self.ignore_task_errors:
                            raise RuntimeError(msg)
                        elif self.ignore_task_errors == 1:
                            logger.warning(msg)
                        # any value other than 0 or 1 will
                        # silently ignore task errors
                    else:
                        ah.status = ActionStatus.done
                        self.event_handle[ah.id].update(ah)
                        self.results[ah.id] = task['result']

        if not found_results:
            # If the queues are all empty, wait some time for things to
            # fill up. Constantly pinging dispatch servers wastes their
            # time, and inhibits task server communication. The good
            # thing about queues_to_check is that it simultaneously
            # grabs information for any queues with results => one
            # client query can yield many results.

            # TBD: We really need to parameterize the time-out value,
            #      but it isn't clear how to propagate this though the
            #      solver manager interface layers.
            time.sleep(0.01)