# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. # language governing permissions and limitations under the License. from __future__ import absolute_import from sagemaker.workflow.retry import ( RetryPolicy, StepRetryPolicy, SageMakerJobStepRetryPolicy, StepExceptionTypeEnum, SageMakerJobExceptionTypeEnum, ) def test_valid_step_retry_policy(): retry_policy = StepRetryPolicy( exception_types=[StepExceptionTypeEnum.SERVICE_FAULT, StepExceptionTypeEnum.THROTTLING], interval_seconds=5, max_attempts=3, ) assert retry_policy.to_request() == { "ExceptionType": ["Step.SERVICE_FAULT", "Step.THROTTLING"], "IntervalSeconds": 5, "BackoffRate": 2.0, "MaxAttempts": 3, } retry_policy = StepRetryPolicy( exception_types=[StepExceptionTypeEnum.SERVICE_FAULT, StepExceptionTypeEnum.THROTTLING], interval_seconds=5, backoff_rate=2.0, expire_after_mins=30, ) assert retry_policy.to_request() == { "ExceptionType": ["Step.SERVICE_FAULT", "Step.THROTTLING"], "IntervalSeconds": 5, "BackoffRate": 2.0, "ExpireAfterMin": 30, } def test_invalid_step_retry_policy(): try: StepRetryPolicy( exception_types=[SageMakerJobExceptionTypeEnum.INTERNAL_ERROR], interval_seconds=5, max_attempts=3, ) assert False except Exception: assert True def test_valid_sagemaker_job_step_retry_policy(): retry_policy = SageMakerJobStepRetryPolicy( exception_types=[SageMakerJobExceptionTypeEnum.RESOURCE_LIMIT], failure_reason_types=[ SageMakerJobExceptionTypeEnum.INTERNAL_ERROR, SageMakerJobExceptionTypeEnum.CAPACITY_ERROR, ], interval_seconds=5, max_attempts=3, ) assert retry_policy.to_request() == { "ExceptionType": [ "SageMaker.RESOURCE_LIMIT", "SageMaker.JOB_INTERNAL_ERROR", "SageMaker.CAPACITY_ERROR", ], "IntervalSeconds": 5, "BackoffRate": 2.0, "MaxAttempts": 3, } retry_policy = SageMakerJobStepRetryPolicy( exception_types=[SageMakerJobExceptionTypeEnum.RESOURCE_LIMIT], failure_reason_types=[ SageMakerJobExceptionTypeEnum.INTERNAL_ERROR, SageMakerJobExceptionTypeEnum.CAPACITY_ERROR, ], interval_seconds=5, max_attempts=3, ) assert retry_policy.to_request() == { "ExceptionType": [ "SageMaker.RESOURCE_LIMIT", "SageMaker.JOB_INTERNAL_ERROR", "SageMaker.CAPACITY_ERROR", ], "IntervalSeconds": 5, "BackoffRate": 2.0, "MaxAttempts": 3, } def test_invalid_retry_policy(): retry_policies = [ (-5, 2.0, 3, None), (5, -2.0, 3, None), (5, 2.0, -3, None), (5, 2.0, 21, None), (5, 2.0, None, -1), (5, 2.0, None, 14401), (5, 2.0, 10, 30), ] for (interval_sec, backoff_rate, max_attempts, expire_after) in retry_policies: try: RetryPolicy( interval_seconds=interval_sec, backoff_rate=backoff_rate, max_attempts=max_attempts, expire_after_mins=expire_after, ).to_request() assert False except Exception: assert True