Skip to content

Commit 941a2bf

Browse files
committed
Revert "WIP: remove test_torch.py"
This reverts commit a7f56b9.
1 parent a7f56b9 commit 941a2bf

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

tests/test_torch.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""Test "unspecified" behavior which we cannot easily test in the Array API test suite.
2+
"""
3+
import itertools
4+
5+
import pytest
6+
import torch
7+
8+
from array_api_compat import torch as xp
9+
10+
11+
class TestResultType:
12+
def test_empty(self):
13+
with pytest.raises(ValueError):
14+
xp.result_type()
15+
16+
def test_one_arg(self):
17+
for x in [1, 1.0, 1j, '...', None]:
18+
with pytest.raises((ValueError, AttributeError)):
19+
xp.result_type(x)
20+
21+
for x in [xp.float32, xp.int64, torch.complex64]:
22+
assert xp.result_type(x) == x
23+
24+
for x in [xp.asarray(True, dtype=xp.bool), xp.asarray(1, dtype=xp.complex64)]:
25+
assert xp.result_type(x) == x.dtype
26+
27+
def test_two_args(self):
28+
# Only include here things "unspecified" in the spec
29+
30+
# scalar, tensor or tensor,tensor
31+
for x, y in [
32+
(1., 1j),
33+
(1j, xp.arange(3)),
34+
(True, xp.asarray(3.)),
35+
(xp.ones(3) == 1, 1j*xp.ones(3)),
36+
]:
37+
assert xp.result_type(x, y) == torch.result_type(x, y)
38+
39+
# dtype, scalar
40+
for x, y in [
41+
(1j, xp.int64),
42+
(True, xp.float64),
43+
]:
44+
assert xp.result_type(x, y) == torch.result_type(x, xp.empty([], dtype=y))
45+
46+
# dtype, dtype
47+
for x, y in [
48+
(xp.bool, xp.complex64)
49+
]:
50+
xt, yt = xp.empty([], dtype=x), xp.empty([], dtype=y)
51+
assert xp.result_type(x, y) == torch.result_type(xt, yt)
52+
53+
def test_multi_arg(self):
54+
torch.set_default_dtype(torch.float32)
55+
56+
args = [1., 5, 3, torch.asarray([3], dtype=torch.float16), 5, 6, 1.]
57+
assert xp.result_type(*args) == torch.float16
58+
59+
args = [1, 2, 3j, xp.arange(3, dtype=xp.float32), 4, 5, 6]
60+
assert xp.result_type(*args) == xp.complex64
61+
62+
args = [1, 2, 3j, xp.float64, 4, 5, 6]
63+
assert xp.result_type(*args) == xp.complex128
64+
65+
args = [1, 2, 3j, xp.float64, 4, xp.asarray(3, dtype=xp.int16), 5, 6, False]
66+
assert xp.result_type(*args) == xp.complex128
67+
68+
i64 = xp.ones(1, dtype=xp.int64)
69+
f16 = xp.ones(1, dtype=xp.float16)
70+
for i in itertools.permutations([i64, f16, 1.0, 1.0]):
71+
assert xp.result_type(*i) == xp.float16, f"{i}"
72+
73+
with pytest.raises(ValueError):
74+
xp.result_type(1, 2, 3, 4)

0 commit comments

Comments
 (0)