Skip to content

Commit 3b9454f

Browse files
committed
opensearch should returns partial results after the timeout in coordinate node when allow_partial_search_results is true
1 parent 9b3ee09 commit 3b9454f

File tree

7 files changed

+157
-17
lines changed

7 files changed

+157
-17
lines changed

server/src/main/java/org/opensearch/action/search/SearchRequest.java

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla
128128
// it's only been used in coordinator, so we don't need to serialize/deserialize it
129129
private long startTimeMills;
130130

131-
private float queryPhaseTimeoutPercentage;
131+
private float queryPhaseTimeoutPercentage = 0.8f;
132132

133133
public SearchRequest() {
134134
this.localClusterAlias = null;
@@ -358,6 +358,10 @@ public ActionRequestValidationException validate() {
358358
validationException = addValidationError("using [point in time] is not allowed in a scroll context", validationException);
359359
}
360360
}
361+
362+
if (queryPhaseTimeoutPercentage <= 0 || queryPhaseTimeoutPercentage >= 1) {
363+
validationException = addValidationError("[queryPhaseTimeoutPercentage] must be in (0, 1)", validationException);
364+
}
361365
return validationException;
362366
}
363367

@@ -722,21 +726,27 @@ public String pipeline() {
722726
return pipeline;
723727
}
724728

725-
726729
public void setQueryPhaseTimeoutPercentage(float queryPhaseTimeoutPercentage) {
727730
if (source.timeout() == null) {
728-
throw new IllegalArgumentException("timeout must be set before setting query phase timeout percentage");
731+
throw new IllegalArgumentException("timeout must be set before setting queryPhaseTimeoutPercentage");
729732
}
730733
this.queryPhaseTimeoutPercentage = queryPhaseTimeoutPercentage;
731734
}
732735

733-
public float getQueryPhasePercentage() {
734-
return queryPhaseTimeoutPercentage;
735-
}
736-
737736
@Override
738737
public SearchTask createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
739-
return new SearchTask(id, type, action, this::buildDescription, parentTaskId, headers, cancelAfterTimeInterval, startTimeMills, source.timeout() != null? source.timeout().millis() : -1, queryPhaseTimeoutPercentage);
738+
return new SearchTask(
739+
id,
740+
type,
741+
action,
742+
this::buildDescription,
743+
parentTaskId,
744+
headers,
745+
cancelAfterTimeInterval,
746+
startTimeMills,
747+
(source != null && source.timeout() != null) ? source.timeout().millis() : -1,
748+
queryPhaseTimeoutPercentage
749+
);
740750
}
741751

742752
public final String buildDescription() {
@@ -788,7 +798,8 @@ public boolean equals(Object o) {
788798
&& ccsMinimizeRoundtrips == that.ccsMinimizeRoundtrips
789799
&& Objects.equals(cancelAfterTimeInterval, that.cancelAfterTimeInterval)
790800
&& Objects.equals(pipeline, that.pipeline)
791-
&& Objects.equals(phaseTook, that.phaseTook);
801+
&& Objects.equals(phaseTook, that.phaseTook)
802+
&& Objects.equals(queryPhaseTimeoutPercentage, that.queryPhaseTimeoutPercentage);
792803
}
793804

794805
@Override
@@ -810,7 +821,8 @@ public int hashCode() {
810821
absoluteStartMillis,
811822
ccsMinimizeRoundtrips,
812823
cancelAfterTimeInterval,
813-
phaseTook
824+
phaseTook,
825+
queryPhaseTimeoutPercentage
814826
);
815827
}
816828

@@ -855,6 +867,8 @@ public String toString() {
855867
+ pipeline
856868
+ ", phaseTook="
857869
+ phaseTook
870+
+ ", queryPhaseTimeoutPercentage="
871+
+ queryPhaseTimeoutPercentage
858872
+ "}";
859873
}
860874
}

server/src/main/java/org/opensearch/action/search/SearchTask.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ public SearchTask(
8484
this.descriptionSupplier = descriptionSupplier;
8585
this.startTimeMills = startTimeMills;
8686
this.timeoutMills = timeoutMills;
87+
assert queryPhaseTimeoutPercentage > 0 && queryPhaseTimeoutPercentage <= 1;
8788
this.queryPhaseTimeoutPercentage = queryPhaseTimeoutPercentage;
8889
}
8990

server/src/main/java/org/opensearch/action/search/SearchTransportService.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -400,17 +400,17 @@ void sendExecuteMultiSearch(final MultiSearchRequest request, SearchTask task, f
400400
);
401401
}
402402

403-
public TransportRequestOptions getTransportRequestOptions(SearchTask task, Consumer<Exception> onFailure, boolean queryPhase) {
404-
if (task.timeoutMills() > 0) {
403+
static TransportRequestOptions getTransportRequestOptions(SearchTask task, Consumer<Exception> onFailure, boolean queryPhase) {
404+
if (task != null && task.timeoutMills() > 0) {
405405
long leftTimeMills;
406406
if (queryPhase) {
407-
//it's costly in query phase.
407+
// it's costly in query phase.
408408
leftTimeMills = task.queryPhaseTimeout() - (System.currentTimeMillis() - task.startTimeMills());
409409
} else {
410410
leftTimeMills = task.timeoutMills() - (System.currentTimeMillis() - task.startTimeMills());
411411
}
412412
if (leftTimeMills <= 0) {
413-
onFailure.accept(new TaskCancelledException("failed to execute fetch phase, timeout exceeded"));
413+
onFailure.accept(new TaskCancelledException("failed to execute fetch phase, timeout exceeded" + leftTimeMills + "ms"));
414414
return null;
415415
} else {
416416
return TransportRequestOptions.builder().withTimeout(leftTimeMills).build();

server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,9 @@ public static void parseSearchRequest(
225225

226226
searchRequest.setCancelAfterTimeInterval(request.paramAsTime("cancel_after_time_interval", null));
227227

228-
searchRequest.setQueryPhaseTimeoutPercentage(request.paramAsFloat("query_phase_timeout_percentage", SearchRequest.DEFAULT_QUERY_PHASE_TIMEOUT_PERCENTAGE));
228+
searchRequest.setQueryPhaseTimeoutPercentage(
229+
request.paramAsFloat("query_phase_timeout_percentage", SearchRequest.DEFAULT_QUERY_PHASE_TIMEOUT_PERCENTAGE)
230+
);
229231
}
230232

231233
/**

server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@
3636
import org.opensearch.action.OriginalIndices;
3737
import org.opensearch.action.support.IndicesOptions;
3838
import org.opensearch.cluster.ClusterState;
39+
import org.opensearch.cluster.node.DiscoveryNode;
3940
import org.opensearch.cluster.routing.GroupShardsIterator;
4041
import org.opensearch.common.UUIDs;
4142
import org.opensearch.common.collect.Tuple;
4243
import org.opensearch.common.settings.ClusterSettings;
4344
import org.opensearch.common.settings.Settings;
45+
import org.opensearch.common.unit.TimeValue;
4446
import org.opensearch.common.util.concurrent.AtomicArray;
4547
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
4648
import org.opensearch.common.util.set.Sets;
@@ -55,6 +57,7 @@
5557
import org.opensearch.index.shard.ShardNotFoundException;
5658
import org.opensearch.search.SearchPhaseResult;
5759
import org.opensearch.search.SearchShardTarget;
60+
import org.opensearch.search.builder.SearchSourceBuilder;
5861
import org.opensearch.search.internal.AliasFilter;
5962
import org.opensearch.search.internal.InternalSearchResponse;
6063
import org.opensearch.search.internal.ShardSearchContextId;
@@ -65,6 +68,7 @@
6568
import org.opensearch.test.OpenSearchTestCase;
6669
import org.opensearch.threadpool.TestThreadPool;
6770
import org.opensearch.threadpool.ThreadPool;
71+
import org.opensearch.transport.ReceiveTimeoutTransportException;
6872
import org.opensearch.transport.Transport;
6973
import org.junit.After;
7074
import org.junit.Before;
@@ -89,6 +93,9 @@
8993
import java.util.function.BiFunction;
9094
import java.util.stream.IntStream;
9195

96+
import org.mockito.Mockito;
97+
98+
import static org.opensearch.action.search.SearchTransportService.QUERY_ACTION_NAME;
9299
import static org.opensearch.tasks.TaskResourceTrackingService.TASK_RESOURCE_USAGE;
93100
import static org.hamcrest.Matchers.equalTo;
94101
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -138,6 +145,7 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
138145
false,
139146
expected,
140147
resourceUsage,
148+
false,
141149
new SearchShardIterator(null, null, Collections.emptyList(), null)
142150
);
143151
}
@@ -151,6 +159,7 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
151159
final boolean catchExceptionWhenExecutePhaseOnShard,
152160
final AtomicLong expected,
153161
final TaskResourceUsage resourceUsage,
162+
final boolean blockTheFirstQueryPhase,
154163
final SearchShardIterator... shards
155164
) {
156165

@@ -179,7 +188,7 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
179188
.setNodeId(randomAlphaOfLengthBetween(1, 5))
180189
.build();
181190
threadPool.getThreadContext().addResponseHeader(TASK_RESOURCE_USAGE, taskResourceInfo.toString());
182-
191+
AtomicBoolean firstShard = new AtomicBoolean(true);
183192
return new AbstractSearchAsyncAction<SearchPhaseResult>(
184193
"test",
185194
logger,
@@ -207,7 +216,13 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
207216
) {
208217
@Override
209218
protected SearchPhase getNextPhase(final SearchPhaseResults<SearchPhaseResult> results, SearchPhaseContext context) {
210-
return null;
219+
return new SearchPhase("test") {
220+
@Override
221+
public void run() {
222+
listener.onResponse(new SearchResponse(null, null, 0, 0, 0, 0, null, null));
223+
assertingListener.onPhaseEnd(context, null);
224+
}
225+
};
211226
}
212227

213228
@Override
@@ -218,6 +233,17 @@ protected void executePhaseOnShard(
218233
) {
219234
if (failExecutePhaseOnShard) {
220235
listener.onFailure(new ShardNotFoundException(shardIt.shardId()));
236+
} else if (blockTheFirstQueryPhase && firstShard.compareAndSet(true, false)) {
237+
// Sleep and throw ReceiveTimeoutTransportException to simulate node blocked
238+
try {
239+
Thread.sleep(request.source().timeout().millis());
240+
} catch (InterruptedException e) {}
241+
;
242+
DiscoveryNode node = Mockito.mock(DiscoveryNode.class);
243+
Mockito.when(node.getName()).thenReturn("test_nodes");
244+
listener.onFailure(
245+
new ReceiveTimeoutTransportException(node, QUERY_ACTION_NAME, "request_id [171] timed out after [413ms]")
246+
);
221247
} else {
222248
if (catchExceptionWhenExecutePhaseOnShard) {
223249
try {
@@ -227,6 +253,7 @@ protected void executePhaseOnShard(
227253
}
228254
} else {
229255
listener.onResponse(new QuerySearchResult());
256+
230257
}
231258
}
232259
}
@@ -587,6 +614,7 @@ public void onFailure(Exception e) {
587614
false,
588615
new AtomicLong(),
589616
new TaskResourceUsage(randomLong(), randomLong()),
617+
false,
590618
shards
591619
);
592620
action.run();
@@ -635,6 +663,7 @@ public void onFailure(Exception e) {
635663
false,
636664
new AtomicLong(),
637665
new TaskResourceUsage(randomLong(), randomLong()),
666+
false,
638667
shards
639668
);
640669
action.run();
@@ -688,6 +717,7 @@ public void onFailure(Exception e) {
688717
catchExceptionWhenExecutePhaseOnShard,
689718
new AtomicLong(),
690719
new TaskResourceUsage(randomLong(), randomLong()),
720+
false,
691721
shards
692722
);
693723
action.run();
@@ -791,6 +821,41 @@ public void testOnPhaseListenersWithDfsType() throws InterruptedException {
791821
assertEquals(0, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName()));
792822
}
793823

824+
public void testExecutePhaseOnShardBlockAndRetrunPartialResult() {
825+
// on shard is blocked in query phase
826+
final Index index = new Index("test", UUID.randomUUID().toString());
827+
828+
final SearchShardIterator[] shards = IntStream.range(0, 2 + randomInt(4))
829+
.mapToObj(i -> new SearchShardIterator(null, new ShardId(index, i), List.of("n1"), null, null, null))
830+
.toArray(SearchShardIterator[]::new);
831+
832+
SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true);
833+
searchRequest.source(new SearchSourceBuilder());
834+
long timeoutMills = 500;
835+
searchRequest.source().timeout(new TimeValue(timeoutMills, TimeUnit.MILLISECONDS));
836+
searchRequest.setMaxConcurrentShardRequests(shards.length);
837+
final AtomicBoolean successed = new AtomicBoolean(false);
838+
long current = System.currentTimeMillis();
839+
840+
final ArraySearchPhaseResults<SearchPhaseResult> queryResult = new ArraySearchPhaseResults<>(shards.length);
841+
AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(searchRequest, queryResult, new ActionListener<>() {
842+
@Override
843+
public void onResponse(SearchResponse response) {
844+
successed.set(true);
845+
}
846+
847+
@Override
848+
public void onFailure(Exception e) {
849+
successed.set(false);
850+
}
851+
}, false, false, false, new AtomicLong(), new TaskResourceUsage(randomLong(), randomLong()), true, shards);
852+
action.run();
853+
long s = System.currentTimeMillis() - current;
854+
assertTrue(s > timeoutMills);
855+
assertTrue(successed.get());
856+
857+
}
858+
794859
private SearchDfsQueryThenFetchAsyncAction createSearchDfsQueryThenFetchAsyncAction(
795860
List<SearchRequestOperationsListener> searchRequestOperationsListeners
796861
) {

server/src/test/java/org/opensearch/action/search/SearchRequestTests.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,15 @@ public void testValidate() throws IOException {
238238
assertEquals(1, validationErrors.validationErrors().size());
239239
assertEquals("using [point in time] is not allowed in a scroll context", validationErrors.validationErrors().get(0));
240240
}
241+
242+
{
243+
// queryPhaseTimeoutPercentage must be in (0, 1)
244+
SearchRequest searchRequest = createSearchRequest().source(new SearchSourceBuilder().timeout(TimeValue.timeValueMillis(10)));
245+
searchRequest.setQueryPhaseTimeoutPercentage(-1);
246+
ActionRequestValidationException validationErrors = searchRequest.validate();
247+
assertNotNull(validationErrors);
248+
assertEquals("[queryPhaseTimeoutPercentage] must be in (0, 1)", validationErrors.validationErrors().get(0));
249+
}
241250
}
242251

243252
public void testCopyConstructor() throws IOException {
@@ -261,6 +270,19 @@ public void testParseSearchRequestWithUnsupportedSearchType() throws IOException
261270
assertEquals("Unsupported search type [query_and_fetch]", exception.getMessage());
262271
}
263272

273+
public void testParseSearchRequestWithTimeoutAndQueryPhaseTimeoutPercentage() throws IOException {
274+
RestRequest restRequest = new FakeRestRequest();
275+
SearchRequest searchRequest = createSearchRequest().source(new SearchSourceBuilder());
276+
IntConsumer setSize = mock(IntConsumer.class);
277+
restRequest.params().put("query_phase_timeout_percentage", "30");
278+
279+
IllegalArgumentException exception = expectThrows(
280+
IllegalArgumentException.class,
281+
() -> RestSearchAction.parseSearchRequest(searchRequest, restRequest, null, namedWriteableRegistry, setSize)
282+
);
283+
assertEquals("timeout must be set before setting queryPhaseTimeoutPercentage", exception.getMessage());
284+
}
285+
264286
public void testEqualsAndHashcode() throws IOException {
265287
checkEqualsAndHashCode(createSearchRequest(), SearchRequest::new, this::mutate);
266288
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
package org.opensearch.action.search;
10+
11+
import org.opensearch.core.tasks.TaskCancelledException;
12+
import org.opensearch.test.OpenSearchTestCase;
13+
import org.opensearch.transport.TransportRequestOptions;
14+
15+
public class SearchTransportServiceTests extends OpenSearchTestCase {
16+
public void testGetTransportRequestOptions() {
17+
SearchTask searchTask = new SearchTask(1, null, null, null, null, null, null, System.currentTimeMillis(), 1000, 0.8f);
18+
TransportRequestOptions transportRequestOptions = SearchTransportService.getTransportRequestOptions(searchTask, e -> {}, true);
19+
assertTrue(transportRequestOptions.timeout().millis() > 0);
20+
21+
TransportRequestOptions transportRequestOptions1 = SearchTransportService.getTransportRequestOptions(searchTask, e -> {}, false);
22+
assertTrue(transportRequestOptions.timeout().millis() < transportRequestOptions1.timeout().millis());
23+
24+
SearchTask searchTask1 = new SearchTask(1, null, null, null, null, null, null, System.currentTimeMillis(), 1, 0.8f);
25+
26+
transportRequestOptions = SearchTransportService.getTransportRequestOptions(searchTask1, exception -> {
27+
assertEquals(TaskCancelledException.class, exception.getClass());
28+
assertTrue(exception.getMessage().contains("failed to execute fetch phase, timeout exceeded"));
29+
}, true);
30+
assertNull(transportRequestOptions);
31+
32+
searchTask = new SearchTask(1, null, null, null, null, null, null, System.currentTimeMillis(), 0, 0.8f);
33+
transportRequestOptions = SearchTransportService.getTransportRequestOptions(searchTask, e -> {}, false);
34+
assertEquals(TransportRequestOptions.EMPTY, transportRequestOptions);
35+
}
36+
}

0 commit comments

Comments
 (0)