Skip to content

Commit f0af7ee

Browse files
committed
Auto rechunk to enable blockwise reduction
Done when 1. `method` is None 2. Grouping and reducing by a 1D array We gate this on fractional change in number of chunks and change in size of largest chunk. Closes #359
1 parent 23f1e49 commit f0af7ee

File tree

1 file changed

+63
-7
lines changed

1 file changed

+63
-7
lines changed

flox/core.py

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,15 @@
119119
# _simple_combine.
120120
DUMMY_AXIS = -2
121121

122+
# Thresholds below which we will automatically rechunk to blockwise if it makes sense
123+
# 1. Fractional change in number of chunks after rechunking
124+
BLOCKWISE_RECHUNK_NUM_CHUNKS_THRESHOLD = 0.25
125+
# 2. Fractional change in max chunk size after rechunking
126+
BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD = 0.25
127+
# 3. If input arrays have chunk size smaller than `dask.array.chunk-size`,
128+
# then adjust chunks to meet that size first.
129+
BLOCKWISE_DEFAULT_ARRAY_CHUNK_SIZE_FACTOR = 1.25
130+
122131
logger = logging.getLogger("flox")
123132

124133

@@ -223,8 +232,11 @@ def identity(x: T) -> T:
223232
return x
224233

225234

226-
def _issorted(arr: np.ndarray) -> bool:
227-
return bool((arr[:-1] <= arr[1:]).all())
235+
def _issorted(arr: np.ndarray, ascending=True) -> bool:
236+
if ascending:
237+
return bool((arr[:-1] <= arr[1:]).all())
238+
else:
239+
return bool((arr[:-1] >= arr[1:]).all())
228240

229241

230242
def _is_arg_reduction(func: T_Agg) -> bool:
@@ -325,6 +337,8 @@ def _get_optimal_chunks_for_groups(chunks, labels):
325337
Δl = abs(c - l)
326338
if c == 0 or newchunkidx[-1] > l:
327339
continue
340+
f = f.item() # noqa
341+
l = l.item() # noqa
328342
if Δf < Δl and f > newchunkidx[-1]:
329343
newchunkidx.append(f)
330344
else:
@@ -716,7 +730,9 @@ def rechunk_for_cohorts(
716730
return array.rechunk({axis: newchunks})
717731

718732

719-
def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) -> DaskArray:
733+
def rechunk_for_blockwise(
734+
array: DaskArray, axis: T_Axis, labels: np.ndarray, *, force: bool = True
735+
) -> tuple[T_MethodOpt, DaskArray]:
720736
"""
721737
Rechunks array so that group boundaries line up with chunk boundaries, allowing
722738
embarrassingly parallel group reductions.
@@ -739,14 +755,43 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) ->
739755
DaskArray
740756
Rechunked array
741757
"""
742-
# TODO: this should be unnecessary?
743-
labels = factorize_((labels,), axes=())[0]
758+
759+
import dask
760+
from dask.utils import parse_bytes
761+
744762
chunks = array.chunks[axis]
745-
newchunks = _get_optimal_chunks_for_groups(chunks, labels)
763+
if len(chunks) == 1:
764+
return array
765+
766+
factor = parse_bytes(dask.config.get("array.chunk-size")) / (
767+
math.prod(array.chunksize) * array.dtype.itemsize
768+
)
769+
if factor > BLOCKWISE_DEFAULT_ARRAY_CHUNK_SIZE_FACTOR:
770+
new_constant_chunks = math.ceil(factor) * max(chunks)
771+
q, r = divmod(array.shape[axis], new_constant_chunks)
772+
new_input_chunks = (new_constant_chunks,) * q + (r,)
773+
else:
774+
new_input_chunks = chunks
775+
776+
# FIXME: this should be unnecessary?
777+
labels = factorize_((labels,), axes=())[0]
778+
newchunks = _get_optimal_chunks_for_groups(new_input_chunks, labels)
746779
if newchunks == chunks:
747780
return array
781+
782+
Δn = abs(len(newchunks) - len(new_input_chunks))
783+
if force or (
784+
(Δn / len(new_input_chunks) < BLOCKWISE_RECHUNK_NUM_CHUNKS_THRESHOLD)
785+
and (
786+
abs(max(newchunks) - max(new_input_chunks)) / max(new_input_chunks)
787+
< BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD
788+
)
789+
):
790+
logger.debug("Rechunking to enable blockwise.")
791+
return "blockwise", array.rechunk({axis: newchunks})
748792
else:
749-
return array.rechunk({axis: newchunks})
793+
logger.debug("Didn't meet thresholds to do automatic rechunking for blockwise reductions.")
794+
return None, array
750795

751796

752797
def reindex_numpy(array, from_: pd.Index, to: pd.Index, fill_value, dtype, axis: int):
@@ -2712,6 +2757,17 @@ def groupby_reduce(
27122757
has_dask = is_duck_dask_array(array) or is_duck_dask_array(by_)
27132758
has_cubed = is_duck_cubed_array(array) or is_duck_cubed_array(by_)
27142759

2760+
if (
2761+
method is None
2762+
and is_duck_dask_array(array)
2763+
and not any_by_dask
2764+
and by_.ndim == 1
2765+
and _issorted(by_, ascending=True)
2766+
):
2767+
# Let's try rechunking for sorted 1D by.
2768+
(single_axis,) = axis_
2769+
method, array = rechunk_for_blockwise(array, single_axis, by_, force=False)
2770+
27152771
is_first_last = _is_first_last_reduction(func)
27162772
if is_first_last:
27172773
if has_dask and nax != 1:

0 commit comments

Comments
 (0)