# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Contains classes to manage metrics for Sagemaker Experiment""" from __future__ import absolute_import import datetime import logging import os import time import threading import queue import dateutil.tz from sagemaker.session import Session METRICS_DIR = os.environ.get("SAGEMAKER_METRICS_DIRECTORY", ".") METRIC_TS_LOWER_BOUND_TO_NOW = 1209600 # on seconds METRIC_TS_UPPER_BOUND_FROM_NOW = 7200 # on seconds BATCH_SIZE = 10 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class _RawMetricData(object): """A Raw Metric Data Object""" MetricName = None Value = None Timestamp = None Step = None def __init__(self, metric_name, value, timestamp=None, step=None): """Construct a `_RawMetricData` instance. Args: metric_name (str): The name of the metric. value (float): The value of the metric. timestamp (datetime.datetime or float or str): Timestamp of the metric. If not specified, the current UTC time will be used. step (int): Iteration number of the metric (default: None). """ if timestamp is None: timestamp = time.time() elif isinstance(timestamp, datetime.datetime): # If the input is a datetime then convert it to UTC time. # Assume a naive datetime is in local timezone if not timestamp.tzinfo: timestamp = timestamp.replace(tzinfo=dateutil.tz.tzlocal()) timestamp = (timestamp - timestamp.utcoffset()).replace(tzinfo=datetime.timezone.utc) timestamp = timestamp.timestamp() else: timestamp = float(timestamp) if timestamp < (time.time() - METRIC_TS_LOWER_BOUND_TO_NOW) or timestamp > ( time.time() + METRIC_TS_UPPER_BOUND_FROM_NOW ): raise ValueError( "Supplied timestamp %f is invalid." " Timestamps must be between two weeks before and two hours from now." % timestamp ) value = float(value) self.MetricName = metric_name self.Value = float(value) self.Timestamp = timestamp if step is not None: if not isinstance(step, int): raise ValueError("step must be int.") self.Step = step def to_record(self): """Convert the `_RawMetricData` object to dict""" return self.__dict__ def to_raw_metric_data(self): """Converts the metric data to a BatchPutMetrics RawMetricData item""" # Convert timestamp from float to timestamp str. # Otherwise will get ParamValidationError raw_metric_data = { "MetricName": self.MetricName, "Value": self.Value, "Timestamp": str(int(self.Timestamp)), } if self.Step is not None: raw_metric_data["Step"] = int(self.Step) return raw_metric_data def __str__(self): """String representation of the `_RawMetricData` object.""" return repr(self) def __repr__(self): """Return a string representation of this _RawMetricData` object.""" return "{}({})".format( type(self).__name__, ",".join(["{}={}".format(k, repr(v)) for k, v in vars(self).items()]), ) class _MetricsManager(object): """Collects metrics and sends them directly to SageMaker Metrics data plane APIs.""" def __init__(self, trial_component_name: str, sagemaker_session: Session, sink=None) -> None: """Initialize a `_MetricsManager` instance Args: trial_component_name (str): The Name of the Trial Component to log metrics to sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, one is created using the default AWS configuration chain. sink (object): The metrics sink to use. """ if sink is None: self.sink = _SyncMetricsSink( trial_component_name, sagemaker_session.sagemaker_metrics_client ) else: self.sink = sink def log_metric(self, metric_name, value, timestamp=None, step=None): """Sends a metric to metrics service.""" metric_data = _RawMetricData(metric_name, value, timestamp, step) self.sink.log_metric(metric_data) def __enter__(self): """Return self""" return self def __exit__(self, exc_type, exc_value, exc_traceback): """Execute self.close()""" self.sink.close() def close(self): """Close the metrics object.""" self.sink.close() class _SyncMetricsSink(object): """Collects metrics and sends them directly to metrics service.""" def __init__(self, trial_component_name, metrics_client) -> None: """Initialize a `_SyncMetricsSink` instance Args: trial_component_name (str): The Name of the Trial Component to log metrics. metrics_client (boto3.client): boto client for metrics service """ self._trial_component_name = trial_component_name self._metrics_client = metrics_client self._buffer = [] def log_metric(self, metric_data): """Sends a metric to metrics service.""" # this is a simplistic solution which calls BatchPutMetrics # on the same thread as the client code self._buffer.append(metric_data) self._drain() def _drain(self, close=False): """Pops off all metrics in the buffer and starts sending them to metrics service.""" if not self._buffer: return if len(self._buffer) < BATCH_SIZE and not close: return # pop all the available metrics available_metrics, self._buffer = self._buffer, [] self._send_metrics(available_metrics) def _send_metrics(self, metrics): """Calls BatchPutMetrics directly on the metrics service.""" while metrics: batch, metrics = ( metrics[:BATCH_SIZE], metrics[BATCH_SIZE:], ) request = self._construct_batch_put_metrics_request(batch) response = self._metrics_client.batch_put_metrics(**request) errors = response["Errors"] if "Errors" in response else None if errors: message = errors[0]["Message"] raise Exception(f'{len(errors)} errors with message "{message}"') def _construct_batch_put_metrics_request(self, batch): """Creates dictionary object used as request to metrics service.""" return { "TrialComponentName": self._trial_component_name.lower(), "MetricData": list(map(lambda x: x.to_raw_metric_data(), batch)), } def close(self): """Drains any remaining metrics.""" self._drain(close=True) class _MetricQueue(object): """A thread safe queue for sending metrics to SageMaker. Args: trial_component_name (str): the ARN of the resource metric_name (str): the name of the metric metrics_client (boto_client): the boto client for SageMaker Metrics service """ _CONSUMER_SLEEP_SECONDS = 5 def __init__(self, trial_component_name, metric_name, metrics_client): # infinite queue size self._queue = queue.Queue() self._buffer = [] self._thread = threading.Thread(target=self._run) self._started = False self._finished = False self._trial_component_name = trial_component_name self._metrics_client = metrics_client self._metric_name = metric_name self._logged_metrics = 0 def log_metric(self, metric_data): """Adds a metric data point to the queue""" self._buffer.append(metric_data) if len(self._buffer) < BATCH_SIZE: return self._enqueue_all() if not self._started: self._thread.start() self._started = True def _run(self): """Starts the metric thread which sends metrics to SageMaker in batches""" while not self._queue.empty() or not self._finished: if self._queue.empty(): time.sleep(self._CONSUMER_SLEEP_SECONDS) else: batch = self._queue.get() self._send_metrics(batch) def _send_metrics(self, metrics_batch): """Calls BatchPutMetrics directly on the metrics service.""" request = self._construct_batch_put_metrics_request(metrics_batch) self._logged_metrics += len(metrics_batch) self._metrics_client.batch_put_metrics(**request) def _construct_batch_put_metrics_request(self, batch): """Creates dictionary object used as request to metrics service.""" return { "TrialComponentName": self._trial_component_name, "MetricData": list(map(lambda x: x.to_raw_metric_data(), batch)), } def _enqueue_all(self): """Enqueue all buffered metrics to be sent to SageMaker""" available_metrics, self._buffer = self._buffer, [] if available_metrics: self._queue.put(available_metrics) def close(self): """Flushes any buffered metrics""" self._enqueue_all() self._finished = True def is_active(self): """Is the thread active (still draining metrics to SageMaker)""" return self._thread.is_alive() class _AsyncMetricsSink(object): """Collects metrics and sends them directly to metrics service.""" _COMPLETE_SLEEP_SECONDS = 1.0 def __init__(self, trial_component_name, metrics_client) -> None: """Initialize a `_AsyncMetricsSink` instance Args: trial_component_name (str): The Name of the Trial Component to log metrics to. metrics_client (boto3.client): boto client for metrics service """ self._trial_component_name = trial_component_name self._metrics_client = metrics_client self._buffer = [] self._is_draining = False self._metric_queues = {} def log_metric(self, metric_data): """Sends a metric to metrics service.""" if metric_data.MetricName in self._metric_queues: self._metric_queues[metric_data.MetricName].log_metric(metric_data) else: cur_metric_queue = _MetricQueue( self._trial_component_name, metric_data.MetricName, self._metrics_client ) self._metric_queues[metric_data.MetricName] = cur_metric_queue cur_metric_queue.log_metric(metric_data) def close(self): """Closes the metric file.""" logging.debug("Closing") for q in self._metric_queues.values(): q.close() # TODO should probably use join while any(map(lambda x: x.is_active(), self._metric_queues.values())): time.sleep(self._COMPLETE_SLEEP_SECONDS) logging.debug("Closed")