# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """The step definitions for workflow.""" from __future__ import absolute_import from typing import List, Dict, Optional, Union from enum import Enum import warnings import attr from sagemaker.workflow.entities import ( RequestType, ) from sagemaker.workflow.properties import ( Properties, ) from sagemaker.workflow.entities import ( DefaultEnumMeta, ) from sagemaker.workflow.step_collections import StepCollection from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig from sagemaker.lambda_helper import Lambda class LambdaOutputTypeEnum(Enum, metaclass=DefaultEnumMeta): """LambdaOutput type enum.""" String = "String" Integer = "Integer" Boolean = "Boolean" Float = "Float" @attr.s class LambdaOutput: """Output for a lambdaback step. Attributes: output_name (str): The output name output_type (LambdaOutputTypeEnum): The output type """ output_name: str = attr.ib(default=None) output_type: LambdaOutputTypeEnum = attr.ib(default=LambdaOutputTypeEnum.String) def to_request(self) -> RequestType: """Get the request structure for workflow service calls.""" return { "OutputName": self.output_name, "OutputType": self.output_type.value, } def expr(self, step_name) -> Dict[str, str]: """The 'Get' expression dict for a `LambdaOutput`.""" return LambdaOutput._expr(self.output_name, step_name) @classmethod def _expr(cls, name, step_name): """An internal classmethod for the 'Get' expression dict for a `LambdaOutput`. Args: name (str): The name of the lambda output. step_name (str): The name of the step the lambda step associated with this output belongs to. """ return {"Get": f"Steps.{step_name}.OutputParameters['{name}']"} class LambdaStep(Step): """Lambda step for workflow.""" def __init__( self, name: str, lambda_func: Lambda, display_name: str = None, description: str = None, inputs: dict = None, outputs: List[LambdaOutput] = None, cache_config: CacheConfig = None, depends_on: Optional[List[Union[str, Step, StepCollection]]] = None, ): """Constructs a LambdaStep. Args: name (str): The name of the lambda step. display_name (str): The display name of the Lambda step. description (str): The description of the Lambda step. lambda_func (str): An instance of sagemaker.lambda_helper.Lambda. If lambda arn is specified in the instance, LambdaStep just invokes the function, else lambda function will be created while creating the pipeline. inputs (dict): Input arguments that will be provided to the lambda function. outputs (List[LambdaOutput]): List of outputs from the lambda function. cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection` names or `Step` instances or `StepCollection` instances that this `LambdaStep` depends on. """ super(LambdaStep, self).__init__( name, display_name, description, StepTypeEnum.LAMBDA, depends_on ) self.lambda_func = lambda_func self.outputs = outputs if outputs is not None else [] self.cache_config = cache_config self.inputs = inputs if inputs is not None else {} root_prop = Properties(step_name=name) property_dict = {} for output in self.outputs: property_dict[output.output_name] = Properties( step_name=name, path=f"OutputParameters['{output.output_name}']" ) root_prop.__dict__["Outputs"] = property_dict self._properties = root_prop @property def arguments(self) -> RequestType: """The arguments dict that is used to define the lambda step.""" return self.inputs @property def properties(self): """A Properties object representing the output parameters of the lambda step.""" return self._properties def to_request(self) -> RequestType: """Updates the dictionary with cache configuration.""" request_dict = super().to_request() if self.cache_config: request_dict.update(self.cache_config.config) function_arn = self._get_function_arn() request_dict["FunctionArn"] = function_arn request_dict["OutputParameters"] = list(map(lambda op: op.to_request(), self.outputs)) return request_dict def _get_function_arn(self): """Returns the lambda function arn It upserts a lambda function if function name is provided. It updates a lambda function if lambda arn and code is provided. It is a no-op if code is not provided but function arn is provided. """ if self.lambda_func.function_arn is None: response = self.lambda_func.upsert() return response["FunctionArn"] if self.lambda_func.zipped_code_dir is None and self.lambda_func.script is None: warnings.warn( "Lambda function won't be updated because zipped_code_dir \ or script is not provided." ) return self.lambda_func.function_arn response = self.lambda_func.update() return response["FunctionArn"]