Skip to content

Commit 2cde121

Browse files
authored
Allow to get the search request from the QueryCoordinatorContext (#17890)
Signed-off-by: Bo Zhang <bzhangam@amazon.com>
1 parent 2ba6aac commit 2cde121

File tree

5 files changed

+124
-13
lines changed

5 files changed

+124
-13
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1313
- Add composite directory factory ([#17988](https://github.com/opensearch-project/OpenSearch/pull/17988))
1414
- Add pull-based ingestion error metrics and make internal queue size configurable ([#18088](https://github.com/opensearch-project/OpenSearch/pull/18088))
1515
- Enabled Async Shard Batch Fetch by default ([#18139](https://github.com/opensearch-project/OpenSearch/pull/18139))
16+
- Allow to get the search request from the QueryCoordinatorContext ([#17818](https://github.com/opensearch-project/OpenSearch/pull/17818))
1617

1718
### Changed
1819

server/src/main/java/org/opensearch/action/explain/TransportExplainAction.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,7 @@ protected void doExecute(Task task, ExplainRequest request, ActionListener<Expla
112112
request.query(rewrittenQuery);
113113
super.doExecute(task, request, listener);
114114
}, listener::onFailure);
115-
Rewriteable.rewriteAndFetch(
116-
request.query(),
117-
searchService.getIndicesService().getRewriteContext(() -> request.nowInMillis),
118-
rewriteListener
119-
);
115+
Rewriteable.rewriteAndFetch(request.query(), searchService.getRewriteContext(() -> request.nowInMillis, request), rewriteListener);
120116
}
121117

122118
@Override

server/src/main/java/org/opensearch/index/query/QueryCoordinatorContext.java

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88

99
package org.opensearch.index.query;
1010

11+
import org.opensearch.action.IndicesRequest;
1112
import org.opensearch.common.annotation.PublicApi;
1213
import org.opensearch.core.action.ActionListener;
1314
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
1415
import org.opensearch.core.xcontent.NamedXContentRegistry;
1516
import org.opensearch.search.pipeline.PipelinedRequest;
1617
import org.opensearch.transport.client.Client;
1718

19+
import java.util.Collections;
1820
import java.util.HashMap;
1921
import java.util.Map;
2022
import java.util.function.BiConsumer;
@@ -31,9 +33,9 @@
3133
@PublicApi(since = "2.19.0")
3234
public class QueryCoordinatorContext implements QueryRewriteContext {
3335
private final QueryRewriteContext rewriteContext;
34-
private final PipelinedRequest searchRequest;
36+
private final IndicesRequest searchRequest;
3537

36-
public QueryCoordinatorContext(QueryRewriteContext rewriteContext, PipelinedRequest searchRequest) {
38+
public QueryCoordinatorContext(QueryRewriteContext rewriteContext, IndicesRequest searchRequest) {
3739
this.rewriteContext = rewriteContext;
3840
this.searchRequest = searchRequest;
3941
}
@@ -84,10 +86,14 @@ public QueryCoordinatorContext convertToCoordinatorContext() {
8486
}
8587

8688
public Map<String, Object> getContextVariables() {
89+
if (searchRequest instanceof PipelinedRequest) {
90+
return new HashMap<>(((PipelinedRequest) searchRequest).getPipelineProcessingContext().getAttributes());
91+
} else {
92+
return Collections.emptyMap();
93+
}
94+
}
8795

88-
// Read from pipeline context
89-
Map<String, Object> contextVariables = new HashMap<>(searchRequest.getPipelineProcessingContext().getAttributes());
90-
91-
return contextVariables;
96+
public IndicesRequest getSearchRequest() {
97+
return searchRequest;
9298
}
9399
}

server/src/main/java/org/opensearch/search/SearchService.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.apache.lucene.search.TopDocs;
4040
import org.opensearch.OpenSearchException;
4141
import org.opensearch.action.ActionRunnable;
42+
import org.opensearch.action.IndicesRequest;
4243
import org.opensearch.action.OriginalIndices;
4344
import org.opensearch.action.search.DeletePitInfo;
4445
import org.opensearch.action.search.DeletePitResponse;
@@ -127,7 +128,6 @@
127128
import org.opensearch.search.internal.ShardSearchContextId;
128129
import org.opensearch.search.internal.ShardSearchRequest;
129130
import org.opensearch.search.lookup.SearchLookup;
130-
import org.opensearch.search.pipeline.PipelinedRequest;
131131
import org.opensearch.search.profile.Profilers;
132132
import org.opensearch.search.query.QueryPhase;
133133
import org.opensearch.search.query.QuerySearchRequest;
@@ -1785,7 +1785,7 @@ private void rewriteAndFetchShardRequest(IndexShard shard, ShardSearchRequest re
17851785
/**
17861786
* Returns a new {@link QueryRewriteContext} with the given {@code now} provider
17871787
*/
1788-
public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, PipelinedRequest searchRequest) {
1788+
public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, IndicesRequest searchRequest) {
17891789
return new QueryCoordinatorContext(indicesService.getRewriteContext(nowInMillis), searchRequest);
17901790
}
17911791

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
package org.opensearch.index.query;
10+
11+
import org.opensearch.action.IndicesRequest;
12+
import org.opensearch.action.search.SearchRequest;
13+
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
14+
import org.opensearch.cluster.service.ClusterService;
15+
import org.opensearch.common.settings.Settings;
16+
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
17+
import org.opensearch.common.util.concurrent.ThreadContext;
18+
import org.opensearch.plugins.SearchPipelinePlugin;
19+
import org.opensearch.search.pipeline.PipelinedRequest;
20+
import org.opensearch.search.pipeline.Processor;
21+
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
22+
import org.opensearch.search.pipeline.SearchPipelineService;
23+
import org.opensearch.search.pipeline.SearchRequestProcessor;
24+
import org.opensearch.search.pipeline.SearchResponseProcessor;
25+
import org.opensearch.test.OpenSearchTestCase;
26+
import org.opensearch.threadpool.ThreadPool;
27+
import org.opensearch.transport.client.Client;
28+
import org.junit.Before;
29+
30+
import java.util.Collections;
31+
import java.util.Map;
32+
import java.util.concurrent.ExecutorService;
33+
34+
import static org.mockito.ArgumentMatchers.anyString;
35+
import static org.mockito.Mockito.mock;
36+
import static org.mockito.Mockito.when;
37+
38+
public class QueryCoordinatorContextTests extends OpenSearchTestCase {
39+
40+
private IndexNameExpressionResolver indexNameExpressionResolver;
41+
42+
@Before
43+
public void setup() {
44+
indexNameExpressionResolver = new IndexNameExpressionResolver(new ThreadContext(Settings.EMPTY));
45+
}
46+
47+
public void testGetContextVariables_whenPipelinedSearchRequest_thenReturnVariables() {
48+
final PipelinedRequest searchRequest = createDummyPipelinedRequest();
49+
searchRequest.getPipelineProcessingContext().setAttribute("key", "value");
50+
51+
final QueryCoordinatorContext queryCoordinatorContext = new QueryCoordinatorContext(mock(QueryRewriteContext.class), searchRequest);
52+
53+
assertEquals(Map.of("key", "value"), queryCoordinatorContext.getContextVariables());
54+
}
55+
56+
private PipelinedRequest createDummyPipelinedRequest() {
57+
final Client client = mock(Client.class);
58+
final ThreadPool threadPool = mock(ThreadPool.class);
59+
final ExecutorService executorService = OpenSearchExecutors.newDirectExecutorService();
60+
when(threadPool.generic()).thenReturn(executorService);
61+
when(threadPool.executor(anyString())).thenReturn(executorService);
62+
final SearchPipelineService searchPipelineService = new SearchPipelineService(
63+
mock(ClusterService.class),
64+
threadPool,
65+
null,
66+
null,
67+
null,
68+
null,
69+
this.writableRegistry(),
70+
Collections.singletonList(new SearchPipelinePlugin() {
71+
@Override
72+
public Map<String, Processor.Factory<SearchRequestProcessor>> getRequestProcessors(Parameters parameters) {
73+
return Collections.emptyMap();
74+
}
75+
76+
@Override
77+
public Map<String, Processor.Factory<SearchResponseProcessor>> getResponseProcessors(Parameters parameters) {
78+
return Collections.emptyMap();
79+
}
80+
81+
@Override
82+
public Map<String, Processor.Factory<SearchPhaseResultsProcessor>> getSearchPhaseResultsProcessors(Parameters parameters) {
83+
return Collections.emptyMap();
84+
}
85+
86+
}),
87+
client
88+
);
89+
final SearchRequest searchRequest = new SearchRequest();
90+
return searchPipelineService.resolvePipeline(searchRequest, indexNameExpressionResolver);
91+
}
92+
93+
public void testGetContextVariables_whenNotPipelinedSearchRequest_thenReturnEmpty() {
94+
final IndicesRequest searchRequest = mock(IndicesRequest.class);
95+
96+
final QueryCoordinatorContext queryCoordinatorContext = new QueryCoordinatorContext(mock(QueryRewriteContext.class), searchRequest);
97+
98+
assertTrue(queryCoordinatorContext.getContextVariables().isEmpty());
99+
}
100+
101+
public void testGetSearchRequest() {
102+
final IndicesRequest searchRequest = mock(IndicesRequest.class);
103+
104+
final QueryCoordinatorContext queryCoordinatorContext = new QueryCoordinatorContext(mock(QueryRewriteContext.class), searchRequest);
105+
106+
assertEquals(searchRequest, queryCoordinatorContext.getSearchRequest());
107+
}
108+
}

0 commit comments

Comments
 (0)