Skip to content

Commit c81bfa7

Browse files
committed
Improved the inference pre-optimization.
1 parent 91fc05e commit c81bfa7

File tree

9 files changed

+678
-324
lines changed

9 files changed

+678
-324
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/InferenceFunction.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.xpack.esql.core.expression.Expression;
1212
import org.elasticsearch.xpack.esql.core.expression.function.Function;
1313
import org.elasticsearch.xpack.esql.core.tree.Source;
14+
import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator;
1415

1516
import java.util.List;
1617

@@ -35,5 +36,20 @@ protected InferenceFunction(Source source, List<Expression> children) {
3536
*/
3637
public abstract TaskType taskType();
3738

39+
/**
40+
* Returns a new instance of the function with the specified inference resolution error.
41+
*/
3842
public abstract PlanType withInferenceResolutionError(String inferenceId, String error);
43+
44+
/**
45+
* Returns the inference function evaluator factory.
46+
*/
47+
public abstract InferenceFunctionEvaluator.Factory inferenceEvaluatorFactory();
48+
49+
/**
50+
* Returns true if the function has a nested inference function.
51+
*/
52+
public boolean hasNestedInferenceFunction() {
53+
return anyMatch(e -> e instanceof InferenceFunction && e != this);
54+
}
3955
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/inference/TextEmbedding.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
2121
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
2222
import org.elasticsearch.xpack.esql.expression.function.Param;
23+
import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator;
24+
import org.elasticsearch.xpack.esql.inference.textembedding.TextEmbeddingFunctionEvaluator;
2325
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
2426

2527
import java.io.IOException;
@@ -129,6 +131,11 @@ public TaskType taskType() {
129131
return TaskType.TEXT_EMBEDDING;
130132
}
131133

134+
@Override
135+
public InferenceFunctionEvaluator.Factory inferenceEvaluatorFactory() {
136+
return inferenceRunner -> new TextEmbeddingFunctionEvaluator(this, inferenceRunner);
137+
}
138+
132139
@Override
133140
public TextEmbedding withInferenceResolutionError(String inferenceId, String error) {
134141
return new TextEmbedding(source(), inputText, new UnresolvedAttribute(inferenceId().source(), inferenceId, error));

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceFunctionEvaluator.java

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,12 @@
1010
import org.elasticsearch.action.ActionListener;
1111
import org.elasticsearch.xpack.esql.core.expression.Expression;
1212
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
13-
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
14-
import org.elasticsearch.xpack.esql.expression.function.inference.TextEmbedding;
15-
import org.elasticsearch.xpack.esql.inference.textembedding.TextEmbeddingFunctionEvaluator;
1613

1714
public interface InferenceFunctionEvaluator {
1815

1916
void eval(FoldContext foldContext, ActionListener<Expression> listener);
2017

21-
static InferenceFunctionEvaluator get(InferenceFunction<?> inferenceFunction, InferenceRunner inferenceRunner) {
22-
return switch (inferenceFunction) {
23-
case TextEmbedding textEmbedding -> new TextEmbeddingFunctionEvaluator(textEmbedding, inferenceRunner);
24-
default -> throw new IllegalArgumentException("Unsupported inference function: " + inferenceFunction.getClass());
25-
};
18+
interface Factory {
19+
InferenceFunctionEvaluator get(InferenceRunner inferenceRunner);
2620
}
2721
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java

Lines changed: 15 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,13 @@
88
package org.elasticsearch.xpack.esql.optimizer;
99

1010
import org.elasticsearch.action.ActionListener;
11-
import org.elasticsearch.action.support.CountDownActionListener;
12-
import org.elasticsearch.xpack.esql.core.expression.Expression;
13-
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
14-
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
15-
import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator;
16-
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
11+
import org.elasticsearch.action.support.SubscribableListener;
12+
import org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer.InferenceFunctionConstantFolding;
13+
import org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer.PreOptimizerRule;
1714
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
1815
import org.elasticsearch.xpack.esql.plugin.TransportActionServices;
1916

20-
import java.util.ArrayList;
21-
import java.util.HashMap;
2217
import java.util.List;
23-
import java.util.Map;
2418

2519
/**
2620
* The class is responsible for invoking any steps that need to be applied to the logical plan,
@@ -31,10 +25,10 @@
3125
*/
3226
public class LogicalPlanPreOptimizer {
3327

34-
private final InferenceFunctionFolding inferenceFunctionFolding;
28+
private final List<PreOptimizerRule> rules;
3529

3630
public LogicalPlanPreOptimizer(TransportActionServices services, LogicalPreOptimizerContext preOptimizerContext) {
37-
this.inferenceFunctionFolding = new InferenceFunctionFolding(services.inferenceRunner(), preOptimizerContext.foldCtx());
31+
rules = List.of(new InferenceFunctionConstantFolding(services.inferenceRunner(), preOptimizerContext.foldCtx()));
3832
}
3933

4034
/**
@@ -55,54 +49,19 @@ public void preOptimize(LogicalPlan plan, ActionListener<LogicalPlan> listener)
5549
}));
5650
}
5751

52+
/**
53+
* Loop over the rules and apply them to the logical plan.
54+
*
55+
* @param plan the analyzed logical plan to pre-optimize
56+
* @param listener the listener returning the pre-optimized plan when pre-optimization is complete
57+
*/
5858
private void doPreOptimize(LogicalPlan plan, ActionListener<LogicalPlan> listener) {
59-
inferenceFunctionFolding.foldInferenceFunctions(plan, listener);
60-
}
61-
62-
private static class InferenceFunctionFolding {
63-
private final InferenceRunner inferenceRunner;
64-
private final FoldContext foldContext;
59+
SubscribableListener<LogicalPlan> rulesListener = SubscribableListener.newSucceeded(plan);
6560

66-
private InferenceFunctionFolding(InferenceRunner inferenceRunner, FoldContext foldContext) {
67-
this.inferenceRunner = inferenceRunner;
68-
this.foldContext = foldContext;
61+
for (PreOptimizerRule rule : rules) {
62+
rulesListener = rulesListener.andThen((l, p) -> rule.apply(p, l));
6963
}
7064

71-
private void foldInferenceFunctions(LogicalPlan plan, ActionListener<LogicalPlan> listener) {
72-
// First let's collect all the inference functions
73-
List<InferenceFunction<?>> inferenceFunctions = new ArrayList<>();
74-
plan.forEachExpressionUp(InferenceFunction.class, inferenceFunctions::add);
75-
76-
if (inferenceFunctions.isEmpty()) {
77-
// No inference functions found. Return the original plan.
78-
listener.onResponse(plan);
79-
return;
80-
}
81-
82-
// This is a map of inference functions to their results.
83-
// We will use this map to replace the inference functions in the plan.
84-
Map<InferenceFunction<?>, Expression> inferenceFunctionsToResults = new HashMap<>();
85-
86-
// Prepare a listener that will be called when all inference functions are done.
87-
// This listener will replace the inference functions in the plan with their results.
88-
CountDownActionListener completionListener = new CountDownActionListener(
89-
inferenceFunctions.size(),
90-
listener.map(
91-
ignored -> plan.transformExpressionsUp(InferenceFunction.class, f -> inferenceFunctionsToResults.getOrDefault(f, f))
92-
)
93-
);
94-
95-
// Try to compute the result for each inference function.
96-
for (InferenceFunction<?> inferenceFunction : inferenceFunctions) {
97-
foldInferenceFunction(inferenceFunction, completionListener.map(e -> {
98-
inferenceFunctionsToResults.put(inferenceFunction, e);
99-
return null;
100-
}));
101-
}
102-
}
103-
104-
private void foldInferenceFunction(InferenceFunction<?> inferenceFunction, ActionListener<Expression> listener) {
105-
InferenceFunctionEvaluator.get(inferenceFunction, inferenceRunner).eval(foldContext, listener);
106-
}
65+
rulesListener.addListener(listener);
10766
}
10867
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.action.support.CountDownActionListener;
12+
import org.elasticsearch.xpack.esql.core.expression.Expression;
13+
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
14+
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
15+
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
16+
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
17+
18+
import java.util.ArrayList;
19+
import java.util.HashMap;
20+
import java.util.List;
21+
import java.util.Map;
22+
23+
/**
24+
* Pre-optimizer rule that evaluates inference functions (like TEXT_EMBEDDING) into constant values.
25+
* <p>
26+
* This rule identifies foldable inference functions in the logical plan, executes them using the
27+
* inference runner, and replaces them with their computed results. This enables downstream
28+
* optimizations to work with the actual embedding values rather than the function calls.
29+
* <p>
30+
* The rule processes inference functions recursively, handling newly revealed functions that might
31+
* appear after the first round of folding.
32+
*/
33+
public class InferenceFunctionConstantFolding implements PreOptimizerRule {
34+
private final InferenceRunner inferenceRunner;
35+
private final FoldContext foldContext;
36+
37+
/**
38+
* Creates a new instance of the InferenceFunctionConstantFolding rule.
39+
*
40+
* @param inferenceRunner the inference runner to use for evaluating inference functions
41+
* @param foldContext the fold context to use for evaluating inference functions
42+
*/
43+
public InferenceFunctionConstantFolding(InferenceRunner inferenceRunner, FoldContext foldContext) {
44+
this.inferenceRunner = inferenceRunner;
45+
this.foldContext = foldContext;
46+
}
47+
48+
/**
49+
* Applies the InferenceFunctionConstantFolding rule to the given logical plan.
50+
*
51+
* @param plan the logical plan to apply the rule to
52+
* @param listener the listener to notify when the rule has been applied
53+
*/
54+
@Override
55+
public void apply(LogicalPlan plan, ActionListener<LogicalPlan> listener) {
56+
foldInferenceFunctions(plan, listener);
57+
}
58+
59+
/**
60+
* Recursively folds inference functions in the logical plan.
61+
* <p>
62+
* This method collects all foldable inference functions, evaluates them in parallel,
63+
* and then replaces them with their results. If new inference functions are revealed
64+
* after the first round of folding, it recursively processes them as well.
65+
*
66+
* @param plan the logical plan to fold inference functions in
67+
* @param listener the listener to notify when the folding is complete
68+
*/
69+
private void foldInferenceFunctions(LogicalPlan plan, ActionListener<LogicalPlan> listener) {
70+
// First let's collect all the inference foldable inference functions
71+
List<InferenceFunction<?>> inferenceFunctions = collectFoldableInferenceFunctions(plan);
72+
73+
if (inferenceFunctions.isEmpty()) {
74+
// No inference functions that can be evaluated at this time found. Return the original plan.
75+
listener.onResponse(plan);
76+
return;
77+
}
78+
79+
// This is a map of inference functions to their results.
80+
// We will use this map to replace the inference functions in the plan.
81+
Map<InferenceFunction<?>, Expression> inferenceFunctionsToResults = new HashMap<>();
82+
83+
// Prepare a listener that will be called when all inference functions are done.
84+
// This listener will replace the inference functions in the plan with their results and then recursively fold the remaining
85+
// inference functions.
86+
CountDownActionListener completionListener = new CountDownActionListener(
87+
inferenceFunctions.size(),
88+
listener.delegateFailureIgnoreResponseAndWrap(l -> {
89+
// Replace the inference functions in the plan with their results
90+
LogicalPlan next = plan.transformExpressionsUp(
91+
InferenceFunction.class,
92+
f -> inferenceFunctionsToResults.getOrDefault(f, f)
93+
);
94+
95+
// Recursively fold the remaining inference functions
96+
foldInferenceFunctions(next, l);
97+
})
98+
);
99+
100+
// Try to compute the result for each inference function.
101+
for (InferenceFunction<?> inferenceFunction : inferenceFunctions) {
102+
foldInferenceFunction(inferenceFunction, completionListener.map(e -> {
103+
inferenceFunctionsToResults.put(inferenceFunction, e);
104+
return null;
105+
}));
106+
}
107+
}
108+
109+
/**
110+
* Collects all foldable inference functions from the logical plan.
111+
* <p>
112+
* A function is considered foldable if:
113+
* 1. It's an instance of InferenceFunction
114+
* 2. It's marked as foldable (all parameters are constants)
115+
* 3. It doesn't contain nested inference functions
116+
*
117+
* @param plan the logical plan to collect inference functions from
118+
* @return a list of foldable inference functions
119+
*/
120+
private List<InferenceFunction<?>> collectFoldableInferenceFunctions(LogicalPlan plan) {
121+
List<InferenceFunction<?>> inferenceFunctions = new ArrayList<>();
122+
123+
plan.forEachExpressionUp(InferenceFunction.class, f -> {
124+
if (f.foldable() && f.hasNestedInferenceFunction() == false) {
125+
inferenceFunctions.add(f);
126+
}
127+
});
128+
129+
return inferenceFunctions;
130+
}
131+
132+
/**
133+
* Evaluates a single inference function asynchronously.
134+
* <p>
135+
* Uses the inference function's evaluator factory to create an evaluator
136+
* that can process the function with the given inference runner.
137+
*
138+
* @param inferenceFunction the inference function to evaluate
139+
* @param listener the listener to notify when the evaluation is complete
140+
*/
141+
private void foldInferenceFunction(InferenceFunction<?> inferenceFunction, ActionListener<Expression> listener) {
142+
inferenceFunction.inferenceEvaluatorFactory().get(inferenceRunner).eval(foldContext, listener);
143+
}
144+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.optimizer.rules.logical.preoptimizer;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
12+
13+
/**
14+
* A rule that can be applied to a logical plan before it is optimized.
15+
*/
16+
public interface PreOptimizerRule {
17+
18+
/**
19+
* Apply the rule to the logical plan.
20+
*
21+
* @param plan the analyzed logical plan to pre-optimize
22+
* @param listener the listener returning the pre-optimized plan when pre-optimization is complete
23+
*/
24+
void apply(LogicalPlan plan, ActionListener<LogicalPlan> listener);
25+
}

0 commit comments

Comments
 (0)