# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import annotations import os import time from typing import Any, Dict, List, Union from braket.aws.aws_session import AwsSession from braket.jobs.config import CheckpointConfig, OutputDataConfig, S3DataSourceConfig from braket.jobs.image_uris import Framework, retrieve_image from braket.jobs.local.local_job_container import _LocalJobContainer from braket.jobs.local.local_job_container_setup import setup_container from braket.jobs.metrics_data.definitions import MetricStatistic, MetricType from braket.jobs.metrics_data.log_metrics_parser import LogMetricsParser from braket.jobs.quantum_job import QuantumJob from braket.jobs.quantum_job_creation import prepare_quantum_job from braket.jobs.serialization import deserialize_values from braket.jobs_data import PersistedJobData class LocalQuantumJob(QuantumJob): """Amazon Braket implementation of a quantum job that runs locally.""" @classmethod def create( cls, device: str, source_module: str, entry_point: str = None, image_uri: str = None, job_name: str = None, code_location: str = None, role_arn: str = None, hyperparameters: Dict[str, Any] = None, input_data: Union[str, Dict, S3DataSourceConfig] = None, output_data_config: OutputDataConfig = None, checkpoint_config: CheckpointConfig = None, aws_session: AwsSession = None, local_container_update: bool = True, ) -> LocalQuantumJob: """Creates and runs job by setting up and running the customer script in a local docker container. Args: device (str): ARN for the AWS device which is primarily accessed for the execution of this job. Alternatively, a string of the format "local:<provider>/<simulator>" for using a local simulator for the job. This string will be available as the environment variable `AMZN_BRAKET_DEVICE_ARN` inside the job container when using a Braket container. source_module (str): Path (absolute, relative or an S3 URI) to a python module to be tarred and uploaded. If `source_module` is an S3 URI, it must point to a tar.gz file. Otherwise, source_module may be a file or directory. entry_point (str): A str that specifies the entry point of the job, relative to the source module. The entry point must be in the format `importable.module` or `importable.module:callable`. For example, `source_module.submodule:start_here` indicates the `start_here` function contained in `source_module.submodule`. If source_module is an S3 URI, entry point must be given. Default: source_module's name image_uri (str): A str that specifies the ECR image to use for executing the job. `image_uris.retrieve_image()` function may be used for retrieving the ECR image URIs for the containers supported by Braket. Default = `<Braket base image_uri>`. job_name (str): A str that specifies the name with which the job is created. Default: f'{image_uri_type}-{timestamp}'. code_location (str): The S3 prefix URI where custom code will be uploaded. Default: f's3://{default_bucket_name}/jobs/{job_name}/script'. role_arn (str): This field is currently not used for local jobs. Local jobs will use the current role's credentials. This may be subject to change. hyperparameters (Dict[str, Any]): Hyperparameters accessible to the job. The hyperparameters are made accessible as a Dict[str, str] to the job. For convenience, this accepts other types for keys and values, but `str()` is called to convert them before being passed on. Default: None. input_data (Union[str, Dict, S3DataSourceConfig]): Information about the training data. Dictionary maps channel names to local paths or S3 URIs. Contents found at any local paths will be uploaded to S3 at f's3://{default_bucket_name}/jobs/{job_name}/data/{channel_name}. If a local path, S3 URI, or S3DataSourceConfig is provided, it will be given a default channel name "input". Default: {}. output_data_config (OutputDataConfig): Specifies the location for the output of the job. Default: OutputDataConfig(s3Path=f's3://{default_bucket_name}/jobs/{job_name}/data', kmsKeyId=None). checkpoint_config (CheckpointConfig): Configuration that specifies the location where checkpoint data is stored. Default: CheckpointConfig(localPath='/opt/jobs/checkpoints', s3Uri=f's3://{default_bucket_name}/jobs/{job_name}/checkpoints'). aws_session (AwsSession): AwsSession for connecting to AWS Services. Default: AwsSession() local_container_update (bool): Perform an update, if available, from ECR to the local container image. Optional. Default: True. Returns: LocalQuantumJob: The representation of a local Braket Job. """ create_job_kwargs = prepare_quantum_job( device=device, source_module=source_module, entry_point=entry_point, image_uri=image_uri, job_name=job_name, code_location=code_location, role_arn=role_arn, hyperparameters=hyperparameters, input_data=input_data, output_data_config=output_data_config, checkpoint_config=checkpoint_config, aws_session=aws_session, ) job_name = create_job_kwargs["jobName"] if os.path.isdir(job_name): raise ValueError( f"A local directory called {job_name} already exists. " f"Please use a different job name." ) session = aws_session or AwsSession() algorithm_specification = create_job_kwargs["algorithmSpecification"] if "containerImage" in algorithm_specification: image_uri = algorithm_specification["containerImage"]["uri"] else: image_uri = retrieve_image(Framework.BASE, session.region) with _LocalJobContainer( image_uri=image_uri, force_update=local_container_update ) as container: env_variables = setup_container(container, session, **create_job_kwargs) container.run_local_job(env_variables) container.copy_from("/opt/ml/model", job_name) with open(os.path.join(job_name, "log.txt"), "w") as log_file: log_file.write(container.run_log) if "checkpointConfig" in create_job_kwargs: checkpoint_config = create_job_kwargs["checkpointConfig"] if "localPath" in checkpoint_config: checkpoint_path = checkpoint_config["localPath"] container.copy_from(checkpoint_path, os.path.join(job_name, "checkpoints")) run_log = container.run_log return LocalQuantumJob(f"local:job/{job_name}", run_log) def __init__(self, arn: str, run_log: str = None): """ Args: arn (str): The ARN of the job. run_log (str): The container output log of running the job with the given arn. """ if not arn.startswith("local:job/"): raise ValueError(f"Arn {arn} is not a valid local job arn") self._arn = arn self._run_log = run_log self._name = arn.partition("job/")[-1] if not run_log and not os.path.isdir(self.name): raise ValueError(f"Unable to find local job results for {self.name}") @property def arn(self) -> str: """str: The ARN (Amazon Resource Name) of the quantum job.""" return self._arn @property def name(self) -> str: """str: The name of the quantum job.""" return self._name @property def run_log(self) -> str: """Gets the run output log from running the job. Returns: str: The container output log from running the job. """ if not self._run_log: try: with open(os.path.join(self.name, "log.txt"), "r") as log_file: self._run_log = log_file.read() except FileNotFoundError: raise ValueError(f"Unable to find logs in the local job directory {self.name}.") return self._run_log def state(self, use_cached_value: bool = False) -> str: """The state of the quantum job. Args: use_cached_value (bool): If `True`, uses the value most recently retrieved value from the Amazon Braket `GetJob` operation. If `False`, calls the `GetJob` operation to retrieve metadata, which also updates the cached value. Default = `False`. Returns: str: Returns "COMPLETED". """ return "COMPLETED" def metadata(self, use_cached_value: bool = False) -> Dict[str, Any]: """When running the quantum job in local mode, the metadata is not available. Args: use_cached_value (bool): If `True`, uses the value most recently retrieved from the Amazon Braket `GetJob` operation, if it exists; if does not exist, `GetJob` is called to retrieve the metadata. If `False`, always calls `GetJob`, which also updates the cached value. Default: `False`. Returns: Dict[str, Any]: None """ pass def cancel(self) -> str: """When running the quantum job in local mode, the cancelling a running is not possible. Returns: str: None """ pass def download_result( self, extract_to: str = None, poll_timeout_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_TIMEOUT, poll_interval_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_INTERVAL, ) -> None: """When running the quantum job in local mode, results are automatically stored locally. Args: extract_to (str): The directory to which the results are extracted. The results are extracted to a folder titled with the job name within this directory. Default= `Current working directory`. poll_timeout_seconds (float): The polling timeout, in seconds, for `result()`. Default: 10 days. poll_interval_seconds (float): The polling interval, in seconds, for `result()`. Default: 5 seconds. """ pass def result( self, poll_timeout_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_TIMEOUT, poll_interval_seconds: float = QuantumJob.DEFAULT_RESULTS_POLL_INTERVAL, ) -> Dict[str, Any]: """Retrieves the job result persisted using save_job_result() function. Args: poll_timeout_seconds (float): The polling timeout, in seconds, for `result()`. Default: 10 days. poll_interval_seconds (float): The polling interval, in seconds, for `result()`. Default: 5 seconds. Returns: Dict[str, Any]: Dict specifying the job results. """ try: with open(os.path.join(self.name, "results.json"), "r") as f: persisted_data = PersistedJobData.parse_raw(f.read()) deserialized_data = deserialize_values( persisted_data.dataDictionary, persisted_data.dataFormat ) return deserialized_data except FileNotFoundError: raise ValueError(f"Unable to find results in the local job directory {self.name}.") def metrics( self, metric_type: MetricType = MetricType.TIMESTAMP, statistic: MetricStatistic = MetricStatistic.MAX, ) -> Dict[str, List[Any]]: """Gets all the metrics data, where the keys are the column names, and the values are a list containing the values in each row. Args: metric_type (MetricType): The type of metrics to get. Default: MetricType.TIMESTAMP. statistic (MetricStatistic): The statistic to determine which metric value to use when there is a conflict. Default: MetricStatistic.MAX. Example: timestamp energy 0 0.1 1 0.2 would be represented as: { "timestamp" : [0, 1], "energy" : [0.1, 0.2] } values may be integers, floats, strings or None. Returns: Dict[str, List[Any]]: The metrics data. """ parser = LogMetricsParser() current_time = str(time.time()) for line in self.run_log.splitlines(): if line.startswith("Metrics -"): parser.parse_log_message(current_time, line) return parser.get_parsed_metrics(metric_type, statistic) def logs(self, wait: bool = False, poll_interval_seconds: int = 5) -> None: """Display container logs for a given job Args: wait (bool): `True` to keep looking for new log entries until the job completes; otherwise `False`. Default: `False`. poll_interval_seconds (int): The interval of time, in seconds, between polling for new log entries and job completion (default: 5). """ return print(self.run_log)