/*
 * Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License").
 * You may not use this file except in compliance with the License.
 * A copy of the License is located at
 *
 *   http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed
 * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
 * express or implied. See the License for the specific language governing
 * permissions and limitations under the License.
 */

package com.amazonaws.services.sagemaker.sparksdk

import java.util

import collection.JavaConverters.mapAsJavaMapConverter
import com.amazonaws.SdkClientException
import com.amazonaws.services.s3.AmazonS3
import com.amazonaws.services.s3.model.{AmazonS3Exception, ObjectListing, S3ObjectSummary}
import com.amazonaws.services.securitytoken.AWSSecurityTokenService
import com.amazonaws.services.securitytoken.model.{GetCallerIdentityRequest, GetCallerIdentityResult}
import org.mockito.ArgumentCaptor
import org.mockito.Matchers.any
import org.mockito.Mockito._
import org.scalatest._
import org.scalatest.mockito.MockitoSugar
import scala.language.postfixOps

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.ml.param.{BooleanParam, IntParam, Param}
import org.apache.spark.sql._

import com.amazonaws.services.sagemaker.AmazonSageMaker
import com.amazonaws.services.sagemaker.model._
import com.amazonaws.services.sagemaker.sparksdk.EndpointCreationPolicy.EndpointCreationPolicy
import com.amazonaws.services.sagemaker.sparksdk.internal.{DataUploader, ManifestDataUploadResult, ObjectPrefixUploadResult, TimeProvider}
import com.amazonaws.services.sagemaker.sparksdk.transformation.{RequestRowSerializer, ResponseRowDeserializer}
import com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.LibSVMResponseRowDeserializer
import com.amazonaws.services.sagemaker.sparksdk.transformation.serializers.LibSVMRequestRowSerializer

class SageMakerEstimatorTests extends FlatSpec with Matchers with MockitoSugar with BeforeAndAfter {

  var dataset: Dataset[String] = _
  var dataUploaderMock: DataUploader = _
  var timeProviderMock: TimeProvider = _
  var sagemakerMock: AmazonSageMaker = _

  var s3Mock : AmazonS3 = _
  var stsMock : AWSSecurityTokenService = _
  var sparkConfMock : SparkConf = _

  val s3Bucket = "a"
  val s3Prefix = "b"
  val s3TrainingPrefix = "b/test-training-job"
  val s3DataPrefix = "b/test-training-job/data.pbr"

  before {
    dataset = mock[Dataset[String]]
    dataUploaderMock = mock[DataUploader]
    when(dataUploaderMock.uploadData(any[S3DataPath], any[Dataset[_]]))
      .thenReturn(ObjectPrefixUploadResult(S3DataPath(s3Bucket, s3TrainingPrefix)))
    timeProviderMock = mock[TimeProvider]

    sagemakerMock = mock[AmazonSageMaker]
    s3Mock = mock[AmazonS3]
    stsMock = mock[AWSSecurityTokenService]
    sparkConfMock = mock[SparkConf]

    val sparkSessionMock = mock[SparkSession]
    val sparkContextMock = mock[SparkContext]

    when(dataset.sparkSession).thenReturn(sparkSessionMock)
    when(sparkSessionMock.sparkContext).thenReturn(sparkContextMock)
    when(sparkContextMock.getConf).thenReturn(sparkConfMock)

    val objectSummaryMock = mock[S3ObjectSummary]
    when(objectSummaryMock.getKey).thenReturn(s3DataPrefix)
    val objectListMock = mock[ObjectListing]
    when(objectListMock.getObjectSummaries).thenReturn(util.Arrays.asList(objectSummaryMock))
    when(s3Mock.listObjects(s3Bucket, s3TrainingPrefix)).thenReturn(objectListMock)
    when(s3Mock.getRegionName).thenReturn("region")
  }

  "SageMakerEstimator" should "generate a UID" in {

    val estimator = new DummyEstimator()
    val estimator2 = new DummyEstimator(uid = "blah")

    assert(estimator.toString().startsWith("sagemaker"))
    assert(estimator.uid != estimator2.uid)
  }

  it should "have empty hyperparameter map when no params defined" in {
    val estimator = new DummyEstimator()
    assert(estimator.makeHyperParameters() isEmpty)
  }

  it should "have correct hyperparameter when empty params defined" in {
    val estimator = new DummyEstimator() {
      val stringParam: Param[String] = new Param(this, "stringParam", "")
      val intParam: IntParam = new IntParam(this, "intParam", "")
      val booleanParam: BooleanParam = new BooleanParam(this, "booleanParam", "")
      val otherStringParam: Param[String] = new Param[String](this, "otherStringParam", "")
    }
    assert(estimator.makeHyperParameters() == collection.immutable.Map().asJava)
  }

  it should "record the latest training job after calling fit()" in {
    val estimator = new DummyEstimator()
    assert (estimator.latestTrainingJob.isEmpty)

    when(timeProviderMock.currentTimeMillis).thenReturn(0)
    setupCreateTrainingJobResult()
    setupDescribeTrainingJobResponses(TrainingJobStatus.Completed)
    val model = estimator.fit(dataset)
    assert (estimator.latestTrainingJob.nonEmpty)
  }

  it should "select only the projected columns of the dataset to a given s3 location in fit()" in {
    val estimator = new DummyEstimator()
    when(timeProviderMock.currentTimeMillis).thenReturn(0)
    setupCreateTrainingJobResult()
    setupDescribeTrainingJobResponses(TrainingJobStatus.Completed)
    val model = estimator.fit(dataset)
    verify(dataset).select(s3Bucket, s3Prefix)
  }

  it should "attempt all the columns of the dataset to a given s3 location in fit()" in {
    val estimator = new DummyEstimator(dummyTrainingProjectedColumns = None)
    when(timeProviderMock.currentTimeMillis).thenReturn(0)
    setupCreateTrainingJobResult()
    setupDescribeTrainingJobResponses(TrainingJobStatus.Completed)
    val model = estimator.fit(dataset)
    verify(dataset, times(0)).select(any[String], any[String])
    verify(dataUploaderMock).uploadData(estimator.dummyS3InputDataPathWithTrainingJobName
      .asInstanceOf[S3DataPath], dataset)
  }

  it should "attempt all the columns of the dataset to a given s3 location in fit() if given an " +
    "empty list of columns names" in {
    val estimator = new DummyEstimator(dummyTrainingProjectedColumns = Some(List[String]()))
    when(timeProviderMock.currentTimeMillis).thenReturn(0)
    setupCreateTrainingJobResult()
    setupDescribeTrainingJobResponses(TrainingJobStatus.Completed)
    val model = estimator.fit(dataset)
    verify(dataset, times(0)).select(any[String], any[String])
    verify(dataUploaderMock).uploadData(estimator.dummyS3InputDataPathWithTrainingJobName
      .asInstanceOf[S3DataPath], dataset)
  }

  it should "correctly create a training job request from training properties" in {
    val estimator = new DummyEstimator()
    when(timeProviderMock.currentTimeMillis).thenReturn(0)
    setupCreateTrainingJobResult()
    setupDescribeTrainingJobResponses(TrainingJobStatus.Completed)

    when(dataUploaderMock.uploadData(any[S3DataPath], any[Dataset[_]]))
      .thenReturn(ObjectPrefixUploadResult(
        estimator.trainingInputS3DataPath.asInstanceOf[S3DataPath]))
    estimator.fit(dataset)

    val createTrainingJobRequestCaptor = ArgumentCaptor.forClass(classOf[CreateTrainingJobRequest])
    verify(sagemakerMock).createTrainingJob(createTrainingJobRequestCaptor.capture())

    val createTrainingJobArgument = createTrainingJobRequestCaptor.getValue
    assert(estimator.trainingImage == createTrainingJobArgument.getAlgorithmSpecification
      .getTrainingImage)

    assert(estimator.trainingInputMode == createTrainingJobArgument.getAlgorithmSpecification
      .getTrainingInputMode)
    assert(estimator.trainingCompressionCodec.get == createTrainingJobArgument.getInputDataConfig
      .get(0).getCompressionType)
    assert(estimator.trainingChannelName == createTrainingJobArgument.getInputDataConfig.get(0)
      .getChannelName)
    assert(estimator.trainingContentType.get == createTrainingJobArgument.getInputDataConfig.get(0)
      .getContentType)
    assert(estimator.trainingInputS3DataPath.asInstanceOf[S3DataPath].toS3UriString
      == createTrainingJobArgument.getInputDataConfig.get(0).getDataSource.getS3DataSource.getS3Uri)
    assert(estimator.trainingS3DataDistribution == createTrainingJobArgument.getInputDataConfig
      .get(0).getDataSource.getS3DataSource.getS3DataDistributionType)
    assert(estimator.dummyS3OutputDataPathWithTrainingJobName.toS3UriString
      == createTrainingJobArgument.getOutputDataConfig.getS3OutputPath)
    assert(estimator.trainingInstanceCount == createTrainingJobArgument
      .getResourceConfig.getInstanceCount)
    assert(estimator.trainingInstanceType == createTrainingJobArgument.getResourceConfig
      .getInstanceType)
    assert(estimator.sagemakerRole
      .asInstanceOf[IAMRole].role == createTrainingJobArgument.getRoleArn)
    assert(estimator.trainingMaxRuntimeInSeconds == createTrainingJobArgument.getStoppingCondition
      .getMaxRuntimeInSeconds)
    assert(estimator.trainingKmsKeyId.get == createTrainingJobArgument.getOutputDataConfig
      .getKmsKeyId)
  }

  it should "have correct hyperparameter map when default and non default params set" in {

    val estimator = new DummyEstimator() {
      val stringParam : Param[String] = new Param(this, "stringParam", "")
      val intParam : IntParam = new IntParam(this, "intParam", "")
      val booleanParam : BooleanParam = new BooleanParam(this, "booleanParam", "")
      val otherStringParam : Param[String] = new Param[String](this, "otherStringParam", "")
      setDefault(stringParam, "default")
      setDefault(intParam, 55)
    }
    estimator.set(estimator.intParam, 66)
    estimator.set(estimator.otherStringParam, "Elizabeth")
    assert(estimator.makeHyperParameters() == Map(
      "stringParam" -> "default",
      "intParam" -> "66",
      "otherStringParam" -> "Elizabeth").asJava)
  }

  it should "poll for training job completion" in {
    val estimator = new DummyEstimator()
    when(timeProviderMock.currentTimeMillis).thenReturn(0)
    setupCreateTrainingJobResult()
    setupDescribeTrainingJobResponses(TrainingJobStatus.InProgress, TrainingJobStatus.InProgress,
      TrainingJobStatus.Completed)
    val sagemakerModel = estimator.fit(dataset)
    verify(sagemakerMock, times(4)).describeTrainingJob(any[DescribeTrainingJobRequest])
    verify(timeProviderMock, times(2)).sleep(SageMakerEstimator.TrainingJobPollInterval.toMillis)
    assert(sagemakerModel.uid == estimator.uid)
  }

  it should "ignore transient failures when polling for training completion" in {
    val estimator = new DummyEstimator()
    when(timeProviderMock.currentTimeMillis).thenReturn(0)
    setupCreateTrainingJobResult()
    val retryableException = new AmazonSageMakerException("transient failure")
    retryableException.setStatusCode(500)
    when(sagemakerMock.describeTrainingJob(any[DescribeTrainingJobRequest]))
      .thenThrow(retryableException)
      .thenReturn(statusToResult(TrainingJobStatus.InProgress))
      .thenReturn(statusToResult(TrainingJobStatus.Completed))
    estimator.fit(dataset)
    verify(sagemakerMock, times(4)).describeTrainingJob(any[DescribeTrainingJobRequest])
    verify(timeProviderMock, times(2)).sleep(SageMakerEstimator.TrainingJobPollInterval.toMillis)
  }

  it should "throw an exception if the training job fails to create" in {
    when(sagemakerMock.createTrainingJob(any[CreateTrainingJobRequest]))
      .thenThrow(new AmazonSageMakerException("EASE is down."))
    val estimator = new DummyEstimator()

    when(timeProviderMock.currentTimeMillis).thenReturn(0)
    val caught = intercept[RuntimeException] {
      estimator.fit(dataset)
    }

    verify(sagemakerMock, never()).describeTrainingJob(any[DescribeTrainingJobRequest])
  }

  it should "throw an exception when training fails" in {
    val estimator = new DummyEstimator()
    when(timeProviderMock.currentTimeMillis).thenReturn(0)
    setupDescribeTrainingJobResponses(TrainingJobStatus.InProgress, TrainingJobStatus.InProgress,
      TrainingJobStatus.Failed)
    val caught = intercept[RuntimeException] {
      estimator.fit(dataset)
    }
  }

  it should "throw an exception when polling exceeds training timeout" in {
    val estimator = new DummyEstimator()

    when(timeProviderMock.currentTimeMillis)
      .thenReturn(0) // starting s3 upload time
      .thenReturn(0) // starting training time
      .thenReturn(1) // await training completion start time
      .thenReturn(2) // first while loop check
      .thenReturn(2) // second while loop check
      // third while loop check - should exit loop
      .thenReturn(estimator.trainingJobTimeout.toMillis + 1)

    setupCreateTrainingJobResult()
    setupDescribeTrainingJobResponses(TrainingJobStatus.InProgress)
    val caught = intercept[RuntimeException] {
      estimator.fit(dataset)
    }

    verify(sagemakerMock, times(2)).describeTrainingJob(any[DescribeTrainingJobRequest])
    verify(timeProviderMock, times(2)).sleep(SageMakerEstimator.TrainingJobPollInterval.toMillis)
  }

  it should "create a training job request with a manifest file if running on EMRFS" in {
    val estimator = new DummyEstimator()
    val trainingJobRequest = estimator.buildCreateTrainingJobRequest("blah",
      ManifestDataUploadResult(new S3DataPath("bucket", "objectPath")), sparkConfMock)
    val s3DataType = trainingJobRequest.getInputDataConfig.get(0).getDataSource.getS3DataSource
      .getS3DataType
    assert(S3DataType.ManifestFile.toString == s3DataType)
  }

  it should "create a training job request without a manifest file for input datasource if not " +
    "running on EMRFS" in {
    val estimator = new DummyEstimator()
    val trainingJobRequest = estimator.buildCreateTrainingJobRequest("blah",
      ObjectPrefixUploadResult(new S3DataPath("bucket", "objectPath")), sparkConfMock)
    val s3DataType = trainingJobRequest.getInputDataConfig.get(0).getDataSource.getS3DataSource
      .getS3DataType
    assert(S3DataType.S3Prefix.toString == s3DataType)
  }

  it should "resolve s3 locations from configuration" in {
    val estimator = new DummyEstimator()
    val trainingJobName = "training"
    when(sparkConfMock.get("test-config-key")).thenReturn ("s3://bucket/path")
    val dp = estimator.resolveS3Path(S3PathFromConfig("test-config-key"), trainingJobName,
      sparkConfMock)
    assert(S3DataPath("bucket", "path/training") == dp)
    assert(dp.objectPath.endsWith(trainingJobName))
  }

  it should "resolve s3 locations with random paths from configuration" in {
    val estimator = new DummyEstimator()
    val trainingJobName = "training"
    when(sparkConfMock.get("test-config-key")).thenReturn ("bucket")
    val dp = estimator.resolveS3Path(S3PathFromConfig("test-config-key"), trainingJobName,
      sparkConfMock)
    assert (dp.bucket == "bucket")
    assert (dp.objectPath.length == 45)
    assert (dp.objectPath.endsWith(trainingJobName))
  }

  it should "resolve role arn from configuration" in {
    val estimator = new DummyEstimator()
    when(sparkConfMock.get("test-config-key")).thenReturn("arn-role")
    assert(IAMRole("arn-role") == estimator.resolveRoleARN(IAMRoleFromConfig("test-config-key"),
      sparkConfMock))
  }

  it should "resolve role arn from role" in {
    val estimator = new DummyEstimator()
    assert(IAMRole("arn-role") == estimator.resolveRoleARN(IAMRole("arn-role"), sparkConfMock))
  }

  it should "create bucket and prefix with training job name" in {
    val estimator = new DummyEstimator()
    val mockAccount = "1234"
    val mockResult = new GetCallerIdentityResult().withAccount(mockAccount)
    val trainingJobName = "training"
    when(stsMock.getCallerIdentity(any[GetCallerIdentityRequest])).thenReturn(mockResult)
    val dp = estimator.resolveS3Path(S3AutoCreatePath(), trainingJobName, sparkConfMock)
    verify(s3Mock, times(1)).createBucket("1234-sagemaker-region")
    assert("1234-sagemaker-region" == dp.bucket)
    assert(dp.objectPath.endsWith(trainingJobName))
  }

  it should "refuse to create a bucket if it already exists" in {
    when(timeProviderMock.currentTimeMillis).thenReturn(0)
    val mockAccount = "1234"
    val mockResult = new GetCallerIdentityResult().withAccount(mockAccount)
    when(stsMock.getCallerIdentity(any[GetCallerIdentityRequest])).thenReturn(mockResult)

    when(s3Mock.createBucket(any[String])).thenThrow(new AmazonS3Exception(
      "not a bucket exists exception"))
    val estimator = new DummyEstimator(dummyTrainingInputS3DataPath = S3AutoCreatePath())
    intercept[AmazonS3Exception] {
      estimator.fit(dataset)
    }
  }

  it should "fail to create bucket on exception" in {
    when(timeProviderMock.currentTimeMillis).thenReturn(0)
    val mockAccount = "1234"
    val mockResult = new GetCallerIdentityResult().withAccount(mockAccount)
    when(stsMock.getCallerIdentity(any[GetCallerIdentityRequest])).thenReturn(mockResult)

    when(s3Mock.createBucket(any[String])).thenThrow(new AmazonS3Exception(
      "not a bucket exists exception"))
    val estimator = new DummyEstimator(dummyTrainingInputS3DataPath = S3AutoCreatePath())
    intercept[AmazonS3Exception] {
      estimator.fit(dataset)
    }
  }

  it should "take hyperparameter map in constructor" in {
    val estimator = new DummyEstimator(dummyHyperParameters = Map(
      s3Bucket -> "a value", s3Prefix -> "55"))
    estimator.makeHyperParameters().equals(Map(s3Bucket -> "a value", s3Prefix -> "55").asJava)
  }

  it should "remove the training data when training completed" in {
    when(timeProviderMock.currentTimeMillis).thenReturn(0)
    val mockAccount = "1234"
    val mockResult = new GetCallerIdentityResult().withAccount(mockAccount)
    when(stsMock.getCallerIdentity(any[GetCallerIdentityRequest])).thenReturn(mockResult)

    val estimator = new DummyEstimator()

    when(sagemakerMock.describeTrainingJob(any[DescribeTrainingJobRequest]))
      .thenReturn(statusToResult(TrainingJobStatus.InProgress))
      .thenReturn(statusToResult(TrainingJobStatus.Completed))

    estimator.fit(dataset)

    verify(s3Mock).deleteObject(s3Bucket, s3DataPrefix)
    verify(s3Mock).deleteObject(s3Bucket, s3TrainingPrefix)
  }

  it should "remove the training data when training failed" in {
    when(timeProviderMock.currentTimeMillis).thenReturn(0)
    val mockAccount = "1234"
    val mockResult = new GetCallerIdentityResult().withAccount(mockAccount)
    when(stsMock.getCallerIdentity(any[GetCallerIdentityRequest])).thenReturn(mockResult)

    val estimator = new DummyEstimator()

    intercept[RuntimeException] {
      estimator.fit(dataset)
    }

    verify(s3Mock).deleteObject(s3Bucket, s3DataPrefix)
    verify(s3Mock).deleteObject(s3Bucket, s3TrainingPrefix)
  }

  it should "swallow the s3 exception if failed to remove training data" in {
    when(s3Mock.deleteObject(any[String], any[String])).thenThrow(new SdkClientException(
      "failed to delete training data"))
    when(timeProviderMock.currentTimeMillis).thenReturn(0)
    val mockAccount = "1234"
    val mockResult = new GetCallerIdentityResult().withAccount(mockAccount)
    when(stsMock.getCallerIdentity(any[GetCallerIdentityRequest])).thenReturn(mockResult)

    val estimator = new DummyEstimator()

    when(sagemakerMock.describeTrainingJob(any[DescribeTrainingJobRequest]))
      .thenReturn(statusToResult(TrainingJobStatus.InProgress))
      .thenReturn(statusToResult(TrainingJobStatus.Completed))

    estimator.fit(dataset)

    verify(s3Mock).deleteObject(any[String], any[String])
  }

  it should "keep the training data when DeleteAfterTrain is false" in {
    when(timeProviderMock.currentTimeMillis).thenReturn(0)
    val mockAccount = "1234"
    val mockResult = new GetCallerIdentityResult().withAccount(mockAccount)
    when(stsMock.getCallerIdentity(any[GetCallerIdentityRequest])).thenReturn(mockResult)

    val estimator = new DummyEstimator(dummyDeleteAfterTraining = false)

    when(sagemakerMock.describeTrainingJob(any[DescribeTrainingJobRequest]))
      .thenReturn(statusToResult(TrainingJobStatus.InProgress))
      .thenReturn(statusToResult(TrainingJobStatus.Completed))

    estimator.fit(dataset)

    verify(s3Mock, never).deleteObject(any[String], any[String])
  }

  val dummyModelArtifactLocation : String = "s3://bucket/string"

  case class DummyNamePolicy(val prefix : String = "") extends NamePolicy {
    val uid = "test"

    val trainingJobName = uid + "-training-job"
    val modelName = uid + "-model"
    val endpointConfigName = uid + "-endpoint-config"
    val endpointName = uid + "-endpoint"
  }

  class DummyEstimator (
    override val uid : String = "sagemaker",
    val dummyTrainingImage : String = "training-image",
    val dummyModelImage : String = "model-image",
    val dummyRequestRowSerializer: RequestRowSerializer = new LibSVMRequestRowSerializer(),
    val dummyResponseRowDeserializer: ResponseRowDeserializer =
      new LibSVMResponseRowDeserializer(10),
    val dummySageMakerRole : IAMRoleResource = IAMRole("dummy-role"),
    val dummyTrainingInputS3DataPath : S3Resource = S3DataPath(s3Bucket, s3Prefix),
    val dummyTrainingOutputS3DataPath : S3Resource = S3DataPath(s3Bucket, s3Prefix),
    val dummyTrainingInstanceType : String = "m4.large",
    val dummyTrainingInstanceCount : Int = 1,
    val dummyTrainingInstanceVolumeSizeInGB : Int = 1024,
    val dummyTrainingProjectedColumns : Option[List[String]] = Some(List(s3Bucket, s3Prefix)),
    val dummyTrainingChannelName : String = "training",
    val dummyTrainingContentType : Option[String] = Some("application/x-record-protobuf"),
    val dummyTrainingS3DataDistribution : String = S3DataDistribution.ShardedByS3Key.toString,
    val dummyTrainingSparkDataFormat : String = "sagemaker",
    val dummyTrainingSparkDataFormatOptions : collection.immutable.Map[String, String] =
      collection.immutable.Map(),
    val dummyTrainingInputMode : String = TrainingInputMode.File.toString,
    val dummyTrainingCompressionCodec : Option[String] = Some("codec"),
    val dummytrainingMaxRuntimeInSeconds : Int = 24 * 60 * 60,
    val dummyTrainingKmsKeyId : Option[String] = Some("kms"),
    val dummyModelEnvironmentVariables : collection.immutable.Map[String, String] =
      collection.immutable.Map(),
    val dummyEndpointInstanceType : String = "m4.large",
    val dummyendpointInitialInstanceCount : Int = 1,
    val dummyEndpointCreationPolicy : EndpointCreationPolicy =
      EndpointCreationPolicy.CREATE_ON_TRANSFORM,
    val dummyModelPrependInputRowsToTransformationRows : Boolean = true,
    val dummyDeleteAfterTraining : Boolean = true,
    val dummyNamePolicy : NamePolicy = DummyNamePolicy(),
    val dummyHyperParameters : Map[String, String] = Map()) extends SageMakerEstimator (
    dummyTrainingImage,
    dummyModelImage,
    dummySageMakerRole,
    dummyTrainingInstanceType,
    dummyTrainingInstanceCount,
    dummyEndpointInstanceType,
    dummyendpointInitialInstanceCount,
    dummyRequestRowSerializer,
    dummyResponseRowDeserializer,
    dummyTrainingInputS3DataPath,
    dummyTrainingOutputS3DataPath,
    dummyTrainingInstanceVolumeSizeInGB,
    dummyTrainingProjectedColumns,
    dummyTrainingChannelName,
    dummyTrainingContentType,
    dummyTrainingS3DataDistribution,
    dummyTrainingSparkDataFormat,
    dummyTrainingSparkDataFormatOptions,
    dummyTrainingInputMode,
    dummyTrainingCompressionCodec,
    dummytrainingMaxRuntimeInSeconds,
    dummyTrainingKmsKeyId,
    dummyModelEnvironmentVariables,
    dummyEndpointCreationPolicy,
    sagemakerMock,
    s3Mock,
    stsMock,
    dummyModelPrependInputRowsToTransformationRows,
    dummyDeleteAfterTraining,
    new NamePolicyFactory { override def createNamePolicy: NamePolicy = DummyNamePolicy() },
    uid,
    dummyHyperParameters) {
    this.timeProvider = timeProviderMock
    this.dataUploader = dataUploaderMock

    def dummyS3InputDataPathWithTrainingJobName : S3DataPath =
      resolveS3Path(dummyTrainingInputS3DataPath, dummyNamePolicy.trainingJobName, sparkConfMock)
    def dummyS3OutputDataPathWithTrainingJobName : S3DataPath =
      resolveS3Path(dummyTrainingOutputS3DataPath, dummyNamePolicy.trainingJobName, sparkConfMock)
  }

  private def setupCreateTrainingJobResult() = {
    val fakeArn = "arn"
    val mockCreateTrainingJobResult = new CreateTrainingJobResult().withTrainingJobArn(fakeArn)
    when(sagemakerMock.createTrainingJob(any[CreateTrainingJobRequest])).thenReturn(
      mockCreateTrainingJobResult)
  }

  private def statusToResult(status: TrainingJobStatus): DescribeTrainingJobResult = {
    val modelArtifacts = new ModelArtifacts().withS3ModelArtifacts(dummyModelArtifactLocation)
    new DescribeTrainingJobResult().withTrainingJobStatus(status).withModelArtifacts(modelArtifacts)
  }

  private def setupDescribeTrainingJobResponses(firstStatus: TrainingJobStatus,
                                                moreStatuses : TrainingJobStatus*) = {
    when(sagemakerMock.describeTrainingJob(any[DescribeTrainingJobRequest])).
      thenReturn(statusToResult(firstStatus),
      moreStatuses.map(statusToResult): _*)
  }

}