// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package middleware

import (
	"bytes"
	"context"
	"encoding/json"
	"io"
	"net/http"
	"net/http/httptest"
	"testing"

	"github.com/go-chi/chi"
	"github.com/google/uuid"
	"github.com/stretchr/testify/assert"
	"go.amzn.com/lambda/appctx"
	"go.amzn.com/lambda/extensions"
	"go.amzn.com/lambda/rapi/handler"
	"go.amzn.com/lambda/rapi/model"
)

type mockHandler struct{}

func (h *mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {}

func TestRuntimeReleaseMiddleware(t *testing.T) {
	appCtx := appctx.NewApplicationContext()
	router := chi.NewRouter()
	handler := &mockHandler{}
	router.Use(RuntimeReleaseMiddleware())
	router.Get("/", handler.ServeHTTP)

	userAgent := "foobar"

	responseRecorder := httptest.NewRecorder()
	responseBody := make([]byte, 100)
	request := httptest.NewRequest("GET", "/", bytes.NewReader(responseBody))
	request.Header.Set("User-Agent", userAgent)
	router.ServeHTTP(responseRecorder, appctx.RequestWithAppCtx(request, appCtx))

	assert.Equal(t, http.StatusOK, responseRecorder.Code)
	ctxRuntimeRelease, ok := appCtx.Load(appctx.AppCtxRuntimeReleaseKey)
	assert.True(t, ok)
	assert.Equal(t, userAgent, ctxRuntimeRelease)
}

func TestAgentUniqueIdentifierHeaderValidatorForbidden(t *testing.T) {
	router := chi.NewRouter()
	mockHandler := &mockHandler{}
	router.Get("/", AgentUniqueIdentifierHeaderValidator(mockHandler).ServeHTTP)
	responseBody := make([]byte, 100)
	var errorResponse model.ErrorResponse

	request := httptest.NewRequest("GET", "/", bytes.NewReader(responseBody))

	responseRecorder := httptest.NewRecorder()
	router.ServeHTTP(responseRecorder, request)
	assert.Equal(t, http.StatusForbidden, responseRecorder.Code)
	respBody, _ := io.ReadAll(responseRecorder.Body)
	json.Unmarshal(respBody, &errorResponse)
	assert.Equal(t, handler.ErrAgentIdentifierMissing, errorResponse.ErrorType)

	responseRecorder = httptest.NewRecorder()
	request.Header.Set(handler.LambdaAgentIdentifier, "invalid-unique-identifier")
	router.ServeHTTP(responseRecorder, request)
	assert.Equal(t, http.StatusForbidden, responseRecorder.Code)
	respBody, _ = io.ReadAll(responseRecorder.Body)
	json.Unmarshal(respBody, &errorResponse)
	assert.Equal(t, handler.ErrAgentIdentifierInvalid, errorResponse.ErrorType)
}

func TestAgentUniqueIdentifierHeaderValidatorSuccess(t *testing.T) {
	router := chi.NewRouter()
	mockHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		val, ok := r.Context().Value(handler.AgentIDCtxKey).(uuid.UUID)
		if !ok {
			assert.FailNow(t, "expected key not in request context")
		}
		assert.Equal(t, "85083764-ff1e-476f-ada1-d51f26e4f6be", val.String())
	})
	router.Get("/", AgentUniqueIdentifierHeaderValidator(mockHandler).ServeHTTP)
	responseBody := make([]byte, 100)
	request := httptest.NewRequest("GET", "/", bytes.NewReader(responseBody))
	ctx := context.Background()
	request = request.WithContext(ctx)

	responseRecorder := httptest.NewRecorder()
	responseRecorder.Code = http.StatusOK
	request.Header.Set(handler.LambdaAgentIdentifier, "85083764-ff1e-476f-ada1-d51f26e4f6be")
	router.ServeHTTP(responseRecorder, request)
	assert.Equal(t, http.StatusOK, responseRecorder.Code)
}

func TestAllowIfExtensionsEnabledPositive(t *testing.T) {
	router := chi.NewRouter()
	handler := &mockHandler{}
	router.Use(AllowIfExtensionsEnabled)
	router.Get("/", handler.ServeHTTP)

	responseRecorder := httptest.NewRecorder()
	responseBody := make([]byte, 100)

	extensions.Enable()
	defer extensions.Disable()

	router.ServeHTTP(responseRecorder, httptest.NewRequest("GET", "/", bytes.NewReader(responseBody)))

	assert.Equal(t, http.StatusOK, responseRecorder.Code)
}

func TestAllowIfExtensionsEnabledNegative(t *testing.T) {
	router := chi.NewRouter()
	handler := &mockHandler{}
	router.Use(AllowIfExtensionsEnabled)
	router.Get("/", handler.ServeHTTP)

	responseRecorder := httptest.NewRecorder()
	responseBody := make([]byte, 100)
	router.ServeHTTP(responseRecorder, httptest.NewRequest("GET", "/", bytes.NewReader(responseBody)))

	assert.Equal(t, http.StatusNotFound, responseRecorder.Code)
}