// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 package com.amazonaws.encryptionsdk; import static org.junit.Assert.assertArrayEquals; import com.amazonaws.encryptionsdk.internal.Utils; import com.amazonaws.encryptionsdk.jce.JceMasterKey; import com.amazonaws.encryptionsdk.multi.MultipleProviderFactory; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import java.io.File; import java.io.StringReader; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Paths; import java.security.KeyFactory; import java.security.PrivateKey; import java.security.spec.PKCS8EncodedKeySpec; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import javax.crypto.spec.SecretKeySpec; import org.apache.commons.lang3.StringUtils; import org.bouncycastle.util.io.pem.PemReader; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameters; @RunWith(Parameterized.class) public class XCompatDecryptTest { private static final String STATIC_XCOMPAT_NAME = "static-aws-xcompat"; private static final String AES_GCM = "AES/GCM/NoPadding"; private static final byte XCOMPAT_MESSAGE_VERSION = 1; private String plaintextFileName; private String ciphertextFileName; private MasterKeyProvider masterKeyProvider; public XCompatDecryptTest( String plaintextFileName, String ciphertextFileName, MasterKeyProvider masterKeyProvider) throws Exception { this.plaintextFileName = plaintextFileName; this.ciphertextFileName = ciphertextFileName; this.masterKeyProvider = masterKeyProvider; } @Parameters(name = "{index}: testDecryptFromFile({0}, {1}, {2})") public static Collection data() throws Exception { String baseDirName; baseDirName = System.getProperty("staticCompatibilityResourcesDir"); if (baseDirName == null) { baseDirName = XCompatDecryptTest.class.getProtectionDomain().getCodeSource().getLocation().getPath() + "aws_encryption_sdk_resources"; } List testCases_ = new ArrayList(); String ciphertextManifestName = StringUtils.join( new String[] {baseDirName, "manifests", "ciphertext.manifest"}, File.separator); File ciphertextManifestFile = new File(ciphertextManifestName); if (!ciphertextManifestFile.exists()) { return Collections.emptySet(); } ObjectMapper ciphertextManifestMapper = new ObjectMapper(); Map ciphertextManifest = ciphertextManifestMapper.readValue( ciphertextManifestFile, new TypeReference>() {}); HashMap> staticKeyMap = new HashMap>(); Map testKeys = (Map) ciphertextManifest.get("test_keys"); for (Map.Entry keyType : testKeys.entrySet()) { Map keys = (Map) keyType.getValue(); HashMap thisKeyType = new HashMap(); for (Map.Entry key : keys.entrySet()) { Map thisKey = (Map) key.getValue(); String keyRaw = new String( StringUtils.join( (List) thisKey.get("key"), (String) thisKey.getOrDefault("line_separator", "")) .getBytes(), StandardCharsets.UTF_8); byte[] keyBytes; switch ((String) thisKey.get("encoding")) { case "base64": keyBytes = Utils.decodeBase64String(keyRaw); break; case "pem": PemReader pemReader = new PemReader(new StringReader(keyRaw)); keyBytes = pemReader.readPemObject().getContent(); break; case "raw": default: keyBytes = keyRaw.getBytes(); } thisKeyType.put((String) key.getKey(), keyBytes); } staticKeyMap.put((String) keyType.getKey(), thisKeyType); } final KeyFactory rsaKeyFactory = KeyFactory.getInstance("RSA"); List> testCases = (List>) ciphertextManifest.get("test_cases"); for (Map testCase : testCases) { Map plaintext = (Map) testCase.get("plaintext"); Map ciphertext = (Map) testCase.get("ciphertext"); short algId = (short) Integer.parseInt((String) testCase.get("algorithm"), 16); CryptoAlgorithm encryptionAlgorithm = CryptoAlgorithm.deserialize(XCOMPAT_MESSAGE_VERSION, algId); List> masterKeys = (List>) testCase.get("master_keys"); List allMasterKeys = new ArrayList(); for (Map aMasterKey : masterKeys) { String providerId = (String) aMasterKey.get("provider_id"); if (providerId.equals(STATIC_XCOMPAT_NAME) && (boolean) aMasterKey.get("decryptable")) { String paddingAlgorithm = (String) aMasterKey.getOrDefault("padding_algorithm", ""); String paddingHash = (String) aMasterKey.getOrDefault("padding_hash", ""); Integer keyBits = (Integer) aMasterKey.getOrDefault("key_bits", encryptionAlgorithm.getDataKeyLength() * 8); String keyId = (String) aMasterKey.get("encryption_algorithm") + "." + keyBits.toString() + "." + paddingAlgorithm + "." + paddingHash; String encAlg = (String) aMasterKey.get("encryption_algorithm"); switch (encAlg.toUpperCase()) { case "RSA": String cipherBase = "RSA/ECB/"; String cipherName; switch (paddingAlgorithm) { case "OAEP-MGF1": cipherName = cipherBase + "OAEPWith" + paddingHash + "AndMGF1Padding"; break; case "PKCS1": cipherName = cipherBase + paddingAlgorithm + "Padding"; break; default: throw new IllegalArgumentException( "Unknown padding algorithm: " + paddingAlgorithm); } PrivateKey privKey = rsaKeyFactory.generatePrivate( new PKCS8EncodedKeySpec(staticKeyMap.get("RSA").get(keyBits.toString()))); allMasterKeys.add( JceMasterKey.getInstance(null, privKey, STATIC_XCOMPAT_NAME, keyId, cipherName)); break; case "AES": SecretKeySpec spec = new SecretKeySpec( staticKeyMap.get("AES").get(keyBits.toString()), 0, encryptionAlgorithm.getDataKeyLength(), encryptionAlgorithm.getDataKeyAlgo()); allMasterKeys.add( JceMasterKey.getInstance(spec, STATIC_XCOMPAT_NAME, keyId, AES_GCM)); break; default: throw new IllegalArgumentException( "Unknown encryption algorithm: " + encAlg.toUpperCase()); } } } if (allMasterKeys.size() > 0) { final MasterKeyProvider provider = MultipleProviderFactory.buildMultiProvider(allMasterKeys); testCases_.add( new Object[] { baseDirName + File.separator + plaintext.get("filename"), baseDirName + File.separator + ciphertext.get("filename"), provider }); } } return testCases_; } @Test public void testDecryptFromFile() throws Exception { AwsCrypto crypto = AwsCrypto.standard(); byte ciphertextBytes[] = Files.readAllBytes(Paths.get(ciphertextFileName)); byte plaintextBytes[] = Files.readAllBytes(Paths.get(plaintextFileName)); final CryptoResult decryptResult = crypto.decryptData(masterKeyProvider, ciphertextBytes); assertArrayEquals(plaintextBytes, (byte[]) decryptResult.getResult()); } }