import json


class InferenceSpecification:

    template = """
{    
    "InferenceSpecification": {
        "Containers" : [{"Image": "IMAGE_REPLACE_ME"}],
        "SupportedTransformInstanceTypes": INSTANCES_REPLACE_ME,
        "SupportedRealtimeInferenceInstanceTypes": INSTANCES_REPLACE_ME,
        "SupportedContentTypes": CONTENT_TYPES_REPLACE_ME,
        "SupportedResponseMIMETypes": RESPONSE_MIME_TYPES_REPLACE_ME
    }
}
"""

    def get_inference_specification_dict(
        self, ecr_image, supports_gpu, supported_content_types=None, supported_mime_types=None
    ):
        return json.loads(
            self.get_inference_specification_json(
                ecr_image, supports_gpu, supported_content_types, supported_mime_types
            )
        )

    def get_inference_specification_json(
        self, ecr_image, supports_gpu, supported_content_types=None, supported_mime_types=None
    ):
        if supported_mime_types is None:
            supported_mime_types = []
        if supported_content_types is None:
            supported_content_types = []
        return (
            self.template.replace("IMAGE_REPLACE_ME", ecr_image)
            .replace("INSTANCES_REPLACE_ME", self.get_supported_instances(supports_gpu))
            .replace("CONTENT_TYPES_REPLACE_ME", json.dumps(supported_content_types))
            .replace("RESPONSE_MIME_TYPES_REPLACE_ME", json.dumps(supported_mime_types))
        )

    def get_supported_instances(self, supports_gpu):
        cpu_list = [
            "ml.m4.xlarge",
            "ml.m4.2xlarge",
            "ml.m4.4xlarge",
            "ml.m4.10xlarge",
            "ml.m4.16xlarge",
            "ml.m5.large",
            "ml.m5.xlarge",
            "ml.m5.2xlarge",
            "ml.m5.4xlarge",
            "ml.m5.12xlarge",
            "ml.m5.24xlarge",
            "ml.c4.xlarge",
            "ml.c4.2xlarge",
            "ml.c4.4xlarge",
            "ml.c4.8xlarge",
            "ml.c5.xlarge",
            "ml.c5.2xlarge",
            "ml.c5.4xlarge",
            "ml.c5.9xlarge",
            "ml.c5.18xlarge",
        ]
        gpu_list = [
            "ml.p2.xlarge",
            "ml.p2.8xlarge",
            "ml.p2.16xlarge",
            "ml.p3.2xlarge",
            "ml.p3.8xlarge",
            "ml.p3.16xlarge",
        ]

        list_to_return = cpu_list

        if supports_gpu:
            list_to_return = cpu_list + gpu_list

        return json.dumps(list_to_return)