Skip to content

Commit 8174c45

Browse files
committed
use seq_idx if provided, or compute it by cu_seqlens
1 parent 909f970 commit 8174c45

File tree

2 files changed

+16
-20
lines changed

2 files changed

+16
-20
lines changed

mamba_ssm/modules/mamba_simple.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,22 @@ def __init__(
116116

117117
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
118118

119-
def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
119+
def forward(self, hidden_states, cu_seqlens=None, seq_idx=None, inference_params=None):
120120
"""
121121
hidden_states: (B, L, D)
122122
cu_seqlens: (Optional) cumulative sum of the sequence lengths, starting from 0 and end with L, and must already be sorted.
123123
Returns: same shape as hidden_states
124124
"""
125125
batch, seqlen, dim = hidden_states.shape
126+
126127
if cu_seqlens is not None:
127128
assert batch == 1 and cu_seqlens.ndimension() == 1, "varlen mamba1 is only supported with B=1"
129+
# compute seq_idx if not provided
130+
if seq_idx is None:
131+
seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device)
132+
for i, s in enumerate(cu_seqlens[1:]-cu_seqlens[:-1])], dim=0).unsqueeze(0)
133+
else:
134+
seq_idx = None
128135

129136
conv_state, ssm_state = None, None
130137
if inference_params is not None:
@@ -160,7 +167,8 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
160167
self.D.float(),
161168
delta_bias=self.dt_proj.bias.float(),
162169
delta_softplus=True,
163-
cu_seqlens=cu_seqlens
170+
cu_seqlens=cu_seqlens,
171+
seq_idx=seq_idx,
164172
)
165173
else:
166174
x, z = xz.chunk(2, dim=1)

mamba_ssm/ops/selective_scan_interface.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -169,19 +169,13 @@ class MambaInnerFn(torch.autograd.Function):
169169
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
170170
out_proj_weight, out_proj_bias,
171171
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
172-
C_proj_bias=None, delta_softplus=True, cu_seqlens=None, checkpoint_lvl=1):
172+
C_proj_bias=None, delta_softplus=True, cu_seqlens=None, seq_idx=None, checkpoint_lvl=1):
173173
"""
174174
xz: (batch, dim, seqlen)
175175
"""
176176
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
177177
assert checkpoint_lvl in [0, 1]
178-
179-
if cu_seqlens is not None:
180-
seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device)
181-
for i, s in enumerate(cu_seqlens[1:]-cu_seqlens[:-1])], dim=0).unsqueeze(0)
182-
else:
183-
seq_idx = None
184-
178+
185179
L = xz.shape[-1]
186180
delta_rank = delta_proj_weight.shape[1]
187181
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
@@ -355,38 +349,32 @@ def backward(ctx, dout):
355349
dout_proj_weight, dout_proj_bias,
356350
dA, dB, dC, dD,
357351
ddelta_bias if delta_bias is not None else None,
358-
dB_proj_bias, dC_proj_bias, None, None)
352+
dB_proj_bias, dC_proj_bias, None, None, None)
359353

360354

361355
def mamba_inner_fn(
362356
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
363357
out_proj_weight, out_proj_bias,
364358
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
365-
C_proj_bias=None, delta_softplus=True, cu_seqlens=None
359+
C_proj_bias=None, delta_softplus=True, cu_seqlens=None, seq_idx=None,
366360
):
367361
return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
368362
out_proj_weight, out_proj_bias,
369-
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, cu_seqlens)
363+
A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, cu_seqlens, seq_idx)
370364

371365

372366
def mamba_inner_ref(
373367
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
374368
out_proj_weight, out_proj_bias,
375369
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
376-
C_proj_bias=None, delta_softplus=True, cu_seqlens=None
370+
C_proj_bias=None, delta_softplus=True, cu_seqlens=None, seq_idx=None,
377371
):
378372
assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d."
379373
L = xz.shape[-1]
380374
delta_rank = delta_proj_weight.shape[1]
381375
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
382376
x, z = xz.chunk(2, dim=1)
383377

384-
if cu_seqlens is not None:
385-
seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device)
386-
for i, s in enumerate(cu_seqlens[1:]-cu_seqlens[:-1])], dim=0).unsqueeze(0)
387-
else:
388-
seq_idx = None
389-
390378
x = causal_conv1d_fn(
391379
x.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else x,
392380
rearrange(conv1d_weight, "d 1 w -> d w"),

0 commit comments

Comments
 (0)