Skip to content

Commit bd96103

Browse files
committed
x[i] returns scalar when i=scalar
1 parent f363965 commit bd96103

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,9 @@ def create(
274274
mask_values.setdefault(f"({mask}){expand}")
275275
output_idx += 1
276276
else:
277-
expand = tile_strategy.expand_str(output_size, output_idx)
277+
# When the index is a scalar (no BlockSizeOrigin), the corresponding dim is eliminated.
278278
val = state.device_function.literal_expr(k)
279-
index_values.append(f"tl.full([1], {val}, {dtype}){expand}")
279+
index_values.append(f"({val})")
280280
elif isinstance(k, slice) and str(k) == "slice(None, None, None)":
281281
expand = tile_strategy.expand_str(output_size, output_idx)
282282
size = fake_value.size(len(index_values))

test/test_grid.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ def _grid_1d_kernel(x, y, out, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.co
8787
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
8888
acc_copy = acc
8989
acc_copy_0 = acc_copy
90-
load = tl.load(x + (tl.full([1], offset_0, tl.int32)[:, None] * 512 + indices_1[:, None] * 32 + indices_3[None, :] * 1), None)
90+
load = tl.load(x + (offset_0 * 512 + indices_1[:, None] * 32 + indices_3[None, :] * 1), None)
9191
load_1 = tl.load(y + (indices_3[:, None] * 4 + indices_2[None, :] * 1), mask_2[None, :], other=0)
9292
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
9393
v_0 = acc.to(tl.float16)
94-
tl.store(out + (tl.full([1], offset_0, tl.int32)[:, None] * 64 + indices_1[:, None] * 4 + indices_2[None, :] * 1), v_0, mask_2[None, :])
94+
tl.store(out + (offset_0 * 64 + indices_1[:, None] * 4 + indices_2[None, :] * 1), v_0, mask_2[None, :])
9595
9696
def grid_1d(x: torch.Tensor, y: torch.Tensor):
9797
b, m, k = x.size()
@@ -225,11 +225,11 @@ def _grid_2d_idx_list_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE
225225
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
226226
acc_copy = acc
227227
acc_copy_0 = acc_copy
228-
load = tl.load(x + (tl.full([1], offset_0, tl.int32)[:, None] * 8192 + tl.full([1], offset_1, tl.int32)[:, None] * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
228+
load = tl.load(x + (offset_0 * 8192 + offset_1 * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
229229
load_1 = tl.load(y + (indices_4[:, None] * 16 + indices_3[None, :] * 1), None)
230230
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
231231
v_0 = acc.to(tl.float16)
232-
tl.store(out + (tl.full([1], offset_0, tl.int32)[:, None] * 4096 + tl.full([1], offset_1, tl.int32)[:, None] * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
232+
tl.store(out + (offset_0 * 4096 + offset_1 * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
233233
234234
def grid_2d_idx_list(x: torch.Tensor, y: torch.Tensor):
235235
bi, bj, m, k = x.size()
@@ -363,11 +363,11 @@ def _grid_2d_idx_nested_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SI
363363
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
364364
acc_copy = acc
365365
acc_copy_0 = acc_copy
366-
load = tl.load(x + (tl.full([1], offset_0, tl.int32)[:, None] * 8192 + tl.full([1], offset_1, tl.int32)[:, None] * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
366+
load = tl.load(x + (offset_0 * 8192 + offset_1 * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
367367
load_1 = tl.load(y + (indices_4[:, None] * 16 + indices_3[None, :] * 1), None)
368368
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
369369
v_0 = acc.to(tl.float16)
370-
tl.store(out + (tl.full([1], offset_0, tl.int32)[:, None] * 4096 + tl.full([1], offset_1, tl.int32)[:, None] * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
370+
tl.store(out + (offset_0 * 4096 + offset_1 * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
371371
372372
def grid_2d_idx_nested(x: torch.Tensor, y: torch.Tensor):
373373
bi, bj, m, k = x.size()
@@ -425,10 +425,10 @@ def _grid_begin_end_kernel(x, out, out_stride_0, x_stride_0):
425425
pid_0 = tl.program_id(0)
426426
begin_0 = 2
427427
offset_0 = begin_0 + pid_0
428-
load = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
428+
load = tl.load(x + offset_0 * x_stride_0, None)
429429
v_0 = 2.0
430430
v_1 = load * v_0
431-
tl.store(out + tl.full([1], offset_0, tl.int32) * out_stride_0, v_1, None)
431+
tl.store(out + offset_0 * out_stride_0, v_1, None)
432432
433433
def grid_begin_end(x: torch.Tensor):
434434
n = x.size(0)
@@ -475,10 +475,10 @@ def grid_begin_end_step_pytorch(x: torch.Tensor) -> torch.Tensor:
475475
def _grid_begin_end_step_kernel(x, out, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
476476
pid_0 = tl.program_id(0)
477477
offset_0 = pid_0 * _BLOCK_SIZE_0
478-
load = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
478+
load = tl.load(x + offset_0 * x_stride_0, None)
479479
v_0 = 2.0
480480
v_1 = load * v_0
481-
tl.store(out + tl.full([1], offset_0, tl.int32) * out_stride_0, v_1, None)
481+
tl.store(out + offset_0 * out_stride_0, v_1, None)
482482
483483
def grid_begin_end_step(x: torch.Tensor):
484484
n = x.size(0)
@@ -527,10 +527,10 @@ def grid_end_step_kwarg_pytorch(x: torch.Tensor) -> torch.Tensor:
527527
def _grid_end_step_kwarg_kernel(x, out, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
528528
pid_0 = tl.program_id(0)
529529
offset_0 = pid_0 * _BLOCK_SIZE_0
530-
load = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
530+
load = tl.load(x + offset_0 * x_stride_0, None)
531531
v_0 = 2.0
532532
v_1 = load * v_0
533-
tl.store(out + tl.full([1], offset_0, tl.int32) * out_stride_0, v_1, None)
533+
tl.store(out + offset_0 * out_stride_0, v_1, None)
534534
535535
def grid_end_step_kwarg(x: torch.Tensor):
536536
n = x.size(0)
@@ -587,10 +587,10 @@ def _grid_multidim_begin_end_kernel(x, out, out_stride_0, out_stride_1, x_stride
587587
offset_0 = begin_0 + pid_0
588588
begin_1 = 1
589589
offset_1 = begin_1 + pid_1
590-
load = tl.load(x + (tl.full([1], offset_0, tl.int32) * x_stride_0 + tl.full([1], offset_1, tl.int32) * x_stride_1), None)
590+
load = tl.load(x + (offset_0 * x_stride_0 + offset_1 * x_stride_1), None)
591591
v_0 = 2.0
592592
v_1 = load * v_0
593-
tl.store(out + (tl.full([1], offset_0, tl.int32) * out_stride_0 + tl.full([1], offset_1, tl.int32) * out_stride_1), v_1, None)
593+
tl.store(out + (offset_0 * out_stride_0 + offset_1 * out_stride_1), v_1, None)
594594
595595
def grid_multidim_begin_end(x: torch.Tensor):
596596
m, n = x.size()
@@ -643,10 +643,10 @@ def _grid_multidim_begin_end_step_kernel(x, out, out_stride_0, out_stride_1, x_s
643643
pid_1 = tl.program_id(0) // num_blocks_0
644644
offset_0 = pid_0 * _BLOCK_SIZE_0
645645
offset_1 = pid_1 * _BLOCK_SIZE_1
646-
load = tl.load(x + (tl.full([1], offset_0, tl.int32) * x_stride_0 + tl.full([1], offset_1, tl.int32) * x_stride_1), None)
646+
load = tl.load(x + (offset_0 * x_stride_0 + offset_1 * x_stride_1), None)
647647
v_0 = 2.0
648648
v_1 = load * v_0
649-
tl.store(out + (tl.full([1], offset_0, tl.int32) * out_stride_0 + tl.full([1], offset_1, tl.int32) * out_stride_1), v_1, None)
649+
tl.store(out + (offset_0 * out_stride_0 + offset_1 * out_stride_1), v_1, None)
650650
651651
def grid_multidim_begin_end_step(x: torch.Tensor):
652652
m, n = x.size()

0 commit comments

Comments
 (0)