//go:build unit
// +build unit

// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
//	http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package audit

import (
	"fmt"
	"net/http"
	"net/url"
	"strconv"
	"strings"
	"testing"

	"github.com/aws/amazon-ecs-agent/agent/config"
	mock_infologger "github.com/aws/amazon-ecs-agent/agent/logger/audit/mocks"
	"github.com/aws/amazon-ecs-agent/ecs-agent/credentials"
	auditinterface "github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit"
	"github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit/request"
	"github.com/golang/mock/gomock"
	"github.com/stretchr/testify/assert"
)

const (
	dummyContainerInstanceArn = "containerInstanceArn"
	dummyCluster              = "cluster"
	dummyEventType            = "someEvent"
	dummyRemoteAddress        = "rAddr"
	dummyURL                  = "http://foo.com" + dummyURLPath + "?id=foo"
	dummyURLPath              = "/urlPath"
	dummyURLV2                = "http://foo.com" + credentials.V2CredentialsPath + "/" + taskARN
	dummyUserAgent            = "userAgent"
	dummyResponseCode         = 400
	dummyRoleType             = "TaskExecution"
	taskARN                   = "task-arn-1"

	commonAuditLogEntryFieldCount = 6
	getCredentialsEntryFieldCount = 4
)

func TestWritingToAuditLog(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()

	mockInfoLogger := mock_infologger.NewMockInfoLogger(ctrl)

	req, _ := http.NewRequest("GET", "foo", nil)
	req.RemoteAddr = dummyRemoteAddress
	parsedURL, err := url.Parse(dummyURL)
	if err != nil {
		t.Fatal("error parsing dummyUrl")
	}
	req.URL = parsedURL
	req.Header.Set("User-Agent", dummyUserAgent)

	cfg := &config.Config{
		Cluster:                 dummyCluster,
		CredentialsAuditLogFile: "foo.txt",
	}

	auditLogger := NewAuditLog(dummyContainerInstanceArn, cfg, mockInfoLogger)
	assert.Equal(t, dummyCluster, auditLogger.GetCluster(), "Cluster is not initialized properly")
	assert.Equal(t, dummyContainerInstanceArn, auditLogger.GetContainerInstanceArn(), "ContainerInstanceArn is not initialized properly")

	mockInfoLogger.EXPECT().Info(gomock.Any()).Do(func(logLine string) {
		verifyAuditLogEntryResult(logLine, taskARN, dummyURLPath, t)
	})

	auditLogger.Log(request.LogRequest{Request: req, ARN: taskARN}, dummyResponseCode,
		auditinterface.GetCredentialsEventTypeFromRoleType(dummyRoleType))
}

func TestWritingToAuditLogV2(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()

	mockInfoLogger := mock_infologger.NewMockInfoLogger(ctrl)

	req, _ := http.NewRequest("GET", "foo", nil)
	req.RemoteAddr = dummyRemoteAddress
	parsedURL, err := url.Parse(dummyURLV2)
	if err != nil {
		t.Fatal("error parsing dummyUrl")
	}
	req.URL = parsedURL
	req.Header.Set("User-Agent", dummyUserAgent)

	cfg := &config.Config{
		Cluster:                 dummyCluster,
		CredentialsAuditLogFile: "foo.txt",
	}

	auditLogger := NewAuditLog(dummyContainerInstanceArn, cfg, mockInfoLogger)
	assert.Equal(t, dummyCluster, auditLogger.GetCluster(), "Cluster is not initialized properly")
	assert.Equal(t, dummyContainerInstanceArn, auditLogger.GetContainerInstanceArn(), "ContainerInstanceArn is not initialized properly")

	mockInfoLogger.EXPECT().Info(gomock.Any()).Do(func(logLine string) {
		verifyAuditLogEntryResult(logLine, taskARN, credentials.V2CredentialsPath, t)
	})

	auditLogger.Log(request.LogRequest{Request: req, ARN: taskARN}, dummyResponseCode,
		auditinterface.GetCredentialsEventTypeFromRoleType(dummyRoleType))
}

func TestWritingErrorsToAuditLog(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()

	mockInfoLogger := mock_infologger.NewMockInfoLogger(ctrl)

	req, _ := http.NewRequest("GET", "foo", nil)
	req.RemoteAddr = dummyRemoteAddress
	parsedURL, err := url.Parse(dummyURL)
	if err != nil {
		t.Fatal("error parsing dummyUrl")
	}
	req.URL = parsedURL
	req.Header.Set("User-Agent", dummyUserAgent)

	cfg := &config.Config{
		Cluster:                 dummyCluster,
		CredentialsAuditLogFile: "foo.txt",
	}

	auditLogger := NewAuditLog(dummyContainerInstanceArn, cfg, mockInfoLogger)
	assert.Equal(t, dummyCluster, auditLogger.GetCluster(), "Cluster is not initialized properly")
	assert.Equal(t, dummyContainerInstanceArn, auditLogger.GetContainerInstanceArn(), "ContainerInstanceArn is not initialized properly")

	mockInfoLogger.EXPECT().Info(gomock.Any()).Do(func(logLine string) {
		verifyAuditLogEntryResult(logLine, "-", dummyURLPath, t)
	})

	auditLogger.Log(request.LogRequest{Request: req, ARN: ""}, dummyResponseCode,
		auditinterface.GetCredentialsEventTypeFromRoleType(dummyRoleType))
}

func TestWritingToAuditLogWhenDisabled(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()

	mockInfoLogger := mock_infologger.NewMockInfoLogger(ctrl)

	req, _ := http.NewRequest("GET", "foo", nil)

	cfg := &config.Config{
		Cluster:                     dummyCluster,
		CredentialsAuditLogFile:     "foo.txt",
		CredentialsAuditLogDisabled: true,
	}

	auditLogger := NewAuditLog(dummyContainerInstanceArn, cfg, mockInfoLogger)
	assert.Equal(t, dummyCluster, auditLogger.GetCluster(), "Cluster is not initialized properly")
	assert.Equal(t, dummyContainerInstanceArn, auditLogger.GetContainerInstanceArn(), "ContainerInstanceArn is not initialized properly")

	mockInfoLogger.EXPECT().Info(gomock.Any()).Times(0)

	auditLogger.Log(request.LogRequest{Request: req, ARN: taskARN}, dummyResponseCode,
		auditinterface.GetCredentialsEventTypeFromRoleType(dummyRoleType))
}

func TestConstructCommonAuditLogEntryFields(t *testing.T) {
	req, _ := http.NewRequest("GET", "foo", nil)
	req.RemoteAddr = dummyRemoteAddress
	parsedURL, err := url.Parse(dummyURL)
	if err != nil {
		t.Fatal("error parsing dummyUrl")
	}
	req.URL = parsedURL
	req.Header.Set("User-Agent", dummyUserAgent)

	result := constructCommonAuditLogEntryFields(request.LogRequest{Request: req, ARN: taskARN}, dummyResponseCode)

	verifyCommonAuditLogEntryFieldResult(result, taskARN, dummyURLPath, t)
}

func TestConstructAuditLogEntryByTypeGetCredentials(t *testing.T) {
	result := constructAuditLogEntryByType(
		auditinterface.GetCredentialsEventTypeFromRoleType(dummyRoleType), dummyCluster,
		dummyContainerInstanceArn)
	verifyConstructAuditLogEntryGetCredentialsResult(result, t)
}

func verifyAuditLogEntryResult(logLine string, expectedTaskArn string, expectedURLPath string, t *testing.T) {
	tokens := strings.Split(logLine, " ")
	assert.Equal(t, commonAuditLogEntryFieldCount+getCredentialsEntryFieldCount, len(tokens), "Incorrect number of tokens in audit log entry")
	verifyCommonAuditLogEntryFieldResult(strings.Join(tokens[:commonAuditLogEntryFieldCount], " "), expectedTaskArn, expectedURLPath, t)
	verifyConstructAuditLogEntryGetCredentialsResult(strings.Join(tokens[commonAuditLogEntryFieldCount:], " "), t)
}

func verifyCommonAuditLogEntryFieldResult(result string, expectedTaskArn string, expectedURLPath string, t *testing.T) {
	tokens := strings.Split(result, " ")
	assert.Equal(t, commonAuditLogEntryFieldCount, len(tokens), "Incorrect number of tokens in common audit log entry")

	respCode, _ := strconv.Atoi(tokens[1])
	assert.Equal(t, dummyResponseCode, respCode, "response code does not match")
	assert.Equal(t, dummyRemoteAddress, tokens[2], "remoted address does not match")
	assert.Equal(t, fmt.Sprintf(`"%s"`, expectedURLPath), tokens[3], "URL path does not match")
	assert.Equal(t, fmt.Sprintf(`"%s"`, dummyUserAgent), tokens[4], "User Agent does not match")
	assert.Equal(t, expectedTaskArn, tokens[5], "ARN for credentials does not match")
}

func verifyConstructAuditLogEntryGetCredentialsResult(result string, t *testing.T) {
	tokens := strings.Split(result, " ")

	assert.Equal(t, getCredentialsEntryFieldCount, len(tokens), "Incorrect number of tokens in GetCredentials audit log entry")
	assert.Equal(t, auditinterface.GetCredentialsEventTypeFromRoleType(dummyRoleType),
		tokens[0], "event type does not match")

	auditLogVersion, _ := strconv.Atoi(tokens[1])
	assert.Equal(t, getCredentialsAuditLogVersion, auditLogVersion, "version does not match")
	assert.Equal(t, dummyCluster, tokens[2], "cluster does not match")
	assert.Equal(t, dummyContainerInstanceArn, tokens[3], "containerInstanceArn does not match")
}

func TestConstructAuditLogEntryByTypeUnknownType(t *testing.T) {
	result := constructAuditLogEntryByType("unknownEvent", dummyCluster, dummyContainerInstanceArn)
	assert.Equal(t, "", result, "unknown event type should not return an entry")
}