@@ -169,19 +169,13 @@ class MambaInnerFn(torch.autograd.Function):
169
169
def forward (ctx , xz , conv1d_weight , conv1d_bias , x_proj_weight , delta_proj_weight ,
170
170
out_proj_weight , out_proj_bias ,
171
171
A , B = None , C = None , D = None , delta_bias = None , B_proj_bias = None ,
172
- C_proj_bias = None , delta_softplus = True , cu_seqlens = None , checkpoint_lvl = 1 ):
172
+ C_proj_bias = None , delta_softplus = True , cu_seqlens = None , seq_idx = None , checkpoint_lvl = 1 ):
173
173
"""
174
174
xz: (batch, dim, seqlen)
175
175
"""
176
176
assert causal_conv1d_cuda is not None , "causal_conv1d_cuda is not available. Please install causal-conv1d."
177
177
assert checkpoint_lvl in [0 , 1 ]
178
-
179
- if cu_seqlens is not None :
180
- seq_idx = torch .cat ([torch .full ((s ,), i , dtype = torch .int32 , device = cu_seqlens .device )
181
- for i , s in enumerate (cu_seqlens [1 :]- cu_seqlens [:- 1 ])], dim = 0 ).unsqueeze (0 )
182
- else :
183
- seq_idx = None
184
-
178
+
185
179
L = xz .shape [- 1 ]
186
180
delta_rank = delta_proj_weight .shape [1 ]
187
181
d_state = A .shape [- 1 ] * (1 if not A .is_complex () else 2 )
@@ -355,38 +349,32 @@ def backward(ctx, dout):
355
349
dout_proj_weight , dout_proj_bias ,
356
350
dA , dB , dC , dD ,
357
351
ddelta_bias if delta_bias is not None else None ,
358
- dB_proj_bias , dC_proj_bias , None , None )
352
+ dB_proj_bias , dC_proj_bias , None , None , None )
359
353
360
354
361
355
def mamba_inner_fn (
362
356
xz , conv1d_weight , conv1d_bias , x_proj_weight , delta_proj_weight ,
363
357
out_proj_weight , out_proj_bias ,
364
358
A , B = None , C = None , D = None , delta_bias = None , B_proj_bias = None ,
365
- C_proj_bias = None , delta_softplus = True , cu_seqlens = None
359
+ C_proj_bias = None , delta_softplus = True , cu_seqlens = None , seq_idx = None ,
366
360
):
367
361
return MambaInnerFn .apply (xz , conv1d_weight , conv1d_bias , x_proj_weight , delta_proj_weight ,
368
362
out_proj_weight , out_proj_bias ,
369
- A , B , C , D , delta_bias , B_proj_bias , C_proj_bias , delta_softplus , cu_seqlens )
363
+ A , B , C , D , delta_bias , B_proj_bias , C_proj_bias , delta_softplus , cu_seqlens , seq_idx )
370
364
371
365
372
366
def mamba_inner_ref (
373
367
xz , conv1d_weight , conv1d_bias , x_proj_weight , delta_proj_weight ,
374
368
out_proj_weight , out_proj_bias ,
375
369
A , B = None , C = None , D = None , delta_bias = None , B_proj_bias = None ,
376
- C_proj_bias = None , delta_softplus = True , cu_seqlens = None
370
+ C_proj_bias = None , delta_softplus = True , cu_seqlens = None , seq_idx = None ,
377
371
):
378
372
assert causal_conv1d_fn is not None , "causal_conv1d_fn is not available. Please install causal-conv1d."
379
373
L = xz .shape [- 1 ]
380
374
delta_rank = delta_proj_weight .shape [1 ]
381
375
d_state = A .shape [- 1 ] * (1 if not A .is_complex () else 2 )
382
376
x , z = xz .chunk (2 , dim = 1 )
383
377
384
- if cu_seqlens is not None :
385
- seq_idx = torch .cat ([torch .full ((s ,), i , dtype = torch .int32 , device = cu_seqlens .device )
386
- for i , s in enumerate (cu_seqlens [1 :]- cu_seqlens [:- 1 ])], dim = 0 ).unsqueeze (0 )
387
- else :
388
- seq_idx = None
389
-
390
378
x = causal_conv1d_fn (
391
379
x .transpose (1 ,2 ).contiguous ().transpose (1 ,2 ) if cu_seqlens is not None else x ,
392
380
rearrange (conv1d_weight , "d 1 w -> d w" ),
0 commit comments