Skip to content

Commit 25da0e4

Browse files
committed
Wrap trace for dask
1 parent 0734064 commit 25da0e4

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

array_api_compat/dask/array/linalg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
# Exports
77
from dask.array.linalg import * # noqa: F403
8-
from dask.array import trace, outer
8+
from dask.array import outer
99

1010
# These functions are in both the main and linalg namespaces
1111
from dask.array import matmul, tensordot
@@ -42,6 +42,7 @@ def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced',
4242
if mode != "reduced":
4343
raise ValueError("dask arrays only support using mode='reduced'")
4444
return QRResult(*da.linalg.qr(x, **kwargs))
45+
trace = get_xp(da)(_linalg.trace)
4546
cholesky = get_xp(da)(_linalg.cholesky)
4647
matrix_rank = get_xp(da)(_linalg.matrix_rank)
4748
matrix_norm = get_xp(da)(_linalg.matrix_norm)

dask-xfails.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,6 @@ array_api_tests/test_linalg.py::test_svdvals
8181
array_api_tests/test_linalg.py::test_cholesky
8282
# dtype mismatch got uint64, but should be uint8, NPY_PROMOTION_STATE=weak doesn't help :(
8383
array_api_tests/test_linalg.py::test_tensordot
84-
# probably same reason for failing as numpy
85-
array_api_tests/test_linalg.py::test_trace
8684

8785
# AssertionError: out.dtype=uint64, but should be uint8 [tensordot(uint8, uint8)]
8886
array_api_tests/test_linalg.py::test_linalg_tensordot

0 commit comments

Comments
 (0)