/*
 * Copyright OpenSearch Contributors
 * SPDX-License-Identifier: Apache-2.0
 */

package org.opensearch.knn;

import com.google.common.base.Charsets;
import com.google.common.io.Resources;
import com.google.common.primitives.Floats;
import org.apache.commons.lang.StringUtils;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.opensearch.common.Strings;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.knn.plugin.KNNPlugin;
import org.opensearch.knn.plugin.script.KNNScoringScriptEngine;
import org.junit.AfterClass;
import org.junit.Before;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.xcontent.MediaType;
import org.opensearch.index.query.ExistsQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.script.Script;
import org.opensearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;

import javax.management.MBeanServerInvocationHandler;
import javax.management.MalformedObjectNameException;
import javax.management.ObjectName;
import javax.management.remote.JMXConnector;
import javax.management.remote.JMXConnectorFactory;
import javax.management.remote.JMXServiceURL;

import java.io.IOException;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MODEL_BLOB_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.MODEL_ERROR;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_MAPPING_PATH;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME;
import static org.opensearch.knn.common.KNNConstants.MODEL_STATE;
import static org.opensearch.knn.common.KNNConstants.MODEL_TIMESTAMP;
import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M;

import static org.opensearch.knn.TestUtils.NUMBER_OF_REPLICAS;
import static org.opensearch.knn.TestUtils.NUMBER_OF_SHARDS;
import static org.opensearch.knn.TestUtils.INDEX_KNN;
import static org.opensearch.knn.TestUtils.PROPERTIES;
import static org.opensearch.knn.TestUtils.VECTOR_TYPE;
import static org.opensearch.knn.TestUtils.KNN_VECTOR;
import static org.opensearch.knn.TestUtils.FIELD;
import static org.opensearch.knn.TestUtils.QUERY_VALUE;
import static org.opensearch.knn.TestUtils.computeGroundTruthValues;

import static org.opensearch.knn.index.SpaceType.L2;
import static org.opensearch.knn.index.memory.NativeMemoryCacheManager.GRAPH_COUNT;
import static org.opensearch.knn.index.util.KNNEngine.FAISS;
import static org.opensearch.knn.plugin.stats.StatNames.INDICES_IN_CACHE;

/**
 * Base class for integration tests for KNN plugin. Contains several methods for testing KNN ES functionality.
 */
public class KNNRestTestCase extends ODFERestTestCase {
    public static final String INDEX_NAME = "test_index";
    public static final String FIELD_NAME = "test_field";
    private static final String DOCUMENT_FIELD_SOURCE = "_source";
    private static final String DOCUMENT_FIELD_FOUND = "found";
    protected static final int DELAY_MILLI_SEC = 1000;
    protected static final int NUM_OF_ATTEMPTS = 30;
    private static final String SYSTEM_INDEX_PREFIX = ".opendistro";

    @AfterClass
    public static void dumpCoverage() throws IOException, MalformedObjectNameException {
        // jacoco.dir is set in esplugin-coverage.gradle, if it doesn't exist we don't
        // want to collect coverage so we can return early
        String jacocoBuildPath = System.getProperty("jacoco.dir");
        if (org.opensearch.core.common.Strings.isNullOrEmpty(jacocoBuildPath)) {
            return;
        }

        String serverUrl = System.getProperty("jmx.serviceUrl");
        try (JMXConnector connector = JMXConnectorFactory.connect(new JMXServiceURL(serverUrl))) {
            IProxy proxy = MBeanServerInvocationHandler.newProxyInstance(
                connector.getMBeanServerConnection(),
                new ObjectName("org.jacoco:type=Runtime"),
                IProxy.class,
                false
            );

            Path path = Paths.get(jacocoBuildPath + "/integTest.exec");
            Files.write(path, proxy.getExecutionData(false));
        } catch (Exception ex) {
            throw new RuntimeException("Failed to dump coverage: " + ex);
        }
    }

    @Before
    public void cleanUpCache() throws Exception {
        clearCache();
    }

    /**
     * Create KNN Index with default settings
     */
    protected void createKnnIndex(String index, String mapping) throws IOException {
        createIndex(index, getKNNDefaultIndexSettings());
        putMappingRequest(index, mapping);
    }

    /**
     * Create KNN Index
     */
    protected void createKnnIndex(String index, Settings settings, String mapping) throws IOException {
        createIndex(index, settings);
        putMappingRequest(index, mapping);
    }

    protected void createBasicKnnIndex(String index, String fieldName, int dimension) throws IOException {
        String mapping = Strings.toString(
            XContentFactory.jsonBuilder()
                .startObject()
                .startObject("properties")
                .startObject(fieldName)
                .field("type", "knn_vector")
                .field("dimension", Integer.toString(dimension))
                .endObject()
                .endObject()
                .endObject()
        );

        mapping = mapping.substring(1, mapping.length() - 1);
        createIndex(index, Settings.EMPTY, mapping);
    }

    /**
     * Run KNN Search on Index
     */
    protected Response searchKNNIndex(String index, KNNQueryBuilder knnQueryBuilder, int resultSize) throws IOException {
        XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query");
        knnQueryBuilder.doXContent(builder, ToXContent.EMPTY_PARAMS);
        builder.endObject().endObject();

        Request request = new Request("POST", "/" + index + "/_search");

        request.addParameter("size", Integer.toString(resultSize));
        request.addParameter("explain", Boolean.toString(true));
        request.addParameter("search_type", "query_then_fetch");
        request.setJsonEntity(Strings.toString(builder));

        Response response = client().performRequest(request);
        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

        return response;
    }

    /**
     * Run exists search
     */
    protected Response searchExists(String index, ExistsQueryBuilder existsQueryBuilder, int resultSize) throws IOException {

        Request request = new Request("POST", "/" + index + "/_search");

        XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query");
        builder = XContentFactory.jsonBuilder().startObject();
        builder.field("query", existsQueryBuilder);
        builder.endObject();

        request.addParameter("size", Integer.toString(resultSize));
        request.setJsonEntity(Strings.toString(builder));

        Response response = client().performRequest(request);
        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

        return response;
    }

    /**
     * Parse the response of KNN search into a List of KNNResults
     */
    protected List<KNNResult> parseSearchResponse(String responseBody, String fieldName) throws IOException {
        @SuppressWarnings("unchecked")
        List<Object> hits = (List<Object>) ((Map<String, Object>) createParser(
            MediaTypeRegistry.getDefaultMediaType().xContent(),
            responseBody
        ).map().get("hits")).get("hits");

        @SuppressWarnings("unchecked")
        List<KNNResult> knnSearchResponses = hits.stream().map(hit -> {
            @SuppressWarnings("unchecked")
            Float[] vector = Arrays.stream(
                ((ArrayList<Float>) ((Map<String, Object>) ((Map<String, Object>) hit).get("_source")).get(fieldName)).toArray()
            ).map(Object::toString).map(Float::valueOf).toArray(Float[]::new);
            return new KNNResult((String) ((Map<String, Object>) hit).get("_id"), vector);
        }).collect(Collectors.toList());

        return knnSearchResponses;
    }

    protected List<Float> parseSearchResponseScore(String responseBody, String fieldName) throws IOException {
        @SuppressWarnings("unchecked")
        List<Object> hits = (List<Object>) ((Map<String, Object>) createParser(
            MediaTypeRegistry.getDefaultMediaType().xContent(),
            responseBody
        ).map().get("hits")).get("hits");

        @SuppressWarnings("unchecked")
        List<Float> knnSearchResponses = hits.stream()
            .map(hit -> ((Double) ((Map<String, Object>) hit).get("_score")).floatValue())
            .collect(Collectors.toList());

        return knnSearchResponses;
    }

    /**
     * Parse the response of Aggregation to extract the value
     */
    protected Double parseAggregationResponse(String responseBody, String aggregationName) throws IOException {
        @SuppressWarnings("unchecked")
        Map<String, Object> aggregations = ((Map<String, Object>) createParser(
            MediaTypeRegistry.getDefaultMediaType().xContent(),
            responseBody
        ).map().get("aggregations"));

        final Map<String, Object> values = (Map<String, Object>) aggregations.get(aggregationName);
        return Double.valueOf(String.valueOf(values.get("value")));
    }

    /**
     * Parse the score from the KNN search response
     */

    /**
     * Delete KNN index
     */
    protected void deleteKNNIndex(String index) throws IOException {
        Request request = new Request("DELETE", "/" + index);

        Response response = client().performRequest(request);
        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
    }

    /**
     * For a given index, make a mapping request
     */
    protected void putMappingRequest(String index, String mapping) throws IOException {
        // Put KNN mapping
        Request request = new Request("PUT", "/" + index + "/_mapping");

        request.setJsonEntity(mapping);
        Response response = client().performRequest(request);
        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
    }

    /**
     * Utility to create a Knn Index Mapping
     */
    protected String createKnnIndexMapping(String fieldName, Integer dimensions) throws IOException {
        return Strings.toString(
            XContentFactory.jsonBuilder()
                .startObject()
                .startObject("properties")
                .startObject(fieldName)
                .field("type", "knn_vector")
                .field("dimension", dimensions.toString())
                .endObject()
                .endObject()
                .endObject()
        );
    }

    /**
     * Utility to create a Knn Index Mapping with specific algorithm and engine
     */
    protected String createKnnIndexMapping(String fieldName, Integer dimensions, String algoName, String knnEngine) throws IOException {
        return Strings.toString(
            XContentFactory.jsonBuilder()
                .startObject()
                .startObject("properties")
                .startObject(fieldName)
                .field("type", "knn_vector")
                .field("dimension", dimensions.toString())
                .startObject("method")
                .field("name", algoName)
                .field("engine", knnEngine)
                .endObject()
                .endObject()
                .endObject()
                .endObject()
        );
    }

    /**
     * Utility to create a Knn Index Mapping with multiple k-NN fields
     */
    protected String createKnnIndexMapping(List<String> fieldNames, List<Integer> dimensions) throws IOException {
        assertNotEquals(0, fieldNames.size());
        assertEquals(fieldNames.size(), dimensions.size());

        XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject("properties");
        for (int i = 0; i < fieldNames.size(); i++) {
            xContentBuilder.startObject(fieldNames.get(i))
                .field("type", "knn_vector")
                .field("dimension", dimensions.get(i).toString())
                .endObject();
        }
        xContentBuilder.endObject().endObject();

        return Strings.toString(xContentBuilder);
    }

    /**
     * Get index mapping as map
     *
     * @param index name of index to fetch
     * @return index mapping a map
     */
    @SuppressWarnings("unchecked")
    public Map<String, Object> getIndexMappingAsMap(String index) throws Exception {
        Request request = new Request("GET", "/" + index + "/_mapping");

        Response response = client().performRequest(request);

        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

        String responseBody = EntityUtils.toString(response.getEntity());

        Map<String, Object> responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map();

        return (Map<String, Object>) ((Map<String, Object>) responseMap.get(index)).get("mappings");
    }

    public int getDocCount(String indexName) throws Exception {
        Request request = new Request("GET", "/" + indexName + "/_count");

        Response response = client().performRequest(request);

        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

        String responseBody = EntityUtils.toString(response.getEntity());

        Map<String, Object> responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map();
        return (Integer) responseMap.get("count");
    }

    /**
     * Force merge KNN index segments
     */
    protected void forceMergeKnnIndex(String index) throws Exception {
        Request request = new Request("POST", "/" + index + "/_refresh");

        Response response = client().performRequest(request);
        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

        request = new Request("POST", "/" + index + "/_forcemerge");

        request.addParameter("max_num_segments", "1");
        request.addParameter("flush", "true");
        response = client().performRequest(request);
        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
        TimeUnit.SECONDS.sleep(5); // To make sure force merge is completed
    }

    /**
     * Add a single KNN Doc to an index
     */
    protected void addKnnDoc(String index, String docId, String fieldName, Object[] vector) throws IOException {
        Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true");

        XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(fieldName, vector).endObject();
        request.setJsonEntity(Strings.toString(builder));
        client().performRequest(request);

        request = new Request("POST", "/" + index + "/_refresh");
        Response response = client().performRequest(request);
        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
    }

    /**
     * Add a single KNN Doc to an index with multiple fields
     */
    protected void addKnnDoc(String index, String docId, List<String> fieldNames, List<Object[]> vectors) throws IOException {
        Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true");

        XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
        for (int i = 0; i < fieldNames.size(); i++) {
            builder.field(fieldNames.get(i), vectors.get(i));
        }
        builder.endObject();

        request.setJsonEntity(Strings.toString(builder));
        Response response = client().performRequest(request);
        assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
    }

    /**
     * Add a single numeric field Doc to an index
     */
    protected void addDocWithNumericField(String index, String docId, String fieldName, long value) throws IOException {
        Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true");

        XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(fieldName, value).endObject();

        request.setJsonEntity(Strings.toString(builder));

        Response response = client().performRequest(request);

        assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
    }

    /**
     * Add a single numeric field Doc to an index
     */
    protected void addDocWithBinaryField(String index, String docId, String fieldName, String base64String) throws IOException {
        Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true");

        XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(fieldName, base64String).endObject();

        request.setJsonEntity(Strings.toString(builder));

        Response response = client().performRequest(request);

        assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
    }

    /**
     * Update a KNN Doc with a new vector for the given fieldName
     */
    protected void updateKnnDoc(String index, String docId, String fieldName, Object[] vector) throws IOException {
        Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true");

        XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(fieldName, vector).endObject();

        request.setJsonEntity(Strings.toString(builder));

        Response response = client().performRequest(request);
        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
    }

    /**
     * Delete Knn Doc
     */
    protected void deleteKnnDoc(String index, String docId) throws IOException {
        // Put KNN mapping
        Request request = new Request("DELETE", "/" + index + "/_doc/" + docId + "?refresh");

        Response response = client().performRequest(request);
        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
    }

    /**
     * Retrieve document by index and document id
     */
    protected Map<String, Object> getKnnDoc(final String index, final String docId) throws Exception {
        final Request request = new Request("GET", "/" + index + "/_doc/" + docId);
        final Response response = client().performRequest(request);

        final Map<String, Object> responseMap = createParser(
            MediaTypeRegistry.getDefaultMediaType().xContent(),
            EntityUtils.toString(response.getEntity())
        ).map();

        assertNotNull(responseMap);
        assertTrue((Boolean) responseMap.get(DOCUMENT_FIELD_FOUND));
        assertNotNull(responseMap.get(DOCUMENT_FIELD_SOURCE));

        final Map<String, Object> docMap = (Map<String, Object>) responseMap.get(DOCUMENT_FIELD_SOURCE);

        return docMap;
    }

    /**
     * Utility to update  settings
     */
    protected void updateClusterSettings(String settingKey, Object value) throws Exception {
        XContentBuilder builder = XContentFactory.jsonBuilder()
            .startObject()
            .startObject("persistent")
            .field(settingKey, value)
            .endObject()
            .endObject();
        Request request = new Request("PUT", "_cluster/settings");
        request.setJsonEntity(Strings.toString(builder));
        Response response = client().performRequest(request);
        assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
    }

    /**
     * Return default index settings for index creation
     */
    protected Settings getKNNDefaultIndexSettings() {
        return Settings.builder().put("number_of_shards", 1).put("number_of_replicas", 0).put("index.knn", true).build();
    }

    /**
     * Get Stats from KNN Plugin
     */
    protected Response getKnnStats(List<String> nodeIds, List<String> stats) throws IOException {
        return executeKnnStatRequest(nodeIds, stats, KNNPlugin.KNN_BASE_URI);
    }

    protected Response executeKnnStatRequest(List<String> nodeIds, List<String> stats, final String baseURI) throws IOException {
        String nodePrefix = "";
        if (!nodeIds.isEmpty()) {
            nodePrefix = "/" + String.join(",", nodeIds);
        }

        String statsSuffix = "";
        if (!stats.isEmpty()) {
            statsSuffix = "/" + String.join(",", stats);
        }

        Request request = new Request("GET", baseURI + nodePrefix + "/stats" + statsSuffix);

        Response response = client().performRequest(request);
        assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
        return response;
    }

    /**
     * Warmup KNN Index
     */
    protected Response knnWarmup(List<String> indices) throws IOException {
        return executeWarmupRequest(indices, KNNPlugin.KNN_BASE_URI);
    }

    protected Response executeWarmupRequest(List<String> indices, final String baseURI) throws IOException {
        String indicesSuffix = "/" + String.join(",", indices);
        Request request = new Request("GET", baseURI + "/warmup" + indicesSuffix);
        return client().performRequest(request);
    }

    /**
     * Parse KNN Cluster stats from response
     */
    protected Map<String, Object> parseClusterStatsResponse(String responseBody) throws IOException {
        Map<String, Object> responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map();
        responseMap.remove("cluster_name");
        responseMap.remove("_nodes");
        responseMap.remove("nodes");
        return responseMap;
    }

    /**
     * Parse KNN Node stats from response
     */
    protected List<Map<String, Object>> parseNodeStatsResponse(String responseBody) throws IOException {
        @SuppressWarnings("unchecked")
        Map<String, Object> responseMap = (Map<String, Object>) createParser(
            MediaTypeRegistry.getDefaultMediaType().xContent(),
            responseBody
        ).map().get("nodes");

        @SuppressWarnings("unchecked")
        List<Map<String, Object>> nodeResponses = responseMap.keySet()
            .stream()
            .map(key -> (Map<String, Object>) responseMap.get(key))
            .collect(Collectors.toList());

        return nodeResponses;
    }

    /**
     * Get the total hits from search response
     */
    @SuppressWarnings("unchecked")
    protected int parseTotalSearchHits(String searchResponseBody) throws IOException {
        Map<String, Object> responseMap = (Map<String, Object>) createParser(
            MediaTypeRegistry.getDefaultMediaType().xContent(),
            searchResponseBody
        ).map().get("hits");

        return (int) ((Map<String, Object>) responseMap.get("total")).get("value");
    }

    /**
     * Get the total number of graphs in the cache across all nodes
     */
    @SuppressWarnings("unchecked")
    protected int getTotalGraphsInCache() throws Exception {
        Response response = getKnnStats(Collections.emptyList(), Collections.emptyList());
        String responseBody = EntityUtils.toString(response.getEntity());

        List<Map<String, Object>> nodesStats = parseNodeStatsResponse(responseBody);

        logger.info("[KNN] Node stats:  " + nodesStats);

        return nodesStats.stream()
            .filter(nodeStats -> nodeStats.get(INDICES_IN_CACHE.getName()) != null)
            .map(nodeStats -> nodeStats.get(INDICES_IN_CACHE.getName()))
            .mapToInt(
                nodeIndicesStats -> ((Map<String, Map<String, Object>>) nodeIndicesStats).values()
                    .stream()
                    .mapToInt(nodeIndexStats -> (int) nodeIndexStats.get(GRAPH_COUNT))
                    .sum()
            )
            .sum();
    }

    /**
     * Get specific Index setting value from response
     */
    protected String getIndexSettingByName(String indexName, String settingName) throws IOException {
        @SuppressWarnings("unchecked")
        Map<String, Object> settings = (Map<String, Object>) ((Map<String, Object>) getIndexSettings(indexName).get(indexName)).get(
            "settings"
        );
        return (String) settings.get(settingName);
    }

    protected void createModelSystemIndex() throws IOException {
        URL url = ModelDao.class.getClassLoader().getResource(MODEL_INDEX_MAPPING_PATH);
        if (url == null) {
            throw new IllegalStateException("Unable to retrieve mapping for \"" + MODEL_INDEX_NAME + "\"");
        }

        String mapping = Resources.toString(url, Charsets.UTF_8);
        mapping = mapping.substring(1, mapping.length() - 1);

        if (!systemIndexExists(MODEL_INDEX_NAME)) {
            createIndex(MODEL_INDEX_NAME, Settings.builder().put("number_of_shards", 1).put("number_of_replicas", 0).build(), mapping);
        }
    }

    protected void addModelToSystemIndex(String modelId, ModelMetadata modelMetadata, byte[] model) throws IOException {
        assertFalse(org.opensearch.core.common.Strings.isNullOrEmpty(modelId));
        String modelBase64 = Base64.getEncoder().encodeToString(model);

        Request request = new Request("POST", "/" + MODEL_INDEX_NAME + "/_doc/" + modelId + "?refresh=true");

        XContentBuilder builder = XContentFactory.jsonBuilder()
            .startObject()
            .field(MODEL_ID, modelId)
            .field(MODEL_STATE, modelMetadata.getState().getName())
            .field(KNN_ENGINE, modelMetadata.getKnnEngine().getName())
            .field(METHOD_PARAMETER_SPACE_TYPE, modelMetadata.getSpaceType().getValue())
            .field(DIMENSION, modelMetadata.getDimension())
            .field(MODEL_BLOB_PARAMETER, modelBase64)
            .field(MODEL_TIMESTAMP, modelMetadata.getTimestamp())
            .field(MODEL_DESCRIPTION, modelMetadata.getDescription())
            .field(MODEL_ERROR, modelMetadata.getError())
            .endObject();

        request.setJsonEntity(Strings.toString(builder));

        Response response = client().performRequest(request);

        assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
    }

    /**
     * Clear cache
     * <p>
     * This function is a temporary workaround. Right now, we do not have a way of clearing the cache except by deleting
     * an index or changing k-NN settings. That being said, this function bounces a random k-NN setting in order to
     * clear the cache.
     */
    protected void clearCache() throws Exception {
        updateClusterSettings(KNNSettings.KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES, "1m");
        updateClusterSettings(KNNSettings.KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES, null);
    }

    /**
     * Clear script cache
     * <p>
     * Remove k-NN script from script cache so that it has to be recompiled
     */
    protected void clearScriptCache() throws Exception {
        updateClusterSettings("script.context.score.cache_expire", "0");
        updateClusterSettings("script.context.score.cache_expire", null);
    }

    private Script buildScript(String source, String language, Map<String, Object> params) {
        return new Script(Script.DEFAULT_SCRIPT_TYPE, language, source, params);
    }

    private ScriptedMetricAggregationBuilder getScriptedMetricAggregationBuilder(
        String initScriptSource,
        String mapScriptSource,
        String combineScriptSource,
        String reduceScriptSource,
        String language,
        String aggName
    ) {
        String scriptLanguage = language != null ? language : Script.DEFAULT_SCRIPT_LANG;
        Script initScript = buildScript(initScriptSource, scriptLanguage, Collections.emptyMap());
        Script mapScript = buildScript(mapScriptSource, scriptLanguage, Collections.emptyMap());
        Script combineScript = buildScript(combineScriptSource, scriptLanguage, Collections.emptyMap());
        Script reduceScript = buildScript(reduceScriptSource, scriptLanguage, Collections.emptyMap());
        return new ScriptedMetricAggregationBuilder(aggName).mapScript(mapScript)
            .combineScript(combineScript)
            .reduceScript(reduceScript)
            .initScript(initScript);
    }

    protected Request constructScriptedMetricAggregationSearchRequest(
        String aggName,
        String language,
        String initScriptSource,
        String mapScriptSource,
        String combineScriptSource,
        String reduceScriptSource,
        int size
    ) throws Exception {

        XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("size", size).startObject("query");
        builder.startObject("match_all");
        builder.endObject();
        builder.endObject();
        builder.startObject("aggs");
        final ScriptedMetricAggregationBuilder scriptedMetricAggregationBuilder = getScriptedMetricAggregationBuilder(
            initScriptSource,
            mapScriptSource,
            combineScriptSource,
            reduceScriptSource,
            language,
            aggName
        );
        scriptedMetricAggregationBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS);
        builder.endObject();
        builder.endObject();
        String endpoint = String.format(Locale.getDefault(), "/%s/_search?size=0&filter_path=aggregations", INDEX_NAME);
        Request request = new Request("POST", endpoint);
        request.setJsonEntity(Strings.toString(builder));
        return request;
    }

    protected Request constructScriptScoreContextSearchRequest(
        String indexName,
        QueryBuilder qb,
        Map<String, Object> params,
        String language,
        String source,
        int size
    ) throws Exception {
        Script script = buildScript(source, language, params);
        ScriptScoreQueryBuilder sc = new ScriptScoreQueryBuilder(qb, script);
        XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("size", size).startObject("query");
        builder.startObject("script_score");
        builder.field("query");
        sc.query().toXContent(builder, ToXContent.EMPTY_PARAMS);
        builder.field("script", script);
        builder.endObject();
        builder.endObject();
        builder.endObject();
        Request request = new Request("POST", "/" + indexName + "/_search");
        request.setJsonEntity(Strings.toString(builder));
        return request;
    }

    protected Request constructKNNScriptQueryRequest(String indexName, QueryBuilder qb, Map<String, Object> params) throws Exception {
        Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, KNNScoringScriptEngine.NAME, KNNScoringScriptEngine.SCRIPT_SOURCE, params);
        ScriptScoreQueryBuilder sc = new ScriptScoreQueryBuilder(qb, script);
        XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query");
        builder.startObject("script_score");
        builder.field("query");
        sc.query().toXContent(builder, ToXContent.EMPTY_PARAMS);
        builder.field("script", script);
        builder.endObject();
        builder.endObject();
        builder.endObject();
        Request request = new Request("POST", "/" + indexName + "/_search");
        request.setJsonEntity(Strings.toString(builder));
        return request;
    }

    protected Request constructKNNScriptQueryRequest(String indexName, QueryBuilder qb, Map<String, Object> params, int size)
        throws Exception {
        return constructScriptScoreContextSearchRequest(
            indexName,
            qb,
            params,
            KNNScoringScriptEngine.NAME,
            KNNScoringScriptEngine.SCRIPT_SOURCE,
            size
        );
    }

    public Map<String, Object> xContentBuilderToMap(XContentBuilder xContentBuilder) {
        return XContentHelper.convertToMap(BytesReference.bytes(xContentBuilder), true, xContentBuilder.contentType()).v2();
    }

    public void bulkIngestRandomVectors(String indexName, String fieldName, int numVectors, int dimension) throws IOException {
        for (int i = 0; i < numVectors; i++) {
            float[] vector = new float[dimension];
            for (int j = 0; j < dimension; j++) {
                vector[j] = randomFloat();
            }

            addKnnDoc(indexName, String.valueOf(i + 1), fieldName, Floats.asList(vector).toArray());
        }

    }

    // Method that adds multiple documents into the index using Bulk API
    public void bulkAddKnnDocs(String index, String fieldName, float[][] indexVectors, int docCount) throws IOException {
        Request request = new Request("POST", "/_bulk");

        request.addParameter("refresh", "true");
        StringBuilder sb = new StringBuilder();

        for (int i = 0; i < docCount; i++) {
            sb.append("{ \"index\" : { \"_index\" : \"")
                .append(index)
                .append("\", \"_id\" : \"")
                .append(i)
                .append("\" } }\n")
                .append("{ \"")
                .append(fieldName)
                .append("\" : ")
                .append(Arrays.toString(indexVectors[i]))
                .append(" }\n");
        }

        request.setJsonEntity(sb.toString());

        Response response = client().performRequest(request);
        assertEquals(response.getStatusLine().getStatusCode(), 200);
    }

    // Method that returns index vectors of the documents that were added before into the index
    public float[][] getIndexVectorsFromIndex(String testIndex, String testField, int docCount, int dimensions) throws Exception {
        float[][] vectors = new float[docCount][dimensions];

        QueryBuilder qb = new MatchAllQueryBuilder();

        Request request = new Request("POST", "/" + testIndex + "/_search");

        request.addParameter("size", Integer.toString(docCount));
        XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
        builder.field("query", qb);
        builder.endObject();
        request.setJsonEntity(Strings.toString(builder));

        Response response = client().performRequest(request);
        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

        List<KNNResult> results = parseSearchResponse(EntityUtils.toString(response.getEntity()), testField);
        int i = 0;

        for (KNNResult result : results) {
            float[] primitiveArray = Floats.toArray(Arrays.stream(result.getVector()).collect(Collectors.toList()));
            vectors[i++] = primitiveArray;
        }

        return vectors;
    }

    // Method that performs bulk search for multiple queries and stores the resulting documents ids into list
    public List<List<String>> bulkSearch(String testIndex, String testField, float[][] queryVectors, int k) throws Exception {
        List<List<String>> searchResults = new ArrayList<>();
        List<String> kVectors;

        for (int i = 0; i < queryVectors.length; i++) {
            KNNQueryBuilder knnQueryBuilderRecall = new KNNQueryBuilder(testField, queryVectors[i], k);
            Response respRecall = searchKNNIndex(testIndex, knnQueryBuilderRecall, k);
            List<KNNResult> resultsRecall = parseSearchResponse(EntityUtils.toString(respRecall.getEntity()), testField);

            assertEquals(resultsRecall.size(), k);
            kVectors = new ArrayList<>();
            for (KNNResult result : resultsRecall) {
                kVectors.add(result.getDocId());
            }
            searchResults.add(kVectors);
        }

        return searchResults;
    }

    // Method that waits till the health of nodes in the cluster goes green
    public void waitForClusterHealthGreen(String numOfNodes) throws IOException {
        Request waitForGreen = new Request("GET", "/_cluster/health");
        waitForGreen.addParameter("wait_for_nodes", numOfNodes);
        waitForGreen.addParameter("wait_for_status", "green");
        client().performRequest(waitForGreen);
    }

    // Add KNN docs into a KNN index by providing the initial documentID and number of documents
    public void addKNNDocs(String testIndex, String testField, int dimension, int firstDocID, int numDocs) throws IOException {
        for (int i = firstDocID; i < firstDocID + numDocs; i++) {
            Float[] indexVector = new Float[dimension];
            Arrays.fill(indexVector, (float) i);
            addKnnDoc(testIndex, Integer.toString(i), testField, indexVector);
        }
    }

    // Validate KNN search on a KNN index by generating the query vector from the number of documents in the index
    public void validateKNNSearch(String testIndex, String testField, int dimension, int numDocs, int k) throws Exception {
        float[] queryVector = new float[dimension];
        Arrays.fill(queryVector, (float) numDocs);

        Response searchResponse = searchKNNIndex(testIndex, new KNNQueryBuilder(testField, queryVector, k), k);
        List<KNNResult> results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), testField);

        assertEquals(k, results.size());
        for (int i = 0; i < k; i++) {
            assertEquals(numDocs - i - 1, Integer.parseInt(results.get(i).getDocId()));
        }
    }

    protected Settings createKNNIndexCustomLegacyFieldMappingSettings(SpaceType spaceType, Integer m, Integer ef_construction) {
        return Settings.builder()
            .put(NUMBER_OF_SHARDS, 1)
            .put(NUMBER_OF_REPLICAS, 0)
            .put(INDEX_KNN, true)
            .put(KNNSettings.KNN_SPACE_TYPE, spaceType.getValue())
            .put(KNNSettings.KNN_ALGO_PARAM_M, m)
            .put(KNNSettings.KNN_ALGO_PARAM_EF_CONSTRUCTION, ef_construction)
            .build();
    }

    public String createKNNIndexMethodFieldMapping(String fieldName, Integer dimensions) throws IOException {
        return Strings.toString(
            XContentFactory.jsonBuilder()
                .startObject()
                .startObject(PROPERTIES)
                .startObject(fieldName)
                .field(VECTOR_TYPE, KNN_VECTOR)
                .field(DIMENSION, dimensions.toString())
                .startObject(KNN_METHOD)
                .field(NAME, METHOD_HNSW)
                .endObject()
                .endObject()
                .endObject()
                .endObject()
        );
    }

    public String createKNNIndexCustomMethodFieldMapping(
        String fieldName,
        Integer dimensions,
        SpaceType spaceType,
        String engine,
        Integer m,
        Integer ef_construction
    ) throws IOException {
        return Strings.toString(
            XContentFactory.jsonBuilder()
                .startObject()
                .startObject(PROPERTIES)
                .startObject(fieldName)
                .field(VECTOR_TYPE, KNN_VECTOR)
                .field(DIMENSION, dimensions.toString())
                .startObject(KNN_METHOD)
                .field(NAME, METHOD_HNSW)
                .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue())
                .field(KNN_ENGINE, engine)
                .startObject(PARAMETERS)
                .field(METHOD_PARAMETER_EF_CONSTRUCTION, ef_construction)
                .field(METHOD_PARAMETER_M, m)
                .endObject()
                .endObject()
                .endObject()
                .endObject()
                .endObject()
        );
    }

    // Default KNN script score settings
    protected Settings createKNNDefaultScriptScoreSettings() {
        return Settings.builder().put(NUMBER_OF_SHARDS, 1).put(NUMBER_OF_REPLICAS, 0).put(INDEX_KNN, false).build();
    }

    // Validate script score search for these space_types : {"l2", "l1", "linf"}
    protected void validateKNNScriptScoreSearch(String testIndex, String testField, int dimension, int numDocs, int k, SpaceType spaceType)
        throws Exception {

        IDVectorProducer idVectorProducer = new IDVectorProducer(dimension, numDocs);
        float[] queryVector = idVectorProducer.getVector(numDocs);

        QueryBuilder qb = new MatchAllQueryBuilder();
        Map<String, Object> params = new HashMap<>();
        params.put(FIELD, testField);
        params.put(QUERY_VALUE, queryVector);
        params.put(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue());

        Request request = constructKNNScriptQueryRequest(testIndex, qb, params, k);
        Response response = client().performRequest(request);
        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

        List<KNNResult> results = parseSearchResponse(EntityUtils.toString(response.getEntity()), testField);
        assertEquals(k, results.size());

        PriorityQueue<DistVector> pq = computeGroundTruthValues(k, spaceType, idVectorProducer);

        for (int i = k - 1; i >= 0; i--) {
            int expDocID = Integer.parseInt(pq.poll().getDocID());
            int actualDocID = Integer.parseInt(results.get(i).getDocId());
            assertEquals(expDocID, actualDocID);
        }
    }

    // validate KNN painless script score search for the space_types : "l2", "l1"
    protected void validateKNNPainlessScriptScoreSearch(String testIndex, String testField, String source, int numDocs, int k)
        throws Exception {
        QueryBuilder qb = new MatchAllQueryBuilder();
        Request request = constructScriptScoreContextSearchRequest(
            testIndex,
            qb,
            Collections.emptyMap(),
            Script.DEFAULT_SCRIPT_LANG,
            source,
            k
        );
        Response response = client().performRequest(request);
        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

        List<KNNResult> results = parseSearchResponse(EntityUtils.toString(response.getEntity()), testField);
        assertEquals(k, results.size());

        for (int i = 0; i < k; i++) {
            int expDocID = numDocs - i - 1;
            int actualDocID = Integer.parseInt(results.get(i).getDocId());
            assertEquals(expDocID, actualDocID);
        }
    }

    // create painless script source for space_type "l2" by creating query vector based on number of documents
    protected String createL2PainlessScriptSource(String testField, int dimension, int numDocs) {
        IDVectorProducer idVectorProducer = new IDVectorProducer(dimension, numDocs);
        float[] queryVector = idVectorProducer.getVector(numDocs);
        return String.format("1/(1 + l2Squared(" + Arrays.toString(queryVector) + ", doc['%s']))", testField);
    }

    // create painless script source for space_type "l1" by creating query vector based on number of documents
    protected String createL1PainlessScriptSource(String testField, int dimension, int numDocs) {
        IDVectorProducer idVectorProducer = new IDVectorProducer(dimension, numDocs);
        float[] queryVector = idVectorProducer.getVector(numDocs);
        return String.format("1/(1 + l1Norm(" + Arrays.toString(queryVector) + ", doc['%s']))", testField);
    }

    /**
     * Method that call train api and produces a trained model
     *
     * @param modelId to identify the model. If null, one will be autogenerated
     * @param trainingIndexName index to pull training data from
     * @param trainingFieldName field to pull training data from
     * @param dimension dimension of model
     * @param method method definition for model
     * @param description description of model
     * @return Response returned by the cluster
     * @throws IOException if request cannot be performed
     */
    public Response trainModel(
        String modelId,
        String trainingIndexName,
        String trainingFieldName,
        int dimension,
        Map<String, Object> method,
        String description
    ) throws IOException {

        XContentBuilder builder = XContentFactory.jsonBuilder()
            .startObject()
            .field(TRAIN_INDEX_PARAMETER, trainingIndexName)
            .field(TRAIN_FIELD_PARAMETER, trainingFieldName)
            .field(DIMENSION, dimension)
            .field(KNN_METHOD, method)
            .field(MODEL_DESCRIPTION, description)
            .endObject();

        if (modelId == null) {
            modelId = "";
        } else {
            modelId = "/" + modelId;
        }

        Request request = new Request("POST", "/_plugins/_knn/models" + modelId + "/_train");
        request.setJsonEntity(Strings.toString(builder));
        return client().performRequest(request);
    }

    /**
     * Retrieve the model
     *
     * @param modelId Id of model to be retrieved
     * @param filters filters to filter fields out. If null, no filters will
     * @return Response from cluster
     * @throws IOException if request cannot be performed
     */
    public Response getModel(String modelId, List<String> filters) throws IOException {

        if (modelId == null) {
            modelId = "";
        } else {
            modelId = "/" + modelId;
        }

        String filterString = "";

        if (filters != null && !filters.isEmpty()) {
            filterString = "&filter_path=" + StringUtils.join(filters, ",");
        }

        Request request = new Request("GET", "/_plugins/_knn/models" + modelId + filterString);

        return client().performRequest(request);
    }

    public void assertTrainingSucceeds(String modelId, int attempts, int delayInMillis) throws InterruptedException, Exception {
        int attemptNum = 0;
        Response response;
        Map<String, Object> responseMap;
        ModelState modelState;
        while (attemptNum < attempts) {
            Thread.sleep(delayInMillis);
            attemptNum++;

            response = getModel(modelId, null);

            responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), EntityUtils.toString(response.getEntity()))
                .map();

            modelState = ModelState.getModelState((String) responseMap.get(MODEL_STATE));
            if (modelState == ModelState.CREATED) {
                return;
            }

            assertNotEquals(ModelState.FAILED, modelState);
        }

        fail("Training did not succeed after " + attempts + " attempts with a delay of " + delayInMillis + " ms.");
    }

    public void assertTrainingFails(String modelId, int attempts, int delayInMillis) throws Exception {
        int attemptNum = 0;
        Response response;
        Map<String, Object> responseMap;
        ModelState modelState;
        while (attemptNum < attempts) {
            Thread.sleep(delayInMillis);
            attemptNum++;

            response = getModel(modelId, null);

            responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), EntityUtils.toString(response.getEntity()))
                .map();

            modelState = ModelState.getModelState((String) responseMap.get(MODEL_STATE));
            if (modelState == ModelState.FAILED) {
                return;
            }

            assertNotEquals(ModelState.CREATED, modelState);
        }

        fail("Training did not succeed after " + attempts + " attempts with a delay of " + delayInMillis + " ms.");
    }

    protected boolean systemIndexExists(final String indexName) throws IOException {
        Response response = adminClient().performRequest(new Request("HEAD", "/" + indexName));
        return RestStatus.OK.getStatus() == response.getStatusLine().getStatusCode();
    }

    protected Settings.Builder noStrictDeprecationModeSettingsBuilder() {
        Settings.Builder builder = Settings.builder().put("strictDeprecationMode", false);
        if (System.getProperty("tests.rest.client_path_prefix") != null) {
            builder.put(CLIENT_PATH_PREFIX, System.getProperty("tests.rest.client_path_prefix"));
        }
        return builder;
    }

    protected void ingestDataAndTrainModel(
        String modelId,
        String trainingIndexName,
        String trainingFieldName,
        int dimension,
        String modelDescription
    ) throws Exception {
        XContentBuilder builder = XContentFactory.jsonBuilder()
            .startObject()
            .field(NAME, "ivf")
            .field(KNN_ENGINE, "faiss")
            .field(METHOD_PARAMETER_SPACE_TYPE, "l2")
            .startObject(PARAMETERS)
            .field(METHOD_PARAMETER_NLIST, 1)
            .startObject(METHOD_ENCODER_PARAMETER)
            .field(NAME, "pq")
            .startObject(PARAMETERS)
            .field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2)
            .field(ENCODER_PARAMETER_PQ_M, 2)
            .endObject()
            .endObject()
            .endObject()
            .endObject();

        Map<String, Object> method = xContentBuilderToMap(builder);
        ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription, method);
    }

    protected void ingestDataAndTrainModel(
        String modelId,
        String trainingIndexName,
        String trainingFieldName,
        int dimension,
        String modelDescription,
        Map<String, Object> method
    ) throws Exception {
        int trainingDataCount = 40;
        bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension);

        Response trainResponse = trainModel(modelId, trainingIndexName, trainingFieldName, dimension, method, modelDescription);

        assertEquals(RestStatus.OK, RestStatus.fromCode(trainResponse.getStatusLine().getStatusCode()));
    }

    protected XContentBuilder getModelMethodBuilder() throws IOException {
        XContentBuilder modelMethodBuilder = XContentFactory.jsonBuilder()
            .startObject()
            .field(NAME, "ivf")
            .field(KNN_ENGINE, FAISS.getName())
            .field(METHOD_PARAMETER_SPACE_TYPE, L2.getValue())
            .startObject(PARAMETERS)
            .field(METHOD_PARAMETER_NLIST, 1)
            .startObject(METHOD_ENCODER_PARAMETER)
            .field(NAME, "pq")
            .startObject(PARAMETERS)
            .field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2)
            .field(ENCODER_PARAMETER_PQ_M, 2)
            .endObject()
            .endObject()
            .endObject()
            .endObject();
        return modelMethodBuilder;
    }

    /**
     * We need to be able to dump the jacoco coverage before cluster is shut down.
     * The new internal testing framework removed some of the gradle tasks we were listening to
     * to choose a good time to do it. This will dump the executionData to file after each test.
     * TODO: This is also currently just overwriting integTest.exec with the updated execData without
     * resetting after writing each time. This can be improved to either write an exec file per test
     * or by letting jacoco append to the file
     */
    public interface IProxy {
        byte[] getExecutionData(boolean reset);

        void dump(boolean reset);

        void reset();
    }

    protected void refreshAllNonSystemIndices() throws Exception {
        Response response = adminClient().performRequest(new Request("GET", "/_cat/indices?format=json&expand_wildcards=all"));
        MediaType mediaType = MediaType.fromMediaType(response.getEntity().getContentType());
        try (
            XContentParser parser = mediaType.xContent()
                .createParser(
                    NamedXContentRegistry.EMPTY,
                    DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
                    response.getEntity().getContent()
                )
        ) {
            XContentParser.Token token = parser.nextToken();
            List<Map<String, Object>> parserList;
            if (token == XContentParser.Token.START_ARRAY) {
                parserList = parser.listOrderedMap().stream().map(obj -> (Map<String, Object>) obj).collect(Collectors.toList());
            } else {
                parserList = Collections.singletonList(parser.mapOrdered());
            }
            Set<String> indices = parserList.stream()
                .map(index -> (String) index.get("index"))
                .filter(index -> !index.startsWith(SYSTEM_INDEX_PREFIX))
                .collect(Collectors.toSet());
            for (String index : indices) {
                refreshIndex(index);
            }
        }
    }

    protected void refreshIndex(final String index) throws IOException {
        Request request = new Request("POST", "/" + index + "/_refresh");

        Response response = client().performRequest(request);
        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
    }

    protected void addKnnDocWithAttributes(String docId, float[] vector, Map<String, String> fieldValues) throws IOException {
        Request request = new Request("POST", "/" + INDEX_NAME + "/_doc/" + docId + "?refresh=true");

        final XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(FIELD_NAME, vector);
        for (String fieldName : fieldValues.keySet()) {
            builder.field(fieldName, fieldValues.get(fieldName));
        }
        builder.endObject();
        request.setJsonEntity(Strings.toString(builder));
        client().performRequest(request);

        request = new Request("POST", "/" + INDEX_NAME + "/_refresh");
        Response response = client().performRequest(request);
        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
    }

    protected void addKnnDocWithAttributes(
        String indexName,
        String docId,
        String vectorFieldName,
        float[] vector,
        Map<String, String> fieldValues
    ) throws IOException {
        Request request = new Request("POST", "/" + indexName + "/_doc/" + docId + "?refresh=true");

        final XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(vectorFieldName, vector);
        for (String fieldName : fieldValues.keySet()) {
            builder.field(fieldName, fieldValues.get(fieldName));
        }
        builder.endObject();
        request.setJsonEntity(Strings.toString(builder));
        client().performRequest(request);

        request = new Request("POST", "/" + indexName + "/_refresh");
        Response response = client().performRequest(request);
        assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
    }
}