import json class InferenceSpecification: template = """ { "InferenceSpecification": { "Containers" : [{"Image": "IMAGE_REPLACE_ME"}], "SupportedTransformInstanceTypes": INSTANCES_REPLACE_ME, "SupportedRealtimeInferenceInstanceTypes": INSTANCES_REPLACE_ME, "SupportedContentTypes": CONTENT_TYPES_REPLACE_ME, "SupportedResponseMIMETypes": RESPONSE_MIME_TYPES_REPLACE_ME } } """ def get_inference_specification_dict( self, ecr_image, supports_gpu, supported_content_types=None, supported_mime_types=None ): return json.loads( self.get_inference_specification_json( ecr_image, supports_gpu, supported_content_types, supported_mime_types ) ) def get_inference_specification_json( self, ecr_image, supports_gpu, supported_content_types=None, supported_mime_types=None ): if supported_mime_types is None: supported_mime_types = [] if supported_content_types is None: supported_content_types = [] return ( self.template.replace("IMAGE_REPLACE_ME", ecr_image) .replace("INSTANCES_REPLACE_ME", self.get_supported_instances(supports_gpu)) .replace("CONTENT_TYPES_REPLACE_ME", json.dumps(supported_content_types)) .replace("RESPONSE_MIME_TYPES_REPLACE_ME", json.dumps(supported_mime_types)) ) def get_supported_instances(self, supports_gpu): cpu_list = [ "ml.m4.xlarge", "ml.m4.2xlarge", "ml.m4.4xlarge", "ml.m4.10xlarge", "ml.m4.16xlarge", "ml.m5.large", "ml.m5.xlarge", "ml.m5.2xlarge", "ml.m5.4xlarge", "ml.m5.12xlarge", "ml.m5.24xlarge", "ml.c4.xlarge", "ml.c4.2xlarge", "ml.c4.4xlarge", "ml.c4.8xlarge", "ml.c5.xlarge", "ml.c5.2xlarge", "ml.c5.4xlarge", "ml.c5.9xlarge", "ml.c5.18xlarge", ] gpu_list = [ "ml.p2.xlarge", "ml.p2.8xlarge", "ml.p2.16xlarge", "ml.p3.2xlarge", "ml.p3.8xlarge", "ml.p3.16xlarge", ] list_to_return = cpu_list if supports_gpu: list_to_return = cpu_list + gpu_list return json.dumps(list_to_return)