# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the 'License'). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the 'license' file accompanying this file. This file is # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import os import signal import subprocess import sys import types import botocore.session from botocore.stub import Stubber from mock import ANY, MagicMock, Mock, patch import pytest from sagemaker_inference import environment, model_server from sagemaker_inference.model_server import MMS_NAMESPACE, REQUIREMENTS_PATH PYTHON_PATH = "python_path" DEFAULT_CONFIGURATION = "default_configuration" @patch("subprocess.call") @patch("subprocess.Popen") @patch("sagemaker_inference.model_server._retry_retrieve_mms_server_process") @patch("sagemaker_inference.model_server._add_sigterm_handler") @patch("sagemaker_inference.model_server._install_requirements") @patch("os.path.exists", return_value=True) @patch("sagemaker_inference.model_server._create_model_server_config_file") @patch("sagemaker_inference.model_server._adapt_to_mms_format") @patch("sagemaker_inference.environment.Environment") def test_start_model_server_default_service_handler( env, adapt, create_config, exists, install_requirements, sigterm, retrieve, subprocess_popen, subprocess_call, ): env.return_value.startup_timeout = 10000 model_server.start_model_server() adapt.assert_not_called() create_config.assert_called_once_with(env.return_value, model_server.DEFAULT_HANDLER_SERVICE) exists.assert_called_once_with(REQUIREMENTS_PATH) install_requirements.assert_called_once_with() multi_model_server_cmd = [ "multi-model-server", "--start", "--model-store", model_server.MODEL_STORE, "--mms-config", model_server.MMS_CONFIG_FILE, "--log-config", model_server.DEFAULT_MMS_LOG_FILE, "--models", "{}={}".format(model_server.DEFAULT_MMS_MODEL_NAME, environment.model_dir), ] subprocess_popen.assert_called_once_with(multi_model_server_cmd) sigterm.assert_called_once_with(retrieve.return_value) @patch("subprocess.call") @patch("subprocess.Popen") @patch("sagemaker_inference.model_server._retry_retrieve_mms_server_process") @patch("sagemaker_inference.model_server._add_sigterm_handler") @patch("sagemaker_inference.model_server._create_model_server_config_file") @patch("sagemaker_inference.model_server._adapt_to_mms_format") @patch("sagemaker_inference.environment.Environment") def test_start_model_server_custom_handler_service( env, adapt, create_config, sigterm, retrieve, subprocess_popen, subprocess_call ): handler_service = Mock() model_server.start_model_server(handler_service) adapt.assert_not_called() create_config.assert_called_once_with(env.return_value, handler_service) @patch("sagemaker_inference.model_server._set_python_path") @patch("subprocess.check_call") @patch("os.makedirs") @patch("os.path.exists", return_value=False) def test_adapt_to_mms_format(path_exists, make_dir, subprocess_check_call, set_python_path): handler_service = Mock() model_server._adapt_to_mms_format(handler_service) path_exists.assert_called_once_with(model_server.DEFAULT_MMS_MODEL_EXPORT_DIRECTORY) make_dir.assert_called_once_with(model_server.DEFAULT_MMS_MODEL_EXPORT_DIRECTORY) model_archiver_cmd = [ "model-archiver", "--model-name", model_server.DEFAULT_MMS_MODEL_NAME, "--handler", handler_service, "--model-path", environment.model_dir, "--export-path", model_server.DEFAULT_MMS_MODEL_EXPORT_DIRECTORY, "--archive-format", "no-archive", ] subprocess_check_call.assert_called_once_with(model_archiver_cmd) set_python_path.assert_called_once_with() @patch("sagemaker_inference.model_server._set_python_path") @patch("subprocess.check_call") @patch("os.makedirs") @patch("os.path.exists", return_value=True) def test_adapt_to_mms_format_existing_path( path_exists, make_dir, subprocess_check_call, set_python_path ): handler_service = Mock() model_server._adapt_to_mms_format(handler_service) path_exists.assert_called_once_with(model_server.DEFAULT_MMS_MODEL_EXPORT_DIRECTORY) make_dir.assert_not_called() @patch.dict(os.environ, {model_server.PYTHON_PATH_ENV: PYTHON_PATH}, clear=True) def test_set_existing_python_path(): model_server._set_python_path() code_dir_path = "{}:".format(environment.code_dir) assert os.environ[model_server.PYTHON_PATH_ENV] == code_dir_path + PYTHON_PATH @patch.dict(os.environ, {}, clear=True) def test_new_python_path(): model_server._set_python_path() code_dir_path = "{}:".format(environment.code_dir) assert os.environ[model_server.PYTHON_PATH_ENV] == code_dir_path @patch("sagemaker_inference.model_server._generate_mms_config_properties") @patch("sagemaker_inference.utils.write_file") @patch("sagemaker_inference.environment.Environment") def test_create_model_server_config_file(env, write_file, generate_mms_config_props): model_server._create_model_server_config_file(env.return_value) write_file.assert_called_once_with( model_server.MMS_CONFIG_FILE, generate_mms_config_props.return_value ) @patch("sagemaker_inference.utils.read_file", return_value=DEFAULT_CONFIGURATION) @patch("sagemaker_inference.environment.Environment") def test_generate_mms_config_properties(env, read_file): model_server_timeout = "model_server_timeout" model_server_workers = "model_server_workers" http_port = "http_port" env.return_value.model_server_timeout = model_server_timeout env.return_value.model_server_workers = model_server_workers env.return_value.inference_http_port = http_port mms_config_properties = model_server._generate_mms_config_properties(env.return_value) inference_address = "inference_address=http://0.0.0.0:{}\n".format(http_port) server_timeout = "default_response_timeout={}\n".format(model_server_timeout) workers = "default_workers_per_model={}\n".format(model_server_workers) read_file.assert_called_once_with(model_server.DEFAULT_MMS_CONFIG_FILE) assert mms_config_properties.startswith(DEFAULT_CONFIGURATION) assert inference_address in mms_config_properties assert server_timeout in mms_config_properties assert workers in mms_config_properties @patch("sagemaker_inference.utils.read_file", return_value=DEFAULT_CONFIGURATION) @patch("sagemaker_inference.environment.Environment") def test_generate_mms_config_properties_default_workers(env, read_file): env.return_value.model_server_workers = None mms_config_properties = model_server._generate_mms_config_properties(env.return_value) workers = "default_workers_per_model={}".format(None) read_file.assert_called_once_with(model_server.DEFAULT_MMS_CONFIG_FILE) assert mms_config_properties.startswith(DEFAULT_CONFIGURATION) assert workers not in mms_config_properties @patch("signal.signal") def test_add_sigterm_handler(signal_call): mms = Mock() model_server._add_sigterm_handler(mms) mock_calls = signal_call.mock_calls first_argument = mock_calls[0][1][0] second_argument = mock_calls[0][1][1] assert len(mock_calls) == 1 assert first_argument == signal.SIGTERM assert isinstance(second_argument, types.FunctionType) @patch("subprocess.check_call") def test_install_requirements(check_call): model_server._install_requirements() install_cmd = [ sys.executable, "-m", "pip", "install", "-r", "/opt/ml/model/code/requirements.txt", ] check_call.assert_called_once_with(install_cmd) @patch("subprocess.check_call", side_effect=subprocess.CalledProcessError(0, "cmd")) def test_install_requirements_installation_failed(check_call): with pytest.raises(ValueError) as e: model_server._install_requirements() assert "failed to install required packages" in str(e.value) @patch.dict(os.environ, {"CA_REPOSITORY_ARN": "invalid_arn"}, clear=True) def test_install_requirements_codeartifact_invalid_arn_installation_failed(): with pytest.raises(Exception) as e: model_server._install_requirements() assert "invalid CodeArtifact repository arn invalid_arn" in str(e.value) @patch("subprocess.check_call") @patch.dict( os.environ, { "CA_REPOSITORY_ARN": "arn:aws:codeartifact:my_region:012345678900:repository/my_domain/my_repo" }, clear=True, ) def test_install_requirements_codeartifact(check_call): # mock/stub codeartifact client and its responses endpoint = "https://domain-012345678900.d.codeartifact.region.amazonaws.com/pypi/my_repo/" codeartifact = botocore.session.get_session().create_client( "codeartifact", region_name="myregion" ) stubber = Stubber(codeartifact) stubber.add_response("get_authorization_token", {"authorizationToken": "the-auth-token"}) stubber.add_response("get_repository_endpoint", {"repositoryEndpoint": endpoint}) stubber.activate() with patch("boto3.client", MagicMock(return_value=codeartifact)): model_server._install_requirements() install_cmd = [ sys.executable, "-m", "pip", "install", "-r", "/opt/ml/model/code/requirements.txt", "-i", "https://aws:the-auth-token@domain-012345678900.d.codeartifact.region.amazonaws.com/pypi/my_repo/simple/", ] check_call.assert_called_once_with(install_cmd) @patch("psutil.process_iter") def test_retrieve_mms_server_process(process_iter): server = Mock() server.cmdline.return_value = MMS_NAMESPACE processes = list() processes.append(server) process_iter.return_value = processes process = model_server._retrieve_mms_server_process() assert process == server @patch("psutil.process_iter", return_value=list()) def test_retrieve_mms_server_process_no_server(process_iter): with pytest.raises(Exception) as e: model_server._retrieve_mms_server_process() assert "mms model server was unsuccessfully started" in str(e.value) @patch("psutil.process_iter") def test_retrieve_mms_server_process_too_many_servers(process_iter): server = Mock() second_server = Mock() server.cmdline.return_value = MMS_NAMESPACE second_server.cmdline.return_value = MMS_NAMESPACE processes = list() processes.append(server) processes.append(second_server) process_iter.return_value = processes with pytest.raises(Exception) as e: model_server._retrieve_mms_server_process() assert "multiple mms model servers are not supported" in str(e.value) @patch("sagemaker_inference.model_server.retry", return_value=lambda f: f) @patch("sagemaker_inference.model_server._retrieve_mms_server_process", return_value=17) def test_retry_retrieve_mms_server_process(retrieve, retry): process_id = model_server._retry_retrieve_mms_server_process(100) assert process_id == 17 retry.assert_called_once_with(wait_fixed=ANY, stop_max_delay=100 * 1000)