Skip to content

Commit 442fab4

Browse files
authored
Fix Incorrect Gradients and Illegal Memory Access Error in Mamba2 (#537)
* Fix incorrect gradients * Fix another pointer error in ddAcs_stable
1 parent bda9af3 commit 442fab4

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

mamba_ssm/ops/triton/ssd_chunk_scan.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,11 +1055,11 @@ def _chunk_scan_bwd_ddAcs_stable_kernel_old(
10551055
@triton.autotune(
10561056
configs=[
10571057
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
1058-
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
1059-
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
1058+
# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
1059+
# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
10601060
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
10611061
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
1062-
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
1062+
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
10631063
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
10641064
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
10651065
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
@@ -1133,7 +1133,7 @@ def _chunk_scan_bwd_ddAcs_stable_kernel(
11331133
# If there's seq_idx, we already zero'ed out cb[i, j] for seq_idx[i] != seq_idx[j]
11341134
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32)
11351135
acc *= cb
1136-
dA_cs_n = tl.load(dA_cumsum_ptr + start_n + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32)
1136+
dA_cs_n = tl.load(dA_cumsum_ptr + (start_n + offs_n) * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32)
11371137
acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
11381138
mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1
11391139
acc = tl.where(mask, acc, 0.0)

0 commit comments

Comments
 (0)