# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import import re import pytest from sagemaker import get_execution_role, utils from sagemaker.workflow.lambda_step import ( LambdaStep, LambdaOutput, LambdaOutputTypeEnum, ) from sagemaker.lambda_helper import Lambda from sagemaker.workflow.parameters import ParameterInteger from sagemaker.workflow.pipeline import Pipeline @pytest.fixture def role(sagemaker_session): return get_execution_role(sagemaker_session) @pytest.fixture def pipeline_name(): return utils.unique_name_from_base("my-pipeline-lambda") @pytest.fixture def region_name(sagemaker_session): return sagemaker_session.boto_session.region_name def test_one_step_lambda_pipeline(sagemaker_session, role, pipeline_name, region_name): instance_count = ParameterInteger(name="InstanceCount", default_value=2) outputParam1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String) step_lambda = LambdaStep( name="lambda-step", lambda_func=Lambda( function_arn=("arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda"), session=sagemaker_session, ), inputs={"arg1": "foo"}, outputs=[outputParam1], ) pipeline = Pipeline( name=pipeline_name, parameters=[instance_count], steps=[step_lambda], sagemaker_session=sagemaker_session, ) try: response = pipeline.create(role) create_arn = response["PipelineArn"] assert re.match( rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", create_arn, ) pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)] response = pipeline.update(role) update_arn = response["PipelineArn"] assert re.match( rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", update_arn, ) finally: try: pipeline.delete() except Exception: pass def test_two_step_lambda_pipeline_with_output_reference( sagemaker_session, role, pipeline_name, region_name ): instance_count = ParameterInteger(name="InstanceCount", default_value=2) outputParam1 = LambdaOutput(output_name="output1", output_type=LambdaOutputTypeEnum.String) step_lambda1 = LambdaStep( name="lambda-step1", lambda_func=Lambda( function_arn=("arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda"), session=sagemaker_session, ), inputs={"arg1": "foo"}, outputs=[outputParam1], ) step_lambda2 = LambdaStep( name="lambda-step2", lambda_func=Lambda( function_arn=("arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda"), session=sagemaker_session, ), inputs={"arg1": outputParam1}, outputs=[], ) pipeline = Pipeline( name=pipeline_name, parameters=[instance_count], steps=[step_lambda1, step_lambda2], sagemaker_session=sagemaker_session, ) try: response = pipeline.create(role) create_arn = response["PipelineArn"] assert re.match( rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", create_arn, ) finally: try: pipeline.delete() except Exception: pass