Skip to content

[Kernel][Model] Improve continuous batching for Jamba and Mamba #9189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions csrc/mamba/causal_conv1d/causal_conv1d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -334,16 +334,17 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
+ channel_id * params.out_c_stride;
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);

bool has_initial_state = params.has_initial_state_ptr == nullptr ? false
: reinterpret_cast<bool *>(params.has_initial_state_ptr)[batch_id];

int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
: reinterpret_cast<int *>(params.cache_indices_ptr);
int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];

input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr

// cache_index == -1 is defined as padding and cache shouldn't been written/read
input_t *conv_states = params.conv_states_ptr == nullptr || cache_index == -1 ? nullptr
: reinterpret_cast<input_t *>(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride;

bool has_initial_state = params.has_initial_state_ptr == nullptr || conv_states == nullptr ? false
: reinterpret_cast<bool *>(params.has_initial_state_ptr)[batch_id] ;

// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
if (tidx == 0) {
input_t initial_state[kNElts] = {0};
Expand Down Expand Up @@ -528,6 +529,7 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
? batch_id
: params.conv_state_indices_ptr[batch_id];
if (conv_state_batch_coord == -1) return;
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
+ conv_state_batch_coord * params.conv_state_batch_stride
+ channel_id * params.conv_state_c_stride;
Expand Down
11 changes: 6 additions & 5 deletions csrc/mamba/mamba_ssm/selective_scan_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,14 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
sequence_start_index = query_start_loc[batch_id];
seqlen = query_start_loc[batch_id + 1] - sequence_start_index;
}
const bool has_initial_state = params.has_initial_state_ptr == nullptr ? false
: reinterpret_cast<bool *>(params.has_initial_state_ptr)[batch_id];

const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
: reinterpret_cast<int *>(params.cache_indices_ptr);
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
const bool wr_cache = cache_index != -1;
const bool has_initial_state = params.has_initial_state_ptr == nullptr && wr_cache ? false
: reinterpret_cast<bool *>(params.has_initial_state_ptr)[batch_id];


input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
+ dim_id * kNRows * params.u_d_stride;
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + sequence_start_index * params.delta_batch_stride
Expand Down Expand Up @@ -250,7 +252,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
if (threadIdx.x == 0) {
smem_running_prefix[state_idx] = prefix_op.running_prefix;
if (chunk == n_chunks - 1) {
if (chunk == n_chunks - 1 && wr_cache) {
ssm_states[state_idx] = input_t(prefix_op.running_prefix.y);
}
}
Expand Down Expand Up @@ -626,7 +628,6 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
TORCH_CHECK(ssm_states.scalar_type() == input_type);
TORCH_CHECK(ssm_states.is_cuda());
TORCH_CHECK(ssm_states.stride(-1) == 1);
CHECK_SHAPE(ssm_states, batch_size, dim, dstate);

SSMParamsBase params;
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
Expand Down
96 changes: 96 additions & 0 deletions tests/kernels/test_causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.utils import seed_everything
Expand Down Expand Up @@ -271,6 +272,52 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
))


@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("seqlen", [1])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize("dim", [2048])
def test_causal_conv1d_update_with_batch_gather_padding_unchanged(
dim, width, seqlen, has_bias, silu_activation, itype):
device = "cuda"
# set seed
seed_everything(0)
batch = 64

x = torch.randn(batch, dim, 1, device=device, dtype=itype)

total_entries = 10 * batch
conv_state = torch.randn(total_entries,
dim,
width - 1,
device=device,
dtype=itype)
conv_state_before = conv_state.clone()
conv_state_indices = torch.as_tensor([PAD_SLOT_ID] * batch,
dtype=torch.int32,
device=device)

weight = torch.randn(dim,
width,
device=device,
dtype=itype,
requires_grad=True)
if has_bias:
bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True)
else:
bias = None
activation = None if not silu_activation else "silu"
causal_conv1d_update(x,
conv_state,
weight,
bias,
activation=activation,
conv_state_indices=conv_state_indices)

assert torch.equal(conv_state, conv_state_before)


@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
Expand Down Expand Up @@ -422,3 +469,52 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
cache_indices, has_initial_states, final_states,
activation)


@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize('seqlen', [256])
@pytest.mark.parametrize('dim', [64])
def test_causal_conv1d_varlen_check_padding_state_unchanged(
dim, seqlen, width, has_bias, silu_activation, itype):
device = "cuda"
# set seed
seed_everything(0)
batch = 1
seqlens = []
nsplits = 3
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
seqlens.append(
torch.diff(
torch.cat(
[torch.tensor([-1]), eos_pos,
torch.tensor([seqlen - 1])])).tolist())
assert sum(seqlens[-1]) == seqlen
assert all(s > 0 for s in seqlens[-1])

cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0)
x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device,
dtype=itype)[:, 4096:4096 + dim, :]
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
activation = None if not silu_activation else "silu"
final_states = torch.randn(nsplits + 1,
dim,
width - 1,
device=x.device,
dtype=x.dtype)
final_states_before = final_states.clone()
has_initial_states = torch.randint(0,
2, (cumsum.shape[0] - 1, ),
dtype=torch.bool,
device=x.device)
cache_indices = torch.as_tensor([PAD_SLOT_ID] * (cumsum.shape[0] - 1),
dtype=torch.int32,
device=x.device)
causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(), cache_indices,
has_initial_states, final_states, activation)
assert torch.equal(final_states, final_states_before)
114 changes: 114 additions & 0 deletions tests/kernels/test_mamba_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.utils import seed_everything
Expand Down Expand Up @@ -515,6 +516,119 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
has_initial_state, prev_state)


@pytest.mark.parametrize('wtype', [torch.float32])
@pytest.mark.parametrize('itype', [torch.float32])
@pytest.mark.parametrize('seqlen', [256])
@pytest.mark.parametrize("return_last_state", [True])
@pytest.mark.parametrize('has_delta_bias', [True])
@pytest.mark.parametrize('delta_softplus', [True])
@pytest.mark.parametrize('has_z', [True])
@pytest.mark.parametrize('has_D', [True])
@pytest.mark.parametrize("varBC_groups", [1])
@pytest.mark.parametrize("is_variable_C", [True])
@pytest.mark.parametrize("is_variable_B", [True])
def test_selective_scan_varlen_padding_state_unchanged(
is_variable_B, is_variable_C, varBC_groups, has_D, has_z,
has_delta_bias, delta_softplus, return_last_state, seqlen, itype,
wtype):
if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
pytest.skip() # This config is not applicable
device = 'cuda'
# set seed
torch.random.manual_seed(0)
seqlens = []
nsplits = 3
if seqlen < 10:
nsplits = 0
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
seqlens.append(
torch.diff(
torch.cat(
[torch.tensor([-1]), eos_pos,
torch.tensor([seqlen - 1])])).tolist())
assert sum(seqlens[-1]) == seqlen
assert all(s > 0 for s in seqlens[-1])

cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0).cuda()

dim = 4
dstate = 8
A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype))
B_shape = [varBC_groups, dstate, seqlen]
B = torch.randn(B_shape,
device=device,
dtype=wtype if not is_variable_B else itype)
C_shape = [varBC_groups, dstate, seqlen]
C = torch.randn(C_shape,
device=device,
dtype=wtype if not is_variable_C else itype)
D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
z = torch.randn(dim, seqlen, device=device, dtype=itype)
delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)
) if has_delta_bias else None
u = torch.randn(dim, seqlen, device=device, dtype=itype)
delta = (0.5 * torch.rand(dim, seqlen, device=device, dtype=itype))
prev_state_shape = (cumsum.shape[0] - 1, u.shape[0], int(A.shape[1]))
prev_state = torch.randn(prev_state_shape,
device=u.device,
dtype=itype,
requires_grad=False)
prev_state_ref = prev_state.clone()
cache_indices = torch.as_tensor([-1] * (cumsum.shape[0] - 1),
dtype=torch.int32,
device=u.device)

has_initial_state = torch.randint(0,
2, (cumsum.shape[0] - 1, ),
dtype=torch.bool,
device=u.device)
selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias,
delta_softplus, cumsum, cache_indices, has_initial_state)
assert torch.equal(prev_state, prev_state_ref)


@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("has_z", [True])
@pytest.mark.parametrize("dstate", [16])
@pytest.mark.parametrize("dim", [2048])
def test_selective_state_update_with_batch_indices_padding_unchanged(
dim, dstate, has_z, itype):
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 3

total_entries = 10 * batch_size
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
state_indices = torch.as_tensor([PAD_SLOT_ID] * batch_size,
dtype=torch.int32,
device=device)

x = torch.randn(batch_size, dim, device=device, dtype=itype)
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(batch_size, dstate, device=device)
C = torch.randn(batch_size, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state.clone()
selective_state_update(state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices)
assert torch.equal(state_ref, state)


@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("has_z", [True])
Expand Down
25 changes: 19 additions & 6 deletions vllm/model_executor/layers/mamba/ops/mamba_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,16 @@ def _selective_scan_update_kernel(
z_ptrs = z_ptr + offs_m * stride_z_dim
out_ptrs = out_ptr + offs_m * stride_out_dim

state = tl.load(state_ptrs,
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
other=0.0)
if HAS_STATE_BATCH_INDICES:
state = tl.load(state_ptrs,
mask=(offs_m[:, None] < dim) &
(offs_n[None, :] < dstate) & (state_batch_idx != -1),
other=0.0)
else:
state = tl.load(state_ptrs,
mask=(offs_m[:, None] < dim) &
(offs_n[None, :] < dstate),
other=0.0)
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if not TIE_HDIM:
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
Expand Down Expand Up @@ -177,9 +184,15 @@ def _selective_scan_update_kernel(

dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
state = state * dA + dB * x[:, None]
tl.store(state_ptrs,
state,
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
if HAS_STATE_BATCH_INDICES:
tl.store(state_ptrs,
state,
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate) &
(state_batch_idx != -1))
else:
tl.store(state_ptrs,
state,
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
out = tl.sum(state * C[None, :], axis=1)
if HAS_D:
out += x * D
Expand Down
Loading
Loading