Skip to content

Commit 52241c7

Browse files
author
Luke Shaw
committed
Make behaviour of compute consistent for slicing
1 parent 3f2f722 commit 52241c7

File tree

2 files changed

+65
-24
lines changed

2 files changed

+65
-24
lines changed

src/blosc2/lazyexpr.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,7 +1378,7 @@ def slices_eval( # noqa: C901
13781378
for i, (c, s) in enumerate(zip(coords, chunks, strict=True))
13791379
)
13801380
# Check whether current slice_ intersects with _slice
1381-
if _slice is not None and _slice != ():
1381+
if _slice is not None and _slice != (): # can't use != when _slice is np.int
13821382
# Ensure that _slice is of type slice
13831383
key = ndindex.ndindex(_slice).expand(shape).raw
13841384
_slice = tuple(k if isinstance(k, slice) else slice(k, k + 1, None) for k in key)
@@ -1508,19 +1508,7 @@ def slices_eval( # noqa: C901
15081508
else:
15091509
raise ValueError("The where condition must be a tuple with one or two elements")
15101510

1511-
if orig_slice is not None:
1512-
if isinstance(out, np.ndarray):
1513-
out = out[orig_slice]
1514-
if _order is not None:
1515-
indices_ = indices_[orig_slice]
1516-
elif isinstance(out, blosc2.NDArray):
1517-
# It *seems* better to choose an automatic chunks and blocks for the output array
1518-
# out = out.slice(orig_slice, chunks=out.chunks, blocks=out.blocks)
1519-
out = out.slice(orig_slice)
1520-
else:
1521-
raise ValueError("The output array is not a NumPy array or a NDArray")
1522-
1523-
if where is not None and len(where) < 2:
1511+
if where is not None and len(where) < 2: # Don't need to take orig_slice since filled up from 0 index
15241512
if _order is not None:
15251513
# argsort the result following _order
15261514
new_order = np.argsort(out[:lenout])
@@ -1532,6 +1520,19 @@ def slices_eval( # noqa: C901
15321520
else:
15331521
out.resize((lenout,))
15341522

1523+
else: # Need to take orig_slice since filled up array according to slice_ for each chunk
1524+
if orig_slice is not None:
1525+
if isinstance(out, np.ndarray):
1526+
out = out[orig_slice]
1527+
if _order is not None:
1528+
indices_ = indices_[orig_slice]
1529+
elif isinstance(out, blosc2.NDArray):
1530+
# It *seems* better to choose an automatic chunks and blocks for the output array
1531+
# out = out.slice(orig_slice, chunks=out.chunks, blocks=out.blocks)
1532+
out = out.slice(orig_slice)
1533+
else:
1534+
raise ValueError("The output array is not a NumPy array or a NDArray")
1535+
15351536
return out
15361537

15371538

tests/ndarray/test_lazyexpr_fields.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -279,17 +279,19 @@ def test_where_one_param(array_fixture):
279279
res = np.sort(res)
280280
nres = np.sort(nres)
281281
np.testing.assert_allclose(res[:], nres)
282+
282283
# Test with getitem
283284
sl = slice(100)
284285
res = expr.where(a1)[sl]
286+
nres = na1[sl][ne_evaluate("na1**2 + na2**2 > 2 * na1 * na2 + 1")[sl]]
285287
if len(a1.shape) == 1 or a1.chunks == a1.shape:
286288
# TODO: fix this, as it seems that is not working well for numexpr?
287289
if blosc2.IS_WASM:
288290
return
289-
np.testing.assert_allclose(res, nres[sl])
291+
np.testing.assert_allclose(res, nres)
290292
else:
291293
# In this case, we cannot compare results, only the length
292-
assert len(res) == len(nres[sl])
294+
assert len(res) == len(nres)
293295

294296

295297
# Test where indirectly via a condition in getitem in a NDArray
@@ -330,25 +332,26 @@ def test_where_getitem(array_fixture):
330332
# Test with partial slice
331333
sl = slice(100)
332334
res = sa1[a1**2 + a2**2 > 2 * a1 * a2 + 1][sl]
335+
nres = nsa1[sl][ne_evaluate("na1**2 + na2**2 > 2 * na1 * na2 + 1")[sl]]
333336
if len(a1.shape) == 1 or a1.chunks == a1.shape:
334337
# TODO: fix this, as it seems that is not working well for numexpr?
335338
if blosc2.IS_WASM:
336339
return
337-
np.testing.assert_allclose(res["a"], nres[sl]["a"])
338-
np.testing.assert_allclose(res["b"], nres[sl]["b"])
340+
np.testing.assert_allclose(res["a"], nres["a"])
341+
np.testing.assert_allclose(res["b"], nres["b"])
339342
else:
340343
# In this case, we cannot compare results, only the length
341-
assert len(res["a"]) == len(nres[sl]["a"])
342-
assert len(res["b"]) == len(nres[sl]["b"])
344+
assert len(res["a"]) == len(nres["a"])
345+
assert len(res["b"]) == len(nres["b"])
343346
# string version
344347
res = sa1["a**2 + b**2 > 2 * a * b + 1"][sl]
345348
if len(a1.shape) == 1 or a1.chunks == a1.shape:
346-
np.testing.assert_allclose(res["a"], nres[sl]["a"])
347-
np.testing.assert_allclose(res["b"], nres[sl]["b"])
349+
np.testing.assert_allclose(res["a"], nres["a"])
350+
np.testing.assert_allclose(res["b"], nres["b"])
348351
else:
349352
# We cannot compare the results here, other than the length
350-
assert len(res["a"]) == len(nres[sl]["a"])
351-
assert len(res["b"]) == len(nres[sl]["b"])
353+
assert len(res["a"]) == len(nres["a"])
354+
assert len(res["b"]) == len(nres["b"])
352355

353356

354357
# Test where indirectly via a condition in getitem in a NDField
@@ -631,3 +634,40 @@ def test_col_reduction(reduce_op):
631634
ns = nreduc(nC[nC > 0])
632635
np.testing.assert_allclose(s, ns)
633636
np.testing.assert_allclose(s2, ns)
637+
638+
639+
def test_fields_indexing():
640+
N = 1000
641+
it = ((-x + 1, x - 2, 0.1 * x) for x in range(N))
642+
sa = blosc2.fromiter(
643+
it, dtype=[("A", "i4"), ("B", "f4"), ("C", "f8")], shape=(N,), urlpath="sa-1M.b2nd", mode="w"
644+
)
645+
expr = sa["(A < B)"]
646+
A = sa["A"][:]
647+
B = sa["B"][:]
648+
C = sa["C"][:]
649+
temp = sa[:]
650+
indices = A < B
651+
idx = np.argmax(indices)
652+
653+
# Returns less than 10 elements in general
654+
sliced = expr.compute(slice(0, 10))
655+
gotitem = expr[:10]
656+
np.testing.assert_array_equal(sliced[:], gotitem)
657+
np.testing.assert_array_equal(gotitem, temp[:10][indices[:10]])
658+
# Actually this makes sense since one can understand this as a request to compute on a portion of operands.
659+
# If one desires a portion of the result, one should compute the whole expression and then slice it.
660+
# For a general slice it is quite difficult to simply stop when the desired slice has been obtained. Or
661+
# to try to optimise chunk computation order.
662+
663+
# Get first true element
664+
sliced = expr.compute(idx)
665+
gotitem = expr[idx]
666+
np.testing.assert_array_equal(sliced[()], gotitem)
667+
np.testing.assert_array_equal(gotitem, temp[idx])
668+
669+
# Should return void arrays here.
670+
sliced = expr.compute(0) # typically gives array of zeros
671+
gotitem = expr[0] # gives an error
672+
np.testing.assert_array_equal(sliced[()], gotitem)
673+
np.testing.assert_array_equal(gotitem, temp[0])

0 commit comments

Comments
 (0)