package middleware

import (
	"context"
	smithymiddleware "github.com/aws/smithy-go/middleware"
	smithyhttp "github.com/aws/smithy-go/transport/http"
	"os"
	"testing"
)

func TestRecursionDetection(t *testing.T) {
	cases := map[string]struct {
		LambdaFuncName string
		TraceID        string
		HeaderBefore   string
		HeaderAfter    string
	}{
		"non lambda env and no trace ID header before": {},
		"with lambda env but no trace ID env variable, no trace ID header before": {
			LambdaFuncName: "some-function1",
		},
		"with lambda env and trace ID env variable, no trace ID header before": {
			LambdaFuncName: "some-function2",
			TraceID:        "traceID1",
			HeaderAfter:    "traceID1",
		},
		"with lambda env and trace ID env variable, has trace ID header before": {
			LambdaFuncName: "some-function3",
			TraceID:        "traceID2",
			HeaderBefore:   "traceID1",
			HeaderAfter:    "traceID1",
		},
		"with lambda env and trace ID (needs encoding) env variable, no trace ID header before": {
			LambdaFuncName: "some-function4",
			TraceID:        "traceID3\n",
			HeaderAfter:    "traceID3%0A",
		},
		"with lambda env and trace ID (contains chars must not be encoded) env variable, no trace ID header before": {
			LambdaFuncName: "some-function5",
			TraceID:        "traceID4-=;:+&[]{}\"'",
			HeaderAfter:    "traceID4-=;:+&[]{}\"'",
		},
	}

	for name, c := range cases {
		t.Run(name, func(t *testing.T) {
			// clear current case's environment variables and restore them at the end of the test func goroutine
			restoreEnv := clearEnv()
			defer restoreEnv()

			setEnvVar(t, envAwsLambdaFunctionName, c.LambdaFuncName)
			setEnvVar(t, envAmznTraceID, c.TraceID)

			req := smithyhttp.NewStackRequest().(*smithyhttp.Request)
			if c.HeaderBefore != "" {
				req.Header.Set(amznTraceIDHeader, c.HeaderBefore)
			}
			var updatedRequest *smithyhttp.Request
			m := RecursionDetection{}
			_, _, err := m.HandleBuild(context.Background(),
				smithymiddleware.BuildInput{Request: req},
				smithymiddleware.BuildHandlerFunc(func(ctx context.Context, input smithymiddleware.BuildInput) (
					out smithymiddleware.BuildOutput, metadata smithymiddleware.Metadata, err error) {
					updatedRequest = input.Request.(*smithyhttp.Request)
					return out, metadata, nil
				}),
			)
			if err != nil {
				t.Fatalf("expect no error, got %v", err)
			}

			if e, a := c.HeaderAfter, updatedRequest.Header.Get(amznTraceIDHeader); e != a {
				t.Errorf("expect header value %v found, got %v", e, a)
			}
		})
	}
}

// check if test case has environment variable and set to os if it has
func setEnvVar(t *testing.T, key, value string) {
	if value != "" {
		err := os.Setenv(key, value)
		if err != nil {
			t.Fatalf("expect no error, got %v", err)
		}
	}
}