// Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"). You may not // use this file except in compliance with the License. A copy of the // License is located at // // http://aws.amazon.com/apache2.0/ // // or in the "license" file accompanying this file. This file is distributed // on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, // either express or implied. See the License for the specific language governing // permissions and limitations under the License. //go:build integration // +build integration // retryer overrides the default aws sdk retryer delay logic to better suit the mds needs package retryer import ( "bytes" "encoding/json" "fmt" "io" "io/ioutil" "net/http" "testing" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/stretchr/testify/assert" "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/awstesting" ) type testData struct { Data string } func body(str string) io.ReadCloser { return ioutil.NopCloser(bytes.NewReader([]byte(str))) } func unmarshal(req *request.Request) { defer req.HTTPResponse.Body.Close() if req.Data != nil { json.NewDecoder(req.HTTPResponse.Body).Decode(req.Data) } return } func unmarshalError(req *request.Request) { bodyBytes, err := ioutil.ReadAll(req.HTTPResponse.Body) if err != nil { req.Error = awserr.New("UnmarshaleError", req.HTTPResponse.Status, err) return } if len(bodyBytes) == 0 { req.Error = awserr.NewRequestFailure( awserr.New("UnmarshaleError", req.HTTPResponse.Status, fmt.Errorf("empty body")), req.HTTPResponse.StatusCode, "", ) return } var jsonErr jsonErrorResponse if err := json.Unmarshal(bodyBytes, &jsonErr); err != nil { req.Error = awserr.New("UnmarshaleError", "JSON unmarshal", err) return } req.Error = awserr.NewRequestFailure( awserr.New(jsonErr.Code, jsonErr.Message, nil), req.HTTPResponse.StatusCode, "", ) } type jsonErrorResponse struct { Code string `json:"__type"` Message string `json:"message"` } // test that retries occur for 5xx status codes func TestRequestRecoverRetry5xx(t *testing.T) { reqNum := 0 reqs := []http.Response{ {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, {StatusCode: 502, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, {StatusCode: 200, Body: body(`{"data":"valid"}`)}, } retryer := SsmRetryer{} retryer.NumMaxRetries = 2 s := awstesting.NewClient(&aws.Config{Retryer: &retryer}) s.Handlers.Validate.Clear() s.Handlers.Unmarshal.PushBack(unmarshal) s.Handlers.UnmarshalError.PushBack(unmarshalError) s.Handlers.Send.Clear() // mock sending s.Handlers.Send.PushBack(func(r *request.Request) { r.HTTPResponse = &reqs[reqNum] reqNum++ }) out := &testData{} r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) err := r.Send() assert.Nil(t, err) assert.Equal(t, 2, int(r.RetryCount)) assert.Equal(t, "valid", out.Data) } // test that retries occur for 4xx status codes with a response type that can be retried - see `shouldRetry` func TestRequestRecoverRetry4xxRetryable(t *testing.T) { reqNum := 0 reqs := []http.Response{ {StatusCode: 400, Body: body(`{"__type":"Throttling","message":"Rate exceeded."}`)}, {StatusCode: 429, Body: body(`{"__type":"ProvisionedThroughputExceededException","message":"Rate exceeded."}`)}, {StatusCode: 200, Body: body(`{"data":"valid"}`)}, } retryer := SsmRetryer{} retryer.NumMaxRetries = 10 s := awstesting.NewClient(&aws.Config{Retryer: &retryer}) s.Handlers.Validate.Clear() s.Handlers.Unmarshal.PushBack(unmarshal) s.Handlers.UnmarshalError.PushBack(unmarshalError) s.Handlers.Send.Clear() // mock sending s.Handlers.Send.PushBack(func(r *request.Request) { r.HTTPResponse = &reqs[reqNum] reqNum++ }) out := &testData{} r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) err := r.Send() assert.Nil(t, err) assert.Equal(t, 2, int(r.RetryCount)) assert.Equal(t, "valid", out.Data) } // test that retries don't occur for 4xx status codes with a response type that can't be retried func TestRequest4xxUnretryable(t *testing.T) { retryer := SsmRetryer{} retryer.NumMaxRetries = 10 s := awstesting.NewClient(&aws.Config{Retryer: &retryer}) s.Handlers.Validate.Clear() s.Handlers.Unmarshal.PushBack(unmarshal) s.Handlers.UnmarshalError.PushBack(unmarshalError) s.Handlers.Send.Clear() // mock sending s.Handlers.Send.PushBack(func(r *request.Request) { r.HTTPResponse = &http.Response{StatusCode: 401, Body: body(`{"__type":"SignatureDoesNotMatch","message":"Signature does not match."}`)} }) out := &testData{} r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) err := r.Send() assert.NotNil(t, err) if e, ok := err.(awserr.RequestFailure); ok { assert.Equal(t, 401, e.StatusCode()) } else { assert.Fail(t, "Expected error to be a service failure") } assert.Equal(t, "SignatureDoesNotMatch", err.(awserr.Error).Code()) assert.Equal(t, "Signature does not match.", err.(awserr.Error).Message()) assert.Equal(t, 0, int(r.RetryCount)) } // test that retries delay increase over time func TestDelayIncreasesOverTime(t *testing.T) { reqNum := 0 reqs := []http.Response{ {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, {StatusCode: 200, Body: body(`{"data":"valid"}`)}, } previousDelay := time.Duration(0) sleepDelay := func(delay time.Duration) { if delay < previousDelay { assert.Fail(t, "Expect delay to increase") } previousDelay = delay } retryer := SsmRetryer{} retryer.NumMaxRetries = 10 s := awstesting.NewClient(&aws.Config{Retryer: &retryer, SleepDelay: sleepDelay}) s.Handlers.Validate.Clear() s.Handlers.Unmarshal.PushBack(unmarshal) s.Handlers.UnmarshalError.PushBack(unmarshalError) s.Handlers.Send.Clear() // mock sending s.Handlers.Send.PushBack(func(r *request.Request) { r.HTTPResponse = &reqs[reqNum] reqNum++ }) out := &testData{} r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) err := r.Send() assert.Nil(t, err) assert.Equal(t, 4, int(r.RetryCount)) } // test that retries delay increase by at least a second func TestDelayIncreasesByASecond(t *testing.T) { reqNum := 0 reqs := []http.Response{ {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, {StatusCode: 200, Body: body(`{"data":"valid"}`)}, } sleepDelay := func(delay time.Duration) { if delay < time.Duration(1*time.Second) { assert.Fail(t, "Expect delay to increase") } } retryer := SsmRetryer{} retryer.NumMaxRetries = 10 s := awstesting.NewClient(&aws.Config{Retryer: &retryer, SleepDelay: sleepDelay}) s.Handlers.Validate.Clear() s.Handlers.Unmarshal.PushBack(unmarshal) s.Handlers.UnmarshalError.PushBack(unmarshalError) s.Handlers.Send.Clear() // mock sending s.Handlers.Send.PushBack(func(r *request.Request) { r.HTTPResponse = &reqs[reqNum] reqNum++ }) out := &testData{} r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) err := r.Send() assert.Nil(t, err) assert.Equal(t, 1, int(r.RetryCount)) }