// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: MIT-0 package aws.location.demo.okhttp import com.amazonaws.DefaultRequest import com.amazonaws.auth.AWS4Signer import com.amazonaws.auth.AWSCredentialsProvider import com.amazonaws.http.HttpMethodName import com.amazonaws.util.IOUtils import okhttp3.HttpUrl import okhttp3.Interceptor import okhttp3.Request import okhttp3.Response import okio.Buffer import java.io.ByteArrayInputStream import java.net.URI class SigV4Interceptor( private val credentialsProvider: AWSCredentialsProvider, private val serviceName: String ) : Interceptor { override fun intercept(chain: Interceptor.Chain): Response { val originalRequest = chain.request() if (originalRequest.url().host().contains("amazonaws.com")) { val signer = if (originalRequest.url().encodedPath().contains("@")) { // the presence of "@" indicates that it doesn't need to be double URL-encoded AWS4Signer(false) } else { AWS4Signer() } val awsRequest = toAWSRequest(originalRequest, serviceName) signer.setServiceName(serviceName) signer.sign(awsRequest, credentialsProvider.credentials) return chain.proceed(toSignedOkHttpRequest(awsRequest, originalRequest)) } return chain.proceed(originalRequest) } companion object { fun toAWSRequest(request: Request, serviceName: String): DefaultRequest { // clone the request (AWS-style) so that it can be populated with credentials val dr = DefaultRequest(serviceName) // copy request info dr.httpMethod = HttpMethodName.valueOf(request.method()) with(request.url()) { dr.resourcePath = uri().path dr.endpoint = URI.create("${scheme()}://${host()}") // copy parameters for (p in queryParameterNames()) { if (p != "") { dr.addParameter(p, queryParameter(p)) } } } // copy headers for (h in request.headers().names()) { dr.addHeader(h, request.header(h)) } // copy the request body val bodyBytes = request.body()?.let { body -> val buffer = Buffer() body.writeTo(buffer) IOUtils.toByteArray(buffer.inputStream()) } dr.content = ByteArrayInputStream(bodyBytes ?: ByteArray(0)) return dr } fun toSignedOkHttpRequest( awsRequest: DefaultRequest, originalRequest: Request ): Request { // copy signed request back into an OkHttp Request val builder = Request.Builder() // copy headers from the signed request for ((k, v) in awsRequest.headers) { builder.addHeader(k, v) } // start building an HttpUrl val urlBuilder = HttpUrl.Builder() .host(awsRequest.endpoint.host) .scheme(awsRequest.endpoint.scheme) .encodedPath(awsRequest.resourcePath) // copy parameters from the signed request for ((k, v) in awsRequest.parameters) { urlBuilder.addQueryParameter(k, v) } return builder.url(urlBuilder.build()) .method(originalRequest.method(), originalRequest.body()) .build() } } }