# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import import importlib import logging import os import mxnet as mx from sagemaker_inference import environment from sagemaker_inference.default_handler_service import DefaultHandlerService from sagemaker_inference.transformer import Transformer from sagemaker_mxnet_serving_container.default_inference_handler import DefaultGluonBlockInferenceHandler, \ DefaultMXNetInferenceHandler from sagemaker_mxnet_serving_container.mxnet_module_transformer import MXNetModuleTransformer PYTHON_PATH_ENV = "PYTHONPATH" logging.basicConfig(level=logging.ERROR) class HandlerService(DefaultHandlerService): """Handler service that is executed by the model server. Determines specific default inference handlers to use based on the type MXNet model being used. This class extends ``DefaultHandlerService``, which define the following: - The ``handle`` method is invoked for all incoming inference requests to the model server. - The ``initialize`` method is invoked at model server start up. Based on: https://github.com/awslabs/multi-model-server/blob/master/docs/custom_service.md """ def __init__(self): self._service = None @staticmethod def _user_module_transformer(model_dir=environment.model_dir): try: user_module = importlib.import_module(environment.Environment().module_name) except ModuleNotFoundError as e: logging.error("import_module exception: {}".format(e)) raise ValueError('import_module exception: {}'.format(e)) if hasattr(user_module, 'transform_fn'): return Transformer(default_inference_handler=DefaultMXNetInferenceHandler()) model_fn = getattr(user_module, 'model_fn', DefaultMXNetInferenceHandler().default_model_fn) model = model_fn(model_dir) if isinstance(model, mx.module.BaseModule): return MXNetModuleTransformer() elif isinstance(model, mx.gluon.block.Block): return Transformer(default_inference_handler=DefaultGluonBlockInferenceHandler()) else: raise ValueError('Unsupported model type: {}. Did you forget to implement ' '`transform_fn` or `model_fn` in your entry-point?' .format(model.__class__.__name__)) def initialize(self, context): """Calls the Transformer method that validates the user module against the SageMaker inference contract. """ properties = context.system_properties model_dir = properties.get("model_dir") # add model_dir/code to python path code_dir_path = "{}:".format(model_dir + '/code') if PYTHON_PATH_ENV in os.environ: os.environ[PYTHON_PATH_ENV] = code_dir_path + os.environ[PYTHON_PATH_ENV] else: os.environ[PYTHON_PATH_ENV] = code_dir_path self._service = self._user_module_transformer(model_dir) super(HandlerService, self).initialize(context)