import json import os import unittest.mock as mock import boto3 import braket.aws import braket.tracking import matplotlib.pyplot as plt plt.savefig = mock.Mock() class Mocker: mock_level = "ALL" def __init__(self): self._wrapper = ( Boto3SessionAllWrapper() if Mocker.mock_level == "ALL" else AwsSessionMinWrapper() ) braket.tracking.Tracker = mock.Mock() tracker = braket.tracking.Tracker().start() tracker.qpu_tasks_cost.return_value = 0 tracker.simulator_tasks_cost.return_value = 0 def set_get_device_result(self, result): self._wrapper.boto_client.get_device.return_value = result def set_create_quantum_task_result(self, result): self._wrapper.boto_client.create_quantum_task.return_value = result def set_get_quantum_task_result(self, result): self._wrapper.boto_client.get_quantum_task.return_value = result def set_cancel_quantum_task_result(self, result): self._wrapper.boto_client.cancel_quantum_task.return_value = result def set_task_result_return(self, result): self._wrapper.task_result_mock.return_value = result def set_task_result_side_effect(self, side_effect): self._wrapper.task_result_mock.side_effect = side_effect def set_search_result(self, result): self._wrapper.boto_client.get_paginator.return_value.paginate.return_value = result def set_log_streams_result(self, result): self._wrapper.boto_client.describe_log_streams.return_value = result def set_start_query_result(self, result): self._wrapper.boto_client.start_query.return_value = result def set_get_query_results_result(self, result): self._wrapper.boto_client.get_query_results.return_value = result def set_list_objects_v2_result(self, result): self._wrapper.boto_client.list_objects_v2.return_value = result @property def region_name(self): return self._wrapper.region_name def read_file(name, file_path=None): if file_path: json_path = os.path.join(os.path.dirname(file_path), name) else: json_path = os.path.join(os.path.dirname(__file__), "default_data", name) with open(json_path, "r") as file: return file.read() def mock_default_device_calls(mocker): mocker.set_get_device_result( {"deviceType": "QPU", "deviceCapabilities": read_file("default_capabilities.json")} ) mocker.set_create_quantum_task_result( { "quantumTaskArn": "arn:aws:braket:us-west-2:000000:quantum-task/TestARN", } ) mocker.set_get_quantum_task_result( { "quantumTaskArn": "arn:aws:braket:us-west-2:000000:quantum-task/TestARN", "status": "COMPLETED", "outputS3Bucket": "Test Bucket", "outputS3Directory": "Test Directory", "shots": 10, "deviceArn": "Test Device Arn", "ResponseMetadata": {"HTTPHeaders": {"date": ""}}, } ) mocker.set_task_result_return(read_file("default_results.json")) def set_level(mock_level): Mocker.mock_level = mock_level class SessionWrapper: def __init__(self): self.boto_client = mock.Mock() self.task_result_mock = mock.Mock() self.resource_mock = mock.Mock() return_mock = mock.Mock() return_mock.read.return_value.decode = self.task_result_mock self.resource_mock.Object.return_value.get.return_value = {"Body": return_mock} self.boto_client.get_caller_identity.return_value = {"Account": "TestAccount"} self.boto_client.meta.region_name = "us-west-2" self.boto_client.get_authorization_token.return_value = { "authorizationData": [{"authorizationToken": "TestToken"}] } class Boto3SessionAllWrapper(SessionWrapper): def __init__(self): super().__init__() boto3.Session = self def __call__(self, *args, **kwargs): return self def client(self, *args, **kwargs): return self.boto_client def resource(self, *args, **kwargs): return self.resource_mock def profile_name(self, *args, **kwargs): return mock.Mock() def get_credentials(self, *args, **kwargs): return mock.Mock() @property def region_name(self): return "us-west-2" class AwsSessionMinWrapper(SessionWrapper): def __init__(self): super().__init__() AwsSessionFacade._wrapper = self AwsSessionFacade.real_get_device = braket.aws.aws_session.AwsSession.get_device braket.aws.aws_session.AwsSession.get_device = AwsSessionFacade.get_device AwsSessionFacade.real_create_quantum_task = ( braket.aws.aws_session.AwsSession.create_quantum_task ) braket.aws.aws_session.AwsSession.create_quantum_task = AwsSessionFacade.create_quantum_task AwsSessionFacade.real_get_quantum_task = braket.aws.aws_session.AwsSession.get_quantum_task braket.aws.aws_session.AwsSession.get_quantum_task = AwsSessionFacade.get_quantum_task AwsSessionFacade.real_cancel_quantum_task = ( braket.aws.aws_session.AwsSession.cancel_quantum_task ) braket.aws.aws_session.AwsSession.cancel_quantum_task = AwsSessionFacade.cancel_quantum_task AwsSessionFacade.real_retrieve_s3_object_body = ( braket.aws.aws_session.AwsSession.retrieve_s3_object_body ) braket.aws.aws_session.AwsSession.retrieve_s3_object_body = ( AwsSessionFacade.retrieve_s3_object_body ) braket.aws.aws_session.AwsSession.copy_s3_directory = AwsSessionFacade.copy_s3_directory AwsSessionMinWrapper.parse_device_config() @staticmethod def parse_device_config(): mock_device_config_str = os.getenv("MOCK_DEVICE_CONFIG") AwsSessionFacade.mock_device_config = ( json.loads(mock_device_config_str) if mock_device_config_str else {} ) unsupported_device_config_str = os.getenv("UNSUPPORTED_DEVICE_CONFIG") AwsSessionFacade.unsupported_device_config = ( set(json.loads(unsupported_device_config_str)) if unsupported_device_config_str else {} ) @property def region_name(self): return boto3.session.Session().region_name class AwsSessionFacade(braket.aws.AwsSession): created_task_arns = set() created_task_locations = set() def get_device(self, arn): device_name = arn.split("/")[-1] if device_name in AwsSessionFacade.unsupported_device_config: return AwsSessionFacade._wrapper.boto_client.get_device(arn) return AwsSessionFacade.real_get_device(self, arn) def create_quantum_task(self, **boto3_kwargs): if boto3_kwargs and boto3_kwargs["deviceArn"]: device_arn = boto3_kwargs["deviceArn"] device_name = device_arn.split("/")[-1] if device_name in AwsSessionFacade.unsupported_device_config: return AwsSessionFacade._wrapper.boto_client.create_quantum_task(boto3_kwargs)[ "quantumTaskArn" ] if device_name in AwsSessionFacade.mock_device_config: device_sub = AwsSessionFacade.mock_device_config[device_name] if device_sub == "MOCK": return AwsSessionFacade._wrapper.boto_client.create_quantum_task(boto3_kwargs)[ "quantumTaskArn" ] else: boto3_kwargs["deviceArn"] = device_sub task_arn = AwsSessionFacade.real_create_quantum_task(self, **boto3_kwargs) AwsSessionFacade.created_task_arns.add(task_arn) return task_arn return AwsSessionFacade._wrapper.boto_client.create_quantum_task(boto3_kwargs)[ "quantumTaskArn" ] def get_quantum_task(self, arn): if arn in AwsSessionFacade.created_task_arns: task_data = AwsSessionFacade.real_get_quantum_task(self, arn) AwsSessionFacade.created_task_locations.add(task_data["outputS3Directory"]) return task_data return AwsSessionFacade._wrapper.boto_client.get_quantum_task(arn) def cancel_quantum_task(self, arn): if arn in AwsSessionFacade.created_task_arns: return AwsSessionFacade.real_cancel_quantum_task(self, arn) return AwsSessionFacade._wrapper.boto_client.cancel_quantum_task(arn) def copy_s3_directory(self, source_s3_path, destination_s3_path): return def retrieve_s3_object_body(self, s3_bucket, s3_object_key): location = s3_object_key[: s3_object_key.rindex("/")] if location in AwsSessionFacade.created_task_locations: return AwsSessionFacade.real_retrieve_s3_object_body(self, s3_bucket, s3_object_key) if AwsSessionFacade._wrapper.task_result_mock.side_effect is not None: return next(AwsSessionFacade._wrapper.task_result_mock.side_effect) return AwsSessionFacade._wrapper.task_result_mock.return_value