# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Helper tools for use in tests.""" from __future__ import division import base64 import copy import itertools import os from collections import defaultdict from decimal import Decimal import boto3 import pytest from boto3.dynamodb.types import Binary from botocore.exceptions import NoRegionError from mock import patch from moto import mock_dynamodb2 from dynamodb_encryption_sdk.delegated_keys.jce import JceNameLocalDelegatedKey from dynamodb_encryption_sdk.encrypted.client import EncryptedClient from dynamodb_encryption_sdk.encrypted.item import decrypt_python_item, encrypt_python_item from dynamodb_encryption_sdk.encrypted.resource import EncryptedResource from dynamodb_encryption_sdk.encrypted.table import EncryptedTable from dynamodb_encryption_sdk.identifiers import CryptoAction from dynamodb_encryption_sdk.internal.identifiers import ReservedAttributes from dynamodb_encryption_sdk.material_providers import CryptographicMaterialsProvider from dynamodb_encryption_sdk.material_providers.most_recent import CachingMostRecentProvider from dynamodb_encryption_sdk.material_providers.static import StaticCryptographicMaterialsProvider from dynamodb_encryption_sdk.material_providers.store.meta import MetaStore from dynamodb_encryption_sdk.material_providers.wrapped import WrappedCryptographicMaterialsProvider from dynamodb_encryption_sdk.materials import CryptographicMaterials from dynamodb_encryption_sdk.materials.raw import RawDecryptionMaterials, RawEncryptionMaterials from dynamodb_encryption_sdk.structures import AttributeActions, EncryptionContext from dynamodb_encryption_sdk.transform import ddb_to_dict, dict_to_ddb RUNNING_IN_TRAVIS = "TRAVIS" in os.environ _DELEGATED_KEY_CACHE = defaultdict(lambda: defaultdict(dict)) TEST_TABLE_NAME = "my_table" TEST_REGION_NAME = "us-west-2" TEST_INDEX = { "partition_attribute": {"type": "S", "value": "test_value"}, "sort_attribute": {"type": "N", "value": Decimal("99.233")}, } SECONDARY_INDEX = { "secondary_index_1": {"type": "B", "value": Binary(b"\x00\x01\x02")}, "secondary_index_2": {"type": "S", "value": "another_value"}, } TEST_KEY = {name: value["value"] for name, value in TEST_INDEX.items()} TEST_BATCH_INDEXES = [ { "partition_attribute": {"type": "S", "value": "test_value"}, "sort_attribute": {"type": "N", "value": Decimal("99.233")}, }, { "partition_attribute": {"type": "S", "value": "test_value"}, "sort_attribute": {"type": "N", "value": Decimal("92986745")}, }, { "partition_attribute": {"type": "S", "value": "test_value"}, "sort_attribute": {"type": "N", "value": Decimal("2231.0001")}, }, { "partition_attribute": {"type": "S", "value": "another_test_value"}, "sort_attribute": {"type": "N", "value": Decimal("732342")}, }, ] TEST_BATCH_KEYS = [{name: value["value"] for name, value in key.items()} for key in TEST_BATCH_INDEXES] @pytest.fixture(scope="module") def mock_ddb_service(): """Centralize service mock to avoid resetting service for tests that use multiple tables.""" with mock_dynamodb2(): yield boto3.client("dynamodb", region_name=TEST_REGION_NAME) @pytest.fixture def example_table(mock_ddb_service): mock_ddb_service.create_table( TableName=TEST_TABLE_NAME, KeySchema=[ {"AttributeName": "partition_attribute", "KeyType": "HASH"}, {"AttributeName": "sort_attribute", "KeyType": "RANGE"}, ], AttributeDefinitions=[ {"AttributeName": name, "AttributeType": value["type"]} for name, value in TEST_INDEX.items() ], ProvisionedThroughput={"ReadCapacityUnits": 100, "WriteCapacityUnits": 100}, ) yield mock_ddb_service mock_ddb_service.delete_table(TableName=TEST_TABLE_NAME) @pytest.fixture def table_with_local_secondary_indexes(mock_ddb_service): mock_ddb_service.create_table( TableName=TEST_TABLE_NAME, KeySchema=[ {"AttributeName": "partition_attribute", "KeyType": "HASH"}, {"AttributeName": "sort_attribute", "KeyType": "RANGE"}, ], LocalSecondaryIndexes=[ { "IndexName": "lsi-1", "KeySchema": [{"AttributeName": "secondary_index_1", "KeyType": "HASH"}], "Projection": {"ProjectionType": "ALL"}, }, { "IndexName": "lsi-2", "KeySchema": [{"AttributeName": "secondary_index_2", "KeyType": "HASH"}], "Projection": {"ProjectionType": "ALL"}, }, ], AttributeDefinitions=[ {"AttributeName": name, "AttributeType": value["type"]} for name, value in list(TEST_INDEX.items()) + list(SECONDARY_INDEX.items()) ], ProvisionedThroughput={"ReadCapacityUnits": 100, "WriteCapacityUnits": 100}, ) yield mock_ddb_service mock_ddb_service.delete_table(TableName=TEST_TABLE_NAME) @pytest.fixture def table_with_global_secondary_indexes(mock_ddb_service): mock_ddb_service.create_table( TableName=TEST_TABLE_NAME, KeySchema=[ {"AttributeName": "partition_attribute", "KeyType": "HASH"}, {"AttributeName": "sort_attribute", "KeyType": "RANGE"}, ], GlobalSecondaryIndexes=[ { "IndexName": "gsi-1", "KeySchema": [{"AttributeName": "secondary_index_1", "KeyType": "HASH"}], "Projection": {"ProjectionType": "ALL"}, "ProvisionedThroughput": {"ReadCapacityUnits": 100, "WriteCapacityUnits": 100}, }, { "IndexName": "gsi-2", "KeySchema": [{"AttributeName": "secondary_index_2", "KeyType": "HASH"}], "Projection": {"ProjectionType": "ALL"}, "ProvisionedThroughput": {"ReadCapacityUnits": 100, "WriteCapacityUnits": 100}, }, ], AttributeDefinitions=[ {"AttributeName": name, "AttributeType": value["type"]} for name, value in list(TEST_INDEX.items()) + list(SECONDARY_INDEX.items()) ], ProvisionedThroughput={"ReadCapacityUnits": 100, "WriteCapacityUnits": 100}, ) yield mock_ddb_service mock_ddb_service.delete_table(TableName=TEST_TABLE_NAME) class PassThroughCryptographicMaterialsProviderThatRequiresAttributes(CryptographicMaterialsProvider): """Cryptographic materials provider that passes through to another, but requires that attributes are set. If the EncryptionContext passed to decryption_materials or encryption_materials ever does not have attributes set, a ValueError is raised. Otherwise, it passes through to the passthrough CMP normally. """ def __init__(self, passthrough_cmp): self._passthrough_cmp = passthrough_cmp @staticmethod def _assert_attributes_set(encryption_context): # type: (EncryptionContext) -> None if not encryption_context.attributes: raise ValueError("Encryption context attributes MUST be set!") def decryption_materials(self, encryption_context): # type: (EncryptionContext) -> CryptographicMaterials self._assert_attributes_set(encryption_context) return self._passthrough_cmp.decryption_materials(encryption_context) def encryption_materials(self, encryption_context): # type: (EncryptionContext) -> CryptographicMaterials self._assert_attributes_set(encryption_context) return self._passthrough_cmp.encryption_materials(encryption_context) def refresh(self): # type: () -> None self._passthrough_cmp.refresh() def _get_from_cache(dk_class, algorithm, key_length): """Don't generate new keys every time. All we care about is that they are valid keys, not that they are unique.""" try: return _DELEGATED_KEY_CACHE[dk_class][algorithm][key_length] except KeyError: key = dk_class.generate(algorithm, key_length) _DELEGATED_KEY_CACHE[dk_class][algorithm][key_length] = key return key def build_static_jce_cmp(encryption_algorithm, encryption_key_length, signing_algorithm, signing_key_length): """Build a StaticCryptographicMaterialsProvider using ephemeral JceNameLocalDelegatedKeys as specified.""" encryption_key = _get_from_cache(JceNameLocalDelegatedKey, encryption_algorithm, encryption_key_length) authentication_key = _get_from_cache(JceNameLocalDelegatedKey, signing_algorithm, signing_key_length) encryption_materials = RawEncryptionMaterials(signing_key=authentication_key, encryption_key=encryption_key) decryption_materials = RawDecryptionMaterials(verification_key=authentication_key, decryption_key=encryption_key) return StaticCryptographicMaterialsProvider( encryption_materials=encryption_materials, decryption_materials=decryption_materials ) def _build_wrapped_jce_cmp(wrapping_algorithm, wrapping_key_length, signing_algorithm, signing_key_length): """Build a WrappedCryptographicMaterialsProvider using ephemeral JceNameLocalDelegatedKeys as specified.""" wrapping_key = _get_from_cache(JceNameLocalDelegatedKey, wrapping_algorithm, wrapping_key_length) signing_key = _get_from_cache(JceNameLocalDelegatedKey, signing_algorithm, signing_key_length) return WrappedCryptographicMaterialsProvider( wrapping_key=wrapping_key, unwrapping_key=wrapping_key, signing_key=signing_key ) def _all_encryption(): """All encryption configurations to test in slow tests.""" return itertools.chain(itertools.product(("AES",), (128, 256)), itertools.product(("RSA",), (1024, 2048, 4096))) def _all_authentication(): """All authentication configurations to test in slow tests.""" return itertools.chain( itertools.product(("HmacSHA224", "HmacSHA256", "HmacSHA384", "HmacSHA512"), (128, 256)), itertools.product(("SHA224withRSA", "SHA256withRSA", "SHA384withRSA", "SHA512withRSA"), (1024, 2048, 4096)), ) def _all_algorithm_pairs(): """All algorithm pairs (encryption + authentication) to test in slow tests.""" for encryption_pair, signing_pair in itertools.product(_all_encryption(), _all_authentication()): yield encryption_pair + signing_pair def _some_algorithm_pairs(): """Cherry-picked set of algorithm pairs (encryption + authentication) to test in fast tests.""" return (("AES", 256, "HmacSHA256", 256), ("AES", 256, "SHA256withRSA", 4096), ("RSA", 4096, "SHA256withRSA", 4096)) _cmp_builders = {"static": build_static_jce_cmp, "wrapped": _build_wrapped_jce_cmp} def _all_possible_cmps(algorithm_generator, require_attributes): """Generate all possible cryptographic materials providers based on the supplied generator. require_attributes determines whether the CMP will be wrapped in PassThroughCryptographicMaterialsProviderThatRequiresAttributes to require that attributes are set on every request. This should ONLY be disabled on the item encryptor tests. All high-level helper clients MUST set the attributes before passing the encryption context down. """ # The AES combinations do the same thing, but this makes sure that the AESWrap name works as expected. yield _build_wrapped_jce_cmp("AESWrap", 256, "HmacSHA256", 256) for builder_info, args in itertools.product(_cmp_builders.items(), algorithm_generator()): builder_type, builder_func = builder_info encryption_algorithm, encryption_key_length, signing_algorithm, signing_key_length = args if builder_type == "static" and encryption_algorithm != "AES": # Only AES keys are allowed to be used with static materials continue id_string = "{enc_algorithm}/{enc_key_length} {builder_type} {sig_algorithm}/{sig_key_length}".format( enc_algorithm=encryption_algorithm, enc_key_length=encryption_key_length, builder_type=builder_type, sig_algorithm=signing_algorithm, sig_key_length=signing_key_length, ) inner_cmp = builder_func(encryption_algorithm, encryption_key_length, signing_algorithm, signing_key_length) if require_attributes: outer_cmp = PassThroughCryptographicMaterialsProviderThatRequiresAttributes(inner_cmp) else: outer_cmp = inner_cmp yield pytest.param(outer_cmp, id=id_string) def set_parametrized_cmp(metafunc, require_attributes=True): """Set paramatrized values for cryptographic materials providers. require_attributes determines whether the CMP will be wrapped in PassThroughCryptographicMaterialsProviderThatRequiresAttributes to require that attributes are set on every request. This should ONLY be disabled on the item encryptor tests. All high-level helper clients MUST set the attributes before passing the encryption context down. """ for name, algorithm_generator in (("all_the_cmps", _all_algorithm_pairs), ("some_cmps", _some_algorithm_pairs)): if name in metafunc.fixturenames: metafunc.parametrize(name, _all_possible_cmps(algorithm_generator, require_attributes)) _ACTIONS = { "hypothesis_actions": ( pytest.param(AttributeActions(default_action=CryptoAction.ENCRYPT_AND_SIGN), id="encrypt all"), pytest.param(AttributeActions(default_action=CryptoAction.SIGN_ONLY), id="sign only all"), pytest.param(AttributeActions(default_action=CryptoAction.DO_NOTHING), id="do nothing"), ) } _ACTIONS["parametrized_actions"] = _ACTIONS["hypothesis_actions"] + ( pytest.param( AttributeActions( default_action=CryptoAction.ENCRYPT_AND_SIGN, attribute_actions={ "number_set": CryptoAction.SIGN_ONLY, "string_set": CryptoAction.SIGN_ONLY, "binary_set": CryptoAction.SIGN_ONLY, }, ), id="sign sets, encrypt everything else", ), pytest.param( AttributeActions( default_action=CryptoAction.ENCRYPT_AND_SIGN, attribute_actions={ "number_set": CryptoAction.DO_NOTHING, "string_set": CryptoAction.DO_NOTHING, "binary_set": CryptoAction.DO_NOTHING, }, ), id="ignore sets, encrypt everything else", ), pytest.param( AttributeActions( default_action=CryptoAction.DO_NOTHING, attribute_actions={"map": CryptoAction.ENCRYPT_AND_SIGN} ), id="encrypt map, ignore everything else", ), pytest.param( AttributeActions( default_action=CryptoAction.SIGN_ONLY, attribute_actions={ "number_set": CryptoAction.DO_NOTHING, "string_set": CryptoAction.DO_NOTHING, "binary_set": CryptoAction.DO_NOTHING, "map": CryptoAction.ENCRYPT_AND_SIGN, }, ), id="ignore sets, encrypt map, sign everything else", ), ) def set_parametrized_actions(metafunc): """Set parametrized values for attribute actions.""" for name, actions in _ACTIONS.items(): if name in metafunc.fixturenames: metafunc.parametrize(name, actions) def set_parametrized_item(metafunc): """Set parametrized values for items to cycle.""" if "parametrized_item" in metafunc.fixturenames: metafunc.parametrize("parametrized_item", (pytest.param(diverse_item(), id="diverse item"),)) def diverse_item(): base_item = { "int": 5, "decimal": Decimal("123.456"), "string": "this is a string", "binary": b"this is a bytestring! \x01", "number_set": set([5, 4, 3]), "string_set": set(["abc", "def", "geh"]), "binary_set": set([b"\x00\x00\x00", b"\x00\x01\x00", b"\x00\x00\x02"]), } base_item["list"] = [copy.copy(i) for i in base_item.values()] base_item["map"] = copy.deepcopy(base_item) return copy.deepcopy(base_item) _reserved_attributes = {attr.value for attr in ReservedAttributes} def return_requestitems_as_unprocessed(*args, **kwargs): return {"UnprocessedItems": kwargs["RequestItems"]} def check_encrypted_item(plaintext_item, ciphertext_item, attribute_actions): # Verify that all expected attributes are present ciphertext_attributes = set(ciphertext_item.keys()) plaintext_attributes = set(plaintext_item.keys()) if attribute_actions.take_no_actions: assert ciphertext_attributes == plaintext_attributes else: assert ciphertext_attributes == plaintext_attributes.union(_reserved_attributes) for name, value in ciphertext_item.items(): # Skip the attributes we add if name in _reserved_attributes: continue # If the attribute should have been encrypted, verify that it is Binary and different from the original if attribute_actions.action(name) is CryptoAction.ENCRYPT_AND_SIGN: assert isinstance(value, Binary) assert value != plaintext_item[name] # Otherwise, verify that it is the same as the original else: assert value == plaintext_item[name] def _matching_key(actual_item, expected): expected_item = [ i for i in expected if i["partition_attribute"] == actual_item["partition_attribute"] and i["sort_attribute"] == actual_item["sort_attribute"] ] assert len(expected_item) == 1 return expected_item[0] def _nop_transformer(item): return item def assert_items_exist_in_list(source, expected, transformer): for actual_item in source: expected_item = _matching_key(actual_item, expected) assert transformer(actual_item) == transformer(expected_item) def assert_equal_lists_of_items(actual, expected, transformer=_nop_transformer): assert len(actual) == len(expected) assert_items_exist_in_list(actual, expected, transformer) def assert_list_of_items_contains(full, subset, transformer=_nop_transformer): assert len(full) >= len(subset) assert_items_exist_in_list(subset, full, transformer) def check_many_encrypted_items(actual, expected, attribute_actions, transformer=_nop_transformer): assert len(actual) == len(expected) for actual_item in actual: expected_item = _matching_key(actual_item, expected) check_encrypted_item( plaintext_item=transformer(expected_item), ciphertext_item=transformer(actual_item), attribute_actions=attribute_actions, ) def _generate_items(initial_item, write_transformer): items = [] for key in TEST_BATCH_KEYS: _item = initial_item.copy() _item.update(key) items.append(write_transformer(_item)) return items def _cleanup_items(encrypted, write_transformer, table_name=TEST_TABLE_NAME): ddb_keys = [write_transformer(key) for key in TEST_BATCH_KEYS] _delete_result = encrypted.batch_write_item( # noqa RequestItems={table_name: [{"DeleteRequest": {"Key": _key}} for _key in ddb_keys]} ) def cycle_batch_item_check( raw, encrypted, initial_actions, initial_item, write_transformer=_nop_transformer, read_transformer=_nop_transformer, table_name=TEST_TABLE_NAME, delete_items=True, ): """Check that cycling (plaintext->encrypted->decrypted) item batch has the expected results.""" check_attribute_actions = initial_actions.copy() check_attribute_actions.set_index_keys(*list(TEST_KEY.keys())) items = _generate_items(initial_item, write_transformer) items_in_table = len(items) _put_result = encrypted.batch_write_item( # noqa RequestItems={table_name: [{"PutRequest": {"Item": _item}} for _item in items]} ) try: ddb_keys = [write_transformer(key) for key in TEST_BATCH_KEYS] encrypted_result = raw.batch_get_item(RequestItems={table_name: {"Keys": ddb_keys}}) check_many_encrypted_items( actual=encrypted_result["Responses"][table_name], expected=items, attribute_actions=check_attribute_actions, transformer=read_transformer, ) decrypted_result = encrypted.batch_get_item(RequestItems={table_name: {"Keys": ddb_keys}}) assert_equal_lists_of_items( actual=decrypted_result["Responses"][table_name], expected=items, transformer=read_transformer ) finally: if delete_items: _cleanup_items(encrypted, write_transformer, table_name) items_in_table = 0 del check_attribute_actions del items return items_in_table def cycle_batch_writer_check(raw_table, encrypted_table, initial_actions, initial_item): """Cycling (plaintext->encrypted->decrypted) items with the Table batch writer should have the expected results.""" check_attribute_actions = initial_actions.copy() check_attribute_actions.set_index_keys(*list(TEST_KEY.keys())) items = _generate_items(initial_item, _nop_transformer) with encrypted_table.batch_writer() as writer: for item in items: writer.put_item(item) ddb_keys = copy.copy(TEST_BATCH_KEYS) encrypted_items = [raw_table.get_item(Key=key, ConsistentRead=True)["Item"] for key in ddb_keys] check_many_encrypted_items( actual=encrypted_items, expected=items, attribute_actions=check_attribute_actions, transformer=_nop_transformer ) decrypted_result = [encrypted_table.get_item(Key=key, ConsistentRead=True)["Item"] for key in ddb_keys] assert_equal_lists_of_items(actual=decrypted_result, expected=items, transformer=_nop_transformer) with encrypted_table.batch_writer() as writer: for key in ddb_keys: writer.delete_item(key) del check_attribute_actions del items def batch_write_item_unprocessed_check( encrypted, initial_item, write_transformer=_nop_transformer, table_name=TEST_TABLE_NAME ): """Check that unprocessed items in a batch result are unencrypted.""" items = _generate_items(initial_item, write_transformer) request_items = {table_name: [{"PutRequest": {"Item": _item}} for _item in items]} _put_result = encrypted.batch_write_item(RequestItems=request_items) # we expect results to include Unprocessed items, or the test case is invalid! unprocessed_items = _put_result["UnprocessedItems"] assert unprocessed_items != {} unprocessed = [operation["PutRequest"]["Item"] for operation in unprocessed_items[TEST_TABLE_NAME]] assert_list_of_items_contains(items, unprocessed, transformer=_nop_transformer) del items def cycle_item_check(plaintext_item, crypto_config): """Check that cycling (plaintext->encrypted->decrypted) an item has the expected results.""" ciphertext_item = encrypt_python_item(plaintext_item, crypto_config) check_encrypted_item(plaintext_item, ciphertext_item, crypto_config.attribute_actions) cycled_item = decrypt_python_item(ciphertext_item, crypto_config) assert cycled_item == plaintext_item del ciphertext_item del cycled_item def table_cycle_check(materials_provider, initial_actions, initial_item, table_name, region_name=None): check_attribute_actions = initial_actions.copy() check_attribute_actions.set_index_keys(*list(TEST_KEY.keys())) item = initial_item.copy() item.update(TEST_KEY) kwargs = {} if region_name is not None: kwargs["region_name"] = region_name table = boto3.resource("dynamodb", **kwargs).Table(table_name) e_table = EncryptedTable(table=table, materials_provider=materials_provider, attribute_actions=initial_actions) _put_result = e_table.put_item(Item=item) # noqa encrypted_result = table.get_item(Key=TEST_KEY, ConsistentRead=True) check_encrypted_item(item, encrypted_result["Item"], check_attribute_actions) decrypted_result = e_table.get_item(Key=TEST_KEY, ConsistentRead=True) assert decrypted_result["Item"] == item e_table.delete_item(Key=TEST_KEY) del item del check_attribute_actions def table_cycle_batch_writer_check(materials_provider, initial_actions, initial_item, table_name, region_name=None): kwargs = {} if region_name is not None: kwargs["region_name"] = region_name table = boto3.resource("dynamodb", **kwargs).Table(table_name) e_table = EncryptedTable(table=table, materials_provider=materials_provider, attribute_actions=initial_actions) cycle_batch_writer_check(table, e_table, initial_actions, initial_item) def table_batch_writer_unprocessed_items_check( materials_provider, initial_actions, initial_item, table_name, region_name=None ): kwargs = {} if region_name is not None: kwargs["region_name"] = region_name resource = boto3.resource("dynamodb", **kwargs) table = resource.Table(table_name) items = _generate_items(initial_item, _nop_transformer) request_items = {table_name: [{"PutRequest": {"Item": _item}} for _item in items]} with patch.object(table.meta.client, "batch_write_item") as batch_write_mock: # Check that unprocessed items returned to a BatchWriter are successfully retried batch_write_mock.side_effect = [{"UnprocessedItems": request_items}, {"UnprocessedItems": {}}] e_table = EncryptedTable(table=table, materials_provider=materials_provider, attribute_actions=initial_actions) with e_table.batch_writer() as writer: for item in items: writer.put_item(item) del items def resource_cycle_batch_items_check(materials_provider, initial_actions, initial_item, table_name, region_name=None): kwargs = {} if region_name is not None: kwargs["region_name"] = region_name resource = boto3.resource("dynamodb", **kwargs) e_resource = EncryptedResource( resource=resource, materials_provider=materials_provider, attribute_actions=initial_actions ) cycle_batch_item_check( raw=resource, encrypted=e_resource, initial_actions=initial_actions, initial_item=initial_item, table_name=table_name, ) raw_scan_result = resource.Table(table_name).scan(ConsistentRead=True) e_scan_result = e_resource.Table(table_name).scan(ConsistentRead=True) assert not raw_scan_result["Items"] assert not e_scan_result["Items"] def resource_batch_items_unprocessed_check( materials_provider, initial_actions, initial_item, table_name, region_name=None ): kwargs = {} if region_name is not None: kwargs["region_name"] = region_name resource = boto3.resource("dynamodb", **kwargs) with patch.object(resource, "batch_write_item", return_requestitems_as_unprocessed): e_resource = EncryptedResource( resource=resource, materials_provider=materials_provider, attribute_actions=initial_actions ) batch_write_item_unprocessed_check( encrypted=e_resource, initial_item=initial_item, write_transformer=dict_to_ddb, table_name=table_name ) def client_cycle_single_item_check(materials_provider, initial_actions, initial_item, table_name, region_name=None): check_attribute_actions = initial_actions.copy() check_attribute_actions.set_index_keys(*list(TEST_KEY.keys())) item = initial_item.copy() item.update(TEST_KEY) ddb_item = dict_to_ddb(item) ddb_key = dict_to_ddb(TEST_KEY) kwargs = {} if region_name is not None: kwargs["region_name"] = region_name client = boto3.client("dynamodb", **kwargs) e_client = EncryptedClient(client=client, materials_provider=materials_provider, attribute_actions=initial_actions) _put_result = e_client.put_item(TableName=table_name, Item=ddb_item) # noqa encrypted_result = client.get_item(TableName=table_name, Key=ddb_key, ConsistentRead=True) check_encrypted_item(item, ddb_to_dict(encrypted_result["Item"]), check_attribute_actions) decrypted_result = e_client.get_item(TableName=table_name, Key=ddb_key, ConsistentRead=True) assert ddb_to_dict(decrypted_result["Item"]) == item e_client.delete_item(TableName=table_name, Key=ddb_key) del item del check_attribute_actions def client_cycle_batch_items_check(materials_provider, initial_actions, initial_item, table_name, region_name=None): kwargs = {} if region_name is not None: kwargs["region_name"] = region_name client = boto3.client("dynamodb", **kwargs) e_client = EncryptedClient(client=client, materials_provider=materials_provider, attribute_actions=initial_actions) cycle_batch_item_check( raw=client, encrypted=e_client, initial_actions=initial_actions, initial_item=initial_item, write_transformer=dict_to_ddb, read_transformer=ddb_to_dict, table_name=table_name, ) raw_scan_result = client.scan(TableName=table_name, ConsistentRead=True) e_scan_result = e_client.scan(TableName=table_name, ConsistentRead=True) assert not raw_scan_result["Items"] assert not e_scan_result["Items"] def client_batch_items_unprocessed_check( materials_provider, initial_actions, initial_item, table_name, region_name=None ): kwargs = {} if region_name is not None: kwargs["region_name"] = region_name client = boto3.client("dynamodb", **kwargs) with patch.object(client, "batch_write_item", return_requestitems_as_unprocessed): e_client = EncryptedClient( client=client, materials_provider=materials_provider, attribute_actions=initial_actions ) batch_write_item_unprocessed_check( encrypted=e_client, initial_item=initial_item, write_transformer=dict_to_ddb, table_name=table_name ) def client_cycle_batch_items_check_scan_paginator( materials_provider, initial_actions, initial_item, table_name, region_name=None ): """Helper function for testing the "scan" paginator. Populate the specified table with encrypted items, scan the table with raw client paginator to get encrypted items, scan the table with encrypted client paginator to get decrypted items, then verify that all items appear to have been encrypted correctly. """ # noqa=D401 # pylint: disable=too-many-locals kwargs = {} if region_name is not None: kwargs["region_name"] = region_name client = boto3.client("dynamodb", **kwargs) e_client = EncryptedClient(client=client, materials_provider=materials_provider, attribute_actions=initial_actions) items_in_table = cycle_batch_item_check( raw=client, encrypted=e_client, initial_actions=initial_actions, initial_item=initial_item, write_transformer=dict_to_ddb, read_transformer=ddb_to_dict, table_name=table_name, delete_items=False, ) try: encrypted_items = [] raw_paginator = client.get_paginator("scan") for page in raw_paginator.paginate(TableName=table_name, ConsistentRead=True): encrypted_items.extend(page["Items"]) decrypted_items = [] encrypted_paginator = e_client.get_paginator("scan") for page in encrypted_paginator.paginate(TableName=table_name, ConsistentRead=True): decrypted_items.extend(page["Items"]) assert encrypted_items and decrypted_items assert len(encrypted_items) == len(decrypted_items) == items_in_table check_attribute_actions = initial_actions.copy() check_attribute_actions.set_index_keys(*list(TEST_KEY.keys())) check_many_encrypted_items( actual=encrypted_items, expected=decrypted_items, attribute_actions=check_attribute_actions, transformer=ddb_to_dict, ) finally: _cleanup_items(encrypted=e_client, write_transformer=dict_to_ddb, table_name=table_name) raw_scan_result = client.scan(TableName=table_name, ConsistentRead=True) e_scan_result = e_client.scan(TableName=table_name, ConsistentRead=True) assert not raw_scan_result["Items"] assert not e_scan_result["Items"] def build_metastore(): client = boto3.client("dynamodb", region_name=TEST_REGION_NAME) table_name = base64.urlsafe_b64encode(os.urandom(32)).decode("utf-8").replace("=", ".") MetaStore.create_table(client, table_name, 1, 1) waiter = client.get_waiter("table_exists") waiter.wait(TableName=table_name) table = boto3.resource("dynamodb", region_name=TEST_REGION_NAME).Table(table_name) return MetaStore(table, build_static_jce_cmp("AES", 256, "HmacSHA256", 256)), table_name def delete_metastore(table_name): client = boto3.client("dynamodb", region_name=TEST_REGION_NAME) client.delete_table(TableName=table_name) # It sometimes takes a long time to delete a table. # If hanging, asynchronously deleting tables becomes an issue, # come back to this. # Otherwise, let's just let them take care of themselves. # waiter = client.get_waiter("table_not_exists") # waiter.wait(TableName=table_name) @pytest.fixture def mock_metastore(): with mock_dynamodb2(): metastore, table_name = build_metastore() yield metastore delete_metastore(table_name) def _count_entries(records, *messages): count = 0 for record in records: if all((message in record.getMessage() for message in messages)): count += 1 return count def _count_puts(records, table_name): return _count_entries(records, '"TableName": "{}"'.format(table_name), "OperationModel(name=PutItem)") def _count_gets(records, table_name): return _count_entries(records, '"TableName": "{}"'.format(table_name), "OperationModel(name=GetItem)") def check_metastore_cache_use_encrypt(metastore, table_name, log_capture): try: table = boto3.resource("dynamodb").Table(table_name) except NoRegionError: table = boto3.resource("dynamodb", region_name=TEST_REGION_NAME).Table(table_name) most_recent_provider = CachingMostRecentProvider(provider_store=metastore, material_name="test", version_ttl=600.0) e_table = EncryptedTable(table=table, materials_provider=most_recent_provider) item = diverse_item() item.update(TEST_KEY) e_table.put_item(Item=item) e_table.put_item(Item=item) e_table.put_item(Item=item) e_table.put_item(Item=item) try: primary_puts = _count_puts(log_capture.records, e_table.name) metastore_puts = _count_puts(log_capture.records, metastore._table.name) assert primary_puts == 4 assert metastore_puts == 1 e_table.get_item(Key=TEST_KEY) e_table.get_item(Key=TEST_KEY) e_table.get_item(Key=TEST_KEY) primary_gets = _count_gets(log_capture.records, e_table.name) metastore_gets = _count_gets(log_capture.records, metastore._table.name) metastore_puts = _count_puts(log_capture.records, metastore._table.name) assert primary_gets == 3 assert metastore_gets == 0 assert metastore_puts == 1 most_recent_provider.refresh() e_table.get_item(Key=TEST_KEY) e_table.get_item(Key=TEST_KEY) e_table.get_item(Key=TEST_KEY) primary_gets = _count_gets(log_capture.records, e_table.name) metastore_gets = _count_gets(log_capture.records, metastore._table.name) assert primary_gets == 6 assert metastore_gets == 1 finally: e_table.delete_item(Key=TEST_KEY)