diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/rolling.py b/python/cudf_polars/cudf_polars/dsl/expressions/rolling.py index e8cbf0c3f25..df039d89283 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/rolling.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/rolling.py @@ -48,6 +48,11 @@ class FillNullWithStrategyOp(UnaryOp): policy: plc.replace.ReplacePolicy = plc.replace.ReplacePolicy.PRECEDING +@dataclass(frozen=True) +class CumSumOp(UnaryOp): + pass + + def to_request( value: expr.Expr, orderby: Column, df: DataFrame ) -> plc.rolling.RollingRequest: @@ -241,7 +246,8 @@ def __init__( isinstance(named_expr.value, (expr.Len, expr.Agg)) or ( isinstance(named_expr.value, expr.UnaryFunction) - and named_expr.value.name in {"rank", "fill_null_with_strategy"} + and named_expr.value.name + in {"rank", "fill_null_with_strategy", "cum_sum"} ) ) ] @@ -265,7 +271,7 @@ def __init__( if isinstance(v, expr.Agg) or ( isinstance(v, expr.UnaryFunction) - and v.name in {"rank", "fill_null_with_strategy"} + and v.name in {"rank", "fill_null_with_strategy", "cum_sum"} ) ] self.by_count = len(by_expr) @@ -393,6 +399,41 @@ def _( dtypes = [ne.value.dtype for ne in named_exprs] return names, dtypes, tables + @_apply_unary_op.register + def _( + self, + op: CumSumOp, + df: DataFrame, + _: plc.groupby.GroupBy, + ) -> tuple[list[str], list[DataType], list[plc.Table]]: + cum_named = op.named_exprs + order_index = op.order_index + + requests: list[plc.groupby.GroupByRequest] = [] + out_names: list[str] = [] + out_dtypes: list[DataType] = [] + + val_cols = self._gather_columns( + [ + ne.value.children[0].evaluate(df, context=ExecutionContext.FRAME).obj + for ne in cum_named + ], + order_index, + cudf_polars_column=False, + ) + agg = plc.aggregation.sum() + + for ne, val_col in zip(cum_named, val_cols, strict=True): + requests.append(plc.groupby.GroupByRequest(val_col, [agg])) + out_names.append(ne.name) + out_dtypes.append(ne.value.dtype) + + lg = op.local_grouper + assert isinstance(lg, plc.groupby.GroupBy) + _, tables = lg.scan(requests) + + return out_names, out_dtypes, tables + def _reorder_to_input( self, row_id: plc.Column, @@ -444,6 +485,7 @@ def _split_named_expr( unary_window_ops: dict[str, list[expr.NamedExpr]] = { "rank": [], "fill_null_with_strategy": [], + "cum_sum": [], } for ne in self.named_aggs: @@ -733,6 +775,40 @@ def do_evaluate( # noqa: D102 ) ) + if cum_named := unary_window_ops["cum_sum"]: + order_index = self._build_window_order_index( + by_cols, + row_id=row_id, + order_by_col=order_by_col if self._order_by_expr is not None else None, + ob_desc=self.options[2] if self._order_by_expr is not None else False, + ob_nulls_last=self.options[3] + if self._order_by_expr is not None + else False, + ) + by_cols_for_scan = self._gather_columns(by_cols, order_index) + local = self._sorted_grouper(by_cols_for_scan) + names, dtypes, tables = self._apply_unary_op( + CumSumOp( + named_exprs=cum_named, + order_index=order_index, + by_cols_for_scan=by_cols_for_scan, + local_grouper=local, + ), + df, + grouper, + ) + broadcasted_cols.extend( + self._reorder_to_input( + row_id, + by_cols, + df.num_rows, + tables, + names, + dtypes, + order_index=order_index, + ) + ) + # Create a temporary DataFrame with the broadcasted columns named by their # placeholder names from agg decomposition, then evaluate the post-expression. df = DataFrame(broadcasted_cols) diff --git a/python/cudf_polars/cudf_polars/dsl/utils/aggregations.py b/python/cudf_polars/cudf_polars/dsl/utils/aggregations.py index 6e1eefb298f..5699dc639bd 100644 --- a/python/cudf_polars/cudf_polars/dsl/utils/aggregations.py +++ b/python/cudf_polars/cudf_polars/dsl/utils/aggregations.py @@ -92,6 +92,7 @@ def decompose_single_agg( if isinstance(agg, expr.UnaryFunction) and agg.name in { "rank", "fill_null_with_strategy", + "cum_sum", }: if context != ExecutionContext.WINDOW: raise NotImplementedError( diff --git a/python/cudf_polars/tests/expressions/test_rolling.py b/python/cudf_polars/tests/expressions/test_rolling.py index 6abe7af1a14..dd01511c85e 100644 --- a/python/cudf_polars/tests/expressions/test_rolling.py +++ b/python/cudf_polars/tests/expressions/test_rolling.py @@ -398,3 +398,29 @@ def test_fill_over( def test_fill_null_with_mean_over_unsupported(df: pl.LazyFrame) -> None: q = df.select(pl.col("x").fill_null(strategy="mean").over("g")) assert_ir_translation_raises(q, NotImplementedError) + + +@pytest.mark.parametrize( + "expr,group_key", + [ + (pl.col("x"), "g"), + (pl.when((pl.col("x") % 4) == 1).then(None).otherwise(pl.col("x")), "g"), + (pl.col("x"), "g_null"), + ], +) +@pytest.mark.parametrize( + "order_by", + [ + None, + ["g2", pl.col("x2") * 2], + ], +) +def test_cum_sum_over( + df: pl.LazyFrame, + *, + expr: pl.Expr, + group_key: str, + order_by: None | list[str | pl.Expr], +) -> None: + q = df.select(expr.cum_sum().over(group_key, order_by=order_by)) + assert_gpu_result_equal(q)