Skip to content

fix: fix large batch size & prefill size issue #708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mamba_ssm/ops/triton/ssd_chunk_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ def _chunk_scan_fwd_kernel(
BLOCK_SIZE_DSTATE: tl.constexpr,
IS_TRITON_22: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_bc = tl.program_id(axis=1).to(tl.int64)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_h = tl.program_id(axis=2)
pid_h = tl.program_id(axis=2).to(tl.int64)
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
pid_m = tl.program_id(axis=0).to(tl.int64) // num_pid_n
pid_n = tl.program_id(axis=0).to(tl.int64) % num_pid_n
cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
Expand Down
14 changes: 7 additions & 7 deletions mamba_ssm/ops/triton/ssd_chunk_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def _chunk_cumsum_fwd_kernel(
HAS_DT_BIAS: tl.constexpr,
BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,
):
pid_b = tl.program_id(axis=0)
pid_c = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
pid_b = tl.program_id(axis=0).to(tl.int64)
pid_c = tl.program_id(axis=1).to(tl.int64)
pid_h = tl.program_id(axis=2).to(tl.int64)
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
Expand Down Expand Up @@ -191,13 +191,13 @@ def _chunk_state_fwd_kernel(
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_bc = tl.program_id(axis=1).to(tl.int64)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_h = tl.program_id(axis=2)
pid_h = tl.program_id(axis=2).to(tl.int64)
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
pid_m = tl.program_id(axis=0).to(tl.int64) // num_pid_n
pid_n = tl.program_id(axis=0).to(tl.int64) % num_pid_n
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
Expand Down
6 changes: 3 additions & 3 deletions mamba_ssm/ops/triton/ssd_state_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def _state_passing_fwd_kernel(
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
pid_m = tl.program_id(axis=0)
pid_b = tl.program_id(axis=1).to(tl.int64)
pid_h = tl.program_id(axis=2).to(tl.int64)
pid_m = tl.program_id(axis=0).to(tl.int64)
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
Expand Down