10
10
11
11
from einops import rearrange , repeat
12
12
13
- from mamba_ssm .ops .selective_scan_interface import selective_scan_fn , mamba_inner_fn , selective_scan_ref
13
+ from mamba_ssm .ops .selective_scan_interface import selective_scan_fn , mamba_inner_fn
14
14
15
15
try :
16
16
from causal_conv1d import causal_conv1d_fn , causal_conv1d_update
@@ -119,10 +119,12 @@ def __init__(
119
119
def forward (self , hidden_states , cu_seqlens = None , inference_params = None ):
120
120
"""
121
121
hidden_states: (B, L, D)
122
- cu_seqlens: one-dimensional tensor representing cumulative start indexes of packed sequence, a.k.a., B=1
122
+ cu_seqlens: (Optional) cumulative sum of the sequence lengths, starting from 0 and end with L, and must already be sorted.
123
123
Returns: same shape as hidden_states
124
124
"""
125
125
batch , seqlen , dim = hidden_states .shape
126
+ if cu_seqlens is not None :
127
+ assert batch == 1 and cu_seqlens .ndimension () == 1 , "varlen mamba1 is only supported with B=1"
126
128
127
129
conv_state , ssm_state = None , None
128
130
if inference_params is not None :
@@ -158,46 +160,40 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
158
160
self .D .float (),
159
161
delta_bias = self .dt_proj .bias .float (),
160
162
delta_softplus = True ,
161
- cu_seqlens = cu_seqlens ,
162
- d_conv = torch .tensor (self .d_conv )
163
+ cu_seqlens = cu_seqlens
163
164
)
164
165
else :
165
166
x , z = xz .chunk (2 , dim = 1 )
166
-
167
- # (Optional Step1 for cu_seqlens): Right padding zeros at sequence boundary for con1d ops in cumulative sequences
168
- if cu_seqlens is not None :
169
- padded_x = x
170
- count = 0
171
- for idx in cu_seqlens [1 :- 1 ].tolist ():
172
- padded_idx = idx + count * (self .d_conv - 1 )
173
- padded_x = torch .cat ((padded_x [:, :, :padded_idx ], torch .zeros (1 , x .shape [1 ], self .d_conv - 1 , dtype = x .dtype , device = x .device ), padded_x [:, :, padded_idx :]), dim = 2 )
174
- count = count + 1
175
- x = padded_x
176
-
177
167
# Compute short convolution
178
168
if conv_state is not None :
179
169
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
180
170
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
181
171
conv_state .copy_ (F .pad (x , (self .d_conv - x .shape [- 1 ], 0 ))) # Update state (B D W)
182
172
if causal_conv1d_fn is None :
183
- x = self .act (self .conv1d (x )[..., :seqlen ])
173
+ if cu_seqlens is not None :
174
+ # naive pure python implementation of varlen causal_conv1d
175
+ for i , s in enumerate (cu_seqlens [1 :- 1 ]):
176
+ x = torch .cat ((x [..., :s + i * (self .d_conv - 1 )], torch .zeros_like (x [..., :(self .d_conv - 1 )]), x [..., s + i * (self .d_conv - 1 ):]), dim = 2 )
177
+ mask = torch .cat ([torch .cat ((torch .full ((s ,), True , dtype = torch .bool , device = x .device ),
178
+ torch .full ((self .d_conv - 1 ,), False , dtype = torch .bool , device = x .device )), dim = 0 )
179
+ for s in (cu_seqlens [1 :] - cu_seqlens [:- 1 ])], dim = 0 )
180
+ x = self .act (self .conv1d (x )[:, :, mask ])
181
+ else :
182
+ x = self .act (self .conv1d (x )[..., :seqlen ])
184
183
else :
185
184
assert self .activation in ["silu" , "swish" ]
185
+ if cu_seqlens is not None :
186
+ seq_idx = torch .cat ([torch .full ((s ,), i , dtype = torch .int32 , device = cu_seqlens .device )
187
+ for i , s in enumerate (cu_seqlens [1 :]- cu_seqlens [:- 1 ])], dim = 0 ).unsqueeze (0 )
188
+ else :
189
+ seq_idx = None
186
190
x = causal_conv1d_fn (
187
- x = x ,
191
+ x = x . transpose ( 1 , 2 ). contiguous (). transpose ( 1 , 2 ) if cu_seqlens is not None else x ,
188
192
weight = rearrange (self .conv1d .weight , "d 1 w -> d w" ),
189
193
bias = self .conv1d .bias ,
194
+ seq_idx = seq_idx ,
190
195
activation = self .activation ,
191
196
)
192
-
193
- # (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences
194
- if cu_seqlens is not None :
195
- mask = []
196
- for seq_len in (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).tolist ():
197
- mask .extend ([True ] * seq_len )
198
- mask .extend ([False ] * (self .d_conv - 1 ))
199
- mask = mask [:- (self .d_conv - 1 )]
200
- x = x [:, :, mask ]
201
197
202
198
# We're careful here about the layout, to avoid extra transposes.
203
199
# We want dt to have d as the slowest moving dimension
@@ -208,7 +204,6 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
208
204
dt = rearrange (dt , "d (b l) -> b d l" , l = seqlen )
209
205
B = rearrange (B , "(b l) dstate -> b dstate l" , l = seqlen ).contiguous ()
210
206
C = rearrange (C , "(b l) dstate -> b dstate l" , l = seqlen ).contiguous ()
211
-
212
207
assert self .activation in ["silu" , "swish" ]
213
208
y = selective_scan_fn (
214
209
x ,
@@ -317,59 +312,3 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states
317
312
conv_state .zero_ ()
318
313
ssm_state .zero_ ()
319
314
return conv_state , ssm_state
320
-
321
- class Block (nn .Module ):
322
- def __init__ (
323
- self , dim , mixer_cls , norm_cls = nn .LayerNorm , fused_add_norm = False , residual_in_fp32 = False
324
- ):
325
- """
326
- Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
327
- This Block has a slightly different structure compared to a regular
328
- prenorm Transformer block.
329
- The standard block is: LN -> MHA/MLP -> Add.
330
- [Ref: https://arxiv.org/abs/2002.04745]
331
- Here we have: Add -> LN -> Mixer, returning both
332
- the hidden_states (output of the mixer) and the residual.
333
- This is purely for performance reasons, as we can fuse add and LayerNorm.
334
- The residual needs to be provided (except for the very first block).
335
- """
336
- super ().__init__ ()
337
- self .residual_in_fp32 = residual_in_fp32
338
- self .fused_add_norm = fused_add_norm
339
- self .mixer = mixer_cls (dim )
340
- self .norm = norm_cls (dim )
341
- if self .fused_add_norm :
342
- assert RMSNorm is not None , "RMSNorm import fails"
343
- assert isinstance (
344
- self .norm , (nn .LayerNorm , RMSNorm )
345
- ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
346
-
347
- def forward (
348
- self , hidden_states : Tensor , residual : Optional [Tensor ] = None , cu_seqlens = None , inference_params = None
349
- ):
350
- r"""Pass the input through the encoder layer.
351
- Args:
352
- hidden_states: the sequence to the encoder layer (required).
353
- residual: hidden_states = Mixer(LN(residual))
354
- """
355
- if not self .fused_add_norm :
356
- residual = (hidden_states + residual ) if residual is not None else hidden_states
357
- hidden_states = self .norm (residual .to (dtype = self .norm .weight .dtype ))
358
- if self .residual_in_fp32 :
359
- residual = residual .to (torch .float32 )
360
- else :
361
- fused_add_norm_fn = rms_norm_fn if isinstance (self .norm , RMSNorm ) else layer_norm_fn
362
- hidden_states , residual = fused_add_norm_fn (
363
- hidden_states ,
364
- self .norm .weight ,
365
- self .norm .bias ,
366
- residual = residual ,
367
- prenorm = True ,
368
- residual_in_fp32 = self .residual_in_fp32 ,
369
- eps = self .norm .eps ,
370
- )
371
- hidden_states = self .mixer (hidden_states , cu_seqlens = cu_seqlens , inference_params = inference_params )
372
- return hidden_states , residual
373
-
374
- def allocate_inference_cache (self , batch_size , max_seqlen , dtype = None , ** kwargs ):
375
- return self .mixer .allocate_inference_cache (batch_size , max_seqlen , dtype = dtype , ** kwargs )
0 commit comments