From f61eae0824e4266dfa09941cccc0177abd9be8c5 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 19 Mar 2025 11:25:08 +0000 Subject: [PATCH 1/4] fix: fix large-bs issue --- mamba_ssm/ops/triton/ssd_chunk_scan.py | 8 ++++---- mamba_ssm/ops/triton/ssd_chunk_state.py | 14 +++++++------- mamba_ssm/ops/triton/ssd_state_passing.py | 6 +++--- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/mamba_ssm/ops/triton/ssd_chunk_scan.py b/mamba_ssm/ops/triton/ssd_chunk_scan.py index fa5b813a..c2301a47 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_scan.py +++ b/mamba_ssm/ops/triton/ssd_chunk_scan.py @@ -67,13 +67,13 @@ def _chunk_scan_fwd_kernel( BLOCK_SIZE_DSTATE: tl.constexpr, IS_TRITON_22: tl.constexpr, ): - pid_bc = tl.program_id(axis=1) + pid_bc = tl.program_id(axis=1).to(tl.int64) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) + pid_h = tl.program_id(axis=2).to(tl.int64) num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n + pid_m = tl.program_id(axis=0).to(tl.int64) // num_pid_n + pid_n = tl.program_id(axis=0).to(tl.int64) % num_pid_n cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head diff --git a/mamba_ssm/ops/triton/ssd_chunk_state.py b/mamba_ssm/ops/triton/ssd_chunk_state.py index bb49c9a9..08069898 100644 --- a/mamba_ssm/ops/triton/ssd_chunk_state.py +++ b/mamba_ssm/ops/triton/ssd_chunk_state.py @@ -48,9 +48,9 @@ def _chunk_cumsum_fwd_kernel( HAS_DT_BIAS: tl.constexpr, BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr, ): - pid_b = tl.program_id(axis=0) - pid_c = tl.program_id(axis=1) - pid_h = tl.program_id(axis=2) + pid_b = tl.program_id(axis=0).to(tl.int64) + pid_c = tl.program_id(axis=1).to(tl.int64) + pid_h = tl.program_id(axis=2).to(tl.int64) dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk @@ -191,13 +191,13 @@ def _chunk_state_fwd_kernel( HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): - pid_bc = tl.program_id(axis=1) + pid_bc = tl.program_id(axis=1).to(tl.int64) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) + pid_h = tl.program_id(axis=2).to(tl.int64) num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n + pid_m = tl.program_id(axis=0).to(tl.int64) // num_pid_n + pid_n = tl.program_id(axis=0).to(tl.int64) % num_pid_n b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head diff --git a/mamba_ssm/ops/triton/ssd_state_passing.py b/mamba_ssm/ops/triton/ssd_state_passing.py index 63863b82..6a7ff1cd 100644 --- a/mamba_ssm/ops/triton/ssd_state_passing.py +++ b/mamba_ssm/ops/triton/ssd_state_passing.py @@ -42,9 +42,9 @@ def _state_passing_fwd_kernel( HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): - pid_b = tl.program_id(axis=1) - pid_h = tl.program_id(axis=2) - pid_m = tl.program_id(axis=0) + pid_b = tl.program_id(axis=1).to(tl.int64) + pid_h = tl.program_id(axis=2).to(tl.int64) + pid_m = tl.program_id(axis=0).to(tl.int64) states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head From 1018e1141e94cf4cabcaec27659f01ae3d7eec17 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 19 Mar 2025 12:22:47 +0000 Subject: [PATCH 2/4] Empty commit for attribution Co-authored-by: LuJunru From 02396519319712049ba0c4a04dade86417ee3c3e Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 19 Mar 2025 12:23:04 +0000 Subject: [PATCH 3/4] Empty commit for attribution Co-authored-by: LuJunru From e9cde7ee901001c63f3987e1ec7377b8ef071d78 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 19 Mar 2025 12:23:20 +0000 Subject: [PATCH 4/4] Empty commit for attribution Co-authored-by: LuJunru