# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at
#
#    http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import scipy.stats


def generate_griffiths_data(
    num_documents=5000,
    average_document_length=150,
    num_topics=5,
    vocabulary_size=25,
    alpha=None,
    eta=None,
    seed=0,
):
    """Returns example documents from Griffiths-Steyvers [1].

    Given an `alpha` and `eta, the Dirichlet priors for the topic and topic-word
    distributions respectively, this function generates sample document word
    counts according to the Latent Dirichlet Allocation (LDA) model.

    Parameters
    ----------
    num_documents : int
        (Default: 1000) The number of example documents to create using LDA.
    average_document_length : int
        (Default: 100) The average number of words in each document. The
        document length is sampled from a Poisson distribution with this mean.
    num_topics : int
        (Default: 10) Can be set to either 5 or 10. The number of known topics.
    alpha : Numpy NDArray
        (Default: None) An array of length `num_topics` representing a given
        Dirichlet topic prior. If `None` is provided then a uniform
        distribution will be used.
    eta : Numpy NDArray
        (Default: None) An array of length `num_topics` representing a given
        Dirichlet topic-word prior.
    seed : int
        (Defualt: 0) The random number generator seed.

    Returns
    -------
    alpha : Numpy NDArray
        A vector of length `num_topics` equal to the Dirichlet prior used to
        generate documents.
    beta : Numpy NDArray
        A matrix of size `num_topics` x 25 equal to the topic-word probability
        matrix used to generate documents.
    documents : Numpy NDArray
        A matrix of size `num_documents` x 25 equal to the documents generated
        by the LDA model defined by `alpha` and `beta.
    theta : Numpy NDArray
        A matrix of size `num_documents` x `num_topics` equal to the topic
        mixtures used to generate the output `documents`.

    References
    ----------
    [1] Thomas L Griffiths and Mark Steyvers. "Finding Scientific Topics."
        Proceedings of the National Academy of Sciences, 101(suppl 1):5228–5235,
        2004.

    """
    image_dim = np.int(np.sqrt(vocabulary_size))

    # perform checks on input
    assert num_topics in [5, 10], "Example data only available for 5 or 10 topics"
    if alpha:
        assert len(alpha) == num_topics, "len(alpha) must be equal to num_topics"

    # initialize Dirichlet alpha and eta distributions if not provided. here,
    # the eta distribution is only across `image_dim` elements since each
    # topic-word distribution will only have `image_dim` non-zero entries
    #
    np.random.seed(seed=seed)
    if alpha is None:
        alpha = np.ones(num_topics, dtype=np.float) / num_topics
    if eta is None:
        eta = [100] * image_dim  # make it close to a uniform distribution
    dirichlet_alpha = sp.stats.dirichlet(alpha)
    dirichlet_eta = sp.stats.dirichlet(eta)

    # initialize a known topic-word distribution (beta) using eta. these are
    # the "row" and "column" topics, respectively. when num_topics = 5 only
    # create the col topics. when num_topics = 10 add the row topics as well
    #
    beta = np.zeros((num_topics, image_dim, image_dim), dtype=np.float)
    for i in range(image_dim):
        beta[i, :, i] = dirichlet_eta.rvs(size=1)
    if num_topics == 10:
        for i in range(image_dim):
            beta[i + image_dim, i, :] = dirichlet_eta.rvs(size=1)
    beta.resize(num_topics, vocabulary_size)

    # generate documents using the LDA model / provess
    #
    document_lengths = sp.stats.poisson(average_document_length).rvs(size=num_documents)
    documents = np.zeros((num_documents, vocabulary_size), dtype=np.float64)
    thetas = dirichlet_alpha.rvs(
        size=num_documents
    )  # precompute topic distributions for performance
    for m in range(num_documents):
        document_length = document_lengths[m]
        theta = thetas[m]
        theta = np.asarray(theta).astype("float64")
        theta = theta / np.sum(theta)
        topic = sp.stats.multinomial.rvs(
            1, theta, size=document_length
        )  # precompute topics for performance

        # generate word counts within document
        for n in range(document_length):
            word_topic = topic[n]
            topic_index = np.argmax(word_topic)
            topic_word_distribution = beta[topic_index]
            topic_word_distribution = np.asarray(topic_word_distribution).astype("float64")
            topic_word_distribution = topic_word_distribution / np.sum(topic_word_distribution)
            word = sp.stats.multinomial.rvs(1, topic_word_distribution, size=1).reshape(
                vocabulary_size
            )
            documents[m] += word

    return alpha, beta, documents, thetas


def plot_topic_data(data, nrows, ncols, with_colorbar=True, cmap=cm.viridis):
    """Helper function for plotting arrays of image"""
    fig, ax = plt.subplots(nrows, ncols, figsize=(ncols, nrows))
    vmin = 0
    vmax = data.max()
    V = len(data[0])
    n = int(np.sqrt(V))
    for i in range(nrows):
        for j in range(ncols):
            index = i * ncols + j

            if nrows > 1:
                im = ax[i, j].matshow(data[index].reshape(n, n), cmap=cmap, vmin=vmin, vmax=vmax)
            else:
                im = ax[j].matshow(data[index].reshape(n, n), cmap=cmap, vmin=vmin, vmax=vmax)

    for axi in ax.ravel():
        axi.set_xticks([])
        axi.set_yticks([])

    if with_colorbar:
        fig.colorbar(im, ax=ax.ravel().tolist(), orientation="horizontal", fraction=0.2)
    return fig


def match_estimated_topics(topics_known, topics_estimated):
    """A dumb but fast way to match known topics to estimated topics"""
    K, V = topics_known.shape
    permutation = -1 * np.ones(K, dtype=np.int)
    unmatched_estimated_topics = []

    for estimated_topic_index, t in enumerate(topics_estimated):
        matched_known_topic_index = np.argmin(
            [np.linalg.norm(known_topic - t) for known_topic in topics_known]
        )
        if permutation[matched_known_topic_index] == -1:
            permutation[matched_known_topic_index] = estimated_topic_index
        else:
            unmatched_estimated_topics.append(estimated_topic_index)

    for estimated_topic_index in unmatched_estimated_topics:
        for i in range(K):
            if permutation[i] == -1:
                permutation[i] = estimated_topic_index
                break

    return (topics_estimated[permutation, :]).copy()