8
8
package org .elasticsearch .xpack .esql .optimizer ;
9
9
10
10
import org .elasticsearch .action .ActionListener ;
11
+ import org .elasticsearch .action .support .CountDownActionListener ;
12
+ import org .elasticsearch .xpack .esql .core .expression .Expression ;
13
+ import org .elasticsearch .xpack .esql .core .expression .FoldContext ;
14
+ import org .elasticsearch .xpack .esql .expression .function .inference .InferenceFunction ;
15
+ import org .elasticsearch .xpack .esql .inference .InferenceFunctionEvaluator ;
16
+ import org .elasticsearch .xpack .esql .inference .InferenceRunner ;
11
17
import org .elasticsearch .xpack .esql .plan .logical .LogicalPlan ;
18
+ import org .elasticsearch .xpack .esql .plugin .TransportActionServices ;
19
+
20
+ import java .util .ArrayList ;
21
+ import java .util .HashMap ;
22
+ import java .util .List ;
23
+ import java .util .Map ;
12
24
13
25
/**
14
26
* The class is responsible for invoking any steps that need to be applied to the logical plan,
19
31
*/
20
32
public class LogicalPlanPreOptimizer {
21
33
22
- private final LogicalPreOptimizerContext preOptimizerContext ;
34
+ private final InferenceFunctionFolding inferenceFunctionFolding ;
23
35
24
- public LogicalPlanPreOptimizer (LogicalPreOptimizerContext preOptimizerContext ) {
25
- this .preOptimizerContext = preOptimizerContext ;
36
+ public LogicalPlanPreOptimizer (TransportActionServices services , LogicalPreOptimizerContext preOptimizerContext ) {
37
+ this .inferenceFunctionFolding = new InferenceFunctionFolding ( services . inferenceRunner (), preOptimizerContext . foldCtx ()) ;
26
38
}
27
39
28
40
/**
@@ -44,7 +56,53 @@ public void preOptimize(LogicalPlan plan, ActionListener<LogicalPlan> listener)
44
56
}
45
57
46
58
private void doPreOptimize (LogicalPlan plan , ActionListener <LogicalPlan > listener ) {
47
- // this is where we will be executing async tasks
48
- listener .onResponse (plan );
59
+ inferenceFunctionFolding .foldInferenceFunctions (plan , listener );
60
+ }
61
+
62
+ private static class InferenceFunctionFolding {
63
+ private final InferenceRunner inferenceRunner ;
64
+ private final FoldContext foldContext ;
65
+
66
+ private InferenceFunctionFolding (InferenceRunner inferenceRunner , FoldContext foldContext ) {
67
+ this .inferenceRunner = inferenceRunner ;
68
+ this .foldContext = foldContext ;
69
+ }
70
+
71
+ private void foldInferenceFunctions (LogicalPlan plan , ActionListener <LogicalPlan > listener ) {
72
+ // First let's collect all the inference functions
73
+ List <InferenceFunction <?>> inferenceFunctions = new ArrayList <>();
74
+ plan .forEachExpressionUp (InferenceFunction .class , inferenceFunctions ::add );
75
+
76
+ if (inferenceFunctions .isEmpty ()) {
77
+ // No inference functions found. Return the original plan.
78
+ listener .onResponse (plan );
79
+ return ;
80
+ }
81
+
82
+ // This is a map of inference functions to their results.
83
+ // We will use this map to replace the inference functions in the plan.
84
+ Map <InferenceFunction <?>, Expression > inferenceFunctionsToResults = new HashMap <>();
85
+
86
+ // Prepare a listener that will be called when all inference functions are done.
87
+ // This listener will replace the inference functions in the plan with their results.
88
+ CountDownActionListener completionListener = new CountDownActionListener (
89
+ inferenceFunctions .size (),
90
+ listener .map (
91
+ ignored -> plan .transformExpressionsUp (InferenceFunction .class , f -> inferenceFunctionsToResults .getOrDefault (f , f ))
92
+ )
93
+ );
94
+
95
+ // Try to compute the result for each inference function.
96
+ for (InferenceFunction <?> inferenceFunction : inferenceFunctions ) {
97
+ foldInferenceFunction (inferenceFunction , completionListener .map (e -> {
98
+ inferenceFunctionsToResults .put (inferenceFunction , e );
99
+ return null ;
100
+ }));
101
+ }
102
+ }
103
+
104
+ private void foldInferenceFunction (InferenceFunction <?> inferenceFunction , ActionListener <Expression > listener ) {
105
+ InferenceFunctionEvaluator .get (inferenceFunction , inferenceRunner ).eval (foldContext , listener );
106
+ }
49
107
}
50
108
}
0 commit comments