1
1
from __future__ import annotations
2
2
3
3
import math
4
+ from functools import partial
4
5
5
6
from xarray .core import dtypes , nputils
6
7
@@ -75,6 +76,47 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
75
76
return coeffs , residuals
76
77
77
78
79
+ def _fill_with_last_one (a , b ):
80
+ import numpy as np
81
+
82
+ # cumreduction apply the push func over all the blocks first so,
83
+ # the only missing part is filling the missing values using the
84
+ # last data of the previous chunk
85
+ return np .where (np .isnan (b ), a , b )
86
+
87
+
88
+ def _dtype_push (a , axis , dtype = None ):
89
+ from xarray .core .duck_array_ops import _push
90
+
91
+ # Not sure why the blelloch algorithm force to receive a dtype
92
+ return _push (a , axis = axis )
93
+
94
+
95
+ def _reset_cumsum (a , axis , dtype = None ):
96
+ import numpy as np
97
+
98
+ cumsum = np .cumsum (a , axis = axis )
99
+ reset_points = np .maximum .accumulate (np .where (a == 0 , cumsum , 0 ), axis = axis )
100
+ return cumsum - reset_points
101
+
102
+
103
+ def _last_reset_cumsum (a , axis , keepdims = None ):
104
+ import numpy as np
105
+
106
+ # Take the last cumulative sum taking into account the reset
107
+ # This is useful for blelloch method
108
+ return np .take (_reset_cumsum (a , axis = axis ), axis = axis , indices = [- 1 ])
109
+
110
+
111
+ def _combine_reset_cumsum (a , b , axis ):
112
+ import numpy as np
113
+
114
+ # It is going to sum the previous result until the first
115
+ # non nan value
116
+ bitmask = np .cumprod (b != 0 , axis = axis )
117
+ return np .where (bitmask , b + a , b )
118
+
119
+
78
120
def push (array , n , axis , method = "blelloch" ):
79
121
"""
80
122
Dask-aware bottleneck.push
@@ -91,16 +133,6 @@ def push(array, n, axis, method="blelloch"):
91
133
# TODO: Replace all this function
92
134
# once https://github.com/pydata/xarray/issues/9229 being implemented
93
135
94
- def _fill_with_last_one (a , b ):
95
- # cumreduction apply the push func over all the blocks first so,
96
- # the only missing part is filling the missing values using the
97
- # last data of the previous chunk
98
- return np .where (np .isnan (b ), a , b )
99
-
100
- def _dtype_push (a , axis , dtype = None ):
101
- # Not sure why the blelloch algorithm force to receive a dtype
102
- return _push (a , axis = axis )
103
-
104
136
pushed_array = da .reductions .cumreduction (
105
137
func = _dtype_push ,
106
138
binop = _fill_with_last_one ,
@@ -113,26 +145,9 @@ def _dtype_push(a, axis, dtype=None):
113
145
)
114
146
115
147
if n is not None and 0 < n < array .shape [axis ] - 1 :
116
-
117
- def _reset_cumsum (a , axis , dtype = None ):
118
- cumsum = np .cumsum (a , axis = axis )
119
- reset_points = np .maximum .accumulate (np .where (a == 0 , cumsum , 0 ), axis = axis )
120
- return cumsum - reset_points
121
-
122
- def _last_reset_cumsum (a , axis , keepdims = None ):
123
- # Take the last cumulative sum taking into account the reset
124
- # This is useful for blelloch method
125
- return np .take (_reset_cumsum (a , axis = axis ), axis = axis , indices = [- 1 ])
126
-
127
- def _combine_reset_cumsum (a , b ):
128
- # It is going to sum the previous result until the first
129
- # non nan value
130
- bitmask = np .cumprod (b != 0 , axis = axis )
131
- return np .where (bitmask , b + a , b )
132
-
133
148
valid_positions = da .reductions .cumreduction (
134
149
func = _reset_cumsum ,
135
- binop = _combine_reset_cumsum ,
150
+ binop = partial ( _combine_reset_cumsum , axis = axis ) ,
136
151
ident = 0 ,
137
152
x = da .isnan (array , dtype = int ),
138
153
axis = axis ,
0 commit comments