# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Contains the SageMaker Experiments Tracker class.""" import datetime import os import mimetypes import urllib.parse import urllib.request import logging import botocore import json from math import isnan, isinf from numbers import Number from smexperiments._utils import get_module from os.path import join import dateutil from smexperiments import api_types, metrics, trial_component, _utils, _environment class Tracker(object): """A SageMaker Experiments Tracker. Use a tracker object to record experiment information to a SageMaker trial component. A new tracker can be created in two ways: - By loading an existing trial component with :meth:`~smexperiments.tracker.Tracker.load` - By creating a tracker for a new trial component with :meth:`~smexperiments.tracker.Tracker.create`. When creating a tracker within a SageMaker training or processing job, use the ``load`` method with no arguments to track artifacts to the trial component automatically created for your job. When tracking within a Jupyter notebook running in SageMaker, use the ``create`` method to automatically create a new trial component. Trackers are Python context managers and you can use them using the Python ``with`` keyword. Exceptions thrown within the with block will cause the tracker's trial component to be marked as failed. Start and end times are automatically set when using the with statement and the trial component is saved to SageMaker at the end of the block. Note that parameters and input/output artifacts are saved to SageMaker directly via the UpdateTrialComponent operation. In contrast metrics (via `log_metric` method) are saved to a file, which is then ingested into SageMaker via a metrics agent which only runs on training job hosts. As a result any metrics logged in non-training job host environments will not be ingested into SageMaker. Parameters: trial_component (TrialComponent): The trial component tracked. """ trial_component = None _metrics_writer = None _in_sagemaker_job = False _artifact_uploader = None def __init__(self, trial_component, metrics_writer, artifact_uploader, lineage_artifact_tracker): self.trial_component = trial_component self.trial_component.parameters = self.trial_component.parameters or {} self.trial_component.input_artifacts = self.trial_component.input_artifacts or {} self.trial_component.output_artifacts = self.trial_component.output_artifacts or {} self._artifact_uploader = artifact_uploader self._metrics_writer = metrics_writer self._warned_on_metrics = False self._lineage_artifact_tracker = lineage_artifact_tracker @classmethod def load( cls, trial_component_name=None, artifact_bucket=None, artifact_prefix=None, boto3_session=None, sagemaker_boto_client=None, training_job_name=None, processing_job_name=None, ): """Create a new ``Tracker`` by loading an existing trial component. Note that `log_metric` will only work from a training job host. Examples: .. code-block:: python from smexperiments import tracker # load tracker from already existing trial component my_tracker = tracker.Tracker.load(trial_component_name='xgboost') # load tracker from a training job name my_tracker = tracker.Tracker.load( training_job_name=estimator.latest_training_job.name) # load tracker from a processing job name my_tracker = tracker.Tracker.load( processing_job_name=my_processing_job.name) Args: trial_component_name: (str, optional). The name of the trial component to track. If specified, this trial component must exist in SageMaker. If you invoke this method in a running SageMaker training or processing job, then trial_component_name can be left empty. In this case, the Tracker will resolve the trial component automatically created for your SageMaker Job. artifact_bucket: (str, optional) The name of the S3 bucket to store artifacts to. artifact_prefix: (str, optional) The prefix to write artifacts to within ``artifact_bucket`` boto3_session: (boto3.Session, optional) The boto3.Session to use to interact with AWS services. If not specified a new default boto3 session will be created. sagemaker_boto_client: (boto3.Client, optional) The SageMaker AWS service client to use. If not specified a new client will be created from the specified ``boto3_session`` or default boto3.Session. training_job_name: (str, optional). The name of the training job to track via trial processing_job_name: (str, optional). The name of the processing job to track via trial component. Returns: Tracker: The tracker for the given trial component. Raises: ValueError: If the trial component failed to load. """ boto3_session = boto3_session or _utils.boto_session() sagemaker_boto_client = sagemaker_boto_client or _utils.sagemaker_client() tce = _environment.TrialComponentEnvironment.load() if training_job_name and not trial_component_name: trial_component_name = training_job_name + "-aws-training-job" elif processing_job_name and not trial_component_name: trial_component_name = processing_job_name + "-aws-processing-job" # Resolve the trial component for this tracker to track: If a trial component name was passed in, then load # and track that trial component. Otherwise, try to find a trial component given the current environment, # failing if we're unable to load one. if trial_component_name: tc = trial_component.TrialComponent.load( trial_component_name=trial_component_name, sagemaker_boto_client=sagemaker_boto_client ) elif tce: tc = tce.get_trial_component(sagemaker_boto_client) else: raise ValueError('Could not load TrialComponent. Specify a trial_component_name or invoke "create"') # metrics require the metrics agent running on training job hosts if not trial_component_name and tce.environment_type == _environment.EnvironmentType.SageMakerTrainingJob: metrics_writer = metrics.SageMakerFileMetricsWriter() else: metrics_writer = None tracker = cls( tc, metrics_writer, _ArtifactUploader(tc.trial_component_name, artifact_bucket, artifact_prefix, boto3_session), _LineageArtifactTracker(tc.trial_component_arn, sagemaker_boto_client), ) tracker._in_sagemaker_job = True if tce else False return tracker @classmethod def create( cls, base_trial_component_name="TrialComponent", display_name=None, artifact_bucket=None, artifact_prefix=None, boto3_session=None, sagemaker_boto_client=None, ): """Create a new ``Tracker`` by creating a new trial component. Note that `log_metric` will _not_ work when tracker is created this way. Examples .. code-block:: python from smexperiments import tracker my_tracker = tracker.Tracker.create() Args: base_trial_component_name: (str,optional). The name of the trial component resource that will be appended with a timestamp. Defaults to "TrialComponent". display_name: (str, optional). The display name of the trial component to track. artifact_bucket: (str, optional) The name of the S3 bucket to store artifacts to. artifact_prefix: (str, optional) The prefix to write artifacts to within ``artifact_bucket`` boto3_session: (boto3.Session, optional) The boto3.Session to use to interact with AWS services. If not specified a new default boto3 session will be created. sagemaker_boto_client: (boto3.Client, optional) The SageMaker AWS service client to use. If not specified a new client will be created from the specified ``boto3_session`` or default boto3.Session. Returns: Tracker: The tracker for the new trial component. """ boto3_session = boto3_session or _utils.boto_session() sagemaker_boto_client = sagemaker_boto_client or _utils.sagemaker_client() tc = trial_component.TrialComponent.create( trial_component_name=_utils.name(base_trial_component_name), display_name=display_name, sagemaker_boto_client=sagemaker_boto_client, ) # metrics require the metrics agent running on training job hosts and in which case the load # method should be used because it loads the trial component associated with the currently # running training job metrics_writer = None return cls( tc, metrics_writer, _ArtifactUploader(tc.trial_component_name, artifact_bucket, artifact_prefix, boto3_session), _LineageArtifactTracker(tc.trial_component_arn, sagemaker_boto_client), ) def log_parameter(self, name, value): """Record a single parameter value for this trial component. Overwrites any previous value recorded for the specified parameter name. Examples .. code-block:: python # log hyper parameter of learning rate my_tracker.log_parameter('learning_rate', 0.01) Args: name (str): The name of the parameter value (str or numbers.Number): The value of the parameter """ if self._is_input_valid("parameter", name, value): self.trial_component.parameters[name] = value def log_parameters(self, parameters): """Record a collection of parameter values for this trial component. Examples .. code-block:: python # log multiple hyper parameters used in training my_tracker.log_parameters({"learning_rate": 1.0, "gamma": 0.9, "dropout": 0.5}) Args: parameters (dict[str, str or numbers.Number]): The parameters to record. """ filtered_parameters = { key: value for (key, value) in parameters.items() if self._is_input_valid("parameter", key, value) } self.trial_component.parameters.update(filtered_parameters) def log_input(self, name, value, media_type=None): """Record a single input artifact for this trial component. Overwrites any previous value recorded for the specified input name. Examples .. code-block:: python # log input dataset s3 location my_tracker.log_input(name='input', value='s3://inputs/path') Args: name (str): The name of the input value. value (str): The value. media_type (str, optional): The MediaType (MIME type) of the value """ if len(self.trial_component.input_artifacts) >= 30: raise ValueError("Cannot add more than 30 input_artifacts under tracker trial_component.") self.trial_component.input_artifacts[name] = api_types.TrialComponentArtifact(value, media_type=media_type) def log_output(self, name, value, media_type=None): """Record a single output artifact for this trial component. Overwrites any previous value recorded for the specified output name. Examples .. code-block:: python # log output dataset s3 location my_tracker.log_output(name='prediction', value='s3://outputs/path') Args: name (str): The name of the output value. value (str): The value. media_type (str, optional): The MediaType (MIME type) of the value. """ if len(self.trial_component.output_artifacts) >= 30: raise ValueError("Cannot add more than 30 output_artifacts under tracker trial_component") self.trial_component.output_artifacts[name] = api_types.TrialComponentArtifact(value, media_type=media_type) def log_artifacts(self, directory, media_type=None): """Upload all the files under the directory to s3 and store it as artifacts in this trial component. The file name is used as the artifact name Examples .. code-block:: python # log local artifact my_tracker.log_artifact(directory='/local/path) Args: directory (str): The directory of the local files to upload. media_type (str, optional): The MediaType (MIME type) of the file. If not specified, this library will attempt to infer the media type from the file extension of ``file_path``. """ for dir_file in os.listdir(directory): file_path = join(directory, dir_file) artifact_name = os.path.splitext(dir_file)[0] self.log_artifact(file_path=file_path, name=artifact_name, media_type=media_type) def log_artifact(self, file_path, name=None, media_type=None): """Legacy overload method to prevent breaking existing code. Examples .. code-block:: python # log a local file my_tracker.log_artifact('output/artifact_data.csv', name='prediction') Args: file_path (str): Path to the file to log. name (str, optional): Name of the artifact. Defaults to None. media_type (str, optional): Media type of the artifact. Defaults to None. """ self.log_output_artifact(file_path, name, media_type) def log_output_artifact(self, file_path, name=None, media_type=None): """Upload a local file to s3 and store it as an output artifact in this trial component. Examples .. code-block:: python # log local artifact my_tracker.log_output_artifact(file_path='/local/path/artifact.tar.gz') Args: file_path (str): The path of the local file to upload. name (str, optional): The name of the artifact. media_type (str, optional): The MediaType (MIME type) of the file. If not specified, this library will attempt to infer the media type from the file extension of ``file_path``. """ if len(self.trial_component.output_artifacts) >= 30: raise ValueError("Cannot add more than 30 output_artifacts under tracker trial_component") media_type = media_type or _guess_media_type(file_path) name = name or _resolve_artifact_name(file_path) s3_uri, etag = self._artifact_uploader.upload_artifact(file_path) self.trial_component.output_artifacts[name] = api_types.TrialComponentArtifact( value=s3_uri, media_type=media_type ) self._lineage_artifact_tracker.add_output_artifact(name, s3_uri, etag, media_type) def log_input_artifact(self, file_path, name=None, media_type=None): """Upload a local file to s3 and store it as an input artifact in this trial component. Examples .. code-block:: python # log local artifact my_tracker.log_input_artifact(file_path='/local/path/artifact.tar.gz') Args: file_path (str): The path of the local file to upload. name (str, optional): The name of the artifact. media_type (str, optional): The MediaType (MIME type) of the file. If not specified, this library will attempt to infer the media type from the file extension of ``file_path``. """ if len(self.trial_component.input_artifacts) >= 30: raise ValueError("Cannot add more than 30 input_artifacts under tracker trial_component.") media_type = media_type or _guess_media_type(file_path) name = name or _resolve_artifact_name(file_path) s3_uri, etag = self._artifact_uploader.upload_artifact(file_path) self.trial_component.input_artifacts[name] = api_types.TrialComponentArtifact( value=s3_uri, media_type=media_type ) self._lineage_artifact_tracker.add_input_artifact(name, s3_uri, etag, media_type) def log_metric(self, metric_name, value, timestamp=None, iteration_number=None): """Record a custom scalar metric value for this TrialComponent. Note that this method is for manual custom metrics, for automatic metrics see the `enable_sagemaker_metrics` parameter on the `estimator` class in the main SageMaker SDK. Note that metrics logged with this method will only appear in SageMaker when this method is called from a training job host. Examples .. code-block:: python for epoch in range(epochs): # your training logic and calculate accuracy and loss my_tracker.log_metric(metric_name='accuracy', value=0.9, iteration_number=epoch) my_tracker.log_metric(metric_name='loss', value=0.03, iteration_number=epoch) Args: metric_name (str): The name of the metric. value (number): The value of the metric. timestamp (datetime.datetime|number, optional): The timestamp of the metric. If specified, should either be a datetime.datetime object or a number representing the seconds since the epoch. If not specified, the current local time will be used. iteration_number (number, optional): The integer iteration number of the metric value. Raises: AttributeError: If the metrics writer is not initialized. """ try: if self._is_input_valid("metric", metric_name, value): self._metrics_writer.log_metric(metric_name, value, timestamp, iteration_number) except AttributeError: if not self._metrics_writer: if not self._warned_on_metrics: logging.warning("Cannot write metrics in this environment.") self._warned_on_metrics = True else: raise def log_table(self, title=None, values=None, data_frame=None, output_artifact=True): """Record a table of values to an artifact. Rendering in Studio is not currently supported. Examples .. code-block:: python table_data = { "x": [1,2,3], "y": [4,5,6] } my_tracker.log_table('SampleData',table_data) # or log a data frame df = pd.DataFrame(data=table_data) my_tracker.log_table('SampleData',df) Args: title (str, optional): Title of the table. Defaults to None. values ([type], optional): A dictionary of values. i.e. {"x": [1,2,3], "y": [1,2,3]}. Defaults to None. data_frame (DataFrame, optional): Pandas dataframe alternative to values. Defaults to None. output_artifact (bool): Determines direction of association to the trial component. Defaults to output artifact. If False will be an input artifact. Raises: ValueError: If values or data_frame are invalid. """ if values is None and data_frame is None: raise ValueError("Either values or data_frame must be supplied.") if values is not None and data_frame is not None: raise ValueError("Only one of values or data_frame may be supplied.") if values is not None: for key in values: if "list" not in str(type(values[key])): raise ValueError( 'Table values should be list. i.e. {"x": [1,2,3]}, instead was ' + str(type(values[key])) ) if data_frame is not None: values = _ArtifactConverter.convert_data_frame_to_values(data_frame) fields = _ArtifactConverter.convert_data_frame_to_fields(data_frame) else: fields = _ArtifactConverter.convert_dict_to_fields(values) data = {"type": "Table", "version": 0, "title": title, "fields": fields, "data": values} self._log_graph_artifact(title, data, "Table", output_artifact) def log_precision_recall( self, y_true, predicted_probabilities, positive_label=None, title=None, output_artifact=True, no_skill=None, ): """Log a precision recall graph artifact. You can view the artifact in the charts tab of the Trial Component UI. If your job is created by a pipeline execution you can view the artifact by selecting the corresponding step in the pipelines UI. See also `SageMaker Pipelines `_ Requires sklearn. Examples .. code-block:: python y_true = [0, 0, 1, 1] y_scores = [0.1, 0.4, 0.35, 0.8] no_skill = len(y_true[y_true==1]) / len(y_true) my_tracker.log_precision_recall(y_true, y_scores, no_skill=no_skill) Args: y_true (array): True labels. If labels are not binary then positive_label should be given. predicted_probabilities (array): Estimated/predicted probabilities. positive_label (str or int, optional): Label of the positive class. Defaults to None. title (str, optional): Title of the graph, Defaults to none. output_artifact (boolean, optional): Determines if the artifact is associated with the Trial Component as an output artifact. If False will be an input artifact. no_skill (int): The precision threshold under which the classifier cannot discriminate between the classes and would predict a random class or a constant class in all cases. Raises: ValueError: If length mismatch between y_true and predicted_probabilities. """ if len(y_true) != len(predicted_probabilities): raise ValueError("Mismatch between actual values and predicted probabilities.") get_module("sklearn") from sklearn.metrics import precision_recall_curve, average_precision_score kwargs = {} if positive_label: kwargs["positive_label"] = positive_label precision, recall, thresholds = precision_recall_curve(y_true, predicted_probabilities, **kwargs) kwargs["average"] = "micro" ap = average_precision_score(y_true, predicted_probabilities, **kwargs) data = { "type": "PrecisionRecallCurve", "version": 0, "title": title, "precision": precision.tolist(), "recall": recall.tolist(), "averagePrecisionScore": ap, "noSkill": no_skill, } self._log_graph_artifact(title, data, "PrecisionRecallCurve", output_artifact) def log_roc_curve( self, y_true, y_score, title=None, output_artifact=True, ): """Log a receiver operating characteristic (ROC curve) artifact. You can view the artifact in the charts tab of the Trial Component UI. If your job is created by a pipeline execution you can view the artifact by selecting the corresponding step in the pipelines UI. See also `SageMaker Pipelines `. Requires sklearn. Examples .. code-block:: python y_true = [0, 0, 1, 1] y_scores = [0.1, 0.4, 0.35, 0.8] my_tracker.log_roc_curve(y_true, y_scores) Args: y_true (array): True labels. If labels are not binary then positive_label should be given. y_score (array): Estimated/predicted probabilities. title (str, optional): Title of the graph, Defaults to none. output_artifact (boolean, optional): Determines if the artifact is associated with the Trial Component as an output artifact. If False will be an input artifact. Raises: ValueError: If mismatch between y_true and predicted_probabilities. """ if len(y_true) != len(y_score): raise ValueError("Length mismatch between actual labels and predicted scores.") get_module("sklearn") from sklearn.metrics import roc_curve, auc fpr, tpr, thresholds = roc_curve(y_true, y_score) auc = auc(fpr, tpr) data = { "type": "ROCCurve", "version": 0, "title": title, "falsePositiveRate": fpr.tolist(), "truePositiveRate": tpr.tolist(), "areaUnderCurve": auc, } self._log_graph_artifact(title, data, "ROCCurve", output_artifact) def log_confusion_matrix( self, y_true, y_pred, title=None, output_artifact=True, ): """Log a confusion matrix artifact. You can view the artifact in the charts tab of the Trial Component UI. If your job is created by a pipeline execution you can view the artifact by selecting the corresponding step in the pipelines UI. See also `SageMaker Pipelines `_ Requires sklearn. Examples .. code-block:: python y_true = [2, 0, 2, 2, 0, 1] y_pred = [0, 0, 2, 2, 0, 2] my_tracker.log_confusion_matrix(y_true, y_pred) Args: y_true (array): True labels. y_pred (array): Predicted labels. title (str, optional): Title of the graph, Defaults to none. output_artifact (boolean, optional): Determines if the artifact is associated with the Trial Component as an output artifact. If False will be an input artifact. Raises: ValueError: If length mismatch between y_true and y_pred. """ if len(y_true) != len(y_pred): raise ValueError("Length mismatch between actual labels and predicted labels.") get_module("sklearn") from sklearn.metrics import confusion_matrix matrix = confusion_matrix(y_true, y_pred) data = {"type": "ConfusionMatrix", "version": 0, "title": title, "confusionMatrix": matrix.tolist()} self._log_graph_artifact(title, data, "ConfusionMatrix", output_artifact) def _log_graph_artifact(self, name, data, graph_type, output_artifact): """Logs an artifact by uploading data to S3, creating an artifact, and associating that artifact with the tracker's Trial Component. Args: name (str): Name of the artifact. data (dict): Artifacts data that will be saved to S3. graph_type (str): The type of the artifact. output_artifact (bool): Determines the direction of association with the trial component. Defaults to True (output association). If False will be input association. """ # generate an artifact name artifact_name = name if not artifact_name: artifact_name = ( graph_type + "-" + str(datetime.datetime.now(dateutil.tz.tzlocal()).timestamp()).split(".")[0] ) # create a json file in S3 s3_uri, etag = self._artifact_uploader.upload_object_artifact(artifact_name, data, file_extension="json") # create an artifact and association for the table if output_artifact: self._lineage_artifact_tracker.add_output_artifact(artifact_name, s3_uri, etag, graph_type) else: self._lineage_artifact_tracker.add_input_artifact(artifact_name, s3_uri, etag, graph_type) def _is_input_valid(self, input_type, field_name, field_value): if isinstance(field_value, Number) and (isnan(field_value) or isinf(field_value)): logging.warning(f"Failed to log {input_type} {field_name}. Received invalid value: {field_value}.") return False return True def __enter__(self): """Updates the start time of the tracked trial component. Returns: obj: self. """ self._start_time = datetime.datetime.now(dateutil.tz.tzlocal()) if not self._in_sagemaker_job: self.trial_component.start_time = self._start_time self.trial_component.status = api_types.TrialComponentStatus(primary_status="InProgress") return self def __exit__(self, exc_type, exc_value, exc_traceback): """Updates the end time of the tracked trial component. exc_value (str): The exception value. exc_traceback (str): The stack trace of the exception. """ self._end_time = datetime.datetime.now(dateutil.tz.tzlocal()) if not self._in_sagemaker_job: self.trial_component.end_time = self._end_time if exc_value: self.trial_component.status = api_types.TrialComponentStatus( primary_status="Failed", message=str(exc_value) ) else: self.trial_component.status = api_types.TrialComponentStatus(primary_status="Completed") self.close() def close(self): """Close this tracker and save state to SageMaker.""" try: # update the trial component with additions from tracker self.trial_component.save() # create lineage entities for the artifacts self._lineage_artifact_tracker.save() finally: if self._metrics_writer: self._metrics_writer.close() def _resolve_artifact_name(file_path): _, filename = os.path.split(file_path) if filename: return filename else: return _utils.name("artifact") class _ArtifactUploader(object): def __init__(self, trial_component_name, artifact_bucket, artifact_prefix, boto_session): self.s3_client = boto_session.client("s3") self.boto_session = boto_session self.trial_component_name = trial_component_name self.artifact_bucket = artifact_bucket self.artifact_prefix = artifact_prefix or "trial-component-artifacts" def upload_artifact(self, file_path): """Upload an artifact file to S3 and record the artifact S3 key with this trial run. Args: file_path (str): the file path of the artifact Returns: (str, str): The s3 URI of the uploaded file and the etag of the file. Raises: ValueError: If file does not exist. """ file_path = os.path.expanduser(file_path) if not os.path.isfile(file_path): raise ValueError("{} does not exist or is not a file. Please supply a file path.".format(file_path)) if not self.artifact_bucket: self.artifact_bucket = _utils.get_or_create_default_bucket(self.boto_session) artifact_name = os.path.basename(file_path) artifact_s3_key = "{}/{}/{}".format(self.artifact_prefix, self.trial_component_name, artifact_name) self.s3_client.upload_file(file_path, self.artifact_bucket, artifact_s3_key) etag = self._try_get_etag(artifact_s3_key) return "s3://{}/{}".format(self.artifact_bucket, artifact_s3_key), etag def upload_object_artifact(self, artifact_name, obj, file_extension=None): """Upload an artifact object to S3 and record the artifact S3 key with this trial component. Args: artifact_name (str): the name of the artifact. obj (obj): the object of the artifact file_extension (str): Optional file extension. Returns: str: The s3 URI of the uploaded file and the version of the file """ if not self.artifact_bucket: self.artifact_bucket = _utils.get_or_create_default_bucket(self.boto_session) if file_extension: artifact_name = artifact_name + ("" if file_extension.startswith(".") else ".") + file_extension artifact_s3_key = "{}/{}/{}".format(self.artifact_prefix, self.trial_component_name, artifact_name) self.s3_client.put_object(Body=json.dumps(obj), Bucket=self.artifact_bucket, Key=artifact_s3_key) etag = self._try_get_etag(artifact_s3_key) return "s3://{}/{}".format(self.artifact_bucket, artifact_s3_key), etag def _try_get_etag(self, key): try: response = self.s3_client.head_object(Bucket=self.artifact_bucket, Key=key) return response["ETag"] except botocore.exceptions.ClientError: # requires read permissions pass return None def _guess_media_type(file_path): """Guesses the media type of a file based on its file name. Args: file_path (str): Path to file. Returns: str: The guessed media type. """ file_url = urllib.parse.urljoin("file:", urllib.request.pathname2url(file_path)) guessed_media_type, _ = mimetypes.guess_type(file_url, strict=False) return guessed_media_type class _LineageArtifactTracker(object): def __init__(self, trial_component_arn, sagemaker_client): self.trial_component_arn = trial_component_arn self.sagemaker_client = sagemaker_client self.artifacts = [] def add_input_artifact(self, name, source_uri, etag, artifact_type): artifact = _LineageArtifact( name, source_uri, etag, dest_arn=self.trial_component_arn, artifact_type=artifact_type ) self.artifacts.append(artifact) def add_output_artifact(self, name, source_uri, etag, artifact_type): artifact = _LineageArtifact( name, source_uri, etag, source_arn=self.trial_component_arn, artifact_type=artifact_type ) self.artifacts.append(artifact) def save(self): for artifact in self.artifacts: artifact.create_artifact(self.sagemaker_client) artifact.add_association(self.sagemaker_client) class _LineageArtifact(object): def __init__(self, name, source_uri, etag, source_arn=None, dest_arn=None, artifact_type=None): self.name = name self.source_uri = source_uri self.etag = etag self.source_arn = source_arn self.dest_arn = dest_arn self.artifact_arn = None self.artifact_type = artifact_type if artifact_type else "Tracker" def create_artifact(self, sagemaker_client): source_ids = [] if self.etag: source_ids.append({"SourceIdType": "S3ETag", "Value": self.etag}) response = sagemaker_client.create_artifact( ArtifactName=self.name, ArtifactType=self.artifact_type, Source={"SourceUri": self.source_uri, "SourceTypes": source_ids}, ) self.artifact_arn = response["ArtifactArn"] def add_association(self, sagemaker_client): source_arn = self.source_arn if self.source_arn else self.artifact_arn dest_arn = self.dest_arn if self.dest_arn else self.artifact_arn # if the trial component (job) is the source then it produced the artifact, otherwise the # artifact contributed to the trial component (job) association_edge_type = "Produced" if self.source_arn else "ContributedTo" sagemaker_client.add_association( SourceArn=source_arn, DestinationArn=dest_arn, AssociationType=association_edge_type ) class _ArtifactConverter(object): """Converts data to easily consumed by Studio.""" @classmethod def convert_dict_to_fields(cls, values): """Converts a dictionary to list of field types. Args: values (dict): The values of the dictionary. Returns: dict: Dictionary of fields. """ fields = [] for key in values: fields.append({"name": key, "type": "string"}) return fields @classmethod def convert_data_frame_to_values(cls, data_frame): """Converts a pandas data frame to a dictionary in the table artifact format. Args: data_frame (DataFrame): The pandas data frame to convert. Returns: [type]: dictionary of values in the format needed to log the artifact. """ df_dict = data_frame.to_dict() new_df = {} for key in df_dict: col_value = df_dict[key] values = [] for row_key in col_value: values.append(col_value[row_key]) new_df[key] = values return new_df @classmethod def convert_data_frame_to_fields(cls, data_frame): """Converts a dataframe to a dictionary describing the type of fields. Args: data_frame(str): The data frame to convert. Returns: dict: Dictionary of fields. """ fields = [] for key in data_frame: col_type = data_frame.dtypes[key] fields.append({"name": key, "type": _ArtifactConverter.convert_df_type_to_simple_type(col_type)}) return fields @classmethod def convert_df_type_to_simple_type(cls, data_frame_type): """Converts a dataframe type to a type for rendering a table in Studio. Args: data_frame_type (str): The pandas type. Returns: str: The type of the table field. """ type_pairs = [ ("datetime", "datetime"), ("float", "number"), ("int", "number"), ("uint", "number"), ("boolean", "boolean"), ] for pair in type_pairs: if str(data_frame_type).lower().startswith(pair[0]): return pair[1] return "string"