Skip to content

Commit b5bc0dd

Browse files
alexeykudinkinDhakshin Suriakannu
authored andcommitted
[Data] Adding more ops to BlockColumnAccessor (ray-project#51571)
## Why are these changes needed? 1. Adding more ops to `BlockColumnAccessor` 2. Fixing circular imports in Ray Data 3. Fixing AggregateFnV2 to be proper ABC 4. Simplifying `accumulate_block` op --------- Signed-off-by: Alexey Kudinkin <ak@anyscale.com> Signed-off-by: Dhakshin Suriakannu <d_suriakannu@apple.com>
1 parent 123cec3 commit b5bc0dd

File tree

8 files changed

+98
-113
lines changed

8 files changed

+98
-113
lines changed

python/ray/air/util/tensor_extensions/arrow.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,13 @@ def _convert_to_pyarrow_native_array(
204204
if len(column_values) > 0 and isinstance(column_values[0], datetime):
205205
column_values = _convert_datetime_to_np_datetime(column_values)
206206

207+
# To avoid deserialization penalty of converting Arrow arrays (`Array` and `ChunkedArray`)
208+
# to Python objects and then back to Arrow, we instead combine them into ListArray manually
209+
if len(column_values) > 0 and isinstance(
210+
column_values[0], (pa.Array, pa.ChunkedArray)
211+
):
212+
return _combine_as_list_array(column_values)
213+
207214
# NOTE: We explicitly infer PyArrow `DataType` so that
208215
# we can perform upcasting to be able to accommodate
209216
# blocks that are larger than 2Gb in size (limited
@@ -238,6 +245,27 @@ def _convert_to_pyarrow_native_array(
238245
raise ArrowConversionError(str(column_values)) from e
239246

240247

248+
def _combine_as_list_array(column_values: List[Union[pa.Array, pa.ChunkedArray]]):
249+
"""Combines list of Arrow arrays into a single `ListArray`"""
250+
251+
# First, compute respective offsets in the resulting array
252+
lens = [len(v) for v in column_values]
253+
offsets = pa.array(np.concatenate([[0], np.cumsum(lens)]), type=pa.int32())
254+
255+
# Concat all the chunks into a single contiguous array
256+
combined = pa.concat_arrays(
257+
itertools.chain(
258+
*[
259+
v.chunks if isinstance(v, pa.ChunkedArray) else [v]
260+
for v in column_values
261+
]
262+
)
263+
)
264+
265+
# TODO support null masking
266+
return pa.ListArray.from_arrays(offsets, combined, pa.list_(combined.type))
267+
268+
241269
def _coerce_np_datetime_to_pa_timestamp_precision(
242270
column_values: np.ndarray, dtype: pa.TimestampType, column_name: str
243271
):

python/ray/data/BUILD

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -268,20 +268,6 @@ py_test(
268268
],
269269
)
270270

271-
py_test(
272-
name = "test_aggregate",
273-
size = "small",
274-
srcs = ["tests/test_aggregate.py"],
275-
tags = [
276-
"exclusive",
277-
"team:data",
278-
],
279-
deps = [
280-
":conftest",
281-
"//:ray_lib",
282-
],
283-
)
284-
285271
py_test(
286272
name = "test_avro",
287273
size = "small",

python/ray/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from packaging.version import parse as parse_version
55

66
from ray._private.arrow_utils import get_pyarrow_version
7+
78
from ray.data._internal.compute import ActorPoolStrategy
89
from ray.data._internal.datasource.tfrecords_datasource import TFXReadOptions
910
from ray.data._internal.execution.interfaces import (

python/ray/data/_internal/arrow_block.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
BlockType,
3535
U,
3636
BlockColumnAccessor,
37+
BlockColumn,
3738
)
3839
from ray.data.context import DataContext
3940

@@ -427,5 +428,28 @@ def sum_of_squared_diffs_from_mean(
427428
)
428429
return res.as_py() if as_py else res
429430

430-
def to_pylist(self):
431+
def quantile(
432+
self, *, q: float, ignore_nulls: bool, as_py: bool = True
433+
) -> Optional[U]:
434+
import pyarrow.compute as pac
435+
436+
array = pac.quantile(self._column, q=q, skip_nulls=ignore_nulls)
437+
# NOTE: That quantile method still returns an array
438+
res = array[0]
439+
return res.as_py() if as_py else res
440+
441+
def unique(self) -> BlockColumn:
442+
import pyarrow.compute as pac
443+
444+
return pac.unique(self._column)
445+
446+
def flatten(self) -> BlockColumn:
447+
import pyarrow.compute as pac
448+
449+
return pac.list_flatten(self._column)
450+
451+
def to_pylist(self) -> List[Any]:
431452
return self._column.to_pylist()
453+
454+
def _as_arrow_compatible(self) -> Union[List[Any], "pyarrow.Array"]:
455+
return self._column

python/ray/data/_internal/pandas_block.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
BlockType,
3030
U,
3131
BlockColumnAccessor,
32+
BlockColumn,
3233
)
3334
from ray.data.context import DataContext
3435

@@ -150,6 +151,18 @@ def mean(self, *, ignore_nulls: bool, as_py: bool = True) -> Optional[U]:
150151
sum_ / self.count(ignore_nulls=ignore_nulls) if not is_null(sum_) else sum_
151152
)
152153

154+
def quantile(
155+
self, *, q: float, ignore_nulls: bool, as_py: bool = True
156+
) -> Optional[U]:
157+
return self._column.quantile(q=q)
158+
159+
def unique(self) -> BlockColumn:
160+
pd = lazy_import_pandas()
161+
return pd.Series(self._column.unique())
162+
163+
def flatten(self) -> BlockColumn:
164+
return self._column.list.flatten()
165+
153166
def sum_of_squared_diffs_from_mean(
154167
self,
155168
ignore_nulls: bool,
@@ -164,9 +177,12 @@ def sum_of_squared_diffs_from_mean(
164177

165178
return ((self._column - mean) ** 2).sum(skipna=ignore_nulls)
166179

167-
def to_pylist(self):
180+
def to_pylist(self) -> List[Any]:
168181
return self._column.to_list()
169182

183+
def _as_arrow_compatible(self) -> Union[List[Any], "pyarrow.Array"]:
184+
return self.to_pylist()
185+
170186
def _is_all_null(self):
171187
return not self._column.notna().any()
172188

python/ray/data/aggregate.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import abc
22
import math
3-
from typing import TYPE_CHECKING, Any, Callable, List, Optional
3+
from typing import TYPE_CHECKING, Callable, List, Optional, Any
44

55
import numpy as np
66

7-
from ray.data._internal.planner.exchange.sort_task_spec import SortKey
87
from ray.data._internal.util import is_null
98
from ray.data.block import AggType, Block, BlockAccessor, KeyType, T, U
109
from ray.util.annotations import PublicAPI, Deprecated
@@ -105,7 +104,7 @@ def _validate(self, schema: Optional["Schema"]) -> None:
105104

106105

107106
@PublicAPI(stability="alpha")
108-
class AggregateFnV2(AggregateFn):
107+
class AggregateFnV2(AggregateFn, abc.ABC):
109108
"""Provides an interface to implement efficient aggregations to be applied
110109
to the dataset.
111110
@@ -148,9 +147,7 @@ def __init__(
148147
name=name,
149148
init=_safe_zero_factory,
150149
merge=_safe_combine,
151-
accumulate_block=(
152-
lambda acc, block: _safe_combine(acc, _safe_aggregate(block))
153-
),
150+
accumulate_block=lambda _, block: _safe_aggregate(block),
154151
finalize=_safe_finalize,
155152
)
156153

@@ -177,6 +174,8 @@ def _finalize(self, accumulator: AggType) -> Optional[U]:
177174

178175
def _validate(self, schema: Optional["Schema"]) -> None:
179176
if self._target_col_name:
177+
from ray.data._internal.planner.exchange.sort_task_spec import SortKey
178+
180179
SortKey(self._target_col_name).validate_schema(schema)
181180

182181

@@ -540,12 +539,13 @@ class Unique(AggregateFnV2):
540539
def __init__(
541540
self,
542541
on: Optional[str] = None,
542+
ignore_nulls: bool = True,
543543
alias_name: Optional[str] = None,
544544
):
545545
super().__init__(
546546
alias_name if alias_name else f"unique({str(on)})",
547547
on=on,
548-
ignore_nulls=False,
548+
ignore_nulls=ignore_nulls,
549549
zero_factory=set,
550550
)
551551

python/ray/data/block.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,21 @@ def mean(self, *, ignore_nulls: bool, as_py: bool = True) -> Optional[U]:
573573
"""Returns a mean of the values in the column"""
574574
raise NotImplementedError()
575575

576+
def quantile(
577+
self, *, q: float, ignore_nulls: bool, as_py: bool = True
578+
) -> Optional[U]:
579+
"""Returns requested quantile of the given column"""
580+
raise NotImplementedError()
581+
582+
def unique(self) -> BlockColumn:
583+
"""Returns new column holding only distinct values of the current one"""
584+
raise NotImplementedError()
585+
586+
def flatten(self) -> BlockColumn:
587+
"""Flattens nested lists merging them into top-level container"""
588+
589+
raise NotImplementedError()
590+
576591
def sum_of_squared_diffs_from_mean(
577592
self,
578593
*,
@@ -583,10 +598,14 @@ def sum_of_squared_diffs_from_mean(
583598
"""Returns a sum of diffs (from mean) squared for the column"""
584599
raise NotImplementedError()
585600

586-
def to_pylist(self):
601+
def to_pylist(self) -> List[Any]:
587602
"""Converts block column to a list of Python native objects"""
588603
raise NotImplementedError()
589604

605+
def _as_arrow_compatible(self) -> Union[List[Any], "pyarrow.Array"]:
606+
"""Converts block column into a representation compatible with Arrow"""
607+
raise NotImplementedError()
608+
590609
@staticmethod
591610
def for_column(col: BlockColumn) -> "BlockColumnAccessor":
592611
"""Create a column accessor for the given column"""

python/ray/data/tests/test_aggregate.py

Lines changed: 0 additions & 89 deletions
This file was deleted.

0 commit comments

Comments
 (0)