package client

import (
	"context"
	"errors"
	"fmt"
	"net"
	"reflect"
	"runtime"
	"strings"
	"testing"
	"time"

	"github.com/aws/aws-dax-go/dax/internal/cbor"
	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/awserr"
	"github.com/aws/aws-sdk-go/aws/credentials"
	"github.com/aws/aws-sdk-go/aws/request"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

var unEncryptedConnConfig = connConfig{isEncrypted: false}

func TestExecuteErrorHandling(t *testing.T) {

	cases := []struct {
		conn *mockConn
		enc  func(writer *cbor.Writer) error
		dec  func(reader *cbor.Reader) error
		ee   error
		ec   map[string]int
	}{
		{ // write error, discard tube
			&mockConn{we: errors.New("io")},
			nil,
			nil,
			errors.New("io"),
			map[string]int{"Write": 1, "Close": 1},
		},
		{ // encoding error, discard tube
			&mockConn{},
			func(writer *cbor.Writer) error { return errors.New("ser") },
			nil,
			errors.New("ser"),
			map[string]int{"Write": 2, "SetDeadline": 1, "Close": 1},
		},
		{ // read error, discard tube
			&mockConn{re: errors.New("IO")},
			func(writer *cbor.Writer) error { return nil },
			nil,
			errors.New("IO"),
			map[string]int{"Write": 2, "Read": 1, "SetDeadline": 1, "Close": 1},
		},
		{ // serialization error, discard tube
			&mockConn{rd: []byte{cbor.NegInt}},
			func(writer *cbor.Writer) error { return nil },
			nil,
			awserr.New(request.ErrCodeSerialization, fmt.Sprintf("cbor: expected major type %d, got %d", cbor.Array, cbor.NegInt), nil),
			map[string]int{"Write": 2, "Read": 1, "SetDeadline": 1, "Close": 1},
		},
		{ // decode error, discard tube
			&mockConn{rd: []byte{cbor.Array + 0}},
			func(writer *cbor.Writer) error { return nil },
			func(reader *cbor.Reader) error { return errors.New("IO") },
			errors.New("IO"),
			map[string]int{"Write": 2, "Read": 1, "SetDeadline": 1, "Close": 1},
		},
		{ // dax error, do not discard tube
			&mockConn{rd: []byte{cbor.Array + 3, cbor.PosInt + 4, cbor.PosInt + 0, cbor.PosInt + 0, cbor.Utf, cbor.Nil}},
			func(writer *cbor.Writer) error { return nil },
			nil,
			newDaxRequestFailure([]int{4, 0, 0}, "", "", "", 400),
			map[string]int{"Write": 2, "Read": 1, "SetDeadline": 1},
		},
		{ // no error, do not discard tube
			&mockConn{rd: []byte{cbor.Array + 0}},
			func(writer *cbor.Writer) error { return nil },
			func(reader *cbor.Reader) error { return nil },
			nil,
			map[string]int{"Write": 2, "Read": 1, "SetDeadline": 1},
		},
	}

	for i, c := range cases {
		cli, err := newSingleClientWithOptions(":9121", unEncryptedConnConfig, "us-west-2", credentials.NewStaticCredentials("id", "secret", "tok"), 1, func(ctx context.Context, a, n string) (net.Conn, error) {
			return c.conn, nil
		})
		if err != nil {
			t.Fatalf("unexpected error %v", err)
		}
		cli.pool.closeTubeImmediately = true

		err = cli.executeWithContext(aws.BackgroundContext(), OpGetItem, c.enc, c.dec, RequestOptions{})
		if !reflect.DeepEqual(c.ee, err) {
			t.Errorf("case[%d] expected error %v, got error %v", i, c.ee, err)
		}
		if !reflect.DeepEqual(c.ec, c.conn.cc) {
			t.Errorf("case[%d] expected %v calls, got %v", i, c.ec, c.conn.cc)
		}
		cli.Close()
	}
}

func TestRetryPropogatesContextError(t *testing.T) {
	client, clientErr := newSingleClientWithOptions(":9121", unEncryptedConnConfig, "us-west-2", credentials.NewStaticCredentials("id", "secret", "tok"), 1, func(ctx context.Context, a, n string) (net.Conn, error) {
		return &mockConn{rd: []byte{cbor.Array + 0}}, nil
	})
	defer client.Close()
	if clientErr != nil {
		t.Fatalf("unexpected error %v", clientErr)
	}

	client.pool.closeTubeImmediately = true

	ctx, cancel := context.WithCancel(aws.BackgroundContext())
	requestOptions := RequestOptions{
		MaxRetries: 2,
		Context:    ctx,
	}

	writer := func(writer *cbor.Writer) error { return nil }
	reader := func(reader *cbor.Reader) error { return nil }

	// Cancel context to fail the execution
	cancel()
	err := client.executeWithRetries(OpGetItem, requestOptions, writer, reader)

	// Context related error should be returned
	awsError, ok := err.(awserr.Error)
	if !ok {
		t.Fatal("Error type is not awserr.Error")
	}

	if awsError.Code() != request.CanceledErrorCode || awsError.OrigErr() != context.Canceled {
		t.Errorf("aws error doesn't match expected. %v", awsError)
	}
}

func TestRetryPropogatesOtherErrors(t *testing.T) {
	client, clientErr := newSingleClientWithOptions(":9121", unEncryptedConnConfig, "us-west-2", credentials.NewStaticCredentials("id", "secret", "tok"), 1, func(ctx context.Context, a, n string) (net.Conn, error) {
		return &mockConn{rd: []byte{cbor.Array + 0}}, nil
	})
	defer client.Close()
	if clientErr != nil {
		t.Fatalf("unexpected error %v", clientErr)
	}

	client.pool.closeTubeImmediately = true

	requestOptions := RequestOptions{
		MaxRetries: 1,
	}
	expectedError := errors.New("IO")

	writer := func(writer *cbor.Writer) error { return nil }
	reader := func(reader *cbor.Reader) error { return errors.New("IO") }

	err := client.executeWithRetries(OpGetItem, requestOptions, writer, reader)

	// IO error should be returned
	awsError, ok := err.(awserr.Error)
	if !ok {
		t.Fatal("Error type is not awserr.Error")
	}

	if awsError.OrigErr() == nil {
		t.Fatal("Original error is empty")
	}

	if awsError.Code() != "UnknownError" || awsError.OrigErr().Error() != expectedError.Error() {
		t.Errorf("aws error doesn't match expected. %v", awsError)
	}
}

func TestRetryPropogatesOtherErrorsWithDelay(t *testing.T) {
	client, clientErr := newSingleClientWithOptions(":9121", unEncryptedConnConfig, "us-west-2", credentials.NewStaticCredentials("id", "secret", "tok"), 1, func(ctx context.Context, a, n string) (net.Conn, error) {
		return &mockConn{rd: []byte{cbor.Array + 0}}, nil
	})
	defer client.Close()
	if clientErr != nil {
		t.Fatalf("unexpected error %v", clientErr)
	}

	client.pool.closeTubeImmediately = true

	requestOptions := RequestOptions{
		MaxRetries: 1,
		RetryDelay: 1,
	}
	expectedError := errors.New("IO")

	writer := func(writer *cbor.Writer) error { return nil }
	reader := func(reader *cbor.Reader) error { return expectedError }

	err := client.executeWithRetries(OpGetItem, requestOptions, writer, reader)

	// IO error should be returned
	awsError, ok := err.(awserr.Error)
	if !ok {
		t.Fatal("Error type is not awserr.Error")
	}

	if awsError.OrigErr() == nil {
		t.Fatal("Original error is empty")
	}

	if awsError.Code() != "UnknownError" || awsError.OrigErr().Error() != expectedError.Error() {
		t.Errorf("aws error doesn't match expected. %v", awsError)
	}
}

func TestRetrySleepCycleCount(t *testing.T) {
	client, clientErr := newSingleClientWithOptions(":9121", unEncryptedConnConfig, "us-west-2", credentials.NewStaticCredentials("id", "secret", "tok"), 1, func(ctx context.Context, a, n string) (net.Conn, error) {
		return &mockConn{rd: []byte{cbor.Array + 0}}, nil
	})
	defer client.Close()
	if clientErr != nil {
		t.Fatalf("unexpected error %v", clientErr)
	}

	client.pool.closeTubeImmediately = true

	sleepCallCount := 0
	requestOptions := RequestOptions{
		MaxRetries:   0,
		RetryDelay:   0,
		SleepDelayFn: func(d time.Duration) { sleepCallCount++ },
	}

	writer := func(writer *cbor.Writer) error { return nil }
	reader := func(reader *cbor.Reader) error { return errors.New("IO") }
	client.executeWithRetries(OpGetItem, requestOptions, writer, reader)

	if sleepCallCount != 0 {
		t.Fatalf("Sleep was called %d times, but expected none", sleepCallCount)
	}

	requestOptions.MaxRetries = 3
	requestOptions.RetryDelay = 1
	client.executeWithRetries(OpGetItem, requestOptions, writer, reader)

	if sleepCallCount != requestOptions.MaxRetries {
		t.Fatalf("Sleep was called %d times, but expected %d", sleepCallCount, requestOptions.MaxRetries)
	}
}

func TestRetryLastError(t *testing.T) {
	client, clientErr := newSingleClientWithOptions(":9121", unEncryptedConnConfig, "us-west-2", credentials.NewStaticCredentials("id", "secret", "tok"), 1, func(ctx context.Context, a, n string) (net.Conn, error) {
		return &mockConn{rd: []byte{cbor.Array + 0}}, nil
	})
	defer client.Close()
	if clientErr != nil {
		t.Fatalf("unexpected error %v", clientErr)
	}

	client.pool.closeTubeImmediately = true

	var sleepCallCount uint
	requestOptions := RequestOptions{
		MaxRetries:   2,
		RetryDelay:   1,
		SleepDelayFn: func(d time.Duration) { sleepCallCount++ },
	}

	writer := func(writer *cbor.Writer) error { return nil }
	reader := func(reader *cbor.Reader) error {
		if sleepCallCount == 1 {
			return errors.New("IO")
		} else {
			return errors.New("LastError")
		}
	}
	err := client.executeWithRetries(OpGetItem, requestOptions, writer, reader)
	awsError, ok := err.(awserr.Error)
	if !ok {
		t.Fatal("Error type is not awserr.Error")
	}

	if awsError.OrigErr() == nil {
		t.Fatal("Original error is empty")
	}

	if awsError.Code() != "UnknownError" || awsError.OrigErr().Error() != "LastError" {
		t.Fatalf("aws error doesn't match expected. %v", awsError)
	}
}

func TestSingleClient_customDialer(t *testing.T) {
	conn := &mockConn{}
	var dialContextFn dialContext = func(ctx context.Context, address string, network string) (net.Conn, error) {
		return conn, nil
	}
	client, err := newSingleClientWithOptions(":9121", unEncryptedConnConfig, "us-west-2", credentials.NewStaticCredentials("id", "secret", "tok"), 1, dialContextFn)
	require.NoError(t, err)
	defer client.Close()

	c, _ := client.pool.dialContext(context.TODO(), "address", "network")
	assert.Equal(t, conn, c)
}

type mockConn struct {
	net.Conn
	we, re error
	wd, rd []byte
	cc     map[string]int
}

func (m *mockConn) Read(b []byte) (n int, err error) {
	m.register()
	if m.re != nil {
		return 0, m.re
	}
	if len(m.rd) > 0 {
		l := copy(b, m.rd)
		m.rd = m.rd[l:]
		return l, nil
	}
	return 0, nil
}

func (m *mockConn) Write(b []byte) (n int, err error) {
	m.register()
	if m.we != nil {
		return 0, m.we
	}
	if len(m.wd) > 0 {
		l := copy(m.wd, b)
		m.wd = m.wd[l:]
		return l, nil
	}
	return len(b), nil
}

func (m *mockConn) Close() error {
	m.register()
	return nil
}

func (m *mockConn) SetDeadline(t time.Time) error {
	m.register()
	return nil
}

func (m *mockConn) register() {
	pc, _, _, _ := runtime.Caller(1)
	fn := runtime.FuncForPC(pc)
	s := strings.Split(fn.Name(), ".")
	n := s[len(s)-1]
	if m.cc == nil {
		m.cc = make(map[string]int)
	}
	m.cc[n]++
}

func (m *mockConn) LocalAddr() net.Addr {
	return nil
}

func (m *mockConn) RemoteAddr() net.Addr {
	return nil
}

func (m *mockConn) SetReadDeadline(t time.Time) error {
	return nil
}

func (m *mockConn) SetWriteDeadline(t time.Time) error {
	return nil
}