# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """This module contains code related to SKLearn Processors which are used for Processing jobs. These jobs let customers perform data pre-processing, post-processing, feature engineering, data validation, and model evaluation and interpretation on SageMaker. """ from __future__ import absolute_import from typing import Union, List, Dict, Optional from sagemaker.network import NetworkConfig from sagemaker import image_uris, Session from sagemaker.processing import ScriptProcessor from sagemaker.sklearn import defaults from sagemaker.workflow.entities import PipelineVariable class SKLearnProcessor(ScriptProcessor): """Handles Amazon SageMaker processing tasks for jobs using scikit-learn.""" def __init__( self, framework_version: str, # New arg role: Optional[Union[str, PipelineVariable]] = None, instance_count: Union[int, PipelineVariable] = None, instance_type: Union[str, PipelineVariable] = None, command: Optional[List[str]] = None, volume_size_in_gb: Union[int, PipelineVariable] = 30, volume_kms_key: Optional[Union[str, PipelineVariable]] = None, output_kms_key: Optional[Union[str, PipelineVariable]] = None, max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, base_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, network_config: Optional[NetworkConfig] = None, ): """Initialize an ``SKLearnProcessor`` instance. The SKLearnProcessor handles Amazon SageMaker processing tasks for jobs using scikit-learn. Args: framework_version (str): The version of scikit-learn. role (str): An AWS IAM role name or ARN. The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource. instance_type (str or PipelineVariable): Type of EC2 instance to use for processing, for example, 'ml.c4.xlarge'. instance_count (int or PipelineVariable): The number of instances to run the Processing job with. Defaults to 1. command ([str]): The command to run, along with any command-line flags. Example: ["python3", "-v"]. If not provided, ["python3"] or ["python2"] will be chosen based on the py_version parameter. volume_size_in_gb (int or PipelineVariable): Size in GB of the EBS volume to use for storing data during processing (default: 30). volume_kms_key (str or PipelineVariable): A KMS key for the processing volume. output_kms_key (str or PipelineVariable): The KMS key id for all ProcessingOutputs. max_runtime_in_seconds (int or PipelineVariable): Timeout in seconds. After this amount of time Amazon SageMaker terminates the job regardless of its current status. base_job_name (str): Prefix for processing name. If not specified, the processor generates a default job name, based on the training image name and current timestamp. sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the processor creates one using the default AWS configuration chain. env (dict[str, str] or dict[str, PipelineVariable]): Environment variables to be passed to the processing job. tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags to be passed to the processing job. network_config (sagemaker.network.NetworkConfig): A NetworkConfig object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. """ if not command: command = ["python3"] session = sagemaker_session or Session() region = session.boto_region_name image_uri = image_uris.retrieve( defaults.SKLEARN_NAME, region, version=framework_version, instance_type=instance_type ) super(SKLearnProcessor, self).__init__( role=role, image_uri=image_uri, instance_count=instance_count, instance_type=instance_type, command=command, volume_size_in_gb=volume_size_in_gb, volume_kms_key=volume_kms_key, output_kms_key=output_kms_key, max_runtime_in_seconds=max_runtime_in_seconds, base_job_name=base_job_name, sagemaker_session=session, env=env, tags=tags, network_config=network_config, )