From b069d47659435f613d3820d6f5a5e188ad46a4f2 Mon Sep 17 00:00:00 2001 From: Hayden Prairie <55720063+Hprairie@users.noreply.github.com> Date: Fri, 23 Aug 2024 22:49:57 -0500 Subject: [PATCH 1/2] Fix incorrect gradients --- mamba_ssm/ops/triton/ssd_chunk_scan.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mamba_ssm/ops/triton/ssd_chunk_scan.py b/mamba_ssm/ops/triton/ssd_chunk_scan.py index 9fa3a934..3d01beca 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_scan.py +++ b/mamba_ssm/ops/triton/ssd_chunk_scan.py @@ -1055,11 +1055,11 @@ def _chunk_scan_bwd_ddAcs_stable_kernel_old( @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), From 11c7860fa59982378d5a409ef1fb1091bd8a5e79 Mon Sep 17 00:00:00 2001 From: Hayden Prairie <55720063+Hprairie@users.noreply.github.com> Date: Sun, 25 Aug 2024 16:27:21 -0500 Subject: [PATCH 2/2] Fix another pointer error in ddAcs_stable --- mamba_ssm/ops/triton/ssd_chunk_scan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mamba_ssm/ops/triton/ssd_chunk_scan.py b/mamba_ssm/ops/triton/ssd_chunk_scan.py index 3d01beca..fa5b813a 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_scan.py +++ b/mamba_ssm/ops/triton/ssd_chunk_scan.py @@ -1133,7 +1133,7 @@ def _chunk_scan_bwd_ddAcs_stable_kernel( # If there's seq_idx, we already zero'ed out cb[i, j] for seq_idx[i] != seq_idx[j] cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32) acc *= cb - dA_cs_n = tl.load(dA_cumsum_ptr + start_n + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) + dA_cs_n = tl.load(dA_cumsum_ptr + (start_n + offs_n) * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1 acc = tl.where(mask, acc, 0.0)