//go:build go1.7
// +build go1.7

package s3manager

import (
	"context"
	"fmt"
	"net/http"
	"net/http/httptest"
	"testing"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/awserr"
	"github.com/aws/aws-sdk-go/aws/request"
	"github.com/aws/aws-sdk-go/service/s3"
)

// #1790 bug
func TestBatchDeleteContext(t *testing.T) {
	cases := []struct {
		objects     []BatchDeleteObject
		batchSize   int
		expected    int
		earlyCancel bool
		checkError  func(error) error
	}{
		0: {
			objects: []BatchDeleteObject{
				{
					Object: &s3.DeleteObjectInput{
						Key:    aws.String("1"),
						Bucket: aws.String("bucket1"),
					},
				},
				{
					Object: &s3.DeleteObjectInput{
						Key:    aws.String("2"),
						Bucket: aws.String("bucket2"),
					},
				},
				{
					Object: &s3.DeleteObjectInput{
						Key:    aws.String("3"),
						Bucket: aws.String("bucket3"),
					},
				},
				{
					Object: &s3.DeleteObjectInput{
						Key:    aws.String("4"),
						Bucket: aws.String("bucket4"),
					},
				},
			},
			batchSize:   1,
			expected:    0,
			earlyCancel: true,
			checkError: func(err error) error {
				batchErr, ok := err.(*BatchError)
				if !ok {
					return fmt.Errorf("expect BatchError, got %T, %v", err, err)
				}

				errs := batchErr.Errors
				if len(errs) != 4 {
					return fmt.Errorf("expected 1 batch errors, but received %d",
						len(errs))
				}

				for _, tempErr := range errs {
					aerr, ok := tempErr.OrigErr.(awserr.Error)
					if !ok {
						return fmt.Errorf("expect awserr.Error, got %T, %v",
							tempErr.OrigErr, tempErr.OrigErr)
					}

					if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
						return fmt.Errorf("expect %q, error code, got %q", e, a)
					}
				}
				return nil
			},
		},
		1: {
			objects: []BatchDeleteObject{
				{
					Object: &s3.DeleteObjectInput{
						Key:    aws.String("1"),
						Bucket: aws.String("bucket1"),
					},
				},
				{
					Object: &s3.DeleteObjectInput{
						Key:    aws.String("2"),
						Bucket: aws.String("bucket2"),
					},
				},
				{
					Object: &s3.DeleteObjectInput{
						Key:    aws.String("3"),
						Bucket: aws.String("bucket3"),
					},
				},
				{
					Object: &s3.DeleteObjectInput{
						Key:    aws.String("4"),
						Bucket: aws.String("bucket4"),
					},
				},
			},
			batchSize: 1,
			expected:  4,
			checkError: func(err error) error {
				if err != nil {
					return fmt.Errorf("Expect no error, got %v", err)
				}
				return nil
			},
		},
	}

	count := 0
	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusNoContent)
		count++
	}))
	defer server.Close()

	svc := &mockS3Client{S3: buildS3SvcClient(server.URL)}
	for i, c := range cases {
		ctx, cancel := context.WithCancel(context.Background())
		defer cancel()
		if c.earlyCancel {
			cancel()
		}

		batcher := BatchDelete{
			Client:    svc,
			BatchSize: c.batchSize,
		}

		err := batcher.Delete(ctx, &DeleteObjectsIterator{Objects: c.objects})
		if terr := c.checkError(err); terr != nil {
			t.Fatalf("%d, %s", i, terr)
		}

		if count != c.expected {
			t.Errorf("Case %d: expected %d, but received %d", i, c.expected, count)
		}

		count = 0
	}
}