# /*---------------------------------------------------------------------------------------------
#  * Copyright (c) 2022 STMicroelectronics.
#  * All rights reserved.
#  * This software is licensed under terms that can be found in the LICENSE file in
#  * the root directory of this software component.
#  * If no LICENSE file comes with this software, it is provided AS-IS.
#  *--------------------------------------------------------------------------------------------*/

from lookup_tables_generator import generate_mel_LUT_files
from common_benchmark import analyze_footprints, Cloud_analyze, Cloud_benchmark, benchmark_model
from sklearn.metrics import accuracy_score
import random
from evaluation import _aggregate_predictions, compute_accuracy_score
from header_file_generator import gen_h_user_file
import load_models
from hydra.core.hydra_config import HydraConfig
from preprocessing import preprocessing
from callbacks import get_callbacks
from common_visualize import vis_training_curves
from visualize import _compute_confusion_matrix, _plot_confusion_matrix
from datasets import _esc10_csv_to_tf_dataset, load_ESC_10, load_custom_esc_like_multiclass
from data_augment import get_data_augmentation
from benchmark import evaluate_TFlite_quantized_model
from quantization import TFLite_PTQ_quantizer
import traceback
import numpy as np
import os
from tensorflow import keras
import tensorflow as tf
from omegaconf import OmegaConf
from munch import DefaultMunch
import mlflow
import ssl
ssl._create_default_https_context = ssl._create_unverified_context


# Set seeds
def setup_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)  # tf cpu fix seed

    tf.config.threading.set_inter_op_parallelism_threads(1)
    tf.config.threading.set_intra_op_parallelism_threads(1)


def get_config(cfg):
    config_dict = OmegaConf.to_container(cfg)
    configs = DefaultMunch.fromDict(config_dict)
    return configs


def mlflow_ini(cfg):
    mlflow.set_tracking_uri(cfg.mlflow.uri)
    mlflow.set_experiment(cfg.general.project_name)
    mlflow.tensorflow.autolog(log_models=False)


def get_optimizer(cfg):

    if cfg.train_parameters.optimizer.lower() == "SGD".lower():
        optimizer = keras.optimizers.SGD(learning_rate=cfg.train_parameters.initial_learning)
    elif cfg.train_parameters.optimizer.lower() == "RMSprop".lower():
        optimizer = keras.optimizers.RMSprop(learning_rate=cfg.train_parameters.initial_learning)
    elif cfg.train_parameters.optimizer.lower() == "Adam".lower():
        optimizer = keras.optimizers.Adam(learning_rate=cfg.train_parameters.initial_learning)
    else:
        optimizer = keras.optimizers.Adam(learning_rate=cfg.train_parameters.initial_learning)
    return optimizer


def get_loss(cfg):
    num_classes = len(cfg.dataset.class_names)
    if cfg.model.multi_label:
        raise NotImplementedError("Multi-label classification not implemented yet, but will be in a future update.")
    elif num_classes > 2:
        loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False)
    else:
        loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)
    return loss


def train(cfg):
    # get model
    model = load_models.get_model(cfg)
    print("[INFO] Model summary")
    model.summary()

    # get loss
    loss = get_loss(cfg)

    # get optimizer
    optimizer = get_optimizer(cfg)

    # get callbacks
    callbacks = get_callbacks(cfg)

    # get data augmentation
    data_augmentation, augment = get_data_augmentation(cfg)

    # get pre_processing
    # pre_process = preprocessing(cfg)
    _, _ = preprocessing(cfg)

    # get datasets

    if cfg.dataset.name.lower() == "esc10" :
        train_ds, valid_ds, test_ds, clip_labels = load_ESC_10(cfg)
    elif cfg.dataset.name.lower() == "custom" and not cfg.model.multi_label:
        train_ds, valid_ds, test_ds, clip_labels = load_custom_esc_like_multiclass(cfg)
    elif cfg.dataset.name.lower() == "custom" and cfg.model.multi_label:
        raise NotImplementedError("Multilabel support not implemented yet !")
    else:
        raise NotImplementedError("Please choose a valid dataset ('esc10' or 'custom')")

    # Apply Data aug
    if augment:
        augmented_model = tf.keras.models.Sequential([data_augmentation, model])
        augmented_model._name = "Augmented_model"
    else:
        augmented_model = model
        augmented_model._name = "Model"

    if cfg.model.expand_last_dim:
        augmented_model.build((None, cfg.model.input_shape[0], cfg.model.input_shape[1], 1))
    else:
        augmented_model.build((None, cfg.model.input_shape[0], cfg.model.input_shape[1]))

    print("[INFO] Augmented model summary")
    augmented_model.summary()

    # Compile the model
    augmented_model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])

    print("[INFO] Model sucessfully compiled")

    if cfg.quantization.quantize:
        print("[INFO] : Estimating the model footprints...")
        if cfg.quantization.quantizer == "TFlite_converter" and cfg.quantization.quantization_type == "PTQ":
            TFLite_PTQ_quantizer(cfg, model, train_ds=None, fake=True)
            model_path = os.path.join(HydraConfig.get().runtime.output_dir,
                                      "{}/{}".format(cfg.quantization.export_dir, "quantized_model.tflite"))
        else:
            raise TypeError("Quantizer and quantization type not supported yet!")
    else:
        model_path = os.path.join(HydraConfig.get().runtime.output_dir,
                                  "{}/{}".format(cfg.general.saved_models_dir, "best_model.h5"))
        model.save(model_path)

    # Evaluate model footprints with STM32Cube.AI
    if cfg.stm32ai.footprints_on_target:
        print("[INFO] : Establishing a connection to STM32Cube.AI Developer Cloud to launch the model benchmark on STM32 target...")
        try:
            output_analyze = Cloud_analyze(cfg, model_path)
            if output_analyze == 0:
                raise Exception("Connection failed, Offline benchmark will be launched.")
        except Exception as e:
            output_analyze = 0
            print("[FAIL] :", e)
            # Write out an error file. This will be returned as the failureReason in the
            # DescribeTrainingJob result.
            trc = traceback.format_exc()
            with open(os.path.join('/opt/ml/output', 'failure'), 'w') as s:
                s.write('Exception during training: ' + str(e) + '\n' + trc)

            print("[INFO] : Offline benchmark launched...")
            benchmark_model(cfg, model_path)

    else:
        benchmark_model(cfg, model_path)

    # train the model
    print("[INFO] : Starting training...")
    history = augmented_model.fit(train_ds, validation_data=valid_ds, callbacks=callbacks,
                                  epochs=cfg.train_parameters.training_epochs)

    # Visualize training curves
    vis_training_curves(history, cfg)

    # evaluate the float model on test set

    # Load best trained model w/o data augmentation layers
    best_model = augmented_model.layers[-1]
    best_model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
    best_model.save(
        os.path.join(HydraConfig.get().runtime.output_dir, cfg.general.saved_models_dir + '/' + "best_model.h5"))

    X_test, y_test = test_ds[0], test_ds[1]

    test_preds = best_model.predict(X_test)
    aggregated_preds = _aggregate_predictions(preds=test_preds,
                                              clip_labels=clip_labels,
                                              is_multilabel=cfg.model.multi_label,
                                              is_truth=False)
    aggregated_truth = _aggregate_predictions(preds=y_test,
                                              clip_labels=clip_labels,
                                              is_multilabel=cfg.model.multi_label,
                                              is_truth=True)

    # generate the confusion matrix for the float model
    patch_level_accuracy = compute_accuracy_score(y_test, test_preds,
                                                  is_multilabel=cfg.model.multi_label)
    print("[INFO] : Patch-level accuracy on test set : {}".format(patch_level_accuracy))
    clip_level_accuracy = compute_accuracy_score(aggregated_truth, aggregated_preds,
                                                 is_multilabel=cfg.model.multi_label)
    print("[INFO] : Clip-level accuracy on test set : {}".format(clip_level_accuracy))

    # Log accuracies in MLFLOW
    mlflow.log_metric("float_patch_test_acc", patch_level_accuracy)
    mlflow.log_metric("float_clip_test_acc", clip_level_accuracy)

    patch_level_confusion_matrix = _compute_confusion_matrix(y_test, test_preds,
                                                             is_multilabel=cfg.model.multi_label)
    clip_level_confusion_matrix = _compute_confusion_matrix(aggregated_truth, aggregated_preds,
                                                            is_multilabel=cfg.model.multi_label)

    _plot_confusion_matrix(patch_level_confusion_matrix,
                           class_names=cfg.dataset.class_names,
                           title="Patch-level CM",
                           test_accuracy=patch_level_accuracy)

    _plot_confusion_matrix(clip_level_confusion_matrix,
                           class_names=cfg.dataset.class_names,
                           title="Clip-level CM",
                           test_accuracy=clip_level_accuracy)

    # quantize the model with training data
    if cfg.quantization.quantize:
        print("[INFO] : Quantizing the model ... This might take few minutes ...")

        if cfg.data_augmentation.VolumeAugment:
            print("Applying Volume augmentation to quantization dataset")
            def map_fn(x, y): return (data_augmentation(x), y)
            train_ds = train_ds.map(map_fn)

        if cfg.quantization.quantizer == "TFlite_converter" and cfg.quantization.quantization_type == "PTQ":
            TFLite_PTQ_quantizer(cfg, best_model, train_ds, fake=False)
            quantized_model_path = os.path.join(HydraConfig.get(
            ).runtime.output_dir, "{}/{}".format(cfg.quantization.export_dir, "quantized_model.tflite"))

            # Generating C model
            if cfg.stm32ai.footprints_on_target:
                try:
                    if output_analyze != 0:
                        output_benchmark = Cloud_benchmark(cfg, quantized_model_path, output_analyze)
                        if output_benchmark == 0:
                            raise Exception("Connection failed, generating C model using local STM32Cube.AI.")
                    else:
                        raise Exception("Connection failed, generating C model using local STM32Cube.AI.")
                except Exception as e:
                    print("[FAIL] :", e)
                    print("[INFO] : Offline C code generation launched...")
                    benchmark_model(cfg, quantized_model_path)
            else:
                benchmark_model(cfg, quantized_model_path)

            # evaluate the quantized model
            if cfg.quantization.evaluate == True:
                q_patch_level_acc, q_clip_level_acc = evaluate_TFlite_quantized_model(
                    quantized_model_path, X_test, y_test, clip_labels, cfg)
                mlflow.log_metric("int_patch_test_acc", q_patch_level_acc)
                mlflow.log_metric("int_clip_test_acc", q_clip_level_acc)

        else:
            raise TypeError("Quantizer and quantization type not supported yet!")

        # Generate Config.h for C embedded application
        print("Generating C header file for Getting Started...")
        gen_h_user_file(cfg)
        print("Done")
        # Generate LUT files
        print("Generating C look-up tables files for Getting Started...")
        generate_mel_LUT_files(cfg)
        print("Done")

    # record the whole hydra working directory to get all infos
    mlflow.log_artifact(HydraConfig.get().runtime.output_dir)
    mlflow.end_run()