@@ -863,16 +863,17 @@ def _softmax_two_pass_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_s
863
863
values = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_1[None, :], other=0)
864
864
_mask_to = tl.where(tl.broadcast_to(mask_1[None, :], [1, _BLOCK_SIZE_1]), values, float('-inf'))
865
865
local_amax = tl.max(_mask_to, 1)
866
- mi = triton_helpers.maximum(mi_copy_0, local_amax)
867
- v_1 = mi_copy_0 - mi
866
+ v_0 = triton_helpers.maximum(mi_copy_0, local_amax)
867
+ v_1 = mi_copy_0 - v_0
868
868
v_2 = tl_math.exp(v_1)
869
869
v_3 = di_copy_0 * v_2
870
- subscript = mi [:, None]
870
+ subscript = v_0 [:, None]
871
871
v_4 = values - subscript
872
872
v_5 = tl_math.exp(v_4)
873
873
_mask_to_1 = tl.where(tl.broadcast_to(mask_1[None, :], [1, _BLOCK_SIZE_1]), v_5, 0)
874
874
sum_1 = tl.sum(_mask_to_1, 1)
875
875
di = v_3 + sum_1
876
+ mi = v_0
876
877
for offset_2 in range(0, n.to(tl.int32), _BLOCK_SIZE_1):
877
878
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
878
879
mask_2 = indices_2 < n
@@ -945,16 +946,17 @@ def _softmax_two_pass_kernel(x, out, out_size_0, out_size_1, x_size_0, x_size_1,
945
946
values = tl.load(tl.make_block_ptr(x, [x_size_0, x_size_1], [x_stride_0, x_stride_1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero')
946
947
_mask_to = tl.where(mask_0[:, None] & mask_1[None, :], values, float('-inf'))
947
948
local_amax = tl.max(_mask_to, 1)
948
- mi = triton_helpers.maximum(mi_copy_0, local_amax)
949
- v_1 = mi_copy_0 - mi
949
+ v_0 = triton_helpers.maximum(mi_copy_0, local_amax)
950
+ v_1 = mi_copy_0 - v_0
950
951
v_2 = tl_math.exp(v_1)
951
952
v_3 = di_copy_0 * v_2
952
- subscript = mi [:, None]
953
+ subscript = v_0 [:, None]
953
954
v_4 = values - subscript
954
955
v_5 = tl_math.exp(v_4)
955
956
_mask_to_1 = tl.where(mask_0[:, None] & mask_1[None, :], v_5, 0)
956
957
sum_1 = tl.sum(_mask_to_1, 1)
957
958
di = v_3 + sum_1
959
+ mi = v_0
958
960
for offset_2 in range(0, n.to(tl.int32), _BLOCK_SIZE_1):
959
961
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
960
962
mi_copy_1 = mi
@@ -1148,21 +1150,22 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
1148
1150
amax = tl.max(qk, 2)
1149
1151
v_0 = 0.18033688
1150
1152
v_1 = amax * v_0
1151
- m_i = triton_helpers.maximum(m_i_copy_0, v_1)
1153
+ v_2 = triton_helpers.maximum(m_i_copy_0, v_1)
1152
1154
v_3 = 0.18033688
1153
1155
v_4 = qk * v_3
1154
- subscript = m_i [:, :, None]
1156
+ subscript = v_2 [:, :, None]
1155
1157
v_5 = v_4 - subscript
1156
1158
v_6 = libdevice.exp2(v_5)
1157
1159
l_ij = tl.sum(v_6, 2)
1158
- v_7 = m_i_copy_0 - m_i
1160
+ v_7 = m_i_copy_0 - v_2
1159
1161
v_8 = libdevice.exp2(v_7)
1160
1162
v_9 = l_i_copy_0 * v_8
1161
1163
l_i = v_9 + l_ij
1162
1164
subscript_1 = v_8[:, :, None]
1163
1165
v_11 = acc_copy_0 * subscript_1
1164
1166
v = tl.load(v_view + (indices_0[:, None, None] * 32768 + indices_2[None, :, None] * 64 + indices_4[None, None, :] * 1), None)
1165
1167
acc = tl.reshape(tl.dot(tl.reshape(v_6, [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(v, [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_11, [_BLOCK_SIZE_1, 64]), input_precision='tf32'), [1, _BLOCK_SIZE_1, 64])
1168
+ m_i = v_2
1166
1169
subscript_2 = l_i[:, :, None]
1167
1170
v_12 = acc / subscript_2
1168
1171
tl.store(out + (indices_0[:, None, None] * 32768 + indices_1[None, :, None] * 64 + indices_4[None, None, :] * 1), v_12, None)
@@ -1254,15 +1257,15 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
1254
1257
v_0 = tl.full([], 0.18033688, tl.float16)
1255
1258
v_1 = amax * v_0
1256
1259
v_2 = v_1.to(tl.float32)
1257
- m_i = triton_helpers.maximum(m_i_copy_0, v_2)
1260
+ v_3 = triton_helpers.maximum(m_i_copy_0, v_2)
1258
1261
v_4 = tl.full([], 0.18033688, tl.float16)
1259
1262
v_5 = qk * v_4
1260
- subscript = m_i [:, :, None]
1263
+ subscript = v_3 [:, :, None]
1261
1264
v_6 = v_5.to(tl.float32)
1262
1265
v_7 = v_6 - subscript
1263
1266
v_8 = libdevice.exp2(v_7)
1264
1267
l_ij = tl.sum(v_8, 2)
1265
- v_9 = m_i_copy_0 - m_i
1268
+ v_9 = m_i_copy_0 - v_3
1266
1269
v_10 = libdevice.exp2(v_9)
1267
1270
v_11 = l_i_copy_0 * v_10
1268
1271
l_i = v_11 + l_ij
@@ -1271,6 +1274,7 @@ def _attention_kernel(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
1271
1274
v = tl.load(tl.make_block_ptr(v_view, [64, 512, 64], [32768, 64, 1], [offset_0, offset_2, 0], [1, _BLOCK_SIZE_3, 64], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero')
1272
1275
v_14 = v_8.to(tl.float16)
1273
1276
acc = tl.reshape(tl.dot(tl.reshape(v_14, [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(v, [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_13, [_BLOCK_SIZE_1, 64]), input_precision='tf32'), [1, _BLOCK_SIZE_1, 64])
1277
+ m_i = v_3
1274
1278
subscript_2 = l_i[:, :, None]
1275
1279
v_15 = acc / subscript_2
1276
1280
v_16 = v_15.to(tl.float16)
@@ -1366,22 +1370,23 @@ def _attention_kernel(q_view, k_view, v_view, out, k_view_size_0, k_view_size_2,
1366
1370
amax = tl.max(_mask_to_2, 2)
1367
1371
v_0 = 0.18033688
1368
1372
v_1 = amax * v_0
1369
- m_i = triton_helpers.maximum(m_i_copy_0, v_1)
1373
+ v_2 = triton_helpers.maximum(m_i_copy_0, v_1)
1370
1374
v_3 = 0.18033688
1371
1375
v_4 = qk * v_3
1372
- subscript = m_i [:, :, None]
1376
+ subscript = v_2 [:, :, None]
1373
1377
v_5 = v_4 - subscript
1374
1378
v_6 = libdevice.exp2(v_5)
1375
1379
_mask_to_3 = tl.where(tl.broadcast_to(mask_1[None, :, None] & mask_3[None, None, :], [1, _BLOCK_SIZE_1, _BLOCK_SIZE_3]), v_6, 0)
1376
1380
l_ij = tl.sum(_mask_to_3, 2)
1377
- v_7 = m_i_copy_0 - m_i
1381
+ v_7 = m_i_copy_0 - v_2
1378
1382
v_8 = libdevice.exp2(v_7)
1379
1383
v_9 = l_i_copy_0 * v_8
1380
1384
l_i = v_9 + l_ij
1381
1385
subscript_1 = v_8[:, :, None]
1382
1386
v_11 = acc_copy_0 * subscript_1
1383
1387
v = tl.load(tl.make_block_ptr(v_view, [v_view_size_0, v_view_size_1, 64], [v_view_stride_0, v_view_stride_1, v_view_stride_2], [offset_0, offset_2, 0], [1, _BLOCK_SIZE_3, 64], [2, 1, 0]), boundary_check=[0, 1, 2], padding_option='zero')
1384
1388
acc = tl.reshape(tl.dot(tl.reshape(_mask_to_3, [_BLOCK_SIZE_1, _BLOCK_SIZE_3]), tl.reshape(v, [_BLOCK_SIZE_3, 64]), acc=tl.reshape(v_11, [_BLOCK_SIZE_1, 64]), input_precision='tf32'), [1, _BLOCK_SIZE_1, 64])
1389
+ m_i = v_2
1385
1390
subscript_2 = l_i[:, :, None]
1386
1391
v_12 = acc / subscript_2
1387
1392
tl.store(tl.make_block_ptr(out, [out_size_0, out_size_1, 64], [out_stride_0, out_stride_1, out_stride_2], [offset_0, offset_1, 0], [1, _BLOCK_SIZE_1, 64], [2, 1, 0]), v_12, boundary_check=[0, 1, 2])
0 commit comments