package request_test import ( "bytes" "encoding/json" "errors" "fmt" "io" "io/ioutil" "net" "net/http" "net/http/httptest" "net/url" "reflect" "runtime" "strconv" "strings" "testing" "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/corehandlers" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/awstesting" "github.com/aws/aws-sdk-go/awstesting/unit" "github.com/aws/aws-sdk-go/private/protocol/rest" ) type tempNetworkError struct { op string msg string isTemp bool } func (e *tempNetworkError) Temporary() bool { return e.isTemp } func (e *tempNetworkError) Error() string { return fmt.Sprintf("%s: %s", e.op, e.msg) } var ( // net.OpError accept, are always temporary errAcceptConnectionResetStub = &tempNetworkError{ isTemp: true, op: "accept", msg: "connection reset", } // net.OpError read for ECONNRESET is not temporary. errReadConnectionResetStub = &tempNetworkError{ isTemp: false, op: "read", msg: "connection reset", } // net.OpError write for ECONNRESET may not be temporary, but is treaded as // temporary by the SDK. errWriteConnectionResetStub = &tempNetworkError{ isTemp: false, op: "write", msg: "connection reset", } // net.OpError write for broken pipe may not be temporary, but is treaded as // temporary by the SDK. errWriteBrokenPipeStub = &tempNetworkError{ isTemp: false, op: "write", msg: "broken pipe", } // Generic connection reset error errConnectionResetStub = errors.New("connection reset") // use of closed network connection error errUseOfClosedConnectionStub = errors.New("use of closed network connection") ) type testData struct { Data string } func body(str string) io.ReadCloser { return ioutil.NopCloser(bytes.NewReader([]byte(str))) } func unmarshal(req *request.Request) { defer req.HTTPResponse.Body.Close() if req.Data != nil { json.NewDecoder(req.HTTPResponse.Body).Decode(req.Data) } } func unmarshalError(req *request.Request) { bodyBytes, err := ioutil.ReadAll(req.HTTPResponse.Body) if err != nil { req.Error = awserr.New("UnmarshaleError", req.HTTPResponse.Status, err) return } if len(bodyBytes) == 0 { req.Error = awserr.NewRequestFailure( awserr.New("UnmarshaleError", req.HTTPResponse.Status, fmt.Errorf("empty body")), req.HTTPResponse.StatusCode, "", ) return } var jsonErr jsonErrorResponse if err := json.Unmarshal(bodyBytes, &jsonErr); err != nil { req.Error = awserr.New("UnmarshaleError", "JSON unmarshal", err) return } req.Error = awserr.NewRequestFailure( awserr.New(jsonErr.Code, jsonErr.Message, nil), req.HTTPResponse.StatusCode, "", ) } type jsonErrorResponse struct { Code string `json:"__type"` Message string `json:"message"` } // test that retries occur for 5xx status codes func TestRequestRecoverRetry5xx(t *testing.T) { reqNum := 0 reqs := []http.Response{ {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, {StatusCode: 502, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, {StatusCode: 200, Body: body(`{"data":"valid"}`)}, } s := awstesting.NewClient(&aws.Config{ MaxRetries: aws.Int(10), SleepDelay: func(time.Duration) {}, }) s.Handlers.Validate.Clear() s.Handlers.Unmarshal.PushBack(unmarshal) s.Handlers.UnmarshalError.PushBack(unmarshalError) s.Handlers.Send.Clear() // mock sending s.Handlers.Send.PushBack(func(r *request.Request) { r.HTTPResponse = &reqs[reqNum] reqNum++ }) out := &testData{} r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) err := r.Send() if err != nil { t.Fatalf("expect no error, but got %v", err) } if e, a := 2, r.RetryCount; e != a { t.Errorf("expect %d retry count, got %d", e, a) } if e, a := "valid", out.Data; e != a { t.Errorf("expect %q output got %q", e, a) } } // test that retries occur for 4xx status codes with a response type that can be retried - see `shouldRetry` func TestRequestRecoverRetry4xxRetryable(t *testing.T) { reqNum := 0 reqs := []http.Response{ {StatusCode: 400, Body: body(`{"__type":"Throttling","message":"Rate exceeded."}`)}, {StatusCode: 400, Body: body(`{"__type":"ProvisionedThroughputExceededException","message":"Rate exceeded."}`)}, {StatusCode: 429, Body: body(`{"__type":"FooException","message":"Rate exceeded."}`)}, {StatusCode: 200, Body: body(`{"data":"valid"}`)}, } s := awstesting.NewClient(&aws.Config{ MaxRetries: aws.Int(10), SleepDelay: func(time.Duration) {}, }) s.Handlers.Validate.Clear() s.Handlers.Unmarshal.PushBack(unmarshal) s.Handlers.UnmarshalError.PushBack(unmarshalError) s.Handlers.Send.Clear() // mock sending s.Handlers.Send.PushBack(func(r *request.Request) { r.HTTPResponse = &reqs[reqNum] reqNum++ }) out := &testData{} r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) err := r.Send() if err != nil { t.Fatalf("expect no error, but got %v", err) } if e, a := 3, r.RetryCount; e != a { t.Errorf("expect %d retry count, got %d", e, a) } if e, a := "valid", out.Data; e != a { t.Errorf("expect %q output got %q", e, a) } } // test that retries don't occur for 4xx status codes with a response type that can't be retried func TestRequest4xxUnretryable(t *testing.T) { s := awstesting.NewClient(&aws.Config{ MaxRetries: aws.Int(1), SleepDelay: func(time.Duration) {}, }) s.Handlers.Validate.Clear() s.Handlers.Unmarshal.PushBack(unmarshal) s.Handlers.UnmarshalError.PushBack(unmarshalError) s.Handlers.Send.Clear() // mock sending s.Handlers.Send.PushBack(func(r *request.Request) { r.HTTPResponse = &http.Response{ StatusCode: 401, Body: body(`{"__type":"SignatureDoesNotMatch","message":"Signature does not match."}`), } }) out := &testData{} r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) err := r.Send() if err == nil { t.Fatalf("expect error, but did not get one") } aerr := err.(awserr.RequestFailure) if e, a := 401, aerr.StatusCode(); e != a { t.Errorf("expect %d status code, got %d", e, a) } if e, a := "SignatureDoesNotMatch", aerr.Code(); e != a { t.Errorf("expect %q error code, got %q", e, a) } if e, a := "Signature does not match.", aerr.Message(); e != a { t.Errorf("expect %q error message, got %q", e, a) } if e, a := 0, r.RetryCount; e != a { t.Errorf("expect %d retry count, got %d", e, a) } } func TestRequestExhaustRetries(t *testing.T) { delays := []time.Duration{} sleepDelay := func(delay time.Duration) { delays = append(delays, delay) } reqNum := 0 reqs := []http.Response{ {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, } s := awstesting.NewClient(&aws.Config{ SleepDelay: sleepDelay, }) s.Handlers.Validate.Clear() s.Handlers.Unmarshal.PushBack(unmarshal) s.Handlers.UnmarshalError.PushBack(unmarshalError) s.Handlers.Send.Clear() // mock sending s.Handlers.Send.PushBack(func(r *request.Request) { r.HTTPResponse = &reqs[reqNum] reqNum++ }) r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, nil) err := r.Send() if err == nil { t.Fatalf("expect error, but did not get one") } aerr := err.(awserr.RequestFailure) if e, a := 500, aerr.StatusCode(); e != a { t.Errorf("expect %d status code, got %d", e, a) } if e, a := "UnknownError", aerr.Code(); e != a { t.Errorf("expect %q error code, got %q", e, a) } if e, a := "An error occurred.", aerr.Message(); e != a { t.Errorf("expect %q error message, got %q", e, a) } if e, a := 3, r.RetryCount; e != a { t.Errorf("expect %d retry count, got %d", e, a) } expectDelays := []struct{ min, max time.Duration }{{30, 60}, {60, 120}, {120, 240}} for i, v := range delays { min := expectDelays[i].min * time.Millisecond max := expectDelays[i].max * time.Millisecond if !(min <= v && v <= max) { t.Errorf("Expect delay to be within range, i:%d, v:%s, min:%s, max:%s", i, v, min, max) } } } // test that the request is retried after the credentials are expired. func TestRequestRecoverExpiredCreds(t *testing.T) { reqNum := 0 reqs := []http.Response{ {StatusCode: 400, Body: body(`{"__type":"ExpiredTokenException","message":"expired token"}`)}, {StatusCode: 200, Body: body(`{"data":"valid"}`)}, } s := awstesting.NewClient(&aws.Config{ MaxRetries: aws.Int(10), Credentials: credentials.NewStaticCredentials("AKID", "SECRET", ""), SleepDelay: func(time.Duration) {}, }) s.Handlers.Validate.Clear() s.Handlers.Unmarshal.PushBack(unmarshal) s.Handlers.UnmarshalError.PushBack(unmarshalError) credExpiredBeforeRetry := false credExpiredAfterRetry := false s.Handlers.AfterRetry.PushBack(func(r *request.Request) { credExpiredAfterRetry = r.Config.Credentials.IsExpired() }) s.Handlers.Sign.Clear() s.Handlers.Sign.PushBack(func(r *request.Request) { r.Config.Credentials.Get() }) s.Handlers.Send.Clear() // mock sending s.Handlers.Send.PushBack(func(r *request.Request) { r.HTTPResponse = &reqs[reqNum] reqNum++ }) out := &testData{} r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) err := r.Send() if err != nil { t.Fatalf("expect no error, got %v", err) } if credExpiredBeforeRetry { t.Errorf("Expect valid creds before retry check") } if !credExpiredAfterRetry { t.Errorf("Expect expired creds after retry check") } if s.Config.Credentials.IsExpired() { t.Errorf("Expect valid creds after cred expired recovery") } if e, a := 1, r.RetryCount; e != a { t.Errorf("expect %d retry count, got %d", e, a) } if e, a := "valid", out.Data; e != a { t.Errorf("expect %q output got %q", e, a) } } func TestMakeAddtoUserAgentHandler(t *testing.T) { fn := request.MakeAddToUserAgentHandler("name", "version", "extra1", "extra2") r := &request.Request{HTTPRequest: &http.Request{Header: http.Header{}}} r.HTTPRequest.Header.Set("User-Agent", "foo/bar") fn(r) if e, a := "foo/bar name/version (extra1; extra2)", r.HTTPRequest.Header.Get("User-Agent"); !strings.HasPrefix(a, e) { t.Errorf("expect %q user agent, got %q", e, a) } } func TestMakeAddtoUserAgentFreeFormHandler(t *testing.T) { fn := request.MakeAddToUserAgentFreeFormHandler("name/version (extra1; extra2)") r := &request.Request{HTTPRequest: &http.Request{Header: http.Header{}}} r.HTTPRequest.Header.Set("User-Agent", "foo/bar") fn(r) if e, a := "foo/bar name/version (extra1; extra2)", r.HTTPRequest.Header.Get("User-Agent"); !strings.HasPrefix(a, e) { t.Errorf("expect %q user agent, got %q", e, a) } } func TestRequestUserAgent(t *testing.T) { s := awstesting.NewClient(&aws.Config{ Region: aws.String("us-east-1"), }) req := s.NewRequest(&request.Operation{Name: "Operation"}, nil, &testData{}) req.HTTPRequest.Header.Set("User-Agent", "foo/bar") if err := req.Build(); err != nil { t.Fatalf("expect no error, got %v", err) } expectUA := fmt.Sprintf("foo/bar %s/%s (%s; %s; %s)", aws.SDKName, aws.SDKVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH) if e, a := expectUA, req.HTTPRequest.Header.Get("User-Agent"); !strings.HasPrefix(a, e) { t.Errorf("expect %q user agent, got %q", e, a) } } func TestRequestThrottleRetries(t *testing.T) { var delays []time.Duration sleepDelay := func(delay time.Duration) { delays = append(delays, delay) } reqNum := 0 reqs := []http.Response{ {StatusCode: 500, Body: body(`{"__type":"Throttling","message":"An error occurred."}`)}, {StatusCode: 500, Body: body(`{"__type":"Throttling","message":"An error occurred."}`)}, {StatusCode: 500, Body: body(`{"__type":"Throttling","message":"An error occurred."}`)}, {StatusCode: 500, Body: body(`{"__type":"Throttling","message":"An error occurred."}`)}, } s := awstesting.NewClient(&aws.Config{ SleepDelay: sleepDelay, }) s.Handlers.Validate.Clear() s.Handlers.Unmarshal.PushBack(unmarshal) s.Handlers.UnmarshalError.PushBack(unmarshalError) s.Handlers.Send.Clear() // mock sending s.Handlers.Send.PushBack(func(r *request.Request) { r.HTTPResponse = &reqs[reqNum] reqNum++ }) r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, nil) err := r.Send() if err == nil { t.Fatalf("expect error, but did not get one") } aerr := err.(awserr.RequestFailure) if e, a := 500, aerr.StatusCode(); e != a { t.Errorf("expect %d status code, got %d", e, a) } if e, a := "Throttling", aerr.Code(); e != a { t.Errorf("expect %q error code, got %q", e, a) } if e, a := "An error occurred.", aerr.Message(); e != a { t.Errorf("expect %q error message, got %q", e, a) } if e, a := 3, r.RetryCount; e != a { t.Errorf("expect %d retry count, got %d", e, a) } expectDelays := []struct{ min, max time.Duration }{{500, 1000}, {1000, 2000}, {2000, 4000}} for i, v := range delays { min := expectDelays[i].min * time.Millisecond max := expectDelays[i].max * time.Millisecond if !(min <= v && v <= max) { t.Errorf("Expect delay to be within range, i:%d, v:%s, min:%s, max:%s", i, v, min, max) } } } // test that retries occur for request timeouts when response.Body can be nil func TestRequestRecoverTimeoutWithNilBody(t *testing.T) { reqNum := 0 reqs := []*http.Response{ {StatusCode: 0, Body: nil}, // body can be nil when requests time out {StatusCode: 200, Body: body(`{"data":"valid"}`)}, } errors := []error{ errTimeout, nil, } s := awstesting.NewClient(aws.NewConfig().WithMaxRetries(10)) s.Handlers.Validate.Clear() s.Handlers.Unmarshal.PushBack(unmarshal) s.Handlers.UnmarshalError.PushBack(unmarshalError) s.Handlers.AfterRetry.Clear() // force retry on all errors s.Handlers.AfterRetry.PushBack(func(r *request.Request) { if r.Error != nil { r.Error = nil r.Retryable = aws.Bool(true) r.RetryCount++ } }) s.Handlers.Send.Clear() // mock sending s.Handlers.Send.PushBack(func(r *request.Request) { r.HTTPResponse = reqs[reqNum] r.Error = errors[reqNum] reqNum++ }) out := &testData{} r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) err := r.Send() if err != nil { t.Fatalf("expect no error, but got %v", err) } if e, a := 1, r.RetryCount; e != a { t.Errorf("expect %d retry count, got %d", e, a) } if e, a := "valid", out.Data; e != a { t.Errorf("expect %q output got %q", e, a) } } func TestRequestRecoverTimeoutWithNilResponse(t *testing.T) { reqNum := 0 reqs := []*http.Response{ nil, {StatusCode: 200, Body: body(`{"data":"valid"}`)}, } errors := []error{ errTimeout, nil, } s := awstesting.NewClient(aws.NewConfig().WithMaxRetries(10)) s.Handlers.Validate.Clear() s.Handlers.Unmarshal.PushBack(unmarshal) s.Handlers.UnmarshalError.PushBack(unmarshalError) s.Handlers.AfterRetry.Clear() // force retry on all errors s.Handlers.AfterRetry.PushBack(func(r *request.Request) { if r.Error != nil { r.Error = nil r.Retryable = aws.Bool(true) r.RetryCount++ } }) s.Handlers.Send.Clear() // mock sending s.Handlers.Send.PushBack(func(r *request.Request) { r.HTTPResponse = reqs[reqNum] r.Error = errors[reqNum] reqNum++ }) out := &testData{} r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) err := r.Send() if err != nil { t.Fatalf("expect no error, but got %v", err) } if e, a := 1, r.RetryCount; e != a { t.Errorf("expect %d retry count, got %d", e, a) } if e, a := "valid", out.Data; e != a { t.Errorf("expect %q output got %q", e, a) } } func TestRequest_NoBody(t *testing.T) { cases := []string{ "GET", "HEAD", "DELETE", "PUT", "POST", "PATCH", } for i, c := range cases { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if v := r.TransferEncoding; len(v) > 0 { t.Errorf("%d, expect no body sent with Transfer-Encoding, %v", i, v) } outMsg := []byte(`{"Value": "abc"}`) if b, err := ioutil.ReadAll(r.Body); err != nil { t.Fatalf("%d, expect no error reading request body, got %v", i, err) } else if n := len(b); n > 0 { t.Errorf("%d, expect no request body, got %d bytes", i, n) } w.Header().Set("Content-Length", strconv.Itoa(len(outMsg))) if _, err := w.Write(outMsg); err != nil { t.Fatalf("%d, expect no error writing server response, got %v", i, err) } })) s := awstesting.NewClient(&aws.Config{ Region: aws.String("mock-region"), MaxRetries: aws.Int(0), Endpoint: aws.String(server.URL), DisableSSL: aws.Bool(true), }) s.Handlers.Build.PushBack(rest.Build) s.Handlers.Validate.Clear() s.Handlers.Unmarshal.PushBack(unmarshal) s.Handlers.UnmarshalError.PushBack(unmarshalError) in := struct { Bucket *string `location:"uri" locationName:"bucket"` Key *string `location:"uri" locationName:"key"` }{ Bucket: aws.String("mybucket"), Key: aws.String("myKey"), } out := struct { Value *string }{} r := s.NewRequest(&request.Operation{ Name: "OpName", HTTPMethod: c, HTTPPath: "/{bucket}/{key+}", }, &in, &out) err := r.Send() server.Close() if err != nil { t.Fatalf("%d, expect no error sending request, got %v", i, err) } } } func TestIsSerializationErrorRetryable(t *testing.T) { testCases := []struct { err error expected bool }{ { err: awserr.New(request.ErrCodeSerialization, "foo error", nil), expected: false, }, { err: awserr.New("ErrFoo", "foo error", nil), expected: false, }, { err: nil, expected: false, }, { err: awserr.New(request.ErrCodeSerialization, "foo error", errAcceptConnectionResetStub), expected: true, }, } for i, c := range testCases { r := &request.Request{ Error: c.err, } if r.IsErrorRetryable() != c.expected { t.Errorf("Case %d: Expected %v, but received %v", i, c.expected, !c.expected) } } } func TestWithLogLevel(t *testing.T) { r := &request.Request{} opt := request.WithLogLevel(aws.LogDebugWithHTTPBody) r.ApplyOptions(opt) if !r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody) { t.Errorf("expect log level to be set, but was not, %v", r.Config.LogLevel.Value()) } } func TestWithGetResponseHeader(t *testing.T) { r := &request.Request{} var val, val2 string r.ApplyOptions( request.WithGetResponseHeader("x-a-header", &val), request.WithGetResponseHeader("x-second-header", &val2), ) r.HTTPResponse = &http.Response{ Header: func() http.Header { h := http.Header{} h.Set("x-a-header", "first") h.Set("x-second-header", "second") return h }(), } r.Handlers.Complete.Run(r) if e, a := "first", val; e != a { t.Errorf("expect %q header value got %q", e, a) } if e, a := "second", val2; e != a { t.Errorf("expect %q header value got %q", e, a) } } func TestWithGetResponseHeaders(t *testing.T) { r := &request.Request{} var headers http.Header opt := request.WithGetResponseHeaders(&headers) r.ApplyOptions(opt) r.HTTPResponse = &http.Response{ Header: func() http.Header { h := http.Header{} h.Set("x-a-header", "headerValue") return h }(), } r.Handlers.Complete.Run(r) if e, a := "headerValue", headers.Get("x-a-header"); e != a { t.Errorf("expect %q header value got %q", e, a) } } type testRetryer struct { shouldRetry bool maxRetries int } func (d *testRetryer) MaxRetries() int { return d.maxRetries } // RetryRules returns the delay duration before retrying this request again func (d *testRetryer) RetryRules(r *request.Request) time.Duration { return 0 } func (d *testRetryer) ShouldRetry(r *request.Request) bool { return d.shouldRetry } func TestEnforceShouldRetryCheck(t *testing.T) { retryer := &testRetryer{ shouldRetry: true, maxRetries: 3, } s := awstesting.NewClient(&aws.Config{ Region: aws.String("mock-region"), MaxRetries: aws.Int(0), Retryer: retryer, EnforceShouldRetryCheck: aws.Bool(true), SleepDelay: func(time.Duration) {}, }) s.Handlers.Validate.Clear() s.Handlers.Send.Swap(corehandlers.SendHandler.Name, request.NamedHandler{ Name: "TestEnforceShouldRetryCheck", Fn: func(r *request.Request) { r.HTTPResponse = &http.Response{ Header: http.Header{}, Body: ioutil.NopCloser(bytes.NewBuffer(nil)), } r.Retryable = aws.Bool(false) }, }) s.Handlers.Unmarshal.PushBack(unmarshal) s.Handlers.UnmarshalError.PushBack(unmarshalError) out := &testData{} r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) err := r.Send() if err == nil { t.Fatalf("expect error, but got nil") } if e, a := 3, r.RetryCount; e != a { t.Errorf("expect %d retry count, got %d", e, a) } if !retryer.shouldRetry { t.Errorf("expect 'true' for ShouldRetry, but got %v", retryer.shouldRetry) } } type errReader struct { err error } func (reader *errReader) Read(b []byte) (int, error) { return 0, reader.err } func (reader *errReader) Close() error { return nil } func TestIsNoBodyReader(t *testing.T) { cases := []struct { reader io.ReadCloser expect bool }{ {ioutil.NopCloser(bytes.NewReader([]byte("abc"))), false}, {ioutil.NopCloser(bytes.NewReader(nil)), false}, {nil, false}, {request.NoBody, true}, } for i, c := range cases { if e, a := c.expect, request.NoBody == c.reader; e != a { t.Errorf("%d, expect %t match, but was %t", i, e, a) } } } func TestRequest_TemporaryRetry(t *testing.T) { done := make(chan struct{}) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Length", "1024") w.WriteHeader(http.StatusOK) w.Write(make([]byte, 100)) f := w.(http.Flusher) f.Flush() <-done })) defer server.Close() client := &http.Client{ Timeout: 100 * time.Millisecond, } svc := awstesting.NewClient(&aws.Config{ Region: unit.Session.Config.Region, MaxRetries: aws.Int(1), HTTPClient: client, DisableSSL: aws.Bool(true), Endpoint: aws.String(server.URL), }) req := svc.NewRequest(&request.Operation{ Name: "name", HTTPMethod: "GET", HTTPPath: "/path", }, &struct{}{}, &struct{}{}) req.Handlers.Unmarshal.PushBack(func(r *request.Request) { defer req.HTTPResponse.Body.Close() _, err := io.Copy(ioutil.Discard, req.HTTPResponse.Body) r.Error = awserr.New(request.ErrCodeSerialization, "error", err) }) err := req.Send() if err == nil { t.Errorf("expect error, got none") } close(done) aerr := err.(awserr.Error) if e, a := request.ErrCodeSerialization, aerr.Code(); e != a { t.Errorf("expect %q error code, got %q", e, a) } if e, a := 1, req.RetryCount; e != a { t.Errorf("expect %d retries, got %d", e, a) } type temporary interface { Temporary() bool } terr := aerr.OrigErr().(temporary) if !terr.Temporary() { t.Errorf("expect temporary error, was not") } } func TestRequest_Presign(t *testing.T) { presign := func(r *request.Request, expire time.Duration) (string, http.Header, error) { u, err := r.Presign(expire) return u, nil, err } presignRequest := func(r *request.Request, expire time.Duration) (string, http.Header, error) { return r.PresignRequest(expire) } mustParseURL := func(v string) *url.URL { u, err := url.Parse(v) if err != nil { panic(err) } return u } cases := []struct { Expire time.Duration PresignFn func(*request.Request, time.Duration) (string, http.Header, error) SignerFn func(*request.Request) URL string Header http.Header Err string }{ { PresignFn: presign, Err: request.ErrCodeInvalidPresignExpire, }, { PresignFn: presignRequest, Err: request.ErrCodeInvalidPresignExpire, }, { Expire: -1, PresignFn: presign, Err: request.ErrCodeInvalidPresignExpire, }, { // Presign clear NotHoist Expire: 1 * time.Minute, PresignFn: func(r *request.Request, dur time.Duration) (string, http.Header, error) { r.NotHoist = true return presign(r, dur) }, SignerFn: func(r *request.Request) { r.HTTPRequest.URL = mustParseURL("https://endpoint/presignedURL") if r.NotHoist { r.Error = fmt.Errorf("expect NotHoist to be cleared") } }, URL: "https://endpoint/presignedURL", }, { // PresignRequest does not clear NotHoist Expire: 1 * time.Minute, PresignFn: func(r *request.Request, dur time.Duration) (string, http.Header, error) { r.NotHoist = true return presignRequest(r, dur) }, SignerFn: func(r *request.Request) { r.HTTPRequest.URL = mustParseURL("https://endpoint/presignedURL") if !r.NotHoist { r.Error = fmt.Errorf("expect NotHoist not to be cleared") } }, URL: "https://endpoint/presignedURL", }, { // PresignRequest returns signed headers Expire: 1 * time.Minute, PresignFn: presignRequest, SignerFn: func(r *request.Request) { r.HTTPRequest.URL = mustParseURL("https://endpoint/presignedURL") r.HTTPRequest.Header.Set("UnsigndHeader", "abc") r.SignedHeaderVals = http.Header{ "X-Amzn-Header": []string{"abc", "123"}, "X-Amzn-Header2": []string{"efg", "456"}, } }, URL: "https://endpoint/presignedURL", Header: http.Header{ "X-Amzn-Header": []string{"abc", "123"}, "X-Amzn-Header2": []string{"efg", "456"}, }, }, } svc := awstesting.NewClient() svc.Handlers.Clear() for i, c := range cases { req := svc.NewRequest(&request.Operation{ Name: "name", HTTPMethod: "GET", HTTPPath: "/path", }, &struct{}{}, &struct{}{}) req.Handlers.Sign.PushBack(c.SignerFn) u, h, err := c.PresignFn(req, c.Expire) if len(c.Err) != 0 { if e, a := c.Err, err.Error(); !strings.Contains(a, e) { t.Errorf("%d, expect %v to be in %v", i, e, a) } continue } else if err != nil { t.Errorf("%d, expect no error, got %v", i, err) continue } if e, a := c.URL, u; e != a { t.Errorf("%d, expect %v URL, got %v", i, e, a) } if e, a := c.Header, h; !reflect.DeepEqual(e, a) { t.Errorf("%d, expect %v header got %v", i, e, a) } } } func TestSanitizeHostForHeader(t *testing.T) { cases := []struct { url string expectedRequestHost string }{ {"https://estest.us-east-1.es.amazonaws.com:443", "estest.us-east-1.es.amazonaws.com"}, {"https://estest.us-east-1.es.amazonaws.com", "estest.us-east-1.es.amazonaws.com"}, {"https://localhost:9200", "localhost:9200"}, {"http://localhost:80", "localhost"}, {"http://localhost:8080", "localhost:8080"}, } for _, c := range cases { r, _ := http.NewRequest("GET", c.url, nil) request.SanitizeHostForHeader(r) if h := r.Host; h != c.expectedRequestHost { t.Errorf("expect %v host, got %q", c.expectedRequestHost, h) } } } func TestRequestWillRetry_ByBody(t *testing.T) { svc := awstesting.NewClient() cases := []struct { WillRetry bool HTTPMethod string Body io.ReadSeeker IsReqNoBody bool }{ { WillRetry: true, HTTPMethod: "GET", Body: bytes.NewReader([]byte{}), IsReqNoBody: true, }, { WillRetry: true, HTTPMethod: "GET", Body: bytes.NewReader(nil), IsReqNoBody: true, }, { WillRetry: true, HTTPMethod: "POST", Body: bytes.NewReader([]byte("abc123")), }, { WillRetry: true, HTTPMethod: "POST", Body: aws.ReadSeekCloser(bytes.NewReader([]byte("abc123"))), }, { WillRetry: true, HTTPMethod: "GET", Body: aws.ReadSeekCloser(bytes.NewBuffer(nil)), IsReqNoBody: true, }, { WillRetry: true, HTTPMethod: "POST", Body: aws.ReadSeekCloser(bytes.NewBuffer(nil)), IsReqNoBody: true, }, { WillRetry: false, HTTPMethod: "POST", Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("abc123"))), }, } for i, c := range cases { req := svc.NewRequest(&request.Operation{ Name: "Operation", HTTPMethod: c.HTTPMethod, HTTPPath: "/", }, nil, nil) req.SetReaderBody(c.Body) req.Build() req.Error = fmt.Errorf("some error") req.Retryable = aws.Bool(true) req.HTTPResponse = &http.Response{ StatusCode: 500, } if e, a := c.IsReqNoBody, request.NoBody == req.HTTPRequest.Body; e != a { t.Errorf("%d, expect request to be no body, %t, got %t, %T", i, e, a, req.HTTPRequest.Body) } if e, a := c.WillRetry, req.WillRetry(); e != a { t.Errorf("%d, expect %t willRetry, got %t", i, e, a) } if req.Error == nil { t.Fatalf("%d, expect error, got none", i) } if e, a := "some error", req.Error.Error(); !strings.Contains(a, e) { t.Errorf("%d, expect %q error in %q", i, e, a) } if e, a := 0, req.RetryCount; e != a { t.Errorf("%d, expect retry count to be %d, got %d", i, e, a) } } } func Test501NotRetrying(t *testing.T) { reqNum := 0 reqs := []http.Response{ {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, {StatusCode: 501, Body: body(`{"__type":"NotImplemented","message":"An error occurred."}`)}, {StatusCode: 200, Body: body(`{"data":"valid"}`)}, } s := awstesting.NewClient(aws.NewConfig().WithMaxRetries(10)) s.Handlers.Validate.Clear() s.Handlers.Unmarshal.PushBack(unmarshal) s.Handlers.UnmarshalError.PushBack(unmarshalError) s.Handlers.Send.Clear() // mock sending s.Handlers.Send.PushBack(func(r *request.Request) { r.HTTPResponse = &reqs[reqNum] reqNum++ }) out := &testData{} r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) err := r.Send() if err == nil { t.Fatal("expect error, but got none") } aerr := err.(awserr.Error) if e, a := "NotImplemented", aerr.Code(); e != a { t.Errorf("expected error code %q, but received %q", e, a) } if e, a := 1, r.RetryCount; e != a { t.Errorf("expect %d retry count, got %d", e, a) } } func TestRequestNoConnection(t *testing.T) { port, err := getFreePort() if err != nil { t.Fatalf("failed to get free port for test") } s := awstesting.NewClient(aws.NewConfig(). WithMaxRetries(10). WithEndpoint("https://localhost:" + strconv.Itoa(port)). WithSleepDelay(func(time.Duration) {}), ) s.Handlers.Validate.Clear() s.Handlers.Unmarshal.PushBack(unmarshal) s.Handlers.UnmarshalError.PushBack(unmarshalError) out := &testData{} r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) if err = r.Send(); err == nil { t.Fatal("expect error, but got none") } t.Logf("Error, %v", err) awsError := err.(awserr.Error) origError := awsError.OrigErr() t.Logf("Orig Error: %#v of type %T", origError, origError) if e, a := 10, r.RetryCount; e != a { t.Errorf("expect %v retry count, got %v", e, a) } } func TestRequestBodySeekFails(t *testing.T) { s := awstesting.NewClient() s.Handlers.Validate.Clear() s.Handlers.Build.Clear() out := &testData{} r := s.NewRequest(&request.Operation{Name: "Operation"}, nil, out) r.SetReaderBody(&stubSeekFail{ Err: fmt.Errorf("failed to seek reader"), }) err := r.Send() if err == nil { t.Fatal("expect error, but got none") } aerr := err.(awserr.Error) if e, a := request.ErrCodeSerialization, aerr.Code(); e != a { t.Errorf("expect %v error code, got %v", e, a) } } func TestRequestEndpointWithDefaultPort(t *testing.T) { s := awstesting.NewClient(&aws.Config{ Endpoint: aws.String("https://example.test:443"), }) r := s.NewRequest(&request.Operation{ Name: "FooBar", HTTPMethod: "GET", HTTPPath: "/", }, nil, nil) r.Handlers.Validate.Clear() r.Handlers.ValidateResponse.Clear() r.Handlers.Send.Clear() r.Handlers.Send.PushFront(func(r *request.Request) { req := r.HTTPRequest if e, a := "example.test", req.Host; e != a { t.Errorf("expected %v, got %v", e, a) } if e, a := "https://example.test:443/", req.URL.String(); e != a { t.Errorf("expected %v, got %v", e, a) } }) err := r.Send() if err != nil { t.Fatalf("expected no error, got %v", err) } } func TestRequestEndpointWithNonDefaultPort(t *testing.T) { s := awstesting.NewClient(&aws.Config{ Endpoint: aws.String("https://example.test:8443"), }) r := s.NewRequest(&request.Operation{ Name: "FooBar", HTTPMethod: "GET", HTTPPath: "/", }, nil, nil) r.Handlers.Validate.Clear() r.Handlers.ValidateResponse.Clear() r.Handlers.Send.Clear() r.Handlers.Send.PushFront(func(r *request.Request) { req := r.HTTPRequest // http.Request.Host should not be set for non-default ports if e, a := "", req.Host; e != a { t.Errorf("expected %v, got %v", e, a) } if e, a := "https://example.test:8443/", req.URL.String(); e != a { t.Errorf("expected %v, got %v", e, a) } }) err := r.Send() if err != nil { t.Fatalf("expected no error, got %v", err) } } func TestRequestMarshaledEndpointWithDefaultPort(t *testing.T) { s := awstesting.NewClient(&aws.Config{ Endpoint: aws.String("https://example.test:443"), }) r := s.NewRequest(&request.Operation{ Name: "FooBar", HTTPMethod: "GET", HTTPPath: "/", }, nil, nil) r.Handlers.Validate.Clear() r.Handlers.ValidateResponse.Clear() r.Handlers.Build.PushBack(func(r *request.Request) { req := r.HTTPRequest req.URL.Host = "foo." + req.URL.Host }) r.Handlers.Send.Clear() r.Handlers.Send.PushFront(func(r *request.Request) { req := r.HTTPRequest if e, a := "foo.example.test", req.Host; e != a { t.Errorf("expected %v, got %v", e, a) } if e, a := "https://foo.example.test:443/", req.URL.String(); e != a { t.Errorf("expected %v, got %v", e, a) } }) err := r.Send() if err != nil { t.Fatalf("expected no error, got %v", err) } } func TestRequestMarshaledEndpointWithNonDefaultPort(t *testing.T) { s := awstesting.NewClient(&aws.Config{ Endpoint: aws.String("https://example.test:8443"), }) r := s.NewRequest(&request.Operation{ Name: "FooBar", HTTPMethod: "GET", HTTPPath: "/", }, nil, nil) r.Handlers.Validate.Clear() r.Handlers.ValidateResponse.Clear() r.Handlers.Build.PushBack(func(r *request.Request) { req := r.HTTPRequest req.URL.Host = "foo." + req.URL.Host }) r.Handlers.Send.Clear() r.Handlers.Send.PushFront(func(r *request.Request) { req := r.HTTPRequest // http.Request.Host should not be set for non-default ports if e, a := "", req.Host; e != a { t.Errorf("expected %v, got %v", e, a) } if e, a := "https://foo.example.test:8443/", req.URL.String(); e != a { t.Errorf("expected %v, got %v", e, a) } }) err := r.Send() if err != nil { t.Fatalf("expected no error, got %v", err) } } func TestRequestCompleteWithoutHTTPResponse(t *testing.T) { s := awstesting.NewClient(aws.NewConfig().WithRegion("mock-region")) r := s.NewRequest(&request.Operation{ Name: "FooBar", HTTPMethod: "GET", HTTPPath: "/", }, nil, nil) r.Handlers.Sign.Clear() r.Handlers.Sign.PushFront(func(r *request.Request) { r.Error = fmt.Errorf("failed") }) r.Handlers.Complete.PushBack(func(r *request.Request) { if r.HTTPResponse == nil { t.Fatalf("expect HTTPResponse not to be nil") } if r.HTTPResponse.Header == nil { t.Fatalf("expect HTTPResponse.Header not to be nil") } if r.HTTPResponse.Body == nil { t.Fatalf("expect HTTPResponse.Body not to be nil") } }) err := r.Send() if err == nil { t.Fatalf("expect error, got none") } if e, a := "failed", err.Error(); !strings.Contains(a, e) { t.Errorf("expect %v error in %v", e, a) } } type stubSeekFail struct { Err error } func (f *stubSeekFail) Read(b []byte) (int, error) { return len(b), nil } func (f *stubSeekFail) ReadAt(b []byte, offset int64) (int, error) { return len(b), nil } func (f *stubSeekFail) Seek(offset int64, mode int) (int64, error) { return 0, f.Err } func getFreePort() (int, error) { l, err := net.Listen("tcp", ":0") if err != nil { return 0, err } defer l.Close() strAddr := l.Addr().String() parts := strings.Split(strAddr, ":") strPort := parts[len(parts)-1] port, err := strconv.ParseInt(strPort, 10, 32) if err != nil { return 0, err } return int(port), nil }