Skip to content

Commit 111f37a

Browse files
committed
Renamed PreOptimizer into LogicalPlanPreOptimizer
1 parent dd9237b commit 111f37a

File tree

7 files changed

+262
-328
lines changed

7 files changed

+262
-328
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public void esql(
8787
indexResolver,
8888
enrichPolicyResolver,
8989
preAnalyzer,
90-
new LogicalPlanPreOptimizer(new LogicalPreOptimizerContext(foldContext)),
90+
new LogicalPlanPreOptimizer(services, new LogicalPreOptimizerContext(foldContext)),
9191
functionRegistry,
9292
new LogicalPlanOptimizer(new LogicalOptimizerContext(cfg, foldContext)),
9393
mapper,

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanPreOptimizer.java

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,19 @@
88
package org.elasticsearch.xpack.esql.optimizer;
99

1010
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.action.support.CountDownActionListener;
12+
import org.elasticsearch.xpack.esql.core.expression.Expression;
13+
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
14+
import org.elasticsearch.xpack.esql.expression.function.inference.InferenceFunction;
15+
import org.elasticsearch.xpack.esql.inference.InferenceFunctionEvaluator;
16+
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
1117
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
18+
import org.elasticsearch.xpack.esql.plugin.TransportActionServices;
19+
20+
import java.util.ArrayList;
21+
import java.util.HashMap;
22+
import java.util.List;
23+
import java.util.Map;
1224

1325
/**
1426
* The class is responsible for invoking any steps that need to be applied to the logical plan,
@@ -19,10 +31,10 @@
1931
*/
2032
public class LogicalPlanPreOptimizer {
2133

22-
private final LogicalPreOptimizerContext preOptimizerContext;
34+
private final InferenceFunctionFolding inferenceFunctionFolding;
2335

24-
public LogicalPlanPreOptimizer(LogicalPreOptimizerContext preOptimizerContext) {
25-
this.preOptimizerContext = preOptimizerContext;
36+
public LogicalPlanPreOptimizer(TransportActionServices services, LogicalPreOptimizerContext preOptimizerContext) {
37+
this.inferenceFunctionFolding = new InferenceFunctionFolding(services.inferenceRunner(), preOptimizerContext.foldCtx());
2638
}
2739

2840
/**
@@ -44,7 +56,53 @@ public void preOptimize(LogicalPlan plan, ActionListener<LogicalPlan> listener)
4456
}
4557

4658
private void doPreOptimize(LogicalPlan plan, ActionListener<LogicalPlan> listener) {
47-
// this is where we will be executing async tasks
48-
listener.onResponse(plan);
59+
inferenceFunctionFolding.foldInferenceFunctions(plan, listener);
60+
}
61+
62+
private static class InferenceFunctionFolding {
63+
private final InferenceRunner inferenceRunner;
64+
private final FoldContext foldContext;
65+
66+
private InferenceFunctionFolding(InferenceRunner inferenceRunner, FoldContext foldContext) {
67+
this.inferenceRunner = inferenceRunner;
68+
this.foldContext = foldContext;
69+
}
70+
71+
private void foldInferenceFunctions(LogicalPlan plan, ActionListener<LogicalPlan> listener) {
72+
// First let's collect all the inference functions
73+
List<InferenceFunction<?>> inferenceFunctions = new ArrayList<>();
74+
plan.forEachExpressionUp(InferenceFunction.class, inferenceFunctions::add);
75+
76+
if (inferenceFunctions.isEmpty()) {
77+
// No inference functions found. Return the original plan.
78+
listener.onResponse(plan);
79+
return;
80+
}
81+
82+
// This is a map of inference functions to their results.
83+
// We will use this map to replace the inference functions in the plan.
84+
Map<InferenceFunction<?>, Expression> inferenceFunctionsToResults = new HashMap<>();
85+
86+
// Prepare a listener that will be called when all inference functions are done.
87+
// This listener will replace the inference functions in the plan with their results.
88+
CountDownActionListener completionListener = new CountDownActionListener(
89+
inferenceFunctions.size(),
90+
listener.map(
91+
ignored -> plan.transformExpressionsUp(InferenceFunction.class, f -> inferenceFunctionsToResults.getOrDefault(f, f))
92+
)
93+
);
94+
95+
// Try to compute the result for each inference function.
96+
for (InferenceFunction<?> inferenceFunction : inferenceFunctions) {
97+
foldInferenceFunction(inferenceFunction, completionListener.map(e -> {
98+
inferenceFunctionsToResults.put(inferenceFunction, e);
99+
return null;
100+
}));
101+
}
102+
}
103+
104+
private void foldInferenceFunction(InferenceFunction<?> inferenceFunction, ActionListener<Expression> listener) {
105+
InferenceFunctionEvaluator.get(inferenceFunction, inferenceRunner).eval(foldContext, listener);
106+
}
49107
}
50108
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PreOptimizer.java

Lines changed: 0 additions & 98 deletions
This file was deleted.

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/LogicalPlan.java

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,6 @@ public boolean optimized() {
6565
return stage.ordinal() >= Stage.OPTIMIZED.ordinal();
6666
}
6767

68-
public void setPreOptimized() {
69-
stage = Stage.PRE_OPTIMIZED;
70-
}
71-
72-
public boolean preOptimized() {
73-
return stage.ordinal() >= Stage.PRE_OPTIMIZED.ordinal();
74-
}
75-
7668
public void setOptimized() {
7769
stage = Stage.OPTIMIZED;
7870
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportActionServices.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,6 @@ public record TransportActionServices(
2828
InferenceServices inferenceServices
2929
) {
3030
public InferenceRunner inferenceRunner() {
31-
return inferenceServices.inferenceRunner();
31+
return inferenceServices == null ? null : inferenceServices.inferenceRunner();
3232
}
3333
}

0 commit comments

Comments
 (0)