/*
 * SPDX-License-Identifier: Apache-2.0
 *
 * The OpenSearch Contributors require contributions made to
 * this file be licensed under the Apache-2.0 license or a
 * compatible open source license.
 */

package org.opensearch.extensions.rest;

import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;

import org.junit.After;
import org.junit.Before;
import org.mockito.Mockito;
import org.opensearch.Version;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.network.NetworkService;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.PageCacheRecycler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.extensions.ExtensionsManager;
import org.opensearch.extensions.ExtensionsSettings;
import org.opensearch.core.indices.breaker.NoneCircuitBreakerService;
import org.opensearch.rest.RestRequest;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.rest.FakeRestChannel;
import org.opensearch.test.rest.FakeRestRequest;
import org.opensearch.test.transport.MockTransportService;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.nio.MockNioTransport;

public class RestInitializeExtensionActionTests extends OpenSearchTestCase {

    private TransportService transportService;
    private MockNioTransport transport;
    private final ThreadPool threadPool = new TestThreadPool(RestInitializeExtensionActionTests.class.getSimpleName());

    @Before
    public void setup() throws Exception {
        Settings settings = Settings.builder().put("cluster.name", "test").build();
        transport = new MockNioTransport(
            settings,
            Version.CURRENT,
            threadPool,
            new NetworkService(Collections.emptyList()),
            PageCacheRecycler.NON_RECYCLING_INSTANCE,
            new NamedWriteableRegistry(Collections.emptyList()),
            new NoneCircuitBreakerService()
        );
        transportService = new MockTransportService(
            settings,
            transport,
            threadPool,
            TransportService.NOOP_TRANSPORT_INTERCEPTOR,
            (boundAddress) -> new DiscoveryNode(
                "test_node",
                "test_node",
                boundAddress.publishAddress(),
                emptyMap(),
                emptySet(),
                Version.CURRENT
            ),
            null,
            Collections.emptySet()
        );

    }

    @Override
    @After
    public void tearDown() throws Exception {
        super.tearDown();
        transportService.close();
        ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS);
    }

    public void testRestInitializeExtensionActionResponse() throws Exception {
        ExtensionsManager extensionsManager = mock(ExtensionsManager.class);
        RestInitializeExtensionAction restInitializeExtensionAction = new RestInitializeExtensionAction(extensionsManager);
        final String content = "{\"name\":\"ad-extension\",\"uniqueId\":\"ad-extension\",\"hostAddress\":\"127.0.0.1\","
            + "\"port\":\"4532\",\"version\":\"1.0\",\"opensearchVersion\":\""
            + Version.CURRENT.toString()
            + "\","
            + "\"minimumCompatibleVersion\":\""
            + Version.CURRENT.minimumCompatibilityVersion().toString()
            + "\"}";
        RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withContent(new BytesArray(content), XContentType.JSON)
            .withMethod(RestRequest.Method.POST)
            .build();

        FakeRestChannel channel = new FakeRestChannel(request, false, 0);
        restInitializeExtensionAction.handleRequest(request, channel, null);

        assertEquals(channel.capturedResponse().status(), RestStatus.ACCEPTED);
        assertTrue(channel.capturedResponse().content().utf8ToString().contains("A request to initialize an extension has been sent."));
    }

    public void testRestInitializeExtensionActionFailure() throws Exception {
        ExtensionsManager extensionsManager = new ExtensionsManager(Set.of());
        RestInitializeExtensionAction restInitializeExtensionAction = new RestInitializeExtensionAction(extensionsManager);

        final String content = "{\"name\":\"ad-extension\",\"uniqueId\":\"\",\"hostAddress\":\"127.0.0.1\","
            + "\"port\":\"4532\",\"version\":\"1.0\",\"opensearchVersion\":\""
            + Version.CURRENT.toString()
            + "\","
            + "\"minimumCompatibleVersion\":\""
            + Version.CURRENT.minimumCompatibilityVersion().toString()
            + "\"}";
        RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withContent(new BytesArray(content), XContentType.JSON)
            .withMethod(RestRequest.Method.POST)
            .build();

        FakeRestChannel channel = new FakeRestChannel(request, false, 0);
        restInitializeExtensionAction.handleRequest(request, channel, null);

        assertEquals(1, channel.errors().get());
        assertTrue(
            channel.capturedResponse().content().utf8ToString().contains("Required field [extension uniqueId] is missing in the request")
        );
    }

    public void testRestInitializeExtensionActionResponseWithAdditionalSettings() throws Exception {
        Setting boolSetting = Setting.boolSetting("boolSetting", false, Setting.Property.ExtensionScope);
        Setting stringSetting = Setting.simpleString("stringSetting", "default", Setting.Property.ExtensionScope);
        Setting intSetting = Setting.intSetting("intSetting", 0, Setting.Property.ExtensionScope);
        Setting listSetting = Setting.listSetting(
            "listSetting",
            List.of("first", "second", "third"),
            Function.identity(),
            Setting.Property.ExtensionScope
        );
        ExtensionsManager extensionsManager = new ExtensionsManager(Set.of(boolSetting, stringSetting, intSetting, listSetting));
        ExtensionsManager spy = spy(extensionsManager);

        // optionally, you can stub out some methods:
        when(spy.getAdditionalSettings()).thenCallRealMethod();
        Mockito.doCallRealMethod().when(spy).loadExtension(any(ExtensionsSettings.Extension.class));
        Mockito.doNothing().when(spy).initialize();
        RestInitializeExtensionAction restInitializeExtensionAction = new RestInitializeExtensionAction(spy);
        final String content = "{\"name\":\"ad-extension\",\"uniqueId\":\"ad-extension\",\"hostAddress\":\"127.0.0.1\","
            + "\"port\":\"4532\",\"version\":\"1.0\",\"opensearchVersion\":\""
            + Version.CURRENT.toString()
            + "\","
            + "\"minimumCompatibleVersion\":\""
            + Version.CURRENT.minimumCompatibilityVersion().toString()
            + "\",\"boolSetting\":true,\"stringSetting\":\"customSetting\",\"intSetting\":5,\"listSetting\":[\"one\",\"two\",\"three\"]}";
        RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withContent(new BytesArray(content), XContentType.JSON)
            .withMethod(RestRequest.Method.POST)
            .build();

        FakeRestChannel channel = new FakeRestChannel(request, false, 0);
        restInitializeExtensionAction.handleRequest(request, channel, null);

        assertEquals(channel.capturedResponse().status(), RestStatus.ACCEPTED);
        assertTrue(channel.capturedResponse().content().utf8ToString().contains("A request to initialize an extension has been sent."));

        Optional<ExtensionsSettings.Extension> extension = spy.lookupExtensionSettingsById("ad-extension");
        assertTrue(extension.isPresent());
        assertEquals(true, extension.get().getAdditionalSettings().get(boolSetting));
        assertEquals("customSetting", extension.get().getAdditionalSettings().get(stringSetting));
        assertEquals(5, extension.get().getAdditionalSettings().get(intSetting));

        List<String> listSettingValue = (List<String>) extension.get().getAdditionalSettings().get(listSetting);
        assertTrue(listSettingValue.contains("one"));
        assertTrue(listSettingValue.contains("two"));
        assertTrue(listSettingValue.contains("three"));
    }

    public void testRestInitializeExtensionActionResponseWithAdditionalSettingsUsingDefault() throws Exception {
        Setting boolSetting = Setting.boolSetting("boolSetting", false, Setting.Property.ExtensionScope);
        Setting stringSetting = Setting.simpleString("stringSetting", "default", Setting.Property.ExtensionScope);
        Setting intSetting = Setting.intSetting("intSetting", 0, Setting.Property.ExtensionScope);
        Setting listSetting = Setting.listSetting(
            "listSetting",
            List.of("first", "second", "third"),
            Function.identity(),
            Setting.Property.ExtensionScope
        );
        ExtensionsManager extensionsManager = new ExtensionsManager(Set.of(boolSetting, stringSetting, intSetting, listSetting));
        ExtensionsManager spy = spy(extensionsManager);

        // optionally, you can stub out some methods:
        when(spy.getAdditionalSettings()).thenCallRealMethod();
        Mockito.doCallRealMethod().when(spy).loadExtension(any(ExtensionsSettings.Extension.class));
        Mockito.doNothing().when(spy).initialize();
        RestInitializeExtensionAction restInitializeExtensionAction = new RestInitializeExtensionAction(spy);
        final String content = "{\"name\":\"ad-extension\",\"uniqueId\":\"ad-extension\",\"hostAddress\":\"127.0.0.1\","
            + "\"port\":\"4532\",\"version\":\"1.0\",\"opensearchVersion\":\""
            + Version.CURRENT.toString()
            + "\","
            + "\"minimumCompatibleVersion\":\""
            + Version.CURRENT.minimumCompatibilityVersion().toString()
            + "\"}";
        RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withContent(new BytesArray(content), XContentType.JSON)
            .withMethod(RestRequest.Method.POST)
            .build();

        FakeRestChannel channel = new FakeRestChannel(request, false, 0);
        restInitializeExtensionAction.handleRequest(request, channel, null);

        assertEquals(channel.capturedResponse().status(), RestStatus.ACCEPTED);
        assertTrue(channel.capturedResponse().content().utf8ToString().contains("A request to initialize an extension has been sent."));

        Optional<ExtensionsSettings.Extension> extension = spy.lookupExtensionSettingsById("ad-extension");
        assertTrue(extension.isPresent());
        assertEquals(false, extension.get().getAdditionalSettings().get(boolSetting));
        assertEquals("default", extension.get().getAdditionalSettings().get(stringSetting));
        assertEquals(0, extension.get().getAdditionalSettings().get(intSetting));

        List<String> listSettingValue = (List<String>) extension.get().getAdditionalSettings().get(listSetting);
        assertTrue(listSettingValue.contains("first"));
        assertTrue(listSettingValue.contains("second"));
        assertTrue(listSettingValue.contains("third"));
    }

}