Skip to content

Commit 4968959

Browse files
committed
feat(cubesql): Support MEASURE SQL push down
1 parent 73d9314 commit 4968959

File tree

3 files changed

+173
-5
lines changed

3 files changed

+173
-5
lines changed

rust/cubesql/cubesql/src/compile/engine/udf/common.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2268,7 +2268,9 @@ pub fn create_measure_udaf() -> AggregateUDF {
22682268
DataType::Float64,
22692269
Arc::new(DataType::Float64),
22702270
Volatility::Immutable,
2271-
Arc::new(|| todo!("Not implemented")),
2271+
Arc::new(|| {
2272+
Err(DataFusionError::NotImplemented("MEASURE function was used in context where it's not supported. Try replacing MEASURE with the measure type-matching function (SUM/AVG/etc).".to_string()))
2273+
}),
22722274
Arc::new(vec![DataType::Float64]),
22732275
)
22742276
}

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18588,4 +18588,37 @@ LIMIT {{ limit }}{% endif %}"#.to_string(),
1858818588

1858918589
Ok(())
1859018590
}
18591+
18592+
#[tokio::test]
18593+
async fn test_measure_func_push_down() {
18594+
if !Rewriter::sql_push_down_enabled() {
18595+
return;
18596+
}
18597+
init_testing_logger();
18598+
18599+
let query_plan = convert_select_to_query_plan(
18600+
r#"
18601+
SELECT MEASURE("sumPrice") AS "total_price"
18602+
FROM "public"."KibanaSampleDataEcommerce"
18603+
WHERE lower("customer_gender") = '123'
18604+
"#
18605+
.to_string(),
18606+
DatabaseProtocol::PostgreSQL,
18607+
)
18608+
.await;
18609+
18610+
let logical_plan = query_plan.as_logical_plan();
18611+
let sql = logical_plan
18612+
.find_cube_scan_wrapper()
18613+
.wrapped_sql
18614+
.unwrap()
18615+
.sql;
18616+
assert!(sql.contains("SUM("));
18617+
18618+
let physical_plan = query_plan.as_physical_plan().await.unwrap();
18619+
println!(
18620+
"Physical plan: {}",
18621+
displayable(physical_plan.as_ref()).indent()
18622+
);
18623+
}
1859118624
}

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs

Lines changed: 137 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
use crate::{
22
compile::rewrite::{
3-
agg_fun_expr, analysis::LogicalPlanAnalysis, rewrite, rules::wrapper::WrapperRules,
4-
transforming_rewrite, wrapper_pullup_replacer, wrapper_pushdown_replacer,
5-
AggregateFunctionExprDistinct, AggregateFunctionExprFun, LogicalPlanLanguage,
6-
WrapperPullupReplacerAliasToCube,
3+
agg_fun_expr, alias_expr, analysis::LogicalPlanAnalysis, column_expr, original_expr_name,
4+
rewrite, rules::wrapper::WrapperRules, transforming_chain_rewrite, transforming_rewrite,
5+
udaf_expr, wrapper_pullup_replacer, wrapper_pushdown_replacer,
6+
AggregateFunctionExprDistinct, AggregateFunctionExprFun, AggregateUDFExprFun,
7+
AliasExprAlias, ColumnExprColumn, LogicalPlanLanguage, WrapperPullupReplacerAliasToCube,
78
},
9+
transport::V1CubeMetaExt,
810
var, var_iter,
911
};
1012
use datafusion::physical_plan::aggregates::AggregateFunction;
@@ -59,6 +61,35 @@ impl WrapperRules {
5961
),
6062
self.transform_agg_fun_expr("?fun", "?distinct", "?alias_to_cube"),
6163
),
64+
transforming_chain_rewrite(
65+
"wrapper-push-down-measure-aggregate-function",
66+
wrapper_pushdown_replacer(
67+
"?udaf",
68+
"?alias_to_cube",
69+
"?ungrouped",
70+
"?in_projection",
71+
"?cube_members",
72+
),
73+
vec![("?udaf", udaf_expr("?fun", vec![column_expr("?column")]))],
74+
alias_expr(
75+
wrapper_pushdown_replacer(
76+
"?output",
77+
"?alias_to_cube",
78+
"?ungrouped",
79+
"?in_projection",
80+
"?cube_members",
81+
),
82+
"?alias",
83+
),
84+
self.transform_measure_udaf_expr(
85+
"?udaf",
86+
"?fun",
87+
"?column",
88+
"?alias_to_cube",
89+
"?output",
90+
"?alias",
91+
),
92+
),
6293
]);
6394
}
6495

@@ -105,4 +136,106 @@ impl WrapperRules {
105136
false
106137
}
107138
}
139+
140+
fn transform_measure_udaf_expr(
141+
&self,
142+
udaf_var: &'static str,
143+
fun_var: &'static str,
144+
column_var: &'static str,
145+
alias_to_cube_var: &'static str,
146+
output_var: &'static str,
147+
alias_var: &'static str,
148+
) -> impl Fn(&mut EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>, &mut Subst) -> bool {
149+
let udaf_var = var!(udaf_var);
150+
let fun_var = var!(fun_var);
151+
let column_var = var!(column_var);
152+
let alias_to_cube_var = var!(alias_to_cube_var);
153+
let output_var = var!(output_var);
154+
let alias_var = var!(alias_var);
155+
let meta = self.meta_context.clone();
156+
move |egraph, subst| {
157+
let Some(original_alias) = original_expr_name(egraph, subst[udaf_var]) else {
158+
return false;
159+
};
160+
161+
for fun in var_iter!(egraph[subst[fun_var]], AggregateUDFExprFun) {
162+
if fun.to_lowercase() != "measure" {
163+
continue;
164+
}
165+
166+
for column in var_iter!(egraph[subst[column_var]], ColumnExprColumn) {
167+
for alias_to_cube in var_iter!(
168+
egraph[subst[alias_to_cube_var]],
169+
WrapperPullupReplacerAliasToCube
170+
) {
171+
let Some((_, cube)) = meta.find_cube_by_column(alias_to_cube, column)
172+
else {
173+
continue;
174+
};
175+
176+
let Some(measure) = cube.lookup_measure(&column.name) else {
177+
continue;
178+
};
179+
180+
let Some(agg_type) = &measure.agg_type else {
181+
continue;
182+
};
183+
184+
let out_fun_distinct = match agg_type.as_str() {
185+
"string" | "time" | "boolean" | "number" => None,
186+
"count" => Some((AggregateFunction::Count, false)),
187+
"countDistinct" => Some((AggregateFunction::Count, true)),
188+
"countDistinctApprox" => {
189+
Some((AggregateFunction::ApproxDistinct, false))
190+
}
191+
"sum" => Some((AggregateFunction::Sum, false)),
192+
"avg" => Some((AggregateFunction::Avg, false)),
193+
"min" => Some((AggregateFunction::Min, false)),
194+
"max" => Some((AggregateFunction::Max, false)),
195+
_ => continue,
196+
};
197+
198+
let column_expr_id =
199+
egraph.add(LogicalPlanLanguage::ColumnExpr([subst[column_var]]));
200+
201+
let output_id = out_fun_distinct
202+
.map(|(out_fun, distinct)| {
203+
let fun_id =
204+
egraph.add(LogicalPlanLanguage::AggregateFunctionExprFun(
205+
AggregateFunctionExprFun(out_fun),
206+
));
207+
let args_tail_id = egraph
208+
.add(LogicalPlanLanguage::AggregateFunctionExprArgs(vec![]));
209+
let args_id =
210+
egraph.add(LogicalPlanLanguage::AggregateFunctionExprArgs(
211+
vec![column_expr_id, args_tail_id],
212+
));
213+
let distinct_id =
214+
egraph.add(LogicalPlanLanguage::AggregateFunctionExprDistinct(
215+
AggregateFunctionExprDistinct(distinct),
216+
));
217+
218+
egraph.add(LogicalPlanLanguage::AggregateFunctionExpr([
219+
fun_id,
220+
args_id,
221+
distinct_id,
222+
]))
223+
})
224+
.unwrap_or(column_expr_id);
225+
226+
subst.insert(output_var, output_id);
227+
228+
subst.insert(
229+
alias_var,
230+
egraph.add(LogicalPlanLanguage::AliasExprAlias(AliasExprAlias(
231+
original_alias,
232+
))),
233+
);
234+
return true;
235+
}
236+
}
237+
}
238+
false
239+
}
240+
}
108241
}

0 commit comments

Comments
 (0)