/* * SPDX-License-Identifier: Apache-2.0 * * The OpenSearch Contributors require contributions made to * this file be licensed under the Apache-2.0 license or a * compatible open source license. */ /* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /* * Modifications Copyright OpenSearch Contributors. See * GitHub history for details. */ package org.opensearch.action.search; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.apache.lucene.tests.store.MockDirectoryWrapper; import org.opensearch.action.OriginalIndices; import org.opensearch.common.UUIDs; import org.opensearch.core.common.breaker.CircuitBreaker; import org.opensearch.core.common.breaker.NoopCircuitBreaker; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.index.shard.ShardId; import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchShardTarget; import org.opensearch.search.fetch.FetchSearchResult; import org.opensearch.search.fetch.QueryFetchSearchResult; import org.opensearch.search.fetch.ShardFetchSearchRequest; import org.opensearch.search.internal.ShardSearchContextId; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.InternalAggregationTestCase; import org.opensearch.transport.Transport; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicInteger; public class FetchSearchPhaseTests extends OpenSearchTestCase { public void testShortcutQueryAndFetchOptimization() { SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder() ); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1); QueryPhaseResultConsumer results = controller.newSearchPhaseResults( OpenSearchExecutors.newDirectExecutorService(), new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, mockSearchPhaseContext.getRequest(), 1, exc -> {} ); boolean hasHits = randomBoolean(); final int numHits; if (hasHits) { QuerySearchResult queryResult = new QuerySearchResult(); queryResult.setSearchShardTarget(new SearchShardTarget("node0", new ShardId("index", "index", 0), null, OriginalIndices.NONE)); queryResult.topDocs( new TopDocsAndMaxScore( new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(42, 1.0F) }), 1.0F ), new DocValueFormat[0] ); queryResult.size(1); FetchSearchResult fetchResult = new FetchSearchResult(); fetchResult.hits(new SearchHits(new SearchHit[] { new SearchHit(42) }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F)); QueryFetchSearchResult fetchSearchResult = new QueryFetchSearchResult(queryResult, fetchResult); fetchSearchResult.setShardIndex(0); results.consumeResult(fetchSearchResult, () -> {}); numHits = 1; } else { numHits = 0; } FetchSearchPhase phase = new FetchSearchPhase( results, controller, null, mockSearchPhaseContext, (searchResponse, scrollId) -> new SearchPhase("test") { @Override public void run() { mockSearchPhaseContext.sendSearchResponse(searchResponse, null); } } ); assertEquals("fetch", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); SearchResponse searchResponse = mockSearchPhaseContext.searchResponse.get(); assertNotNull(searchResponse); assertEquals(numHits, searchResponse.getHits().getTotalHits().value); if (numHits != 0) { assertEquals(42, searchResponse.getHits().getAt(0).docId()); } assertTrue(mockSearchPhaseContext.releasedSearchContexts.isEmpty()); } public void testFetchTwoDocument() { MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder() ); QueryPhaseResultConsumer results = controller.newSearchPhaseResults( OpenSearchExecutors.newDirectExecutorService(), new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {} ); int resultSetSize = randomIntBetween(2, 10); ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); QuerySearchResult queryResult = new QuerySearchResult( ctx1, new SearchShardTarget("node1", new ShardId("test", "na", 0), null, OriginalIndices.NONE), null ); queryResult.topDocs( new TopDocsAndMaxScore( new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(42, 1.0F) }), 2.0F ), new DocValueFormat[0] ); queryResult.size(resultSetSize); // the size of the result set queryResult.setShardIndex(0); results.consumeResult(queryResult, () -> {}); final ShardSearchContextId ctx2 = new ShardSearchContextId(UUIDs.base64UUID(), 321); queryResult = new QuerySearchResult( ctx2, new SearchShardTarget("node2", new ShardId("test", "na", 1), null, OriginalIndices.NONE), null ); queryResult.topDocs( new TopDocsAndMaxScore( new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(84, 2.0F) }), 2.0F ), new DocValueFormat[0] ); queryResult.size(resultSetSize); queryResult.setShardIndex(1); results.consumeResult(queryResult, () -> {}); mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null) { @Override public void sendExecuteFetch( Transport.Connection connection, ShardFetchSearchRequest request, SearchTask task, SearchActionListener<FetchSearchResult> listener ) { FetchSearchResult fetchResult = new FetchSearchResult(); if (request.contextId().equals(ctx2)) { fetchResult.hits( new SearchHits(new SearchHit[] { new SearchHit(84) }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 2.0F) ); } else { assertEquals(ctx1, request.contextId()); fetchResult.hits( new SearchHits(new SearchHit[] { new SearchHit(42) }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F) ); } listener.onResponse(fetchResult); } }; FetchSearchPhase phase = new FetchSearchPhase( results, controller, null, mockSearchPhaseContext, (searchResponse, scrollId) -> new SearchPhase("test") { @Override public void run() { mockSearchPhaseContext.sendSearchResponse(searchResponse, null); } } ); assertEquals("fetch", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); SearchResponse searchResponse = mockSearchPhaseContext.searchResponse.get(); assertNotNull(searchResponse); assertEquals(2, searchResponse.getHits().getTotalHits().value); assertEquals(84, searchResponse.getHits().getAt(0).docId()); assertEquals(42, searchResponse.getHits().getAt(1).docId()); assertEquals(0, searchResponse.getFailedShards()); assertEquals(2, searchResponse.getSuccessfulShards()); assertTrue(mockSearchPhaseContext.releasedSearchContexts.isEmpty()); } public void testFailFetchOneDoc() { MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder() ); QueryPhaseResultConsumer results = controller.newSearchPhaseResults( OpenSearchExecutors.newDirectExecutorService(), new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {} ); int resultSetSize = randomIntBetween(2, 10); final ShardSearchContextId ctx = new ShardSearchContextId(UUIDs.base64UUID(), 123); QuerySearchResult queryResult = new QuerySearchResult( ctx, new SearchShardTarget("node1", new ShardId("test", "na", 0), null, OriginalIndices.NONE), null ); queryResult.topDocs( new TopDocsAndMaxScore( new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(42, 1.0F) }), 2.0F ), new DocValueFormat[0] ); queryResult.size(resultSetSize); // the size of the result set queryResult.setShardIndex(0); results.consumeResult(queryResult, () -> {}); queryResult = new QuerySearchResult( new ShardSearchContextId("", 321), new SearchShardTarget("node2", new ShardId("test", "na", 1), null, OriginalIndices.NONE), null ); queryResult.topDocs( new TopDocsAndMaxScore( new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(84, 2.0F) }), 2.0F ), new DocValueFormat[0] ); queryResult.size(resultSetSize); queryResult.setShardIndex(1); results.consumeResult(queryResult, () -> {}); mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null) { @Override public void sendExecuteFetch( Transport.Connection connection, ShardFetchSearchRequest request, SearchTask task, SearchActionListener<FetchSearchResult> listener ) { if (request.contextId().getId() == 321) { FetchSearchResult fetchResult = new FetchSearchResult(); fetchResult.hits( new SearchHits(new SearchHit[] { new SearchHit(84) }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 2.0F) ); listener.onResponse(fetchResult); } else { listener.onFailure(new MockDirectoryWrapper.FakeIOException()); } } }; FetchSearchPhase phase = new FetchSearchPhase( results, controller, null, mockSearchPhaseContext, (searchResponse, scrollId) -> new SearchPhase("test") { @Override public void run() { mockSearchPhaseContext.sendSearchResponse(searchResponse, null); } } ); assertEquals("fetch", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); SearchResponse searchResponse = mockSearchPhaseContext.searchResponse.get(); assertNotNull(searchResponse); assertEquals(2, searchResponse.getHits().getTotalHits().value); assertEquals(84, searchResponse.getHits().getAt(0).docId()); assertEquals(1, searchResponse.getFailedShards()); assertEquals(1, searchResponse.getSuccessfulShards()); assertEquals(1, searchResponse.getShardFailures().length); assertTrue(searchResponse.getShardFailures()[0].getCause() instanceof MockDirectoryWrapper.FakeIOException); assertEquals(1, mockSearchPhaseContext.releasedSearchContexts.size()); assertTrue(mockSearchPhaseContext.releasedSearchContexts.contains(ctx)); } public void testFetchDocsConcurrently() throws InterruptedException { int resultSetSize = randomIntBetween(0, 100); // we use at least 2 hits otherwise this is subject to single shard optimization and we trip an assert... int numHits = randomIntBetween(2, 100); // also numshards --> 1 hit per shard SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder() ); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(numHits); QueryPhaseResultConsumer results = controller.newSearchPhaseResults( OpenSearchExecutors.newDirectExecutorService(), new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, mockSearchPhaseContext.getRequest(), numHits, exc -> {} ); for (int i = 0; i < numHits; i++) { QuerySearchResult queryResult = new QuerySearchResult( new ShardSearchContextId("", i), new SearchShardTarget("node1", new ShardId("test", "na", 0), null, OriginalIndices.NONE), null ); queryResult.topDocs( new TopDocsAndMaxScore( new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(i + 1, i) }), i ), new DocValueFormat[0] ); queryResult.size(resultSetSize); // the size of the result set queryResult.setShardIndex(i); results.consumeResult(queryResult, () -> {}); } mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null) { @Override public void sendExecuteFetch( Transport.Connection connection, ShardFetchSearchRequest request, SearchTask task, SearchActionListener<FetchSearchResult> listener ) { new Thread(() -> { FetchSearchResult fetchResult = new FetchSearchResult(); fetchResult.hits( new SearchHits( new SearchHit[] { new SearchHit((int) (request.contextId().getId() + 1)) }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 100F ) ); listener.onResponse(fetchResult); }).start(); } }; CountDownLatch latch = new CountDownLatch(1); FetchSearchPhase phase = new FetchSearchPhase( results, controller, null, mockSearchPhaseContext, (searchResponse, scrollId) -> new SearchPhase("test") { @Override public void run() { mockSearchPhaseContext.sendSearchResponse(searchResponse, null); latch.countDown(); } } ); assertEquals("fetch", phase.getName()); phase.run(); latch.await(); mockSearchPhaseContext.assertNoFailure(); SearchResponse searchResponse = mockSearchPhaseContext.searchResponse.get(); assertNotNull(searchResponse); assertEquals(numHits, searchResponse.getHits().getTotalHits().value); assertEquals(Math.min(numHits, resultSetSize), searchResponse.getHits().getHits().length); SearchHit[] hits = searchResponse.getHits().getHits(); for (int i = 0; i < hits.length; i++) { assertNotNull(hits[i]); assertEquals("index: " + i, numHits - i, hits[i].docId()); assertEquals("index: " + i, numHits - 1 - i, (int) hits[i].getScore()); } assertEquals(0, searchResponse.getFailedShards()); assertEquals(numHits, searchResponse.getSuccessfulShards()); int sizeReleasedContexts = Math.max(0, numHits - resultSetSize); // all non fetched results will be freed assertEquals( mockSearchPhaseContext.releasedSearchContexts.toString(), sizeReleasedContexts, mockSearchPhaseContext.releasedSearchContexts.size() ); } public void testExceptionFailsPhase() { MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder() ); QueryPhaseResultConsumer results = controller.newSearchPhaseResults( OpenSearchExecutors.newDirectExecutorService(), new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {} ); int resultSetSize = randomIntBetween(2, 10); QuerySearchResult queryResult = new QuerySearchResult( new ShardSearchContextId("", 123), new SearchShardTarget("node1", new ShardId("test", "na", 0), null, OriginalIndices.NONE), null ); queryResult.topDocs( new TopDocsAndMaxScore( new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(42, 1.0F) }), 2.0F ), new DocValueFormat[0] ); queryResult.size(resultSetSize); // the size of the result set queryResult.setShardIndex(0); results.consumeResult(queryResult, () -> {}); queryResult = new QuerySearchResult( new ShardSearchContextId("", 321), new SearchShardTarget("node2", new ShardId("test", "na", 1), null, OriginalIndices.NONE), null ); queryResult.topDocs( new TopDocsAndMaxScore( new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(84, 2.0F) }), 2.0F ), new DocValueFormat[0] ); queryResult.size(resultSetSize); queryResult.setShardIndex(1); results.consumeResult(queryResult, () -> {}); AtomicInteger numFetches = new AtomicInteger(0); mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null) { @Override public void sendExecuteFetch( Transport.Connection connection, ShardFetchSearchRequest request, SearchTask task, SearchActionListener<FetchSearchResult> listener ) { FetchSearchResult fetchResult = new FetchSearchResult(); if (numFetches.incrementAndGet() == 1) { throw new RuntimeException("BOOM"); } if (request.contextId().getId() == 321) { fetchResult.hits( new SearchHits(new SearchHit[] { new SearchHit(84) }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 2.0F) ); } else { assertEquals(request, 123); fetchResult.hits( new SearchHits(new SearchHit[] { new SearchHit(42) }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0F) ); } listener.onResponse(fetchResult); } }; FetchSearchPhase phase = new FetchSearchPhase( results, controller, null, mockSearchPhaseContext, (searchResponse, scrollId) -> new SearchPhase("test") { @Override public void run() { mockSearchPhaseContext.sendSearchResponse(searchResponse, null); } } ); assertEquals("fetch", phase.getName()); phase.run(); assertNotNull(mockSearchPhaseContext.phaseFailure.get()); assertEquals(mockSearchPhaseContext.phaseFailure.get().getMessage(), "BOOM"); assertNull(mockSearchPhaseContext.searchResponse.get()); assertTrue(mockSearchPhaseContext.releasedSearchContexts.isEmpty()); } public void testCleanupIrrelevantContexts() { // contexts that are not fetched should be cleaned up MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder() ); QueryPhaseResultConsumer results = controller.newSearchPhaseResults( OpenSearchExecutors.newDirectExecutorService(), new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {} ); int resultSetSize = 1; final ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); QuerySearchResult queryResult = new QuerySearchResult( ctx1, new SearchShardTarget("node1", new ShardId("test", "na", 0), null, OriginalIndices.NONE), null ); queryResult.topDocs( new TopDocsAndMaxScore( new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(42, 1.0F) }), 2.0F ), new DocValueFormat[0] ); queryResult.size(resultSetSize); // the size of the result set queryResult.setShardIndex(0); results.consumeResult(queryResult, () -> {}); final ShardSearchContextId ctx2 = new ShardSearchContextId(UUIDs.base64UUID(), 321); queryResult = new QuerySearchResult( ctx2, new SearchShardTarget("node2", new ShardId("test", "na", 1), null, OriginalIndices.NONE), null ); queryResult.topDocs( new TopDocsAndMaxScore( new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(84, 2.0F) }), 2.0F ), new DocValueFormat[0] ); queryResult.size(resultSetSize); queryResult.setShardIndex(1); results.consumeResult(queryResult, () -> {}); mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null) { @Override public void sendExecuteFetch( Transport.Connection connection, ShardFetchSearchRequest request, SearchTask task, SearchActionListener<FetchSearchResult> listener ) { FetchSearchResult fetchResult = new FetchSearchResult(); if (request.contextId().getId() == 321) { fetchResult.hits( new SearchHits(new SearchHit[] { new SearchHit(84) }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 2.0F) ); } else { fail("requestID 123 should not be fetched but was"); } listener.onResponse(fetchResult); } }; FetchSearchPhase phase = new FetchSearchPhase( results, controller, null, mockSearchPhaseContext, (searchResponse, scrollId) -> new SearchPhase("test") { @Override public void run() { mockSearchPhaseContext.sendSearchResponse(searchResponse, null); } } ); assertEquals("fetch", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); SearchResponse searchResponse = mockSearchPhaseContext.searchResponse.get(); assertNotNull(searchResponse); assertEquals(2, searchResponse.getHits().getTotalHits().value); assertEquals(1, searchResponse.getHits().getHits().length); assertEquals(84, searchResponse.getHits().getAt(0).docId()); assertEquals(0, searchResponse.getFailedShards()); assertEquals(2, searchResponse.getSuccessfulShards()); assertEquals(1, mockSearchPhaseContext.releasedSearchContexts.size()); assertTrue(mockSearchPhaseContext.releasedSearchContexts.contains(ctx1)); } }