Skip to content

Commit 8c4ae55

Browse files
authored
Rechunk where dict has missing axes (#546)
When rechunking with a dict that doesn't contain all axes, then the chunking should be unchanged for those axes that are missing. In particular, `a.rechunk({})` should be a no-op. This is consistent with Dask (dask/dask#11261) and Xarray (pydata/xarray#9286)
1 parent 1633431 commit 8c4ae55

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

cubed/core/ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,16 @@ def wrap(*a, block_id=None, **kw):
756756

757757

758758
def rechunk(x, chunks, target_store=None):
759+
if isinstance(chunks, dict):
760+
chunks = {validate_axis(c, x.ndim): v for c, v in chunks.items()}
761+
for i in range(x.ndim):
762+
if i not in chunks:
763+
chunks[i] = x.chunks[i]
764+
elif chunks[i] is None:
765+
chunks[i] = x.chunks[i]
766+
if isinstance(chunks, (tuple, list)):
767+
chunks = tuple(lc if lc is not None else rc for lc, rc in zip(chunks, x.chunks))
768+
759769
normalized_chunks = normalize_chunks(chunks, x.shape, dtype=x.dtype)
760770
if x.chunks == normalized_chunks:
761771
return x

cubed/tests/test_core.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,10 +246,19 @@ def test_multiple_ops(spec, executor):
246246
)
247247

248248

249-
@pytest.mark.parametrize("new_chunks", [(1, 2), {0: 1, 1: 2}])
250-
def test_rechunk(spec, executor, new_chunks):
249+
@pytest.mark.parametrize(
250+
("new_chunks", "expected_chunks"),
251+
[
252+
((1, 2), ((1, 1, 1), (2, 1))),
253+
({0: 1, 1: 2}, ((1, 1, 1), (2, 1))),
254+
({1: 2}, ((2, 1), (2, 1))), # dim 0 unchanged
255+
({}, ((2, 1), (1, 1, 1))), # unchanged
256+
],
257+
)
258+
def test_rechunk(spec, executor, new_chunks, expected_chunks):
251259
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 1), spec=spec)
252260
b = a.rechunk(new_chunks)
261+
assert b.chunks == expected_chunks
253262
assert_array_equal(
254263
b.compute(executor=executor),
255264
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),

0 commit comments

Comments
 (0)