Skip to content

Commit b69b957

Browse files
committed
migrate to tridao's native varlen causal_conv1d kernel for speedup
1 parent 6961faa commit b69b957

File tree

5 files changed

+122
-172
lines changed

5 files changed

+122
-172
lines changed

mamba_ssm/models/mixer_seq_simple.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from mamba_ssm.modules.mamba2 import Mamba2
1717
from mamba_ssm.modules.mha import MHA
1818
from mamba_ssm.modules.mlp import GatedMLP
19-
from mamba_ssm.modules.mamba_simple import Block as Block_Mamba1
2019
from mamba_ssm.modules.block import Block
2120
from mamba_ssm.utils.generation import GenerationMixin
2221
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
@@ -71,8 +70,7 @@ def create_block(
7170
mlp_cls = partial(
7271
GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
7372
)
74-
block_cls = Block if ssm_layer == "Mamba2" else Block_Mamba1
75-
block = block_cls(
73+
block = Block(
7674
d_model,
7775
mixer_cls,
7876
mlp_cls,
@@ -189,12 +187,12 @@ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs)
189187
for i, layer in enumerate(self.layers)
190188
}
191189

192-
def forward(self, input_ids, cu_seqlens=None, inference_params=None, **mixer_kwargs):
190+
def forward(self, input_ids, inference_params=None, **mixer_kwargs):
193191
hidden_states = self.embedding(input_ids)
194192
residual = None
195193
for layer in self.layers:
196194
hidden_states, residual = layer(
197-
hidden_states, residual, cu_seqlens=cu_seqlens, inference_params=inference_params, **mixer_kwargs
195+
hidden_states, residual, inference_params=inference_params, **mixer_kwargs
198196
)
199197
if not self.fused_add_norm:
200198
residual = (hidden_states + residual) if residual is not None else hidden_states
@@ -273,12 +271,12 @@ def tie_weights(self):
273271
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
274272
return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
275273

276-
def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, cu_seqlens=None, **mixer_kwargs):
274+
def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs):
277275
"""
278276
"position_ids" is just to be compatible with Transformer generation. We don't use it.
279277
num_last_tokens: if > 0, only return the logits for the last n tokens
280278
"""
281-
hidden_states = self.backbone(input_ids, cu_seqlens=cu_seqlens, inference_params=inference_params, **mixer_kwargs)
279+
hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
282280
if num_last_tokens > 0:
283281
hidden_states = hidden_states[:, -num_last_tokens:]
284282
lm_logits = self.lm_head(hidden_states)

mamba_ssm/modules/block.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,4 @@ def forward(
8888
return hidden_states, residual
8989

9090
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
91-
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
91+
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

mamba_ssm/modules/mamba_simple.py

Lines changed: 22 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from einops import rearrange, repeat
1212

13-
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, selective_scan_ref
13+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
1414

1515
try:
1616
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
@@ -119,10 +119,12 @@ def __init__(
119119
def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
120120
"""
121121
hidden_states: (B, L, D)
122-
cu_seqlens: one-dimensional tensor representing cumulative start indexes of packed sequence, a.k.a., B=1
122+
cu_seqlens: (Optional) cumulative sum of the sequence lengths, starting from 0 and end with L, and must already be sorted.
123123
Returns: same shape as hidden_states
124124
"""
125125
batch, seqlen, dim = hidden_states.shape
126+
if cu_seqlens is not None:
127+
assert batch == 1 and cu_seqlens.ndimension() == 1, "varlen mamba1 is only supported with B=1"
126128

127129
conv_state, ssm_state = None, None
128130
if inference_params is not None:
@@ -158,46 +160,40 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
158160
self.D.float(),
159161
delta_bias=self.dt_proj.bias.float(),
160162
delta_softplus=True,
161-
cu_seqlens=cu_seqlens,
162-
d_conv=torch.tensor(self.d_conv)
163+
cu_seqlens=cu_seqlens
163164
)
164165
else:
165166
x, z = xz.chunk(2, dim=1)
166-
167-
# (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences
168-
if cu_seqlens is not None:
169-
padded_x = x
170-
count = 0
171-
for idx in cu_seqlens[1:-1].tolist():
172-
padded_idx = idx + count*(self.d_conv - 1)
173-
padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], self.d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2)
174-
count = count + 1
175-
x = padded_x
176-
177167
# Compute short convolution
178168
if conv_state is not None:
179169
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
180170
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
181171
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
182172
if causal_conv1d_fn is None:
183-
x = self.act(self.conv1d(x)[..., :seqlen])
173+
if cu_seqlens is not None:
174+
# naive pure python implementation of varlen causal_conv1d
175+
for i, s in enumerate(cu_seqlens[1:-1]):
176+
x = torch.cat((x[..., :s + i*(self.d_conv - 1)], torch.zeros_like(x[..., :(self.d_conv - 1)]), x[..., s + i*(self.d_conv - 1):]), dim=2)
177+
mask = torch.cat([torch.cat((torch.full((s,), True, dtype=torch.bool, device=x.device),
178+
torch.full((self.d_conv - 1,), False, dtype=torch.bool, device=x.device)), dim=0)
179+
for s in (cu_seqlens[1:] - cu_seqlens[:-1])], dim=0)
180+
x = self.act(self.conv1d(x)[:, :, mask])
181+
else:
182+
x = self.act(self.conv1d(x)[..., :seqlen])
184183
else:
185184
assert self.activation in ["silu", "swish"]
185+
if cu_seqlens is not None:
186+
seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device)
187+
for i, s in enumerate(cu_seqlens[1:]-cu_seqlens[:-1])], dim=0).unsqueeze(0)
188+
else:
189+
seq_idx = None
186190
x = causal_conv1d_fn(
187-
x=x,
191+
x=x.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else x,
188192
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
189193
bias=self.conv1d.bias,
194+
seq_idx=seq_idx,
190195
activation=self.activation,
191196
)
192-
193-
# (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences
194-
if cu_seqlens is not None:
195-
mask = []
196-
for seq_len in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist():
197-
mask.extend([True] * seq_len)
198-
mask.extend([False] * (self.d_conv - 1))
199-
mask = mask[:-(self.d_conv - 1)]
200-
x = x[:, :, mask]
201197

202198
# We're careful here about the layout, to avoid extra transposes.
203199
# We want dt to have d as the slowest moving dimension
@@ -208,7 +204,6 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
208204
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
209205
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
210206
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
211-
212207
assert self.activation in ["silu", "swish"]
213208
y = selective_scan_fn(
214209
x,
@@ -317,59 +312,3 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states
317312
conv_state.zero_()
318313
ssm_state.zero_()
319314
return conv_state, ssm_state
320-
321-
class Block(nn.Module):
322-
def __init__(
323-
self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
324-
):
325-
"""
326-
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
327-
This Block has a slightly different structure compared to a regular
328-
prenorm Transformer block.
329-
The standard block is: LN -> MHA/MLP -> Add.
330-
[Ref: https://arxiv.org/abs/2002.04745]
331-
Here we have: Add -> LN -> Mixer, returning both
332-
the hidden_states (output of the mixer) and the residual.
333-
This is purely for performance reasons, as we can fuse add and LayerNorm.
334-
The residual needs to be provided (except for the very first block).
335-
"""
336-
super().__init__()
337-
self.residual_in_fp32 = residual_in_fp32
338-
self.fused_add_norm = fused_add_norm
339-
self.mixer = mixer_cls(dim)
340-
self.norm = norm_cls(dim)
341-
if self.fused_add_norm:
342-
assert RMSNorm is not None, "RMSNorm import fails"
343-
assert isinstance(
344-
self.norm, (nn.LayerNorm, RMSNorm)
345-
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
346-
347-
def forward(
348-
self, hidden_states: Tensor, residual: Optional[Tensor] = None, cu_seqlens=None, inference_params=None
349-
):
350-
r"""Pass the input through the encoder layer.
351-
Args:
352-
hidden_states: the sequence to the encoder layer (required).
353-
residual: hidden_states = Mixer(LN(residual))
354-
"""
355-
if not self.fused_add_norm:
356-
residual = (hidden_states + residual) if residual is not None else hidden_states
357-
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
358-
if self.residual_in_fp32:
359-
residual = residual.to(torch.float32)
360-
else:
361-
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
362-
hidden_states, residual = fused_add_norm_fn(
363-
hidden_states,
364-
self.norm.weight,
365-
self.norm.bias,
366-
residual=residual,
367-
prenorm=True,
368-
residual_in_fp32=self.residual_in_fp32,
369-
eps=self.norm.eps,
370-
)
371-
hidden_states = self.mixer(hidden_states, cu_seqlens=cu_seqlens, inference_params=inference_params)
372-
return hidden_states, residual
373-
374-
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
375-
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

0 commit comments

Comments
 (0)