Skip to content

Commit 61d84d1

Browse files
committed
Coordinator can return partial results after the timeout when allow_partial_search_results is true
Signed-off-by: kkewwei <kewei.11@bytedance.com> Signed-off-by: kkewwei <kkewwei@163.com>
1 parent 078ebb4 commit 61d84d1

File tree

10 files changed

+158
-28
lines changed

10 files changed

+158
-28
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
2020
- Add vertical scaling and SoftReference for snapshot repository data cache ([#16489](https://github.com/opensearch-project/OpenSearch/pull/16489))
2121
- Support prefix list for remote repository attributes([#16271](https://github.com/opensearch-project/OpenSearch/pull/16271))
2222
- Add new configuration setting `synonym_analyzer`, to the `synonym` and `synonym_graph` filters, enabling the specification of a custom analyzer for reading the synonym file ([#16488](https://github.com/opensearch-project/OpenSearch/pull/16488)).
23+
- Coordinator can return partial results after the timeout when allow_partial_search_results is true ([#16681](https://github.com/opensearch-project/OpenSearch/pull/16681)).
2324

2425
### Dependencies
2526
- Bump `com.google.cloud:google-cloud-core-http` from 2.23.0 to 2.47.0 ([#16504](https://github.com/opensearch-project/OpenSearch/pull/16504))

server/src/main/java/org/opensearch/action/search/SearchRequest.java

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla
128128
// it's only been used in coordinator, so we don't need to serialize/deserialize it
129129
private long startTimeMills;
130130

131-
private float queryPhaseTimeoutPercentage;
131+
private float queryPhaseTimeoutPercentage = 0.8f;
132132

133133
public SearchRequest() {
134134
this.localClusterAlias = null;
@@ -358,6 +358,10 @@ public ActionRequestValidationException validate() {
358358
validationException = addValidationError("using [point in time] is not allowed in a scroll context", validationException);
359359
}
360360
}
361+
362+
if (queryPhaseTimeoutPercentage <= 0 || queryPhaseTimeoutPercentage >= 1) {
363+
validationException = addValidationError("[queryPhaseTimeoutPercentage] must be in (0, 1)", validationException);
364+
}
361365
return validationException;
362366
}
363367

@@ -722,21 +726,27 @@ public String pipeline() {
722726
return pipeline;
723727
}
724728

725-
726729
public void setQueryPhaseTimeoutPercentage(float queryPhaseTimeoutPercentage) {
727730
if (source.timeout() == null) {
728-
throw new IllegalArgumentException("timeout must be set before setting query phase timeout percentage");
731+
throw new IllegalArgumentException("timeout must be set before setting queryPhaseTimeoutPercentage");
729732
}
730733
this.queryPhaseTimeoutPercentage = queryPhaseTimeoutPercentage;
731734
}
732735

733-
public float getQueryPhasePercentage() {
734-
return queryPhaseTimeoutPercentage;
735-
}
736-
737736
@Override
738737
public SearchTask createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
739-
return new SearchTask(id, type, action, this::buildDescription, parentTaskId, headers, cancelAfterTimeInterval, startTimeMills, source.timeout() != null? source.timeout().millis() : -1, queryPhaseTimeoutPercentage);
738+
return new SearchTask(
739+
id,
740+
type,
741+
action,
742+
this::buildDescription,
743+
parentTaskId,
744+
headers,
745+
cancelAfterTimeInterval,
746+
startTimeMills,
747+
(source != null && source.timeout() != null) ? source.timeout().millis() : -1,
748+
queryPhaseTimeoutPercentage
749+
);
740750
}
741751

742752
public final String buildDescription() {
@@ -788,7 +798,8 @@ public boolean equals(Object o) {
788798
&& ccsMinimizeRoundtrips == that.ccsMinimizeRoundtrips
789799
&& Objects.equals(cancelAfterTimeInterval, that.cancelAfterTimeInterval)
790800
&& Objects.equals(pipeline, that.pipeline)
791-
&& Objects.equals(phaseTook, that.phaseTook);
801+
&& Objects.equals(phaseTook, that.phaseTook)
802+
&& Objects.equals(queryPhaseTimeoutPercentage, that.queryPhaseTimeoutPercentage);
792803
}
793804

794805
@Override
@@ -810,7 +821,8 @@ public int hashCode() {
810821
absoluteStartMillis,
811822
ccsMinimizeRoundtrips,
812823
cancelAfterTimeInterval,
813-
phaseTook
824+
phaseTook,
825+
queryPhaseTimeoutPercentage
814826
);
815827
}
816828

@@ -855,6 +867,8 @@ public String toString() {
855867
+ pipeline
856868
+ ", phaseTook="
857869
+ phaseTook
870+
+ ", queryPhaseTimeoutPercentage="
871+
+ queryPhaseTimeoutPercentage
858872
+ "}";
859873
}
860874
}

server/src/main/java/org/opensearch/action/search/SearchTask.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ public SearchTask(
8484
this.descriptionSupplier = descriptionSupplier;
8585
this.startTimeMills = startTimeMills;
8686
this.timeoutMills = timeoutMills;
87+
assert queryPhaseTimeoutPercentage > 0 && queryPhaseTimeoutPercentage <= 1;
8788
this.queryPhaseTimeoutPercentage = queryPhaseTimeoutPercentage;
8889
}
8990

server/src/main/java/org/opensearch/action/search/SearchTransportService.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -400,17 +400,17 @@ void sendExecuteMultiSearch(final MultiSearchRequest request, SearchTask task, f
400400
);
401401
}
402402

403-
public TransportRequestOptions getTransportRequestOptions(SearchTask task, Consumer<Exception> onFailure, boolean queryPhase) {
404-
if (task.timeoutMills() > 0) {
403+
static TransportRequestOptions getTransportRequestOptions(SearchTask task, Consumer<Exception> onFailure, boolean queryPhase) {
404+
if (task != null && task.timeoutMills() > 0) {
405405
long leftTimeMills;
406406
if (queryPhase) {
407-
//it's costly in query phase.
407+
// it's costly in query phase.
408408
leftTimeMills = task.queryPhaseTimeout() - (System.currentTimeMillis() - task.startTimeMills());
409409
} else {
410410
leftTimeMills = task.timeoutMills() - (System.currentTimeMillis() - task.startTimeMills());
411411
}
412412
if (leftTimeMills <= 0) {
413-
onFailure.accept(new TaskCancelledException("failed to execute fetch phase, timeout exceeded"));
413+
onFailure.accept(new TaskCancelledException("failed to execute fetch phase, timeout exceeded" + leftTimeMills + "ms"));
414414
return null;
415415
} else {
416416
return TransportRequestOptions.builder().withTimeout(leftTimeMills).build();

server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,9 @@ public static void parseSearchRequest(
225225

226226
searchRequest.setCancelAfterTimeInterval(request.paramAsTime("cancel_after_time_interval", null));
227227

228-
searchRequest.setQueryPhaseTimeoutPercentage(request.paramAsFloat("query_phase_timeout_percentage", SearchRequest.DEFAULT_QUERY_PHASE_TIMEOUT_PERCENTAGE));
228+
searchRequest.setQueryPhaseTimeoutPercentage(
229+
request.paramAsFloat("query_phase_timeout_percentage", SearchRequest.DEFAULT_QUERY_PHASE_TIMEOUT_PERCENTAGE)
230+
);
229231
}
230232

231233
/**

server/src/main/java/org/opensearch/search/SearchService.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -876,11 +876,6 @@ public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, A
876876
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest());
877877
final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest));
878878
runAsync(getExecutor(readerContext.indexShard()), () -> {
879-
if (request.getShardSearchRequest().shardId().getId() == 1) {
880-
try {
881-
Thread.sleep(10000);
882-
} catch (Exception e) {}
883-
}
884879
try (SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, false)) {
885880
if (request.lastEmittedDoc() != null) {
886881
searchContext.scrollContext().lastEmittedDoc = request.lastEmittedDoc();

server/src/main/java/org/opensearch/search/query/QueryPhase.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,6 @@ public void execute(SearchContext searchContext) throws QueryPhaseExecutionExcep
143143
return;
144144
}
145145

146-
if (searchContext.request().shardId().getId() == 2) {
147-
try {
148-
Thread.sleep(10000);
149-
} catch (Exception e) {}
150-
}
151-
152146
if (LOGGER.isTraceEnabled()) {
153147
LOGGER.trace("{}", new SearchContextSourcePrinter(searchContext));
154148
}

server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@
3636
import org.opensearch.action.OriginalIndices;
3737
import org.opensearch.action.support.IndicesOptions;
3838
import org.opensearch.cluster.ClusterState;
39+
import org.opensearch.cluster.node.DiscoveryNode;
3940
import org.opensearch.cluster.routing.GroupShardsIterator;
4041
import org.opensearch.common.UUIDs;
4142
import org.opensearch.common.collect.Tuple;
4243
import org.opensearch.common.settings.ClusterSettings;
4344
import org.opensearch.common.settings.Settings;
45+
import org.opensearch.common.unit.TimeValue;
4446
import org.opensearch.common.util.concurrent.AtomicArray;
4547
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
4648
import org.opensearch.common.util.set.Sets;
@@ -55,6 +57,7 @@
5557
import org.opensearch.index.shard.ShardNotFoundException;
5658
import org.opensearch.search.SearchPhaseResult;
5759
import org.opensearch.search.SearchShardTarget;
60+
import org.opensearch.search.builder.SearchSourceBuilder;
5861
import org.opensearch.search.internal.AliasFilter;
5962
import org.opensearch.search.internal.InternalSearchResponse;
6063
import org.opensearch.search.internal.ShardSearchContextId;
@@ -65,6 +68,7 @@
6568
import org.opensearch.test.OpenSearchTestCase;
6669
import org.opensearch.threadpool.TestThreadPool;
6770
import org.opensearch.threadpool.ThreadPool;
71+
import org.opensearch.transport.ReceiveTimeoutTransportException;
6872
import org.opensearch.transport.Transport;
6973
import org.junit.After;
7074
import org.junit.Before;
@@ -89,6 +93,9 @@
8993
import java.util.function.BiFunction;
9094
import java.util.stream.IntStream;
9195

96+
import org.mockito.Mockito;
97+
98+
import static org.opensearch.action.search.SearchTransportService.QUERY_ACTION_NAME;
9299
import static org.opensearch.tasks.TaskResourceTrackingService.TASK_RESOURCE_USAGE;
93100
import static org.hamcrest.Matchers.equalTo;
94101
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -138,6 +145,7 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
138145
false,
139146
expected,
140147
resourceUsage,
148+
false,
141149
new SearchShardIterator(null, null, Collections.emptyList(), null)
142150
);
143151
}
@@ -151,6 +159,7 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
151159
final boolean catchExceptionWhenExecutePhaseOnShard,
152160
final AtomicLong expected,
153161
final TaskResourceUsage resourceUsage,
162+
final boolean blockTheFirstQueryPhase,
154163
final SearchShardIterator... shards
155164
) {
156165

@@ -179,7 +188,7 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
179188
.setNodeId(randomAlphaOfLengthBetween(1, 5))
180189
.build();
181190
threadPool.getThreadContext().addResponseHeader(TASK_RESOURCE_USAGE, taskResourceInfo.toString());
182-
191+
AtomicBoolean firstShard = new AtomicBoolean(true);
183192
return new AbstractSearchAsyncAction<SearchPhaseResult>(
184193
"test",
185194
logger,
@@ -207,7 +216,13 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
207216
) {
208217
@Override
209218
protected SearchPhase getNextPhase(final SearchPhaseResults<SearchPhaseResult> results, SearchPhaseContext context) {
210-
return null;
219+
return new SearchPhase("test") {
220+
@Override
221+
public void run() {
222+
listener.onResponse(new SearchResponse(null, null, 0, 0, 0, 0, null, null));
223+
assertingListener.onPhaseEnd(context, null);
224+
}
225+
};
211226
}
212227

213228
@Override
@@ -218,6 +233,17 @@ protected void executePhaseOnShard(
218233
) {
219234
if (failExecutePhaseOnShard) {
220235
listener.onFailure(new ShardNotFoundException(shardIt.shardId()));
236+
} else if (blockTheFirstQueryPhase && firstShard.compareAndSet(true, false)) {
237+
// Sleep and throw ReceiveTimeoutTransportException to simulate node blocked
238+
try {
239+
Thread.sleep(request.source().timeout().millis());
240+
} catch (InterruptedException e) {}
241+
;
242+
DiscoveryNode node = Mockito.mock(DiscoveryNode.class);
243+
Mockito.when(node.getName()).thenReturn("test_nodes");
244+
listener.onFailure(
245+
new ReceiveTimeoutTransportException(node, QUERY_ACTION_NAME, "request_id [171] timed out after [413ms]")
246+
);
221247
} else {
222248
if (catchExceptionWhenExecutePhaseOnShard) {
223249
try {
@@ -227,6 +253,7 @@ protected void executePhaseOnShard(
227253
}
228254
} else {
229255
listener.onResponse(new QuerySearchResult());
256+
230257
}
231258
}
232259
}
@@ -587,6 +614,7 @@ public void onFailure(Exception e) {
587614
false,
588615
new AtomicLong(),
589616
new TaskResourceUsage(randomLong(), randomLong()),
617+
false,
590618
shards
591619
);
592620
action.run();
@@ -635,6 +663,7 @@ public void onFailure(Exception e) {
635663
false,
636664
new AtomicLong(),
637665
new TaskResourceUsage(randomLong(), randomLong()),
666+
false,
638667
shards
639668
);
640669
action.run();
@@ -688,6 +717,7 @@ public void onFailure(Exception e) {
688717
catchExceptionWhenExecutePhaseOnShard,
689718
new AtomicLong(),
690719
new TaskResourceUsage(randomLong(), randomLong()),
720+
false,
691721
shards
692722
);
693723
action.run();
@@ -791,6 +821,41 @@ public void testOnPhaseListenersWithDfsType() throws InterruptedException {
791821
assertEquals(0, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName()));
792822
}
793823

824+
public void testExecutePhaseOnShardBlockAndRetrunPartialResult() {
825+
// on shard is blocked in query phase
826+
final Index index = new Index("test", UUID.randomUUID().toString());
827+
828+
final SearchShardIterator[] shards = IntStream.range(0, 2 + randomInt(4))
829+
.mapToObj(i -> new SearchShardIterator(null, new ShardId(index, i), List.of("n1"), null, null, null))
830+
.toArray(SearchShardIterator[]::new);
831+
832+
SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true);
833+
searchRequest.source(new SearchSourceBuilder());
834+
long timeoutMills = 500;
835+
searchRequest.source().timeout(new TimeValue(timeoutMills, TimeUnit.MILLISECONDS));
836+
searchRequest.setMaxConcurrentShardRequests(shards.length);
837+
final AtomicBoolean successed = new AtomicBoolean(false);
838+
long current = System.currentTimeMillis();
839+
840+
final ArraySearchPhaseResults<SearchPhaseResult> queryResult = new ArraySearchPhaseResults<>(shards.length);
841+
AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(searchRequest, queryResult, new ActionListener<>() {
842+
@Override
843+
public void onResponse(SearchResponse response) {
844+
successed.set(true);
845+
}
846+
847+
@Override
848+
public void onFailure(Exception e) {
849+
successed.set(false);
850+
}
851+
}, false, false, false, new AtomicLong(), new TaskResourceUsage(randomLong(), randomLong()), true, shards);
852+
action.run();
853+
long s = System.currentTimeMillis() - current;
854+
assertTrue(s > timeoutMills);
855+
assertTrue(successed.get());
856+
857+
}
858+
794859
private SearchDfsQueryThenFetchAsyncAction createSearchDfsQueryThenFetchAsyncAction(
795860
List<SearchRequestOperationsListener> searchRequestOperationsListeners
796861
) {

server/src/test/java/org/opensearch/action/search/SearchRequestTests.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,15 @@ public void testValidate() throws IOException {
238238
assertEquals(1, validationErrors.validationErrors().size());
239239
assertEquals("using [point in time] is not allowed in a scroll context", validationErrors.validationErrors().get(0));
240240
}
241+
242+
{
243+
// queryPhaseTimeoutPercentage must be in (0, 1)
244+
SearchRequest searchRequest = createSearchRequest().source(new SearchSourceBuilder().timeout(TimeValue.timeValueMillis(10)));
245+
searchRequest.setQueryPhaseTimeoutPercentage(-1);
246+
ActionRequestValidationException validationErrors = searchRequest.validate();
247+
assertNotNull(validationErrors);
248+
assertEquals("[queryPhaseTimeoutPercentage] must be in (0, 1)", validationErrors.validationErrors().get(0));
249+
}
241250
}
242251

243252
public void testCopyConstructor() throws IOException {
@@ -261,6 +270,19 @@ public void testParseSearchRequestWithUnsupportedSearchType() throws IOException
261270
assertEquals("Unsupported search type [query_and_fetch]", exception.getMessage());
262271
}
263272

273+
public void testParseSearchRequestWithTimeoutAndQueryPhaseTimeoutPercentage() throws IOException {
274+
RestRequest restRequest = new FakeRestRequest();
275+
SearchRequest searchRequest = createSearchRequest().source(new SearchSourceBuilder());
276+
IntConsumer setSize = mock(IntConsumer.class);
277+
restRequest.params().put("query_phase_timeout_percentage", "30");
278+
279+
IllegalArgumentException exception = expectThrows(
280+
IllegalArgumentException.class,
281+
() -> RestSearchAction.parseSearchRequest(searchRequest, restRequest, null, namedWriteableRegistry, setSize)
282+
);
283+
assertEquals("timeout must be set before setting queryPhaseTimeoutPercentage", exception.getMessage());
284+
}
285+
264286
public void testEqualsAndHashcode() throws IOException {
265287
checkEqualsAndHashCode(createSearchRequest(), SearchRequest::new, this::mutate);
266288
}

0 commit comments

Comments
 (0)