Skip to content

[8.19] Fix Semantic Query Rewrite Interception Drops Boosts (#129282) #131472

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/129282.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 129282
summary: "Fix query rewrite logic to preserve `boosts` and `queryName` for `match`,\
\ `knn`, and `sparse_vector` queries on semantic_text fields"
area: Search
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ public Set<NodeFeature> getFeatures() {
private static final NodeFeature TEST_RERANKING_SERVICE_PARSE_TEXT_AS_SCORE = new NodeFeature(
"test_reranking_service.parse_text_as_score"
);
private static final NodeFeature SEMANTIC_QUERY_REWRITE_INTERCEPTORS_PROPAGATE_BOOST_AND_QUERY_NAME_FIX = new NodeFeature(
"semantic_query_rewrite_interceptors.propagate_boost_and_query_name_fix"
);
private static final NodeFeature SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER = new NodeFeature("semantic_text.match_all_highlighter");
private static final NodeFeature COHERE_V2_API = new NodeFeature("inference.cohere.v2");

Expand Down Expand Up @@ -74,7 +77,8 @@ public Set<NodeFeature> getTestFeatures() {
SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER,
SEMANTIC_TEXT_EXCLUDE_SUB_FIELDS_FROM_FIELD_CAPS,
SEMANTIC_TEXT_INDEX_OPTIONS,
COHERE_V2_API
COHERE_V2_API,
SEMANTIC_QUERY_REWRITE_INTERCEPTORS_PROPAGATE_BOOST_AND_QUERY_NAME_FIX
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,20 @@ protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceI
assert (queryBuilder instanceof KnnVectorQueryBuilder);
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder;
Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
QueryBuilder finalQueryBuilder;
if (inferenceIdsIndices.size() == 1) {
// Simple case, everything uses the same inference ID
Map.Entry<String, List<String>> inferenceIdIndex = inferenceIdsIndices.entrySet().iterator().next();
String searchInferenceId = inferenceIdIndex.getKey();
List<String> indices = inferenceIdIndex.getValue();
return buildNestedQueryFromKnnVectorQuery(knnVectorQueryBuilder, indices, searchInferenceId);
finalQueryBuilder = buildNestedQueryFromKnnVectorQuery(knnVectorQueryBuilder, indices, searchInferenceId);
} else {
// Multiple inference IDs, construct a boolean query
return buildInferenceQueryWithMultipleInferenceIds(knnVectorQueryBuilder, inferenceIdsIndices);
finalQueryBuilder = buildInferenceQueryWithMultipleInferenceIds(knnVectorQueryBuilder, inferenceIdsIndices);
}
finalQueryBuilder.boost(queryBuilder.boost());
finalQueryBuilder.queryName(queryBuilder.queryName());
return finalQueryBuilder;
}

private QueryBuilder buildInferenceQueryWithMultipleInferenceIds(
Expand Down Expand Up @@ -102,6 +106,8 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
)
);
}
boolQueryBuilder.boost(queryBuilder.boost());
boolQueryBuilder.queryName(queryBuilder.queryName());
return boolQueryBuilder;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ protected String getQuery(QueryBuilder queryBuilder) {

@Override
protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
return new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
semanticQueryBuilder.boost(queryBuilder.boost());
semanticQueryBuilder.queryName(queryBuilder.queryName());
return semanticQueryBuilder;
}

@Override
Expand All @@ -45,7 +48,10 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
InferenceIndexInformationForField indexInformation
) {
assert (queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
MatchQueryBuilder originalMatchQueryBuilder = (MatchQueryBuilder) queryBuilder;
// Create a copy for non-inference fields without boost and _name
MatchQueryBuilder matchQueryBuilder = copyMatchQueryBuilder(originalMatchQueryBuilder);

BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.should(
createSemanticSubQuery(
Expand All @@ -55,11 +61,33 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
)
);
boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), matchQueryBuilder));
boolQueryBuilder.boost(queryBuilder.boost());
boolQueryBuilder.queryName(queryBuilder.queryName());
return boolQueryBuilder;
}

@Override
public String getQueryName() {
return MatchQueryBuilder.NAME;
}

private MatchQueryBuilder copyMatchQueryBuilder(MatchQueryBuilder queryBuilder) {
MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder(queryBuilder.fieldName(), queryBuilder.value());
matchQueryBuilder.operator(queryBuilder.operator());
matchQueryBuilder.prefixLength(queryBuilder.prefixLength());
matchQueryBuilder.maxExpansions(queryBuilder.maxExpansions());
matchQueryBuilder.fuzzyTranspositions(queryBuilder.fuzzyTranspositions());
matchQueryBuilder.lenient(queryBuilder.lenient());
matchQueryBuilder.zeroTermsQuery(queryBuilder.zeroTermsQuery());
matchQueryBuilder.analyzer(queryBuilder.analyzer());
matchQueryBuilder.minimumShouldMatch(queryBuilder.minimumShouldMatch());
matchQueryBuilder.fuzzyRewrite(queryBuilder.fuzzyRewrite());

if (queryBuilder.fuzziness() != null) {
matchQueryBuilder.fuzziness(queryBuilder.fuzziness());
}

matchQueryBuilder.autoGenerateSynonymsPhraseQuery(queryBuilder.autoGenerateSynonymsPhraseQuery());
return matchQueryBuilder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,18 @@ protected String getQuery(QueryBuilder queryBuilder) {
@Override
protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
QueryBuilder finalQueryBuilder;
if (inferenceIdsIndices.size() == 1) {
// Simple case, everything uses the same inference ID
String searchInferenceId = inferenceIdsIndices.keySet().iterator().next();
return buildNestedQueryFromSparseVectorQuery(queryBuilder, searchInferenceId);
finalQueryBuilder = buildNestedQueryFromSparseVectorQuery(queryBuilder, searchInferenceId);
} else {
// Multiple inference IDs, construct a boolean query
return buildInferenceQueryWithMultipleInferenceIds(queryBuilder, inferenceIdsIndices);
finalQueryBuilder = buildInferenceQueryWithMultipleInferenceIds(queryBuilder, inferenceIdsIndices);
}
finalQueryBuilder.queryName(queryBuilder.queryName());
finalQueryBuilder.boost(queryBuilder.boost());
return finalQueryBuilder;
}

private QueryBuilder buildInferenceQueryWithMultipleInferenceIds(
Expand Down Expand Up @@ -79,7 +83,19 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();

BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), sparseVectorQueryBuilder));
boolQueryBuilder.should(
createSubQueryForIndices(
indexInformation.nonInferenceIndices(),
new SparseVectorQueryBuilder(
sparseVectorQueryBuilder.getFieldName(),
sparseVectorQueryBuilder.getQueryVectors(),
sparseVectorQueryBuilder.getInferenceId(),
sparseVectorQueryBuilder.getQuery(),
sparseVectorQueryBuilder.shouldPruneTokens(),
sparseVectorQueryBuilder.getTokenPruningConfig()
)
)
);
// We always perform nested subqueries on semantic_text fields, to support
// sparse_vector queries using query vectors.
for (String inferenceId : inferenceIdsIndices.keySet()) {
Expand All @@ -90,6 +106,8 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
)
);
}
boolQueryBuilder.boost(queryBuilder.boost());
boolQueryBuilder.queryName(queryBuilder.queryName());
return boolQueryBuilder;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ public void testKnnQueryWithVectorBuilderIsInterceptedAndRewritten() throws IOEx
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(INFERENCE_ID, QUERY);
KnnVectorQueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, null);
if (randomBoolean()) {
float boost = randomFloatBetween(1, 10, randomBoolean());
original.boost(boost);
}
if (randomBoolean()) {
String queryName = randomAlphaOfLength(5);
original.queryName(queryName);
}
testRewrittenInferenceQuery(context, original);
}

Expand All @@ -72,6 +80,14 @@ public void testKnnWithQueryBuilderWithoutInferenceIdIsInterceptedAndRewritten()
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(null, QUERY);
KnnVectorQueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, null);
if (randomBoolean()) {
float boost = randomFloatBetween(1, 10, randomBoolean());
original.boost(boost);
}
if (randomBoolean()) {
String queryName = randomAlphaOfLength(5);
original.queryName(queryName);
}
testRewrittenInferenceQuery(context, original);
}

Expand All @@ -82,14 +98,23 @@ private void testRewrittenInferenceQuery(QueryRewriteContext context, KnnVectorQ
rewritten instanceof InterceptedQueryBuilderWrapper
);
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
assertEquals(original.boost(), intercepted.boost(), 0.0f);
assertEquals(original.queryName(), intercepted.queryName());
assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);

NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
assertEquals(original.boost(), nestedQueryBuilder.boost(), 0.0f);
assertEquals(original.queryName(), nestedQueryBuilder.queryName());
assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());

QueryBuilder innerQuery = nestedQueryBuilder.query();
assertTrue(innerQuery instanceof KnnVectorQueryBuilder);
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) innerQuery;
assertEquals(1.0f, knnVectorQueryBuilder.boost(), 0.0f);
assertNull(knnVectorQueryBuilder.queryName());
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), knnVectorQueryBuilder.getFieldName());
assertTrue(knnVectorQueryBuilder.queryVectorBuilder() instanceof TextEmbeddingQueryVectorBuilder);

TextEmbeddingQueryVectorBuilder textEmbeddingQueryVectorBuilder = (TextEmbeddingQueryVectorBuilder) knnVectorQueryBuilder
.queryVectorBuilder();
assertEquals(QUERY, textEmbeddingQueryVectorBuilder.getModelText());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ public class SemanticMatchQueryRewriteInterceptorTests extends ESTestCase {

private static final String FIELD_NAME = "fieldName";
private static final String VALUE = "value";
private static final String QUERY_NAME = "match_query";
private static final float BOOST = 5.0f;

@Before
public void setup() {
Expand Down Expand Up @@ -79,6 +81,29 @@ public void testMatchQueryOnNonInferenceFieldRemainsMatchQuery() throws IOExcept
assertEquals(original, rewritten);
}

public void testBoostAndQueryNameInMatchQueryRewrite() throws IOException {
Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
FIELD_NAME,
new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }, null)
);
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
QueryBuilder original = createTestQueryBuilder();
original.boost(BOOST);
original.queryName(QUERY_NAME);
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
rewritten instanceof InterceptedQueryBuilderWrapper
);
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
assertEquals(BOOST, intercepted.boost(), 0.0f);
assertEquals(QUERY_NAME, intercepted.queryName());
assertTrue(intercepted.queryBuilder instanceof SemanticQueryBuilder);
SemanticQueryBuilder semanticQueryBuilder = (SemanticQueryBuilder) intercepted.queryBuilder;
assertEquals(FIELD_NAME, semanticQueryBuilder.getFieldName());
assertEquals(VALUE, semanticQueryBuilder.getQuery());
}

private MatchQueryBuilder createTestQueryBuilder() {
return new MatchQueryBuilder(FIELD_NAME, VALUE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,15 @@ public void testSparseVectorQueryOnInferenceFieldIsInterceptedAndRewritten() thr
);
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
rewritten instanceof InterceptedQueryBuilderWrapper
);
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);
NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
QueryBuilder innerQuery = nestedQueryBuilder.query();
assertTrue(innerQuery instanceof SparseVectorQueryBuilder);
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery;
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName());
assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId());
assertEquals(QUERY, sparseVectorQueryBuilder.getQuery());
if (randomBoolean()) {
float boost = randomFloatBetween(1, 10, randomBoolean());
original.boost(boost);
}
if (randomBoolean()) {
String queryName = randomAlphaOfLength(5);
original.queryName(queryName);
}
testRewrittenInferenceQuery(context, original);
}

public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsInterceptedAndRewritten() throws IOException {
Expand All @@ -82,32 +76,52 @@ public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsIntercepted
);
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, null, QUERY);
if (randomBoolean()) {
float boost = randomFloatBetween(1, 10, randomBoolean());
original.boost(boost);
}
if (randomBoolean()) {
String queryName = randomAlphaOfLength(5);
original.queryName(queryName);
}
testRewrittenInferenceQuery(context, original);
}

public void testSparseVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IOException {
QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to remain sparse_vector but was [" + rewritten.getClass().getName() + "]",
rewritten instanceof SparseVectorQueryBuilder
);
assertEquals(original, rewritten);
}

private void testRewrittenInferenceQuery(QueryRewriteContext context, QueryBuilder original) throws IOException {
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
rewritten instanceof InterceptedQueryBuilderWrapper
);
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
assertEquals(original.boost(), intercepted.boost(), 0.0f);
assertEquals(original.queryName(), intercepted.queryName());

assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);
NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
assertEquals(original.boost(), nestedQueryBuilder.boost(), 0.0f);
assertEquals(original.queryName(), nestedQueryBuilder.queryName());

QueryBuilder innerQuery = nestedQueryBuilder.query();
assertTrue(innerQuery instanceof SparseVectorQueryBuilder);
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery;
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName());
assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId());
assertEquals(QUERY, sparseVectorQueryBuilder.getQuery());
}

public void testSparseVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IOException {
QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to remain sparse_vector but was [" + rewritten.getClass().getName() + "]",
rewritten instanceof SparseVectorQueryBuilder
);
assertEquals(original, rewritten);
assertEquals(1.0f, sparseVectorQueryBuilder.boost(), 0.0f);
assertNull(sparseVectorQueryBuilder.queryName());
}

private QueryRewriteContext createQueryRewriteContext(Map<String, InferenceFieldMetadata> inferenceFields) {
Expand Down
Loading