# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import import pytest from sagemaker.cli.compatibility.v2.modifiers import parsing from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call def test_arg_from_keywords(): kw_name = "framework_version" kw_value = "1.6.0" call = ast_call("MXNet({}='{}', py_version='py3', entry_point='run')".format(kw_name, kw_value)) returned_kw = parsing.arg_from_keywords(call, kw_name) assert kw_name == returned_kw.arg assert kw_value == returned_kw.value.s def test_arg_from_keywords_absent_keyword(): call = ast_call("MXNet(entry_point='run')") assert parsing.arg_from_keywords(call, "framework_version") is None def test_arg_value(): call = ast_call("MXNet(framework_version='1.6.0')") assert "1.6.0" == parsing.arg_value(call, "framework_version") call = ast_call("MXNet(framework_version=mxnet_version)") assert "mxnet_version" == parsing.arg_value(call, "framework_version") call = ast_call("MXNet(instance_count=1)") assert 1 == parsing.arg_value(call, "instance_count") call = ast_call("MXNet(enable_network_isolation=True)") assert parsing.arg_value(call, "enable_network_isolation") is True call = ast_call("MXNet(source_dir=None)") assert parsing.arg_value(call, "source_dir") is None def test_arg_value_absent_keyword(): code = "MXNet(entry_point='run')" with pytest.raises(KeyError) as e: parsing.arg_value(ast_call(code), "framework_version") assert "arg 'framework_version' not found in call: {}".format(code) in str(e.value)