Skip to content

Commit ddcd924

Browse files
authored
Fix issue with phi nodes and aliasing (#220)
Fixes #218
1 parent c86b278 commit ddcd924

File tree

3 files changed

+107
-24
lines changed

3 files changed

+107
-24
lines changed

helion/_compiler/device_ir.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,12 @@ def visit_Assign(self, node: ast.Assign) -> None:
728728
(target,) = node.targets
729729
if isinstance(target, ast.Name):
730730
# TODO(jansel): should assert that name is only used on device
731-
self._assign(target, self.visit(node.value))
731+
value = self.visit(node.value)
732+
# For simple variable assignments like `a = b`, we need to create a new
733+
# variable to avoid phi node issues when the source variable gets mutated
734+
if isinstance(node.value, ast.Name) and isinstance(value, torch.Tensor):
735+
value = _new_var(value)
736+
self._assign(target, value)
732737
return None
733738
if isinstance(target, ast.Tuple):
734739
# Handle tuple unpacking

test/test_examples.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -863,16 +863,17 @@ def _softmax_two_pass_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_s
863863
values = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_1[None, :], other=0)
864864
_mask_to = tl.where(tl.broadcast_to(mask_1[None, :], [1, _BLOCK_SIZE_1]), values, float('-inf'))
865865
local_amax = tl.max(_mask_to, 1)
866-
mi = triton_helpers.maximum(mi_copy_0, local_amax)
867-
v_1 = mi_copy_0 - mi
866+
v_0 = triton_helpers.maximum(mi_copy_0, local_amax)
867+
v_1 = mi_copy_0 - v_0
868868
v_2 = tl_math.exp(v_1)
869869
v_3 = di_copy_0 * v_2
870-
subscript = mi[:, None]
870+
subscript = v_0[:, None]
871871
v_4 = values - subscript
872872
v_5 = tl_math.exp(v_4)
873873
_mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _BLOCK_SIZE_1]), v_5, 0)
874874
sum_1 = tl.sum(_mask_to_1, 1)
875875
di = v_3 + sum_1
876+
mi = v_0
876877
for offset_2 in range(0, n.to(tl.int32), _BLOCK_SIZE_1):
877878
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
878879
mask_2 = indices_2 < n
@@ -945,16 +946,17 @@ def _softmax_two_pass_kernel(x, out, out_size_0, out_size_1, x_size_0, x_size_1,
945946
values = tl.load(tl.make_block_ptr(x, [x_size_0, x_size_1], [x_stride_0, x_stride_1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero')
946947
_mask_to = tl.where(mask_0[:, None] & mask_1[None, :], values, float('-inf'))
947948
local_amax = tl.max(_mask_to, 1)
948-
mi = triton_helpers.maximum(mi_copy_0, local_amax)
949-
v_1 = mi_copy_0 - mi
949+
v_0 = triton_helpers.maximum(mi_copy_0, local_amax)
950+
v_1 = mi_copy_0 - v_0
950951
v_2 = tl_math.exp(v_1)
951952
v_3 = di_copy_0 * v_2
952-
subscript = mi[:, None]
953+
subscript = v_0[:, None]
953954
v_4 = values - subscript
954955
v_5 = tl_math.exp(v_4)
955956
_mask_to_1 = tl.where(mask_0[:, None] & mask_1[None, :], v_5, 0)
956957
sum_1 = tl.sum(_mask_to_1, 1)
957958
di = v_3 + sum_1
959+
mi = v_0
958960
for offset_2 in range(0, n.to(tl.int32), _BLOCK_SIZE_1):
959961
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
960962
mi_copy_1 = mi
@@ -1148,21 +1150,22 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
11481150
amax = tl.max(qk, 2)
11491151
v_0 = 0.18033688
11501152
v_1 = amax * v_0
1151-
m_i = triton_helpers.maximum(m_i_copy_0, v_1)
1153+
v_2 = triton_helpers.maximum(m_i_copy_0, v_1)
11521154
v_3 = 0.18033688
11531155
v_4 = qk * v_3
1154-
subscript = m_i[:, :, None]
1156+
subscript = v_2[:, :, None]
11551157
v_5 = v_4 - subscript
11561158
v_6 = libdevice.exp2(v_5)
11571159
l_ij = tl.sum(v_6, 2)
1158-
v_7 = m_i_copy_0 - m_i
1160+
v_7 = m_i_copy_0 - v_2
11591161
v_8 = libdevice.exp2(v_7)
11601162
v_9 = l_i_copy_0 * v_8
11611163
l_i = v_9 + l_ij
11621164
subscript_1 = v_8[:, :, None]
11631165
v_11 = acc_copy_0 * subscript_1
11641166
v = tl.load(v_view + (indices_0[:, None, None] * 32768 + indices_2[None, :, None] * 64 + indices_4[None, None, :] * 1), None)
11651167
acc = tl.reshape(tl.dot(tl.reshape(v_6, [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(v, [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_11, [_BLOCK_SIZE_1, 64]), input_precision='tf32'), [1, _BLOCK_SIZE_1, 64])
1168+
m_i = v_2
11661169
subscript_2 = l_i[:, :, None]
11671170
v_12 = acc / subscript_2
11681171
tl.store(out + (indices_0[:, None, None] * 32768 + indices_1[None, :, None] * 64 + indices_4[None, None, :] * 1), v_12, None)
@@ -1254,15 +1257,15 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
12541257
v_0 = tl.full([], 0.18033688, tl.float16)
12551258
v_1 = amax * v_0
12561259
v_2 = v_1.to(tl.float32)
1257-
m_i = triton_helpers.maximum(m_i_copy_0, v_2)
1260+
v_3 = triton_helpers.maximum(m_i_copy_0, v_2)
12581261
v_4 = tl.full([], 0.18033688, tl.float16)
12591262
v_5 = qk * v_4
1260-
subscript = m_i[:, :, None]
1263+
subscript = v_3[:, :, None]
12611264
v_6 = v_5.to(tl.float32)
12621265
v_7 = v_6 - subscript
12631266
v_8 = libdevice.exp2(v_7)
12641267
l_ij = tl.sum(v_8, 2)
1265-
v_9 = m_i_copy_0 - m_i
1268+
v_9 = m_i_copy_0 - v_3
12661269
v_10 = libdevice.exp2(v_9)
12671270
v_11 = l_i_copy_0 * v_10
12681271
l_i = v_11 + l_ij
@@ -1271,6 +1274,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
12711274
v = tl.load(tl.make_block_ptr(v_view, [64, 512, 64], [32768, 64, 1], [offset_0, offset_2, 0], [1, _BLOCK_SIZE_3, 64], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero')
12721275
v_14 = v_8.to(tl.float16)
12731276
acc = tl.reshape(tl.dot(tl.reshape(v_14, [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(v, [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_13, [_BLOCK_SIZE_1, 64]), input_precision='tf32'), [1, _BLOCK_SIZE_1, 64])
1277+
m_i = v_3
12741278
subscript_2 = l_i[:, :, None]
12751279
v_15 = acc / subscript_2
12761280
v_16 = v_15.to(tl.float16)
@@ -1366,22 +1370,23 @@ def _attention_kernel(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2,
13661370
amax = tl.max(_mask_to_2, 2)
13671371
v_0 = 0.18033688
13681372
v_1 = amax * v_0
1369-
m_i = triton_helpers.maximum(m_i_copy_0, v_1)
1373+
v_2 = triton_helpers.maximum(m_i_copy_0, v_1)
13701374
v_3 = 0.18033688
13711375
v_4 = qk * v_3
1372-
subscript = m_i[:, :, None]
1376+
subscript = v_2[:, :, None]
13731377
v_5 = v_4 - subscript
13741378
v_6 = libdevice.exp2(v_5)
13751379
_mask_to_3 = tl.where(tl.broadcast_to(mask_1[None, :, None] & mask_3[None, None, :], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3]), v_6, 0)
13761380
l_ij = tl.sum(_mask_to_3, 2)
1377-
v_7 = m_i_copy_0 - m_i
1381+
v_7 = m_i_copy_0 - v_2
13781382
v_8 = libdevice.exp2(v_7)
13791383
v_9 = l_i_copy_0 * v_8
13801384
l_i = v_9 + l_ij
13811385
subscript_1 = v_8[:, :, None]
13821386
v_11 = acc_copy_0 * subscript_1
13831387
v = tl.load(tl.make_block_ptr(v_view, [v_view_size_0, v_view_size_1, 64], [v_view_stride_0, v_view_stride_1, v_view_stride_2], [offset_0, offset_2, 0], [1, _BLOCK_SIZE_3, 64], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero')
13841388
acc = tl.reshape(tl.dot(tl.reshape(_mask_to_3, [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(v, [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_11, [_BLOCK_SIZE_1, 64]), input_precision='tf32'), [1, _BLOCK_SIZE_1, 64])
1389+
m_i = v_2
13851390
subscript_2 = l_i[:, :, None]
13861391
v_12 = acc / subscript_2
13871392
tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1, 64], [out_stride_0, out_stride_1, out_stride_2], [offset_0, offset_1, 0], [1, _BLOCK_SIZE_1, 64], [2, 1, 0]), v_12, boundary_check=[0, 1, 2])

test/test_loops.py

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,32 +1327,35 @@ def _chebyshev_kernel_kernel(x, w, out, out_stride_0, out_stride_1, w_stride_0,
13271327
offset_1 = pid_1 * _BLOCK_SIZE_1
13281328
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
13291329
mask_1 = indices_1 < C
1330-
T1 = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
1330+
in_x = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
13311331
T0 = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 1.0, tl.float32)
1332+
in_x_0 = in_x
13321333
load_1 = tl.load(w + (0 * w_stride_0 + indices_1 * w_stride_1), mask_1, other=0)
13331334
subscript = load_1[None, :]
13341335
v_0 = subscript * T0
13351336
load_2 = tl.load(w + (1 * w_stride_0 + indices_1 * w_stride_1), mask_1, other=0)
13361337
subscript_1 = load_2[None, :]
1337-
v_1 = subscript_1 * T1
1338+
v_1 = subscript_1 * in_x_0
13381339
v_2 = v_0 + v_1
13391340
v_3 = 2.0
1340-
v_4 = T1 * v_3
1341+
v_4 = in_x * v_3
13411342
for offset_2 in range(2, 5, 1):
13421343
indices_2 = offset_2 + tl.arange(0, 1).to(tl.int32)
13431344
v_4_copy = v_4
1344-
T1_copy = T1
1345+
in_x_0_copy = in_x_0
13451346
T0_copy = T0
13461347
v_2_copy = v_2
13471348
v_4_copy_0 = v_4_copy
1348-
T0 = T1_copy
1349+
in_x_0_copy_0 = in_x_0_copy
13491350
T0_copy_0 = T0_copy
13501351
v_2_copy_0 = v_2_copy
1351-
v_5 = v_4_copy_0 * T0
1352-
T1 = v_5 - T0_copy_0
1352+
v_5 = v_4_copy_0 * in_x_0_copy_0
1353+
v_6 = v_5 - T0_copy_0
13531354
load = tl.load(w + (indices_2[:, None] * w_stride_0 + indices_1[None, :] * w_stride_1), mask_1[None, :], other=0)
1354-
v_7 = load * T1
1355+
v_7 = load * v_6
13551356
v_2 = v_2_copy_0 + v_7
1357+
T0 = in_x_0_copy_0
1358+
in_x_0 = v_6
13561359
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_2, mask_0[:, None] & mask_1[None, :])
13571360
13581361
def chebyshev_kernel(x: torch.Tensor, w: torch.Tensor):
@@ -1499,6 +1502,76 @@ def _fn_make_precompiler(x: torch.Tensor):
14991502
return make_precompiler(_fn_kernel)(x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
15001503
)
15011504

1505+
def test_variable_assignment_phi_nodes(self):
1506+
"""Test for phi node issue with variable assignments like U1 = two_x.
1507+
1508+
This test ensures that simple variable assignments create new variables
1509+
rather than aliases, preventing phi node issues when the source variable
1510+
gets mutated in loops.
1511+
"""
1512+
1513+
@helion.kernel(use_default_config=True)
1514+
def kernel_with_assignment(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
1515+
B, C = x.shape
1516+
N, _ = w.shape
1517+
hl.specialize(N)
1518+
grad_x = torch.zeros_like(x)
1519+
1520+
for b_tile, c_tile in hl.tile([B, C]):
1521+
in_x = x[b_tile, c_tile]
1522+
two_x = 2.0 * in_x
1523+
1524+
# This assignment should create a new variable, not an alias
1525+
U1 = two_x
1526+
U0 = hl.full((b_tile, c_tile), 1.0, x.dtype)
1527+
1528+
acc = w[0, c_tile] * U0 + w[1, c_tile] * U1
1529+
1530+
for order in hl.tile(2, N, block_size=1):
1531+
acc += w[order, c_tile] * U1
1532+
U_new = two_x * U1 - U0
1533+
U0 = U1
1534+
U1 = U_new
1535+
1536+
grad_x[b_tile, c_tile] = acc
1537+
return grad_x
1538+
1539+
@helion.kernel(use_default_config=True)
1540+
def kernel_without_assignment(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
1541+
B, C = x.shape
1542+
N, _ = w.shape
1543+
hl.specialize(N)
1544+
grad_x = torch.zeros_like(x)
1545+
1546+
for b_tile, c_tile in hl.tile([B, C]):
1547+
in_x = x[b_tile, c_tile]
1548+
two_x = 2.0 * in_x
1549+
1550+
# Direct use without assignment
1551+
U1 = 2.0 * in_x
1552+
U0 = hl.full((b_tile, c_tile), 1.0, x.dtype)
1553+
1554+
acc = w[0, c_tile] * U0 + w[1, c_tile] * U1
1555+
1556+
for order in hl.tile(2, N, block_size=1):
1557+
acc += w[order, c_tile] * U1
1558+
U_new = two_x * U1 - U0
1559+
U0 = U1
1560+
U1 = U_new
1561+
1562+
grad_x[b_tile, c_tile] = acc
1563+
return grad_x
1564+
1565+
# Test with small tensor
1566+
x = torch.randn(4, 8, device=DEVICE, dtype=torch.float32)
1567+
w = torch.randn(5, 8, device=DEVICE, dtype=torch.float32)
1568+
1569+
code1, result1 = code_and_output(kernel_with_assignment, (x, w))
1570+
code2, result2 = code_and_output(kernel_without_assignment, (x, w))
1571+
1572+
# Both should produce identical results
1573+
torch.testing.assert_close(result1, result2, rtol=1e-5, atol=1e-5)
1574+
15021575

15031576
if __name__ == "__main__":
15041577
unittest.main()

0 commit comments

Comments
 (0)