# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from smexperiments import trial_component, api_types, trial import datetime import pytest import unittest.mock @pytest.fixture def sagemaker_boto_client(): return unittest.mock.Mock() def test_create(sagemaker_boto_client): sagemaker_boto_client.create_trial_component.return_value = { "TrialComponentArn": "bazz", } obj = trial_component.TrialComponent.create( trial_component_name="foo", display_name="bar", sagemaker_boto_client=sagemaker_boto_client ) sagemaker_boto_client.create_trial_component.assert_called_with(TrialComponentName="foo", DisplayName="bar") assert "foo" == obj.trial_component_name assert "bar" == obj.display_name assert "bazz" == obj.trial_component_arn def test_create_with_tags(sagemaker_boto_client): sagemaker_boto_client.create_trial_component.return_value = { "TrialComponentArn": "bazz", } tags = [{"Key": "foo", "Value": "bar"}] obj = trial_component.TrialComponent.create( trial_component_name="foo", display_name="bar", sagemaker_boto_client=sagemaker_boto_client, tags=tags ) sagemaker_boto_client.create_trial_component.assert_called_with( TrialComponentName="foo", DisplayName="bar", Tags=[{"Key": "foo", "Value": "bar"}] ) def test_load(sagemaker_boto_client): now = datetime.datetime.now(datetime.timezone.utc) sagemaker_boto_client.describe_trial_component.return_value = { "TrialComponentArn": "A", "TrialComponentName": "B", "DisplayName": "C", "Status": {"PrimaryStatus": "InProgress", "Message": "D"}, "Parameters": {"E": {"NumberValue": 1.0}, "F": {"StringValue": "G"}}, "InputArtifacts": {"H": {"Value": "s3://foo/bar", "MediaType": "text/plain"}}, "OutputArtifacts": {"I": {"Value": "s3://whizz/bang", "MediaType": "text/plain"}}, "Metrics": [ { "MetricName": "J", "Count": 1, "Min": 1.0, "Max": 2.0, "Avg": 3.0, "StdDev": 4.0, "SourceArn": "K", "Timestamp": now, } ], } obj = trial_component.TrialComponent.load(trial_component_name="foo", sagemaker_boto_client=sagemaker_boto_client) sagemaker_boto_client.describe_trial_component.assert_called_with(TrialComponentName="foo") assert "A" == obj.trial_component_arn assert "B" == obj.trial_component_name assert "C" == obj.display_name assert api_types.TrialComponentStatus(primary_status="InProgress", message="D") == obj.status assert {"E": 1.0, "F": "G"} == obj.parameters assert {"H": api_types.TrialComponentArtifact(value="s3://foo/bar", media_type="text/plain")} assert {"I": api_types.TrialComponentArtifact(value="s3://whizz/bang", media_type="text/plain")} assert [ api_types.TrialComponentMetricSummary( metric_name="J", count=1, min=1.0, max=2.0, avg=3.0, std_dev=4.0, source_arn="K", timestamp=now ) ] def test_list(sagemaker_boto_client): start_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1) end_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=2) creation_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=3) last_modified_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=4) sagemaker_boto_client.list_trial_components.side_effect = [ { "TrialComponentSummaries": [ { "TrialComponentName": "A" + str(i), "TrialComponentArn": "B" + str(i), "DisplayName": "C" + str(i), "SourceArn": "D" + str(i), "Status": {"PrimaryStatus": "InProgress", "Message": "E" + str(i)}, "StartTime": start_time + datetime.timedelta(hours=i), "EndTime": end_time + datetime.timedelta(hours=i), "CreationTime": creation_time + datetime.timedelta(hours=i), "LastModifiedTime": last_modified_time + datetime.timedelta(hours=i), "LastModifiedBy": {}, } for i in range(10) ], "NextToken": "100", }, { "TrialComponentSummaries": [ { "TrialComponentName": "A" + str(i), "TrialComponentArn": "B" + str(i), "DisplayName": "C" + str(i), "SourceArn": "D" + str(i), "Status": {"PrimaryStatus": "InProgress", "Message": "E" + str(i)}, "StartTime": start_time + datetime.timedelta(hours=i), "EndTime": end_time + datetime.timedelta(hours=i), "CreationTime": creation_time + datetime.timedelta(hours=i), "LastModifiedTime": last_modified_time + datetime.timedelta(hours=i), "LastModifiedBy": {}, } for i in range(10, 20) ] }, ] expected = [ api_types.TrialComponentSummary( trial_component_name="A" + str(i), trial_component_arn="B" + str(i), display_name="C" + str(i), source_arn="D" + str(i), status=api_types.TrialComponentStatus(primary_status="InProgress", message="E" + str(i)), start_time=start_time + datetime.timedelta(hours=i), end_time=end_time + datetime.timedelta(hours=i), creation_time=creation_time + datetime.timedelta(hours=i), last_modified_time=last_modified_time + datetime.timedelta(hours=i), last_modified_by={}, ) for i in range(20) ] result = list( trial_component.TrialComponent.list( sagemaker_boto_client=sagemaker_boto_client, source_arn="foo", sort_by="CreationTime", sort_order="Ascending", ) ) assert expected == result expected_calls = [ unittest.mock.call(SortBy="CreationTime", SortOrder="Ascending", SourceArn="foo"), unittest.mock.call(NextToken="100", SortBy="CreationTime", SortOrder="Ascending", SourceArn="foo"), ] assert expected_calls == sagemaker_boto_client.list_trial_components.mock_calls def test_list_empty(sagemaker_boto_client): sagemaker_boto_client.list_trial_components.return_value = {"TrialComponentSummaries": []} assert [] == list(trial_component.TrialComponent.list(sagemaker_boto_client=sagemaker_boto_client)) def test_list_trial_components_call_args(sagemaker_boto_client): created_before = datetime.datetime(1999, 10, 12, 0, 0, 0) created_after = datetime.datetime(1990, 10, 12, 0, 0, 0) trial_name = "foo-trial" experiment_name = "foo-experiment" next_token = "thetoken" max_results = 99 sagemaker_boto_client.list_trial_components.return_value = {} assert [] == list( trial_component.TrialComponent.list( sagemaker_boto_client=sagemaker_boto_client, trial_name=trial_name, experiment_name=experiment_name, created_before=created_before, created_after=created_after, next_token=next_token, max_results=max_results, sort_by="CreationTime", sort_order="Ascending", ) ) expected_calls = [ unittest.mock.call( TrialName="foo-trial", ExperimentName="foo-experiment", CreatedBefore=created_before, CreatedAfter=created_after, SortBy="CreationTime", SortOrder="Ascending", NextToken="thetoken", MaxResults=99, ) ] assert expected_calls == sagemaker_boto_client.list_trial_components.mock_calls def test_search(sagemaker_boto_client): sagemaker_boto_client.search.return_value = { "Results": [ { "TrialComponent": { "TrialComponentName": "tc-1", "TrialComponentArn": "arn::tc-1", "DisplayName": "TC1", } }, { "TrialComponent": { "TrialComponentName": "tc-2", "TrialComponentArn": "arn::tc-2", "DisplayName": "TC2", } }, ] } expected = [ api_types.TrialComponentSearchResult( trial_component_name="tc-1", trial_component_arn="arn::tc-1", display_name="TC1" ), api_types.TrialComponentSearchResult( trial_component_name="tc-2", trial_component_arn="arn::tc-2", display_name="TC2" ), ] assert expected == list(trial_component.TrialComponent.search(sagemaker_boto_client=sagemaker_boto_client)) def test_save(sagemaker_boto_client): obj = trial_component.TrialComponent( sagemaker_boto_client, trial_component_name="foo", display_name="bar", parameters_to_remove=["E"], input_artifacts_to_remove=["F"], output_artifacts_to_remove=["G"], ) sagemaker_boto_client.update_trial_component.return_value = {} obj.save() sagemaker_boto_client.update_trial_component.assert_called_with( TrialComponentName="foo", DisplayName="bar", ParametersToRemove=["E"], InputArtifactsToRemove=["F"], OutputArtifactsToRemove=["G"], ) def test_delete(sagemaker_boto_client): obj = trial_component.TrialComponent(sagemaker_boto_client, trial_component_name="foo", display_name="bar") sagemaker_boto_client.delete_trial_component.return_value = {} obj.delete() sagemaker_boto_client.delete_trial_component.assert_called_with(TrialComponentName="foo") def test_delete_with_force_disassociate(sagemaker_boto_client): obj = trial_component.TrialComponent(sagemaker_boto_client, trial_component_name="foo", display_name="bar") sagemaker_boto_client.delete_trial_component.return_value = {} sagemaker_boto_client.list_trials.side_effect = [ {"TrialSummaries": [{"TrialName": "trial-1"}, {"TrialName": "trial-2"}], "NextToken": "a"}, {"TrialSummaries": [{"TrialName": "trial-3"}, {"TrialName": "trial-4"}]}, ] obj.delete(force_disassociate=True) expected_calls = [ unittest.mock.call(TrialName="trial-1", TrialComponentName="foo"), unittest.mock.call(TrialName="trial-2", TrialComponentName="foo"), unittest.mock.call(TrialName="trial-3", TrialComponentName="foo"), unittest.mock.call(TrialName="trial-4", TrialComponentName="foo"), ] assert expected_calls == sagemaker_boto_client.disassociate_trial_component.mock_calls sagemaker_boto_client.delete_trial_component.assert_called_with(TrialComponentName="foo") def test_list_trials(sagemaker_boto_client): sagemaker_boto_client.list_trials.return_value = { "TrialSummaries": [ { "TrialName": "trial-1", "CreationTime": None, "LastModifiedTime": None, }, { "TrialName": "trial-2", "CreationTime": None, "LastModifiedTime": None, }, ] } obj = trial_component.TrialComponent(sagemaker_boto_client, trial_component_name="foo", discompplay_name="bar") expected = [ api_types.TrialSummary(trial_name="trial-1", creation_time=None, last_modified_time=None), api_types.TrialSummary(trial_name="trial-2", creation_time=None, last_modified_time=None), ] assert expected == list(obj.list_trials()) sagemaker_boto_client.list_trials.assert_called_with(TrialComponentName="foo") def test_boto_ignore(): obj = trial_component.TrialComponent(sagemaker_boto_client, trial_component_name="foo", display_name="bar") assert obj._boto_ignore() == ["ResponseMetadata", "CreatedBy"]