Skip to content

Commit 03a38fb

Browse files
committed
Implement varlen generation
1 parent 3462302 commit 03a38fb

File tree

8 files changed

+372
-23
lines changed

8 files changed

+372
-23
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ with an efficient hardware-aware design and implementation in the spirit of [Fla
1919

2020
## Installation
2121

22-
- [Option] `pip install causal-conv1d>=1.2.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
22+
- [Option] `pip install causal-conv1d>=1.4.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
2323
- `pip install mamba-ssm`: the core Mamba package.
2424

2525
It can also be built from source with `pip install .` from this repository.

mamba_ssm/models/mixer_seq_simple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def forward(self, input_ids, inference_params=None, **mixer_kwargs):
192192
residual = None
193193
for layer in self.layers:
194194
hidden_states, residual = layer(
195-
hidden_states, residual, inference_params=inference_params
195+
hidden_states, residual, inference_params=inference_params, **mixer_kwargs
196196
)
197197
if not self.fused_add_norm:
198198
residual = (hidden_states + residual) if residual is not None else hidden_states

mamba_ssm/modules/mamba2.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
except ImportError:
1414
causal_conv1d_fn, causal_conv1d_update = None, None
1515

16+
try:
17+
from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
18+
except ImportError:
19+
causal_conv1d_varlen_states = None
20+
1621
try:
1722
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
1823
except ImportError:
@@ -144,7 +149,7 @@ def __init__(
144149
process_group=self.process_group, sequence_parallel=self.sequence_parallel,
145150
**factory_kwargs)
146151

147-
def forward(self, u, seqlen=None, seq_idx=None, inference_params=None):
152+
def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None):
148153
"""
149154
u: (batch, seqlen, hidden_dim) if seqlen=None.
150155
If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we
@@ -161,7 +166,8 @@ def forward(self, u, seqlen=None, seq_idx=None, inference_params=None):
161166

162167
conv_state, ssm_state = None, None
163168
if inference_params is not None:
164-
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
169+
inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
170+
conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch)
165171
if inference_params.seqlen_offset > 0:
166172
# The states are updated inplace
167173
out, _, _ = self.step(u, conv_state, ssm_state)
@@ -206,14 +212,22 @@ def forward(self, u, seqlen=None, seq_idx=None, inference_params=None):
206212
dim=-1
207213
)
208214
if conv_state is not None:
209-
# If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
210-
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
211-
xBC_t = rearrange(xBC, "b l d -> b d l")
212-
conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W)
215+
if cu_seqlens is None:
216+
# If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
217+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
218+
xBC_t = rearrange(xBC, "b l d -> b d l")
219+
conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W)
220+
else:
221+
assert causal_conv1d_varlen_states is not None, "varlen inference requires causal_conv1d package"
222+
assert batch == 1, "varlen inference only supports batch dimension 1"
223+
conv_varlen_states = causal_conv1d_varlen_states(
224+
xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
225+
)
226+
conv_state.copy_(conv_varlen_states)
213227
assert self.activation in ["silu", "swish"]
214228
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
215229
xBC = self.act(
216-
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
230+
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.dconv - 1):]
217231
) # (B, L, self.d_ssm + 2 * ngroups * d_state)
218232
else:
219233
xBC = causal_conv1d_fn(
@@ -235,12 +249,18 @@ def forward(self, u, seqlen=None, seq_idx=None, inference_params=None):
235249
dt_bias=self.dt_bias,
236250
dt_softplus=True,
237251
seq_idx=seq_idx,
252+
cu_seqlens=cu_seqlens,
238253
**dt_limit_kwargs,
239254
return_final_states=ssm_state is not None,
255+
return_varlen_states=cu_seqlens is not None and inference_params is not None,
240256
)
241257
if ssm_state is not None:
242-
y, last_state = y
243-
ssm_state.copy_(last_state)
258+
y, last_state, *rest = y
259+
if cu_seqlens is None:
260+
ssm_state.copy_(last_state)
261+
else:
262+
varlen_states = rest[0]
263+
ssm_state.copy_(varlen_states)
244264
y = rearrange(y, "b l h p -> b l (h p)")
245265
if self.rmsnorm:
246266
y = self.norm(y, z)
@@ -322,8 +342,8 @@ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs)
322342
device = self.out_proj.weight.device
323343
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
324344
conv_state = torch.zeros(
325-
batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=conv_dtype
326-
)
345+
batch_size, self.d_conv, self.conv1d.weight.shape[0], device=device, dtype=conv_dtype
346+
).transpose(1, 2)
327347
ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
328348
ssm_state = torch.zeros(
329349
batch_size, self.nheads, self.headdim, self.d_state, device=device, dtype=ssm_dtype
@@ -336,11 +356,11 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states
336356
batch_shape = (batch_size,)
337357
conv_state = torch.zeros(
338358
batch_size,
339-
self.conv1d.weight.shape[0],
340359
self.d_conv,
360+
self.conv1d.weight.shape[0],
341361
device=self.conv1d.weight.device,
342362
dtype=self.conv1d.weight.dtype,
343-
)
363+
).transpose(1, 2)
344364
ssm_state = torch.zeros(
345365
batch_size,
346366
self.nheads,

mamba_ssm/ops/triton/ssd_chunk_state.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,97 @@ def _chunk_state_bwd_ddAcs_stable_kernel(
571571
tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1)
572572

573573

574+
@triton.autotune(
575+
configs=[
576+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
577+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
578+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
579+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
580+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
581+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
582+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
583+
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
584+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
585+
],
586+
key=['hdim', 'dstate', 'chunk_size'],
587+
)
588+
@triton.jit
589+
def _chunk_state_varlen_kernel(
590+
# Pointers to matrices
591+
x_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, chunk_states_ptr, cu_seqlens_ptr, states_ptr,
592+
# Matrix dimensions
593+
hdim, dstate, chunk_size,
594+
seqlen, nheads_ngroups_ratio,
595+
# Strides
596+
stride_x_seqlen, stride_x_head, stride_x_hdim,
597+
stride_b_seqlen, stride_b_head, stride_b_dstate,
598+
stride_dt_chunk, stride_dt_head, stride_dt_csize,
599+
stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
600+
stride_chunk_states_chunk, stride_chunk_states_head, stride_chunk_states_hdim, stride_chunk_states_dstate,
601+
stride_states_batch, stride_states_head, stride_states_hdim, stride_states_dstate,
602+
# Meta-parameters
603+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
604+
):
605+
pid_b = tl.program_id(axis=1)
606+
pid_h = tl.program_id(axis=2)
607+
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
608+
pid_m = tl.program_id(axis=0) // num_pid_n
609+
pid_n = tl.program_id(axis=0) % num_pid_n
610+
end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
611+
pid_c = (end_idx - 1) // chunk_size
612+
b_ptr += pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
613+
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
614+
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
615+
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
616+
chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
617+
618+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
619+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
620+
offs_k = tl.arange(0, BLOCK_SIZE_K)
621+
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen)
622+
b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen)
623+
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
624+
dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
625+
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
626+
627+
chunk_size_limit = end_idx - pid_c * chunk_size
628+
start_idx = tl.load(cu_seqlens_ptr + pid_b)
629+
start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
630+
631+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
632+
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
633+
x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k) & (offs_k[None, :] >= start_idx_cur - k), other=0.0)
634+
b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate) & (offs_k[:, None] >= start_idx_cur - k), other=0.0).to(tl.float32)
635+
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
636+
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
637+
scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
638+
tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
639+
b *= scale[:, None]
640+
b = b.to(x_ptr.dtype.element_ty)
641+
acc += tl.dot(x, b)
642+
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
643+
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
644+
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
645+
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
646+
647+
# If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
648+
if start_idx < pid_c * chunk_size:
649+
chunk_states_ptrs = chunk_states_ptr + (offs_m[:, None] * stride_chunk_states_hdim + offs_n[None, :] * stride_chunk_states_dstate)
650+
chunk_states = tl.load(chunk_states_ptrs, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
651+
# scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0)
652+
scale = tl.exp(dA_cs_last)
653+
acc += chunk_states * scale
654+
655+
states = acc.to(states_ptr.dtype.element_ty)
656+
657+
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
658+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
659+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
660+
states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate)
661+
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
662+
tl.store(states_ptrs, states, mask=c_mask)
663+
664+
574665
def _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
575666
batch, seqlen, nheads = dt.shape
576667
assert A.shape == (nheads,)
@@ -790,6 +881,35 @@ def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None):
790881
return ddA_cumsum
791882

792883

884+
def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states):
885+
total_seqlen, nheads, headdim = x.shape
886+
_, nchunks, chunk_size = dt.shape
887+
_, ngroups, dstate = B.shape
888+
batch = cu_seqlens.shape[0] - 1
889+
cu_seqlens = cu_seqlens.contiguous()
890+
assert nheads % ngroups == 0
891+
assert B.shape == (total_seqlen, ngroups, dstate)
892+
assert dt.shape == (nheads, nchunks, chunk_size)
893+
assert dA_cumsum.shape == dt.shape
894+
assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
895+
states = torch.empty(batch, nheads, headdim, dstate, dtype=chunk_states.dtype, device=chunk_states.device)
896+
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
897+
batch, nheads)
898+
with torch.cuda.device(x.device.index):
899+
_chunk_state_varlen_kernel[grid](
900+
x, B, dt, dA_cumsum, chunk_states, cu_seqlens, states,
901+
headdim, dstate, chunk_size,
902+
total_seqlen, nheads // ngroups,
903+
x.stride(0), x.stride(1), x.stride(2),
904+
B.stride(0), B.stride(1), B.stride(2),
905+
dt.stride(1), dt.stride(0), dt.stride(2),
906+
dA_cumsum.stride(1), dA_cumsum.stride(0), dA_cumsum.stride(2),
907+
chunk_states.stride(0), chunk_states.stride(1), chunk_states.stride(2), chunk_states.stride(3),
908+
states.stride(0), states.stride(1), states.stride(2), states.stride(3),
909+
)
910+
return states
911+
912+
793913
class ChunkStateFn(torch.autograd.Function):
794914

795915
@staticmethod

0 commit comments

Comments
 (0)