@@ -274,16 +274,42 @@ def _kernel_make_precompiler(a_list, b_dict, b_tuple, c_named_tuple, d_dataclass
274
274
275
275
def test_config_flatten_issue (self ):
276
276
@helion .kernel (use_default_config = True )
277
- def test_tile_id_atomic_add (x : torch .Tensor ) -> torch .Tensor :
277
+ def test_tile_begin (x : torch .Tensor ) -> torch .Tensor :
278
278
out = torch .zeros_like (x , dtype = torch .int32 )
279
279
for tile_m , tile_n in hl .tile (x .size ()):
280
280
out [tile_m .begin , tile_n .begin ] = 1
281
281
return out
282
282
283
283
x = torch .randn (64 , 64 , device = "cuda" )
284
284
config = helion .Config (block_sizes = [16 , 16 ])
285
- test_tile_id_atomic_add .bind ((x ,)).to_triton_code (config )
286
- result = test_tile_id_atomic_add .bind ((x ,)).compile_config (config )(x )
285
+ test_tile_begin .bind ((x ,)).to_triton_code (config )
286
+ result = test_tile_begin .bind ((x ,)).compile_config (config )(x )
287
+ self .assertEqual (result .sum ().item (), 16 )
288
+
289
+ @helion .kernel (use_default_config = True )
290
+ def test_tile_end (x : torch .Tensor ) -> torch .Tensor :
291
+ out = torch .zeros_like (x , dtype = torch .int32 )
292
+ for tile_m , tile_n in hl .tile (x .size ()):
293
+ out [tile_m .end , tile_n .end ] = 1
294
+ return out
295
+
296
+ x = torch .randn (64 , 64 , device = "cuda" )
297
+ config = helion .Config (block_sizes = [16 , 16 ])
298
+ test_tile_end .bind ((x ,)).to_triton_code (config )
299
+ result = test_tile_end .bind ((x ,)).compile_config (config )(x )
300
+ self .assertEqual (result .sum ().item (), 12 )
301
+
302
+ @helion .kernel (use_default_config = True )
303
+ def test_tile_id (x : torch .Tensor ) -> torch .Tensor :
304
+ out = torch .zeros_like (x , dtype = torch .int32 )
305
+ for tile_m , tile_n in hl .tile (x .size ()):
306
+ out [tile_m .id , tile_n .id ] = 1
307
+ return out
308
+
309
+ x = torch .randn (64 , 64 , device = "cuda" )
310
+ config = helion .Config (block_sizes = [16 , 16 ])
311
+ test_tile_id .bind ((x ,)).to_triton_code (config )
312
+ result = test_tile_id .bind ((x ,)).compile_config (config )(x )
287
313
self .assertEqual (result .sum ().item (), 16 )
288
314
289
315
0 commit comments