119
119
# _simple_combine.
120
120
DUMMY_AXIS = - 2
121
121
122
+ # Thresholds below which we will automatically rechunk to blockwise if it makes sense
123
+ # 1. Fractional change in number of chunks after rechunking
124
+ BLOCKWISE_RECHUNK_NUM_CHUNKS_THRESHOLD = 0.25
125
+ # 2. Fractional change in max chunk size after rechunking
126
+ BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD = 0.25
127
+ # 3. If input arrays have chunk size smaller than `dask.array.chunk-size`,
128
+ # then adjust chunks to meet that size first.
129
+ BLOCKWISE_DEFAULT_ARRAY_CHUNK_SIZE_FACTOR = 1.25
130
+
122
131
logger = logging .getLogger ("flox" )
123
132
124
133
@@ -223,8 +232,11 @@ def identity(x: T) -> T:
223
232
return x
224
233
225
234
226
- def _issorted (arr : np .ndarray ) -> bool :
227
- return bool ((arr [:- 1 ] <= arr [1 :]).all ())
235
+ def _issorted (arr : np .ndarray , ascending = True ) -> bool :
236
+ if ascending :
237
+ return bool ((arr [:- 1 ] <= arr [1 :]).all ())
238
+ else :
239
+ return bool ((arr [:- 1 ] >= arr [1 :]).all ())
228
240
229
241
230
242
def _is_arg_reduction (func : T_Agg ) -> bool :
@@ -325,6 +337,8 @@ def _get_optimal_chunks_for_groups(chunks, labels):
325
337
Δl = abs (c - l )
326
338
if c == 0 or newchunkidx [- 1 ] > l :
327
339
continue
340
+ f = f .item () # noqa
341
+ l = l .item () # noqa
328
342
if Δf < Δl and f > newchunkidx [- 1 ]:
329
343
newchunkidx .append (f )
330
344
else :
@@ -716,7 +730,9 @@ def rechunk_for_cohorts(
716
730
return array .rechunk ({axis : newchunks })
717
731
718
732
719
- def rechunk_for_blockwise (array : DaskArray , axis : T_Axis , labels : np .ndarray ) -> DaskArray :
733
+ def rechunk_for_blockwise (
734
+ array : DaskArray , axis : T_Axis , labels : np .ndarray , * , force : bool = True
735
+ ) -> tuple [T_MethodOpt , DaskArray ]:
720
736
"""
721
737
Rechunks array so that group boundaries line up with chunk boundaries, allowing
722
738
embarrassingly parallel group reductions.
@@ -739,14 +755,43 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) ->
739
755
DaskArray
740
756
Rechunked array
741
757
"""
742
- # TODO: this should be unnecessary?
743
- labels = factorize_ ((labels ,), axes = ())[0 ]
758
+
759
+ import dask
760
+ from dask .utils import parse_bytes
761
+
744
762
chunks = array .chunks [axis ]
745
- newchunks = _get_optimal_chunks_for_groups (chunks , labels )
763
+ if len (chunks ) == 1 :
764
+ return array
765
+
766
+ factor = parse_bytes (dask .config .get ("array.chunk-size" )) / (
767
+ math .prod (array .chunksize ) * array .dtype .itemsize
768
+ )
769
+ if factor > BLOCKWISE_DEFAULT_ARRAY_CHUNK_SIZE_FACTOR :
770
+ new_constant_chunks = math .ceil (factor ) * max (chunks )
771
+ q , r = divmod (array .shape [axis ], new_constant_chunks )
772
+ new_input_chunks = (new_constant_chunks ,) * q + (r ,)
773
+ else :
774
+ new_input_chunks = chunks
775
+
776
+ # FIXME: this should be unnecessary?
777
+ labels = factorize_ ((labels ,), axes = ())[0 ]
778
+ newchunks = _get_optimal_chunks_for_groups (new_input_chunks , labels )
746
779
if newchunks == chunks :
747
780
return array
781
+
782
+ Δn = abs (len (newchunks ) - len (new_input_chunks ))
783
+ if force or (
784
+ (Δn / len (new_input_chunks ) < BLOCKWISE_RECHUNK_NUM_CHUNKS_THRESHOLD )
785
+ and (
786
+ abs (max (newchunks ) - max (new_input_chunks )) / max (new_input_chunks )
787
+ < BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD
788
+ )
789
+ ):
790
+ logger .debug ("Rechunking to enable blockwise." )
791
+ return "blockwise" , array .rechunk ({axis : newchunks })
748
792
else :
749
- return array .rechunk ({axis : newchunks })
793
+ logger .debug ("Didn't meet thresholds to do automatic rechunking for blockwise reductions." )
794
+ return None , array
750
795
751
796
752
797
def reindex_numpy (array , from_ : pd .Index , to : pd .Index , fill_value , dtype , axis : int ):
@@ -2712,6 +2757,17 @@ def groupby_reduce(
2712
2757
has_dask = is_duck_dask_array (array ) or is_duck_dask_array (by_ )
2713
2758
has_cubed = is_duck_cubed_array (array ) or is_duck_cubed_array (by_ )
2714
2759
2760
+ if (
2761
+ method is None
2762
+ and is_duck_dask_array (array )
2763
+ and not any_by_dask
2764
+ and by_ .ndim == 1
2765
+ and _issorted (by_ , ascending = True )
2766
+ ):
2767
+ # Let's try rechunking for sorted 1D by.
2768
+ (single_axis ,) = axis_
2769
+ method , array = rechunk_for_blockwise (array , single_axis , by_ , force = False )
2770
+
2715
2771
is_first_last = _is_first_last_reduction (func )
2716
2772
if is_first_last :
2717
2773
if has_dask and nax != 1 :
0 commit comments