// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

#pragma once
#include "gmock/gmock.h"

#include <aws/core/http/HttpClient.h>
#include <aws/core/http/HttpClientFactory.h>
#include <aws/core/http/HttpRequest.h>
#include <aws/core/http/HttpResponse.h>
#include <aws/core/http/standard/StandardHttpRequest.h>
#include <aws/core/http/standard/StandardHttpResponse.h>
#include <aws/core/utils/stream/SimpleStreamBuf.h>
#include <boost/algorithm/string.hpp>

class FakeHttpRequest : public Aws::Http::HttpRequest
{
private:
    Aws::Http::HeaderValueCollection headers;
    std::shared_ptr<Aws::IOStream> bodyStream;

public:
    FakeHttpRequest(const Aws::Http::URI& uri, Aws::Http::HttpMethod method) : HttpRequest(uri, method) 
    {
        headers["host"] = uri.GetAuthority();
    }
    virtual ~FakeHttpRequest() {}
    
    virtual Aws::Http::HeaderValueCollection GetHeaders() const override
    {
        return headers;
    }
    
    virtual const Aws::String& GetHeaderValue(const char* headerName) const override
    {
        return headers.at(headerName);
    }
    
    virtual void SetHeaderValue(const char* headerName, const Aws::String& headerValue) override
    {
        headers[headerName] = headerValue;
    }
    
    virtual void SetHeaderValue(const Aws::String& headerName, const Aws::String& headerValue) override
    {
        headers[headerName] = headerValue;
    }
    
    virtual void DeleteHeader(const char* headerName) override 
    {
        auto iterator = headers.find(headerName);
        
        if (iterator != headers.end())
        {
            headers.erase(iterator);
        }
    }
    
    virtual void AddContentBody(const std::shared_ptr<Aws::IOStream>& strContent) override 
    {
        bodyStream = strContent;
    }
    
    virtual const std::shared_ptr<Aws::IOStream>& GetContentBody() const override
    {
        return bodyStream;
    }
    
    virtual bool HasHeader(const char* name) const override 
    {
        auto iterator = headers.find(name);

        return (iterator != headers.end());
    }
    
    virtual int64_t GetSize() const override { return 0; }
    
    virtual const Aws::IOStreamFactory& GetResponseStreamFactory() const override
    {
        static Aws::IOStreamFactory foo; // might need to be a member variable
        return foo;
    }
    
    virtual void SetResponseStreamFactory(const Aws::IOStreamFactory& factory) override
    {}
};

class FakeHttpResponse : public Aws::Http::HttpResponse
{
private:
    Aws::Http::HeaderValueCollection headers;
    std::string responseBody;
    std::shared_ptr<Aws::IOStream> bodyStream;
    Aws::Utils::Stream::ResponseStream dummyResponseStream;

public:
    FakeHttpResponse(const std::shared_ptr<const Aws::Http::HttpRequest>& originatingRequest) : 
        HttpResponse(originatingRequest), dummyResponseStream([]()->Aws::IOStream* { return nullptr; })
    {}
    
    virtual ~FakeHttpResponse() override {};
    
    std::shared_ptr<Aws::Http::HttpRequest> originalRequest;
    
    virtual Aws::Http::HeaderValueCollection GetHeaders() const override { return headers; }
    
    virtual bool HasHeader(const char* headerName) const override
    {
        return headers.find(headerName) != headers.end();
    }
    
    virtual const Aws::String& GetHeader(const Aws::String& headerName) const override 
    { 
        return headers.at(headerName);
    }
    
    virtual Aws::Utils::Stream::ResponseStream&& SwapResponseStreamOwnership() override 
    { 
        // Return a ResponseStream with a dummy IOStreamFactory.
        return std::move(dummyResponseStream);
    }
    
    virtual void AddHeader(const Aws::String& headerName, const Aws::String& headerValue) override
    {
        headers[headerName] = headerValue;
    }
    
    void SetResponseBody(const std::string& responseBody)
    {
        this->responseBody = responseBody;
        std::stringstream ss(responseBody);
        bodyStream = std::make_shared<Aws::StringStream>(responseBody.c_str());
    }

    virtual Aws::IOStream& GetResponseBody() const override
    {
        return *bodyStream;
    }

    FakeHttpResponse() : Aws::Http::HttpResponse(std::make_shared<FakeHttpRequest>(Aws::Http::URI(), Aws::Http::HttpMethod::HTTP_GET))
    {}
};

class FakeHttpClient : public Aws::Http::HttpClient
{
private:
    std::map<std::shared_ptr<Aws::Http::HttpRequest>, std::shared_ptr<FakeHttpResponse>> requestResponseMap;

public:
    void AddRequestAndResponse(std::shared_ptr<Aws::Http::HttpRequest> request, std::shared_ptr<FakeHttpResponse> response)
    {
        requestResponseMap[request] = response;
    }

    FakeHttpClient() = default;

    virtual std::shared_ptr<Aws::Http::HttpResponse> MakeRequest(const std::shared_ptr<Aws::Http::HttpRequest>& request,
        Aws::Utils::RateLimits::RateLimiterInterface* readLimiter = nullptr,
        Aws::Utils::RateLimits::RateLimiterInterface* writeLimiter = nullptr)
        const override
    {
        return requestResponseMap.at(request);
    }

    void DisableRequestProcessing() {}
    void EnableRequestProcessing() {}
    bool IsRequestProcessingEnabled() const { return true; }
    void RetryRequestSleep(std::chrono::milliseconds sleepTime) {}
    bool ContinueRequest(const Aws::Http::HttpRequest&) const { return true; }
};

class MockHttpClient : public Aws::Http::HttpClient
{
private:
    FakeHttpClient fake;

public:
    void DelegateToFake()
    {
        ON_CALL(*this, MakeRequest).WillByDefault([this](const std::shared_ptr<Aws::Http::HttpRequest>& r, Aws::Utils::RateLimits::RateLimiterInterface* rl, Aws::Utils::RateLimits::RateLimiterInterface* wl) 
            {
                return fake.MakeRequest(r, rl, wl);
            });
    }

    MockHttpClient() {}
    virtual ~MockHttpClient() override {}

    MOCK_METHOD(std::shared_ptr<Aws::Http::HttpResponse>, MakeRequest, (const std::shared_ptr<Aws::Http::HttpRequest>&, Aws::Utils::RateLimits::RateLimiterInterface*, Aws::Utils::RateLimits::RateLimiterInterface*), (const override));
    MOCK_METHOD(void, DisableRequestProcessing, ());
    MOCK_METHOD(void, EnableRequestProcessing, ());
    MOCK_METHOD(bool, IsRequestProcessingEnabled, (), (const));
    MOCK_METHOD(void, RetryRequestSleep, (std::chrono::milliseconds));
    MOCK_METHOD(bool, ContinueRequest, (const Aws::Http::HttpRequest&), (const));
};

class MockHttpClientFactory : public Aws::Http::HttpClientFactory
{
private:
    std::shared_ptr<Aws::Http::HttpClient> mockClient;

public:
    virtual std::shared_ptr<Aws::Http::HttpClient> CreateHttpClient(const Aws::Client::ClientConfiguration& clientConfiguration) const override
    {
        return mockClient;
    }

    virtual std::shared_ptr<Aws::Http::HttpRequest> CreateHttpRequest(const Aws::String& uri, Aws::Http::HttpMethod method, const Aws::IOStreamFactory& streamFactory) const override
    {
        auto request = std::make_shared<Aws::Http::Standard::StandardHttpRequest>(uri, method);
        request->SetResponseStreamFactory(streamFactory);

        return request;
    }

    virtual std::shared_ptr<Aws::Http::HttpRequest> CreateHttpRequest(const Aws::Http::URI& uri, Aws::Http::HttpMethod method, const Aws::IOStreamFactory& streamFactory) const override
    {
        auto request = std::make_shared<Aws::Http::Standard::StandardHttpRequest>(uri, method);
        request->SetResponseStreamFactory(streamFactory);

        return request;
    }

    inline std::shared_ptr<Aws::Http::HttpClient> GetClient() const
    { 
        return mockClient; 
    }

    inline void SetClient(const std::shared_ptr<Aws::Http::HttpClient>& client)
    { 
        mockClient = client; 
    }
};