# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Contains the SageMaker Experiment Run class."""
from __future__ import absolute_import
import datetime
import logging
from enum import Enum
from math import isnan, isinf
from numbers import Number
from typing import Optional, List, Dict, TYPE_CHECKING, Union
import dateutil
from numpy import array
from sagemaker.apiutils import _utils
from sagemaker.experiments import _api_types
from sagemaker.experiments._api_types import (
TrialComponentArtifact,
_TrialComponentStatusType,
)
from sagemaker.experiments._helper import (
_ArtifactUploader,
_LineageArtifactTracker,
_DEFAULT_ARTIFACT_PREFIX,
)
from sagemaker.experiments._environment import _RunEnvironment
from sagemaker.experiments._run_context import _RunContext
from sagemaker.experiments.experiment import Experiment
from sagemaker.experiments._metrics import _MetricsManager
from sagemaker.experiments.trial import _Trial
from sagemaker.experiments.trial_component import _TrialComponent
from sagemaker.utils import (
get_module,
unique_name_from_base,
)
from sagemaker.experiments._utils import (
guess_media_type,
resolve_artifact_name,
verify_length_of_true_and_predicted,
validate_invoked_inside_run_context,
get_tc_and_exp_config_from_job_env,
verify_load_input_names,
is_run_trial_component,
)
if TYPE_CHECKING:
from sagemaker import Session
logger = logging.getLogger(__name__)
RUN_NAME_BASE = "Sagemaker-Run".lower()
TRIAL_NAME_TEMPLATE = "Default-Run-Group-{}"
MAX_RUN_TC_ARTIFACTS_LEN = 30
MAX_NAME_LEN_IN_BACKEND = 120
EXPERIMENT_NAME = "ExperimentName"
TRIAL_NAME = "TrialName"
RUN_NAME = "RunName"
DELIMITER = "-"
RUN_TC_TAG_KEY = "sagemaker:trial-component-source"
RUN_TC_TAG_VALUE = "run"
RUN_TC_TAG = {"Key": RUN_TC_TAG_KEY, "Value": RUN_TC_TAG_VALUE}
class SortByType(Enum):
"""The type of property by which to sort the `list_runs` results."""
CREATION_TIME = "CreationTime"
NAME = "Name"
class SortOrderType(Enum):
"""The type of order to sort the list or search results."""
ASCENDING = "Ascending"
DESCENDING = "Descending"
class Run(object):
"""A collection of parameters, metrics, and artifacts to create a ML model."""
def __init__(
self,
experiment_name: str,
run_name: Optional[str] = None,
experiment_display_name: Optional[str] = None,
run_display_name: Optional[str] = None,
tags: Optional[List[Dict[str, str]]] = None,
sagemaker_session: Optional["Session"] = None,
artifact_bucket: Optional[str] = None,
artifact_prefix: Optional[str] = None,
):
"""Construct a `Run` instance.
SageMaker Experiments automatically tracks the inputs, parameters, configurations,
and results of your iterations as runs.
You can assign, group, and organize these runs into experiments.
You can also create, compare, and evaluate runs.
The code sample below shows how to initialize a run, log parameters to the Run object
and invoke a training job under the context of this Run object, which automatically
passes the run's ``experiment_config`` (including the experiment name, run name etc.)
to the training job.
Note:
All log methods (e.g. ``log_parameter``, ``log_metric``, etc.) have to be called within
the run context (i.e. the ``with`` statement). Otherwise, a ``RuntimeError`` is thrown.
.. code:: python
with Run(experiment_name="my-exp", run_name="my-run", ...) as run:
run.log_parameter(...)
...
estimator.fit(job_name="my-job") # Create a training job
In order to reuse an existing run to log extra data, ``load_run`` is recommended.
For example, instead of the ``Run`` constructor, the ``load_run`` is recommended to use
in a job script to load the existing run created before the job launch.
Otherwise, a new run may be created each time you launch a job.
The code snippet below displays how to load the run initialized above
in a custom training job script, where no ``run_name`` or ``experiment_name``
is presented as they are automatically retrieved from the experiment config
in the job environment.
.. code:: python
with load_run(sagemaker_session=sagemaker_session) as run:
run.log_metric(...)
...
Args:
experiment_name (str): The name of the experiment. The name must be unique
within an account.
run_name (str): The name of the run. If it is not specified, one is auto generated.
experiment_display_name (str): Name of the experiment that will appear in UI,
such as SageMaker Studio. (default: None). This display name is used in
a create experiment call. If an experiment with the specified name already exists,
this display name won't take effect.
run_display_name (str): The display name of the run used in UI (default: None).
This display name is used in a create run call. If a run with the
specified name already exists, this display name won't take effect.
tags (List[Dict[str, str]]): A list of tags to be used for all create calls,
e.g. to create an experiment, a run group, etc. (default: None).
sagemaker_session (sagemaker.session.Session): Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, one is created using the
default AWS configuration chain.
artifact_bucket (str): The S3 bucket to upload the artifact to.
If not specified, the default bucket defined in `sagemaker_session`
will be used.
artifact_prefix (str): The S3 key prefix used to generate the S3 path
to upload the artifact to (default: "trial-component-artifacts").
"""
# TODO: we should revert the lower casting once backend fix reaches prod
self.experiment_name = experiment_name.lower()
sagemaker_session = sagemaker_session or _utils.default_session()
self.run_name = run_name or unique_name_from_base(RUN_NAME_BASE)
# avoid confusion due to mis-match in casing between run name and TC name
self.run_name = self.run_name.lower()
trial_component_name = Run._generate_trial_component_name(
run_name=self.run_name, experiment_name=self.experiment_name
)
self.run_group_name = Run._generate_trial_name(self.experiment_name)
self._experiment = Experiment._load_or_create(
experiment_name=self.experiment_name,
display_name=experiment_display_name,
tags=tags,
sagemaker_session=sagemaker_session,
)
self._trial = _Trial._load_or_create(
experiment_name=self.experiment_name,
trial_name=self.run_group_name,
tags=tags,
sagemaker_session=sagemaker_session,
)
self._trial_component, is_existed = _TrialComponent._load_or_create(
trial_component_name=trial_component_name,
display_name=run_display_name,
tags=Run._append_run_tc_label_to_tags(tags),
sagemaker_session=sagemaker_session,
)
if is_existed:
logger.info(
"The run (%s) under experiment (%s) already exists. Loading it.",
self.run_name,
self.experiment_name,
)
if not _TrialComponent._trial_component_is_associated_to_trial(
self._trial_component.trial_component_name, self._trial.trial_name, sagemaker_session
):
self._trial.add_trial_component(self._trial_component)
self._artifact_uploader = _ArtifactUploader(
trial_component_name=self._trial_component.trial_component_name,
sagemaker_session=sagemaker_session,
artifact_bucket=artifact_bucket,
artifact_prefix=_DEFAULT_ARTIFACT_PREFIX
if artifact_prefix is None
else artifact_prefix,
)
self._lineage_artifact_tracker = _LineageArtifactTracker(
trial_component_arn=self._trial_component.trial_component_arn,
sagemaker_session=sagemaker_session,
)
self._metrics_manager = _MetricsManager(
trial_component_name=self._trial_component.trial_component_name,
sagemaker_session=sagemaker_session,
)
self._inside_init_context = False
self._inside_load_context = False
self._in_load = False
@property
def experiment_config(self) -> dict:
"""Get experiment config from run attributes."""
return {
EXPERIMENT_NAME: self.experiment_name,
TRIAL_NAME: self.run_group_name,
RUN_NAME: self._trial_component.trial_component_name,
}
@validate_invoked_inside_run_context
def log_parameter(self, name: str, value: Union[str, int, float]):
"""Record a single parameter value for this run.
Overwrites any previous value recorded for the specified parameter name.
Args:
name (str): The name of the parameter.
value (str or int or float): The value of the parameter.
"""
if self._is_input_valid("parameter", name, value):
self._trial_component.parameters[name] = value
@validate_invoked_inside_run_context
def log_parameters(self, parameters: Dict[str, Union[str, int, float]]):
"""Record a collection of parameter values for this run.
Args:
parameters (dict[str, str or int or float]): The parameters to record.
"""
filtered_parameters = {
key: value
for (key, value) in parameters.items()
if self._is_input_valid("parameter", key, value)
}
self._trial_component.parameters.update(filtered_parameters)
@validate_invoked_inside_run_context
def log_metric(
self,
name: str,
value: float,
timestamp: Optional[datetime.datetime] = None,
step: Optional[int] = None,
):
"""Record a custom scalar metric value for this run.
Note:
This method is for manual custom metrics, for automatic metrics see the
``enable_sagemaker_metrics`` parameter on the ``estimator`` class.
Args:
name (str): The name of the metric.
value (float): The value of the metric.
timestamp (datetime.datetime): The timestamp of the metric.
If not specified, the current UTC time will be used.
step (int): The integer iteration number of the metric value (default: None).
"""
if self._is_input_valid("metric", name, value):
self._metrics_manager.log_metric(
metric_name=name, value=value, timestamp=timestamp, step=step
)
@validate_invoked_inside_run_context
def log_precision_recall(
self,
y_true: Union[list, array],
predicted_probabilities: Union[list, array],
positive_label: Optional[Union[str, int]] = None,
title: Optional[str] = None,
is_output: bool = True,
no_skill: Optional[int] = None,
):
"""Create and log a precision recall graph artifact for Studio UI to render.
The artifact is stored in S3 and represented as a lineage artifact
with an association with the run.
You can view the artifact in the UI.
If your job is created by a pipeline execution you can view the artifact
by selecting the corresponding step in the pipelines UI.
See also `SageMaker Pipelines `_
This method requires sklearn library.
Args:
y_true (list or array): True labels. If labels are not binary
then positive_label should be given.
predicted_probabilities (list or array): Estimated/predicted probabilities.
positive_label (str or int): Label of the positive class (default: None).
title (str): Title of the graph (default: None).
is_output (bool): Determines direction of association to the
run. Defaults to True (output artifact).
If set to False then represented as input association.
no_skill (int): The precision threshold under which the classifier cannot discriminate
between the classes and would predict a random class or a constant class in
all cases (default: None).
"""
verify_length_of_true_and_predicted(
true_labels=y_true,
predicted_attrs=predicted_probabilities,
predicted_attrs_name="predicted probabilities",
)
get_module("sklearn")
from sklearn.metrics import precision_recall_curve, average_precision_score
kwargs = {}
if positive_label is not None:
kwargs["pos_label"] = positive_label
precision, recall, _ = precision_recall_curve(y_true, predicted_probabilities, **kwargs)
kwargs["average"] = "micro"
ap = average_precision_score(y_true, predicted_probabilities, **kwargs)
data = {
"type": "PrecisionRecallCurve",
"version": 0,
"title": title,
"precision": precision.tolist(),
"recall": recall.tolist(),
"averagePrecisionScore": ap,
"noSkill": no_skill,
}
self._log_graph_artifact(
artifact_name=title,
data=data,
graph_type="PrecisionRecallCurve",
is_output=is_output,
)
@validate_invoked_inside_run_context
def log_roc_curve(
self,
y_true: Union[list, array],
y_score: Union[list, array],
title: Optional[str] = None,
is_output: bool = True,
):
"""Create and log a receiver operating characteristic (ROC curve) artifact.
The artifact is stored in S3 and represented as a lineage artifact
with an association with the run.
You can view the artifact in the UI.
If your job is created by a pipeline execution you can view the artifact
by selecting the corresponding step in the pipelines UI.
See also `SageMaker Pipelines `_
This method requires sklearn library.
Args:
y_true (list or array): True labels. If labels are not binary
then positive_label should be given.
y_score (list or array): Estimated/predicted probabilities.
title (str): Title of the graph (default: None).
is_output (bool): Determines direction of association to the
run. Defaults to True (output artifact).
If set to False then represented as input association.
"""
verify_length_of_true_and_predicted(
true_labels=y_true,
predicted_attrs=y_score,
predicted_attrs_name="predicted scores",
)
get_module("sklearn")
from sklearn.metrics import roc_curve, auc
fpr, tpr, _ = roc_curve(y_true, y_score)
auc = auc(fpr, tpr)
data = {
"type": "ROCCurve",
"version": 0,
"title": title,
"falsePositiveRate": fpr.tolist(),
"truePositiveRate": tpr.tolist(),
"areaUnderCurve": auc,
}
self._log_graph_artifact(
artifact_name=title, data=data, graph_type="ROCCurve", is_output=is_output
)
@validate_invoked_inside_run_context
def log_confusion_matrix(
self,
y_true: Union[list, array],
y_pred: Union[list, array],
title: Optional[str] = None,
is_output: bool = True,
):
"""Create and log a confusion matrix artifact.
The artifact is stored in S3 and represented as a lineage artifact
with an association with the run.
You can view the artifact in the UI.
If your job is created by a pipeline execution you can view the
artifact by selecting the corresponding step in the pipelines UI.
See also `SageMaker Pipelines `_
This method requires sklearn library.
Args:
y_true (list or array): True labels. If labels are not binary
then positive_label should be given.
y_pred (list or array): Predicted labels.
title (str): Title of the graph (default: None).
is_output (bool): Determines direction of association to the
run. Defaults to True (output artifact).
If set to False then represented as input association.
"""
verify_length_of_true_and_predicted(
true_labels=y_true,
predicted_attrs=y_pred,
predicted_attrs_name="predicted labels",
)
get_module("sklearn")
from sklearn.metrics import confusion_matrix
matrix = confusion_matrix(y_true, y_pred)
data = {
"type": "ConfusionMatrix",
"version": 0,
"title": title,
"confusionMatrix": matrix.tolist(),
}
self._log_graph_artifact(
artifact_name=title,
data=data,
graph_type="ConfusionMatrix",
is_output=is_output,
)
@validate_invoked_inside_run_context
def log_artifact(
self,
name: str,
value: str,
media_type: Optional[str] = None,
is_output: bool = True,
):
"""Record a single artifact for this run.
Overwrites any previous value recorded for the specified name.
Args:
name (str): The name of the artifact.
value (str): The value.
media_type (str): The MediaType (MIME type) of the value (default: None).
is_output (bool): Determines direction of association to the
run. Defaults to True (output artifact).
If set to False then represented as input association.
"""
self._verify_trial_component_artifacts_length(is_output=is_output)
if is_output:
self._trial_component.output_artifacts[name] = TrialComponentArtifact(
value, media_type=media_type
)
else:
self._trial_component.input_artifacts[name] = TrialComponentArtifact(
value, media_type=media_type
)
@validate_invoked_inside_run_context
def log_file(
self,
file_path: str,
name: Optional[str] = None,
media_type: Optional[str] = None,
is_output: bool = True,
):
"""Upload a file to s3 and store it as an input/output artifact in this run.
Args:
file_path (str): The path of the local file to upload.
name (str): The name of the artifact (default: None).
media_type (str): The MediaType (MIME type) of the file.
If not specified, this library will attempt to infer the media type
from the file extension of ``file_path``.
is_output (bool): Determines direction of association to the
run. Defaults to True (output artifact).
If set to False then represented as input association.
"""
self._verify_trial_component_artifacts_length(is_output)
media_type = media_type or guess_media_type(file_path)
name = name or resolve_artifact_name(file_path)
s3_uri, _ = self._artifact_uploader.upload_artifact(file_path)
if is_output:
self._trial_component.output_artifacts[name] = TrialComponentArtifact(
value=s3_uri, media_type=media_type
)
else:
self._trial_component.input_artifacts[name] = TrialComponentArtifact(
value=s3_uri, media_type=media_type
)
def close(self):
"""Persist any data saved locally."""
try:
# Update the trial component with additions from the Run object
self._trial_component.save()
# Create Lineage entities for the artifacts
self._lineage_artifact_tracker.save()
finally:
if self._metrics_manager:
self._metrics_manager.close()
@staticmethod
def _generate_trial_name(base_name) -> str:
"""Generate the reserved trial name based on experiment name
Args:
base_name (str): The ``experiment_name`` of this ``Run`` object.
"""
available_length = MAX_NAME_LEN_IN_BACKEND - len(TRIAL_NAME_TEMPLATE)
return TRIAL_NAME_TEMPLATE.format(base_name[:available_length])
@staticmethod
def _is_input_valid(input_type, field_name, field_value) -> bool:
"""Check if the input is valid or not
Args:
input_type (str): The type of the input, one of ``parameter``, ``metric``.
field_name (str): The name of the field to be checked.
field_value (str or int or float): The value of the field to be checked.
"""
if isinstance(field_value, Number) and (isnan(field_value) or isinf(field_value)):
logger.warning(
"Failed to log %s %s. Received invalid value: %s.",
input_type,
field_name,
field_value,
)
return False
return True
def _log_graph_artifact(self, data, graph_type, is_output, artifact_name=None):
"""Log an artifact.
Logs an artifact by uploading data to S3, creating an artifact, and associating that
artifact with the run trial component.
Args:
data (dict): Artifacts data that will be saved to S3.
graph_type (str): The type of the artifact.
is_output (bool): Determines direction of association to the
trial component. Defaults to True (output artifact).
If set to False then represented as input association.
artifact_name (str): Name of the artifact (default: None).
"""
# generate an artifact name
if not artifact_name:
unique_name_from_base(graph_type)
# create a json file in S3
s3_uri, etag = self._artifact_uploader.upload_object_artifact(
artifact_name, data, file_extension="json"
)
# create an artifact and association for the table
if is_output:
self._lineage_artifact_tracker.add_output_artifact(
name=artifact_name,
source_uri=s3_uri,
etag=etag,
artifact_type=graph_type,
)
else:
self._lineage_artifact_tracker.add_input_artifact(
name=artifact_name,
source_uri=s3_uri,
etag=etag,
artifact_type=graph_type,
)
def _verify_trial_component_artifacts_length(self, is_output):
"""Verify the length of trial component artifacts
Args:
is_output (bool): Determines direction of association to the
trial component.
Raises:
ValueError: If the length of trial component artifacts exceeds the limit.
"""
err_msg_template = "Cannot add more than {} {}_artifacts under run"
if is_output:
if len(self._trial_component.output_artifacts) >= MAX_RUN_TC_ARTIFACTS_LEN:
raise ValueError(err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "output"))
else:
if len(self._trial_component.input_artifacts) >= MAX_RUN_TC_ARTIFACTS_LEN:
raise ValueError(err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "input"))
@staticmethod
def _generate_trial_component_name(run_name: str, experiment_name: str) -> str:
"""Generate the TrialComponentName based on run_name and experiment_name
Args:
run_name (str): The run_name supplied by the user.
experiment_name (str): The experiment_name supplied by the user,
which is prepended to the run_name to generate the TrialComponentName.
Returns:
str: The TrialComponentName used to create a trial component
which is unique in an account.
Raises:
ValueError: If either the run_name or the experiment_name exceeds
the length limit.
"""
buffer = 1 # leave length buffers for delimiters
max_len = int(MAX_NAME_LEN_IN_BACKEND / 2) - buffer
err_msg_template = "The {} (length: {}) must have length less than or equal to {}"
if len(run_name) > max_len:
raise ValueError(err_msg_template.format("run_name", len(run_name), max_len))
if len(experiment_name) > max_len:
raise ValueError(
err_msg_template.format("experiment_name", len(experiment_name), max_len)
)
trial_component_name = "{}{}{}".format(experiment_name, DELIMITER, run_name)
# due to mixed-case concerns on the backend
trial_component_name = trial_component_name.lower()
return trial_component_name
@staticmethod
def _extract_run_name_from_tc_name(trial_component_name: str, experiment_name: str) -> str:
"""Extract the user supplied run name from a trial component name.
Args:
trial_component_name (str): The name of a run trial component.
experiment_name (str): The experiment_name supplied by the user,
which was prepended to the run_name to generate the trial_component_name.
Returns:
str: The name of the Run object supplied by a user.
"""
# TODO: we should revert the lower casting once backend fix reaches prod
return trial_component_name.replace(
"{}{}".format(experiment_name.lower(), DELIMITER), "", 1
)
@staticmethod
def _append_run_tc_label_to_tags(tags: Optional[List[Dict[str, str]]] = None) -> list:
"""Append the run trial component label to tags used to create a trial component.
Args:
tags (List[Dict[str, str]]): The tags supplied by users to initialize a Run object.
Returns:
list: The updated tags with the appended run trial component label.
"""
if not tags:
tags = []
if RUN_TC_TAG not in tags:
tags.append(RUN_TC_TAG)
return tags
def __enter__(self):
"""Updates the start time of the run.
Returns:
object: self.
"""
nested_with_err_msg_template = (
"It is not allowed to use nested 'with' statements on the {}."
)
if self._in_load:
if self._inside_load_context:
raise RuntimeError(nested_with_err_msg_template.format("load_run"))
self._inside_load_context = True
if not self._inside_init_context:
# Add to run context only if the load_run is called separately
# without under a Run init context
_RunContext.add_run_object(self)
else:
if _RunContext.get_current_run():
raise RuntimeError(nested_with_err_msg_template.format("Run"))
self._inside_init_context = True
_RunContext.add_run_object(self)
if not self._trial_component.start_time:
start_time = datetime.datetime.now(dateutil.tz.tzlocal())
self._trial_component.start_time = start_time
self._trial_component.status = _api_types.TrialComponentStatus(
primary_status=_TrialComponentStatusType.InProgress.value,
message="Within a run context",
)
# Save the start_time and status changes to backend
self._trial_component.save()
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
"""Updates the end time of the run.
Args:
exc_type (str): The exception type.
exc_value (str): The exception value.
exc_traceback (str): The stack trace of the exception.
"""
if self._in_load:
self._inside_load_context = False
self._in_load = False
if not self._inside_init_context:
_RunContext.drop_current_run()
else:
self._inside_init_context = False
_RunContext.drop_current_run()
end_time = datetime.datetime.now(dateutil.tz.tzlocal())
self._trial_component.end_time = end_time
if exc_value:
self._trial_component.status = _api_types.TrialComponentStatus(
primary_status=_TrialComponentStatusType.Failed.value,
message=str(exc_value),
)
else:
self._trial_component.status = _api_types.TrialComponentStatus(
primary_status=_TrialComponentStatusType.Completed.value
)
self.close()
def __getstate__(self):
"""Overriding this method to prevent instance of Run from being pickled.
Raise:
NotImplementedError: If attempting to pickle this instance.
"""
raise NotImplementedError("Instance of Run type is not allowed to be pickled.")
def load_run(
run_name: Optional[str] = None,
experiment_name: Optional[str] = None,
sagemaker_session: Optional["Session"] = None,
artifact_bucket: Optional[str] = None,
artifact_prefix: Optional[str] = None,
) -> Run:
"""Load an existing run.
In order to reuse an existing run to log extra data, ``load_run`` is recommended.
It can be used in several ways:
1. Use ``load_run`` by explicitly passing in ``run_name`` and ``experiment_name``.
If ``run_name`` and ``experiment_name`` are passed in, they are honored over
the default experiment config in the job environment or the run context
(i.e. within the ``with`` block).
Note:
Both ``run_name`` and ``experiment_name`` should be supplied to make this usage work.
Otherwise, you may get a ``ValueError``.
.. code:: python
with load_run(experiment_name="my-exp", run_name="my-run") as run:
run.log_metric(...)
...
2. Use the ``load_run`` in a job script without supplying ``run_name`` and ``experiment_name``.
In this case, the default experiment config (specified when creating the job) is fetched
from the job environment to load the run.
.. code:: python
# In a job script
with load_run() as run:
run.log_metric(...)
...
3. Use the ``load_run`` in a notebook within a run context (i.e. the ``with`` block)
but without supplying ``run_name`` and ``experiment_name``.
Every time we call ``with Run(...) as run1:``, the initialized ``run1`` is tracked
in the run context. Then when we call ``load_run()`` under this with statement, the ``run1``
in the context is loaded by default.
.. code:: python
# In a notebook
with Run(experiment_name="my-exp", run_name="my-run", ...) as run1:
run1.log_parameter(...)
with load_run() as run2: # run2 is the same object as run1
run2.log_metric(...)
...
Args:
run_name (str): The name of the run to be loaded (default: None).
If it is None, the ``RunName`` in the ``ExperimentConfig`` of the job will be
fetched to load the run.
experiment_name (str): The name of the Experiment that the to be loaded run
is associated with (default: None).
Note: the experiment_name must be supplied along with a valid run_name.
Otherwise, it will be ignored.
sagemaker_session (sagemaker.session.Session): Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, one is created using the
default AWS configuration chain.
artifact_bucket (str): The S3 bucket to upload the artifact to.
If not specified, the default bucket defined in `sagemaker_session`
will be used.
artifact_prefix (str): The S3 key prefix used to generate the S3 path
to upload the artifact to (default: "trial-component-artifacts").
Returns:
Run: The loaded Run object.
"""
environment = _RunEnvironment.load()
verify_load_input_names(run_name=run_name, experiment_name=experiment_name)
if run_name:
logger.warning(
"run_name is explicitly supplied in load_run, "
"which will be prioritized to load the Run object. "
"In other words, the run name in the experiment config, fetched from the "
"job environment or the current run context, will be ignored."
)
run_instance = Run(
experiment_name=experiment_name,
run_name=run_name,
sagemaker_session=sagemaker_session or _utils.default_session(),
artifact_bucket=artifact_bucket,
artifact_prefix=artifact_prefix,
)
elif _RunContext.get_current_run():
run_instance = _RunContext.get_current_run()
elif environment:
exp_config = get_tc_and_exp_config_from_job_env(
environment=environment,
sagemaker_session=sagemaker_session or _utils.default_session(),
)
run_name = Run._extract_run_name_from_tc_name(
trial_component_name=exp_config[RUN_NAME],
experiment_name=exp_config[EXPERIMENT_NAME],
)
experiment_name = exp_config[EXPERIMENT_NAME]
run_instance = Run(
experiment_name=experiment_name,
run_name=run_name,
sagemaker_session=sagemaker_session or _utils.default_session(),
artifact_bucket=artifact_bucket,
artifact_prefix=artifact_prefix,
)
else:
raise RuntimeError(
"Failed to load a Run object. "
"Please make sure a Run object has been initialized already."
)
run_instance._in_load = True
return run_instance
def list_runs(
experiment_name: str,
created_before: Optional[datetime.datetime] = None,
created_after: Optional[datetime.datetime] = None,
sagemaker_session: Optional["Session"] = None,
max_results: Optional[int] = None,
next_token: Optional[str] = None,
sort_by: SortByType = SortByType.CREATION_TIME,
sort_order: SortOrderType = SortOrderType.DESCENDING,
) -> list:
"""Return a list of ``Run`` objects matching the given criteria.
Args:
experiment_name (str): Only Run objects related to the specified experiment
are returned.
created_before (datetime.datetime): Return Run objects created before this instant
(default: None).
created_after (datetime.datetime): Return Run objects created after this instant
(default: None).
sagemaker_session (sagemaker.session.Session): Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, one is created using the
default AWS configuration chain.
max_results (int): Maximum number of Run objects to retrieve (default: None).
next_token (str): Token for next page of results (default: None).
sort_by (SortByType): The property to sort results by. One of NAME, CREATION_TIME
(default: CREATION_TIME).
sort_order (SortOrderType): One of ASCENDING, or DESCENDING (default: DESCENDING).
Returns:
list: A list of ``Run`` objects.
"""
# all trial components retrieved by default
tc_summaries = _TrialComponent.list(
experiment_name=experiment_name,
created_before=created_before,
created_after=created_after,
sort_by=sort_by.value,
sort_order=sort_order.value,
sagemaker_session=sagemaker_session,
max_results=max_results,
next_token=next_token,
)
run_list = []
for tc_summary in tc_summaries:
if not is_run_trial_component(
trial_component_name=tc_summary.trial_component_name,
sagemaker_session=sagemaker_session,
):
continue
run_instance = Run(
experiment_name=experiment_name,
run_name=Run._extract_run_name_from_tc_name(
trial_component_name=tc_summary.trial_component_name,
experiment_name=experiment_name,
),
sagemaker_session=sagemaker_session,
)
run_list.append(run_instance)
return run_list