import json

class InferenceSpecification:

    template = """
{    
    "InferenceSpecification": {
        "Containers" : [{"Image": "IMAGE_REPLACE_ME"}],
        "SupportedTransformInstanceTypes": INSTANCES_REPLACE_ME,
        "SupportedRealtimeInferenceInstanceTypes": INSTANCES_REPLACE_ME,
        "SupportedContentTypes": CONTENT_TYPES_REPLACE_ME,
        "SupportedResponseMIMETypes": RESPONSE_MIME_TYPES_REPLACE_ME
    }
}
"""

    def get_inference_specification_dict(self, ecr_image, supports_gpu, supported_content_types=None, supported_mime_types=None):
        return json.loads(self.get_inference_specification_json(ecr_image, supports_gpu, supported_content_types, supported_mime_types))

    def get_inference_specification_json(self, ecr_image, supports_gpu, supported_content_types=None, supported_mime_types=None):
        if supported_mime_types is None:
            supported_mime_types = []
        if supported_content_types is None:
            supported_content_types = []
        return self.template.replace("IMAGE_REPLACE_ME", ecr_image) \
            .replace("INSTANCES_REPLACE_ME", self.get_supported_instances(supports_gpu)) \
            .replace("CONTENT_TYPES_REPLACE_ME", json.dumps(supported_content_types)) \
            .replace("RESPONSE_MIME_TYPES_REPLACE_ME", json.dumps(supported_mime_types)) \

    def get_supported_instances(self, supports_gpu):
        cpu_list = ["ml.m4.xlarge","ml.m4.2xlarge","ml.m4.4xlarge","ml.m4.10xlarge","ml.m4.16xlarge","ml.m5.large","ml.m5.xlarge","ml.m5.2xlarge","ml.m5.4xlarge","ml.m5.12xlarge","ml.m5.24xlarge","ml.c4.xlarge","ml.c4.2xlarge","ml.c4.4xlarge","ml.c4.8xlarge","ml.c5.xlarge","ml.c5.2xlarge","ml.c5.4xlarge","ml.c5.9xlarge","ml.c5.18xlarge"]
        gpu_list = ["ml.p2.xlarge", "ml.p2.8xlarge", "ml.p2.16xlarge", "ml.p3.2xlarge", "ml.p3.8xlarge", "ml.p3.16xlarge"]

        list_to_return = cpu_list

        if supports_gpu:
            list_to_return = cpu_list + gpu_list

        return json.dumps(list_to_return)