from pyspark.ml.feature import VectorAssembler, StringIndexer, MinMaxScaler
from feature_store_manager import FeatureStoreManager
from pyspark.sql.functions import udf, datediff, to_date, lit, col,isnan, when, count
from pyspark.sql.types import IntegerType, DoubleType, StructType, StructField, StringType, FloatType
from pyspark.sql import SparkSession, DataFrame
from argparse import Namespace, ArgumentParser
from pyspark.ml.linalg import Vector
from pyspark.ml import Pipeline
from datetime import datetime
import argparse
import ast
import logging
import boto3
import time
import os


logger = logging.getLogger('__name__')
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())


def parse_args() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_processes', type=int, default=1)
    parser.add_argument('--num_workers', type=int, default=1)
    parser.add_argument('--feature_group_name', type=str)
    parser.add_argument('--feature_group_arn', type=str)
    parser.add_argument('--target_feature_store_list', type=str)
    parser.add_argument('--s3_uri_prefix', type=str)
    
    args, _ = parser.parse_known_args()
    return args

def transform_row(row) -> list:
    columns = list(row.asDict())
    record = []
    for column in columns:
        feature = {'FeatureName': column, 'ValueAsString': str(row[column])}
        record.append(feature)
    return record

def batch_ingest_to_feature_store(args: argparse.Namespace, df: DataFrame) -> None:
    feature_group_name = args.feature_group_name
    logger.info(f'Feature Group name supplied is: {feature_group_name}')
    session = boto3.session.Session()

    logger.info(f'Instantiating FeatureStoreManger!')
    feature_store_manager=FeatureStoreManager()

    logger.info(f'trying to load datatypes directly from Dataframe')

    # Load the feature definitions from input schema. The feature definitions can be used to create a feature group
    feature_definitions = feature_store_manager.load_feature_definitions_from_schema(df)
    logger.info(f'Feature definitions loaded successfully!')
    print(feature_definitions)
    feature_group_arn = args.feature_group_arn
    logger.info(f'Feature Group ARN supplied is: {feature_group_arn}')

    # If only OfflineStore is selected, the connector will batch write the data to offline store directly
    args.target_feature_store_list = ast.literal_eval(args.target_feature_store_list)
    logger.info(f'Ingesting into the following stores: {args.target_feature_store_list}')

    feature_store_manager.ingest_data(input_data_frame=df, feature_group_arn=feature_group_arn, target_stores= args.target_feature_store_list) 
    logger.info(f'Feature Ingestions successful!')

def scale_col(df: DataFrame, col_name: str) -> DataFrame:
    unlist = udf(lambda x: round(float(list(x)[0]), 2), DoubleType())
    assembler = VectorAssembler(inputCols=[col_name], outputCol=f'{col_name}_vec')
    # scale an column col_name with minmax scaler and drop the original column

    scaler = MinMaxScaler(inputCol=f'{col_name}_vec', outputCol=f'{col_name}_scaled')
    pipeline = Pipeline(stages=[assembler, scaler])
    df = pipeline.fit(df).transform(df).withColumn(f'{col_name}_scaled', unlist(f'{col_name}_scaled')) \
                                       .drop(f'{col_name}_vec')
    df = df.drop(col_name)
    df = df.withColumnRenamed(f'{col_name}_scaled', col_name)
    return df

def ordinal_encode_col(df: DataFrame, col_name: str) -> DataFrame:
    indexer = StringIndexer(inputCol=col_name, outputCol=f'{col_name}_new')
    df = indexer.fit(df).transform(df)
    df = df.drop(col_name)
    df = df.withColumnRenamed(f'{col_name}_new', col_name)
    return df

def run_spark_job():

    args = parse_args()
   
    spark = SparkSession.builder.getOrCreate()
    
    # set the legacy time parser policy to LEGACY to allow for parsing of dates in the format dd/MM/yyyy HH:mm:ss, which solves backwards compatibility issues to spark 2.4
    spark.sql("set spark.sql.legacy.timeParserPolicy=LEGACY")

    logger.info(f'Using Spark-Version:{spark.version}')

    # get the total number of cores in the Spark cluster; if developing locally, there might be no executor
    try:
        spark_context = spark.sparkContext
        total_cores = int(spark_context._conf.get('spark.executor.instances')) * int(spark_context._conf.get('spark.executor.cores'))
        logger.info(f'Total available cores in the Spark cluster = {total_cores}')
    except:
        total_cores = 1
        logger.error(f'Could not retrieve number of total cores. Setting total cores to 1. Error message: {str(e)}')
       
    logger.info(f'Reading input file from S3. S3 uri is {args.s3_uri_prefix}')

    # define the schema of the input data
    csvSchema = StructType([
        StructField("order_id", StringType(), True),
        StructField("customer_id", StringType(), False),
        StructField("product_id", StringType(), False),
        StructField("purchase_amount", FloatType(), False),
        StructField("is_reordered", IntegerType(), False),
        StructField("purchased_on", StringType(), False),
        StructField("event_time", StringType(), False)])

    # read the pyspark dataframe with a schema 
    df = spark.read.option("header", "true").schema(csvSchema).csv(args.s3_uri_prefix)  

    # transform 1 - encode boolean to int
    df = ordinal_encode_col(df, 'is_reordered')
    df = df.withColumn('is_reordered', df['is_reordered'].cast(IntegerType()))

    # transform 2 - min max scale `purchase_amount`
    df = df.withColumn('purchase_amount', df['purchase_amount'].cast(DoubleType()))
    df = scale_col(df, 'purchase_amount')
    
    # transform 3 - derive `n_days_since_last_purchase` column using the `purchased_on` col
    current_date = datetime.today().strftime('%Y-%m-%d')
    df = df.withColumn('n_days_since_last_purchase', datediff(to_date(lit(current_date)), to_date('purchased_on', 'yyyy-MM-dd')))
    df = df.drop('purchased_on')
    df = scale_col(df, 'n_days_since_last_purchase')
    
    
    logger.info(f'Number of partitions = {df.rdd.getNumPartitions()}')
    # Rule of thumb heuristic - rely on the product of #executors by #executor.cores, and then multiply that by 3 or 4
    df = df.repartition(total_cores * 3)
    logger.info(f'Number of partitions after re-partitioning = {df.rdd.getNumPartitions()}')
    logger.info(f'Feature Store ingestion start: {datetime.now().strftime("%m/%d/%Y, %H:%M:%S")}')
    batch_ingest_to_feature_store(args, df)
    logger.info(f'Feature Store ingestion complete: {datetime.now().strftime("%m/%d/%Y, %H:%M:%S")}')

if __name__ == '__main__':
    logger.info('BATCH INGESTION - STARTED')
    run_spark_job()
    logger.info('BATCH INGESTION - COMPLETED')