// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may not
// use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
// either express or implied. See the License for the specific language governing
// permissions and limitations under the License.
package s3util
import (
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"testing"
"time"
"github.com/aws/amazon-ssm-agent/agent/appconfig"
"github.com/aws/amazon-ssm-agent/agent/context"
"github.com/aws/amazon-ssm-agent/agent/log"
contextmocks "github.com/aws/amazon-ssm-agent/agent/mocks/context"
logmocks "github.com/aws/amazon-ssm-agent/agent/mocks/log"
"github.com/aws/amazon-ssm-agent/agent/sdkutil/retryer"
identityMocks "github.com/aws/amazon-ssm-agent/common/identity/mocks"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
type MockedHttpProvider struct {
mock.Mock
}
func (m *MockedHttpProvider) Head(url string) (*http.Response, error) {
args := m.Called(url)
return args.Get(0).(*http.Response), args.Error(1)
}
func setBucketRegionFromSignedHeadBucketRequest(bucketRegion string) {
getBucketRegionFromSignedHeadBucketRequestFunc = func(context context.T, region, regionalEndpoint, bucketName string) string {
return bucketRegion
}
}
func setS3Endpoint(region, endpoint string, err error) {
getS3Endpoint = func(context context.T, region string) (string, error) {
return endpoint, err
}
}
func setS3FallbackEndpoint(region, endpoint string) {
getFallbackS3EndpointFunc = func(context context.T, region string) string {
return endpoint
}
}
func TestBucketRegion_WithHeadBucketRequestSuccessful(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("us-east-1")
setS3Endpoint("us-east-1", "", fmt.Errorf("invalid region"))
resp := &http.Response{
StatusCode: 200,
Header: http.Header{
bucketRegionHeader: []string{"us-east-1"},
},
}
var err error = nil
httpProvider := &MockedHttpProvider{}
httpProvider.On("Head", "https://bucket-1.s3.amazonaws.com").Return(resp, err)
actual := getBucketRegion(contextmocks.NewMockDefault(), "us-east-1", "bucket-1", httpProvider)
expected := "us-east-1"
assert.Equal(t, expected, actual)
}
func TestGetBucketRegion_NoError_InvalidS3Endpoint_ReturnsRegionFromFallback(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
setS3Endpoint("us-east-1", "", fmt.Errorf("invalid region"))
setS3FallbackEndpoint("us-east-1", "s3.amazonaws.com")
resp := &http.Response{
StatusCode: 200,
Header: http.Header{
bucketRegionHeader: []string{"us-east-1"},
},
}
var err error = nil
httpProvider := &MockedHttpProvider{}
httpProvider.On("Head", "https://bucket-1.s3.amazonaws.com").Return(resp, err)
actual := getBucketRegion(contextmocks.NewMockDefault(), "us-east-1", "bucket-1", httpProvider)
expected := "us-east-1"
assert.Equal(t, expected, actual)
}
func TestGetBucketRegion_NoError_InvalidFallbackS3Endpoint_ReturnsRegionFroms3(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
setS3Endpoint("us-east-1", "s3.us-east-1.amazonaws.com", nil)
setS3FallbackEndpoint("us-east-1", "")
resp := &http.Response{
StatusCode: 200,
Header: http.Header{
bucketRegionHeader: []string{"us-east-1"},
},
}
var err error = nil
httpProvider := &MockedHttpProvider{}
httpProvider.On("Head", "https://bucket-1.s3.us-east-1.amazonaws.com").Return(resp, err)
actual := getBucketRegion(contextmocks.NewMockDefault(), "us-east-1", "bucket-1", httpProvider)
expected := "us-east-1"
assert.Equal(t, expected, actual)
}
func TestGetBucketRegion_NoError_InvalidS3Endpoint_Error(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
setS3Endpoint("us-east-1", "", fmt.Errorf("invalid region"))
setS3FallbackEndpoint("us-east-1", "")
resp := &http.Response{
StatusCode: 200,
Header: http.Header{
bucketRegionHeader: []string{"us-east-1"},
},
}
var err error = nil
httpProvider := &MockedHttpProvider{}
httpProvider.On("Head", "https://bucket-1.s3.us-east-1.amazonaws.com").Return(resp, err)
actual := getBucketRegion(contextmocks.NewMockDefault(), "us-east-1", "bucket-1", httpProvider)
assert.Equal(t, "", actual)
}
func TestGetBucketRegion_NoError_NoRegionInResponse_ReturnsEmptyString(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
setS3Endpoint("us-east-1", "s3.us-east-1.amazonaws.com", nil)
setS3FallbackEndpoint("us-east-1", "s3.amazonaws.com")
resp := &http.Response{
StatusCode: 401,
}
var err error = nil
httpProvider := &MockedHttpProvider{}
httpProvider.On("Head", "https://bucket-1.s3.us-east-1.amazonaws.com").Return(resp, err)
httpProvider.On("Head", "https://bucket-1.s3.amazonaws.com").Return(resp, err)
actual := getBucketRegion(contextmocks.NewMockDefault(), "us-east-1", "bucket-1", httpProvider)
assert.Equal(t, "", actual)
}
func TestGetBucketRegion_NoError_RegionInResponse_ReturnsRegion(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
setS3Endpoint("us-east-1", "s3.us-east-1.amazonaws.com", nil)
setS3FallbackEndpoint("us-east-1", "s3.amazonaws.com")
resp := &http.Response{
StatusCode: 301,
Header: http.Header{
bucketRegionHeader: []string{"eu-west-1"},
},
}
var err error = nil
httpProvider := &MockedHttpProvider{}
httpProvider.On("Head", "https://bucket-1.s3.us-east-1.amazonaws.com").Return(resp, err)
actual := getBucketRegion(contextmocks.NewMockDefault(), "us-east-1", "bucket-1", httpProvider)
assert.Equal(t, "eu-west-1", actual)
}
func TestGetBucketRegion_AllUrlsFail_ReturnsEmptyString(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
setS3Endpoint("us-east-1", "s3.us-east-1.amazonaws.com", nil)
setS3FallbackEndpoint("us-east-1", "s3.amazonaws.com")
var resp *http.Response = nil
err := fmt.Errorf("failed")
httpProvider := &MockedHttpProvider{}
httpProvider.On("Head", "https://bucket-1.s3.us-east-1.amazonaws.com").Return(resp, err)
httpProvider.On("Head", "https://bucket-1.s3.amazonaws.com").Return(resp, err)
httpProvider.On("Head", "http://bucket-1.s3.us-east-1.amazonaws.com").Return(resp, err)
httpProvider.On("Head", "http://bucket-1.s3.amazonaws.com").Return(resp, err)
actual := getBucketRegion(contextmocks.NewMockDefault(), "us-east-1", "bucket-1", httpProvider)
assert.Equal(t, "", actual)
httpProvider.AssertExpectations(t)
}
func TestGetS3CrossRegionCapableSession_regionFromHead_noConfigOverrides(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
setupMocksForGetS3CrossRegionCapableSession("us-east-1", "bucket-1", "eu-west-1")
sess, err := GetS3CrossRegionCapableSession(contextmocks.NewMockDefault(), "bucket-1")
assert.NotNil(t, sess)
assert.Equal(t, *sess.Config.Region, "eu-west-1")
assert.Nil(t, sess.Config.Endpoint)
assert.NotNil(t, sess.Config.HTTPClient.Transport)
_, correctType := sess.Config.HTTPClient.Transport.(*s3BucketRegionHeaderCapturingTransport)
assert.True(t, correctType)
assert.Nil(t, err)
}
func TestGetS3CrossRegionCapableSession_noRegionFromHead_noConfigOverrides(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
identityMock := &identityMocks.IAgentIdentity{}
identityMock.On("Region").Return("cn-north-1", nil)
contextMock := new(contextmocks.Mock)
contextMock.On("Identity").Return(identityMock)
contextMock.On("Log").Return(logmocks.NewMockLog())
contextMock.On("AppConfig").Return(appconfig.DefaultConfig())
setupMocksForGetS3CrossRegionCapableSession("cn-north-1", "bucket-1", "")
sess, err := GetS3CrossRegionCapableSession(contextMock, "bucket-1")
assert.NotNil(t, sess)
assert.Equal(t, "cn-north-1", *sess.Config.Region)
assert.Nil(t, sess.Config.Endpoint)
assert.NotNil(t, sess.Config.HTTPClient.Transport)
_, correctType := sess.Config.HTTPClient.Transport.(*s3BucketRegionHeaderCapturingTransport)
assert.True(t, correctType)
assert.Nil(t, err)
}
func TestGetS3CrossRegionCapableSession_regionFromHead_withConfigOverrides(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
appConfig := appconfig.DefaultConfig()
appConfig.S3.Endpoint = "https://custom.endpoint.com"
identityMock := &identityMocks.IAgentIdentity{}
identityMock.On("Region").Return("us-east-1", nil)
contextMock := new(contextmocks.Mock)
contextMock.On("Identity").Return(identityMock)
contextMock.On("Log").Return(logmocks.NewMockLog())
contextMock.On("AppConfig").Return(appConfig)
setupMocksForGetS3CrossRegionCapableSession("us-east-1", "bucket-1", "eu-west-1")
sess, err := GetS3CrossRegionCapableSession(contextMock, "bucket-1")
assert.NotNil(t, sess)
assert.Equal(t, "eu-west-1", *sess.Config.Region)
assert.Equal(t, "https://custom.endpoint.com", *sess.Config.Endpoint)
assert.NotNil(t, sess.Config.HTTPClient.Transport)
_, correctType := sess.Config.HTTPClient.Transport.(*s3BucketRegionHeaderCapturingTransport)
assert.True(t, correctType)
assert.Nil(t, err)
}
func TestGetS3CrossRegionCapableSession_noRegionFromHead_withConfigOverrides(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
appConfig := appconfig.DefaultConfig()
appConfig.S3.Endpoint = "https://custom.endpoint.com.cn"
identityMock := &identityMocks.IAgentIdentity{}
identityMock.On("Region").Return("cn-north-1", nil)
contextMock := new(contextmocks.Mock)
contextMock.On("Identity").Return(identityMock)
contextMock.On("Log").Return(logmocks.NewMockLog())
contextMock.On("AppConfig").Return(appConfig)
setupMocksForGetS3CrossRegionCapableSession("cn-north-1", "bucket-1", "")
sess, err := GetS3CrossRegionCapableSession(contextMock, "bucket-1")
assert.NotNil(t, sess)
assert.Equal(t, "cn-north-1", *sess.Config.Region)
assert.Equal(t, "https://custom.endpoint.com.cn", *sess.Config.Endpoint)
assert.NotNil(t, sess.Config.HTTPClient.Transport)
_, correctType := sess.Config.HTTPClient.Transport.(*s3BucketRegionHeaderCapturingTransport)
assert.True(t, correctType)
assert.Nil(t, err)
}
func setupMocksForGetS3CrossRegionCapableSession(instanceRegion, bucketName, headBucketResponse string) {
setBucketRegionFromSignedHeadBucketRequest("")
setupMockHeadBucketResponse(bucketName, instanceRegion, headBucketResponse)
makeAwsConfig = func(context context.T, service, region string) *aws.Config {
result := aws.NewConfig()
result.Region = aws.String(region)
result.Credentials = credentials.NewCredentials(&mockCredentialsProvider{})
return result
}
}
func setupMockHeadBucketResponse(bucketName, instanceRegion, headBucketResponse string) {
setBucketRegionFromSignedHeadBucketRequest("")
s3Endpoint := "s3." + instanceRegion + ".amazonaws.com"
s3FallbackEndpoint := "s3.amazonaws.com"
if strings.HasPrefix(instanceRegion, "cn-") {
s3Endpoint += ".cn"
s3FallbackEndpoint = "s3.cn-north-1.amazonaws.com.cn"
}
setS3Endpoint(instanceRegion, s3Endpoint, nil)
setS3FallbackEndpoint(instanceRegion, s3FallbackEndpoint)
getHttpProvider = func(log.T, appconfig.SsmagentConfig) HttpProvider {
provider := &MockedHttpProvider{}
resp := &http.Response{
Header: http.Header{},
}
var err error = nil
if headBucketResponse != "" {
resp.Header.Add(bucketRegionHeader, headBucketResponse)
}
provider.On("Head", "https://"+bucketName+"."+s3Endpoint).Return(resp, err)
provider.On("Head", "https://"+bucketName+"."+s3FallbackEndpoint).Return(resp, err)
return provider
}
}
func TestRedirect_RedirectResponse_RetryWithCorrectRegion(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
appConfig := appconfig.DefaultConfig()
identityMock := &identityMocks.IAgentIdentity{}
identityMock.On("Region").Return("cn-northwest-1", nil)
contextMock := new(contextmocks.Mock)
contextMock.On("Identity").Return(identityMock)
contextMock.On("Log").Return(logmocks.NewMockLog())
contextMock.On("AppConfig").Return(appConfig)
setupMocksForGetS3CrossRegionCapableSession("cn-northwest-1", "bucket-1", "")
sess, err := GetS3CrossRegionCapableSession(contextMock, "bucket-1")
assert.Nil(t, err)
trans, transTypeOk := sess.Config.HTTPClient.Transport.(*s3BucketRegionHeaderCapturingTransport)
assert.True(t, transTypeOk)
delegate := newMockTransport()
trans.delegate = delegate
svc := s3.New(sess)
input := &s3.HeadBucketInput{
Bucket: aws.String("bucket-1"),
}
// First attempt goes to the instance's home region. S3 returns a 301 PermanentRedirect
// response with header indicating the correct region for the bucket.
req1Url := "https://bucket-1.s3.cn-northwest-1.amazonaws.com.cn/"
resp1Header := http.Header{}
resp1Header.Add(bucketRegionHeader, "cn-north-1")
resp1 := &http.Response{
Status: "PermanentRedirect",
StatusCode: 301,
Header: resp1Header,
Body: ioutil.NopCloser(strings.NewReader("body contents")),
}
// The retry goes to the correct endpoint for the bucket, which is cn-north-1.
req2Url := "https://bucket-1.s3.cn-north-1.amazonaws.com.cn/"
resp2Header := http.Header{}
resp2 := &http.Response{
Status: "Success",
StatusCode: 200,
Header: resp2Header,
Body: ioutil.NopCloser(strings.NewReader("")),
}
delegate.AddResponse(req1Url, resp1)
delegate.AddResponse(req2Url, resp2)
_, err = svc.HeadBucket(input)
assert.Nil(t, err)
assert.Equal(t, 2, len(delegate.requestURLsReceived))
assert.Equal(t, "https://bucket-1.s3.cn-northwest-1.amazonaws.com.cn/", delegate.requestURLsReceived[0])
assert.Equal(t, "https://bucket-1.s3.cn-north-1.amazonaws.com.cn/", delegate.requestURLsReceived[1])
// Cleanup
getBucketRegionMap().Remove("bucket-1")
}
func TestRedirect_BadSigningRegionResponse_RetryWithCorrectRegion(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
setupMocksForGetS3CrossRegionCapableSession("us-east-1", "bucket-1", "")
sess, err := GetS3CrossRegionCapableSession(contextmocks.NewMockDefault(), "bucket-1")
assert.Nil(t, err)
trans, transTypeOk := sess.Config.HTTPClient.Transport.(*s3BucketRegionHeaderCapturingTransport)
assert.True(t, transTypeOk)
delegate := newMockTransport()
trans.delegate = delegate
svc := s3.New(sess)
input := &s3.HeadBucketInput{
Bucket: aws.String("bucket-1"),
}
// For the first attempt, the client is initialized for us-east-1.
// However, DNS is able to resolve the virtual hosted bucket URL
// to the correct regional endpoint in eu-west-1. The eu-west-1
// endpoint returns an HTTP 400 "wrong signing region" error, with
// the bucket region set in the response body.
req1Url := "https://bucket-1.s3.amazonaws.com/"
resp1Header := http.Header{}
resp1Body := makeAuthorizationHeaderMalformedErrorResponse("us-east-1", "eu-west-1")
resp1 := &http.Response{
Status: "",
StatusCode: 400,
Header: resp1Header,
Body: ioutil.NopCloser(strings.NewReader(resp1Body)),
}
// The retry should have the correct regional endpoint in the request URL
req2Url := "https://bucket-1.s3.eu-west-1.amazonaws.com/"
resp2Header := http.Header{}
resp2 := &http.Response{
Status: "Success",
StatusCode: 200,
Header: resp2Header,
Body: ioutil.NopCloser(strings.NewReader("")),
}
delegate.AddResponse(req1Url, resp1)
delegate.AddResponse(req2Url, resp2)
_, err = svc.HeadBucket(input)
assert.Nil(t, err)
assert.Equal(t, 2, len(delegate.requestURLsReceived))
assert.Equal(t, "https://bucket-1.s3.amazonaws.com/", delegate.requestURLsReceived[0])
assert.Equal(t, "https://bucket-1.s3.eu-west-1.amazonaws.com/", delegate.requestURLsReceived[1])
// Cleanup
getBucketRegionMap().Remove("bucket-1")
}
func TestRedirect_CachedBucketRegion_FirstRequestGoesToCorrectRegion(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
// The correct region for the bucket is already cached. The first
// attempt should go to the correct region.
getBucketRegionMap().Put("bucket-1", "cn-north-1")
setupMocksForGetS3CrossRegionCapableSession("cn-northwest-1", "bucket-1", "")
sess, err := GetS3CrossRegionCapableSession(contextmocks.NewMockDefault(), "bucket-1")
assert.Nil(t, err)
trans, transTypeOk := sess.Config.HTTPClient.Transport.(*s3BucketRegionHeaderCapturingTransport)
assert.True(t, transTypeOk)
delegate := newMockTransport()
trans.delegate = delegate
svc := s3.New(sess)
input := &s3.GetBucketLocationInput{
Bucket: aws.String("bucket-1"),
}
// The first attempt goes to the correct endpoint for the bucket, which is cn-north-1.
reqUrl := "https://s3.cn-north-1.amazonaws.com.cn/bucket-1?location="
respHeader := http.Header{}
resp := &http.Response{
Status: "Success",
StatusCode: 200,
Header: respHeader,
Body: ioutil.NopCloser(strings.NewReader(makeGetBucketLocationResponseBodyText("cn-north-1"))),
}
delegate.AddResponse(reqUrl, resp)
output, err := svc.GetBucketLocation(input)
assert.Nil(t, err)
assert.Equal(t, "cn-north-1", *output.LocationConstraint)
assert.Equal(t, 1, len(delegate.requestURLsReceived))
assert.Equal(t, "https://s3.cn-north-1.amazonaws.com.cn/bucket-1?location=", delegate.requestURLsReceived[0])
// Cleanup
getBucketRegionMap().Remove("bucket-1")
}
type handlerTestCaseData struct {
bucketName string
op *request.Operation
input interface{}
output interface{}
}
var handlerTestCases = []handlerTestCaseData{
{
bucketName: "bucket-1",
op: &request.Operation{
Name: "PutObject",
HTTPMethod: "PUT",
HTTPPath: "/{Bucket}/{Key+}",
},
input: &s3.PutObjectInput{
Body: strings.NewReader("body contents"),
Key: aws.String("a/b"),
Bucket: aws.String("bucket-1"),
},
output: &s3.PutObjectOutput{},
},
{
bucketName: "bucket-1",
op: &request.Operation{
Name: "CreateMultipartUpload",
HTTPMethod: "POST",
HTTPPath: "/{Bucket}/{Key+}?uploads",
},
input: &s3.CreateMultipartUploadInput{
Bucket: aws.String("bucket-1"),
Key: aws.String("a/b"),
ContentType: aws.String("text/plain"),
ACL: aws.String("bucket-owner-full-control"),
},
output: &s3.CreateMultipartUploadOutput{},
},
{
bucketName: "bucket-1",
op: &request.Operation{
Name: "UploadPart",
HTTPMethod: "PUT",
HTTPPath: "/{Bucket}/{Key+}",
},
input: &s3.UploadPartInput{
Bucket: aws.String("bucket-1"),
Key: aws.String("a/b"),
Body: strings.NewReader("body contents"),
UploadId: aws.String("1324"),
PartNumber: aws.Int64(1),
},
output: &s3.UploadPartOutput{},
},
}
func TestHandlerAllCases(t *testing.T) {
for _, d := range handlerTestCases {
validationHandlerTestCase(t, d.bucketName, "cn-northwest-1", "cn-north-1", d.op, d.input, d.output)
validationHandlerTestCase(t, d.bucketName, "us-east-1", "us-west-1", d.op, d.input, d.output)
validationHandlerTestCase(t, d.bucketName, "us-gov-east-1", "us-gov-west-1", d.op, d.input, d.output)
retryHandlerTestCase(t, d.bucketName, "cn-northwest-1", "cn-north-1", d.op, d.input, d.output)
retryHandlerTestCase(t, d.bucketName, "us-east-1", "us-west-1", d.op, d.input, d.output)
retryHandlerTestCase(t, d.bucketName, "us-gov-east-1", "us-gov-west-1", d.op, d.input, d.output)
}
}
func validationHandlerTestCase(t *testing.T, bucketName, oldRegion, newRegion string,
op *request.Operation, input, output interface{}) {
// The request initially targets the old region
retryer := retryer.SsmRetryer{}
retryer.NumMaxRetries = 3
config := &aws.Config{
Retryer: retryer,
SleepDelay: func(d time.Duration) {
time.Sleep(d)
},
Region: &oldRegion,
}
// The correct region for the bucket has been discovered by s3BucketRegionHeaderCapturingTransport
getBucketRegionMap().Put(bucketName, newRegion)
sess := session.New(config)
sess.Handlers.Validate.PushBackNamed(makeS3RegionCorrectingValidateHandler(logmocks.NewMockLog()))
svc := s3.New(sess)
request := svc.NewRequest(op, input, output)
assert.Equal(t, oldRegion, *request.Config.Region)
request.Build()
var newRegionS3EndpointHostname string
if strings.HasPrefix(newRegion, "cn-") {
newRegionS3EndpointHostname = "s3." + newRegion + ".amazonaws.com.cn"
} else {
newRegionS3EndpointHostname = "s3." + newRegion + ".amazonaws.com"
}
assert.Equal(t, newRegion, *request.Config.Region)
assert.Equal(t, "https://"+newRegionS3EndpointHostname, request.ClientInfo.Endpoint)
assert.Equal(t, newRegion, request.ClientInfo.SigningRegion)
assert.Equal(t, bucketName+"."+newRegionS3EndpointHostname, request.HTTPRequest.URL.Host)
// Cleanup
getBucketRegionMap().Remove(bucketName)
}
func retryHandlerTestCase(t *testing.T, bucketName, oldRegion, newRegion string,
op *request.Operation, input, output interface{}) {
// The request initially targets the old region
retryer := retryer.SsmRetryer{}
retryer.NumMaxRetries = 3
config := &aws.Config{
Retryer: retryer,
SleepDelay: func(d time.Duration) {
time.Sleep(d)
},
Region: &oldRegion,
}
sess := session.New(config)
sess.Handlers.Retry.PushFrontNamed(makeS3RegionCorrectingRetryHandler(logmocks.NewMockLog()))
svc := s3.New(sess)
request := svc.NewRequest(op, input, output)
assert.Equal(t, oldRegion, *request.Config.Region)
request.Build()
// Simulate sending the request. S3 returns a 301, and the Transport
// captures the bucket region from the response headers.
getBucketRegionMap().Put(bucketName, newRegion)
request.HTTPResponse = &http.Response{
StatusCode: 301,
}
// Invoke the handler
request.Handlers.Retry.Run(request)
var newRegionS3EndpointHostname string
if strings.HasPrefix(newRegion, "cn-") {
newRegionS3EndpointHostname = "s3." + newRegion + ".amazonaws.com.cn"
} else {
newRegionS3EndpointHostname = "s3." + newRegion + ".amazonaws.com"
}
assert.Equal(t, newRegion, *request.Config.Region)
assert.Equal(t, "https://"+newRegionS3EndpointHostname, request.ClientInfo.Endpoint)
assert.Equal(t, newRegion, request.ClientInfo.SigningRegion)
assert.Equal(t, bucketName+"."+newRegionS3EndpointHostname, request.HTTPRequest.URL.Host)
// Cleanup
getBucketRegionMap().Remove(bucketName)
}
func TestValidateHandler_EndpointLookupFailure_NoChangeToRequest(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
// bucket-1 somehow got mapped to an unknown region
getBucketRegionMap().Put("bucket-1", "unknown-region-1")
config := &aws.Config{
Region: aws.String("us-east-1"),
// This simulates an endpoint lookup failure
EndpointResolver: mockEndpointResolver{
endpoints.ResolvedEndpoint{},
fmt.Errorf("ERROR"),
},
}
op := &request.Operation{
Name: "PutObject",
HTTPMethod: "PUT",
HTTPPath: "/{Bucket}/{Key+}",
}
input := &s3.PutObjectInput{
Body: strings.NewReader("body contents"),
Key: aws.String("a/b"),
Bucket: aws.String("bucket-1"),
}
output := &s3.PutObjectOutput{}
sess := session.New(config)
svc := s3.New(sess)
request := svc.NewRequest(op, input, output)
handler := makeS3RegionCorrectingValidateHandler(logmocks.NewMockLog())
handler.Fn(request)
assert.Equal(t, "us-east-1", *request.Config.Region)
// Cleanup
getBucketRegionMap().Remove("bucket-1")
}
func TestRetryHandler_EndpointLookupFailure_NoChangeToRequest(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
// bucket-1 somehow got mapped to an unknown region
getBucketRegionMap().Put("bucket-1", "unknown-region-1")
config := &aws.Config{
Region: aws.String("us-east-1"),
// This simulates an endpoint lookup failure
EndpointResolver: mockEndpointResolver{
endpoints.ResolvedEndpoint{},
fmt.Errorf("ERROR"),
},
}
op := &request.Operation{
Name: "PutObject",
HTTPMethod: "PUT",
HTTPPath: "/{Bucket}/{Key+}",
}
input := &s3.PutObjectInput{
Body: strings.NewReader("body contents"),
Key: aws.String("a/b"),
Bucket: aws.String("bucket-1"),
}
output := &s3.PutObjectOutput{}
sess := session.New(config)
svc := s3.New(sess)
request := svc.NewRequest(op, input, output)
request.HTTPResponse = &http.Response{
StatusCode: 301,
}
handler := makeS3RegionCorrectingRetryHandler(logmocks.NewMockLog())
handler.Fn(request)
assert.Equal(t, "us-east-1", *request.Config.Region)
// Cleanup
getBucketRegionMap().Remove("bucket-1")
}
func TestFixupRequest_NoHttpRequestUrl_NoCustomEndpoint_SetsRegionAndEndpoint(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
request := &request.Request{
Config: aws.Config{
Region: aws.String("us-east-1"),
},
ClientInfo: metadata.ClientInfo{},
HTTPRequest: &http.Request{},
}
fixupRequest(logmocks.NewMockLog(), request, "eu-west-1")
assert.Equal(t, "eu-west-1", *request.Config.Region)
assert.Nil(t, request.Config.Endpoint)
assert.Equal(t, "https://s3.eu-west-1.amazonaws.com", request.ClientInfo.Endpoint)
}
func TestFixupRequest_NoHttpRequestUrl_CustomEndpoint_SetsRegionAndEndpoint(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
request := &request.Request{
Config: aws.Config{
Region: aws.String("us-east-1"),
Endpoint: aws.String("https://my-custom-endpoint.com"),
},
ClientInfo: metadata.ClientInfo{},
HTTPRequest: &http.Request{},
}
fixupRequest(logmocks.NewMockLog(), request, "eu-west-1")
assert.Equal(t, "eu-west-1", *request.Config.Region)
assert.Equal(t, "https://my-custom-endpoint.com", *request.Config.Endpoint)
assert.Equal(t, "https://my-custom-endpoint.com", request.ClientInfo.Endpoint)
}
func TestFixupRequest_HttpRequestUrl_NoCustomEndpoint_SetsRegionAndHttpRequestUrl(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
request := &request.Request{
Config: aws.Config{
Region: aws.String("us-east-1"),
EndpointResolver: endpoints.DefaultResolver(),
},
ClientInfo: metadata.ClientInfo{
Endpoint: "http://s3.amazonaws.com",
},
HTTPRequest: &http.Request{
URL: &url.URL{
Scheme: "https",
Host: "s3.amazonaws.com",
},
},
}
fixupRequest(logmocks.NewMockLog(), request, "eu-west-1")
assert.Equal(t, "eu-west-1", *request.Config.Region)
assert.Equal(t, "https://s3.eu-west-1.amazonaws.com", request.HTTPRequest.URL.String())
assert.Nil(t, request.Config.Endpoint)
assert.Equal(t, "https://s3.eu-west-1.amazonaws.com", request.ClientInfo.Endpoint)
}
func TestFixupRequest_HttpRequestUrlPresent_RespectsCustomEndpoint(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
request := &request.Request{
Config: aws.Config{
Region: aws.String("us-east-1"),
Endpoint: aws.String("https://my-custom-endpoint.com"),
EndpointResolver: endpoints.DefaultResolver(),
},
ClientInfo: metadata.ClientInfo{
Endpoint: "https://my-custom-endpoint.com",
},
HTTPRequest: &http.Request{
URL: &url.URL{
Scheme: "https",
Host: "s3.amazonaws.com",
},
},
}
fixupRequest(logmocks.NewMockLog(), request, "eu-west-1")
assert.Equal(t, "eu-west-1", *request.Config.Region)
assert.Equal(t, "https://my-custom-endpoint.com", request.HTTPRequest.URL.String())
assert.Equal(t, "https://my-custom-endpoint.com", *request.Config.Endpoint)
assert.Equal(t, "https://my-custom-endpoint.com", request.ClientInfo.Endpoint)
}
func TestFixupRequest_HttpRequestUrlPresent_VirtualHostedUrlWithKey(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
request := &request.Request{
Config: aws.Config{
Region: aws.String("us-east-1"),
EndpointResolver: endpoints.DefaultResolver(),
},
ClientInfo: metadata.ClientInfo{
Endpoint: "https://s3.amazonaws.com",
},
HTTPRequest: &http.Request{
URL: &url.URL{
Scheme: "https",
Host: "bucket-1.s3.amazonaws.com",
Path: "/key",
},
},
}
fixupRequest(logmocks.NewMockLog(), request, "eu-west-1")
assert.Equal(t, "eu-west-1", *request.Config.Region)
assert.Equal(t, "https://bucket-1.s3.eu-west-1.amazonaws.com/key", request.HTTPRequest.URL.String())
assert.Nil(t, request.Config.Endpoint)
assert.Equal(t, "https://s3.eu-west-1.amazonaws.com", request.ClientInfo.Endpoint)
}
func TestFixupRequest_HttpRequestUrlPresent_PathStyleUrlWithKey(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
request := &request.Request{
Config: aws.Config{
Region: aws.String("us-east-1"),
EndpointResolver: endpoints.DefaultResolver(),
},
ClientInfo: metadata.ClientInfo{
Endpoint: "https://s3.amazonaws.com",
},
HTTPRequest: &http.Request{
URL: &url.URL{
Scheme: "https",
Host: "s3.amazonaws.com",
Path: "/bucket-1/key",
},
},
}
fixupRequest(logmocks.NewMockLog(), request, "eu-west-1")
assert.Equal(t, "eu-west-1", *request.Config.Region)
assert.Equal(t, "https://s3.eu-west-1.amazonaws.com/bucket-1/key", request.HTTPRequest.URL.String())
assert.Nil(t, request.Config.Endpoint)
assert.Equal(t, "https://s3.eu-west-1.amazonaws.com", request.ClientInfo.Endpoint)
}
func TestNewS3BucketRegionHeaderCapturingTransport(t *testing.T) {
transport := newS3BucketRegionHeaderCapturingTransport(logmocks.NewMockLog(), appconfig.SsmagentConfig{})
_, goodType := transport.delegate.(*http.Transport)
assert.True(t, goodType)
}
func TestRoundTrip_bucketRegionHeaderPresent(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
requestUrl := "https://test-bucket.s3.cn-northwest-1.amazonaws.com.cn/a/b"
request := makeRequest("GET", requestUrl)
responseHeader := http.Header{}
responseHeader.Add(bucketRegionHeader, "cn-north-1")
responseHeader.Add("x-amz-request-id", "123")
responseBodyContents := makeRedirectResponseBodyText("test-bucket.s3.cn-north-1.amazonaws.com.cn", "test-bucket")
response := makeResponse(301, responseHeader, responseBodyContents)
delegate := newMockTransport()
delegate.AddResponse(requestUrl, response)
transport := newS3BucketRegionHeaderCapturingTransportForTest(delegate)
actualResponse, err := transport.RoundTrip(request)
assert.NotNil(t, actualResponse)
assert.Nil(t, err)
cachedRegion, ok := getBucketRegionMap().Get("test-bucket")
assert.True(t, ok)
assert.Equal(t, "cn-north-1", cachedRegion)
// Cleanup
getBucketRegionMap().Remove("test-bucket")
}
func TestRoundTrip_bucketRegionInErrorResponseBody(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
requestUrl := "https://test-bucket.s3.cn-northwest-1.amazonaws.com.cn/a/b"
request := makeRequest("GET", requestUrl)
responseHeader := http.Header{}
responseBodyContents := makeAuthorizationHeaderMalformedErrorResponse("cn-northwest-1", "cn-north-1")
response := makeResponse(400, responseHeader, responseBodyContents)
delegate := newMockTransport()
delegate.AddResponse(requestUrl, response)
transport := newS3BucketRegionHeaderCapturingTransportForTest(delegate)
actualResponse, err := transport.RoundTrip(request)
assert.NotNil(t, actualResponse)
assert.Nil(t, err)
cachedRegion, ok := getBucketRegionMap().Get("test-bucket")
assert.True(t, ok)
assert.Equal(t, "cn-north-1", cachedRegion)
// Cleanup
getBucketRegionMap().Remove("test-bucket")
}
func TestRoundTrip_endpointInErrorResponseBody(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
requestUrl := "https://test-bucket.s3.cn-northwest-1.amazonaws.com.cn/a/b"
request := makeRequest("GET", requestUrl)
responseHeader := http.Header{}
responseBodyContents := makeRedirectResponseBodyText("test-bucket.s3.cn-north-1.amazonaws.com.cn", "test-bucket")
response := makeResponse(301, responseHeader, responseBodyContents)
delegate := newMockTransport()
delegate.AddResponse(requestUrl, response)
transport := newS3BucketRegionHeaderCapturingTransportForTest(delegate)
actualResponse, err := transport.RoundTrip(request)
assert.NotNil(t, actualResponse)
assert.Nil(t, err)
cachedRegion, ok := getBucketRegionMap().Get("test-bucket")
assert.True(t, ok)
assert.Equal(t, "cn-north-1", cachedRegion)
// Cleanup
getBucketRegionMap().Remove("test-bucket")
}
func TestRoundTrip_bucketRegionNotPresent(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
requestUrl := "https://test-bucket.s3.cn-north-1.amazonaws.com.cn/a/b"
request := makeRequest("GET", requestUrl)
response := makeResponse(200, http.Header{}, "Success")
delegate := newMockTransport()
delegate.AddResponse(requestUrl, response)
transport := newS3BucketRegionHeaderCapturingTransportForTest(delegate)
actualResponse, err := transport.RoundTrip(request)
assert.NotNil(t, actualResponse)
assert.Nil(t, err)
assert.Equal(t, actualResponse.StatusCode, 200)
_, ok := getBucketRegionMap().Get("test-bucket")
assert.False(t, ok)
// Cleanup
getBucketRegionMap().Remove("test-bucket")
}
func TestRoundTrip_error(t *testing.T) {
setBucketRegionFromSignedHeadBucketRequest("")
requestUrl := "https://test-bucket.s3.cn-north-1.amazonaws.com.cn/a/b"
request := makeRequest("GET", requestUrl)
delegate := newMockTransport()
transport := newS3BucketRegionHeaderCapturingTransportForTest(delegate)
actualResponse, err := transport.RoundTrip(request)
assert.Nil(t, actualResponse)
assert.NotNil(t, err)
}
func TestBucketRegionCache_keepsNMostRecentItems(t *testing.T) {
for i := 0; i < 2*bucketRegionCacheItemCountMax; i++ {
bucketName := fmt.Sprintf("bucket-%d", i)
getBucketRegionMap().Put(bucketName, "us-east-1")
}
// Only the most-recently-added bucketRegionCacheItemCountMax items should be in the cache
assert.Equal(t, uint64(bucketRegionCacheItemCountMax), getBucketRegionMap().bucketNameCache.Size())
for i := 0; i < bucketRegionCacheItemCountMax; i++ {
bucketName := fmt.Sprintf("bucket-%d", i)
v, ok := getBucketRegionMap().Get(bucketName)
assert.Equal(t, "", v)
assert.False(t, ok)
}
for i := bucketRegionCacheItemCountMax; i < 2*bucketRegionCacheItemCountMax; i++ {
bucketName := fmt.Sprintf("bucket-%d", i)
v, ok := getBucketRegionMap().Get(bucketName)
assert.Equal(t, "us-east-1", v)
assert.True(t, ok)
}
// Cleanup
for i := bucketRegionCacheItemCountMax; i < 2*bucketRegionCacheItemCountMax; i++ {
bucketName := fmt.Sprintf("bucket-%d", i)
getBucketRegionMap().Remove(bucketName)
}
}
// Constructor that allows tests to supply a mock Transport
func newS3BucketRegionHeaderCapturingTransportForTest(delegate http.RoundTripper) *s3BucketRegionHeaderCapturingTransport {
return &s3BucketRegionHeaderCapturingTransport{
delegate: delegate,
logger: logmocks.NewMockLog(),
}
}
type mockTransportResponse struct {
resp *http.Response
err error
}
// A mock Transport implementation with a map of hard-coded responses
// for a set of URLs.
type mockTransport struct {
urlToResponseAndError map[string]mockTransportResponse
requestURLsReceived []string
}
// Create a new mockTransport with an empty response map
func newMockTransport() *mockTransport {
return &mockTransport{
urlToResponseAndError: make(map[string]mockTransportResponse),
requestURLsReceived: make([]string, 0),
}
}
// Register a mock response for the specified url
func (t *mockTransport) AddResponse(url string, response *http.Response) {
t.urlToResponseAndError[url] = mockTransportResponse{response, nil}
}
// Register a transport error for the specified url
func (t *mockTransport) AddTransportError(url string, err error) {
t.urlToResponseAndError[url] = mockTransportResponse{nil, err}
}
// Mock RoundTrip implementation. If the request is for a URL that is in
// the response map, returns the response. Otherwise, returns a nil response
// and an error.
func (t *mockTransport) RoundTrip(request *http.Request) (*http.Response, error) {
t.requestURLsReceived = append(t.requestURLsReceived, request.URL.String())
if response, ok := t.urlToResponseAndError[request.URL.String()]; ok {
return response.resp, response.err
}
return nil, fmt.Errorf("ERROR")
}
func makeRequest(method, rawUrl string) *http.Request {
parsedUrl, _ := url.Parse(rawUrl)
return &http.Request{
Method: method,
URL: parsedUrl,
}
}
func makeResponse(statusCode int, header http.Header, bodyContents string) *http.Response {
return &http.Response{
StatusCode: statusCode,
Header: header,
Body: ioutil.NopCloser(strings.NewReader(bodyContents)),
ContentLength: int64(len(bodyContents)),
}
}
// A credentials.Provider implementation that returns fake credentials
// for testing.
type mockCredentialsProvider struct {
accessKey string
secretKey string
}
// Returns fake credentials.
func (c *mockCredentialsProvider) Retrieve() (credentials.Value, error) {
return credentials.Value{
AccessKeyID: "FAKEACCESSKEY",
SecretAccessKey: "FAKESECRETKEY",
SessionToken: "FAKESESSIONTOKEN",
ProviderName: "mockCredentialsProvider",
}, nil
}
// Always returns false to indicate the credentials are still valid.
func (c *mockCredentialsProvider) IsExpired() bool {
return false
}
// A Resolver implementation that returns a hard-coded endpoint
type mockEndpointResolver struct {
resolvedEndpoint endpoints.ResolvedEndpoint
err error
}
// Returns the hard-coded endpoint lookup response
func (r mockEndpointResolver) EndpointFor(service, region string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
return r.resolvedEndpoint, r.err
}
func makeGetBucketLocationResponseBodyText(region string) string {
return "\\n" +
"" + region + ""
}
func makeRedirectResponseBodyText(endpoint, bucketName string) string {
return "\\n" +
"PermanentRedirect
" +
"The bucket you are attempting to access must be addressed using the specified endpoint. " +
"Please send all future requests to this endpoint." +
"" + endpoint + "" +
"" + bucketName + "" +
"12345" +
"abcde"
}
func makeAuthorizationHeaderMalformedErrorResponse(wrongRegion, expRegion string) string {
return "" +
"AuthorizationHeaderMalformed
" +
"The authorization header is malformed; " +
"the region '" + wrongRegion + "' is wrong; expecting '" + expRegion + "'" +
"" + expRegion + "" +
"Request1Host1"
}
func TestExtractRegionFromBody_ErrorXmlWithRegion(t *testing.T) {
bodyContents := makeAuthorizationHeaderMalformedErrorResponse("us-east-1", "eu-west-1")
transport := newS3BucketRegionHeaderCapturingTransport(logmocks.NewMockLog(), appconfig.SsmagentConfig{})
assert.Equal(t, "eu-west-1", transport.extractRegionFromBody([]byte(bodyContents)))
}
func TestExtractRegionFromBody_ErrorXmlWithEndpoint(t *testing.T) {
bodyContents := makeRedirectResponseBodyText("bucket-1.s3.cn-north-1.amazonaws.com.cn", "cn-north-1")
transport := newS3BucketRegionHeaderCapturingTransport(logmocks.NewMockLog(), appconfig.SsmagentConfig{})
assert.Equal(t, "cn-north-1", transport.extractRegionFromBody([]byte(bodyContents)))
}
func TestExtractRegionFromBody_ErrorXmlWithEndpoint_PathStyleEndpointUrl(t *testing.T) {
bodyContents := makeRedirectResponseBodyText("s3.cn-north-1.amazonaws.com.cn/bucket-1", "cn-north-1")
transport := newS3BucketRegionHeaderCapturingTransport(logmocks.NewMockLog(), appconfig.SsmagentConfig{})
assert.Equal(t, "cn-north-1", transport.extractRegionFromBody([]byte(bodyContents)))
}
type mockReaderResponse struct {
data []byte
err error
}
type mockReader struct {
readResponses []mockReaderResponse
readResponseIndex int
}
func (r *mockReader) Read(buf []byte) (int, error) {
resp := r.readResponses[r.readResponseIndex]
r.readResponseIndex++
n := len(resp.data)
if n > len(buf) {
n = len(buf)
}
for i := 0; i < n; i++ {
buf[i] = resp.data[i]
}
return n, resp.err
}
func (r *mockReader) Close() error {
return nil
}
func TestGetResponseBody_SingleRead_EOFOnNonemptyRead(t *testing.T) {
readResponses := []mockReaderResponse{
{data: []byte("payload"), err: io.EOF},
}
getResponseBodyBufsize, getResponseBodyMaxLength = 16, 32
expResult := []byte("payload")
expErr := error(nil)
doGetResponseBodyTest(t, readResponses, expResult, expErr)
}
func TestGetResponseBody_MultipleReads_EOFOnNonemptyRead(t *testing.T) {
readResponses := []mockReaderResponse{
{data: []byte("payload"), err: nil},
{data: []byte("payload"), err: io.EOF},
}
getResponseBodyBufsize, getResponseBodyMaxLength = 7, 32
expResult := []byte("payloadpayload")
expErr := error(nil)
doGetResponseBodyTest(t, readResponses, expResult, expErr)
}
func TestGetResponseBody_MultipleReads_EOFOnEmptyRead(t *testing.T) {
readResponses := []mockReaderResponse{
{data: []byte("payload"), err: nil},
{data: []byte("payload"), err: nil},
{data: []byte(""), err: io.EOF},
}
getResponseBodyBufsize, getResponseBodyMaxLength = 7, 32
expResult := []byte("payloadpayload")
expErr := error(nil)
doGetResponseBodyTest(t, readResponses, expResult, expErr)
}
func TestGetResponseBody_MultipleReads_MaxLenExceeded(t *testing.T) {
readResponses := []mockReaderResponse{
{data: []byte("payload"), err: nil},
{data: []byte("payload"), err: nil},
{data: []byte("payload"), err: io.EOF},
}
getResponseBodyBufsize, getResponseBodyMaxLength = 7, 10
expResult := []byte("payloadpay")
expErr := fmt.Errorf("getResponseBody(): buffer length exceeded")
doGetResponseBodyTest(t, readResponses, expResult, expErr)
}
func doGetResponseBodyTest(t *testing.T, mockResponses []mockReaderResponse, expResult []byte, expErr error) {
body := &mockReader{
readResponses: mockResponses,
}
response := &http.Response{
Body: body,
}
transport := newS3BucketRegionHeaderCapturingTransport(logmocks.NewMockLog(), appconfig.SsmagentConfig{})
actualBody, actualErr := transport.getResponseBody(response)
assert.Equal(t, expResult, actualBody)
assert.Equal(t, expErr, actualErr)
}