#!/usr/bin/env python3

import subprocess
import os
import sys
import time
import boto3

DIST_DIR = "../dist"
BATCH_INGESTION_TEST_FILE_NAME = "BatchIngestionTest.py"
EMR_TEST_CLUSTER_NAME = "SparkTestEmrCluster"


s3_client = boto3.client("s3")
emr_client = boto3.client("emr")
account_id = boto3.client("sts").get_caller_identity()["Account"]

TEST_S3_BUCKET = f"spark-test-bucket-{account_id}"

tars = list(filter(lambda file_name: "sagemaker_feature_store_pyspark" in file_name, os.listdir(DIST_DIR)))

if len(tars) == 0:
    exit("tar file not detected, please build the library and try again.")
elif len(tars) > 1:
    exit("Mutiple tar detected under in dist, please remove redundant tars and try again.")

pyspark_tar_name = tars[0]
pyspark_tar_path = os.path.join(DIST_DIR, pyspark_tar_name)
current_time_stamp = int(time.time())
test_temp_dir = f"spark-test-temp-{current_time_stamp}"

# First, submit all dependent integration test files to test S3 bucket
print(f"Uploading the spark library test files to {TEST_S3_BUCKET}.")
s3_client.upload_file(pyspark_tar_path, TEST_S3_BUCKET, f"{test_temp_dir}/sagemaker_feature_store_pyspark.tar.gz")
s3_client.upload_file(BATCH_INGESTION_TEST_FILE_NAME, TEST_S3_BUCKET, f"{test_temp_dir}/{BATCH_INGESTION_TEST_FILE_NAME}")
print("Upload finished.")

# Fetch the latest EMR cluster's Id that is still working
clusters = emr_client.list_clusters(
    ClusterStates=["STARTING", "BOOTSTRAPPING", "RUNNING", "WAITING"],
)["Clusters"]
clusters = list(filter(lambda cluster: cluster["Name"] == EMR_TEST_CLUSTER_NAME, clusters))
clusters.sort(key=lambda cluster: cluster["Status"]["Timeline"]["CreationDateTime"], reverse=True)
cluster_id = clusters[0]["Id"]
hadoop_test_temp_dir = f'/home/hadoop/{test_temp_dir}'

# Add workflow steps to EMR
step_ids = emr_client.add_job_flow_steps(
    JobFlowId=cluster_id,
    Steps=[
        {
            'Name': 'Test Preparation',
            'ActionOnFailure': 'CONTINUE',
            'HadoopJarStep': {
                'Jar': 'command-runner.jar',
                'Args': [
                    "bash",
                    "-c",
                    f'mkdir {hadoop_test_temp_dir};'
                    f'aws s3 cp --recursive s3://{TEST_S3_BUCKET}/{test_temp_dir} {hadoop_test_temp_dir};'
                    f'cd {hadoop_test_temp_dir};'
                    f'pip3 install --upgrade boto3;'
                    f'sudo pip3 uninstall -y sagemaker-feature-store-pyspark;'
                    f'sudo pip3 uninstall -y sagemaker-feature-store-pyspark-3.1;'
                    f'pip3 install sagemaker_feature_store_pyspark.tar.gz'
                ]
            }
        },
        {
            'Name': 'Run Spark Integration Test',
            'ActionOnFailure': 'CONTINUE',
            'HadoopJarStep': {
                'Jar': 'command-runner.jar',
                'Args': [
                    "bash",
                    "-c",
                    f'python3 {hadoop_test_temp_dir}/{BATCH_INGESTION_TEST_FILE_NAME}'
                ]
            }
        },
        {
            'Name': 'Spark Integration Test Cleanup',
            'ActionOnFailure': 'CONTINUE',
            'HadoopJarStep': {
                'Jar': 'command-runner.jar',
                'Args': [
                    "bash",
                    "-c",
                    f'pip3 uninstall -y sagemaker-feature-store-pyspark;'
                    f'pip3 uninstall -y sagemaker-feature-store-pyspark-3.1;'
                    f'rm -r {hadoop_test_temp_dir}'
                ]
            }
        }
    ]
)['StepIds']

pending_steps_id_list = step_ids

while True:

    steps = emr_client.list_steps(
        ClusterId=cluster_id,
        StepIds=pending_steps_id_list,
    )["Steps"]

    for step in steps:
        if step["Status"]["State"] in ['CANCELLED', 'FAILED', 'INTERRUPTED', 'CANCEL_PENDING']:
            exit(f"Integration test failed, cluster id:{cluster_id}, step id: {step['Id']}, step name: {step['Name']}, "
                 f"please log in the console to root cause the test failure.")
        elif step["Status"]["State"] == 'COMPLETED':
            pending_steps_id_list.remove(step['Id'])
        else:
            print(f"Step [{step['Name']}] is in {step['Status']['State']} state, cluster id:{cluster_id}, "
                  f"step id: {step['Id']}.")
    if len(pending_steps_id_list) == 0:
        break

    # Sleep for a while before next polling
    print("=================================================================")
    time.sleep(10)

print("Spark library integration test succeeded!")