Skip to content

Commit 5d98aa8

Browse files
committed
Merge branch 'main' into typ_v4
2 parents 924fc3d + 2b5e289 commit 5d98aa8

File tree

8 files changed

+67
-5
lines changed

8 files changed

+67
-5
lines changed

array_api_compat/cupy/_aliases.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,11 @@ def count_nonzero(
118118
return result
119119

120120

121+
# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg
122+
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
123+
return cp.take_along_axis(x, indices, axis=axis)
124+
125+
121126
# These functions are completely new here. If the library already has them
122127
# (i.e., numpy 2.0), use the library version instead of our wrapper.
123128
if hasattr(cp, 'vecdot'):
@@ -139,6 +144,7 @@ def count_nonzero(
139144
'acos', 'acosh', 'asin', 'asinh', 'atan',
140145
'atan2', 'atanh', 'bitwise_left_shift',
141146
'bitwise_invert', 'bitwise_right_shift',
142-
'bool', 'concat', 'count_nonzero', 'pow', 'sign']
147+
'bool', 'concat', 'count_nonzero', 'pow', 'sign',
148+
'take_along_axis']
143149

144150
_all_ignore = ['cp', 'get_xp']

array_api_compat/numpy/_aliases.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ def count_nonzero(
125125
return result
126126

127127

128+
# take_along_axis: axis defaults to -1 but in numpy axis is a required arg
129+
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
130+
return np.take_along_axis(x, indices, axis=axis)
131+
132+
128133
# These functions are completely new here. If the library already has them
129134
# (i.e., numpy 2.0), use the library version instead of our wrapper.
130135
if hasattr(np, "vecdot"):
@@ -160,6 +165,7 @@ def count_nonzero(
160165
"concat",
161166
"count_nonzero",
162167
"pow",
168+
"take_along_axis"
163169
]
164170
__all__ += _aliases.__all__
165171
_all_ignore = ["np", "get_xp"]

array_api_compat/torch/_aliases.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Sequence
44
from functools import reduce as _reduce, wraps as _wraps
55
from builtins import all as _builtin_all, any as _builtin_any
6-
from typing import Any
6+
from typing import Any, List, Optional, Sequence, Tuple, Union, Literal
77

88
import torch
99

@@ -547,8 +547,12 @@ def count_nonzero(
547547
) -> Array:
548548
result = torch.count_nonzero(x, dim=axis)
549549
if keepdims:
550-
if axis is not None:
550+
if isinstance(axis, int):
551551
return result.unsqueeze(axis)
552+
elif isinstance(axis, tuple):
553+
n_axis = [x.ndim + ax if ax < 0 else ax for ax in axis]
554+
sh = [1 if i in n_axis else x.shape[i] for i in range(x.ndim)]
555+
return torch.reshape(result, sh)
552556
return _axis_none_keepdims(result, x.ndim, keepdims)
553557
else:
554558
return result
@@ -820,6 +824,12 @@ def sign(x: Array, /) -> Array:
820824
return out
821825

822826

827+
def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> List[Array]:
828+
# enforce the default of 'xy'
829+
# TODO: is the return type a list or a tuple
830+
return list(torch.meshgrid(*arrays, indexing='xy'))
831+
832+
823833
__all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast',
824834
'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add',
825835
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
@@ -836,6 +846,6 @@ def sign(x: Array, /) -> Array:
836846
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
837847
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
838848
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
839-
'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat']
849+
'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat', 'meshgrid']
840850

841851
_all_ignore = ['torch', 'get_xp']

cupy-xfails.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub
3434
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, s)]
3535
# floating point inaccuracy
3636
array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)]
37+
# incomplete NEP50 support in CuPy 13.x (fixed in 14.0.0a1)
38+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow]
39+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp]
40+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter]
41+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot]
3742

3843
# cupy (arg)min/max wrong with infinities
3944
# https://github.com/cupy/cupy/issues/7424
@@ -182,7 +187,6 @@ array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
182187
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
183188

184189
# 2024.12 support
185-
array_api_tests/test_signatures.py::test_func_signature[count_nonzero]
186190
array_api_tests/test_signatures.py::test_func_signature[bitwise_and]
187191
array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift]
188192
array_api_tests/test_signatures.py::test_func_signature[bitwise_or]

dask-xfails.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ array_api_tests/test_creation_functions.py::test_linspace
2424
# Shape mismatch
2525
array_api_tests/test_indexing_functions.py::test_take
2626

27+
# missing `take_along_axis`, https://github.com/dask/dask/issues/3663
28+
array_api_tests/test_indexing_functions.py::test_take_along_axis
29+
2730
# Array methods and attributes not already on da.Array cannot be wrapped
2831
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
2932
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]

numpy-1-22-xfails.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,20 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtr
123123
array_api_tests/test_searching_functions.py::test_where
124124
array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0]
125125

126+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[add]
127+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide]
128+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot]
129+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[subtract]
130+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp]
131+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter]
132+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[multiply]
133+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum]
134+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign]
135+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow]
136+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow]
137+
138+
array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars
139+
126140
# 2023.12 support
127141
array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack]
128142
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]

numpy-1-26-xfails.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift]
5050
array_api_tests/test_signatures.py::test_func_signature[bitwise_xor]
5151
array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars
5252

53+
array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars
54+
5355
# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that
5456
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
5557
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]

tests/test_torch.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,20 @@ def test_gh_273(self, default_dt, dtype_a, dtype_b):
100100
assert dtype_1 == dtype_2
101101
finally:
102102
torch.set_default_dtype(prev_default)
103+
104+
105+
def test_meshgrid():
106+
"""Verify that array_api_compat.torch.meshgrid defaults to indexing='xy'."""
107+
108+
x, y = xp.asarray([1, 2]), xp.asarray([4])
109+
110+
X, Y = xp.meshgrid(x, y)
111+
112+
# output of torch.meshgrid(x, y, indexing='xy') -- indexing='ij' is different
113+
X_xy, Y_xy = xp.asarray([[1, 2]]), xp.asarray([[4, 4]])
114+
115+
assert X.shape == X_xy.shape
116+
assert xp.all(X == X_xy)
117+
118+
assert Y.shape == Y_xy.shape
119+
assert xp.all(Y == Y_xy)

0 commit comments

Comments
 (0)