@@ -1055,11 +1055,11 @@ def _chunk_scan_bwd_ddAcs_stable_kernel_old(
1055
1055
@triton .autotune (
1056
1056
configs = [
1057
1057
triton .Config ({'BLOCK_SIZE_M' : 32 , 'BLOCK_SIZE_N' : 32 }, num_stages = 3 , num_warps = 4 ),
1058
- triton .Config ({'BLOCK_SIZE_M' : 32 , 'BLOCK_SIZE_N' : 64 }, num_stages = 3 , num_warps = 4 ),
1059
- triton .Config ({'BLOCK_SIZE_M' : 32 , 'BLOCK_SIZE_N' : 128 }, num_stages = 3 , num_warps = 4 ),
1058
+ # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
1059
+ # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
1060
1060
triton .Config ({'BLOCK_SIZE_M' : 64 , 'BLOCK_SIZE_N' : 32 }, num_stages = 3 , num_warps = 4 ),
1061
1061
triton .Config ({'BLOCK_SIZE_M' : 64 , 'BLOCK_SIZE_N' : 64 }, num_stages = 3 , num_warps = 4 ),
1062
- triton .Config ({'BLOCK_SIZE_M' : 64 , 'BLOCK_SIZE_N' : 128 }, num_stages = 3 , num_warps = 4 ),
1062
+ # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
1063
1063
triton .Config ({'BLOCK_SIZE_M' : 128 , 'BLOCK_SIZE_N' : 32 }, num_stages = 3 , num_warps = 4 ),
1064
1064
triton .Config ({'BLOCK_SIZE_M' : 128 , 'BLOCK_SIZE_N' : 64 }, num_stages = 3 , num_warps = 4 ),
1065
1065
triton .Config ({'BLOCK_SIZE_M' : 128 , 'BLOCK_SIZE_N' : 128 }, num_stages = 3 , num_warps = 4 ),
@@ -1133,7 +1133,7 @@ def _chunk_scan_bwd_ddAcs_stable_kernel(
1133
1133
# If there's seq_idx, we already zero'ed out cb[i, j] for seq_idx[i] != seq_idx[j]
1134
1134
cb = tl .load (cb_ptrs , mask = (offs_m [:, None ] < chunk_size ) & (offs_n [None , :] < chunk_size - start_n ), other = 0.0 ).to (tl .float32 )
1135
1135
acc *= cb
1136
- dA_cs_n = tl .load (dA_cumsum_ptr + start_n + offs_n * stride_dA_cs_csize , mask = offs_n < chunk_size - start_n , other = 0.0 ).to (tl .float32 )
1136
+ dA_cs_n = tl .load (dA_cumsum_ptr + ( start_n + offs_n ) * stride_dA_cs_csize , mask = offs_n < chunk_size - start_n , other = 0.0 ).to (tl .float32 )
1137
1137
acc *= tl .exp (dA_cs_m [:, None ] - dA_cs_n [None , :])
1138
1138
mask = offs_m [:, None ] >= start_n + offs_n [None , :] + 1
1139
1139
acc = tl .where (mask , acc , 0.0 )
0 commit comments