# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import re
import os
import subprocess
import time
import urllib3

from datetime import datetime
from urllib3.util.retry import Retry

import boto3
import pytest
from sagemaker.s3 import S3Downloader, S3Uploader
from sagemaker.spark.processing import PySparkProcessor, SparkJarProcessor
from tests.integ import DATA_DIR
from unittest.case import TestCase

HISTORY_SERVER_ENDPOINT = "http://0.0.0.0/proxy/15050"
JAVA_FILE_PATH = os.path.join("com", "amazonaws", "sagemaker", "spark", "test")
JAVA_VERSION_PATTERN = r"(\d+\.\d+).*"
SPARK_APPLICATION_URL_SUFFIX = "/history/application_1594922484246_0001/1/jobs/"
SPARK_PATH = os.path.join(DATA_DIR, "spark")


@pytest.fixture(scope="module")
def build_jar():
    jar_file_path = os.path.join(SPARK_PATH, "code", "java", "hello-java-spark")
    # compile java file
    java_version = subprocess.check_output(["java", "-version"], stderr=subprocess.STDOUT).decode(
        "utf-8"
    )
    java_version = re.search(JAVA_VERSION_PATTERN, java_version).groups()[0]

    if float(java_version) > 1.8:
        subprocess.run(
            [
                "javac",
                "--release",
                "8",
                os.path.join(jar_file_path, JAVA_FILE_PATH, "HelloJavaSparkApp.java"),
            ]
        )
    else:
        subprocess.run(
            ["javac", os.path.join(jar_file_path, JAVA_FILE_PATH, "HelloJavaSparkApp.java")]
        )

    subprocess.run(
        [
            "jar",
            "cfm",
            os.path.join(jar_file_path, "hello-spark-java.jar"),
            os.path.join(jar_file_path, "manifest.txt"),
            "-C",
            jar_file_path,
            ".",
        ]
    )
    yield
    subprocess.run(["rm", os.path.join(jar_file_path, "hello-spark-java.jar")])
    subprocess.run(["rm", os.path.join(jar_file_path, JAVA_FILE_PATH, "HelloJavaSparkApp.class")])


@pytest.fixture(scope="module")
def spark_py_processor(sagemaker_session, cpu_instance_type):
    spark_py_processor = PySparkProcessor(
        role="SageMakerRole",
        instance_count=2,
        instance_type=cpu_instance_type,
        sagemaker_session=sagemaker_session,
        framework_version="2.4",
    )

    return spark_py_processor


@pytest.fixture(scope="module")
def spark_v3_py_processor(sagemaker_session, cpu_instance_type):
    spark_py_processor = PySparkProcessor(
        role="SageMakerRole",
        instance_count=2,
        instance_type=cpu_instance_type,
        sagemaker_session=sagemaker_session,
        framework_version="3.1",
    )

    return spark_py_processor


@pytest.fixture(scope="module")
def spark_jar_processor(sagemaker_session, cpu_instance_type):
    spark_jar_processor = SparkJarProcessor(
        role="SageMakerRole",
        instance_count=2,
        instance_type=cpu_instance_type,
        sagemaker_session=sagemaker_session,
        framework_version="2.4",
    )

    return spark_jar_processor


@pytest.fixture(scope="module")
def spark_v3_jar_processor(sagemaker_session, cpu_instance_type):
    spark_jar_processor = SparkJarProcessor(
        role="SageMakerRole",
        instance_count=2,
        instance_type=cpu_instance_type,
        sagemaker_session=sagemaker_session,
        framework_version="3.1",
    )

    return spark_jar_processor


@pytest.fixture
def configuration() -> list:
    configuration = [
        {
            "Classification": "spark-defaults",
            "Properties": {"spark.executor.memory": "2g", "spark.executor.cores": "1"},
        },
        {
            "Classification": "hadoop-env",
            "Properties": {},
            "Configurations": [
                {
                    "Classification": "export",
                    "Properties": {
                        "HADOOP_DATANODE_HEAPSIZE": "2048",
                        "HADOOP_NAMENODE_OPTS": "-XX:GCTimeRatio=19",
                    },
                    "Configurations": [],
                }
            ],
        },
        {
            "Classification": "core-site",
            "Properties": {"spark.executor.memory": "2g", "spark.executor.cores": "1"},
        },
        {"Classification": "hadoop-log4j", "Properties": {"key": "value"}},
        {
            "Classification": "hive-env",
            "Properties": {},
            "Configurations": [
                {
                    "Classification": "export",
                    "Properties": {
                        "HADOOP_DATANODE_HEAPSIZE": "2048",
                        "HADOOP_NAMENODE_OPTS": "-XX:GCTimeRatio=19",
                    },
                    "Configurations": [],
                }
            ],
        },
        {"Classification": "hive-log4j", "Properties": {"key": "value"}},
        {"Classification": "hive-exec-log4j", "Properties": {"key": "value"}},
        {"Classification": "hive-site", "Properties": {"key": "value"}},
        {"Classification": "spark-defaults", "Properties": {"key": "value"}},
        {
            "Classification": "spark-env",
            "Properties": {},
            "Configurations": [
                {
                    "Classification": "export",
                    "Properties": {
                        "HADOOP_DATANODE_HEAPSIZE": "2048",
                        "HADOOP_NAMENODE_OPTS": "-XX:GCTimeRatio=19",
                    },
                    "Configurations": [],
                }
            ],
        },
        {"Classification": "spark-log4j", "Properties": {"key": "value"}},
        {"Classification": "spark-hive-site", "Properties": {"key": "value"}},
        {"Classification": "spark-metrics", "Properties": {"key": "value"}},
        {"Classification": "yarn-site", "Properties": {"key": "value"}},
        {
            "Classification": "yarn-env",
            "Properties": {},
            "Configurations": [
                {
                    "Classification": "export",
                    "Properties": {
                        "HADOOP_DATANODE_HEAPSIZE": "2048",
                        "HADOOP_NAMENODE_OPTS": "-XX:GCTimeRatio=19",
                    },
                    "Configurations": [],
                }
            ],
        },
    ]
    return configuration


def test_sagemaker_pyspark_v3(
    spark_v3_py_processor, spark_v3_jar_processor, sagemaker_session, configuration, build_jar
):
    test_sagemaker_pyspark_multinode(spark_v3_py_processor, sagemaker_session, configuration)
    test_sagemaker_java_jar_multinode(
        spark_v3_jar_processor, sagemaker_session, configuration, build_jar
    )


def test_sagemaker_pyspark_multinode(spark_py_processor, sagemaker_session, configuration):
    """Test that basic multinode case works on 32KB of data"""
    bucket = spark_py_processor.sagemaker_session.default_bucket()
    timestamp = datetime.now().isoformat()
    output_data_uri = f"s3://{bucket}/spark/output/sales/{timestamp}"
    spark_event_logs_key_prefix = f"spark/spark-events/{timestamp}"
    spark_event_logs_s3_uri = f"s3://{bucket}/{spark_event_logs_key_prefix}"

    with open(os.path.join(SPARK_PATH, "files", "data.jsonl")) as data:
        body = data.read()
        input_data_uri = f"s3://{bucket}/spark/input/data.jsonl"
        S3Uploader.upload_string_as_file_body(
            body=body, desired_s3_uri=input_data_uri, sagemaker_session=sagemaker_session
        )

    spark_py_processor.run(
        submit_app=os.path.join(
            SPARK_PATH, "code", "python", "hello_py_spark", "hello_py_spark_app.py"
        ),
        submit_py_files=[
            os.path.join(SPARK_PATH, "code", "python", "hello_py_spark", "hello_py_spark_udfs.py")
        ],
        arguments=["--input", input_data_uri, "--output", output_data_uri],
        configuration=configuration,
        spark_event_logs_s3_uri=spark_event_logs_s3_uri,
        wait=False,
    )
    processing_job = spark_py_processor.latest_job

    s3_client = boto3.client(
        "s3", region_name=spark_py_processor.sagemaker_session.boto_region_name
    )

    file_size = 0
    latest_file_size = None
    updated_times_count = 0
    time_out = time.time() + 900
    while not processing_job_not_fail_or_complete(
        sagemaker_session.sagemaker_client, processing_job.job_name
    ):
        response = s3_client.list_objects(Bucket=bucket, Prefix=spark_event_logs_key_prefix)
        if "Contents" in response:
            # somehow when call list_objects the first file size is always 0, this for loop
            # is to skip that.
            for event_log_file in response["Contents"]:
                if event_log_file["Size"] != 0:
                    latest_file_size = event_log_file["Size"]

        # update the file size if it increased
        if latest_file_size and latest_file_size > file_size:
            updated_times_count += 1
            file_size = latest_file_size

        if time.time() > time_out:
            raise RuntimeError("Timeout")

        time.sleep(20)

    # verify that spark event logs are periodically written to s3
    assert file_size != 0

    output_contents = S3Downloader.list(output_data_uri, sagemaker_session=sagemaker_session)
    assert len(output_contents) != 0


def test_sagemaker_java_jar_multinode(
    spark_jar_processor, sagemaker_session, configuration, build_jar
):
    """Test SparkJarProcessor using Java application jar"""
    bucket = spark_jar_processor.sagemaker_session.default_bucket()
    with open(os.path.join(SPARK_PATH, "files", "data.jsonl")) as data:
        body = data.read()
        input_data_uri = f"s3://{bucket}/spark/input/data.jsonl"
        S3Uploader.upload_string_as_file_body(
            body=body, desired_s3_uri=input_data_uri, sagemaker_session=sagemaker_session
        )
    output_data_uri = f"s3://{bucket}/spark/output/sales/{datetime.now().isoformat()}"

    java_project_dir = os.path.join(SPARK_PATH, "code", "java", "hello-java-spark")
    spark_jar_processor.run(
        submit_app=f"{java_project_dir}/hello-spark-java.jar",
        submit_class="com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp",
        arguments=["--input", input_data_uri, "--output", output_data_uri],
        configuration=configuration,
    )
    processing_job = spark_jar_processor.latest_job

    waiter = sagemaker_session.sagemaker_client.get_waiter("processing_job_completed_or_stopped")
    waiter.wait(
        ProcessingJobName=processing_job.job_name,
        # poll every 15 seconds. timeout after 15 minutes.
        WaiterConfig={"Delay": 15, "MaxAttempts": 60},
    )

    describe_response = sagemaker_session.sagemaker_client.describe_processing_job(
        ProcessingJobName=processing_job.job_name
    )
    assert describe_response["ProcessingJobStatus"] == "Completed"


def processing_job_not_fail_or_complete(sagemaker_client, job_name):
    response = sagemaker_client.describe_processing_job(ProcessingJobName=job_name)

    if not response or "ProcessingJobStatus" not in response:
        raise ValueError("Response is none or does not have ProcessingJobStatus")
    status = response["ProcessingJobStatus"]
    return status == "Failed" or status == "Completed" or status == "Stopped"


def test_integ_history_server(spark_py_processor, sagemaker_session):
    bucket = spark_py_processor.sagemaker_session.default_bucket()
    spark_event_logs_key_prefix = "spark/spark-history-fs"
    spark_event_logs_s3_uri = f"s3://{bucket}/{spark_event_logs_key_prefix}"

    with open(os.path.join(SPARK_PATH, "files", "sample_spark_event_logs")) as data:
        body = data.read()
        S3Uploader.upload_string_as_file_body(
            body=body,
            desired_s3_uri=spark_event_logs_s3_uri + "/sample_spark_event_logs",
            sagemaker_session=sagemaker_session,
        )

    # sleep 3 seconds to avoid s3 eventual consistency issue
    time.sleep(3)
    spark_py_processor.start_history_server(spark_event_logs_s3_uri=spark_event_logs_s3_uri)

    try:
        response = _request_with_retry(HISTORY_SERVER_ENDPOINT)
        assert response.status == 200
    finally:
        spark_py_processor.terminate_history_server()


def test_integ_history_server_with_expected_failure(spark_py_processor):
    with TestCase.assertLogs("sagemaker", level="ERROR") as cm:
        spark_py_processor.start_history_server(spark_event_logs_s3_uri="invalids3uri")
    response = _request_with_retry(HISTORY_SERVER_ENDPOINT, max_retries=5)
    assert response is None
    assert (
        "History server failed to start. Please run 'docker logs history_server' to see logs"
        in cm.output[0]
    )


def _request_with_retry(url, max_retries=10):
    http = urllib3.PoolManager(
        retries=Retry(
            max_retries,
            redirect=max_retries,
            status=max_retries,
            status_forcelist=[502, 404],
            backoff_factor=0.2,
        )
    )
    try:
        return http.request("GET", url)
    except Exception:  # pylint: disable=W0703
        return None