Skip to content

Commit 2561136

Browse files
authored
Don't hardcode cuda in test files (#160)
1 parent d507981 commit 2561136

File tree

4 files changed

+33
-32
lines changed

4 files changed

+33
-32
lines changed

test/test_control_flow.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,9 @@ def mul_relu_block_backward_kernel(
248248
return dx, dy
249249
return dx, dy.sum(axis=-1)
250250

251-
x = torch.randn(512, 1024, device="cuda", requires_grad=True)
252-
y = torch.randn(512, device="cuda", requires_grad=True)
253-
dz = torch.randn(512, 1024, device="cuda")
251+
x = torch.randn(512, 1024, device=DEVICE, requires_grad=True)
252+
y = torch.randn(512, device=DEVICE, requires_grad=True)
253+
dz = torch.randn(512, 1024, device=DEVICE)
254254
expected = mul_relu_block_back_spec(x, y, dz)
255255
torch.testing.assert_close(
256256
mul_relu_block_backward_kernel(x, y, dz, False),

test/test_indexing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def arange(length: int, device: torch.device) -> torch.Tensor:
2424

2525
code, result = code_and_output(
2626
arange,
27-
(100, torch.device("cuda")),
27+
(100, DEVICE),
2828
block_size=32,
2929
)
3030
torch.testing.assert_close(

test/test_reductions.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88

99
import helion
10+
from helion._testing import DEVICE
1011
from helion._testing import code_and_output
1112
import helion.language as hl
1213

@@ -59,7 +60,7 @@ class TestReductions(TestCase):
5960
maxDiff = 16384
6061

6162
def test_sum(self):
62-
args = (torch.randn([512, 512], device="cuda"),)
63+
args = (torch.randn([512, 512], device=DEVICE),)
6364
code, output = code_and_output(sum_kernel, args, block_size=1)
6465
torch.testing.assert_close(output, args[0].sum(-1), rtol=1e-04, atol=1e-04)
6566
self.assertExpectedInline(
@@ -98,7 +99,7 @@ def _sum_kernel_make_precompiler(x: torch.Tensor):
9899
)
99100

100101
def test_sum_keepdims(self):
101-
args = (torch.randn([512, 512], device="cuda"),)
102+
args = (torch.randn([512, 512], device=DEVICE),)
102103
code, output = code_and_output(
103104
sum_kernel_keepdims, args, block_size=16, indexing="block_ptr"
104105
)
@@ -141,7 +142,7 @@ def _sum_kernel_keepdims_make_precompiler(x: torch.Tensor):
141142

142143
def test_argmin_argmax(self):
143144
for fn in (torch.argmin, torch.argmax):
144-
args = (torch.randn([512, 512], device="cuda"), fn, torch.int64)
145+
args = (torch.randn([512, 512], device=DEVICE), fn, torch.int64)
145146
code, output = code_and_output(
146147
reduce_kernel, args, block_size=16, indexing="block_ptr"
147148
)
@@ -197,7 +198,7 @@ def test_reduction_functions(self):
197198
torch.sum,
198199
torch.mean,
199200
):
200-
args = (torch.randn([512, 512], device="cuda"), fn)
201+
args = (torch.randn([512, 512], device=DEVICE), fn)
201202
_, output = code_and_output(
202203
reduce_kernel,
203204
args,
@@ -210,46 +211,46 @@ def test_reduction_functions(self):
210211
)
211212

212213
def test_mean(self):
213-
args = (torch.randn([512, 512], device="cuda"), torch.mean, torch.float32)
214+
args = (torch.randn([512, 512], device=DEVICE), torch.mean, torch.float32)
214215
self.assertExpectedInline(
215216
reduce_kernel.bind(args)._debug_str(),
216217
"""\
217218
def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], out_dtype=torch.float32):
218-
# Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=<SourceLocation test_reductions.py:47>)
219+
# Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=<SourceLocation test_reductions.py:48>)
219220
# Attribute: TensorAttributeType AttributeOrigin(value=ArgumentOrigin(name='x'), key='size')
220221
# Name: TensorType([x_size0, x_size1], torch.float32) ArgumentOrigin(name='x')
221222
n, _m = x.size()
222-
# Call: TensorType([x_size0], torch.float32) SourceOrigin(location=<SourceLocation test_reductions.py:48>)
223+
# Call: TensorType([x_size0], torch.float32) SourceOrigin(location=<SourceLocation test_reductions.py:49>)
223224
# Attribute: CallableType(_VariableFunctionsClass.empty) AttributeOrigin(value=GlobalOrigin(name='torch'), key='empty')
224225
# Name: PythonModuleType(torch) GlobalOrigin(name='torch')
225-
# List: SequenceType([SymIntType(s77)]) SourceOrigin(location=<SourceLocation test_reductions.py:49>)
226-
# Name: SymIntType(s77) GetItemOrigin(value=SourceOrigin(location=<SourceLocation test_reductions.py:47>), key=0)
226+
# List: SequenceType([SymIntType(s77)]) SourceOrigin(location=<SourceLocation test_reductions.py:50>)
227+
# Name: SymIntType(s77) GetItemOrigin(value=SourceOrigin(location=<SourceLocation test_reductions.py:48>), key=0)
227228
# Name: LiteralType(torch.float32) ArgumentOrigin(name='out_dtype')
228229
# Attribute: LiteralType(device(type='cuda', index=0)) AttributeOrigin(value=ArgumentOrigin(name='x'), key='device')
229230
# Name: TensorType([x_size0, x_size1], torch.float32) ArgumentOrigin(name='x')
230231
# For: loop_type=GRID
231232
out = torch.empty([n], dtype=out_dtype, device=x.device)
232-
# Call: IterType(TileIndexType(0)) SourceOrigin(location=<SourceLocation test_reductions.py:53>)
233+
# Call: IterType(TileIndexType(0)) SourceOrigin(location=<SourceLocation test_reductions.py:54>)
233234
# Attribute: CallableType(tile) AttributeOrigin(value=GlobalOrigin(name='hl'), key='tile')
234235
# Name: PythonModuleType(helion.language) GlobalOrigin(name='hl')
235-
# Name: SymIntType(s77) GetItemOrigin(value=SourceOrigin(location=<SourceLocation test_reductions.py:47>), key=0)
236+
# Name: SymIntType(s77) GetItemOrigin(value=SourceOrigin(location=<SourceLocation test_reductions.py:48>), key=0)
236237
for tile_n in hl.tile(n):
237-
# Subscript: TensorType([block_size_0], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:54>)
238-
# Name: TensorType([x_size0], torch.float32) SourceOrigin(location=<SourceLocation test_reductions.py:48>)
239-
# Name: TileIndexType(0) SourceOrigin(location=<SourceLocation test_reductions.py:53>)
240-
# Call: TensorType([block_size_0], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:54>)
238+
# Subscript: TensorType([block_size_0], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:55>)
239+
# Name: TensorType([x_size0], torch.float32) SourceOrigin(location=<SourceLocation test_reductions.py:49>)
240+
# Name: TileIndexType(0) SourceOrigin(location=<SourceLocation test_reductions.py:54>)
241+
# Call: TensorType([block_size_0], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:55>)
241242
# Name: CallableType(_VariableFunctionsClass.mean) ArgumentOrigin(name='fn')
242-
# Subscript: TensorType([block_size_0, rdim_1], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:54>)
243+
# Subscript: TensorType([block_size_0, rdim_1], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:55>)
243244
# Name: TensorType([x_size0, x_size1], torch.float32) ArgumentOrigin(name='x')
244-
# Name: TileIndexType(0) SourceOrigin(location=<SourceLocation test_reductions.py:53>)
245-
# Slice: SliceType(LiteralType(None):LiteralType(None):LiteralType(None)) DeviceOrigin(location=<SourceLocation test_reductions.py:54>)
246-
# UnaryOp: LiteralType(-1) DeviceOrigin(location=<SourceLocation test_reductions.py:54>)
247-
# Constant: LiteralType(1) DeviceOrigin(location=<SourceLocation test_reductions.py:54>)
245+
# Name: TileIndexType(0) SourceOrigin(location=<SourceLocation test_reductions.py:54>)
246+
# Slice: SliceType(LiteralType(None):LiteralType(None):LiteralType(None)) DeviceOrigin(location=<SourceLocation test_reductions.py:55>)
247+
# UnaryOp: LiteralType(-1) DeviceOrigin(location=<SourceLocation test_reductions.py:55>)
248+
# Constant: LiteralType(1) DeviceOrigin(location=<SourceLocation test_reductions.py:55>)
248249
out[tile_n] = fn(x[tile_n, :], dim=-1)
249250
return out
250251
251252
def root_graph_0():
252-
# File: .../test_reductions.py:54 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
253+
# File: .../test_reductions.py:55 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
253254
x: "f32[s77, s27]" = helion_language__tracing_ops__host_tensor('x')
254255
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
255256
load: "f32[u0, u1]" = helion_language_memory_ops_load(x, [block_size_0, slice(None, None, None)], None); x = None
@@ -260,15 +261,15 @@ def root_graph_0():
260261
return None
261262
262263
def reduction_loop_1():
263-
# File: .../test_reductions.py:54 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
264+
# File: .../test_reductions.py:55 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
264265
x: "f32[s77, s27]" = helion_language__tracing_ops__host_tensor('x')
265266
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
266267
load: "f32[u0, u1]" = helion_language_memory_ops_load(x, [block_size_0, slice(None, None, None)], None); x = block_size_0 = None
267268
mean_extra: "f32[u0]" = helion_language__tracing_ops__inductor_lowering_extra([load]); load = None
268269
return [mean_extra]
269270
270271
def root_graph_2():
271-
# File: .../test_reductions.py:54 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
272+
# File: .../test_reductions.py:55 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
272273
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
273274
_get_symnode: "Sym(s27)" = helion_language__tracing_ops__get_symnode('rdim1')
274275
_for_loop = helion_language__tracing_ops__for_loop(1, [0], [_get_symnode], []); _get_symnode = None
@@ -318,7 +319,7 @@ def _reduce_kernel_make_precompiler(x: torch.Tensor, fn: Callable[[torch.Tensor]
318319
)
319320

320321
def test_sum_looped(self):
321-
args = (torch.randn([512, 512], device="cuda"),)
322+
args = (torch.randn([512, 512], device=DEVICE),)
322323
code, output = code_and_output(
323324
sum_kernel, args, block_size=2, reduction_loop=64
324325
)
@@ -367,7 +368,7 @@ def _sum_kernel_make_precompiler(x: torch.Tensor):
367368

368369
def test_argmin_argmax_looped(self):
369370
for fn in (torch.argmin, torch.argmax):
370-
args = (torch.randn([512, 512], device="cuda"), fn, torch.int64)
371+
args = (torch.randn([512, 512], device=DEVICE), fn, torch.int64)
371372
code, output = code_and_output(
372373
reduce_kernel,
373374
args,

test/test_register_tunable.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def kernel_with_tunable(x: torch.Tensor) -> torch.Tensor:
2929

3030
return out
3131

32-
x = torch.randn(128, device="cuda", dtype=torch.float32)
32+
x = torch.randn(128, device=DEVICE, dtype=torch.float32)
3333
code, result = code_and_output(kernel_with_tunable, (x,))
3434
expected = x * 2.0
3535
torch.testing.assert_close(result, expected)
@@ -87,7 +87,7 @@ def kernel_with_int_param(x: torch.Tensor) -> torch.Tensor:
8787
out[tile_n] = x[tile_n] * multiplier
8888
return out
8989

90-
x = torch.randn(128, device="cuda", dtype=torch.float32)
90+
x = torch.randn(128, device=DEVICE, dtype=torch.float32)
9191
code, result = code_and_output(
9292
kernel_with_int_param, (x,), block_size=64, multiplier=4
9393
)
@@ -150,7 +150,7 @@ def kernel_with_enum(x: torch.Tensor) -> torch.Tensor:
150150

151151
return out
152152

153-
x = torch.randn(128, device="cuda", dtype=torch.float32)
153+
x = torch.randn(128, device=DEVICE, dtype=torch.float32)
154154
result = kernel_with_enum(x)
155155
expected = x * 2.0
156156
torch.testing.assert_close(result, expected)

0 commit comments

Comments
 (0)