/* * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.knn.index.util; import com.google.common.collect.ImmutableMap; import org.opensearch.common.ValidationException; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNMethod; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponent; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import java.io.IOException; import java.util.Collections; import java.util.HashMap; import java.util.Map; import static org.opensearch.knn.common.KNNConstants.NAME; public class AbstractKNNLibraryTests extends KNNTestCase { public void testGetVersion() { String testVersion = "test-version"; TestAbstractKNNLibrary testAbstractKNNLibrary = new TestAbstractKNNLibrary(Collections.emptyMap(), testVersion); assertEquals(testVersion, testAbstractKNNLibrary.getVersion()); } public void testGetMethod() { String methodName1 = "test-method-1"; KNNMethod knnMethod1 = KNNMethod.Builder.builder(MethodComponent.Builder.builder(methodName1).build()).build(); String methodName2 = "test-method-2"; KNNMethod knnMethod2 = KNNMethod.Builder.builder(MethodComponent.Builder.builder(methodName2).build()).build(); Map knnMethodMap = ImmutableMap.of(methodName1, knnMethod1, methodName2, knnMethod2); TestAbstractKNNLibrary testAbstractKNNLibrary = new TestAbstractKNNLibrary(knnMethodMap, ""); assertEquals(knnMethod1, testAbstractKNNLibrary.getMethod(methodName1)); assertEquals(knnMethod2, testAbstractKNNLibrary.getMethod(methodName2)); expectThrows(IllegalArgumentException.class, () -> testAbstractKNNLibrary.getMethod("invalid")); } public void testValidateMethod() throws IOException { // Invalid - method not supported String methodName1 = "test-method-1"; KNNMethod knnMethod1 = KNNMethod.Builder.builder(MethodComponent.Builder.builder(methodName1).build()).build(); Map methodMap = ImmutableMap.of(methodName1, knnMethod1); TestAbstractKNNLibrary testAbstractKNNLibrary1 = new TestAbstractKNNLibrary(methodMap, ""); XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().field(NAME, "invalid").endObject(); Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext1 = KNNMethodContext.parse(in); expectThrows(IllegalArgumentException.class, () -> testAbstractKNNLibrary1.validateMethod(knnMethodContext1)); // Invalid - method validation String methodName2 = "test-method-2"; KNNMethod knnMethod2 = new KNNMethod(MethodComponent.Builder.builder(methodName2).build(), Collections.emptySet()) { @Override public ValidationException validate(KNNMethodContext knnMethodContext) { return new ValidationException(); } }; methodMap = ImmutableMap.of(methodName2, knnMethod2); TestAbstractKNNLibrary testAbstractKNNLibrary2 = new TestAbstractKNNLibrary(methodMap, ""); xContentBuilder = XContentFactory.jsonBuilder().startObject().field(NAME, methodName2).endObject(); in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext2 = KNNMethodContext.parse(in); assertNotNull(testAbstractKNNLibrary2.validateMethod(knnMethodContext2)); } public void testGetMethodAsMap() { String methodName = "test-method-1"; SpaceType spaceType = SpaceType.DEFAULT; Map generatedMap = ImmutableMap.of("test-key", "test-param"); MethodComponent methodComponent = MethodComponent.Builder.builder(methodName) .setMapGenerator(((methodComponent1, methodComponentContext) -> generatedMap)) .build(); KNNMethod knnMethod = KNNMethod.Builder.builder(methodComponent).build(); TestAbstractKNNLibrary testAbstractKNNLibrary = new TestAbstractKNNLibrary(ImmutableMap.of(methodName, knnMethod), ""); // Check that map is expected Map expectedMap = new HashMap<>(generatedMap); expectedMap.put(KNNConstants.SPACE_TYPE, spaceType.getValue()); KNNMethodContext knnMethodContext = new KNNMethodContext( KNNEngine.DEFAULT, spaceType, new MethodComponentContext(methodName, Collections.emptyMap()) ); assertEquals(expectedMap, testAbstractKNNLibrary.getMethodAsMap(knnMethodContext)); // Check when invalid method is passed in KNNMethodContext invalidKnnMethodContext = new KNNMethodContext( KNNEngine.DEFAULT, spaceType, new MethodComponentContext("invalid", Collections.emptyMap()) ); expectThrows(IllegalArgumentException.class, () -> testAbstractKNNLibrary.getMethodAsMap(invalidKnnMethodContext)); } private static class TestAbstractKNNLibrary extends AbstractKNNLibrary { public TestAbstractKNNLibrary(Map methods, String currentVersion) { super(methods, currentVersion); } @Override public String getExtension() { return null; } @Override public String getCompoundExtension() { return null; } @Override public float score(float rawScore, SpaceType spaceType) { return 0; } @Override public int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension) { return 0; } @Override public Boolean isInitialized() { return null; } @Override public void setInitialized(Boolean isInitialized) { } } }