Skip to content

Commit 74e2240

Browse files
committed
ENH: size() to return None on dask instead of nan
1 parent beac55b commit 74e2240

File tree

3 files changed

+34
-7
lines changed

3 files changed

+34
-7
lines changed

array_api_compat/common/_helpers.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -788,19 +788,24 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
788788
return x.to_device(device, stream=stream)
789789

790790

791-
def size(x):
791+
def size(x: Array) -> int | None:
792792
"""
793793
Return the total number of elements of x.
794794
795795
This is equivalent to `x.size` according to the `standard
796796
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html>`__.
797+
797798
This helper is included because PyTorch defines `size` in an
798799
:external+torch:meth:`incompatible way <torch.Tensor.size>`.
799-
800+
It also fixes dask.array's behaviour which returns nan for unknown sizes, whereas
801+
the standard requires None.
800802
"""
801-
if None in x.shape:
803+
if None in x.shape: # this happens e.g. in ndonnx
804+
return None
805+
out = math.prod(x.shape)
806+
if math.isnan(out): # this happens e.g. in dask
802807
return None
803-
return math.prod(x.shape)
808+
return out
804809

805810

806811
def is_writeable_array(x) -> bool:

tests/_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
all_libraries.append('sparse')
1212

1313
def import_(library, wrapper=False):
14-
if library == 'cupy':
14+
if library in ('cupy', 'ndonnx'):
1515
pytest.importorskip(library)
1616
if wrapper:
1717
if 'jax' in library:

tests/test_common.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
66
)
77

8-
from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device
9-
8+
from array_api_compat import (
9+
device, is_array_api_obj, is_writeable_array, size, to_device
10+
)
1011
from ._helpers import import_, wrapped_libraries, all_libraries
1112

1213
import pytest
@@ -92,6 +93,27 @@ def test_is_writeable_array_numpy():
9293
assert not is_writeable_array(x)
9394

9495

96+
@pytest.mark.parametrize("library", all_libraries)
97+
def test_size(library):
98+
xp = import_(library)
99+
x = xp.asarray([1, 2, 3])
100+
assert size(x) == 3
101+
102+
103+
def test_size_nan():
104+
xp = import_("dask.array")
105+
x = xp.arange(10)
106+
x = x[x < 5]
107+
assert size(x) is None # NaNs in the shape have been special-cased
108+
109+
110+
def test_size_none():
111+
xp = import_("ndonnx")
112+
x = xp.arange(10)
113+
x = x[x < 5]
114+
assert size(x) is None
115+
116+
95117
@pytest.mark.parametrize("library", all_libraries)
96118
def test_device(library):
97119
xp = import_(library, wrapper=True)

0 commit comments

Comments
 (0)