package manager_test import ( "bytes" "context" "fmt" "io" "io/ioutil" "reflect" "regexp" "strconv" "strings" "sync" "sync/atomic" "testing" "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/feature/s3/manager" managertesting "github.com/aws/aws-sdk-go-v2/feature/s3/manager/internal/testing" "github.com/aws/aws-sdk-go-v2/internal/awstesting" "github.com/aws/aws-sdk-go-v2/internal/sdkio" "github.com/aws/aws-sdk-go-v2/service/s3" ) type downloadCaptureClient struct { GetObjectFn func(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error) GetObjectInvocations int RetrievedRanges []string lock sync.Mutex } func (c *downloadCaptureClient) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) { c.lock.Lock() defer c.lock.Unlock() c.GetObjectInvocations++ if params.Range != nil { c.RetrievedRanges = append(c.RetrievedRanges, aws.ToString(params.Range)) } return c.GetObjectFn(ctx, params, optFns...) } var rangeValueRegex = regexp.MustCompile(`bytes=(\d+)-(\d+)`) func parseRange(rangeValue string) (start, fin int64) { rng := rangeValueRegex.FindStringSubmatch(rangeValue) start, _ = strconv.ParseInt(rng[1], 10, 64) fin, _ = strconv.ParseInt(rng[2], 10, 64) return start, fin } func newDownloadRangeClient(data []byte) (*downloadCaptureClient, *int, *[]string) { capture := &downloadCaptureClient{} capture.GetObjectFn = func(_ context.Context, params *s3.GetObjectInput, _ ...func(*s3.Options)) (*s3.GetObjectOutput, error) { start, fin := parseRange(aws.ToString(params.Range)) fin++ if fin >= int64(len(data)) { fin = int64(len(data)) } bodyBytes := data[start:fin] return &s3.GetObjectOutput{ Body: ioutil.NopCloser(bytes.NewReader(bodyBytes)), ContentRange: aws.String(fmt.Sprintf("bytes %d-%d/%d", start, fin-1, len(data))), ContentLength: int64(len(bodyBytes)), }, nil } return capture, &capture.GetObjectInvocations, &capture.RetrievedRanges } func newDownloadNonRangeClient(data []byte) (*downloadCaptureClient, *int) { capture := &downloadCaptureClient{} capture.GetObjectFn = func(_ context.Context, params *s3.GetObjectInput, _ ...func(*s3.Options)) (*s3.GetObjectOutput, error) { return &s3.GetObjectOutput{ Body: ioutil.NopCloser(bytes.NewReader(data[:])), ContentLength: int64(len(data)), }, nil } return capture, &capture.GetObjectInvocations } type mockHTTPStatusError struct { StatusCode int } func (m *mockHTTPStatusError) Error() string { return fmt.Sprintf("http status code: %v", m.StatusCode) } func (m *mockHTTPStatusError) HTTPStatusCode() int { return m.StatusCode } func newDownloadContentRangeTotalAnyClient(data []byte) (*downloadCaptureClient, *int) { capture := &downloadCaptureClient{} completed := false capture.GetObjectFn = func(_ context.Context, params *s3.GetObjectInput, _ ...func(*s3.Options)) (*s3.GetObjectOutput, error) { if completed { return nil, &mockHTTPStatusError{StatusCode: 416} } start, fin := parseRange(aws.ToString(params.Range)) fin++ if fin >= int64(len(data)) { fin = int64(len(data)) completed = true } bodyBytes := data[start:fin] return &s3.GetObjectOutput{ Body: ioutil.NopCloser(bytes.NewReader(bodyBytes)), ContentRange: aws.String(fmt.Sprintf("bytes %d-%d/*", start, fin-1)), }, nil } return capture, &capture.GetObjectInvocations } func newDownloadWithErrReaderClient(cases []testErrReader) (*downloadCaptureClient, *int) { var index int c := &downloadCaptureClient{} c.GetObjectFn = func(_ context.Context, params *s3.GetObjectInput, _ ...func(*s3.Options)) (*s3.GetObjectOutput, error) { c := cases[index] out := &s3.GetObjectOutput{ Body: ioutil.NopCloser(&c), ContentRange: aws.String(fmt.Sprintf("bytes %d-%d/%d", 0, c.Len-1, c.Len)), ContentLength: c.Len, } index++ return out, nil } return c, &c.GetObjectInvocations } func TestDownloadOrder(t *testing.T) { c, invocations, ranges := newDownloadRangeClient(buf12MB) d := manager.NewDownloader(c, func(d *manager.Downloader) { d.Concurrency = 1 }) w := manager.NewWriteAtBuffer(make([]byte, len(buf12MB))) n, err := d.Download(context.Background(), w, &s3.GetObjectInput{ Bucket: aws.String("bucket"), Key: aws.String("key"), }) if err != nil { t.Fatalf("expect no error, got %v", err) } if e, a := int64(len(buf12MB)), n; e != a { t.Errorf("expect %d buffer length, got %d", e, a) } if e, a := 3, *invocations; e != a { t.Errorf("expect %v API calls, got %v", e, a) } expectRngs := []string{"bytes=0-5242879", "bytes=5242880-10485759", "bytes=10485760-15728639"} if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) { t.Errorf("expect %v ranges, got %v", e, a) } } func TestDownloadZero(t *testing.T) { c, invocations, ranges := newDownloadRangeClient([]byte{}) d := manager.NewDownloader(c) w := &manager.WriteAtBuffer{} n, err := d.Download(context.Background(), w, &s3.GetObjectInput{ Bucket: aws.String("bucket"), Key: aws.String("key"), }) if err != nil { t.Fatalf("expect no error, got %v", err) } if n != 0 { t.Errorf("expect 0 bytes read, got %d", n) } if e, a := 1, *invocations; e != a { t.Errorf("expect %v API calls, got %v", e, a) } expectRngs := []string{"bytes=0-5242879"} if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) { t.Errorf("expect %v ranges, got %v", e, a) } } func TestDownloadSetPartSize(t *testing.T) { c, invocations, ranges := newDownloadRangeClient([]byte{1, 2, 3}) d := manager.NewDownloader(c, func(d *manager.Downloader) { d.Concurrency = 1 d.PartSize = 1 }) w := &manager.WriteAtBuffer{} n, err := d.Download(context.Background(), w, &s3.GetObjectInput{ Bucket: aws.String("bucket"), Key: aws.String("key"), }) if err != nil { t.Fatalf("expect no error, got %v", err) } if e, a := int64(3), n; e != a { t.Errorf("expect %d bytes read, got %d", e, a) } if e, a := 3, *invocations; e != a { t.Errorf("expect %v API calls, got %v", e, a) } expectRngs := []string{"bytes=0-0", "bytes=1-1", "bytes=2-2"} if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) { t.Errorf("expect %v ranges, got %v", e, a) } expectBytes := []byte{1, 2, 3} if e, a := expectBytes, w.Bytes(); !reflect.DeepEqual(e, a) { t.Errorf("expect %v bytes, got %v", e, a) } } func TestDownloadError(t *testing.T) { c, invocations, _ := newDownloadRangeClient([]byte{1, 2, 3}) num := 0 orig := c.GetObjectFn c.GetObjectFn = func(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) { out, err := orig(ctx, params, optFns...) num++ if num > 1 { return &s3.GetObjectOutput{}, fmt.Errorf("s3 service error") } return out, err } d := manager.NewDownloader(c, func(d *manager.Downloader) { d.Concurrency = 1 d.PartSize = 1 }) w := &manager.WriteAtBuffer{} n, err := d.Download(context.Background(), w, &s3.GetObjectInput{ Bucket: aws.String("bucket"), Key: aws.String("key"), }) if err == nil { t.Fatalf("expect error, got none") } if e, a := "s3 service error", err.Error(); e != a { t.Errorf("expect %s error code, got %s", e, a) } if e, a := int64(1), n; e != a { t.Errorf("expect %d bytes read, got %d", e, a) } if e, a := 2, *invocations; e != a { t.Errorf("expect %v API calls, got %v", e, a) } expectBytes := []byte{1} if e, a := expectBytes, w.Bytes(); !reflect.DeepEqual(e, a) { t.Errorf("expect %v bytes, got %v", e, a) } } func TestDownloadNonChunk(t *testing.T) { c, invocations := newDownloadNonRangeClient(buf2MB) d := manager.NewDownloader(c, func(d *manager.Downloader) { d.Concurrency = 1 }) w := &manager.WriteAtBuffer{} n, err := d.Download(context.Background(), w, &s3.GetObjectInput{ Bucket: aws.String("bucket"), Key: aws.String("key"), }) if err != nil { t.Fatalf("expect no error, got %v", err) } if e, a := int64(len(buf2MB)), n; e != a { t.Errorf("expect %d bytes read, got %d", e, a) } if e, a := 1, *invocations; e != a { t.Errorf("expect %v API calls, got %v", e, a) } count := 0 for _, b := range w.Bytes() { count += int(b) } if count != 0 { t.Errorf("expect 0 count, got %d", count) } } func TestDownloadNoContentRangeLength(t *testing.T) { s, invocations, _ := newDownloadRangeClient(buf2MB) d := manager.NewDownloader(s, func(d *manager.Downloader) { d.Concurrency = 1 }) w := &manager.WriteAtBuffer{} n, err := d.Download(context.Background(), w, &s3.GetObjectInput{ Bucket: aws.String("bucket"), Key: aws.String("key"), }) if err != nil { t.Fatalf("expect no error, got %v", err) } if e, a := int64(len(buf2MB)), n; e != a { t.Errorf("expect %d bytes read, got %d", e, a) } if e, a := 1, *invocations; e != a { t.Errorf("expect %v API calls, got %v", e, a) } count := 0 for _, b := range w.Bytes() { count += int(b) } if count != 0 { t.Errorf("expect 0 count, got %d", count) } } func TestDownloadContentRangeTotalAny(t *testing.T) { s, invocations := newDownloadContentRangeTotalAnyClient(buf2MB) d := manager.NewDownloader(s, func(d *manager.Downloader) { d.Concurrency = 1 }) w := &manager.WriteAtBuffer{} n, err := d.Download(context.Background(), w, &s3.GetObjectInput{ Bucket: aws.String("bucket"), Key: aws.String("key"), }) if err != nil { t.Fatalf("expect no error, got %v", err) } if e, a := int64(len(buf2MB)), n; e != a { t.Errorf("expect %d bytes read, got %d", e, a) } if e, a := 2, *invocations; e != a { t.Errorf("expect %v API calls, got %v", e, a) } count := 0 for _, b := range w.Bytes() { count += int(b) } if count != 0 { t.Errorf("expect 0 count, got %d", count) } } func TestDownloadPartBodyRetry_SuccessRetry(t *testing.T) { c, invocations := newDownloadWithErrReaderClient([]testErrReader{ {Buf: []byte("ab"), Len: 3, Err: io.ErrUnexpectedEOF}, {Buf: []byte("123"), Len: 3, Err: io.EOF}, }) d := manager.NewDownloader(c, func(d *manager.Downloader) { d.Concurrency = 1 }) w := &manager.WriteAtBuffer{} n, err := d.Download(context.Background(), w, &s3.GetObjectInput{ Bucket: aws.String("bucket"), Key: aws.String("key"), }) if err != nil { t.Fatalf("expect no error, got %v", err) } if e, a := int64(3), n; e != a { t.Errorf("expect %d bytes read, got %d", e, a) } if e, a := 2, *invocations; e != a { t.Errorf("expect %v API calls, got %v", e, a) } if e, a := "123", string(w.Bytes()); e != a { t.Errorf("expect %q response, got %q", e, a) } } func TestDownloadPartBodyRetry_SuccessNoRetry(t *testing.T) { c, invocations := newDownloadWithErrReaderClient([]testErrReader{ {Buf: []byte("abc"), Len: 3, Err: io.EOF}, }) d := manager.NewDownloader(c, func(d *manager.Downloader) { d.Concurrency = 1 }) w := &manager.WriteAtBuffer{} n, err := d.Download(context.Background(), w, &s3.GetObjectInput{ Bucket: aws.String("bucket"), Key: aws.String("key"), }) if err != nil { t.Fatalf("expect no error, got %v", err) } if e, a := int64(3), n; e != a { t.Errorf("expect %d bytes read, got %d", e, a) } if e, a := 1, *invocations; e != a { t.Errorf("expect %v API calls, got %v", e, a) } if e, a := "abc", string(w.Bytes()); e != a { t.Errorf("expect %q response, got %q", e, a) } } func TestDownloadPartBodyRetry_FailRetry(t *testing.T) { c, invocations := newDownloadWithErrReaderClient([]testErrReader{ {Buf: []byte("ab"), Len: 3, Err: io.ErrUnexpectedEOF}, }) d := manager.NewDownloader(c, func(d *manager.Downloader) { d.Concurrency = 1 d.PartBodyMaxRetries = 0 }) w := &manager.WriteAtBuffer{} n, err := d.Download(context.Background(), w, &s3.GetObjectInput{ Bucket: aws.String("bucket"), Key: aws.String("key"), }) if err == nil { t.Fatalf("expect error, got none") } if e, a := "unexpected EOF", err.Error(); !strings.Contains(a, e) { t.Errorf("expect %q error message to be in %q", e, a) } if e, a := int64(2), n; e != a { t.Errorf("expect %d bytes read, got %d", e, a) } if e, a := 1, *invocations; e != a { t.Errorf("expect %v API calls, got %v", e, a) } if e, a := "ab", string(w.Bytes()); e != a { t.Errorf("expect %q response, got %q", e, a) } } func TestDownloadWithContextCanceled(t *testing.T) { d := manager.NewDownloader(s3.New(s3.Options{ Region: "mock-region", })) params := s3.GetObjectInput{ Bucket: aws.String("bucket"), Key: aws.String("Key"), } ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})} ctx.Error = fmt.Errorf("context canceled") close(ctx.DoneCh) w := &manager.WriteAtBuffer{} _, err := d.Download(ctx, w, ¶ms) if err == nil { t.Fatalf("expected error, did not get one") } if e, a := "canceled", err.Error(); !strings.Contains(a, e) { t.Errorf("expected error message to contain %q, but did not %q", e, a) } } func TestDownload_WithRange(t *testing.T) { c, invocations, ranges := newDownloadRangeClient([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) d := manager.NewDownloader(c, func(d *manager.Downloader) { d.Concurrency = 10 // should be ignored d.PartSize = 1 // should be ignored }) w := &manager.WriteAtBuffer{} n, err := d.Download(context.Background(), w, &s3.GetObjectInput{ Bucket: aws.String("bucket"), Key: aws.String("key"), Range: aws.String("bytes=2-6"), }) if err != nil { t.Fatalf("expect no error, got %v", err) } if e, a := int64(5), n; e != a { t.Errorf("expect %d bytes read, got %d", e, a) } if e, a := 1, *invocations; e != a { t.Errorf("expect %v API calls, got %v", e, a) } expectRngs := []string{"bytes=2-6"} if e, a := expectRngs, *ranges; !reflect.DeepEqual(e, a) { t.Errorf("expect %v ranges, got %v", e, a) } expectBytes := []byte{2, 3, 4, 5, 6} if e, a := expectBytes, w.Bytes(); !reflect.DeepEqual(e, a) { t.Errorf("expect %v bytes, got %v", e, a) } } type mockDownloadCLient func(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) func (m mockDownloadCLient) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) { return m(ctx, params, optFns...) } func TestDownload_WithFailure(t *testing.T) { reqCount := int64(0) startingByte := 0 client := mockDownloadCLient(func(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (out *s3.GetObjectOutput, err error) { switch atomic.LoadInt64(&reqCount) { case 1: // Give a chance for the multipart chunks to be queued up time.Sleep(1 * time.Second) err = fmt.Errorf("some connection error") default: body := bytes.NewReader(make([]byte, manager.DefaultDownloadPartSize)) out = &s3.GetObjectOutput{ Body: ioutil.NopCloser(body), ContentLength: int64(body.Len()), ContentRange: aws.String(fmt.Sprintf("bytes %d-%d/%d", startingByte, body.Len()-1, body.Len()*10)), } startingByte += body.Len() if reqCount > 0 { // sleep here to ensure context switching between goroutines time.Sleep(25 * time.Millisecond) } } atomic.AddInt64(&reqCount, 1) return out, err }) d := manager.NewDownloader(client, func(d *manager.Downloader) { d.Concurrency = 2 }) w := &manager.WriteAtBuffer{} params := s3.GetObjectInput{ Bucket: aws.String("Bucket"), Key: aws.String("Key"), } // Expect this request to exit quickly after failure _, err := d.Download(context.Background(), w, ¶ms) if err == nil { t.Fatalf("expect error, got none") } if atomic.LoadInt64(&reqCount) > 3 { t.Errorf("expect no more than 3 requests, but received %d", reqCount) } } func TestDownloadBufferStrategy(t *testing.T) { cases := map[string]struct { partSize int64 strategy *recordedWriterReadFromProvider expectedSize int64 }{ "no strategy": { partSize: manager.DefaultDownloadPartSize, expectedSize: 10 * sdkio.MebiByte, }, "partSize modulo bufferSize == 0": { partSize: 5 * sdkio.MebiByte, strategy: &recordedWriterReadFromProvider{ WriterReadFromProvider: manager.NewPooledBufferedWriterReadFromProvider(int(sdkio.MebiByte)), // 1 MiB }, expectedSize: 10 * sdkio.MebiByte, // 10 MiB }, "partSize modulo bufferSize > 0": { partSize: 5 * 1024 * 1204, // 5 MiB strategy: &recordedWriterReadFromProvider{ WriterReadFromProvider: manager.NewPooledBufferedWriterReadFromProvider(2 * int(sdkio.MebiByte)), // 2 MiB }, expectedSize: 10 * sdkio.MebiByte, // 10 MiB }, } for name, tCase := range cases { t.Run(name, func(t *testing.T) { expected := managertesting.GetTestBytes(int(tCase.expectedSize)) client, _, _ := newDownloadRangeClient(expected) d := manager.NewDownloader(client, func(d *manager.Downloader) { d.PartSize = tCase.partSize if tCase.strategy != nil { d.BufferProvider = tCase.strategy } }) buffer := manager.NewWriteAtBuffer(make([]byte, len(expected))) n, err := d.Download(context.Background(), buffer, &s3.GetObjectInput{ Bucket: aws.String("bucket"), Key: aws.String("key"), }) if err != nil { t.Errorf("failed to download: %v", err) } if e, a := len(expected), int(n); e != a { t.Errorf("expected %v, got %v downloaded bytes", e, a) } if e, a := expected, buffer.Bytes(); !bytes.Equal(e, a) { t.Errorf("downloaded bytes did not match expected") } if tCase.strategy != nil { if e, a := tCase.strategy.callbacksVended, tCase.strategy.callbacksExecuted; e != a { t.Errorf("expected %v, got %v", e, a) } } }) } } type testErrReader struct { Buf []byte Err error Len int64 off int } func (r *testErrReader) Read(p []byte) (int, error) { to := len(r.Buf) - r.off n := copy(p, r.Buf[r.off:to]) r.off += n if n < len(p) { return n, r.Err } return n, nil } func TestDownloadBufferStrategy_Errors(t *testing.T) { expected := managertesting.GetTestBytes(int(10 * sdkio.MebiByte)) client, _, _ := newDownloadRangeClient(expected) strat := &recordedWriterReadFromProvider{ WriterReadFromProvider: manager.NewPooledBufferedWriterReadFromProvider(int(2 * sdkio.MebiByte)), } seenOps := make(map[string]struct{}) orig := client.GetObjectFn client.GetObjectFn = func(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) { out, err := orig(ctx, params, optFns...) fingerPrint := fmt.Sprintf("%s/%s/%s", *params.Bucket, *params.Key, *params.Range) if _, ok := seenOps[fingerPrint]; ok { return out, err } seenOps[fingerPrint] = struct{}{} _, _ = io.Copy(ioutil.Discard, out.Body) out.Body = ioutil.NopCloser(&badReader{err: io.ErrUnexpectedEOF}) return out, err } d := manager.NewDownloader(client, func(d *manager.Downloader) { d.PartSize = 5 * sdkio.MebiByte d.BufferProvider = strat d.Concurrency = 1 }) buffer := manager.NewWriteAtBuffer(make([]byte, len(expected))) n, err := d.Download(context.Background(), buffer, &s3.GetObjectInput{ Bucket: aws.String("bucket"), Key: aws.String("key"), }) if err != nil { t.Errorf("failed to download: %v", err) } if e, a := len(expected), int(n); e != a { t.Errorf("expected %v, got %v downloaded bytes", e, a) } if e, a := expected, buffer.Bytes(); !bytes.Equal(e, a) { t.Errorf("downloaded bytes did not match expected") } if e, a := strat.callbacksVended, strat.callbacksExecuted; e != a { t.Errorf("expected %v, got %v", e, a) } } func TestDownloaderValidARN(t *testing.T) { cases := map[string]struct { input s3.GetObjectInput wantErr bool }{ "standard bucket": { input: s3.GetObjectInput{ Bucket: aws.String("test-bucket"), Key: aws.String("test-key"), }, }, "accesspoint": { input: s3.GetObjectInput{ Bucket: aws.String("arn:aws:s3:us-west-2:123456789012:accesspoint/myap"), Key: aws.String("test-key"), }, }, "outpost accesspoint": { input: s3.GetObjectInput{ Bucket: aws.String("arn:aws:s3-outposts:us-west-2:012345678901:outpost/op-1234567890123456/accesspoint/myaccesspoint"), Key: aws.String("test-key"), }, }, "s3-object-lambda accesspoint": { input: s3.GetObjectInput{ Bucket: aws.String("arn:aws:s3-object-lambda:us-west-2:123456789012:accesspoint/myap"), }, wantErr: true, }, } for name, tt := range cases { t.Run(name, func(t *testing.T) { client, _ := newDownloadNonRangeClient(buf2MB) downloader := manager.NewDownloader(client, func(downloader *manager.Downloader) { downloader.Concurrency = 1 }) _, err := downloader.Download(context.Background(), &awstesting.DiscardAt{}, &tt.input) if (err != nil) != tt.wantErr { t.Errorf("err: %v, wantErr: %v", err, tt.wantErr) } }) } } type recordedWriterReadFromProvider struct { callbacksVended uint32 callbacksExecuted uint32 manager.WriterReadFromProvider } func (r *recordedWriterReadFromProvider) GetReadFrom(writer io.Writer) (manager.WriterReadFrom, func()) { w, cleanup := r.WriterReadFromProvider.GetReadFrom(writer) atomic.AddUint32(&r.callbacksVended, 1) return w, func() { atomic.AddUint32(&r.callbacksExecuted, 1) cleanup() } } type badReader struct { err error } func (b *badReader) Read(p []byte) (int, error) { tb := managertesting.GetTestBytes(len(p)) copy(p, tb) return len(p), b.err }