/* * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.knn.index.codec; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; import lombok.Builder; import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.DocValuesProducer; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.NumericDocValues; import org.apache.lucene.index.SegmentInfo; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.SortedDocValues; import org.apache.lucene.index.SortedNumericDocValues; import org.apache.lucene.index.SortedSetDocValues; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.Sort; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.store.FilterDirectory; import org.apache.lucene.store.IOContext; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.StringHelper; import org.apache.lucene.util.Version; import java.util.Set; import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.util.KNNVectorSerializer; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.jni.JNIService; import java.io.IOException; import java.nio.file.Paths; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.Map; import static com.carrotsearch.randomizedtesting.RandomizedTest.randomFloat; import static org.junit.Assert.assertTrue; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; import static org.opensearch.test.OpenSearchTestCase.randomByteArrayOfLength; public class KNNCodecTestUtil { // Utility class to help build FieldInfo public static class FieldInfoBuilder { private final String fieldName; private int fieldNumber; private boolean storeTermVector; private boolean omitNorms; private boolean storePayloads; private IndexOptions indexOptions; private DocValuesType docValuesType; private long dvGen; private final Map<String, String> attributes; private int pointDimensionCount; private int pointIndexDimensionCount; private int pointNumBytes; private int vectorDimension; private VectorSimilarityFunction vectorSimilarityFunction; private boolean softDeletes; public static FieldInfoBuilder builder(String fieldName) { return new FieldInfoBuilder(fieldName); } private FieldInfoBuilder(String fieldName) { this.fieldName = fieldName; this.fieldNumber = 0; this.storeTermVector = false; this.omitNorms = true; this.storePayloads = true; this.indexOptions = IndexOptions.NONE; this.docValuesType = DocValuesType.BINARY; this.dvGen = 0; this.attributes = new HashMap<>(); this.pointDimensionCount = 0; this.pointIndexDimensionCount = 0; this.pointNumBytes = 0; this.vectorDimension = 0; this.vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN; this.softDeletes = false; } public FieldInfoBuilder fieldNumber(int fieldNumber) { this.fieldNumber = fieldNumber; return this; } public FieldInfoBuilder storeTermVector(boolean storeTermVector) { this.storeTermVector = storeTermVector; return this; } public FieldInfoBuilder omitNorms(boolean omitNorms) { this.omitNorms = omitNorms; return this; } public FieldInfoBuilder storePayloads(boolean storePayloads) { this.storePayloads = storePayloads; return this; } public FieldInfoBuilder indexOptions(IndexOptions indexOptions) { this.indexOptions = indexOptions; return this; } public FieldInfoBuilder docValuesType(DocValuesType docValuesType) { this.docValuesType = docValuesType; return this; } public FieldInfoBuilder dvGen(long dvGen) { this.dvGen = dvGen; return this; } public FieldInfoBuilder addAttribute(String key, String value) { this.attributes.put(key, value); return this; } public FieldInfoBuilder pointDimensionCount(int pointDimensionCount) { this.pointDimensionCount = pointDimensionCount; return this; } public FieldInfoBuilder pointIndexDimensionCount(int pointIndexDimensionCount) { this.pointIndexDimensionCount = pointIndexDimensionCount; return this; } public FieldInfoBuilder pointNumBytes(int pointNumBytes) { this.pointNumBytes = pointNumBytes; return this; } public FieldInfoBuilder vectorDimension(int vectorDimension) { this.vectorDimension = vectorDimension; return this; } public FieldInfoBuilder vectorSimilarityFunction(VectorSimilarityFunction vectorSimilarityFunction) { this.vectorSimilarityFunction = vectorSimilarityFunction; return this; } public FieldInfoBuilder softDeletes(boolean softDeletes) { this.softDeletes = softDeletes; return this; } public FieldInfo build() { return new FieldInfo( fieldName, fieldNumber, storeTermVector, omitNorms, storePayloads, indexOptions, docValuesType, dvGen, attributes, pointDimensionCount, pointIndexDimensionCount, pointNumBytes, vectorDimension, VectorEncoding.FLOAT32, vectorSimilarityFunction, softDeletes ); } } public static abstract class VectorDocValues extends BinaryDocValues { final int count; final int dimension; int current; KNNVectorSerializer knnVectorSerializer; public VectorDocValues(int count, int dimension) { this.count = count; this.dimension = dimension; this.current = -1; this.knnVectorSerializer = KNNVectorSerializerFactory.getDefaultSerializer(); } @Override public boolean advanceExact(int target) throws IOException { return false; } @Override public int docID() { if (this.current > this.count) { return BinaryDocValues.NO_MORE_DOCS; } return this.current; } @Override public int nextDoc() throws IOException { return advance(current + 1); } @Override public int advance(int target) throws IOException { current = target; if (current >= count) { current = NO_MORE_DOCS; } return current; } @Override public long cost() { return 0; } } public static class ConstantVectorBinaryDocValues extends VectorDocValues { private final BytesRef value; public ConstantVectorBinaryDocValues(int count, int dimension, float value) { super(count, dimension); float[] array = new float[dimension]; Arrays.fill(array, value); this.value = new BytesRef(knnVectorSerializer.floatToByteArray(array)); } @Override public BytesRef binaryValue() throws IOException { return value; } } public static class RandomVectorBinaryDocValues extends VectorDocValues { public RandomVectorBinaryDocValues(int count, int dimension) { super(count, dimension); } @Override public BytesRef binaryValue() throws IOException { return new BytesRef(knnVectorSerializer.floatToByteArray(getRandomVector(dimension))); } } public static class RandomVectorDocValuesProducer extends DocValuesProducer { final RandomVectorBinaryDocValues randomBinaryDocValues; public RandomVectorDocValuesProducer(int count, int dimension) { this.randomBinaryDocValues = new RandomVectorBinaryDocValues(count, dimension); } @Override public NumericDocValues getNumeric(FieldInfo field) { return null; } @Override public BinaryDocValues getBinary(FieldInfo field) throws IOException { return randomBinaryDocValues; } @Override public SortedDocValues getSorted(FieldInfo field) { return null; } @Override public SortedNumericDocValues getSortedNumeric(FieldInfo field) { return null; } @Override public SortedSetDocValues getSortedSet(FieldInfo field) { return null; } @Override public void checkIntegrity() { } @Override public void close() throws IOException { } } public static void assertFileInCorrectLocation(SegmentWriteState state, String expectedFile) throws IOException { assertTrue(Set.of(state.directory.listAll()).contains(expectedFile)); } public static void assertValidFooter(Directory dir, String filename) throws IOException { ChecksumIndexInput indexInput = dir.openChecksumInput(filename, IOContext.DEFAULT); indexInput.seek(indexInput.length() - CodecUtil.footerLength()); CodecUtil.checkFooter(indexInput); indexInput.close(); } public static void assertLoadableByEngine( SegmentWriteState state, String fileName, KNNEngine knnEngine, SpaceType spaceType, int dimension ) { String filePath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), fileName) .toString(); long indexPtr = JNIService.loadIndex( filePath, Maps.newHashMap(ImmutableMap.of(SPACE_TYPE, spaceType.getValue())), knnEngine.getName() ); int k = 2; float[] queryVector = new float[dimension]; KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, knnEngine.getName(), null); assertTrue(results.length > 0); JNIService.free(indexPtr, knnEngine.getName()); } public static float[][] getRandomVectors(int count, int dimension) { float[][] data = new float[count][dimension]; for (int i = 0; i < count; i++) { data[i] = getRandomVector(dimension); } return data; } public static float[] getRandomVector(int dimension) { float[] data = new float[dimension]; for (int i = 0; i < dimension; i++) { data[i] = randomFloat(); } return data; } @Builder(builderMethodName = "segmentInfoBuilder") public static SegmentInfo newSegmentInfo(final Directory directory, final String segmentName, int docsInSegment, final Codec codec) { return new SegmentInfo( directory, Version.LATEST, Version.LATEST, segmentName, docsInSegment, false, codec, Collections.emptyMap(), randomByteArrayOfLength(StringHelper.ID_LENGTH), ImmutableMap.of(), Sort.INDEXORDER ); } }