Skip to content

Commit f82a256

Browse files
committed
[ESQL] Complete TEXT_EMBEDDING function integration
Integrates the TEXT_EMBEDDING function with the ESQL execution pipeline: - Update PreOptimizer to handle TEXT_EMBEDDING function evaluation - Add TextEmbedding function definition and type validation - Integrate with InferenceServices for model execution - Add comprehensive tests in PreOptimizerTests - Update session and execution components for async function support
1 parent c2e1033 commit f82a256

File tree

9 files changed

+254
-10
lines changed

9 files changed

+254
-10
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
2222
import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext;
2323
import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer;
24+
import org.elasticsearch.xpack.esql.optimizer.PreOptimizer;
2425
import org.elasticsearch.xpack.esql.planner.mapper.Mapper;
2526
import org.elasticsearch.xpack.esql.plugin.TransportActionServices;
2627
import org.elasticsearch.xpack.esql.querylog.EsqlQueryLog;
@@ -85,6 +86,7 @@ public void esql(
8586
indexResolver,
8687
enrichPolicyResolver,
8788
preAnalyzer,
89+
new PreOptimizer(services, foldContext),
8890
functionRegistry,
8991
new LogicalPlanOptimizer(new LogicalOptimizerContext(cfg, foldContext)),
9092
mapper,

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,10 @@ protected InferenceFunction(Source source, List<Expression> children) {
3636
public abstract TaskType taskType();
3737

3838
public abstract PlanType withInferenceResolutionError(String inferenceId, String error);
39+
40+
@Override
41+
public boolean foldable() {
42+
// Inference functions are not foldable and need to be evaluated using an async inference call.
43+
return false;
44+
}
3945
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,6 @@ protected TypeResolution resolveType() {
119119
return TypeResolution.TYPE_RESOLVED;
120120
}
121121

122-
@Override
123-
public boolean foldable() {
124-
// The function is foldable only if both arguments are foldable
125-
return inputText.foldable() && inferenceId.foldable();
126-
}
127-
128122
@Override
129123
public TaskType taskType() {
130124
return TaskType.TEXT_EMBEDDING;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceServices.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@ public InferenceServices(Client client, ThreadPool threadPool) {
3636
this.inferenceRunnerFactory = InferenceRunner.factory(client, threadPool);
3737
}
3838

39+
/**
40+
* Creates an inference runner with the specified execution configuration using the default configuration.
41+
*
42+
* @return A configured inference runner capable of executing inference requests
43+
*/
44+
public InferenceRunner inferenceRunner() {
45+
return inferenceRunner(InferenceExecutionConfig.DEFAULT);
46+
}
47+
3948
/**
4049
* Creates an inference runner with the specified execution configuration.
4150
*

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88
package org.elasticsearch.xpack.esql.optimizer;
99

1010
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.xpack.esql.core.expression.Expression;
12+
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
13+
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
14+
import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator;
15+
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
1116
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
17+
import org.elasticsearch.xpack.esql.plugin.TransportActionServices;
1218

1319
/**
1420
* The class is responsible for invoking any steps that need to be applied to the logical plan,
@@ -19,11 +25,35 @@
1925
*/
2026
public class PreOptimizer {
2127

22-
public PreOptimizer() {
28+
private final InferencePreOptimizer inferencePreOptimizer;
2329

30+
public PreOptimizer(TransportActionServices services, FoldContext foldContext) {
31+
this(services.inferenceRunner(), foldContext);
32+
}
33+
34+
PreOptimizer(InferenceRunner inferenceRunner, FoldContext foldContext) {
35+
this.inferencePreOptimizer = new InferencePreOptimizer(inferenceRunner, foldContext);
2436
}
2537

2638
public void preOptimize(LogicalPlan plan, ActionListener<LogicalPlan> listener) {
27-
listener.onResponse(plan);
39+
inferencePreOptimizer.foldInferenceFunctions(plan, listener);
40+
}
41+
42+
private static class InferencePreOptimizer {
43+
private final InferenceRunner inferenceRunner;
44+
private final FoldContext foldContext;
45+
46+
private InferencePreOptimizer(InferenceRunner inferenceRunner, FoldContext foldContext) {
47+
this.inferenceRunner = inferenceRunner;
48+
this.foldContext = foldContext;
49+
}
50+
51+
private void foldInferenceFunctions(LogicalPlan plan, ActionListener<LogicalPlan> listener) {
52+
plan.transformExpressionsUp(InferenceFunction.class, this::foldInferenceFunction, listener);
53+
}
54+
55+
private void foldInferenceFunction(InferenceFunction<?> inferenceFunction, ActionListener<Expression> listener) {
56+
InferenceFunctionEvaluator.get(inferenceFunction, inferenceRunner).eval(foldContext, listener);
57+
}
2858
}
2959
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.search.SearchService;
1515
import org.elasticsearch.transport.TransportService;
1616
import org.elasticsearch.usage.UsageService;
17+
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
1718
import org.elasticsearch.xpack.esql.inference.InferenceServices;
1819

1920
public record TransportActionServices(
@@ -25,4 +26,8 @@ public record TransportActionServices(
2526
IndexNameExpressionResolver indexNameExpressionResolver,
2627
UsageService usageService,
2728
InferenceServices inferenceServices
28-
) {}
29+
) {
30+
public InferenceRunner inferenceRunner() {
31+
return inferenceServices.inferenceRunner();
32+
}
33+
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ public EsqlSession(
170170
IndexResolver indexResolver,
171171
EnrichPolicyResolver enrichPolicyResolver,
172172
PreAnalyzer preAnalyzer,
173+
PreOptimizer preOptimizer,
173174
EsqlFunctionRegistry functionRegistry,
174175
LogicalPlanOptimizer logicalPlanOptimizer,
175176
Mapper mapper,
@@ -192,8 +193,8 @@ public EsqlSession(
192193
this.planTelemetry = planTelemetry;
193194
this.indicesExpressionGrouper = indicesExpressionGrouper;
194195
this.inferenceResolver = inferenceResolver;
195-
this.preOptimizer = new PreOptimizer();
196196
this.preMapper = new PreMapper(services);
197+
this.preOptimizer = preOptimizer;
197198
this.remoteClusterService = services.transportService().getRemoteClusterService();
198199
}
199200

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,7 @@ private ActualResults executePlan(BigArrays bigArrays) throws Exception {
584584
null,
585585
null,
586586
null,
587+
null,
587588
functionRegistry,
588589
new LogicalPlanOptimizer(new LogicalOptimizerContext(configuration, foldCtx)),
589590
mapper,
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.optimizer;
9+
10+
import org.apache.lucene.util.SetOnce;
11+
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.test.ESTestCase;
13+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
14+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingBitResults;
15+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
16+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
17+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
18+
import org.elasticsearch.xpack.esql.core.expression.Alias;
19+
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
20+
import org.elasticsearch.xpack.esql.core.expression.Literal;
21+
import org.elasticsearch.xpack.esql.core.tree.Source;
22+
import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
23+
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
24+
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
25+
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
26+
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
27+
import org.elasticsearch.xpack.esql.plan.logical.Eval;
28+
import org.elasticsearch.xpack.esql.plan.logical.Filter;
29+
30+
import java.util.List;
31+
32+
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
33+
import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute;
34+
import static org.elasticsearch.xpack.esql.EsqlTestUtils.of;
35+
import static org.elasticsearch.xpack.esql.EsqlTestUtils.relation;
36+
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
37+
import static org.hamcrest.Matchers.equalTo;
38+
import static org.hamcrest.Matchers.hasSize;
39+
import static org.hamcrest.Matchers.notNullValue;
40+
41+
public class PreOptimizerTests extends ESTestCase {
42+
43+
public void testEvalFunctionEmbeddingBytes() throws Exception {
44+
testEvalFunctionEmbedding(BYTES_EMBEDDING_MODEL);
45+
}
46+
47+
public void testEvalFunctionEmbeddingBits() throws Exception {
48+
testEvalFunctionEmbedding(BIT_EMBEDDING_MODEL);
49+
}
50+
51+
public void testEvalFunctionEmbeddingFloats() throws Exception {
52+
testEvalFunctionEmbedding(FLOAT_EMBEDDING_MODEL);
53+
}
54+
55+
public void testKnnFunctionEmbeddingBytes() throws Exception {
56+
testKnnFunctionEmbedding(BYTES_EMBEDDING_MODEL);
57+
}
58+
59+
public void testKnnFunctionEmbeddingBits() throws Exception {
60+
testKnnFunctionEmbedding(BIT_EMBEDDING_MODEL);
61+
}
62+
63+
public void testKnnFunctionEmbeddingFloats() throws Exception {
64+
testKnnFunctionEmbedding(FLOAT_EMBEDDING_MODEL);
65+
}
66+
67+
private void testEvalFunctionEmbedding(TextEmbeddingModelMock textEmbeddingModel) throws Exception {
68+
String inferenceId = randomUUID();
69+
String query = String.join(" ", randomArray(1, String[]::new, () -> randomAlphaOfLength(randomIntBetween(1, 10))));
70+
String fieldName = randomIdentifier();
71+
72+
PreOptimizer preOptimizer = new PreOptimizer(mockInferenceRunner(textEmbeddingModel), FoldContext.small());
73+
EsRelation relation = relation();
74+
Eval eval = new Eval(
75+
Source.EMPTY,
76+
relation,
77+
List.of(new Alias(Source.EMPTY, fieldName, new TextEmbedding(Source.EMPTY, of(query), of(inferenceId))))
78+
);
79+
80+
SetOnce<Object> preOptimizedPlanHolder = new SetOnce<>();
81+
preOptimizer.preOptimize(eval, ActionListener.wrap(preOptimizedPlanHolder::set, ESTestCase::fail));
82+
83+
assertBusy(() -> {
84+
assertThat(preOptimizedPlanHolder.get(), notNullValue());
85+
Eval preOptimizedEval = as(preOptimizedPlanHolder.get(), Eval.class);
86+
assertThat(preOptimizedEval.fields(), hasSize(1));
87+
assertThat(preOptimizedEval.fields().get(0).name(), equalTo(fieldName));
88+
Literal preOptimizedQuery = as(preOptimizedEval.fields().get(0).child(), Literal.class);
89+
assertThat(preOptimizedQuery.dataType(), equalTo(DENSE_VECTOR));
90+
assertThat(preOptimizedQuery.value(), equalTo(textEmbeddingModel.embedding(query)));
91+
});
92+
}
93+
94+
private void testKnnFunctionEmbedding(TextEmbeddingModelMock textEmbeddingModel) throws Exception {
95+
String inferenceId = randomUUID();
96+
String query = String.join(" ", randomArray(1, String[]::new, () -> randomAlphaOfLength(randomIntBetween(1, 10))));
97+
98+
PreOptimizer preOptimizer = new PreOptimizer(mockInferenceRunner(textEmbeddingModel), FoldContext.small());
99+
EsRelation relation = relation();
100+
Filter filter = new Filter(
101+
Source.EMPTY,
102+
relation,
103+
new Knn(Source.EMPTY, getFieldAttribute("a"), new TextEmbedding(Source.EMPTY, of(query), of(inferenceId)), of(10), null)
104+
);
105+
Knn knn = as(filter.condition(), Knn.class);
106+
107+
SetOnce<Object> preOptimizedHolder = new SetOnce<>();
108+
preOptimizer.preOptimize(filter, ActionListener.wrap(preOptimizedHolder::set, ESTestCase::fail));
109+
110+
assertBusy(() -> {
111+
assertThat(preOptimizedHolder.get(), notNullValue());
112+
Filter preOptimizedFilter = as(preOptimizedHolder.get(), Filter.class);
113+
Knn preOptimizedKnn = as(preOptimizedFilter.condition(), Knn.class);
114+
assertThat(preOptimizedKnn.field(), equalTo(knn.field()));
115+
assertThat(preOptimizedKnn.k(), equalTo(knn.k()));
116+
assertThat(preOptimizedKnn.options(), equalTo(knn.options()));
117+
118+
Literal preOptimizedQuery = as(preOptimizedKnn.query(), Literal.class);
119+
assertThat(preOptimizedQuery.dataType(), equalTo(DENSE_VECTOR));
120+
assertThat(preOptimizedQuery.value(), equalTo(textEmbeddingModel.embedding(query)));
121+
});
122+
}
123+
124+
private InferenceRunner mockInferenceRunner(TextEmbeddingModelMock textEmbeddingModel) {
125+
return new InferenceRunner() {
126+
@Override
127+
public void execute(InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
128+
listener.onResponse(new InferenceAction.Response(textEmbeddingModel.embeddingResults(request.getInput().getFirst())));
129+
}
130+
131+
@Override
132+
public void executeBulk(BulkInferenceRequestIterator requests, ActionListener<List<InferenceAction.Response>> listener) {
133+
listener.onFailure(
134+
new UnsupportedOperationException("executeBulk should not be invoked for plans without inference functions")
135+
);
136+
}
137+
};
138+
}
139+
140+
private interface TextEmbeddingModelMock {
141+
TextEmbeddingResults<?> embeddingResults(String input);
142+
143+
float[] embedding(String input);
144+
}
145+
146+
private static final TextEmbeddingModelMock FLOAT_EMBEDDING_MODEL = new TextEmbeddingModelMock() {
147+
public TextEmbeddingResults<?> embeddingResults(String input) {
148+
TextEmbeddingFloatResults.Embedding embedding = new TextEmbeddingFloatResults.Embedding(embedding(input));
149+
return new TextEmbeddingFloatResults(List.of(embedding));
150+
}
151+
152+
public float[] embedding(String input) {
153+
String[] tokens = input.split("\\s+");
154+
float[] embedding = new float[tokens.length];
155+
for (int i = 0; i < tokens.length; i++) {
156+
embedding[i] = tokens[i].length();
157+
}
158+
return embedding;
159+
}
160+
};
161+
162+
private static final TextEmbeddingModelMock BYTES_EMBEDDING_MODEL = new TextEmbeddingModelMock() {
163+
public TextEmbeddingResults<?> embeddingResults(String input) {
164+
TextEmbeddingByteResults.Embedding embedding = new TextEmbeddingByteResults.Embedding(bytes(input));
165+
return new TextEmbeddingBitResults(List.of(embedding));
166+
}
167+
168+
private byte[] bytes(String input) {
169+
return input.getBytes();
170+
}
171+
172+
public float[] embedding(String input) {
173+
return new TextEmbeddingByteResults.Embedding(bytes(input)).toFloatArray();
174+
}
175+
};
176+
177+
private static final TextEmbeddingModelMock BIT_EMBEDDING_MODEL = new TextEmbeddingModelMock() {
178+
public TextEmbeddingResults<?> embeddingResults(String input) {
179+
TextEmbeddingByteResults.Embedding embedding = new TextEmbeddingByteResults.Embedding(bytes(input));
180+
return new TextEmbeddingBitResults(List.of(embedding));
181+
}
182+
183+
private byte[] bytes(String input) {
184+
String[] tokens = input.split("\\s+");
185+
byte[] embedding = new byte[tokens.length];
186+
for (int i = 0; i < tokens.length; i++) {
187+
embedding[i] = (byte) (tokens[i].length() % 2);
188+
}
189+
return embedding;
190+
}
191+
192+
public float[] embedding(String input) {
193+
return new TextEmbeddingByteResults.Embedding(bytes(input)).toFloatArray();
194+
}
195+
};
196+
}

0 commit comments

Comments
 (0)