Skip to content

Commit a78a9eb

Browse files
committed
add notes for variable length sequences
1 parent d28e1b0 commit a78a9eb

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

mamba_ssm/modules/mamba_simple.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def __init__(
119119
def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
120120
"""
121121
hidden_states: (B, L, D)
122+
cu_seqlens: one-dimensional tensor like flash-attn varlen API, only used for variable-length sequences and packing variable-length sequences into one, a.k.a., batch_size B=1
122123
Returns: same shape as hidden_states
123124
"""
124125
batch, seqlen, dim = hidden_states.shape
@@ -157,7 +158,7 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
157158
self.D.float(),
158159
delta_bias=self.dt_proj.bias.float(),
159160
delta_softplus=True,
160-
cu_seqlens=cu_seqlens[0] if cu_seqlens is not None else None,
161+
cu_seqlens=cu_seqlens,
161162
)
162163
else:
163164
x, z = xz.chunk(2, dim=1)
@@ -166,12 +167,12 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
166167
if cu_seqlens is not None:
167168
padded_x = x
168169
count = 0
169-
for idx in cu_seqlens[0][1:-1].tolist():
170+
for idx in cu_seqlens[1:-1].tolist():
170171
padded_idx = idx + count*(self.d_conv - 1)
171172
padded_x = torch.cat((padded_x[:, :, :padded_idx], torch.zeros(1, x.shape[1], self.d_conv - 1, dtype=x.dtype, device=x.device), padded_x[:, :, padded_idx:]), dim=2)
172173
count = count + 1
173174
x = padded_x
174-
assert x.shape[2] == (self.d_conv - 1) * len(cu_seqlens[0][1:-1]) + z.shape[2]
175+
# assert x.shape[2] == (self.d_conv - 1) * len(cu_seqlens[1:-1]) + z.shape[2]
175176

176177
# Compute short convolution
177178
if conv_state is not None:
@@ -192,13 +193,13 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
192193
# (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences
193194
if cu_seqlens is not None:
194195
mask = []
195-
for seq_len in (cu_seqlens[0][1:] - cu_seqlens[0][:-1]).tolist():
196+
for seq_len in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist():
196197
mask.extend([True] * seq_len)
197198
mask.extend([False] * (self.d_conv - 1))
198199
mask = mask[:-(self.d_conv - 1)]
199-
assert x.shape[2] == len(mask)
200+
# assert x.shape[2] == len(mask)
200201
x = x[:, :, mask]
201-
assert x.shape[2] == z.shape[2]
202+
# assert x.shape[2] == z.shape[2]
202203

203204
# We're careful here about the layout, to avoid extra transposes.
204205
# We want dt to have d as the slowest moving dimension
@@ -222,7 +223,7 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
222223
delta_bias=self.dt_proj.bias.float(),
223224
delta_softplus=True,
224225
return_last_state=ssm_state is not None,
225-
cu_seqlens=cu_seqlens[0] if cu_seqlens is not None else None,
226+
cu_seqlens=cu_seqlens,
226227
)
227228
if ssm_state is not None:
228229
y, last_state = y

0 commit comments

Comments
 (0)