# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import import ast import pasta import pytest from sagemaker.cli.compatibility.v2.modifiers import serde from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call, ast_import @pytest.mark.parametrize( "src, expected", [ ("sagemaker.predictor._CsvSerializer()", True), ("sagemaker.predictor._JsonSerializer()", True), ("sagemaker.predictor._NpySerializer()", True), ("sagemaker.predictor._CsvDeserializer()", True), ("sagemaker.predictor.BytesDeserializer()", True), ("sagemaker.predictor.StringDeserializer()", True), ("sagemaker.predictor.StreamDeserializer()", True), ("sagemaker.predictor._NumpyDeserializer()", True), ("sagemaker.predictor._JsonDeserializer()", True), ("sagemaker.predictor.OtherClass()", False), ("sagemaker.amazon.common.numpy_to_record_serializer()", True), ("sagemaker.amazon.common.record_deserializer()", True), ("_CsvSerializer()", True), ("_JsonSerializer()", True), ("_NpySerializer()", True), ("_CsvDeserializer()", True), ("BytesDeserializer()", True), ("StringDeserializer()", True), ("StreamDeserializer()", True), ("_NumpyDeserializer()", True), ("_JsonDeserializer()", True), ("numpy_to_record_serializer()", True), ("record_deserializer()", True), ("OtherClass()", False), ], ) def test_constructor_node_should_be_modified(src, expected): modifier = serde.SerdeConstructorRenamer() node = ast_call(src) assert modifier.node_should_be_modified(node) is expected @pytest.mark.parametrize( "src, expected", [ ("sagemaker.predictor._CsvSerializer()", "serializers.CSVSerializer()"), ("sagemaker.predictor._JsonSerializer()", "serializers.JSONSerializer()"), ("sagemaker.predictor._NpySerializer()", "serializers.NumpySerializer()"), ("sagemaker.predictor._CsvDeserializer()", "deserializers.CSVDeserializer()"), ("sagemaker.predictor.BytesDeserializer()", "deserializers.BytesDeserializer()"), ( "sagemaker.predictor.StringDeserializer()", "deserializers.StringDeserializer()", ), ( "sagemaker.predictor.StreamDeserializer()", "deserializers.StreamDeserializer()", ), ("sagemaker.predictor._NumpyDeserializer()", "deserializers.NumpyDeserializer()"), ("sagemaker.predictor._JsonDeserializer()", "deserializers.JSONDeserializer()"), ( "sagemaker.amazon.common.numpy_to_record_serializer()", "sagemaker.amazon.common.RecordSerializer()", ), ( "sagemaker.amazon.common.record_deserializer()", "sagemaker.amazon.common.RecordDeserializer()", ), ("_CsvSerializer()", "serializers.CSVSerializer()"), ("_JsonSerializer()", "serializers.JSONSerializer()"), ("_NpySerializer()", "serializers.NumpySerializer()"), ("_CsvDeserializer()", "deserializers.CSVDeserializer()"), ("BytesDeserializer()", "deserializers.BytesDeserializer()"), ("StringDeserializer()", "deserializers.StringDeserializer()"), ("StreamDeserializer()", "deserializers.StreamDeserializer()"), ("_NumpyDeserializer()", "deserializers.NumpyDeserializer()"), ("_JsonDeserializer()", "deserializers.JSONDeserializer()"), ("numpy_to_record_serializer()", "RecordSerializer()"), ("record_deserializer()", "RecordDeserializer()"), ], ) def test_constructor_modify_node(src, expected): modifier = serde.SerdeConstructorRenamer() node = ast_call(src) modified_node = modifier.modify_node(node) assert expected == pasta.dump(modified_node) assert isinstance(modified_node, ast.Call) @pytest.mark.parametrize( "src, expected", [ ( "sagemaker.predictor.csv_serializer", True, ), ( "sagemaker.predictor.json_serializer", True, ), ( "sagemaker.predictor.npy_serializer", True, ), ( "sagemaker.predictor.csv_deserializer", True, ), ( "sagemaker.predictor.json_deserializer", True, ), ( "sagemaker.predictor.numpy_deserializer", True, ), ( "csv_serializer", True, ), ( "json_serializer", True, ), ( "npy_serializer", True, ), ( "csv_deserializer", True, ), ( "json_deserializer", True, ), ( "numpy_deserializer", True, ), ], ) def test_name_node_should_be_modified(src, expected): modifier = serde.SerdeObjectRenamer() node = ast_call(src) assert modifier.node_should_be_modified(node) is True @pytest.mark.parametrize( "src, expected", [ ("sagemaker.predictor.csv_serializer", "serializers.CSVSerializer()"), ("sagemaker.predictor.json_serializer", "serializers.JSONSerializer()"), ("sagemaker.predictor.npy_serializer", "serializers.NumpySerializer()"), ("sagemaker.predictor.csv_deserializer", "deserializers.CSVDeserializer()"), ("sagemaker.predictor.json_deserializer", "deserializers.JSONDeserializer()"), ("sagemaker.predictor.numpy_deserializer", "deserializers.NumpyDeserializer()"), ("csv_serializer", "serializers.CSVSerializer()"), ("json_serializer", "serializers.JSONSerializer()"), ("npy_serializer", "serializers.NumpySerializer()"), ("csv_deserializer", "deserializers.CSVDeserializer()"), ("json_deserializer", "deserializers.JSONDeserializer()"), ("numpy_deserializer", "deserializers.NumpyDeserializer()"), ], ) def test_name_modify_node(src, expected): modifier = serde.SerdeObjectRenamer() node = ast_call(src) modified_node = modifier.modify_node(node) assert expected == pasta.dump(modified_node) assert isinstance(modified_node, ast.Call) @pytest.mark.parametrize( "src, expected", [ ("from sagemaker.predictor import _CsvSerializer", True), ("from sagemaker.predictor import _JsonSerializer", True), ("from sagemaker.predictor import _NpySerializer", True), ("from sagemaker.predictor import _CsvDeserializer", True), ("from sagemaker.predictor import BytesDeserializer", True), ("from sagemaker.predictor import StringDeserializer", True), ("from sagemaker.predictor import StreamDeserializer", True), ("from sagemaker.predictor import _NumpyDeserializer", True), ("from sagemaker.predictor import _JsonDeserializer", True), ("from sagemaker.predictor import csv_serializer", True), ("from sagemaker.predictor import json_serializer", True), ("from sagemaker.predictor import npy_serializer", True), ("from sagemaker.predictor import csv_deserializer", True), ("from sagemaker.predictor import json_deserializer", True), ("from sagemaker.predictor import numpy_deserializer", True), ("from sagemaker.predictor import RealTimePredictor, _CsvSerializer", True), ("from sagemaker.predictor import RealTimePredictor", False), ("from sagemaker.amazon.common import numpy_to_record_serializer", False), ], ) def test_import_from_predictor_node_should_be_modified(src, expected): modifier = serde.SerdeImportFromPredictorRenamer() node = ast_import(src) assert modifier.node_should_be_modified(node) is expected @pytest.mark.parametrize( "src, expected", [ ("from sagemaker.predictor import _CsvSerializer", None), ("from sagemaker.predictor import _JsonSerializer", None), ("from sagemaker.predictor import _NpySerializer", None), ("from sagemaker.predictor import _CsvDeserializer", None), ("from sagemaker.predictor import BytesDeserializer", None), ("from sagemaker.predictor import StringDeserializer", None), ("from sagemaker.predictor import StreamDeserializer", None), ("from sagemaker.predictor import _NumpyDeserializer", None), ("from sagemaker.predictor import _JsonDeserializer", None), ("from sagemaker.predictor import csv_serializer", None), ("from sagemaker.predictor import json_serializer", None), ("from sagemaker.predictor import npy_serializer", None), ("from sagemaker.predictor import csv_deserializer", None), ("from sagemaker.predictor import json_deserializer", None), ("from sagemaker.predictor import numpy_deserializer", None), ( "from sagemaker.predictor import RealTimePredictor, _NpySerializer", "from sagemaker.predictor import RealTimePredictor", ), ], ) def test_import_from_predictor_modify_node(src, expected): modifier = serde.SerdeImportFromPredictorRenamer() node = ast_import(src) modified_node = modifier.modify_node(node) assert expected == (modified_node and pasta.dump(modified_node)) @pytest.mark.parametrize( "import_statement, expected", [ ("from sagemaker.amazon.common import numpy_to_record_serializer", True), ("from sagemaker.amazon.common import record_deserializer", True), ("from sagemaker.amazon.common import write_spmatrix_to_sparse_tensor", False), ], ) def test_import_from_amazon_common_node_should_be_modified(import_statement, expected): modifier = serde.SerdeImportFromAmazonCommonRenamer() node = ast_import(import_statement) assert modifier.node_should_be_modified(node) is expected @pytest.mark.parametrize( "import_statement, expected", [ ( "from sagemaker.amazon.common import numpy_to_record_serializer", "from sagemaker.amazon.common import RecordSerializer", ), ( "from sagemaker.amazon.common import record_deserializer", "from sagemaker.amazon.common import RecordDeserializer", ), ( "from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer", "from sagemaker.amazon.common import RecordSerializer, RecordDeserializer", ), ( "from sagemaker.amazon.common import write_spmatrix_to_sparse_tensor, numpy_to_record_serializer", "from sagemaker.amazon.common import write_spmatrix_to_sparse_tensor, RecordSerializer", ), ], ) def test_import_from_amazon_common_modify_node(import_statement, expected): modifier = serde.SerdeImportFromAmazonCommonRenamer() node = ast_import(import_statement) modified_node = modifier.modify_node(node) assert expected == pasta.dump(modified_node) @pytest.mark.parametrize( "src, expected", [ ("serializers.CSVSerializer()", True), ("serializers.JSONSerializer()", True), ("serializers.NumpySerializer()", True), ("pass", False), ], ) def test_serializer_module_node_should_be_modified(src, expected): modifier = serde.SerializerImportInserter() node = pasta.parse(src) assert modifier.node_should_be_modified(node) is expected @pytest.mark.parametrize( "src, expected", [ ( "serializers.CSVSerializer()", "from sagemaker import serializers\nserializers.CSVSerializer()", ), ( "serializers.JSONSerializer()", "from sagemaker import serializers\nserializers.JSONSerializer()", ), ( "serializers.NumpySerializer()", "from sagemaker import serializers\nserializers.NumpySerializer()", ), ( "pass\nimport random\nserializers.CSVSerializer()", "pass\nfrom sagemaker import serializers\nimport random\nserializers.CSVSerializer()", ), ], ) def test_serializer_module_modify_node(src, expected): modifier = serde.SerializerImportInserter() node = pasta.parse(src) modified_node = modifier.modify_node(node) assert expected == pasta.dump(modified_node) @pytest.mark.parametrize( "src, expected", [ ("deserializers.CSVDeserializer()", True), ("deserializers.BytesDeserializer()", True), ("deserializers.StringDeserializer()", True), ("deserializers.StreamDeserializer()", True), ("deserializers.NumpyDeserializer()", True), ("deserializers.JSONDeserializer()", True), ("pass", False), ], ) def test_deserializer_module_node_should_be_modified(src, expected): modifier = serde.DeserializerImportInserter() node = pasta.parse(src) assert modifier.node_should_be_modified(node) is expected @pytest.mark.parametrize( "src, expected", [ ( "deserializers.CSVDeserializer()", "from sagemaker import deserializers\ndeserializers.CSVDeserializer()", ), ( "deserializers.BytesDeserializer()", "from sagemaker import deserializers\ndeserializers.BytesDeserializer()", ), ( "deserializers.StringDeserializer()", "from sagemaker import deserializers\ndeserializers.StringDeserializer()", ), ( "deserializers.StreamDeserializer()", "from sagemaker import deserializers\ndeserializers.StreamDeserializer()", ), ( "deserializers.NumpyDeserializer()", "from sagemaker import deserializers\ndeserializers.NumpyDeserializer()", ), ( "deserializers.JSONDeserializer()", "from sagemaker import deserializers\ndeserializers.JSONDeserializer()", ), ( "pass\nimport random\ndeserializers.CSVDeserializer()", "pass\nfrom sagemaker import deserializers\nimport random\ndeserializers.CSVDeserializer()", ), ], ) def test_deserializer_module_modify_node(src, expected): modifier = serde.DeserializerImportInserter() node = pasta.parse(src) modified_node = modifier.modify_node(node) assert expected == pasta.dump(modified_node) @pytest.mark.parametrize( "src, expected", [ ('estimator.create_model(entry_point="inference.py")', False), ("estimator.create_model(serializer=CSVSerializer())", True), ("estimator.create_model(deserializer=CSVDeserializer())", True), ( "estimator.create_model(serializer=CSVSerializer(), deserializer=CSVDeserializer())", True, ), ("estimator.deploy(serializer=CSVSerializer())", False), ], ) def test_create_model_call_node_should_be_modified(src, expected): modifier = serde.SerdeKeywordRemover() node = ast_call(src) assert modifier.node_should_be_modified(node) is expected @pytest.mark.parametrize( "src, expected", [ ( 'estimator.create_model(entry_point="inference.py", serializer=CSVSerializer())', 'estimator.create_model(entry_point="inference.py")', ), ( 'estimator.create_model(entry_point="inference.py", deserializer=CSVDeserializer())', 'estimator.create_model(entry_point="inference.py")', ), ], ) def test_create_model_call_modify_node(src, expected): modifier = serde.SerdeKeywordRemover() node = ast_call(src) modified_node = modifier.modify_node(node) assert expected == pasta.dump(modified_node)