@@ -87,11 +87,11 @@ def _grid_1d_kernel(x, y, out, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.co
87
87
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
88
88
acc_copy = acc
89
89
acc_copy_0 = acc_copy
90
- load = tl.load(x + (tl.full([1], offset_0, tl.int32)[:, None] * 512 + indices_1[:, None] * 32 + indices_3[None, :] * 1), None)
90
+ load = tl.load(x + (offset_0 * 512 + indices_1[:, None] * 32 + indices_3[None, :] * 1), None)
91
91
load_1 = tl.load(y + (indices_3[:, None] * 4 + indices_2[None, :] * 1), mask_2[None, :], other=0)
92
92
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
93
93
v_0 = acc.to(tl.float16)
94
- tl.store(out + (tl.full([1], offset_0, tl.int32)[:, None] * 64 + indices_1[:, None] * 4 + indices_2[None, :] * 1), v_0, mask_2[None, :])
94
+ tl.store(out + (offset_0 * 64 + indices_1[:, None] * 4 + indices_2[None, :] * 1), v_0, mask_2[None, :])
95
95
96
96
def grid_1d(x: torch.Tensor, y: torch.Tensor):
97
97
b, m, k = x.size()
@@ -225,11 +225,11 @@ def _grid_2d_idx_list_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE
225
225
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
226
226
acc_copy = acc
227
227
acc_copy_0 = acc_copy
228
- load = tl.load(x + (tl.full([1], offset_0, tl.int32)[:, None] * 8192 + tl.full([1], offset_1, tl.int32)[:, None] * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
228
+ load = tl.load(x + (offset_0 * 8192 + offset_1 * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
229
229
load_1 = tl.load(y + (indices_4[:, None] * 16 + indices_3[None, :] * 1), None)
230
230
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
231
231
v_0 = acc.to(tl.float16)
232
- tl.store(out + (tl.full([1], offset_0, tl.int32)[:, None] * 4096 + tl.full([1], offset_1, tl.int32)[:, None] * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
232
+ tl.store(out + (offset_0 * 4096 + offset_1 * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
233
233
234
234
def grid_2d_idx_list(x: torch.Tensor, y: torch.Tensor):
235
235
bi, bj, m, k = x.size()
@@ -363,11 +363,11 @@ def _grid_2d_idx_nested_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SI
363
363
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
364
364
acc_copy = acc
365
365
acc_copy_0 = acc_copy
366
- load = tl.load(x + (tl.full([1], offset_0, tl.int32)[:, None] * 8192 + tl.full([1], offset_1, tl.int32)[:, None] * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
366
+ load = tl.load(x + (offset_0 * 8192 + offset_1 * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
367
367
load_1 = tl.load(y + (indices_4[:, None] * 16 + indices_3[None, :] * 1), None)
368
368
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
369
369
v_0 = acc.to(tl.float16)
370
- tl.store(out + (tl.full([1], offset_0, tl.int32)[:, None] * 4096 + tl.full([1], offset_1, tl.int32)[:, None] * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
370
+ tl.store(out + (offset_0 * 4096 + offset_1 * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
371
371
372
372
def grid_2d_idx_nested(x: torch.Tensor, y: torch.Tensor):
373
373
bi, bj, m, k = x.size()
@@ -425,10 +425,10 @@ def _grid_begin_end_kernel(x, out, out_stride_0, x_stride_0):
425
425
pid_0 = tl.program_id(0)
426
426
begin_0 = 2
427
427
offset_0 = begin_0 + pid_0
428
- load = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
428
+ load = tl.load(x + offset_0 * x_stride_0, None)
429
429
v_0 = 2.0
430
430
v_1 = load * v_0
431
- tl.store(out + tl.full([1], offset_0, tl.int32) * out_stride_0, v_1, None)
431
+ tl.store(out + offset_0 * out_stride_0, v_1, None)
432
432
433
433
def grid_begin_end(x: torch.Tensor):
434
434
n = x.size(0)
@@ -475,10 +475,10 @@ def grid_begin_end_step_pytorch(x: torch.Tensor) -> torch.Tensor:
475
475
def _grid_begin_end_step_kernel(x, out, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
476
476
pid_0 = tl.program_id(0)
477
477
offset_0 = pid_0 * _BLOCK_SIZE_0
478
- load = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
478
+ load = tl.load(x + offset_0 * x_stride_0, None)
479
479
v_0 = 2.0
480
480
v_1 = load * v_0
481
- tl.store(out + tl.full([1], offset_0, tl.int32) * out_stride_0, v_1, None)
481
+ tl.store(out + offset_0 * out_stride_0, v_1, None)
482
482
483
483
def grid_begin_end_step(x: torch.Tensor):
484
484
n = x.size(0)
@@ -527,10 +527,10 @@ def grid_end_step_kwarg_pytorch(x: torch.Tensor) -> torch.Tensor:
527
527
def _grid_end_step_kwarg_kernel(x, out, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
528
528
pid_0 = tl.program_id(0)
529
529
offset_0 = pid_0 * _BLOCK_SIZE_0
530
- load = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
530
+ load = tl.load(x + offset_0 * x_stride_0, None)
531
531
v_0 = 2.0
532
532
v_1 = load * v_0
533
- tl.store(out + tl.full([1], offset_0, tl.int32) * out_stride_0, v_1, None)
533
+ tl.store(out + offset_0 * out_stride_0, v_1, None)
534
534
535
535
def grid_end_step_kwarg(x: torch.Tensor):
536
536
n = x.size(0)
@@ -587,10 +587,10 @@ def _grid_multidim_begin_end_kernel(x, out, out_stride_0, out_stride_1, x_stride
587
587
offset_0 = begin_0 + pid_0
588
588
begin_1 = 1
589
589
offset_1 = begin_1 + pid_1
590
- load = tl.load(x + (tl.full([1], offset_0, tl.int32) * x_stride_0 + tl.full([1], offset_1, tl.int32) * x_stride_1), None)
590
+ load = tl.load(x + (offset_0 * x_stride_0 + offset_1 * x_stride_1), None)
591
591
v_0 = 2.0
592
592
v_1 = load * v_0
593
- tl.store(out + (tl.full([1], offset_0, tl.int32) * out_stride_0 + tl.full([1], offset_1, tl.int32) * out_stride_1), v_1, None)
593
+ tl.store(out + (offset_0 * out_stride_0 + offset_1 * out_stride_1), v_1, None)
594
594
595
595
def grid_multidim_begin_end(x: torch.Tensor):
596
596
m, n = x.size()
@@ -643,10 +643,10 @@ def _grid_multidim_begin_end_step_kernel(x, out, out_stride_0, out_stride_1, x_s
643
643
pid_1 = tl.program_id(0) // num_blocks_0
644
644
offset_0 = pid_0 * _BLOCK_SIZE_0
645
645
offset_1 = pid_1 * _BLOCK_SIZE_1
646
- load = tl.load(x + (tl.full([1], offset_0, tl.int32) * x_stride_0 + tl.full([1], offset_1, tl.int32) * x_stride_1), None)
646
+ load = tl.load(x + (offset_0 * x_stride_0 + offset_1 * x_stride_1), None)
647
647
v_0 = 2.0
648
648
v_1 = load * v_0
649
- tl.store(out + (tl.full([1], offset_0, tl.int32) * out_stride_0 + tl.full([1], offset_1, tl.int32) * out_stride_1), v_1, None)
649
+ tl.store(out + (offset_0 * out_stride_0 + offset_1 * out_stride_1), v_1, None)
650
650
651
651
def grid_multidim_begin_end_step(x: torch.Tensor):
652
652
m, n = x.size()
0 commit comments