# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Contains the TrialComponent class.""" from __future__ import absolute_import import time from botocore.exceptions import ClientError from sagemaker.apiutils import _base_types from sagemaker.experiments import _api_types from sagemaker.experiments._api_types import TrialComponentSearchResult class _TrialComponent(_base_types.Record): """This class represents a SageMaker trial component object. A trial component is a stage in a trial. Trial components are created automatically within the SageMaker runtime and may not be created directly. To automatically associate trial components with a trial and experiment, supply an experiment config when creating a job. For example: https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTrainingJob.html Attributes: trial_component_name (str): The name of the trial component. Generated by SageMaker from the name of the source job with a suffix specific to the type of source job. trial_component_arn (str): The ARN of the trial component. display_name (str): The name of the trial component that will appear in UI, such as SageMaker Studio. source (TrialComponentSource): A TrialComponentSource object with a source_arn attribute. status (str): Status of the source job. start_time (datetime): When the source job started. end_time (datetime): When the source job ended. creation_time (datetime): When the source job was created. created_by (obj): Contextual info on which account created the trial component. last_modified_time (datetime): When the trial component was last modified. last_modified_by (obj): Contextual info on which account last modified the trial component. parameters (dict): Dictionary of parameters to the source job. input_artifacts (dict): Dictionary of input artifacts. output_artifacts (dict): Dictionary of output artifacts. metrics (obj): Aggregated metrics for the job. parameters_to_remove (list): The hyperparameters to remove from the component. input_artifacts_to_remove (list): The input artifacts to remove from the component. output_artifacts_to_remove (list): The output artifacts to remove from the component. tags (List[Dict[str, str]]): A list of tags to associate with the trial component. """ trial_component_name = None trial_component_arn = None display_name = None source = None status = None start_time = None end_time = None creation_time = None created_by = None last_modified_time = None last_modified_by = None parameters = None input_artifacts = None output_artifacts = None metrics = None parameters_to_remove = None input_artifacts_to_remove = None output_artifacts_to_remove = None tags = None _boto_load_method = "describe_trial_component" _boto_create_method = "create_trial_component" _boto_update_method = "update_trial_component" _boto_delete_method = "delete_trial_component" _custom_boto_types = { "source": (_api_types.TrialComponentSource, False), "status": (_api_types.TrialComponentStatus, False), "parameters": (_api_types.TrialComponentParameters, False), "input_artifacts": (_api_types.TrialComponentArtifact, True), "output_artifacts": (_api_types.TrialComponentArtifact, True), "metrics": (_api_types.TrialComponentMetricSummary, True), } _boto_update_members = [ "trial_component_name", "display_name", "status", "start_time", "end_time", "parameters", "input_artifacts", "output_artifacts", "parameters_to_remove", "input_artifacts_to_remove", "output_artifacts_to_remove", ] _boto_delete_members = ["trial_component_name"] def __init__(self, sagemaker_session=None, **kwargs): """Init for _TrialComponent""" super().__init__(sagemaker_session, **kwargs) self.parameters = self.parameters or {} self.input_artifacts = self.input_artifacts or {} self.output_artifacts = self.output_artifacts or {} @classmethod def _boto_ignore(cls): """Response fields to ignore by default.""" return super(_TrialComponent, cls)._boto_ignore() + ["CreatedBy"] def save(self): """Save the state of this TrialComponent to SageMaker.""" return self._invoke_api(self._boto_update_method, self._boto_update_members) def delete(self, force_disassociate=False): """Delete this TrialComponent from SageMaker. Args: force_disassociate (boolean): Indicates whether to force disassociate the trial component with the trials before deletion (default: False). If set to true, force disassociate the trial component with associated trials first, then delete the trial component. If it's not set or set to false, it will delete the trial component directory without disassociation. Returns: dict: Delete trial component response. """ if force_disassociate: next_token = None while True: if next_token: list_trials_response = self.sagemaker_session.sagemaker_client.list_trials( TrialComponentName=self.trial_component_name, NextToken=next_token ) else: list_trials_response = self.sagemaker_session.sagemaker_client.list_trials( TrialComponentName=self.trial_component_name ) # Disassociate the trials and trial components for per_trial in list_trials_response["TrialSummaries"]: # to prevent DisassociateTrialComponent throttling time.sleep(1.2) self.sagemaker_session.sagemaker_client.disassociate_trial_component( TrialName=per_trial["TrialName"], TrialComponentName=self.trial_component_name, ) if "NextToken" in list_trials_response: next_token = list_trials_response["NextToken"] else: break return self._invoke_api(self._boto_delete_method, self._boto_delete_members) @classmethod def load(cls, trial_component_name, sagemaker_session=None): """Load an existing trial component and return an `_TrialComponent` object representing it. Args: trial_component_name (str): Name of the trial component sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, one is created using the default AWS configuration chain. Returns: experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object """ trial_component = cls._construct( cls._boto_load_method, trial_component_name=trial_component_name, sagemaker_session=sagemaker_session, ) return trial_component @classmethod def create(cls, trial_component_name, display_name=None, tags=None, sagemaker_session=None): """Create a trial component and return a `_TrialComponent` object representing it. Args: trial_component_name (str): The name of the trial component. display_name (str): Display name of the trial component used by Studio (default: None). tags (List[Dict[str, str]]): Tags to add to the trial component (default: None). sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, one is created using the default AWS configuration chain. Returns: experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object. """ return super(_TrialComponent, cls)._construct( cls._boto_create_method, trial_component_name=trial_component_name, display_name=display_name, tags=tags, sagemaker_session=sagemaker_session, ) @classmethod def list( cls, source_arn=None, created_before=None, created_after=None, sort_by=None, sort_order=None, sagemaker_session=None, trial_name=None, experiment_name=None, max_results=None, next_token=None, ): """Return a list of trial component summaries. Args: source_arn (str): A SageMaker Training or Processing Job ARN (default: None). created_before (datetime.datetime): Return trial components created before this instant (default: None). created_after (datetime.datetime): Return trial components created after this instant (default: None). sort_by (str): Which property to sort results by. One of 'Name', 'CreationTime' (default: None). sort_order (str): One of 'Ascending', or 'Descending' (default: None). sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, one is created using the default AWS configuration chain. trial_name (str): If provided only trial components related to the trial are returned (default: None). experiment_name (str): If provided only trial components related to the experiment are returned (default: None). max_results (int): maximum number of trial components to retrieve (default: None). next_token (str): token for next page of results (default: None). Returns: collections.Iterator[experiments._api_types.TrialComponentSummary]: An iterator over `TrialComponentSummary` objects. """ return super(_TrialComponent, cls)._list( "list_trial_components", _api_types.TrialComponentSummary.from_boto, "TrialComponentSummaries", source_arn=source_arn, created_before=created_before, created_after=created_after, sort_by=sort_by, sort_order=sort_order, sagemaker_session=sagemaker_session, trial_name=trial_name, experiment_name=experiment_name, max_results=max_results, next_token=next_token, ) @classmethod def search( cls, search_expression=None, sort_by=None, sort_order=None, max_results=None, sagemaker_session=None, ): """Search Experiment Trail Component. Returns SearchResults in the account matching the search criteria. Args: search_expression: (SearchExpression): A Boolean conditional statement (default: None). Resource objects must satisfy this condition to be included in search results. You must provide at least one subexpression, filter, or nested filter. sort_by (str): The name of the resource property used to sort the SearchResults (default: None). sort_order (str): How SearchResults are ordered. Valid values are Ascending or Descending (default: None). max_results (int): The maximum number of results to return in a SearchResponse (default: None). sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, one is created using the default AWS configuration chain. Returns: collections.Iterator[SearchResult] : An iterator over search results matching the search criteria. """ return super(_TrialComponent, cls)._search( search_resource="ExperimentTrialComponent", search_item_factory=TrialComponentSearchResult.from_boto, search_expression=None if search_expression is None else search_expression.to_boto(), sort_by=sort_by, sort_order=sort_order, max_results=max_results, sagemaker_session=sagemaker_session, ) @classmethod def _load_or_create( cls, trial_component_name, display_name=None, tags=None, sagemaker_session=None ): """Load a trial component by name and create a new one if it does not exist. Args: trial_component_name (str): The name of the trial component. display_name (str): Display name of the trial component used by Studio (default: None). This is used only when the given `trial_component_name` does not exist and a new trial component has to be created. tags (List[Dict[str, str]]): Tags to add to the trial component (default: None). This is used only when the given `trial_component_name` does not exist and a new trial component has to be created. sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, one is created using the default AWS configuration chain. Returns: experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object. bool: A boolean variable indicating whether the trail component already exists """ is_existed = False try: run_tc = _TrialComponent.create( trial_component_name=trial_component_name, display_name=display_name, tags=tags, sagemaker_session=sagemaker_session, ) except ClientError as ce: error_code = ce.response["Error"]["Code"] error_message = ce.response["Error"]["Message"] if not (error_code == "ValidationException" and "already exists" in error_message): raise ce # already exists run_tc = _TrialComponent.load(trial_component_name, sagemaker_session) is_existed = True return run_tc, is_existed @classmethod def _trial_component_is_associated_to_trial( cls, trial_component_name, trial_name=None, sagemaker_session=None ): """Returns a bool based on if trial_component is already associated with the trial. Args: trial_component_name (str): The name of the trial component. trial_name: (str): The name of the trial. sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. Returns: bool: A boolean variable indicating whether the trial component is already associated with the trial. """ search_results = sagemaker_session.sagemaker_client.search( Resource="ExperimentTrialComponent", SearchExpression={ "Filters": [ { "Name": "TrialComponentName", "Operator": "Equals", "Value": str(trial_component_name), }, { "Name": "Parents.TrialName", "Operator": "Equals", "Value": str(trial_name), }, ] }, ) if search_results["Results"]: return True return False