/* * SPDX-License-Identifier: Apache-2.0 * * The OpenSearch Contributors require contributions made to * this file be licensed under the Apache-2.0 license or a * compatible open source license. * * Modifications Copyright OpenSearch Contributors. See * GitHub history for details. */ #include "faiss_wrapper.h" #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "jni_util.h" #include "test_util.h" using ::testing::NiceMock; using ::testing::Return; TEST(FaissCreateIndexTest, BasicAssertions) { // Define the data faiss::idx_t numIds = 200; std::vector ids; std::vector> vectors; int dim = 2; for (int64_t i = 0; i < numIds; ++i) { ids.push_back(i); std::vector vect; vect.reserve(dim); for (int j = 0; j < dim; ++j) { vect.push_back(test_util::RandomFloat(-500.0, 500.0)); } vectors.push_back(vect); } std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); std::string spaceType = knn_jni::L2; std::string index_description = "Flat"; // TODO: Revert bach to HNSW32,Flat std::unordered_map parametersMap; parametersMap[knn_jni::SPACE_TYPE] = (jobject)&spaceType; parametersMap[knn_jni::INDEX_DESCRIPTION] = (jobject)&index_description; // Set up jni JNIEnv *jniEnv = nullptr; NiceMock mockJNIUtil; EXPECT_CALL(mockJNIUtil, GetJavaObjectArrayLength( jniEnv, reinterpret_cast(&vectors))) .WillRepeatedly(Return(vectors.size())); // Create the index knn_jni::faiss_wrapper::CreateIndex( &mockJNIUtil, jniEnv, reinterpret_cast(&ids), reinterpret_cast(&vectors), (jstring)&indexPath, (jobject)¶metersMap); // Make sure index can be loaded std::unique_ptr index(test_util::FaissLoadIndex(indexPath)); // Clean up std::remove(indexPath.c_str()); } TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) { // Define the data faiss::idx_t numIds = 100; std::vector ids; std::vector> vectors; int dim = 2; for (int64_t i = 0; i < numIds; ++i) { ids.push_back(i); std::vector vect; vect.reserve(dim); for (int j = 0; j < dim; ++j) { vect.push_back(test_util::RandomFloat(-500.0, 500.0)); } vectors.push_back(vect); } std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); faiss::MetricType metricType = faiss::METRIC_L2; std::string method = "Flat"; // TODO: Revert bach to HNSW32,Flat std::unique_ptr createdIndex( test_util::FaissCreateIndex(dim, method, metricType)); auto vectorIoWriter = test_util::FaissGetSerializedIndex(createdIndex.get()); // Setup jni JNIEnv *jniEnv = nullptr; NiceMock mockJNIUtil; EXPECT_CALL(mockJNIUtil, GetJavaObjectArrayLength( jniEnv, reinterpret_cast(&vectors))) .WillRepeatedly(Return(vectors.size())); std::string spaceType = knn_jni::L2; std::unordered_map parametersMap; parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType; knn_jni::faiss_wrapper::CreateIndexFromTemplate( &mockJNIUtil, jniEnv, reinterpret_cast(&ids), reinterpret_cast(&vectors), (jstring)&indexPath, reinterpret_cast(&(vectorIoWriter.data)), (jobject) ¶metersMap ); // Make sure index can be loaded std::unique_ptr index(test_util::FaissLoadIndex(indexPath)); // Clean up std::remove(indexPath.c_str()); } TEST(FaissLoadIndexTest, BasicAssertions) { // Define the data faiss::idx_t numIds = 100; std::vector ids; std::vector vectors; int dim = 2; for (int64_t i = 0; i < numIds; i++) { ids.push_back(i); for (int j = 0; j < dim; j++) { vectors.push_back(test_util::RandomFloat(-500.0, 500.0)); } } std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); faiss::MetricType metricType = faiss::METRIC_L2; std::string method = "Flat"; // TODO: Revert bach to HNSW32,Flat // Create the index std::unique_ptr createdIndex( test_util::FaissCreateIndex(dim, method, metricType)); auto createdIndexWithData = test_util::FaissAddData(createdIndex.get(), ids, vectors); test_util::FaissWriteIndex(&createdIndexWithData, indexPath); // Setup jni JNIEnv *jniEnv = nullptr; NiceMock mockJNIUtil; std::unique_ptr loadedIndexPointer( reinterpret_cast(knn_jni::faiss_wrapper::LoadIndex( &mockJNIUtil, jniEnv, (jstring)&indexPath))); // Compare serialized versions auto createIndexSerialization = test_util::FaissGetSerializedIndex(&createdIndexWithData); auto loadedIndexSerialization = test_util::FaissGetSerializedIndex( reinterpret_cast(loadedIndexPointer.get())); ASSERT_NE(0, loadedIndexSerialization.data.size()); ASSERT_EQ(createIndexSerialization.data.size(), loadedIndexSerialization.data.size()); for (int i = 0; i < loadedIndexSerialization.data.size(); ++i) { ASSERT_EQ(createIndexSerialization.data[i], loadedIndexSerialization.data[i]); } // Clean up std::remove(indexPath.c_str()); } TEST(FaissQueryIndexTest, BasicAssertions) { // Define the index data faiss::idx_t numIds = 100; std::vector ids; std::vector vectors; int dim = 16; for (int64_t i = 0; i < numIds; i++) { ids.push_back(i); for (int j = 0; j < dim; j++) { vectors.push_back(test_util::RandomFloat(-500.0, 500.0)); } } faiss::MetricType metricType = faiss::METRIC_L2; std::string method = "Flat"; // TODO: Revert bach to HNSW32,Flat // Define query data int k = 10; int numQueries = 100; std::vector> queries; for (int i = 0; i < numQueries; i++) { std::vector query; query.reserve(dim); for (int j = 0; j < dim; j++) { query.push_back(test_util::RandomFloat(-500.0, 500.0)); } queries.push_back(query); } // Create the index std::unique_ptr createdIndex( test_util::FaissCreateIndex(2, method, metricType)); auto createdIndexWithData = test_util::FaissAddData(createdIndex.get(), ids, vectors); // Setup jni JNIEnv *jniEnv = nullptr; NiceMock mockJNIUtil; for (auto query : queries) { std::unique_ptr *>> results( reinterpret_cast *> *>( knn_jni::faiss_wrapper::QueryIndex( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), reinterpret_cast(&query), k))); ASSERT_EQ(k, results->size()); // Need to free up each result for (auto it : *results.get()) { delete it; } } } TEST(FaissFreeTest, BasicAssertions) { // Define the data int dim = 2; faiss::MetricType metricType = faiss::METRIC_L2; std::string method = "Flat"; // TODO: Revert bach to HNSW32,Flat // Create the index faiss::Index *createdIndex( test_util::FaissCreateIndex(dim, method, metricType)); // Free created index --> memory check should catch failure knn_jni::faiss_wrapper::Free(reinterpret_cast(createdIndex)); } TEST(FaissInitLibraryTest, BasicAssertions) { knn_jni::faiss_wrapper::InitLibrary(); } TEST(FaissTrainIndexTest, BasicAssertions) { // Define the index configuration int dim = 2; std::string spaceType = knn_jni::L2; std::string index_description = "IVF4,Flat"; std::unordered_map parametersMap; parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType; parametersMap[knn_jni::INDEX_DESCRIPTION] = (jobject) &index_description; // Define training data int numTrainingVectors = 256; std::vector trainingVectors; for (int i = 0; i < numTrainingVectors; ++i) { for (int j = 0; j < dim; ++j) { trainingVectors.push_back(test_util::RandomFloat(-500.0, 500.0)); } } // Setup jni JNIEnv *jniEnv = nullptr; NiceMock mockJNIUtil; // Perform training std::unique_ptr> trainedIndexSerialization( reinterpret_cast *>( knn_jni::faiss_wrapper::TrainIndex( &mockJNIUtil, jniEnv, (jobject) ¶metersMap, dim, reinterpret_cast(&trainingVectors)))); std::unique_ptr trainedIndex( test_util::FaissLoadFromSerializedIndex(trainedIndexSerialization.get())); // Confirm that training succeeded ASSERT_TRUE(trainedIndex->is_trained); }