/* * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.neuralsearch.plugin; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.function.Supplier; import lombok.extern.log4j.Log4j2; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.util.FeatureFlags; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; import org.opensearch.ingest.Processor; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.ExtensiblePlugin; import org.opensearch.plugins.IngestPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.plugins.SearchPipelinePlugin; import org.opensearch.plugins.SearchPlugin; import org.opensearch.repositories.RepositoriesService; import org.opensearch.script.ScriptService; import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; import org.opensearch.search.query.QueryPhaseSearcher; import org.opensearch.threadpool.ThreadPool; import org.opensearch.watcher.ResourceWatcherService; import com.google.common.annotations.VisibleForTesting; /** * Neural Search plugin class */ @Log4j2 public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin, SearchPipelinePlugin { /** * Gates the functionality of hybrid search * Currently query phase searcher added with hybrid search will conflict with concurrent search in core. * Once that problem is resolved this feature flag can be removed. */ @VisibleForTesting public static final String NEURAL_SEARCH_HYBRID_SEARCH_ENABLED = "neural_search_hybrid_search_enabled"; private MLCommonsClientAccessor clientAccessor; private NormalizationProcessorWorkflow normalizationProcessorWorkflow; private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();; @Override public Collection createComponents( final Client client, final ClusterService clusterService, final ThreadPool threadPool, final ResourceWatcherService resourceWatcherService, final ScriptService scriptService, final NamedXContentRegistry xContentRegistry, final Environment environment, final NodeEnvironment nodeEnvironment, final NamedWriteableRegistry namedWriteableRegistry, final IndexNameExpressionResolver indexNameExpressionResolver, final Supplier repositoriesServiceSupplier ) { NeuralQueryBuilder.initialize(clientAccessor); normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()); return List.of(clientAccessor); } @Override public List> getQueries() { return Arrays.asList( new QuerySpec<>(NeuralQueryBuilder.NAME, NeuralQueryBuilder::new, NeuralQueryBuilder::fromXContent), new QuerySpec<>(HybridQueryBuilder.NAME, HybridQueryBuilder::new, HybridQueryBuilder::fromXContent) ); } @Override public Map getProcessors(Processor.Parameters parameters) { clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client)); return Collections.singletonMap(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env)); } @Override public Optional getQueryPhaseSearcher() { if (FeatureFlags.isEnabled(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED)) { log.info("Registering hybrid query phase searcher with feature flag [%]", NEURAL_SEARCH_HYBRID_SEARCH_ENABLED); return Optional.of(new HybridQueryPhaseSearcher()); } log.info("Not registering hybrid query phase searcher because feature flag [%] is disabled", NEURAL_SEARCH_HYBRID_SEARCH_ENABLED); // we want feature be disabled by default due to risk of colliding and breaking concurrent search in core return Optional.empty(); } @Override public Map> getSearchPhaseResultsProcessors( Parameters parameters ) { return Map.of( NormalizationProcessor.TYPE, new NormalizationProcessorFactory(normalizationProcessorWorkflow, scoreNormalizationFactory, scoreCombinationFactory) ); } }