package manager_test
import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"reflect"
"regexp"
"sort"
"strconv"
"strings"
"testing"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/retry"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
s3testing "github.com/aws/aws-sdk-go-v2/feature/s3/manager/internal/testing"
"github.com/aws/aws-sdk-go-v2/internal/awstesting"
"github.com/aws/aws-sdk-go-v2/internal/sdk"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/google/go-cmp/cmp"
)
// getReaderLength discards the bytes from reader and returns the length
func getReaderLength(r io.Reader) int64 {
n, _ := io.Copy(ioutil.Discard, r)
return n
}
func TestUploadOrderMulti(t *testing.T) {
c, invocations, args := s3testing.NewUploadLoggingClient(nil)
u := manager.NewUploader(c)
resp, err := u.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key - value"),
Body: bytes.NewReader(buf12MB),
ServerSideEncryption: "aws:kms",
SSEKMSKeyId: aws.String("KmsId"),
ContentType: aws.String("content/type"),
})
if err != nil {
t.Errorf("Expected no error but received %v", err)
}
if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart",
"UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 {
t.Error(err)
}
if e, a := `https://mock.amazonaws.com/key`, resp.Location; e != a {
t.Errorf("expect %q, got %q", e, a)
}
if "UPLOAD-ID" != resp.UploadID {
t.Errorf("expect %q, got %q", "UPLOAD-ID", resp.UploadID)
}
if "VERSION-ID" != *resp.VersionID {
t.Errorf("expect %q, got %q", "VERSION-ID", *resp.VersionID)
}
// Validate input values
// UploadPart
for i := 1; i < 4; i++ {
v := aws.ToString((*args)[i].(*s3.UploadPartInput).UploadId)
if "UPLOAD-ID" != v {
t.Errorf("Expected %q, but received %q", "UPLOAD-ID", v)
}
}
// CompleteMultipartUpload
v := aws.ToString((*args)[4].(*s3.CompleteMultipartUploadInput).UploadId)
if "UPLOAD-ID" != v {
t.Errorf("Expected %q, but received %q", "UPLOAD-ID", v)
}
parts := (*args)[4].(*s3.CompleteMultipartUploadInput).MultipartUpload.Parts
for i := 0; i < 3; i++ {
num := parts[i].PartNumber
etag := aws.ToString(parts[i].ETag)
if int32(i+1) != num {
t.Errorf("expect %d, got %d", i+1, num)
}
if matched, err := regexp.MatchString(`^ETAG\d+$`, etag); !matched || err != nil {
t.Errorf("Failed regexp expression `^ETAG\\d+$`")
}
}
// Custom headers
cmu := (*args)[0].(*s3.CreateMultipartUploadInput)
if e, a := types.ServerSideEncryption("aws:kms"), cmu.ServerSideEncryption; e != a {
t.Errorf("expect %q, got %q", e, a)
}
if e, a := "KmsId", aws.ToString(cmu.SSEKMSKeyId); e != a {
t.Errorf("expect %q, got %q", e, a)
}
if e, a := "content/type", aws.ToString(cmu.ContentType); e != a {
t.Errorf("expect %q, got %q", e, a)
}
}
func TestUploadOrderMultiDifferentPartSize(t *testing.T) {
s, ops, args := s3testing.NewUploadLoggingClient(nil)
mgr := manager.NewUploader(s, func(u *manager.Uploader) {
u.PartSize = 1024 * 1024 * 7
u.Concurrency = 1
})
_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
})
if err != nil {
t.Errorf("expect no error, got %v", err)
}
vals := []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}
if !reflect.DeepEqual(vals, *ops) {
t.Errorf("expect %v, got %v", vals, *ops)
}
// Part lengths
if len := getReaderLength((*args)[1].(*s3.UploadPartInput).Body); 1024*1024*7 != len {
t.Errorf("expect %d, got %d", 1024*1024*7, len)
}
if len := getReaderLength((*args)[2].(*s3.UploadPartInput).Body); 1024*1024*5 != len {
t.Errorf("expect %d, got %d", 1024*1024*5, len)
}
}
func TestUploadIncreasePartSize(t *testing.T) {
s, invocations, args := s3testing.NewUploadLoggingClient(nil)
mgr := manager.NewUploader(s, func(u *manager.Uploader) {
u.Concurrency = 1
u.MaxUploadParts = 2
})
_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
})
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if int64(manager.DefaultDownloadPartSize) != mgr.PartSize {
t.Errorf("expect %d, got %d", manager.DefaultDownloadPartSize, mgr.PartSize)
}
if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 {
t.Error(diff)
}
// Part lengths
if len := getReaderLength((*args)[1].(*s3.UploadPartInput).Body); (1024*1024*6)+1 != len {
t.Errorf("expect %d, got %d", (1024*1024*6)+1, len)
}
if len := getReaderLength((*args)[2].(*s3.UploadPartInput).Body); (1024*1024*6)-1 != len {
t.Errorf("expect %d, got %d", (1024*1024*6)-1, len)
}
}
func TestUploadFailIfPartSizeTooSmall(t *testing.T) {
mgr := manager.NewUploader(s3.New(s3.Options{}), func(u *manager.Uploader) {
u.PartSize = 5
})
resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
})
if resp != nil {
t.Errorf("Expected response to be nil, but received %v", resp)
}
if err == nil {
t.Errorf("Expected error, but received nil")
}
if e, a := "part size must be at least", err.Error(); !strings.Contains(a, e) {
t.Errorf("expect %v to be in %v", e, a)
}
}
func TestUploadOrderSingle(t *testing.T) {
client, invocations, params := s3testing.NewUploadLoggingClient(nil)
mgr := manager.NewUploader(client)
resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key - value"),
Body: bytes.NewReader(buf2MB),
ServerSideEncryption: "aws:kms",
SSEKMSKeyId: aws.String("KmsId"),
ContentType: aws.String("content/type"),
})
if err != nil {
t.Errorf("expect no error but received %v", err)
}
if diff := cmp.Diff([]string{"PutObject"}, *invocations); len(diff) > 0 {
t.Error(diff)
}
if e, a := `https://mock.amazonaws.com/key`, resp.Location; e != a {
t.Errorf("expect %q, got %q", e, a)
}
if e := "VERSION-ID"; e != *resp.VersionID {
t.Errorf("expect %q, got %q", e, *resp.VersionID)
}
if len(resp.UploadID) > 0 {
t.Errorf("expect empty string, got %q", resp.UploadID)
}
putObjectInput := (*params)[0].(*s3.PutObjectInput)
if e, a := types.ServerSideEncryption("aws:kms"), putObjectInput.ServerSideEncryption; e != a {
t.Errorf("expect %q, got %q", e, a)
}
if e, a := "KmsId", aws.ToString(putObjectInput.SSEKMSKeyId); e != a {
t.Errorf("expect %q, got %q", e, a)
}
if e, a := "content/type", aws.ToString(putObjectInput.ContentType); e != a {
t.Errorf("Expected %q, but received %q", e, a)
}
}
func TestUploadOrderSingleFailure(t *testing.T) {
client, ops, _ := s3testing.NewUploadLoggingClient(nil)
client.PutObjectFn = func(*s3testing.UploadLoggingClient, *s3.PutObjectInput) (*s3.PutObjectOutput, error) {
return nil, fmt.Errorf("put object failure")
}
mgr := manager.NewUploader(client)
resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf2MB),
})
if err == nil {
t.Error("expect error, got nil")
}
if diff := cmp.Diff([]string{"PutObject"}, *ops); len(diff) > 0 {
t.Error(diff)
}
if resp != nil {
t.Errorf("expect response to be nil, got %v", resp)
}
}
func TestUploadOrderZero(t *testing.T) {
c, invocations, params := s3testing.NewUploadLoggingClient(nil)
mgr := manager.NewUploader(c)
resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(make([]byte, 0)),
})
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if diff := cmp.Diff([]string{"PutObject"}, *invocations); len(diff) > 0 {
t.Error(diff)
}
if len(resp.Location) == 0 {
t.Error("expect Location to not be empty")
}
if len(resp.UploadID) > 0 {
t.Errorf("expect empty string, got %q", resp.UploadID)
}
if e, a := int64(0), getReaderLength((*params)[0].(*s3.PutObjectInput).Body); e != a {
t.Errorf("Expected %d, but received %d", e, a)
}
}
func TestUploadOrderMultiFailure(t *testing.T) {
c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
c.UploadPartFn = func(u *s3testing.UploadLoggingClient, _ *s3.UploadPartInput) (*s3.UploadPartOutput, error) {
if u.PartNum == 2 {
return nil, fmt.Errorf("an unexpected error")
}
return &s3.UploadPartOutput{ETag: aws.String(fmt.Sprintf("ETAG%d", u.PartNum))}, nil
}
mgr := manager.NewUploader(c, func(u *manager.Uploader) {
u.Concurrency = 1
})
_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
})
if err == nil {
t.Error("expect error, got nil")
}
if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", "AbortMultipartUpload"}, *invocations); len(diff) > 0 {
t.Error(diff)
}
}
func TestUploadOrderMultiFailureOnComplete(t *testing.T) {
c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
c.CompleteMultipartUploadFn = func(*s3testing.UploadLoggingClient, *s3.CompleteMultipartUploadInput) (*s3.CompleteMultipartUploadOutput, error) {
return nil, fmt.Errorf("complete multipart error")
}
mgr := manager.NewUploader(c, func(u *manager.Uploader) {
u.Concurrency = 1
})
_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(buf12MB),
})
if err == nil {
t.Error("expect error, got nil")
}
if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", "UploadPart",
"CompleteMultipartUpload", "AbortMultipartUpload"}, *invocations); len(diff) > 0 {
t.Error(diff)
}
}
func TestUploadOrderMultiFailureOnCreate(t *testing.T) {
c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
c.CreateMultipartUploadFn = func(*s3testing.UploadLoggingClient, *s3.CreateMultipartUploadInput) (*s3.CreateMultipartUploadOutput, error) {
return nil, fmt.Errorf("create multipart upload failure")
}
mgr := manager.NewUploader(c)
_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(make([]byte, 1024*1024*12)),
})
if err == nil {
t.Error("expect error, got nil")
}
if diff := cmp.Diff([]string{"CreateMultipartUpload"}, *invocations); len(diff) > 0 {
t.Error(diff)
}
}
func TestUploadOrderMultiFailureLeaveParts(t *testing.T) {
c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
c.UploadPartFn = func(u *s3testing.UploadLoggingClient, _ *s3.UploadPartInput) (*s3.UploadPartOutput, error) {
if u.PartNum == 2 {
return nil, fmt.Errorf("upload part failure")
}
return &s3.UploadPartOutput{ETag: aws.String(fmt.Sprintf("ETAG%d", u.PartNum))}, nil
}
mgr := manager.NewUploader(c, func(u *manager.Uploader) {
u.Concurrency = 1
u.LeavePartsOnError = true
})
_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(make([]byte, 1024*1024*12)),
})
if err == nil {
t.Error("expect error, got nil")
}
if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart"}, *invocations); len(diff) > 0 {
t.Error(err)
}
}
type failreader struct {
times int
failCount int
}
func (f *failreader) Read(b []byte) (int, error) {
f.failCount++
if f.failCount >= f.times {
return 0, fmt.Errorf("random failure")
}
return len(b), nil
}
func TestUploadOrderReadFail1(t *testing.T) {
c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
mgr := manager.NewUploader(c)
_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &failreader{times: 1},
})
if err == nil {
t.Fatalf("expect error to not be nil")
}
if e, a := "random failure", err.Error(); !strings.Contains(a, e) {
t.Errorf("expect %v, got %v", e, a)
}
if diff := cmp.Diff([]string(nil), *invocations); len(diff) > 0 {
t.Error(diff)
}
}
func TestUploadOrderReadFail2(t *testing.T) {
c, invocations, _ := s3testing.NewUploadLoggingClient([]string{"UploadPart"})
mgr := manager.NewUploader(c, func(u *manager.Uploader) {
u.Concurrency = 1
})
_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &failreader{times: 2},
})
if err == nil {
t.Fatalf("expect error to not be nil")
}
if e, a := "random failure", err.Error(); !strings.Contains(a, e) {
t.Errorf("expect %v, got %q", e, a)
}
if diff := cmp.Diff([]string{"CreateMultipartUpload", "AbortMultipartUpload"}, *invocations); len(diff) > 0 {
t.Error(diff)
}
}
type sizedReader struct {
size int
cur int
err error
}
func (s *sizedReader) Read(p []byte) (n int, err error) {
if s.cur >= s.size {
if s.err == nil {
s.err = io.EOF
}
return 0, s.err
}
n = len(p)
s.cur += len(p)
if s.cur > s.size {
n -= s.cur - s.size
}
return n, err
}
func TestUploadOrderMultiBufferedReader(t *testing.T) {
c, invocations, params := s3testing.NewUploadLoggingClient(nil)
mgr := manager.NewUploader(c)
_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &sizedReader{size: 1024 * 1024 * 12},
})
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart",
"UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 {
t.Error(diff)
}
// Part lengths
var parts []int64
for i := 1; i <= 3; i++ {
parts = append(parts, getReaderLength((*params)[i].(*s3.UploadPartInput).Body))
}
sort.Slice(parts, func(i, j int) bool {
return parts[i] < parts[j]
})
if diff := cmp.Diff([]int64{1024 * 1024 * 2, 1024 * 1024 * 5, 1024 * 1024 * 5}, parts); len(diff) > 0 {
t.Error(diff)
}
}
func TestUploadOrderMultiBufferedReaderPartial(t *testing.T) {
c, invocations, params := s3testing.NewUploadLoggingClient(nil)
mgr := manager.NewUploader(c)
_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &sizedReader{size: 1024 * 1024 * 12, err: io.EOF},
})
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart",
"UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 {
t.Error(diff)
}
// Part lengths
var parts []int64
for i := 1; i <= 3; i++ {
parts = append(parts, getReaderLength((*params)[i].(*s3.UploadPartInput).Body))
}
sort.Slice(parts, func(i, j int) bool {
return parts[i] < parts[j]
})
if diff := cmp.Diff([]int64{1024 * 1024 * 2, 1024 * 1024 * 5, 1024 * 1024 * 5}, parts); len(diff) > 0 {
t.Error(diff)
}
}
// TestUploadOrderMultiBufferedReaderEOF tests the edge case where the
// file size is the same as part size.
func TestUploadOrderMultiBufferedReaderEOF(t *testing.T) {
c, invocations, params := s3testing.NewUploadLoggingClient(nil)
mgr := manager.NewUploader(c)
_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &sizedReader{size: 1024 * 1024 * 10, err: io.EOF},
})
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if diff := cmp.Diff([]string{"CreateMultipartUpload", "UploadPart", "UploadPart", "CompleteMultipartUpload"}, *invocations); len(diff) > 0 {
t.Error(diff)
}
// Part lengths
var parts []int64
for i := 1; i <= 2; i++ {
parts = append(parts, getReaderLength((*params)[i].(*s3.UploadPartInput).Body))
}
sort.Slice(parts, func(i, j int) bool {
return parts[i] < parts[j]
})
if diff := cmp.Diff([]int64{1024 * 1024 * 5, 1024 * 1024 * 5}, parts); len(diff) > 0 {
t.Error(diff)
}
}
func TestUploadOrderMultiBufferedReaderExceedTotalParts(t *testing.T) {
c, invocations, _ := s3testing.NewUploadLoggingClient([]string{"UploadPart"})
mgr := manager.NewUploader(c, func(u *manager.Uploader) {
u.Concurrency = 1
u.MaxUploadParts = 2
})
resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &sizedReader{size: 1024 * 1024 * 12},
})
if err == nil {
t.Fatal("expect error, got nil")
}
if resp != nil {
t.Errorf("expect nil, got %v", resp)
}
if diff := cmp.Diff([]string{"CreateMultipartUpload", "AbortMultipartUpload"}, *invocations); len(diff) > 0 {
t.Error(diff)
}
if !strings.Contains(err.Error(), "configured MaxUploadParts (2)") {
t.Errorf("expect 'configured MaxUploadParts (2)', got %q", err.Error())
}
}
func TestUploadOrderSingleBufferedReader(t *testing.T) {
c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
mgr := manager.NewUploader(c)
resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &sizedReader{size: 1024 * 1024 * 2},
})
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if diff := cmp.Diff([]string{"PutObject"}, *invocations); len(diff) > 0 {
t.Error(diff)
}
if len(resp.Location) == 0 {
t.Error("expect a value in Location")
}
if len(resp.UploadID) > 0 {
t.Errorf("expect no value, got %q", resp.UploadID)
}
}
func TestUploadZeroLenObject(t *testing.T) {
client, invocations, _ := s3testing.NewUploadLoggingClient(nil)
mgr := manager.NewUploader(client)
resp, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: strings.NewReader(""),
})
if err != nil {
t.Errorf("expect no error but received %v", err)
}
if diff := cmp.Diff([]string{"PutObject"}, *invocations); len(diff) > 0 {
t.Errorf("expect request to have been made, but was not, %v", diff)
}
// TODO: not needed?
if len(resp.Location) == 0 {
t.Error("expect a non-empty string value for Location")
}
if len(resp.UploadID) > 0 {
t.Errorf("expect empty string, but received %q", resp.UploadID)
}
}
type testIncompleteReader struct {
Size int64
read int64
}
func (r *testIncompleteReader) Read(p []byte) (n int, err error) {
r.read += int64(len(p))
if r.read >= r.Size {
return int(r.read - r.Size), io.ErrUnexpectedEOF
}
return len(p), nil
}
func TestUploadUnexpectedEOF(t *testing.T) {
c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
mgr := manager.NewUploader(c, func(u *manager.Uploader) {
u.Concurrency = 1
u.PartSize = manager.MinUploadPartSize
})
_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: &testIncompleteReader{
Size: manager.MinUploadPartSize + 1,
},
})
if err == nil {
t.Error("expect error, got nil")
}
// Ensure upload started.
if e, a := "CreateMultipartUpload", (*invocations)[0]; e != a {
t.Errorf("expect %q, got %q", e, a)
}
// Part may or may not be sent because of timing of sending parts and
// reading next part in upload manager. Just check for the last abort.
if e, a := "AbortMultipartUpload", (*invocations)[len(*invocations)-1]; e != a {
t.Errorf("expect %q, got %q", e, a)
}
}
func TestSSE(t *testing.T) {
client, _, _ := s3testing.NewUploadLoggingClient(nil)
client.UploadPartFn = func(u *s3testing.UploadLoggingClient, params *s3.UploadPartInput) (*s3.UploadPartOutput, error) {
if params.SSECustomerAlgorithm == nil {
t.Fatal("SSECustomerAlgoritm should not be nil")
}
if params.SSECustomerKey == nil {
t.Fatal("SSECustomerKey should not be nil")
}
return &s3.UploadPartOutput{ETag: aws.String(fmt.Sprintf("ETAG%d", u.PartNum))}, nil
}
mgr := manager.NewUploader(client, func(u *manager.Uploader) {
u.Concurrency = 5
})
_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
SSECustomerAlgorithm: aws.String("AES256"),
SSECustomerKey: aws.String("foo"),
Body: bytes.NewBuffer(make([]byte, 1024*1024*10)),
})
if err != nil {
t.Fatal("Expected no error, but received" + err.Error())
}
}
func TestUploadWithContextCanceled(t *testing.T) {
u := manager.NewUploader(s3.New(s3.Options{
UsePathStyle: true,
Region: "mock-region",
}))
params := s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: bytes.NewReader(make([]byte, 0)),
}
ctx := &awstesting.FakeContext{DoneCh: make(chan struct{})}
ctx.Error = fmt.Errorf("context canceled")
close(ctx.DoneCh)
_, err := u.Upload(ctx, ¶ms)
if err == nil {
t.Fatalf("expect error, got nil")
}
if e, a := "canceled", err.Error(); !strings.Contains(a, e) {
t.Errorf("expected error message to contain %q, but did not %q", e, a)
}
}
// S3 Uploader incorrectly fails an upload if the content being uploaded
// has a size of MinPartSize * MaxUploadParts.
// Github: aws/aws-sdk-go#2557
func TestUploadMaxPartsEOF(t *testing.T) {
c, invocations, _ := s3testing.NewUploadLoggingClient(nil)
mgr := manager.NewUploader(c, func(u *manager.Uploader) {
u.Concurrency = 1
u.PartSize = manager.DefaultUploadPartSize
u.MaxUploadParts = 2
})
f := bytes.NewReader(make([]byte, int(mgr.PartSize)*int(mgr.MaxUploadParts)))
r1 := io.NewSectionReader(f, 0, manager.DefaultUploadPartSize)
r2 := io.NewSectionReader(f, manager.DefaultUploadPartSize, 2*manager.DefaultUploadPartSize)
body := io.MultiReader(r1, r2)
_, err := mgr.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: body,
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
expectOps := []string{
"CreateMultipartUpload",
"UploadPart",
"UploadPart",
"CompleteMultipartUpload",
}
if diff := cmp.Diff(expectOps, *invocations); len(diff) > 0 {
t.Error(diff)
}
}
func createTempFile(t *testing.T, size int64) (*os.File, func(*testing.T), error) {
file, err := ioutil.TempFile(os.TempDir(), aws.SDKName+t.Name())
if err != nil {
return nil, nil, err
}
filename := file.Name()
if err := file.Truncate(size); err != nil {
return nil, nil, err
}
return file,
func(t *testing.T) {
if err := file.Close(); err != nil {
t.Errorf("failed to close temp file, %s, %v", filename, err)
}
if err := os.Remove(filename); err != nil {
t.Errorf("failed to remove temp file, %s, %v", filename, err)
}
},
nil
}
func buildFailHandlers(tb testing.TB, parts, retry int) []http.Handler {
handlers := make([]http.Handler, parts)
for i := 0; i < len(handlers); i++ {
handlers[i] = &failPartHandler{
tb: tb,
failsRemaining: retry,
successHandler: successPartHandler{tb: tb},
}
}
return handlers
}
func TestUploadRetry(t *testing.T) {
const numParts, retries = 3, 10
testFile, testFileCleanup, err := createTempFile(t, manager.DefaultUploadPartSize*numParts)
if err != nil {
t.Fatalf("failed to create test file, %v", err)
}
defer testFileCleanup(t)
cases := map[string]struct {
Body io.Reader
PartHandlers func(testing.TB) []http.Handler
}{
"bytes.Buffer": {
Body: bytes.NewBuffer(make([]byte, manager.DefaultUploadPartSize*numParts)),
PartHandlers: func(tb testing.TB) []http.Handler {
return buildFailHandlers(tb, numParts, retries)
},
},
"bytes.Reader": {
Body: bytes.NewReader(make([]byte, manager.DefaultUploadPartSize*numParts)),
PartHandlers: func(tb testing.TB) []http.Handler {
return buildFailHandlers(tb, numParts, retries)
},
},
"os.File": {
Body: testFile,
PartHandlers: func(tb testing.TB) []http.Handler {
return buildFailHandlers(tb, numParts, retries)
},
},
}
for name, c := range cases {
t.Run(name, func(t *testing.T) {
restoreSleep := sdk.TestingUseNopSleep()
defer restoreSleep()
mux := newMockS3UploadServer(t, c.PartHandlers(t))
server := httptest.NewServer(mux)
defer server.Close()
client := s3.New(s3.Options{
EndpointResolver: s3testing.EndpointResolverFunc(func(region string, options s3.EndpointResolverOptions) (aws.Endpoint, error) {
return aws.Endpoint{
URL: server.URL,
}, nil
}),
UsePathStyle: true,
Retryer: retry.NewStandard(func(o *retry.StandardOptions) {
o.MaxAttempts = retries + 1
}),
})
uploader := manager.NewUploader(client)
_, err := uploader.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
Body: c.Body,
})
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
})
}
}
func TestUploadBufferStrategy(t *testing.T) {
cases := map[string]struct {
PartSize int64
Size int64
Strategy manager.ReadSeekerWriteToProvider
callbacks int
}{
"NoBuffer": {
PartSize: manager.DefaultUploadPartSize,
Strategy: nil,
},
"SinglePart": {
PartSize: manager.DefaultUploadPartSize,
Size: manager.DefaultUploadPartSize,
Strategy: &recordedBufferProvider{size: int(manager.DefaultUploadPartSize)},
callbacks: 1,
},
"MultiPart": {
PartSize: manager.DefaultUploadPartSize,
Size: manager.DefaultUploadPartSize * 2,
Strategy: &recordedBufferProvider{size: int(manager.DefaultUploadPartSize)},
callbacks: 2,
},
}
for name, tCase := range cases {
t.Run(name, func(t *testing.T) {
client, _, _ := s3testing.NewUploadLoggingClient(nil)
client.ConsumeBody = true
uploader := manager.NewUploader(client, func(u *manager.Uploader) {
u.PartSize = tCase.PartSize
u.BufferProvider = tCase.Strategy
u.Concurrency = 1
})
expected := s3testing.GetTestBytes(int(tCase.Size))
_, err := uploader.Upload(context.Background(), &s3.PutObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
Body: bytes.NewReader(expected),
})
if err != nil {
t.Fatalf("failed to upload file: %v", err)
}
switch strat := tCase.Strategy.(type) {
case *recordedBufferProvider:
if !bytes.Equal(expected, strat.content) {
t.Errorf("content buffered did not match expected")
}
if tCase.callbacks != strat.callbackCount {
t.Errorf("expected %v, got %v callbacks", tCase.callbacks, strat.callbackCount)
}
}
})
}
}
func TestUploaderValidARN(t *testing.T) {
cases := map[string]struct {
input s3.PutObjectInput
wantErr bool
}{
"standard bucket": {
input: s3.PutObjectInput{
Bucket: aws.String("test-bucket"),
Key: aws.String("test-key"),
Body: bytes.NewReader([]byte("test body content")),
},
},
"accesspoint": {
input: s3.PutObjectInput{
Bucket: aws.String("arn:aws:s3:us-west-2:123456789012:accesspoint/myap"),
Key: aws.String("test-key"),
Body: bytes.NewReader([]byte("test body content")),
},
},
"outpost accesspoint": {
input: s3.PutObjectInput{
Bucket: aws.String("arn:aws:s3-outposts:us-west-2:012345678901:outpost/op-1234567890123456/accesspoint/myaccesspoint"),
Key: aws.String("test-key"),
Body: bytes.NewReader([]byte("test body content")),
},
},
"s3-object-lambda accesspoint": {
input: s3.PutObjectInput{
Bucket: aws.String("arn:aws:s3-object-lambda:us-west-2:123456789012:accesspoint/myap"),
Key: aws.String("test-key"),
Body: bytes.NewReader([]byte("test body content")),
},
wantErr: true,
},
}
for name, tt := range cases {
t.Run(name, func(t *testing.T) {
client, _, _ := s3testing.NewUploadLoggingClient(nil)
client.ConsumeBody = true
uploader := manager.NewUploader(client)
_, err := uploader.Upload(context.Background(), &tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("err: %v, wantErr: %v", err, tt.wantErr)
}
})
}
}
type mockS3UploadServer struct {
*http.ServeMux
tb testing.TB
partHandler []http.Handler
}
func newMockS3UploadServer(tb testing.TB, partHandler []http.Handler) *mockS3UploadServer {
s := &mockS3UploadServer{
ServeMux: http.NewServeMux(),
partHandler: partHandler,
tb: tb,
}
s.HandleFunc("/", s.handleRequest)
return s
}
func (s mockS3UploadServer) handleRequest(w http.ResponseWriter, r *http.Request) {
defer func() {
closeErr := r.Body.Close()
if closeErr != nil {
failRequest(w, 0, "BodyCloseError",
fmt.Sprintf("request body close error: %v", closeErr))
}
}()
_, hasUploads := r.URL.Query()["uploads"]
switch {
case r.Method == "POST" && hasUploads:
// CreateMultipartUpload
w.Header().Set("Content-Length", strconv.Itoa(len(createUploadResp)))
w.Write([]byte(createUploadResp))
case r.Method == "PUT":
// UploadPart
partNumStr := r.URL.Query().Get("partNumber")
id, err := strconv.Atoi(partNumStr)
if err != nil {
failRequest(w, 400, "BadRequest",
fmt.Sprintf("unable to parse partNumber, %q, %v",
partNumStr, err))
return
}
id--
if id < 0 || id >= len(s.partHandler) {
failRequest(w, 400, "BadRequest",
fmt.Sprintf("invalid partNumber %v", id))
return
}
s.partHandler[id].ServeHTTP(w, r)
case r.Method == "POST":
// CompleteMultipartUpload
w.Header().Set("Content-Length", strconv.Itoa(len(completeUploadResp)))
w.Write([]byte(completeUploadResp))
case r.Method == "DELETE":
// AbortMultipartUpload
w.Header().Set("Content-Length", strconv.Itoa(len(abortUploadResp)))
w.WriteHeader(200)
w.Write([]byte(abortUploadResp))
default:
failRequest(w, 400, "BadRequest",
fmt.Sprintf("invalid request %v %v", r.Method, r.URL))
}
}
func failRequest(w http.ResponseWriter, status int, code, msg string) {
msg = fmt.Sprintf(baseRequestErrorResp, code, msg)
w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
w.WriteHeader(status)
w.Write([]byte(msg))
}
type successPartHandler struct {
tb testing.TB
}
func (h successPartHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer func() {
closeErr := r.Body.Close()
if closeErr != nil {
failRequest(w, 0, "BodyCloseError",
fmt.Sprintf("request body close error: %v", closeErr))
}
}()
n, err := io.Copy(ioutil.Discard, r.Body)
if err != nil {
failRequest(w, 400, "BadRequest",
fmt.Sprintf("failed to read body, %v", err))
return
}
contLenStr := r.Header.Get("Content-Length")
expectLen, err := strconv.ParseInt(contLenStr, 10, 64)
if err != nil {
h.tb.Logf("expect content-length, got %q, %v", contLenStr, err)
failRequest(w, 400, "BadRequest",
fmt.Sprintf("unable to get content-length %v", err))
return
}
if e, a := expectLen, n; e != a {
h.tb.Logf("expect %v read, got %v", e, a)
failRequest(w, 400, "BadRequest",
fmt.Sprintf(
"content-length and body do not match, %v, %v", e, a))
return
}
w.Header().Set("Content-Length", strconv.Itoa(len(uploadPartResp)))
w.Write([]byte(uploadPartResp))
}
type failPartHandler struct {
tb testing.TB
failsRemaining int
successHandler http.Handler
}
func (h *failPartHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer func() {
closeErr := r.Body.Close()
if closeErr != nil {
failRequest(w, 0, "BodyCloseError",
fmt.Sprintf("request body close error: %v", closeErr))
}
}()
if h.failsRemaining == 0 && h.successHandler != nil {
h.successHandler.ServeHTTP(w, r)
return
}
io.Copy(ioutil.Discard, r.Body)
failRequest(w, 500, "InternalException",
fmt.Sprintf("mock error, partNumber %v", r.URL.Query().Get("partNumber")))
h.failsRemaining--
}
type recordedBufferProvider struct {
content []byte
size int
callbackCount int
}
func (r *recordedBufferProvider) GetWriteTo(seeker io.ReadSeeker) (manager.ReadSeekerWriteTo, func()) {
b := make([]byte, r.size)
w := &manager.BufferedReadSeekerWriteTo{BufferedReadSeeker: manager.NewBufferedReadSeeker(seeker, b)}
return w, func() {
r.content = append(r.content, b...)
r.callbackCount++
}
}
const createUploadResp = `
bucket
key
abc123
`
const uploadPartResp = `
key
`
const baseRequestErrorResp = `
%s
%s
request-id
host-id
`
const completeUploadResp = `
bucket
key
key
https://bucket.us-west-2.amazonaws.com/key
abc123
`
const abortUploadResp = ``