7
7
import torch
8
8
9
9
import helion
10
+ from helion ._testing import DEVICE
10
11
from helion ._testing import code_and_output
11
12
import helion .language as hl
12
13
@@ -59,7 +60,7 @@ class TestReductions(TestCase):
59
60
maxDiff = 16384
60
61
61
62
def test_sum (self ):
62
- args = (torch .randn ([512 , 512 ], device = "cuda" ),)
63
+ args = (torch .randn ([512 , 512 ], device = DEVICE ),)
63
64
code , output = code_and_output (sum_kernel , args , block_size = 1 )
64
65
torch .testing .assert_close (output , args [0 ].sum (- 1 ), rtol = 1e-04 , atol = 1e-04 )
65
66
self .assertExpectedInline (
@@ -98,7 +99,7 @@ def _sum_kernel_make_precompiler(x: torch.Tensor):
98
99
)
99
100
100
101
def test_sum_keepdims (self ):
101
- args = (torch .randn ([512 , 512 ], device = "cuda" ),)
102
+ args = (torch .randn ([512 , 512 ], device = DEVICE ),)
102
103
code , output = code_and_output (
103
104
sum_kernel_keepdims , args , block_size = 16 , indexing = "block_ptr"
104
105
)
@@ -141,7 +142,7 @@ def _sum_kernel_keepdims_make_precompiler(x: torch.Tensor):
141
142
142
143
def test_argmin_argmax (self ):
143
144
for fn in (torch .argmin , torch .argmax ):
144
- args = (torch .randn ([512 , 512 ], device = "cuda" ), fn , torch .int64 )
145
+ args = (torch .randn ([512 , 512 ], device = DEVICE ), fn , torch .int64 )
145
146
code , output = code_and_output (
146
147
reduce_kernel , args , block_size = 16 , indexing = "block_ptr"
147
148
)
@@ -197,7 +198,7 @@ def test_reduction_functions(self):
197
198
torch .sum ,
198
199
torch .mean ,
199
200
):
200
- args = (torch .randn ([512 , 512 ], device = "cuda" ), fn )
201
+ args = (torch .randn ([512 , 512 ], device = DEVICE ), fn )
201
202
_ , output = code_and_output (
202
203
reduce_kernel ,
203
204
args ,
@@ -210,46 +211,46 @@ def test_reduction_functions(self):
210
211
)
211
212
212
213
def test_mean (self ):
213
- args = (torch .randn ([512 , 512 ], device = "cuda" ), torch .mean , torch .float32 )
214
+ args = (torch .randn ([512 , 512 ], device = DEVICE ), torch .mean , torch .float32 )
214
215
self .assertExpectedInline (
215
216
reduce_kernel .bind (args )._debug_str (),
216
217
"""\
217
218
def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], out_dtype=torch.float32):
218
- # Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=<SourceLocation test_reductions.py:47 >)
219
+ # Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=<SourceLocation test_reductions.py:48 >)
219
220
# Attribute: TensorAttributeType AttributeOrigin(value=ArgumentOrigin(name='x'), key='size')
220
221
# Name: TensorType([x_size0, x_size1], torch.float32) ArgumentOrigin(name='x')
221
222
n, _m = x.size()
222
- # Call: TensorType([x_size0], torch.float32) SourceOrigin(location=<SourceLocation test_reductions.py:48 >)
223
+ # Call: TensorType([x_size0], torch.float32) SourceOrigin(location=<SourceLocation test_reductions.py:49 >)
223
224
# Attribute: CallableType(_VariableFunctionsClass.empty) AttributeOrigin(value=GlobalOrigin(name='torch'), key='empty')
224
225
# Name: PythonModuleType(torch) GlobalOrigin(name='torch')
225
- # List: SequenceType([SymIntType(s77)]) SourceOrigin(location=<SourceLocation test_reductions.py:49 >)
226
- # Name: SymIntType(s77) GetItemOrigin(value=SourceOrigin(location=<SourceLocation test_reductions.py:47 >), key=0)
226
+ # List: SequenceType([SymIntType(s77)]) SourceOrigin(location=<SourceLocation test_reductions.py:50 >)
227
+ # Name: SymIntType(s77) GetItemOrigin(value=SourceOrigin(location=<SourceLocation test_reductions.py:48 >), key=0)
227
228
# Name: LiteralType(torch.float32) ArgumentOrigin(name='out_dtype')
228
229
# Attribute: LiteralType(device(type='cuda', index=0)) AttributeOrigin(value=ArgumentOrigin(name='x'), key='device')
229
230
# Name: TensorType([x_size0, x_size1], torch.float32) ArgumentOrigin(name='x')
230
231
# For: loop_type=GRID
231
232
out = torch.empty([n], dtype=out_dtype, device=x.device)
232
- # Call: IterType(TileIndexType(0)) SourceOrigin(location=<SourceLocation test_reductions.py:53 >)
233
+ # Call: IterType(TileIndexType(0)) SourceOrigin(location=<SourceLocation test_reductions.py:54 >)
233
234
# Attribute: CallableType(tile) AttributeOrigin(value=GlobalOrigin(name='hl'), key='tile')
234
235
# Name: PythonModuleType(helion.language) GlobalOrigin(name='hl')
235
- # Name: SymIntType(s77) GetItemOrigin(value=SourceOrigin(location=<SourceLocation test_reductions.py:47 >), key=0)
236
+ # Name: SymIntType(s77) GetItemOrigin(value=SourceOrigin(location=<SourceLocation test_reductions.py:48 >), key=0)
236
237
for tile_n in hl.tile(n):
237
- # Subscript: TensorType([block_size_0], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:54 >)
238
- # Name: TensorType([x_size0], torch.float32) SourceOrigin(location=<SourceLocation test_reductions.py:48 >)
239
- # Name: TileIndexType(0) SourceOrigin(location=<SourceLocation test_reductions.py:53 >)
240
- # Call: TensorType([block_size_0], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:54 >)
238
+ # Subscript: TensorType([block_size_0], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:55 >)
239
+ # Name: TensorType([x_size0], torch.float32) SourceOrigin(location=<SourceLocation test_reductions.py:49 >)
240
+ # Name: TileIndexType(0) SourceOrigin(location=<SourceLocation test_reductions.py:54 >)
241
+ # Call: TensorType([block_size_0], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:55 >)
241
242
# Name: CallableType(_VariableFunctionsClass.mean) ArgumentOrigin(name='fn')
242
- # Subscript: TensorType([block_size_0, rdim_1], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:54 >)
243
+ # Subscript: TensorType([block_size_0, rdim_1], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:55 >)
243
244
# Name: TensorType([x_size0, x_size1], torch.float32) ArgumentOrigin(name='x')
244
- # Name: TileIndexType(0) SourceOrigin(location=<SourceLocation test_reductions.py:53 >)
245
- # Slice: SliceType(LiteralType(None):LiteralType(None):LiteralType(None)) DeviceOrigin(location=<SourceLocation test_reductions.py:54 >)
246
- # UnaryOp: LiteralType(-1) DeviceOrigin(location=<SourceLocation test_reductions.py:54 >)
247
- # Constant: LiteralType(1) DeviceOrigin(location=<SourceLocation test_reductions.py:54 >)
245
+ # Name: TileIndexType(0) SourceOrigin(location=<SourceLocation test_reductions.py:54 >)
246
+ # Slice: SliceType(LiteralType(None):LiteralType(None):LiteralType(None)) DeviceOrigin(location=<SourceLocation test_reductions.py:55 >)
247
+ # UnaryOp: LiteralType(-1) DeviceOrigin(location=<SourceLocation test_reductions.py:55 >)
248
+ # Constant: LiteralType(1) DeviceOrigin(location=<SourceLocation test_reductions.py:55 >)
248
249
out[tile_n] = fn(x[tile_n, :], dim=-1)
249
250
return out
250
251
251
252
def root_graph_0():
252
- # File: .../test_reductions.py:54 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
253
+ # File: .../test_reductions.py:55 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
253
254
x: "f32[s77, s27]" = helion_language__tracing_ops__host_tensor('x')
254
255
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
255
256
load: "f32[u0, u1]" = helion_language_memory_ops_load(x, [block_size_0, slice(None, None, None)], None); x = None
@@ -260,15 +261,15 @@ def root_graph_0():
260
261
return None
261
262
262
263
def reduction_loop_1():
263
- # File: .../test_reductions.py:54 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
264
+ # File: .../test_reductions.py:55 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
264
265
x: "f32[s77, s27]" = helion_language__tracing_ops__host_tensor('x')
265
266
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
266
267
load: "f32[u0, u1]" = helion_language_memory_ops_load(x, [block_size_0, slice(None, None, None)], None); x = block_size_0 = None
267
268
mean_extra: "f32[u0]" = helion_language__tracing_ops__inductor_lowering_extra([load]); load = None
268
269
return [mean_extra]
269
270
270
271
def root_graph_2():
271
- # File: .../test_reductions.py:54 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
272
+ # File: .../test_reductions.py:55 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
272
273
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
273
274
_get_symnode: "Sym(s27)" = helion_language__tracing_ops__get_symnode('rdim1')
274
275
_for_loop = helion_language__tracing_ops__for_loop(1, [0], [_get_symnode], []); _get_symnode = None
@@ -318,7 +319,7 @@ def _reduce_kernel_make_precompiler(x: torch.Tensor, fn: Callable[[torch.Tensor]
318
319
)
319
320
320
321
def test_sum_looped (self ):
321
- args = (torch .randn ([512 , 512 ], device = "cuda" ),)
322
+ args = (torch .randn ([512 , 512 ], device = DEVICE ),)
322
323
code , output = code_and_output (
323
324
sum_kernel , args , block_size = 2 , reduction_loop = 64
324
325
)
@@ -367,7 +368,7 @@ def _sum_kernel_make_precompiler(x: torch.Tensor):
367
368
368
369
def test_argmin_argmax_looped (self ):
369
370
for fn in (torch .argmin , torch .argmax ):
370
- args = (torch .randn ([512 , 512 ], device = "cuda" ), fn , torch .int64 )
371
+ args = (torch .randn ([512 , 512 ], device = DEVICE ), fn , torch .int64 )
371
372
code , output = code_and_output (
372
373
reduce_kernel ,
373
374
args ,
0 commit comments