# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this
# software and associated documentation files (the "Software"), to deal in the Software
# without restriction, including without limitation the rights to use, copy, modify,
# merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
# PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

#!/usr/bin/env python

# Trains a TFFM FFM model
# Inputs: 
# - a sparsely formatted protobuf file
# - List of headers
# Outputs:
# - Model state
# - List of headers

from __future__ import print_function

import os
import os.path
import json
import sys
import csv
import traceback
from io import FileIO

import pandas as pd
import numpy as np
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, classification_report
import scipy.sparse as sp

import tensorflow as tf
from tffm import TFFMRegressor

import sagemaker.amazon.common as smac
import boto3

# These are defined by Sagemaker
param_file = '/opt/ml/input/config/hyperparameters.json'
input_file = '/opt/ml/input/config/inputdataconfig.json'
training_dir = '/opt/ml/input/data/train'
test_dir = '/opt/ml/input/data/test'
failure_file = '/opt/ml/output/failure'
model_dir = '/opt/ml/model'
output_path = '/opt/ml/output'

# The function to execute the training.
def train():
    print('Starting the training.')
    model = None
    try:
        # load parameters
        # order
        # rank
        # epochs
        # location of header data
        param_data = json.load(open(param_file))
        order = 3
        rank = 7
        epochs = 20
        header_file_bucket = ''
        header_file_prefix = ''
        order = param_data.get('order', None)
        if order is not None:
            order = int(order)
        else:
            raise ValueError(('Missing required parameter order'))
        rank = param_data.get('rank', None)
        if rank is not None:
            rank = int(rank)
        else:
            raise ValueError(('Missing required parameter rank'))
        epochs = param_data.get('epochs', None)
        if epochs is not None:
            epochs = int(epochs)
        else:
            raise ValueError(('Missing required parameter epochs'))
        header_file_bucket = param_data.get('header_file_bucket', None)
        if header_file_bucket is None:
            raise ValueError(('Missing required parameter header_file_bucket'))
        header_file_prefix = param_data.get('header_file_prefix', None)
        if header_file_prefix is None:
            raise ValueError(('Missing required parameter header_file_prefix'))

        # Read headers
        s3 = boto3.resource('s3')
        header_files = s3.Bucket(header_file_bucket).objects.filter(Prefix=header_file_prefix)
        headers = []
        for obj in header_files:
            destination_key = os.path.join("/tmp", os.path.basename(obj.key))
            s3.Bucket(header_file_bucket).download_file(obj.key, destination_key)
            with open(destination_key, 'r') as csvfile:
                headerreader = csv.reader(csvfile) 
                for row in headerreader:
                    headers.append("".join(row))
        if len(headers) == 0:
            raise ValueError(('There are no header files in s3://{}/{}.\n').format(header_file_bucket, header_file_prefix))
        num_features = len(headers)

        # Read protobuf records and convert to sparse matrix
        input_files = [ os.path.join(training_dir, file) for file in os.listdir(training_dir) ]
        if len(input_files) == 0:
            raise ValueError(('There are no files in {}.\n' +
                              'This usually indicates that the channel ({}) was incorrectly specified,\n' +
                              'the data specification in S3 was incorrectly specified or the role specified\n' +
                              'does not have permission to access the data.').format(training_dir, 'train'))
        num_samples = 0
        tr_row_ind = []
        tr_col_ind = []
        tr_data = []
        tr_idx = 0
        y_tr = []
        for file in input_files:
            train_records = smac.read_records(FileIO(file))
            num_samples = num_samples + len(train_records)
            for train_record in train_records:
                tr_row_ind.extend([tr_idx] * len(train_record.features['values'].float32_tensor.values))
                tr_col_ind.extend(train_record.features['values'].float32_tensor.keys)
                tr_data.extend(train_record.features['values'].float32_tensor.values)
                tr_idx = tr_idx + 1
                y_tr.append(train_record.label['values'].float32_tensor.values[0])
        print("Number of samples: {0}".format(len(tr_row_ind)))

        print("Creating sparse matrix of shape {0},{1}".format(num_samples, num_features))
        X_tr_sparse = sp.csr_matrix( (np.array(tr_data),(np.array(tr_row_ind),np.array(tr_col_ind))), shape=(num_samples,num_features) )
        print("X_tr shape: {0}".format(X_tr_sparse.shape))

        # define model
        model = TFFMRegressor(
            order=order,
            rank=rank,
            optimizer=tf.train.AdamOptimizer(learning_rate=0.1),
            n_epochs=epochs,
            batch_size=-1,
            init_std=0.001,
            input_type='sparse'
        )

        # train model
        model.fit(X_tr_sparse, np.array(y_tr), show_progress=True)

        # save artifacts
        model.save_state("{0}/tffm_state.tf".format(model_dir))
        print('Training complete.')

        # Save headers
        with open("{0}/headers.csv".format(model_dir), 'w') as csvfile:
            hwriter = csv.writer(csvfile, delimiter=' ',
                                    quotechar='|', quoting=csv.QUOTE_MINIMAL)
            for h in headers:
                hwriter.writerow([h])

        # Load test data if available
        if os.path.exists(test_dir):
            test_files = [ os.path.join(test_dir, file) for file in os.listdir(test_dir) ]
            num_test_samples = 0
            te_row_ind = []
            te_col_ind = []
            te_data = []
            te_idx = 0
            y_te = []
            for file in test_files:
                test_records = smac.read_records(FileIO(file))
                num_test_samples = num_test_samples + len(test_records)
                for test_record in test_records:
                    te_row_ind.extend([te_idx] * len(test_record.features['values'].float32_tensor.values))
                    te_col_ind.extend(test_record.features['values'].float32_tensor.keys)
                    te_data.extend(test_record.features['values'].float32_tensor.values)
                    te_idx = te_idx + 1
                    y_te.append(test_record.label['values'].float32_tensor.values[0])

            print("Creating test sparse matrix of shape {0},{1}".format(num_test_samples, num_features))
            X_te_sparse = sp.csr_matrix( (np.array(te_data),(np.array(te_row_ind),np.array(te_col_ind))), shape=(num_test_samples,num_features) )
            print("X_te shape: {0}".format(X_te_sparse.shape))

            # Validate and report scores
            predictions = model.predict(X_te_sparse)
            predvec = np.where(predictions > 0.5, 1, 0)

            print('Weighted F1: {}'.format(f1_score(y_te, predvec,average='weighted')))
            print('Accuracy: {}'.format(accuracy_score(y_te, predvec)))
            print('Weighted ROC: {}'.format(roc_auc_score(y_te, predvec, average='weighted')))
            print('Classification report: {}'.format(classification_report(y_te, predvec)))
            print("Confusion matrix")
            print(pd.crosstab(np.array(y_te), predvec, rownames=['actuals'], colnames=['predictions']))

    except Exception as e:
        # Write out an error file. This will be returned as the failureReason in the
        # DescribeTrainingJob result.
        trc = traceback.format_exc()
        with open(failure_file, 'w') as s:
            s.write('Exception during training: ' + str(e) + '\n' + trc)
        # Printing this causes the exception to be in the training job logs, as well.
        print('Exception during training: ' + str(e) + '\n' + trc, file=sys.stderr)
        # A non-zero exit code causes the training job to be marked as Failed.
        sys.exit(255)
    finally:
        if model is not None:
            model.destroy()

if __name__ == '__main__':
    train()

    # A zero exit code causes the job to be marked a Succeeded.
    sys.exit(0)