// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package com.amazonaws.encryptionsdk.kmssdkv2;

import static com.amazonaws.encryptionsdk.TestUtils.assertThrows;
import static com.amazonaws.encryptionsdk.internal.RandomBytesGenerator.generate;
import static org.junit.Assert.*;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.*;

import com.amazonaws.encryptionsdk.*;
import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException;
import com.amazonaws.encryptionsdk.internal.VersionInfo;
import com.amazonaws.encryptionsdk.model.KeyBlob;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.function.Supplier;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import software.amazon.awssdk.awscore.AwsRequest;
import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.kms.KmsClient;
import software.amazon.awssdk.services.kms.model.*;

public class KmsMasterKeyTest {

  private static final String AWS_KMS_PROVIDER_ID = "aws-kms";
  private static final String OTHER_PROVIDER_ID = "not-aws-kms";

  private static final CryptoAlgorithm ALGORITHM_SUITE =
      CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256;
  private static final SecretKey DATA_KEY =
      new SecretKeySpec(
          generate(ALGORITHM_SUITE.getDataKeyLength()), ALGORITHM_SUITE.getDataKeyAlgo());
  private static final List<String> GRANT_TOKENS = Collections.singletonList("testGrantToken");
  private static final Map<String, String> ENCRYPTION_CONTEXT =
      Collections.singletonMap("myKey", "myValue");

  @Test
  public void testEncryptAndDecrypt() {
    KmsClient client = spy(new MockKmsClient());
    Supplier supplier = mock(Supplier.class);
    when(supplier.get()).thenReturn(client);

    MasterKey otherMasterKey = mock(MasterKey.class);
    when(otherMasterKey.getProviderId()).thenReturn(OTHER_PROVIDER_ID);
    when(otherMasterKey.getKeyId()).thenReturn("someOtherId");
    DataKey dataKey =
        new DataKey(
            DATA_KEY,
            new byte[0],
            OTHER_PROVIDER_ID.getBytes(StandardCharsets.UTF_8),
            otherMasterKey);

    MasterKeyProvider mkp = mock(MasterKeyProvider.class);
    when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID);
    String keyId = client.createKey().keyMetadata().arn();
    KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp);
    kmsMasterKey.setGrantTokens(GRANT_TOKENS);

    DataKey<KmsMasterKey> encryptDataKeyResponse =
        kmsMasterKey.encryptDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT, dataKey);

    ArgumentCaptor<EncryptRequest> er = ArgumentCaptor.forClass(EncryptRequest.class);
    verify(client, times(1)).encrypt(er.capture());

    EncryptRequest actualRequest = er.getValue();
    assertEquals(keyId, actualRequest.keyId());
    assertEquals(GRANT_TOKENS, actualRequest.grantTokens());
    assertEquals(ENCRYPTION_CONTEXT, actualRequest.encryptionContext());
    assertArrayEquals(DATA_KEY.getEncoded(), actualRequest.plaintext().asByteArray());
    assertApiName(actualRequest);

    assertEquals(encryptDataKeyResponse.getMasterKey(), kmsMasterKey);
    assertEquals(AWS_KMS_PROVIDER_ID, encryptDataKeyResponse.getProviderId());
    assertArrayEquals(
        keyId.getBytes(StandardCharsets.UTF_8), encryptDataKeyResponse.getProviderInformation());
    assertNotNull(encryptDataKeyResponse.getEncryptedDataKey());

    DataKey<KmsMasterKey> decryptDataKeyResponse =
        kmsMasterKey.decryptDataKey(
            ALGORITHM_SUITE, Collections.singletonList(encryptDataKeyResponse), ENCRYPTION_CONTEXT);

    ArgumentCaptor<DecryptRequest> decrypt = ArgumentCaptor.forClass(DecryptRequest.class);
    verify(client, times(1)).decrypt(decrypt.capture());

    DecryptRequest actualDecryptRequest = decrypt.getValue();
    assertArrayEquals(
        encryptDataKeyResponse.getProviderInformation(),
        actualDecryptRequest.keyId().getBytes(StandardCharsets.UTF_8));
    assertEquals(GRANT_TOKENS, actualDecryptRequest.grantTokens());
    assertEquals(ENCRYPTION_CONTEXT, actualDecryptRequest.encryptionContext());
    assertArrayEquals(
        encryptDataKeyResponse.getEncryptedDataKey(),
        actualDecryptRequest.ciphertextBlob().asByteArray());
    assertApiName(actualDecryptRequest);

    assertEquals(DATA_KEY, decryptDataKeyResponse.getKey());
    assertArrayEquals(
        keyId.getBytes(StandardCharsets.UTF_8), decryptDataKeyResponse.getProviderInformation());
  }

  @Test
  public void testGenerateAndDecrypt() {
    KmsClient client = spy(new MockKmsClient());
    Supplier supplier = mock(Supplier.class);
    when(supplier.get()).thenReturn(client);

    MasterKeyProvider mkp = mock(MasterKeyProvider.class);
    when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID);
    String keyId = client.createKey().keyMetadata().arn();
    KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp);
    kmsMasterKey.setGrantTokens(GRANT_TOKENS);

    DataKey<KmsMasterKey> generateDataKeyResponse =
        kmsMasterKey.generateDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT);

    ArgumentCaptor<GenerateDataKeyRequest> gr =
        ArgumentCaptor.forClass(GenerateDataKeyRequest.class);
    verify(client, times(1)).generateDataKey(gr.capture());

    GenerateDataKeyRequest actualRequest = gr.getValue();

    assertEquals(keyId, actualRequest.keyId());
    assertEquals(GRANT_TOKENS, actualRequest.grantTokens());
    assertEquals(ENCRYPTION_CONTEXT, actualRequest.encryptionContext());
    assertEquals(ALGORITHM_SUITE.getDataKeyLength(), actualRequest.numberOfBytes().longValue());
    assertApiName(actualRequest);

    assertNotNull(generateDataKeyResponse.getKey());
    assertEquals(
        ALGORITHM_SUITE.getDataKeyLength(), generateDataKeyResponse.getKey().getEncoded().length);
    assertEquals(ALGORITHM_SUITE.getDataKeyAlgo(), generateDataKeyResponse.getKey().getAlgorithm());
    assertNotNull(generateDataKeyResponse.getEncryptedDataKey());

    DataKey<KmsMasterKey> decryptDataKeyResponse =
        kmsMasterKey.decryptDataKey(
            ALGORITHM_SUITE,
            Collections.singletonList(generateDataKeyResponse),
            ENCRYPTION_CONTEXT);

    ArgumentCaptor<DecryptRequest> decrypt = ArgumentCaptor.forClass(DecryptRequest.class);
    verify(client, times(1)).decrypt(decrypt.capture());

    DecryptRequest actualDecryptRequest = decrypt.getValue();
    assertArrayEquals(
        generateDataKeyResponse.getProviderInformation(),
        actualDecryptRequest.keyId().getBytes(StandardCharsets.UTF_8));
    assertEquals(GRANT_TOKENS, actualDecryptRequest.grantTokens());
    assertEquals(ENCRYPTION_CONTEXT, actualDecryptRequest.encryptionContext());
    assertArrayEquals(
        generateDataKeyResponse.getEncryptedDataKey(),
        actualDecryptRequest.ciphertextBlob().asByteArray());
    assertApiName(actualDecryptRequest);

    assertEquals(generateDataKeyResponse.getKey(), decryptDataKeyResponse.getKey());
    assertArrayEquals(
        keyId.getBytes(StandardCharsets.UTF_8), decryptDataKeyResponse.getProviderInformation());
  }

  @Test
  public void testEncryptWithRawKeyId() {
    KmsClient client = spy(new MockKmsClient());
    Supplier supplier = mock(Supplier.class);
    when(supplier.get()).thenReturn(client);

    MasterKey otherMasterKey = mock(MasterKey.class);
    when(otherMasterKey.getProviderId()).thenReturn(OTHER_PROVIDER_ID);
    when(otherMasterKey.getKeyId()).thenReturn("someOtherId");
    DataKey dataKey =
        new DataKey(
            DATA_KEY,
            new byte[0],
            OTHER_PROVIDER_ID.getBytes(StandardCharsets.UTF_8),
            otherMasterKey);

    MasterKeyProvider mkp = mock(MasterKeyProvider.class);
    when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID);
    String keyId = client.createKey().keyMetadata().arn();
    String rawKeyId = keyId.split("/")[1];
    KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, rawKeyId, mkp);
    kmsMasterKey.setGrantTokens(GRANT_TOKENS);

    DataKey<KmsMasterKey> encryptDataKeyResponse =
        kmsMasterKey.encryptDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT, dataKey);

    ArgumentCaptor<EncryptRequest> er = ArgumentCaptor.forClass(EncryptRequest.class);
    verify(client, times(1)).encrypt(er.capture());

    EncryptRequest actualRequest = er.getValue();

    assertEquals(rawKeyId, actualRequest.keyId());
    assertEquals(GRANT_TOKENS, actualRequest.grantTokens());
    assertEquals(ENCRYPTION_CONTEXT, actualRequest.encryptionContext());
    assertArrayEquals(DATA_KEY.getEncoded(), actualRequest.plaintext().asByteArray());
    assertApiName(actualRequest);

    assertEquals(AWS_KMS_PROVIDER_ID, encryptDataKeyResponse.getProviderId());
    assertArrayEquals(
        keyId.getBytes(StandardCharsets.UTF_8), encryptDataKeyResponse.getProviderInformation());
    assertNotNull(encryptDataKeyResponse.getEncryptedDataKey());
  }

  @Test
  public void testEncryptWrongKeyFormat() {
    SecretKey key = mock(SecretKey.class);
    when(key.getFormat()).thenReturn("BadFormat");

    KmsClient client = spy(new MockKmsClient());
    Supplier supplier = mock(Supplier.class);
    when(supplier.get()).thenReturn(client);

    MasterKey otherMasterKey = mock(MasterKey.class);
    when(otherMasterKey.getProviderId()).thenReturn(OTHER_PROVIDER_ID);
    when(otherMasterKey.getKeyId()).thenReturn("someOtherId");
    DataKey dataKey =
        new DataKey(
            key, new byte[0], OTHER_PROVIDER_ID.getBytes(StandardCharsets.UTF_8), otherMasterKey);

    MasterKeyProvider mkp = mock(MasterKeyProvider.class);
    when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID);
    String keyId = client.createKey().keyMetadata().arn();
    KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp);

    assertThrows(
        IllegalArgumentException.class,
        () -> kmsMasterKey.encryptDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT, dataKey));
  }

  @Test
  public void testGenerateBadKmsKeyLength() {
    KmsClient client = spy(new MockKmsClient());
    Supplier supplier = mock(Supplier.class);
    when(supplier.get()).thenReturn(client);

    MasterKeyProvider mkp = mock(MasterKeyProvider.class);
    when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID);
    String keyId = client.createKey().keyMetadata().arn();
    KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp);

    GenerateDataKeyResponse badResponse =
        GenerateDataKeyResponse.builder()
            .keyId(keyId)
            .plaintext(SdkBytes.fromByteArray(new byte[ALGORITHM_SUITE.getDataKeyLength() + 1]))
            .build();

    doReturn(badResponse).when(client).generateDataKey(isA(GenerateDataKeyRequest.class));

    assertThrows(
        IllegalStateException.class,
        () -> kmsMasterKey.generateDataKey(ALGORITHM_SUITE, ENCRYPTION_CONTEXT));
  }

  @Test
  public void testDecryptBadKmsKeyLength() {
    KmsClient client = spy(new MockKmsClient());
    Supplier supplier = mock(Supplier.class);
    when(supplier.get()).thenReturn(client);

    MasterKeyProvider mkp = mock(MasterKeyProvider.class);
    when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID);
    String keyId = client.createKey().keyMetadata().arn();
    KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp);

    DecryptResponse badResponse =
        DecryptResponse.builder()
            .keyId(keyId)
            .plaintext(SdkBytes.fromByteArray(new byte[ALGORITHM_SUITE.getDataKeyLength() + 1]))
            .build();

    doReturn(badResponse).when(client).decrypt(isA(DecryptRequest.class));

    EncryptedDataKey edk =
        new KeyBlob(
            AWS_KMS_PROVIDER_ID,
            keyId.getBytes(StandardCharsets.UTF_8),
            generate(ALGORITHM_SUITE.getDataKeyLength()));

    assertThrows(
        IllegalStateException.class,
        () ->
            kmsMasterKey.decryptDataKey(
                ALGORITHM_SUITE, Collections.singletonList(edk), ENCRYPTION_CONTEXT));
  }

  @Test
  public void testDecryptMissingKmsKeyId() {
    KmsClient client = spy(new MockKmsClient());
    Supplier supplier = mock(Supplier.class);
    when(supplier.get()).thenReturn(client);

    MasterKeyProvider mkp = mock(MasterKeyProvider.class);
    when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID);
    String keyId = client.createKey().keyMetadata().arn();
    KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp);

    DecryptResponse badResponse =
        DecryptResponse.builder()
            .plaintext(SdkBytes.fromByteArray(new byte[ALGORITHM_SUITE.getDataKeyLength()]))
            .build();

    doReturn(badResponse).when(client).decrypt(isA(DecryptRequest.class));

    EncryptedDataKey edk =
        new KeyBlob(
            AWS_KMS_PROVIDER_ID,
            keyId.getBytes(StandardCharsets.UTF_8),
            generate(ALGORITHM_SUITE.getDataKeyLength()));

    assertThrows(
        IllegalStateException.class,
        "Received an empty keyId from KMS",
        () ->
            kmsMasterKey.decryptDataKey(
                ALGORITHM_SUITE, Collections.singletonList(edk), ENCRYPTION_CONTEXT));
  }

  @Test
  public void testDecryptMismatchedKmsKeyId() {
    KmsClient client = spy(new MockKmsClient());
    Supplier supplier = mock(Supplier.class);
    when(supplier.get()).thenReturn(client);

    MasterKeyProvider mkp = mock(MasterKeyProvider.class);
    when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID);
    String keyId = client.createKey().keyMetadata().arn();
    KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp);

    DecryptResponse badResponse =
        DecryptResponse.builder()
            .keyId("mismatchedID")
            .plaintext(SdkBytes.fromByteArray(new byte[ALGORITHM_SUITE.getDataKeyLength()]))
            .build();

    doReturn(badResponse).when(client).decrypt(isA(DecryptRequest.class));

    EncryptedDataKey edk =
        new KeyBlob(
            AWS_KMS_PROVIDER_ID,
            keyId.getBytes(StandardCharsets.UTF_8),
            generate(ALGORITHM_SUITE.getDataKeyLength()));

    assertThrows(
        CannotUnwrapDataKeyException.class,
        () ->
            kmsMasterKey.decryptDataKey(
                ALGORITHM_SUITE, Collections.singletonList(edk), ENCRYPTION_CONTEXT));
  }

  @Test
  public void testDecryptSkipsMismatchedIdEDK() {
    KmsClient client = spy(new MockKmsClient());
    Supplier supplier = mock(Supplier.class);
    when(supplier.get()).thenReturn(client);

    MasterKeyProvider mkp = mock(MasterKeyProvider.class);
    when(mkp.getDefaultProviderId()).thenReturn(AWS_KMS_PROVIDER_ID);
    String keyId = client.createKey().keyMetadata().arn();
    KmsMasterKey kmsMasterKey = KmsMasterKey.getInstance(supplier, keyId, mkp);

    // Mock expected KMS response to verify success if second EDK is ok,
    // and the mismatched EDK is skipped vs failing outright
    DecryptResponse kmsResponse =
        DecryptResponse.builder()
            .keyId(keyId)
            .plaintext(SdkBytes.fromByteArray(new byte[ALGORITHM_SUITE.getDataKeyLength()]))
            .build();
    doReturn(kmsResponse).when(client).decrypt(isA(DecryptRequest.class));

    EncryptedDataKey edk =
        new KeyBlob(
            AWS_KMS_PROVIDER_ID,
            keyId.getBytes(StandardCharsets.UTF_8),
            generate(ALGORITHM_SUITE.getDataKeyLength()));
    EncryptedDataKey mismatchedEDK =
        new KeyBlob(
            AWS_KMS_PROVIDER_ID,
            "mismatchedID".getBytes(StandardCharsets.UTF_8),
            generate(ALGORITHM_SUITE.getDataKeyLength()));

    DataKey<KmsMasterKey> decryptDataKeyResponse =
        kmsMasterKey.decryptDataKey(
            ALGORITHM_SUITE, Arrays.asList(mismatchedEDK, edk), ENCRYPTION_CONTEXT);

    ArgumentCaptor<DecryptRequest> decrypt = ArgumentCaptor.forClass(DecryptRequest.class);
    verify(client, times(1)).decrypt(decrypt.capture());

    DecryptRequest actualDecryptRequest = decrypt.getValue();
    assertArrayEquals(
        edk.getProviderInformation(),
        actualDecryptRequest.keyId().getBytes(StandardCharsets.UTF_8));
  }

  private void assertApiName(AwsRequest request) {
    Optional<AwsRequestOverrideConfiguration> overrideConfig = request.overrideConfiguration();
    assertTrue(overrideConfig.isPresent());
    assertTrue(
        overrideConfig.get().apiNames().stream()
            .anyMatch(
                api ->
                    api.name().equals(VersionInfo.apiName())
                        && api.version().equals(VersionInfo.versionNumber())));
  }
}