//go:build go1.7
// +build go1.7

package stscreds

import (
	"fmt"
	"net/http"
	"reflect"
	"strings"
	"testing"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/client"
	"github.com/aws/aws-sdk-go/aws/credentials"
	"github.com/aws/aws-sdk-go/aws/request"
	"github.com/aws/aws-sdk-go/service/sts"
	"github.com/aws/aws-sdk-go/service/sts/stsiface"
)

func TestWebIdentityProviderRetrieve(t *testing.T) {
	cases := map[string]struct {
		roleARN           string
		tokenPath         string
		sessionName       string
		newClient         func(t *testing.T) stsiface.STSAPI
		duration          time.Duration
		expectedError     string
		expectedCredValue credentials.Value
	}{
		"session name case": {
			roleARN:     "arn01234567890123456789",
			tokenPath:   "testdata/token.jwt",
			sessionName: "foo",
			newClient: func(t *testing.T) stsiface.STSAPI {
				return mockAssumeRoleWithWebIdentityClient{
					t: t,
					doRequest: func(t *testing.T, input *sts.AssumeRoleWithWebIdentityInput) (
						*sts.AssumeRoleWithWebIdentityOutput, error,
					) {
						if e, a := "foo", *input.RoleSessionName; e != a {
							t.Errorf("expected %v, but received %v", e, a)
						}
						if input.DurationSeconds != nil {
							t.Errorf("expect no duration, got %v", *input.DurationSeconds)
						}

						return &sts.AssumeRoleWithWebIdentityOutput{
							Credentials: &sts.Credentials{
								Expiration:      aws.Time(time.Now()),
								AccessKeyId:     aws.String("access-key-id"),
								SecretAccessKey: aws.String("secret-access-key"),
								SessionToken:    aws.String("session-token"),
							},
						}, nil
					},
				}
			},
			expectedCredValue: credentials.Value{
				AccessKeyID:     "access-key-id",
				SecretAccessKey: "secret-access-key",
				SessionToken:    "session-token",
				ProviderName:    WebIdentityProviderName,
			},
		},
		"with duration": {
			roleARN:     "arn01234567890123456789",
			tokenPath:   "testdata/token.jwt",
			sessionName: "foo",
			duration:    15 * time.Minute,
			newClient: func(t *testing.T) stsiface.STSAPI {
				return mockAssumeRoleWithWebIdentityClient{
					t: t,
					doRequest: func(t *testing.T, input *sts.AssumeRoleWithWebIdentityInput) (
						*sts.AssumeRoleWithWebIdentityOutput, error,
					) {
						if e, a := int64((15*time.Minute)/time.Second), *input.DurationSeconds; e != a {
							t.Errorf("expect %v duration, got %v", e, a)
						}

						return &sts.AssumeRoleWithWebIdentityOutput{
							Credentials: &sts.Credentials{
								Expiration:      aws.Time(time.Now()),
								AccessKeyId:     aws.String("access-key-id"),
								SecretAccessKey: aws.String("secret-access-key"),
								SessionToken:    aws.String("session-token"),
							},
						}, nil
					},
				}
			},
			expectedCredValue: credentials.Value{
				AccessKeyID:     "access-key-id",
				SecretAccessKey: "secret-access-key",
				SessionToken:    "session-token",
				ProviderName:    WebIdentityProviderName,
			},
		},
	}

	for name, c := range cases {
		t.Run(name, func(t *testing.T) {
			p := NewWebIdentityRoleProvider(c.newClient(t), c.roleARN, c.sessionName, c.tokenPath)
			p.Duration = c.duration

			credValue, err := p.Retrieve()
			if len(c.expectedError) != 0 {
				if err == nil {
					t.Fatalf("expect error, got none")
				}
				if e, a := c.expectedError, err.Error(); !strings.Contains(a, e) {
					t.Fatalf("expect error to contain %v, got %v", e, a)
				}
				return
			}
			if err != nil {
				t.Fatalf("expect no error, got %v", err)
			}

			if e, a := c.expectedCredValue, credValue; !reflect.DeepEqual(e, a) {
				t.Errorf("expected %v, but received %v", e, a)
			}
		})
	}
}

type mockAssumeRoleWithWebIdentityClient struct {
	stsiface.STSAPI

	t         *testing.T
	doRequest func(*testing.T, *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error)
}

func (c mockAssumeRoleWithWebIdentityClient) AssumeRoleWithWebIdentityRequest(input *sts.AssumeRoleWithWebIdentityInput) (
	*request.Request, *sts.AssumeRoleWithWebIdentityOutput,
) {
	output, err := c.doRequest(c.t, input)

	req := &request.Request{
		HTTPRequest: &http.Request{},
		Retryer:     client.DefaultRetryer{},
	}
	req.Handlers.Send.PushBack(func(r *request.Request) {
		r.HTTPResponse = &http.Response{}
		r.Data = output
		r.Error = err

		var found bool
		for _, retryCode := range req.RetryErrorCodes {
			if retryCode == sts.ErrCodeInvalidIdentityTokenException {
				found = true
				break
			}
		}
		if !found {
			c.t.Errorf("expect ErrCodeInvalidIdentityTokenException error code to be retry-able")
		}
	})

	return req, output
}

func TestNewWebIdentityRoleProviderWithOptions(t *testing.T) {
	const roleARN = "a-role-arn"
	const roleSessionName = "a-session-name"

	cases := map[string]struct {
		options []func(*WebIdentityRoleProvider)
		expect  WebIdentityRoleProvider
	}{
		"no options": {
			expect: WebIdentityRoleProvider{
				client:          stubClient{},
				tokenFetcher:    stubTokenFetcher{},
				roleARN:         roleARN,
				roleSessionName: roleSessionName,
			},
		},
		"with options": {
			options: []func(*WebIdentityRoleProvider){
				func(o *WebIdentityRoleProvider) {
					o.Duration = 10 * time.Minute
					o.ExpiryWindow = time.Minute
				},
			},
			expect: WebIdentityRoleProvider{
				client:          stubClient{},
				tokenFetcher:    stubTokenFetcher{},
				roleARN:         roleARN,
				roleSessionName: roleSessionName,
				Duration:        10 * time.Minute,
				ExpiryWindow:    time.Minute,
			},
		},
	}

	for name, c := range cases {
		t.Run(name, func(t *testing.T) {

			p := NewWebIdentityRoleProviderWithOptions(
				stubClient{}, roleARN, roleSessionName,
				stubTokenFetcher{}, c.options...)

			if !reflect.DeepEqual(c.expect, *p) {
				t.Errorf("expect:\n%v\nactual:\n%v", c.expect, *p)
			}
		})
	}
}

type stubClient struct {
	stsiface.STSAPI
}
type stubTokenFetcher struct{}

func (stubTokenFetcher) FetchToken(credentials.Context) ([]byte, error) {
	return nil, fmt.Errorf("stubTokenFetcher should not be called")
}