/*
 * Copyright OpenSearch Contributors
 * SPDX-License-Identifier: Apache-2.0
 */

package org.opensearch.knn.indices;

import lombok.SneakyThrows;
import org.opensearch.OpenSearchParseException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
import java.util.HashSet;
import java.util.Set;

public class ModelGraveyardTests extends OpenSearchTestCase {

    public void testAdd() {
        ModelGraveyard testModelGraveyard = new ModelGraveyard();
        String testModelId = "test-model-id";
        testModelGraveyard.add(testModelId);
        assertTrue(testModelGraveyard.contains(testModelId));
    }

    public void testRemove() {
        Set<String> modelIds = new HashSet<>();
        String testModelId = "test-model-id";
        modelIds.add(testModelId);
        ModelGraveyard testModelGraveyard = new ModelGraveyard(modelIds);

        assertTrue(testModelGraveyard.contains(testModelId));
        testModelGraveyard.remove(testModelId);
        assertFalse(testModelGraveyard.contains(testModelId));
    }

    public void testContains() {
        Set<String> modelIds = new HashSet<>();
        String testModelId = "test-model-id";
        modelIds.add(testModelId);

        ModelGraveyard testModelGraveyard = new ModelGraveyard(modelIds);
        assertTrue(testModelGraveyard.contains(testModelId));
    }

    public void testStreams() throws IOException {
        Set<String> modelIds = new HashSet<>();
        String testModelId = "test-model-id";
        modelIds.add(testModelId);
        ModelGraveyard testModelGraveyard = new ModelGraveyard(modelIds);

        BytesStreamOutput streamOutput = new BytesStreamOutput();
        testModelGraveyard.writeTo(streamOutput);

        ModelGraveyard testModelGraveyardCopy = new ModelGraveyard(streamOutput.bytes().streamInput());

        assertEquals(testModelGraveyard.size(), testModelGraveyardCopy.size());
        assertTrue(testModelGraveyard.contains(testModelId));
        assertTrue(testModelGraveyardCopy.contains(testModelId));
    }

    // Validating {model_ids: ["test-model-id1", "test-model-id2"]}
    @SneakyThrows
    public void testXContentBuilder_withModelIds_returnsModelGraveyardWithModelIds() {
        Set<String> modelIds = new HashSet<>();
        String testModelId1 = "test-model-id1";
        String testModelId2 = "test-model-id2";
        modelIds.add(testModelId1);
        modelIds.add(testModelId2);
        ModelGraveyard testModelGraveyard = new ModelGraveyard(modelIds);

        XContentBuilder xContentBuilder = XContentFactory.jsonBuilder();
        xContentBuilder.startObject();
        XContentBuilder builder = testModelGraveyard.toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS);
        builder.endObject();

        ModelGraveyard testModelGraveyard2 = ModelGraveyard.fromXContent(createParser(builder));
        assertEquals(2, testModelGraveyard2.size());
        assertTrue(testModelGraveyard2.contains(testModelId1));
        assertTrue(testModelGraveyard2.contains(testModelId2));
    }

    // Validating {model_ids:[]}
    @SneakyThrows
    public void testXContentBuilder_withoutModelIds_returnsModelGraveyardWithoutModelIds() {
        ModelGraveyard testModelGraveyard = new ModelGraveyard();
        XContentBuilder xContentBuilder = XContentFactory.jsonBuilder();
        xContentBuilder.startObject();
        XContentBuilder builder = testModelGraveyard.toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS);
        builder.endObject();

        ModelGraveyard testModelGraveyard2 = ModelGraveyard.fromXContent(createParser(builder));
        assertEquals(0, testModelGraveyard2.size());
    }

    // Validating {test-model:"abcd"}
    @SneakyThrows
    public void testXContentBuilder_withWrongFieldName_throwsException() {
        XContentBuilder xContentBuilder = XContentFactory.jsonBuilder();
        xContentBuilder.startObject();
        xContentBuilder.field("test-model");
        xContentBuilder.value("abcd");
        xContentBuilder.endObject();

        OpenSearchParseException ex = expectThrows(
            OpenSearchParseException.class,
            () -> ModelGraveyard.fromXContent(createParser(xContentBuilder))
        );
        assertTrue(ex.getMessage().contains("Expecting field model_ids but got test-model"));
    }

    // Validating {}
    @SneakyThrows
    public void testXContentBuilder_validateBackwardCompatibility_returnsEmptyModelGraveyardObject() {
        XContentBuilder xContentBuilder = XContentFactory.jsonBuilder();
        xContentBuilder.startObject();
        xContentBuilder.endObject();

        ModelGraveyard testModelGraveyard = ModelGraveyard.fromXContent(createParser(xContentBuilder));
        assertEquals(0, testModelGraveyard.size());
    }

    // Validating null
    @SneakyThrows
    public void testXContentBuilder_withNull_throwsExceptionExpectingStartObject() {
        XContentBuilder xContentBuilder = XContentFactory.jsonBuilder();

        OpenSearchParseException ex = expectThrows(
            OpenSearchParseException.class,
            () -> ModelGraveyard.fromXContent(createParser(xContentBuilder))
        );
        assertTrue(ex.getMessage().contains("Expecting token start of an object but got null"));
    }

    // Validating {model_ids:"abcd"}
    @SneakyThrows
    public void testXContentBuilder_withMissingStartArray_throwsExceptionExpectingStartArray() {
        XContentBuilder xContentBuilder = XContentFactory.jsonBuilder();
        xContentBuilder.startObject();
        xContentBuilder.field("model_ids");
        xContentBuilder.value("abcd");
        xContentBuilder.endObject();

        OpenSearchParseException ex = expectThrows(
            OpenSearchParseException.class,
            () -> ModelGraveyard.fromXContent(createParser(xContentBuilder))
        );
        assertTrue(ex.getMessage().contains("Expecting token start of an array but got VALUE_STRING"));
    }

    // Validating {model_ids:["abcd"],model_ids_2:[]}
    @SneakyThrows
    public void testXContentBuilder_validateEndObject_throwsExceptionGotFieldName() {
        XContentBuilder xContentBuilder = XContentFactory.jsonBuilder();
        xContentBuilder.startObject();
        xContentBuilder.startArray("model_ids");
        xContentBuilder.value("abcd");
        xContentBuilder.endArray();
        xContentBuilder.startArray("model_ids_2");
        xContentBuilder.endArray();
        xContentBuilder.endObject();

        OpenSearchParseException ex = expectThrows(
            OpenSearchParseException.class,
            () -> ModelGraveyard.fromXContent(createParser(xContentBuilder))
        );
        assertTrue(ex.getMessage().contains("Expecting token end of an object but got FIELD_NAME"));
    }

    public void testDiffStreams() throws IOException {
        Set<String> added = new HashSet<>();
        Set<String> removed = new HashSet<>();
        String testModelId = "test-model-id";
        String testModelId1 = "test-model-id-1";
        added.add(testModelId);
        removed.add(testModelId1);

        ModelGraveyard modelGraveyardCurrent = new ModelGraveyard(added);
        ModelGraveyard modelGraveyardPrevious = new ModelGraveyard(removed);

        ModelGraveyard.ModelGraveyardDiff modelGraveyardDiff = new ModelGraveyard.ModelGraveyardDiff(
            modelGraveyardPrevious,
            modelGraveyardCurrent
        );
        assertEquals(added, modelGraveyardDiff.getAdded());
        assertEquals(removed, modelGraveyardDiff.getRemoved());

        BytesStreamOutput streamOutput = new BytesStreamOutput();
        modelGraveyardDiff.writeTo(streamOutput);

        ModelGraveyard.ModelGraveyardDiff modelGraveyardDiffCopy = new ModelGraveyard.ModelGraveyardDiff(
            streamOutput.bytes().streamInput()
        );
        assertEquals(added, modelGraveyardDiffCopy.getAdded());
        assertEquals(removed, modelGraveyardDiffCopy.getRemoved());
    }

    public void testDiff() {

        // nothing will have been removed in previous object, and all entries in current object are new
        ModelGraveyard modelGraveyard1 = new ModelGraveyard();

        Set<String> modelIds = new HashSet<>();
        modelIds.add("1");
        modelIds.add("2");
        ModelGraveyard modelGraveyard2 = new ModelGraveyard(modelIds);

        ModelGraveyard.ModelGraveyardDiff diff1 = new ModelGraveyard.ModelGraveyardDiff(modelGraveyard1, modelGraveyard2);
        assertEquals(0, diff1.getRemoved().size());
        assertEquals(2, diff1.getAdded().size());

        ModelGraveyard updatedGraveyard1 = diff1.apply(modelGraveyard1);
        assertEquals(2, updatedGraveyard1.size());
        assertTrue(updatedGraveyard1.contains("1"));
        assertTrue(updatedGraveyard1.contains("2"));

        // nothing will have been added to current object, and all entries in previous object are removed
        ModelGraveyard modelGraveyard3 = new ModelGraveyard();
        ModelGraveyard.ModelGraveyardDiff diff2 = new ModelGraveyard.ModelGraveyardDiff(modelGraveyard2, modelGraveyard3);
        assertEquals(2, diff2.getRemoved().size());
        assertEquals(0, diff2.getAdded().size());

        ModelGraveyard updatedGraveyard2 = diff2.apply(modelGraveyard2);
        assertEquals(0, updatedGraveyard2.size());

        // some entries in previous object are removed and few entries are added to current object
        modelIds = new HashSet<>();
        modelIds.add("1");
        modelIds.add("3");
        modelIds.add("4");
        ModelGraveyard modelGraveyard4 = new ModelGraveyard(modelIds);

        ModelGraveyard.ModelGraveyardDiff diff3 = new ModelGraveyard.ModelGraveyardDiff(modelGraveyard2, modelGraveyard4);
        assertEquals(1, diff3.getRemoved().size());
        assertEquals(2, diff3.getAdded().size());

        ModelGraveyard updatedGraveyard3 = diff3.apply(modelGraveyard2);
        assertEquals(3, updatedGraveyard3.size());
        assertTrue(updatedGraveyard3.contains("1"));
        assertTrue(updatedGraveyard3.contains("3"));
        assertTrue(updatedGraveyard3.contains("4"));
        assertFalse(updatedGraveyard3.contains("2"));
    }

}