Skip to content

Commit 89ec530

Browse files
committed
feat(cubesql): Support MEASURE SQL push down
1 parent 73d9314 commit 89ec530

File tree

3 files changed

+170
-5
lines changed

3 files changed

+170
-5
lines changed

rust/cubesql/cubesql/src/compile/engine/udf/common.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2268,7 +2268,9 @@ pub fn create_measure_udaf() -> AggregateUDF {
22682268
DataType::Float64,
22692269
Arc::new(DataType::Float64),
22702270
Volatility::Immutable,
2271-
Arc::new(|| todo!("Not implemented")),
2271+
Arc::new(|| {
2272+
Err(DataFusionError::NotImplemented("MEASURE function was used in context where it's not supported. Try replacing MEASURE with the measure type-matching function (SUM/AVG/etc).".to_string()))
2273+
}),
22722274
Arc::new(vec![DataType::Float64]),
22732275
)
22742276
}

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18588,4 +18588,37 @@ LIMIT {{ limit }}{% endif %}"#.to_string(),
1858818588

1858918589
Ok(())
1859018590
}
18591+
18592+
#[tokio::test]
18593+
async fn test_measure_func_push_down() {
18594+
if !Rewriter::sql_push_down_enabled() {
18595+
return;
18596+
}
18597+
init_testing_logger();
18598+
18599+
let query_plan = convert_select_to_query_plan(
18600+
r#"
18601+
SELECT MEASURE("sumPrice") AS "total_price"
18602+
FROM "public"."KibanaSampleDataEcommerce"
18603+
WHERE lower("customer_gender") = '123'
18604+
"#
18605+
.to_string(),
18606+
DatabaseProtocol::PostgreSQL,
18607+
)
18608+
.await;
18609+
18610+
let logical_plan = query_plan.as_logical_plan();
18611+
let sql = logical_plan
18612+
.find_cube_scan_wrapper()
18613+
.wrapped_sql
18614+
.unwrap()
18615+
.sql;
18616+
assert!(sql.contains("SUM("));
18617+
18618+
let physical_plan = query_plan.as_physical_plan().await.unwrap();
18619+
println!(
18620+
"Physical plan: {}",
18621+
displayable(physical_plan.as_ref()).indent()
18622+
);
18623+
}
1859118624
}

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs

Lines changed: 134 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
use crate::{
22
compile::rewrite::{
3-
agg_fun_expr, analysis::LogicalPlanAnalysis, rewrite, rules::wrapper::WrapperRules,
4-
transforming_rewrite, wrapper_pullup_replacer, wrapper_pushdown_replacer,
5-
AggregateFunctionExprDistinct, AggregateFunctionExprFun, LogicalPlanLanguage,
6-
WrapperPullupReplacerAliasToCube,
3+
agg_fun_expr, alias_expr, analysis::LogicalPlanAnalysis, column_expr, original_expr_name,
4+
rewrite, rules::wrapper::WrapperRules, transforming_chain_rewrite, transforming_rewrite,
5+
udaf_expr, wrapper_pullup_replacer, wrapper_pushdown_replacer,
6+
AggregateFunctionExprDistinct, AggregateFunctionExprFun, AggregateUDFExprFun,
7+
AliasExprAlias, ColumnExprColumn, LogicalPlanLanguage, WrapperPullupReplacerAliasToCube,
78
},
9+
transport::V1CubeMetaExt,
810
var, var_iter,
911
};
1012
use datafusion::physical_plan::aggregates::AggregateFunction;
@@ -59,6 +61,32 @@ impl WrapperRules {
5961
),
6062
self.transform_agg_fun_expr("?fun", "?distinct", "?alias_to_cube"),
6163
),
64+
transforming_chain_rewrite(
65+
"wrapper-push-down-measure-aggregate-function",
66+
wrapper_pushdown_replacer(
67+
"?udaf",
68+
"?alias_to_cube",
69+
"?ungrouped",
70+
"?in_projection",
71+
"?cube_members",
72+
),
73+
vec![("?udaf", udaf_expr("?fun", vec![column_expr("?column")]))],
74+
wrapper_pushdown_replacer(
75+
alias_expr("?output", "?alias"),
76+
"?alias_to_cube",
77+
"?ungrouped",
78+
"?in_projection",
79+
"?cube_members",
80+
),
81+
self.transform_measure_udaf_expr(
82+
"?udaf",
83+
"?fun",
84+
"?column",
85+
"?alias_to_cube",
86+
"?output",
87+
"?alias",
88+
),
89+
),
6290
]);
6391
}
6492

@@ -105,4 +133,106 @@ impl WrapperRules {
105133
false
106134
}
107135
}
136+
137+
fn transform_measure_udaf_expr(
138+
&self,
139+
udaf_var: &'static str,
140+
fun_var: &'static str,
141+
column_var: &'static str,
142+
alias_to_cube_var: &'static str,
143+
output_var: &'static str,
144+
alias_var: &'static str,
145+
) -> impl Fn(&mut EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>, &mut Subst) -> bool {
146+
let udaf_var = var!(udaf_var);
147+
let fun_var = var!(fun_var);
148+
let column_var = var!(column_var);
149+
let alias_to_cube_var = var!(alias_to_cube_var);
150+
let output_var = var!(output_var);
151+
let alias_var = var!(alias_var);
152+
let meta = self.meta_context.clone();
153+
move |egraph, subst| {
154+
let Some(original_alias) = original_expr_name(egraph, subst[udaf_var]) else {
155+
return false;
156+
};
157+
158+
for fun in var_iter!(egraph[subst[fun_var]], AggregateUDFExprFun) {
159+
if fun.to_lowercase() != "measure" {
160+
continue;
161+
}
162+
163+
for column in var_iter!(egraph[subst[column_var]], ColumnExprColumn) {
164+
for alias_to_cube in var_iter!(
165+
egraph[subst[alias_to_cube_var]],
166+
WrapperPullupReplacerAliasToCube
167+
) {
168+
let Some((_, cube)) = meta.find_cube_by_column(alias_to_cube, column)
169+
else {
170+
continue;
171+
};
172+
173+
let Some(measure) = cube.lookup_measure(&column.name) else {
174+
continue;
175+
};
176+
177+
let Some(agg_type) = &measure.agg_type else {
178+
continue;
179+
};
180+
181+
let out_fun_distinct = match agg_type.as_str() {
182+
"string" | "time" | "boolean" | "number" => None,
183+
"count" => Some((AggregateFunction::Count, false)),
184+
"countDistinct" => Some((AggregateFunction::Count, true)),
185+
"countDistinctApprox" => {
186+
Some((AggregateFunction::ApproxDistinct, false))
187+
}
188+
"sum" => Some((AggregateFunction::Sum, false)),
189+
"avg" => Some((AggregateFunction::Avg, false)),
190+
"min" => Some((AggregateFunction::Min, false)),
191+
"max" => Some((AggregateFunction::Max, false)),
192+
_ => continue,
193+
};
194+
195+
let column_expr_id =
196+
egraph.add(LogicalPlanLanguage::ColumnExpr([subst[column_var]]));
197+
198+
let output_id = out_fun_distinct
199+
.map(|(out_fun, distinct)| {
200+
let fun_id =
201+
egraph.add(LogicalPlanLanguage::AggregateFunctionExprFun(
202+
AggregateFunctionExprFun(out_fun),
203+
));
204+
let args_tail_id = egraph
205+
.add(LogicalPlanLanguage::AggregateFunctionExprArgs(vec![]));
206+
let args_id =
207+
egraph.add(LogicalPlanLanguage::AggregateFunctionExprArgs(
208+
vec![column_expr_id, args_tail_id],
209+
));
210+
let distinct_id =
211+
egraph.add(LogicalPlanLanguage::AggregateFunctionExprDistinct(
212+
AggregateFunctionExprDistinct(distinct),
213+
));
214+
215+
egraph.add(LogicalPlanLanguage::AggregateFunctionExpr([
216+
fun_id,
217+
args_id,
218+
distinct_id,
219+
]))
220+
})
221+
.unwrap_or(column_expr_id);
222+
223+
subst.insert(output_var, output_id);
224+
225+
subst.insert(
226+
alias_var,
227+
egraph.add(LogicalPlanLanguage::AliasExprAlias(AliasExprAlias(
228+
original_alias,
229+
))),
230+
);
231+
return true;
232+
}
233+
}
234+
}
235+
false
236+
}
237+
}
108238
}

0 commit comments

Comments
 (0)