Skip to content

Commit 59be631

Browse files
committed
use seq_idx if provided, or compute it by cu_seqlens
1 parent 8174c45 commit 59be631

File tree

1 file changed

+0
-5
lines changed

1 file changed

+0
-5
lines changed

mamba_ssm/modules/mamba_simple.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,6 @@ def forward(self, hidden_states, cu_seqlens=None, seq_idx=None, inference_params
190190
x = self.act(self.conv1d(x)[..., :seqlen])
191191
else:
192192
assert self.activation in ["silu", "swish"]
193-
if cu_seqlens is not None:
194-
seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device)
195-
for i, s in enumerate(cu_seqlens[1:]-cu_seqlens[:-1])], dim=0).unsqueeze(0)
196-
else:
197-
seq_idx = None
198193
x = causal_conv1d_fn(
199194
x=x.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else x,
200195
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),

0 commit comments

Comments
 (0)