import os
import pickle
import numpy as np
import random as rnd
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.mlab as mlab
import seaborn
from PIL import Image, ImageColor
from collections import namedtuple
import warnings

warnings.filterwarnings("ignore")


def download_model_weights() -> str:
    from pathlib import Path
    import urllib.request

    cwd = os.path.dirname(os.path.abspath(__file__))
    for k in [
        "model-29.data-00000-of-00001",
        "model-29.index",
        "model-29.meta",
        "translation.pkl",
    ]:
        download_dir = Path(cwd) / "handwritten_model/"
        download_dir.mkdir(exist_ok=True, parents=True)
        if (download_dir / f"{k}").exists():
            continue
        print(f"file {k} not found, downloading from git repo..")
        urllib.request.urlretrieve(
            f"https://raw.github.com/Belval/TextRecognitionDataGenerator/master/trdg/handwritten_model/{k}",
            download_dir / f"{k}",
        )
        print(f"file {k} saved to disk")
    return cwd


def _sample(e, mu1, mu2, std1, std2, rho):
    cov = np.array([[std1 * std1, std1 * std2 * rho], [std1 * std2 * rho, std2 * std2]])
    mean = np.array([mu1, mu2])

    x, y = np.random.multivariate_normal(mean, cov)
    end = np.random.binomial(1, e)
    return np.array([x, y, end])


def _split_strokes(points):
    points = np.array(points)
    strokes = []
    b = 0
    for e in range(len(points)):
        if points[e, 2] == 1.0:
            strokes += [points[b : e + 1, :2].copy()]
            b = e + 1
    return strokes


def _cumsum(points):
    sums = np.cumsum(points[:, :2], axis=0)
    return np.concatenate([sums, points[:, 2:]], axis=1)


def _sample_text(sess, args_text, translation):
    # Original creator said it helps (https://github.com/Grzego/handwriting-generation/issues/3)
    args_text += " "

    fields = [
        "coordinates",
        "sequence",
        "bias",
        "e",
        "pi",
        "mu1",
        "mu2",
        "std1",
        "std2",
        "rho",
        "window",
        "kappa",
        "phi",
        "finish",
        "zero_states",
    ]
    vs = namedtuple("Params", fields)(
        *[tf.compat.v1.get_collection(name)[0] for name in fields]
    )

    text = np.array([translation.get(c, 0) for c in args_text])
    sequence = np.eye(len(translation), dtype=np.float32)[text]
    sequence = np.expand_dims(
        np.concatenate([sequence, np.zeros((1, len(translation)))]), axis=0
    )

    coord = np.array([0.0, 0.0, 1.0])
    coords = [coord]

    phi_data, window_data, kappa_data, stroke_data = [], [], [], []
    sess.run(vs.zero_states)
    for s in range(1, 60 * len(args_text) + 1):
        e, pi, mu1, mu2, std1, std2, rho, finish, phi, window, kappa = sess.run(
            [
                vs.e,
                vs.pi,
                vs.mu1,
                vs.mu2,
                vs.std1,
                vs.std2,
                vs.rho,
                vs.finish,
                vs.phi,
                vs.window,
                vs.kappa,
            ],
            feed_dict={
                vs.coordinates: coord[None, None, ...],
                vs.sequence: sequence,
                vs.bias: 1.0,
            },
        )
        phi_data += [phi[0, :]]
        window_data += [window[0, :]]
        kappa_data += [kappa[0, :]]
        # ---
        g = np.random.choice(np.arange(pi.shape[1]), p=pi[0])
        coord = _sample(
            e[0, 0], mu1[0, g], mu2[0, g], std1[0, g], std2[0, g], rho[0, g]
        )
        coords += [coord]
        stroke_data += [
            [mu1[0, g], mu2[0, g], std1[0, g], std2[0, g], rho[0, g], coord[2]]
        ]

        if finish[0, 0] > 0.8:
            break

    coords = np.array(coords)
    coords[-1, 2] = 1.0

    return phi_data, window_data, kappa_data, stroke_data, coords


def _crop_white_borders(image):
    image_data = np.asarray(image)
    grey_image_data = np.asarray(image.convert("L"))
    non_empty_columns = np.where(grey_image_data.min(axis=0) < 255)[0]
    non_empty_rows = np.where(grey_image_data.min(axis=1) < 255)[0]
    cropBox = (
        min(non_empty_rows),
        max(non_empty_rows),
        min(non_empty_columns),
        max(non_empty_columns),
    )
    image_data_new = image_data[
        cropBox[0] : cropBox[1] + 1, cropBox[2] : cropBox[3] + 1, :
    ]

    return Image.fromarray(image_data_new)


def _join_images(images):
    widths, heights = zip(*(i.size for i in images))

    total_width = sum(widths) - 35 * len(images)
    max_height = max(heights)

    compound_image = Image.new("RGBA", (total_width, max_height))

    x_offset = 0
    for im in images:
        compound_image.paste(im, (x_offset, 0))
        x_offset += im.size[0] - 35

    return compound_image


def generate(text, text_color):
    cd = download_model_weights()
    with open(
        os.path.join(cd, os.path.join("handwritten_model", "translation.pkl")), "rb"
    ) as file:
        translation = pickle.load(file)

    config = tf.compat.v1.ConfigProto(device_count={"GPU": 0})
    tf.compat.v1.reset_default_graph()
    with tf.compat.v1.Session(config=config) as sess:
        saver = tf.compat.v1.train.import_meta_graph(
            os.path.join(cd, "handwritten_model/model-29.meta")
        )
        saver.restore(
            sess, os.path.join(cd, os.path.join("handwritten_model/model-29"))
        )
        images = []
        colors = [ImageColor.getrgb(c) for c in text_color.split(",")]
        c1, c2 = colors[0], colors[-1]

        color = "#{:02x}{:02x}{:02x}".format(
            rnd.randint(min(c1[0], c2[0]), max(c1[0], c2[0])),
            rnd.randint(min(c1[1], c2[1]), max(c1[1], c2[1])),
            rnd.randint(min(c1[2], c2[2]), max(c1[2], c2[2])),
        )

        for word in text.split(" "):
            _, window_data, kappa_data, stroke_data, coords = _sample_text(
                sess, word, translation
            )

            strokes = np.array(stroke_data)
            strokes[:, :2] = np.cumsum(strokes[:, :2], axis=0)
            _, maxx = np.min(strokes[:, 0]), np.max(strokes[:, 0])
            miny, maxy = np.min(strokes[:, 1]), np.max(strokes[:, 1])

            fig, ax = plt.subplots(1, 1)
            fig.patch.set_visible(False)
            ax.axis("off")

            for stroke in _split_strokes(_cumsum(np.array(coords))):
                plt.plot(stroke[:, 0], -stroke[:, 1], color=color)

            fig.patch.set_alpha(0)
            fig.patch.set_facecolor("none")

            canvas = plt.get_current_fig_manager().canvas
            canvas.draw()

            s, (width, height) = canvas.print_to_buffer()
            image = Image.frombytes("RGBA", (width, height), s)
            mask = Image.new("RGB", (width, height), (0, 0, 0))

            images.append(_crop_white_borders(image))

            plt.close()

        return _join_images(images), mask