Skip to content

Commit eac5105

Browse files
fjetterpre-commit-ci[bot]dcherian
authored
Avoid local functions in push (#9856)
* Avoid local functions in push * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
1 parent fab900c commit eac5105

File tree

1 file changed

+43
-28
lines changed

1 file changed

+43
-28
lines changed

xarray/core/dask_array_ops.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import math
4+
from functools import partial
45

56
from xarray.core import dtypes, nputils
67

@@ -75,6 +76,47 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
7576
return coeffs, residuals
7677

7778

79+
def _fill_with_last_one(a, b):
80+
import numpy as np
81+
82+
# cumreduction apply the push func over all the blocks first so,
83+
# the only missing part is filling the missing values using the
84+
# last data of the previous chunk
85+
return np.where(np.isnan(b), a, b)
86+
87+
88+
def _dtype_push(a, axis, dtype=None):
89+
from xarray.core.duck_array_ops import _push
90+
91+
# Not sure why the blelloch algorithm force to receive a dtype
92+
return _push(a, axis=axis)
93+
94+
95+
def _reset_cumsum(a, axis, dtype=None):
96+
import numpy as np
97+
98+
cumsum = np.cumsum(a, axis=axis)
99+
reset_points = np.maximum.accumulate(np.where(a == 0, cumsum, 0), axis=axis)
100+
return cumsum - reset_points
101+
102+
103+
def _last_reset_cumsum(a, axis, keepdims=None):
104+
import numpy as np
105+
106+
# Take the last cumulative sum taking into account the reset
107+
# This is useful for blelloch method
108+
return np.take(_reset_cumsum(a, axis=axis), axis=axis, indices=[-1])
109+
110+
111+
def _combine_reset_cumsum(a, b, axis):
112+
import numpy as np
113+
114+
# It is going to sum the previous result until the first
115+
# non nan value
116+
bitmask = np.cumprod(b != 0, axis=axis)
117+
return np.where(bitmask, b + a, b)
118+
119+
78120
def push(array, n, axis, method="blelloch"):
79121
"""
80122
Dask-aware bottleneck.push
@@ -91,16 +133,6 @@ def push(array, n, axis, method="blelloch"):
91133
# TODO: Replace all this function
92134
# once https://github.com/pydata/xarray/issues/9229 being implemented
93135

94-
def _fill_with_last_one(a, b):
95-
# cumreduction apply the push func over all the blocks first so,
96-
# the only missing part is filling the missing values using the
97-
# last data of the previous chunk
98-
return np.where(np.isnan(b), a, b)
99-
100-
def _dtype_push(a, axis, dtype=None):
101-
# Not sure why the blelloch algorithm force to receive a dtype
102-
return _push(a, axis=axis)
103-
104136
pushed_array = da.reductions.cumreduction(
105137
func=_dtype_push,
106138
binop=_fill_with_last_one,
@@ -113,26 +145,9 @@ def _dtype_push(a, axis, dtype=None):
113145
)
114146

115147
if n is not None and 0 < n < array.shape[axis] - 1:
116-
117-
def _reset_cumsum(a, axis, dtype=None):
118-
cumsum = np.cumsum(a, axis=axis)
119-
reset_points = np.maximum.accumulate(np.where(a == 0, cumsum, 0), axis=axis)
120-
return cumsum - reset_points
121-
122-
def _last_reset_cumsum(a, axis, keepdims=None):
123-
# Take the last cumulative sum taking into account the reset
124-
# This is useful for blelloch method
125-
return np.take(_reset_cumsum(a, axis=axis), axis=axis, indices=[-1])
126-
127-
def _combine_reset_cumsum(a, b):
128-
# It is going to sum the previous result until the first
129-
# non nan value
130-
bitmask = np.cumprod(b != 0, axis=axis)
131-
return np.where(bitmask, b + a, b)
132-
133148
valid_positions = da.reductions.cumreduction(
134149
func=_reset_cumsum,
135-
binop=_combine_reset_cumsum,
150+
binop=partial(_combine_reset_cumsum, axis=axis),
136151
ident=0,
137152
x=da.isnan(array, dtype=int),
138153
axis=axis,

0 commit comments

Comments
 (0)