Skip to content

Fix Semantic Query Rewrite Interception Drops Boosts #129282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 63 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
ccd64ae
fix boosting for knn
Samiul-TheSoccerFan Jun 11, 2025
9338cd5
Fixing for match query
Samiul-TheSoccerFan Jun 11, 2025
370931d
fixing for match subquery
Samiul-TheSoccerFan Jun 11, 2025
b85abda
fix for sparse vector query boost
Samiul-TheSoccerFan Jun 11, 2025
5db2686
fix linting issues
Samiul-TheSoccerFan Jun 11, 2025
2ce691e
Update docs/changelog/129282.yaml
Samiul-TheSoccerFan Jun 11, 2025
4100200
update changelog
Samiul-TheSoccerFan Jun 11, 2025
3406ae1
Copy constructor with match query
Samiul-TheSoccerFan Jun 12, 2025
d07952a
util function to create sparseVectorBuilder for sparse query
Samiul-TheSoccerFan Jun 12, 2025
f133632
util function for knn query to support boost
Samiul-TheSoccerFan Jun 12, 2025
a9048f0
adding unit tests for all intercepted query terms
Samiul-TheSoccerFan Jun 12, 2025
5a1dab9
Adding yaml test for match,sparse, and knn
Samiul-TheSoccerFan Jun 13, 2025
6cef441
Adding queryname support for nested query
Samiul-TheSoccerFan Jun 13, 2025
faa35ea
fix code styles
Samiul-TheSoccerFan Jun 13, 2025
675fb22
merge from main
Samiul-TheSoccerFan Jun 13, 2025
13e791e
Fix failed yaml tests
Samiul-TheSoccerFan Jun 13, 2025
3a5a30f
Update docs/changelog/129282.yaml
Samiul-TheSoccerFan Jun 13, 2025
016e448
update yaml tests to expand test scenarios
Samiul-TheSoccerFan Jun 16, 2025
70b228e
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jun 16, 2025
d5e7caa
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jun 16, 2025
efcf9c4
Updating knn to copy constructor
Samiul-TheSoccerFan Jun 25, 2025
00eb6ad
merge from main
Samiul-TheSoccerFan Jun 25, 2025
f449299
adding yaml tests for multiple indices
Samiul-TheSoccerFan Jun 27, 2025
6db0abf
refactoring match query to adjust boost and queryname and move to cop…
Samiul-TheSoccerFan Jun 27, 2025
daf2cb4
refactoring sparse query to adjust boost and queryname and move to co…
Samiul-TheSoccerFan Jun 27, 2025
b88b077
[CI] Auto commit changes from spotless
elasticsearchmachine Jun 27, 2025
9e725cb
Refactor sparse vector to adjust boost and queryname in the top level
Samiul-TheSoccerFan Jul 2, 2025
651ee2b
Refactor knn vector to adjust boost and queryname in the top level
Samiul-TheSoccerFan Jul 2, 2025
a356b44
merge from main
Samiul-TheSoccerFan Jul 2, 2025
71eac8d
fix knn combined query
Samiul-TheSoccerFan Jul 2, 2025
d71bf2c
fix unit tests
Samiul-TheSoccerFan Jul 2, 2025
675463c
fix lint issues
Samiul-TheSoccerFan Jul 2, 2025
201d27c
remove unused code
Samiul-TheSoccerFan Jul 2, 2025
daf2f6e
Update inference feature name
Samiul-TheSoccerFan Jul 3, 2025
2521b48
Remove double boosting issue from match
Samiul-TheSoccerFan Jul 3, 2025
61f9445
Fix double boosting in match test yaml file
Samiul-TheSoccerFan Jul 3, 2025
f4cadaa
move to bool level for match semantic boost
Samiul-TheSoccerFan Jul 3, 2025
08909de
fix double boosting for sparse vector
Samiul-TheSoccerFan Jul 3, 2025
37bfc43
fix double boosting for sparse vector in yaml test
Samiul-TheSoccerFan Jul 3, 2025
fa5cfe7
fix knn combined query
Samiul-TheSoccerFan Jul 3, 2025
0640631
fix knn combined query
Samiul-TheSoccerFan Jul 3, 2025
404efcf
fix sparse combined query
Samiul-TheSoccerFan Jul 3, 2025
f73285d
fix knn yaml test for combined query
Samiul-TheSoccerFan Jul 3, 2025
96f5aa6
refactoring unit tests
Samiul-TheSoccerFan Jul 4, 2025
3065e5b
linting
Samiul-TheSoccerFan Jul 4, 2025
828d8c2
fix match query unit test
Samiul-TheSoccerFan Jul 4, 2025
6a2e0a5
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 4, 2025
bde54df
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 7, 2025
d08dbdd
adding copy constructor for match query
Samiul-TheSoccerFan Jul 8, 2025
fa955c3
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 8, 2025
916b1cc
refactor copy match builder to intercepter
Samiul-TheSoccerFan Jul 8, 2025
d9ef867
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 8, 2025
cde55d1
resolve conflicts from main
Samiul-TheSoccerFan Jul 9, 2025
8ddda3c
[CI] Auto commit changes from spotless
elasticsearchmachine Jul 9, 2025
873efdb
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 9, 2025
44b8aa9
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 10, 2025
104f16b
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 11, 2025
2a96d52
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 11, 2025
5dcfc1b
fix unit tests
Samiul-TheSoccerFan Jul 11, 2025
469f598
update yaml tests
Samiul-TheSoccerFan Jul 11, 2025
375ae36
fix match yaml test
Samiul-TheSoccerFan Jul 11, 2025
c81f184
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 14, 2025
394f43a
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/129282.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 129282
summary: "Fix query rewrite logic to preserve `boosts` and `queryName` for `match`,\
\ `knn`, and `sparse_vector` queries on semantic_text fields"
area: Search
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ public class InferenceFeatures implements FeatureSpecification {
private static final NodeFeature TEST_RULE_RETRIEVER_WITH_INDICES_THAT_DONT_RETURN_RANK_DOCS = new NodeFeature(
"test_rule_retriever.with_indices_that_dont_return_rank_docs"
);
private static final NodeFeature SEMANTIC_QUERY_REWRITE_INTERCEPTORS_PROPAGATE_BOOST_AND_QUERY_NAME_FIX = new NodeFeature(
"semantic_query_rewrite_interceptors.propagate_boost_and_query_name_fix"
);
private static final NodeFeature SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER = new NodeFeature("semantic_text.match_all_highlighter");
private static final NodeFeature COHERE_V2_API = new NodeFeature("inference.cohere.v2");

Expand Down Expand Up @@ -68,7 +71,8 @@ public Set<NodeFeature> getTestFeatures() {
SEMANTIC_TEXT_EXCLUDE_SUB_FIELDS_FROM_FIELD_CAPS,
SEMANTIC_TEXT_INDEX_OPTIONS,
COHERE_V2_API,
SEMANTIC_TEXT_INDEX_OPTIONS_WITH_DEFAULTS
SEMANTIC_TEXT_INDEX_OPTIONS_WITH_DEFAULTS,
SEMANTIC_QUERY_REWRITE_INTERCEPTORS_PROPAGATE_BOOST_AND_QUERY_NAME_FIX
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,20 @@ protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceI
assert (queryBuilder instanceof KnnVectorQueryBuilder);
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder;
Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
QueryBuilder finalQueryBuilder;
if (inferenceIdsIndices.size() == 1) {
// Simple case, everything uses the same inference ID
Map.Entry<String, List<String>> inferenceIdIndex = inferenceIdsIndices.entrySet().iterator().next();
String searchInferenceId = inferenceIdIndex.getKey();
List<String> indices = inferenceIdIndex.getValue();
return buildNestedQueryFromKnnVectorQuery(knnVectorQueryBuilder, indices, searchInferenceId);
finalQueryBuilder = buildNestedQueryFromKnnVectorQuery(knnVectorQueryBuilder, indices, searchInferenceId);
} else {
// Multiple inference IDs, construct a boolean query
return buildInferenceQueryWithMultipleInferenceIds(knnVectorQueryBuilder, inferenceIdsIndices);
finalQueryBuilder = buildInferenceQueryWithMultipleInferenceIds(knnVectorQueryBuilder, inferenceIdsIndices);
}
finalQueryBuilder.boost(queryBuilder.boost());
finalQueryBuilder.queryName(queryBuilder.queryName());
return finalQueryBuilder;
}

private QueryBuilder buildInferenceQueryWithMultipleInferenceIds(
Expand Down Expand Up @@ -102,6 +106,8 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
)
);
}
boolQueryBuilder.boost(queryBuilder.boost());
boolQueryBuilder.queryName(queryBuilder.queryName());
return boolQueryBuilder;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ protected String getQuery(QueryBuilder queryBuilder) {

@Override
protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
return new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
semanticQueryBuilder.boost(queryBuilder.boost());
semanticQueryBuilder.queryName(queryBuilder.queryName());
return semanticQueryBuilder;
}

@Override
Expand All @@ -45,7 +48,10 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
InferenceIndexInformationForField indexInformation
) {
assert (queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
MatchQueryBuilder originalMatchQueryBuilder = (MatchQueryBuilder) queryBuilder;
// Create a copy for non-inference fields without boost and _name
MatchQueryBuilder matchQueryBuilder = copyMatchQueryBuilder(originalMatchQueryBuilder);

BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.should(
createSemanticSubQuery(
Expand All @@ -55,11 +61,33 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
)
);
boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), matchQueryBuilder));
boolQueryBuilder.boost(queryBuilder.boost());
boolQueryBuilder.queryName(queryBuilder.queryName());
return boolQueryBuilder;
}

@Override
public String getQueryName() {
return MatchQueryBuilder.NAME;
}

private MatchQueryBuilder copyMatchQueryBuilder(MatchQueryBuilder queryBuilder) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking feedback: This PR has been around for a while, and I'm OK with this as written if we just want to get it in. However, future maintainability would be a concern of mine (what if we add a new param to match, such as prefiltering?). I'd almost suggest creating a new constructor in MatchQueryBuilder instead of this private method.

MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder(queryBuilder.fieldName(), queryBuilder.value());
matchQueryBuilder.operator(queryBuilder.operator());
matchQueryBuilder.prefixLength(queryBuilder.prefixLength());
matchQueryBuilder.maxExpansions(queryBuilder.maxExpansions());
matchQueryBuilder.fuzzyTranspositions(queryBuilder.fuzzyTranspositions());
matchQueryBuilder.lenient(queryBuilder.lenient());
matchQueryBuilder.zeroTermsQuery(queryBuilder.zeroTermsQuery());
matchQueryBuilder.analyzer(queryBuilder.analyzer());
matchQueryBuilder.minimumShouldMatch(queryBuilder.minimumShouldMatch());
matchQueryBuilder.fuzzyRewrite(queryBuilder.fuzzyRewrite());

if (queryBuilder.fuzziness() != null) {
matchQueryBuilder.fuzziness(queryBuilder.fuzziness());
}

matchQueryBuilder.autoGenerateSynonymsPhraseQuery(queryBuilder.autoGenerateSynonymsPhraseQuery());
return matchQueryBuilder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,18 @@ protected String getQuery(QueryBuilder queryBuilder) {
@Override
protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
QueryBuilder finalQueryBuilder;
if (inferenceIdsIndices.size() == 1) {
// Simple case, everything uses the same inference ID
String searchInferenceId = inferenceIdsIndices.keySet().iterator().next();
return buildNestedQueryFromSparseVectorQuery(queryBuilder, searchInferenceId);
finalQueryBuilder = buildNestedQueryFromSparseVectorQuery(queryBuilder, searchInferenceId);
} else {
// Multiple inference IDs, construct a boolean query
return buildInferenceQueryWithMultipleInferenceIds(queryBuilder, inferenceIdsIndices);
finalQueryBuilder = buildInferenceQueryWithMultipleInferenceIds(queryBuilder, inferenceIdsIndices);
}
finalQueryBuilder.queryName(queryBuilder.queryName());
finalQueryBuilder.boost(queryBuilder.boost());
return finalQueryBuilder;
}

private QueryBuilder buildInferenceQueryWithMultipleInferenceIds(
Expand Down Expand Up @@ -79,7 +83,19 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();

BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), sparseVectorQueryBuilder));
boolQueryBuilder.should(
createSubQueryForIndices(
indexInformation.nonInferenceIndices(),
new SparseVectorQueryBuilder(
sparseVectorQueryBuilder.getFieldName(),
sparseVectorQueryBuilder.getQueryVectors(),
sparseVectorQueryBuilder.getInferenceId(),
sparseVectorQueryBuilder.getQuery(),
sparseVectorQueryBuilder.shouldPruneTokens(),
sparseVectorQueryBuilder.getTokenPruningConfig()
)
)
);
// We always perform nested subqueries on semantic_text fields, to support
// sparse_vector queries using query vectors.
for (String inferenceId : inferenceIdsIndices.keySet()) {
Expand All @@ -90,6 +106,8 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
)
);
}
boolQueryBuilder.boost(queryBuilder.boost());
boolQueryBuilder.queryName(queryBuilder.queryName());
return boolQueryBuilder;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,24 +54,32 @@ public void cleanup() {
}

public void testKnnQueryWithVectorBuilderIsInterceptedAndRewritten() throws IOException {
float boost = randomFloatBetween(1, 10, true);
String queryName = randomAlphaOfLength(5);
Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
FIELD_NAME,
new InferenceFieldMetadata(index.getName(), INFERENCE_ID, new String[] { FIELD_NAME }, null)
);
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(INFERENCE_ID, QUERY);
KnnVectorQueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, null);
original.boost(boost);
original.queryName(queryName);
testRewrittenInferenceQuery(context, original);
}

public void testKnnWithQueryBuilderWithoutInferenceIdIsInterceptedAndRewritten() throws IOException {
float boost = randomFloatBetween(1, 10, true);
String queryName = randomAlphaOfLength(5);
Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
FIELD_NAME,
new InferenceFieldMetadata(index.getName(), INFERENCE_ID, new String[] { FIELD_NAME }, null)
);
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(null, QUERY);
KnnVectorQueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, null);
original.boost(boost);
original.queryName(queryName);
testRewrittenInferenceQuery(context, original);
}

Expand All @@ -82,14 +90,23 @@ private void testRewrittenInferenceQuery(QueryRewriteContext context, KnnVectorQ
rewritten instanceof InterceptedQueryBuilderWrapper
);
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
assertEquals(original.boost(), intercepted.boost(), 0.0f);
assertEquals(original.queryName(), intercepted.queryName());
assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);

NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
assertEquals(original.boost(), nestedQueryBuilder.boost(), 0.0f);
assertEquals(original.queryName(), nestedQueryBuilder.queryName());
assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());

QueryBuilder innerQuery = nestedQueryBuilder.query();
assertTrue(innerQuery instanceof KnnVectorQueryBuilder);
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) innerQuery;
assertEquals(1.0f, knnVectorQueryBuilder.boost(), 0.0f);
assertNull(knnVectorQueryBuilder.queryName());
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), knnVectorQueryBuilder.getFieldName());
assertTrue(knnVectorQueryBuilder.queryVectorBuilder() instanceof TextEmbeddingQueryVectorBuilder);

TextEmbeddingQueryVectorBuilder textEmbeddingQueryVectorBuilder = (TextEmbeddingQueryVectorBuilder) knnVectorQueryBuilder
.queryVectorBuilder();
assertEquals(QUERY, textEmbeddingQueryVectorBuilder.getModelText());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ public class SemanticMatchQueryRewriteInterceptorTests extends ESTestCase {

private static final String FIELD_NAME = "fieldName";
private static final String VALUE = "value";
private static final String QUERY_NAME = "match_query";
private static final float BOOST = 5.0f;

@Before
public void setup() {
Expand Down Expand Up @@ -79,10 +81,38 @@ public void testMatchQueryOnNonInferenceFieldRemainsMatchQuery() throws IOExcept
assertEquals(original, rewritten);
}

public void testBoostAndQueryNameInMatchQueryRewrite() throws IOException {
Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
FIELD_NAME,
new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }, null)
);
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
QueryBuilder original = createTestQueryBuilderWithBoostAndQueryName();
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
rewritten instanceof InterceptedQueryBuilderWrapper
);
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
assertEquals(BOOST, intercepted.boost(), 0.0f);
assertEquals(QUERY_NAME, intercepted.queryName());
assertTrue(intercepted.queryBuilder instanceof SemanticQueryBuilder);
SemanticQueryBuilder semanticQueryBuilder = (SemanticQueryBuilder) intercepted.queryBuilder;
assertEquals(FIELD_NAME, semanticQueryBuilder.getFieldName());
assertEquals(VALUE, semanticQueryBuilder.getQuery());
}

private MatchQueryBuilder createTestQueryBuilder() {
return new MatchQueryBuilder(FIELD_NAME, VALUE);
}

private MatchQueryBuilder createTestQueryBuilderWithBoostAndQueryName() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking nitpick: Since this is pretty simple and only used once, it doesn't need to be its own method. Maybe more readable to include the test query builder inside the test itself.

MatchQueryBuilder queryBuilder = createTestQueryBuilder();
queryBuilder.boost(BOOST);
queryBuilder.queryName(QUERY_NAME);
return queryBuilder;
}

private QueryRewriteContext createQueryRewriteContext(Map<String, InferenceFieldMetadata> inferenceFields) {
IndexMetadata indexMetadata = IndexMetadata.builder(index.getName())
.settings(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,62 +52,68 @@ public void cleanup() {
}

public void testSparseVectorQueryOnInferenceFieldIsInterceptedAndRewritten() throws IOException {
float boost = randomFloatBetween(1, 10, true);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do a randomBoolean() check to determine whether to set these boosts and names at all?

String queryName = randomAlphaOfLength(5);
Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
FIELD_NAME,
new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }, null)
);
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
rewritten instanceof InterceptedQueryBuilderWrapper
);
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);
NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
QueryBuilder innerQuery = nestedQueryBuilder.query();
assertTrue(innerQuery instanceof SparseVectorQueryBuilder);
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery;
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName());
assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId());
assertEquals(QUERY, sparseVectorQueryBuilder.getQuery());
original.boost(boost);
original.queryName(queryName);
testRewrittenInferenceQuery(context, original);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice consolidation 🙌

}

public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsInterceptedAndRewritten() throws IOException {
float boost = randomFloatBetween(1, 10, true);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same note on randomBoolean()

String queryName = randomAlphaOfLength(5);
Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
FIELD_NAME,
new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }, null)
);
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, null, QUERY);
original.boost(boost);
original.queryName(queryName);
testRewrittenInferenceQuery(context, original);
}

public void testSparseVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IOException {
QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to remain sparse_vector but was [" + rewritten.getClass().getName() + "]",
rewritten instanceof SparseVectorQueryBuilder
);
assertEquals(original, rewritten);
}

private void testRewrittenInferenceQuery(QueryRewriteContext context, QueryBuilder original) throws IOException {
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
rewritten instanceof InterceptedQueryBuilderWrapper
);
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
assertEquals(original.boost(), intercepted.boost(), 0.0f);
assertEquals(original.queryName(), intercepted.queryName());

assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);
NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
assertEquals(original.boost(), nestedQueryBuilder.boost(), 0.0f);
assertEquals(original.queryName(), nestedQueryBuilder.queryName());

QueryBuilder innerQuery = nestedQueryBuilder.query();
assertTrue(innerQuery instanceof SparseVectorQueryBuilder);
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery;
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName());
assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId());
assertEquals(QUERY, sparseVectorQueryBuilder.getQuery());
}

public void testSparseVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IOException {
QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to remain sparse_vector but was [" + rewritten.getClass().getName() + "]",
rewritten instanceof SparseVectorQueryBuilder
);
assertEquals(original, rewritten);
assertEquals(1.0f, sparseVectorQueryBuilder.boost(), 0.0f);
assertNull(sparseVectorQueryBuilder.queryName());
}

private QueryRewriteContext createQueryRewriteContext(Map<String, InferenceFieldMetadata> inferenceFields) {
Expand Down
Loading
Loading