package software.aws.mcs.auth; /*- * #%L * AWS SigV4 Auth Java Driver 4.x Plugin * %% * Copyright (C) 2020-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. * %% * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * #L% */ import java.io.UnsupportedEncodingException; import java.net.URLEncoder; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.time.Instant; import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatterBuilder; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import javax.crypto.Mac; import javax.crypto.spec.SecretKeySpec; import javax.validation.constraints.NotNull; import org.apache.commons.codec.binary.Hex; import com.datastax.oss.driver.api.core.auth.AuthProvider; import com.datastax.oss.driver.api.core.auth.AuthenticationException; import com.datastax.oss.driver.api.core.auth.Authenticator; import com.datastax.oss.driver.api.core.config.DriverOption; import com.datastax.oss.driver.api.core.context.DriverContext; import com.datastax.oss.driver.api.core.metadata.EndPoint; import software.amazon.awssdk.auth.credentials.AwsCredentials; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; import software.amazon.awssdk.auth.signer.internal.Aws4SignerUtils; import software.amazon.awssdk.auth.signer.internal.SignerConstant; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain; import software.amazon.awssdk.services.sts.StsClient; import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider; import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; import static software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider.create; /** * This auth provider can be used with the Amazon MCS service to * authenticate with SigV4. It uses the AWSCredentialsProvider * interface provided by the official AWS Java SDK to provide * credentials for signing. */ public class SigV4AuthProvider implements AuthProvider { private static final byte[] SIGV4_INITIAL_RESPONSE_BYTES = "SigV4\0\0".getBytes(StandardCharsets.UTF_8); private static final ByteBuffer SIGV4_INITIAL_RESPONSE; static { ByteBuffer initialResponse = ByteBuffer.allocate(SIGV4_INITIAL_RESPONSE_BYTES.length); initialResponse.put(SIGV4_INITIAL_RESPONSE_BYTES); initialResponse.flip(); // According to the driver docs, it's safe to reuse a // read-only buffer, and in our case, the initial response has // no sensitive information SIGV4_INITIAL_RESPONSE = initialResponse.asReadOnlyBuffer(); } private static final int AWS_FRACTIONAL_TIMESTAMP_DIGITS = 3; // SigV4 expects three digits of nanoseconds for timestamps private static final DateTimeFormatter timestampFormatter = (new DateTimeFormatterBuilder()).appendInstant(AWS_FRACTIONAL_TIMESTAMP_DIGITS).toFormatter(); private static final byte[] NONCE_KEY = "nonce=".getBytes(StandardCharsets.UTF_8); private static final int EXPECTED_NONCE_LENGTH = 32; // These are static values because we don't need HTTP, but SigV4 assumes some amount of HTTP metadata private static final String CANONICAL_SERVICE = "cassandra"; private final AwsCredentialsProvider credentialsProvider; private final String signingRegion; /** * Create a new Provider, using the * DefaultAWSCredentialsProviderChain as its credentials provider. * The signing region is taking from the AWS_DEFAULT_REGION * environment variable or the "aws.region" system property. */ public SigV4AuthProvider() { this(create(), null); } private final static DriverOption REGION_OPTION = () -> "advanced.auth-provider.aws-region"; private final static DriverOption ROLE_OPTION = () -> "advanced.auth-provider.aws-role-arn"; /** * This constructor is provided so that the driver can create * instances of this class based on configuration. For example: * * <pre> * datastax-java-driver.advanced.auth-provider = { * aws-region = us-east-2 * class = software.aws.mcs.auth.SigV4AuthProvider * } * </pre> * * The signing region is taken from the * datastax-java-driver.advanced.auth-provider.aws-region * property, from the "aws.region" system property, or the * AWS_DEFAULT_REGION environment variable, in that order of * preference. * * For programmatic construction, use {@link #SigV4AuthProvider()} * or {@link #SigV4AuthProvider(AwsCredentialsProvider, String)}. * * @param driverContext the driver context for instance creation. * Unused for this plugin. */ public SigV4AuthProvider(DriverContext driverContext) { this(driverContext.getConfig().getDefaultProfile().getString(REGION_OPTION, getDefaultRegion()), driverContext.getConfig().getDefaultProfile().getString(ROLE_OPTION, null)); } /** * Create a new Provider, using the specified region. * @param region the region (e.g. us-east-1) to use for signing. A * null value indicates to use the AWS_REGION environment * variable, or the "aws.region" system property to configure it. */ public SigV4AuthProvider(final String region) { this(create(), region); } /** * Create a new Provider, using the specified region and IAM role to assume. * @param region the region (e.g. us-east-1) to use for signing. A * null value indicates to use the AWS_REGION environment * variable, or the "aws.region" system property to configure it. * @param roleArn The IAM Role ARN which the connecting client should assume before connecting with Amazon Keyspaces. */ public SigV4AuthProvider(final String region,final String roleArn) { this(Optional.ofNullable(roleArn).map(r->(AwsCredentialsProvider)createSTSRoleCredentialProvider(r,region)).orElse(create()), region); } /** * Create a new Provider, using the specified AWSCredentialsProvider and region. * @param credentialsProvider the credentials provider used to obtain signature material * @param region the region (e.g. us-east-1) to use for signing. A * null value indicates to use the AWS_REGION environment * variable, or the "aws.region" system property to configure it. */ public SigV4AuthProvider(@NotNull AwsCredentialsProvider credentialsProvider, final String region) { this.credentialsProvider = credentialsProvider; if (region == null) { this.signingRegion = getDefaultRegion(); } else { this.signingRegion = region.toLowerCase(); } if (this.signingRegion == null) { throw new IllegalStateException( "A region must be specified by constructor, AWS_REGION env variable, or aws.region system property" ); } } @Override public Authenticator newAuthenticator(EndPoint endPoint, String authenticator) throws AuthenticationException { return new SigV4Authenticator(); } @Override public void onMissingChallenge(EndPoint endPoint) { throw new AuthenticationException(endPoint, "SigV4 requires a challenge from the endpoint. None was sent"); } @Override public void close() { // We do not open any resources, so this is a NOOP } /** * This authenticator performs SigV4 MCS authentication. */ public class SigV4Authenticator implements Authenticator { @Override public CompletionStage<ByteBuffer> initialResponse() { return CompletableFuture.completedFuture(SIGV4_INITIAL_RESPONSE); } @Override public CompletionStage<ByteBuffer> evaluateChallenge(ByteBuffer challenge) { try { byte[] nonce = extractNonce(challenge); Instant requestTimestamp = Instant.now(); AwsCredentials credentials = credentialsProvider.resolveCredentials(); String signature = generateSignature(nonce, requestTimestamp, credentials); String response = String.format("signature=%s,access_key=%s,amzdate=%s", signature, credentials.accessKeyId(), timestampFormatter.format(requestTimestamp)); if (credentials instanceof AwsSessionCredentials) { response = response + ",session_token=" + ((AwsSessionCredentials)credentials).sessionToken(); } return CompletableFuture.completedFuture(ByteBuffer.wrap(response.getBytes(StandardCharsets.UTF_8))); } catch (UnsupportedEncodingException e) { throw new RuntimeException("This platform does not support the UTF-8encoding", e); } } @Override public CompletionStage<Void> onAuthenticationSuccess(ByteBuffer token) { return CompletableFuture.completedFuture(null); } } /** * Extracts the nonce value from the challenge */ static byte[] extractNonce(ByteBuffer challengeBuffer) { byte[] challenge = new byte[challengeBuffer.remaining()]; challengeBuffer.get(challenge); int nonceStart = indexOf(challenge, NONCE_KEY); if (nonceStart == -1) { throw new IllegalArgumentException("Did not find nonce in SigV4 challenge: " + new String(challenge, StandardCharsets.UTF_8)); } // We'll start extraction right after the nonce bytes nonceStart += NONCE_KEY.length; int nonceEnd = nonceStart; // Advance until we find the comma or hit the end of input while (nonceEnd < challenge.length && challenge[nonceEnd] != ',') { nonceEnd++; } int nonceLength = nonceEnd - nonceStart; if (nonceLength != EXPECTED_NONCE_LENGTH) { throw new IllegalArgumentException("Expected a nonce of " + EXPECTED_NONCE_LENGTH + " bytes but received " + nonceLength); } return Arrays.copyOfRange(challenge, nonceStart, nonceEnd); } private String generateSignature(byte[] nonce, Instant requestTimestamp, AwsCredentials credentials) throws UnsupportedEncodingException { String credentialScopeDate = Aws4SignerUtils.formatDateStamp(requestTimestamp.toEpochMilli()); String signingScope = String.format("%s/%s/%s/aws4_request", credentialScopeDate, signingRegion, CANONICAL_SERVICE); String nonceHash = sha256Digest(nonce); String canonicalRequest = canonicalizeRequest(credentials.accessKeyId(), signingScope, requestTimestamp, nonceHash); String stringToSign = String.format("%s\n%s\n%s\n%s", SignerConstant.AWS4_SIGNING_ALGORITHM, timestampFormatter.format(requestTimestamp), signingScope, sha256Digest(canonicalRequest)); byte[] signingKey = getSignatureKey(credentials.secretAccessKey(), credentialScopeDate, signingRegion, CANONICAL_SERVICE); byte[] signature = hmacSHA256(stringToSign, signingKey); return Hex.encodeHexString(signature, true); } private static final String AMZ_ALGO_HEADER = "X-Amz-Algorithm=" + SignerConstant.AWS4_SIGNING_ALGORITHM; private static final String AMZ_EXPIRES_HEADER = "X-Amz-Expires=900"; private static String canonicalizeRequest(String accessKey, String signingScope, Instant requestTimestamp, String payloadHash) throws UnsupportedEncodingException { List<String> queryStringHeaders = Arrays.asList( AMZ_ALGO_HEADER, String.format("X-Amz-Credential=%s%%2F%s", accessKey, URLEncoder.encode(signingScope, StandardCharsets.UTF_8.name())), "X-Amz-Date=" + URLEncoder.encode(timestampFormatter.format(requestTimestamp), StandardCharsets.UTF_8.name()), AMZ_EXPIRES_HEADER ); // IMPORTANT: This list must maintain alphabetical order for canonicalization Collections.sort(queryStringHeaders); String queryString = String.join("&", queryStringHeaders); return String.format("PUT\n/authenticate\n%s\nhost:%s\n\nhost\n%s", queryString, CANONICAL_SERVICE, payloadHash); } static String sha256Digest(byte[] bytes) { try { final MessageDigest md = MessageDigest.getInstance("SHA-256"); return Hex.encodeHexString(md.digest(bytes), true); } catch (NoSuchAlgorithmException e) { throw new RuntimeException("This platform does not support the SHA-256 digest algorithm", e); } } static String sha256Digest(String input) { return sha256Digest(input.getBytes(StandardCharsets.UTF_8)); } // Taken from https://docs.aws.amazon.com/general/latest/gr/signature-v4-examples.html#signature-v4-examples-java private static final String HMAC_ALGORITHM = "hmacSHA256"; static byte[] hmacSHA256(String data, byte[] key) { try { Mac mac = Mac.getInstance(HMAC_ALGORITHM); mac.init(new SecretKeySpec(key, HMAC_ALGORITHM)); return mac.doFinal(data.getBytes(StandardCharsets.UTF_8)); } catch (Exception e) { throw new RuntimeException("Failure computing HMAC-SHA256", e); } } static byte[] getSignatureKey(String key, String dateStamp, String regionName, String serviceName) { byte[] kSecret = ("AWS4" + key).getBytes(StandardCharsets.UTF_8); byte[] kDate = hmacSHA256(dateStamp, kSecret); byte[] kRegion = hmacSHA256(regionName, kDate); byte[] kService = hmacSHA256(serviceName, kRegion); byte[] kSigning = hmacSHA256("aws4_request", kService); return kSigning; } /* * Java does not natively provide a method for locating one array * within another, so we provide that here. While other libraries * also provide this, we want to minimize the dependencies that * this plugin brings in. */ static int indexOf(byte[] target, byte[] pattern) { final int lastCheckIndex = target.length - pattern.length; for (int i = 0; i <= lastCheckIndex; i++) { if (pattern[0] == target[i]) { int inner = 0; int outer = i; // A tight loop over target, comparing indices for (; inner < pattern.length && pattern[inner] == target[outer]; inner++, outer++) {} // If the inner loop reached the end of the pattern, then we have found the index if (inner == pattern.length) { return i; } } } // Loop exhaustion means we did not find it return -1; } /** * Creates a STS role credential provider * @param roleArn The ARN of the role to assume * @param stsRegion The region of the STS endpoint * @return The STS role credential provider */ private static StsAssumeRoleCredentialsProvider createSTSRoleCredentialProvider(@NotNull String roleArn, @NotNull String stsRegion) { final String sessionName="keyspaces-session-"+System.currentTimeMillis(); StsClient stsClient = StsClient.builder() .region(Region.of(stsRegion)) .build(); AssumeRoleRequest assumeRoleRequest=AssumeRoleRequest.builder() .roleArn(roleArn) .roleSessionName(sessionName) .build(); return StsAssumeRoleCredentialsProvider.builder() .stsClient(stsClient) .refreshRequest(assumeRoleRequest) .build(); } /** * Gets the default region for SigV4 if region is not provided. * @return Default region */ private static String getDefaultRegion() { DefaultAwsRegionProviderChain chain = new DefaultAwsRegionProviderChain(); return chain.getRegion().toString().toLowerCase(); } }