/*

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package hosting

import (
	"context"

	commonv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/common"
	endpointconfigv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/endpointconfig"
	hostingv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/hostingdeployment"
	modelv1 "github.com/aws/amazon-sagemaker-operator-for-k8s/api/v1/model"
	. "github.com/aws/amazon-sagemaker-operator-for-k8s/controllers/controllertest"
	endpointconfigcontroller "github.com/aws/amazon-sagemaker-operator-for-k8s/controllers/endpointconfig"
	modelcontroller "github.com/aws/amazon-sagemaker-operator-for-k8s/controllers/model"
	apierrs "k8s.io/apimachinery/pkg/api/errors"
	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

	"github.com/google/uuid"
	. "github.com/onsi/ginkgo"
	. "github.com/onsi/gomega"
	"k8s.io/apimachinery/pkg/types"
	ctrl "sigs.k8s.io/controller-runtime"
	// +kubebuilder:scaffold:imports
)

var _ = Describe("EndpointConfigReconciler.Reconcile", func() {

	var (
		reconciler EndpointConfigReconciler
		desired    hostingv1.HostingDeployment
	)

	BeforeEach(func() {
		reconciler = NewEndpointConfigReconciler(k8sClient, ctrl.Log)

		desired = createHostingDeploymentWithGeneratedNames()

	})

	It("Returns an error when a ProductionVariant does not have a VariantName", func() {
		desired.Spec.ProductionVariants = []commonv1.ProductionVariant{
			{
				InitialInstanceCount: ToInt64Ptr(1),
				InstanceType:         "instance-type",
				ModelName:            ToStringPtr("model-name"),
			},
		}

		err := reconciler.Reconcile(context.Background(), &desired, true)

		Expect(err).To(HaveOccurred())
		Expect(err.Error()).To(ContainSubstring("ProductionVariant has nil VariantName"))
	})

	It("Returns an error when a ProductionVariant does not have a ModelName", func() {

		desired.Spec.ProductionVariants = []commonv1.ProductionVariant{
			{
				InitialInstanceCount: ToInt64Ptr(1),
				InstanceType:         "instance-type",
				VariantName:          ToStringPtr("variant-name"),
			},
		}

		err := reconciler.Reconcile(context.Background(), &desired, true)

		Expect(err).To(HaveOccurred())
		Expect(err.Error()).To(ContainSubstring("ProductionVariant"))
		Expect(err.Error()).To(ContainSubstring("has nil ModelName"))
	})

	Context("Fail to get k8s model", func() {
		BeforeEach(func() {
			reconciler = NewEndpointConfigReconciler(FailToGetK8sClient{}, ctrl.Log)
			desired = createHostingDeploymentWithBasicProductionVariant()
		})

		It("Returns error if unable to get K8s EndpointConfig", func() {
			err := reconciler.Reconcile(context.Background(), &desired, true)

			Expect(err).To(HaveOccurred())
			Expect(err.Error()).To(ContainSubstring("Unable to resolve SageMaker model name for model"))
		})
	})

	It("Will not create if an EndpointConfig exists already and they are deep equal", func() {
		desired = createHostingDeploymentWithBasicProductionVariant()
		key := GetSubresourceNamespacedName(desired.ObjectMeta.GetName(), desired)

		err := k8sClient.Create(context.Background(), &endpointconfigv1.EndpointConfig{
			ObjectMeta: metav1.ObjectMeta{
				Name:      key.Name,
				Namespace: key.Namespace,
				Labels:    GetResourceOwnershipLabelsForHostingDeployment(desired),
			},
			Spec: endpointconfigv1.EndpointConfigSpec{
				ProductionVariants: []commonv1.ProductionVariant{
					{
						InitialInstanceCount: ToInt64Ptr(1),
						InstanceType:         "instance-type",
						VariantName:          ToStringPtr("variant-name"),
						ModelName:            ToStringPtr("model-name"),
					},
				},
				Region: desired.Spec.Region,
			},
		})
		Expect(err).ToNot(HaveOccurred())
		updateEndpointConfigStatus(key, endpointconfigcontroller.CreatedStatus, "sagemaker-endpoint-name")

		modelName := *desired.Spec.ProductionVariants[0].ModelName
		modelNamespacedName := GetSubresourceNamespacedName(modelName, desired)
		Expect(createCreatedModelWithAnySageMakerName(modelNamespacedName, desired)).ToNot(HaveOccurred())

		// Fail test if Create is called on k8s client.
		// We expect Create to not be called for an existing deep equal endpoint config.
		reconciler = NewEndpointConfigReconciler(FailTestOnCreateK8sClient{
			ActualClient: k8sClient,
		}, ctrl.Log)
		err = reconciler.Reconcile(context.Background(), &desired, true)

		Expect(err).ToNot(HaveOccurred())
	})

	Context("Fail to create k8s EndpointConfig", func() {

		BeforeEach(func() {
			reconciler = NewEndpointConfigReconciler(FailToCreateK8sClient{
				ActualClient: k8sClient,
			}, ctrl.Log)

			desired = createHostingDeploymentWithBasicProductionVariant()

			// Create model correct status
			modelName := *desired.Spec.ProductionVariants[0].ModelName
			modelNamespacedName := GetSubresourceNamespacedName(modelName, desired)

			Expect(createCreatedModelWithAnySageMakerName(modelNamespacedName, desired)).ToNot(HaveOccurred())
		})

		It("Returns error if unable to create k8s EndpointConfig", func() {
			err := reconciler.Reconcile(context.Background(), &desired, true)

			Expect(err).To(HaveOccurred())
			Expect(err.Error()).To(ContainSubstring("Unable to create Kubernetes EndpointConfig"))
		})
	})

	Context("The deployment spec is valid", func() {

		var (
			tagKey                               string
			tagValue                             string
			expectedEndpointConfigNamespacedName types.NamespacedName
		)

		BeforeEach(func() {
			tagKey = "tag-key"
			tagValue = "tag-value"

			desired = createHostingDeploymentWithBasicProductionVariant()
			desired.Spec.Tags = []commonv1.Tag{
				{
					Key:   &tagKey,
					Value: &tagValue,
				},
			}

			expectedEndpointConfigNamespacedName = GetSubresourceNamespacedName(desired.ObjectMeta.GetName(), desired)

			// Create models

			modelName := *desired.Spec.ProductionVariants[0].ModelName
			modelNamespacedName := GetSubresourceNamespacedName(modelName, desired)
			Expect(createCreatedModelWithAnySageMakerName(modelNamespacedName, desired)).ToNot(HaveOccurred())

			reconciler.Reconcile(context.Background(), &desired, true)
		})

		AfterEach(func() {
			var endpointConfig endpointconfigv1.EndpointConfig
			err := k8sClient.Get(context.Background(), expectedEndpointConfigNamespacedName, &endpointConfig)
			Expect(err).ToNot(HaveOccurred())

			err = k8sClient.Delete(context.Background(), &endpointConfig)
			Expect(err).ToNot(HaveOccurred())
		})

		It("Created the k8s endpointconfig with correct ProductionVariant", func() {
			var endpointConfig endpointconfigv1.EndpointConfig
			err := k8sClient.Get(context.Background(), expectedEndpointConfigNamespacedName, &endpointConfig)
			Expect(err).ToNot(HaveOccurred())

			Expect(len(endpointConfig.Spec.ProductionVariants)).To(Equal(1))
			Expect(*endpointConfig.Spec.ProductionVariants[0].InitialInstanceCount).To(Equal(*desired.Spec.ProductionVariants[0].InitialInstanceCount))
			Expect(endpointConfig.Spec.ProductionVariants[0].InstanceType).To(Equal(desired.Spec.ProductionVariants[0].InstanceType))
			Expect(*endpointConfig.Spec.ProductionVariants[0].VariantName).To(Equal(*desired.Spec.ProductionVariants[0].VariantName))
		})

		It("Created the k8s endpointconfig with correct region", func() {
			var endpointConfig endpointconfigv1.EndpointConfig
			err := k8sClient.Get(context.Background(), expectedEndpointConfigNamespacedName, &endpointConfig)
			Expect(err).ToNot(HaveOccurred())

			Expect(endpointConfig.Spec.Region).ToNot(BeNil())
			Expect(*endpointConfig.Spec.Region).To(Equal(*desired.Spec.Region))
		})

		It("Created the k8s endpointconfig with correct tags", func() {
			var endpointConfig endpointconfigv1.EndpointConfig
			err := k8sClient.Get(context.Background(), expectedEndpointConfigNamespacedName, &endpointConfig)
			Expect(err).ToNot(HaveOccurred())

			Expect(len(endpointConfig.Spec.Tags)).To(Equal(1))
			Expect(endpointConfig.Spec.Tags[0].Key).ToNot(BeNil())
			Expect(*endpointConfig.Spec.Tags[0].Key).To(Equal(tagKey))
			Expect(endpointConfig.Spec.Tags[0].Value).ToNot(BeNil())
			Expect(*endpointConfig.Spec.Tags[0].Value).To(Equal(tagValue))
		})
	})
})

var _ = Describe("Delete EndpointConfigReconciler.Reconcile", func() {
	var (
		tagKey                               string
		tagValue                             string
		expectedEndpointConfigNamespacedName types.NamespacedName
		reconciler                           EndpointConfigReconciler
		desired                              hostingv1.HostingDeployment
	)

	BeforeEach(func() {
		tagKey = "tag-key"
		tagValue = "tag-value"

		reconciler = NewEndpointConfigReconciler(k8sClient, ctrl.Log)

		desired = createHostingDeploymentWithBasicProductionVariant()
		desired.Spec.Tags = []commonv1.Tag{
			{
				Key:   &tagKey,
				Value: &tagValue,
			},
		}

		expectedEndpointConfigNamespacedName = GetSubresourceNamespacedName(desired.ObjectMeta.GetName(), desired)

		// Create models

		modelName := *desired.Spec.ProductionVariants[0].ModelName
		modelNamespacedName := GetSubresourceNamespacedName(modelName, desired)
		Expect(createCreatedModelWithAnySageMakerName(modelNamespacedName, desired)).ToNot(HaveOccurred())

		reconciler.Reconcile(context.Background(), &desired, true)

		var endpointConfig endpointconfigv1.EndpointConfig
		err := k8sClient.Get(context.Background(), expectedEndpointConfigNamespacedName, &endpointConfig)
		Expect(err).ToNot(HaveOccurred())

		err = k8sClient.Delete(context.Background(), &endpointConfig)
		Expect(err).ToNot(HaveOccurred())
	})

	It("Verify that endpoint config has been deleted from k8s", func() {
		var endpointConfig endpointconfigv1.EndpointConfig
		err := k8sClient.Get(context.Background(), expectedEndpointConfigNamespacedName, &endpointConfig)
		Expect(err).To(HaveOccurred())
		Expect(apierrs.IsNotFound(err)).To(Equal(true))
	})

	It("Verify that reconciler returns error if deletion fails", func() {
		//TODO Gautam Kumar
	})
})

var _ = Describe("Update EndpointConfigReconciler.Reconcile", func() {

	var (
		reconciler EndpointConfigReconciler

		desired                      *hostingv1.HostingDeployment
		endpointConfigNamespacedName types.NamespacedName
		modelName                    string
	)

	BeforeEach(func() {
		reconciler = NewEndpointConfigReconciler(k8sClient, ctrl.Log)
		modelName = "model-name"
		containers := []*commonv1.ContainerDefinition{
			{
				ContainerHostname: ToStringPtr("present-container"),
				ModelDataUrl:      ToStringPtr("s3://bucket/model.tar.gz"),
			},
		}

		k8sName := "k8s-name-" + uuid.New().String()
		k8sNamespace := "k8s-namespace-" + uuid.New().String()
		CreateMockNamespace(context.Background(), k8sClient, k8sNamespace)

		desired = &hostingv1.HostingDeployment{
			ObjectMeta: metav1.ObjectMeta{
				Name:      k8sName,
				Namespace: k8sNamespace,
				UID:       types.UID(uuid.New().String()),
			},
			Spec: hostingv1.HostingDeploymentSpec{
				Region: ToStringPtr("us-east-1"),
				ProductionVariants: []commonv1.ProductionVariant{
					{
						InitialVariantWeight: ToInt64Ptr(1),
						InitialInstanceCount: ToInt64Ptr(4),
						VariantName:          ToStringPtr("variant-A"),
						ModelName:            &modelName,
					},
				},
				Models: []commonv1.Model{
					{
						Name:             &modelName,
						Containers:       containers,
						ExecutionRoleArn: ToStringPtr("xxx-yyy"),
					},
				},
			},
		}

		endpointConfigNamespacedName = GetSubresourceNamespacedName(desired.ObjectMeta.GetName(), *desired)
		err := createCreatedEndpointConfig(endpointConfigNamespacedName, *desired, "")
		Expect(err).ToNot(HaveOccurred())

		err = createCreatedModelWithAnySageMakerName(GetSubresourceNamespacedName(modelName, *desired), *desired)
		Expect(err).ToNot(HaveOccurred())
	})

	AfterEach(func() {
		var endpointConfig endpointconfigv1.EndpointConfig
		err := k8sClient.Get(context.Background(), types.NamespacedName{
			Namespace: endpointConfigNamespacedName.Namespace,
			Name:      endpointConfigNamespacedName.Name,
		}, &endpointConfig)
		Expect(err).ToNot(HaveOccurred())

		err = k8sClient.Delete(context.Background(), &endpointConfig)
		Expect(err).ToNot(HaveOccurred())

		modelNamespacedName := GetSubresourceNamespacedName(modelName, *desired)
		var model modelv1.Model
		err = k8sClient.Get(context.Background(), types.NamespacedName{
			Namespace: modelNamespacedName.Namespace,
			Name:      modelNamespacedName.Name,
		}, &model)
		Expect(err).ToNot(HaveOccurred())

		err = k8sClient.Delete(context.Background(), &model)
		Expect(err).ToNot(HaveOccurred())
	})

	It("Updates Kubernetes EndpointConfig", func() {
		newWeight := int64(5)
		updated := desired.DeepCopy()
		updated.Spec.ProductionVariants[0].InitialVariantWeight = &newWeight

		err := reconciler.Reconcile(context.Background(), updated, true)
		Expect(err).ToNot(HaveOccurred())

		var endpointConfig endpointconfigv1.EndpointConfig
		err = k8sClient.Get(context.Background(), types.NamespacedName{
			Namespace: endpointConfigNamespacedName.Namespace,
			Name:      endpointConfigNamespacedName.Name,
		}, &endpointConfig)
		Expect(err).ToNot(HaveOccurred())

		Expect(*endpointConfig.Spec.ProductionVariants[0].InitialVariantWeight).To(Equal(newWeight))
	})
})

// Create a K8s model that would have been created by ModelReconciler.
func createCreatedModelWithAnySageMakerName(namespacedName types.NamespacedName, deployment hostingv1.HostingDeployment) error {
	return createCreatedModelWithSageMakerName(namespacedName, deployment, "sagemaker-model-name")
}

// Create a K8s model that would have been created by ModelReconciler.
func createCreatedModelWithSageMakerName(namespacedName types.NamespacedName, deployment hostingv1.HostingDeployment, sageMakerName string) error {
	err := k8sClient.Create(context.Background(), &modelv1.Model{
		ObjectMeta: metav1.ObjectMeta{
			Name:      namespacedName.Name,
			Namespace: namespacedName.Namespace,
			Labels:    GetResourceOwnershipLabelsForHostingDeployment(deployment),
		},
		Spec: modelv1.ModelSpec{
			ExecutionRoleArn: ToStringPtr("xxx"),
			Region:           ToStringPtr(*deployment.Spec.Region),
		},
	})
	if err != nil {
		return err
	}

	var model modelv1.Model
	err = k8sClient.Get(context.Background(), namespacedName, &model)
	if err != nil {
		return err
	}

	model.Status.Status = modelcontroller.CreatedStatus
	model.Status.SageMakerModelName = sageMakerName
	err = k8sClient.Status().Update(context.Background(), &model)
	if err != nil {
		return err
	}

	return nil
}

// Create a K8s EndpointConfig that would have been created by EndpointConfigReconciler.
func createCreatedEndpointConfig(namespacedName types.NamespacedName, deployment hostingv1.HostingDeployment, sageMakerName string) error {
	err := k8sClient.Create(context.Background(), &endpointconfigv1.EndpointConfig{
		ObjectMeta: metav1.ObjectMeta{
			Name:      namespacedName.Name,
			Namespace: namespacedName.Namespace,
			Labels:    GetResourceOwnershipLabelsForHostingDeployment(deployment),
		},
		Spec: endpointconfigv1.EndpointConfigSpec{
			Region:             ToStringPtr(*deployment.Spec.Region),
			ProductionVariants: deployment.Spec.ProductionVariants,
		},
	})
	if err != nil {
		return err
	}

	var endpointConfig endpointconfigv1.EndpointConfig
	err = k8sClient.Get(context.Background(), namespacedName, &endpointConfig)
	if err != nil {
		return err
	}

	endpointConfig.Status.SageMakerEndpointConfigName = sageMakerName
	endpointConfig.Status.Status = endpointconfigcontroller.CreatedStatus
	err = k8sClient.Status().Update(context.Background(), &endpointConfig)
	if err != nil {
		return err
	}

	return nil
}

func createHostingDeploymentWithBasicProductionVariant() hostingv1.HostingDeployment {
	deployment := createHostingDeploymentWithGeneratedNames()
	deployment.Spec.ProductionVariants = []commonv1.ProductionVariant{
		{
			InitialInstanceCount: ToInt64Ptr(1),
			InstanceType:         "instance-type",
			VariantName:          ToStringPtr("variant-name"),
			ModelName:            ToStringPtr("model-name"),
		},
	}

	return deployment
}

func createHostingDeploymentWithGeneratedNames() hostingv1.HostingDeployment {
	k8sName := "endpointconfig-" + uuid.New().String()
	k8sNamespace := "namespace-" + uuid.New().String()
	CreateMockNamespace(context.Background(), k8sClient, k8sNamespace)
	return createHostingDeployment(k8sName, k8sNamespace)
}

func createHostingDeployment(k8sName, k8sNamespace string) hostingv1.HostingDeployment {
	return hostingv1.HostingDeployment{
		ObjectMeta: metav1.ObjectMeta{
			Name:      k8sName,
			Namespace: k8sNamespace,
			UID:       types.UID(uuid.New().String()),
		},
		Spec: hostingv1.HostingDeploymentSpec{
			ProductionVariants: []commonv1.ProductionVariant{},
			Models:             []commonv1.Model{},
			Region:             ToStringPtr("us-east-1"),
		},
	}
}

func updateEndpointConfigStatus(namespacedName types.NamespacedName, status, sageMakerEndpointConfigName string) error {
	var endpointconfig endpointconfigv1.EndpointConfig
	err := k8sClient.Get(context.Background(), namespacedName, &endpointconfig)
	if err != nil {
		return err
	}

	endpointconfig.Status.Status = status
	endpointconfig.Status.SageMakerEndpointConfigName = sageMakerEndpointConfigName
	err = k8sClient.Status().Update(context.Background(), &endpointconfig)
	if err != nil {
		return err
	}

	return nil
}