@@ -119,6 +119,7 @@ def __init__(
119
119
def forward (self , hidden_states , cu_seqlens = None , inference_params = None ):
120
120
"""
121
121
hidden_states: (B, L, D)
122
+ cu_seqlens: one-dimensional tensor like flash-attn varlen API, only used for variable-length sequences and packing variable-length sequences into one, a.k.a., batch_size B=1
122
123
Returns: same shape as hidden_states
123
124
"""
124
125
batch , seqlen , dim = hidden_states .shape
@@ -157,7 +158,7 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
157
158
self .D .float (),
158
159
delta_bias = self .dt_proj .bias .float (),
159
160
delta_softplus = True ,
160
- cu_seqlens = cu_seqlens [ 0 ] if cu_seqlens is not None else None ,
161
+ cu_seqlens = cu_seqlens ,
161
162
)
162
163
else :
163
164
x , z = xz .chunk (2 , dim = 1 )
@@ -166,12 +167,12 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
166
167
if cu_seqlens is not None :
167
168
padded_x = x
168
169
count = 0
169
- for idx in cu_seqlens [0 ][ 1 :- 1 ].tolist ():
170
+ for idx in cu_seqlens [1 :- 1 ].tolist ():
170
171
padded_idx = idx + count * (self .d_conv - 1 )
171
172
padded_x = torch .cat ((padded_x [:, :, :padded_idx ], torch .zeros (1 , x .shape [1 ], self .d_conv - 1 , dtype = x .dtype , device = x .device ), padded_x [:, :, padded_idx :]), dim = 2 )
172
173
count = count + 1
173
174
x = padded_x
174
- assert x .shape [2 ] == (self .d_conv - 1 ) * len (cu_seqlens [ 0 ] [1 :- 1 ]) + z .shape [2 ]
175
+ # assert x.shape[2] == (self.d_conv - 1) * len(cu_seqlens[1:-1]) + z.shape[2]
175
176
176
177
# Compute short convolution
177
178
if conv_state is not None :
@@ -192,13 +193,13 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
192
193
# (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences
193
194
if cu_seqlens is not None :
194
195
mask = []
195
- for seq_len in (cu_seqlens [0 ][ 1 :] - cu_seqlens [ 0 ] [:- 1 ]).tolist ():
196
+ for seq_len in (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).tolist ():
196
197
mask .extend ([True ] * seq_len )
197
198
mask .extend ([False ] * (self .d_conv - 1 ))
198
199
mask = mask [:- (self .d_conv - 1 )]
199
- assert x .shape [2 ] == len (mask )
200
+ # assert x.shape[2] == len(mask)
200
201
x = x [:, :, mask ]
201
- assert x .shape [2 ] == z .shape [2 ]
202
+ # assert x.shape[2] == z.shape[2]
202
203
203
204
# We're careful here about the layout, to avoid extra transposes.
204
205
# We want dt to have d as the slowest moving dimension
@@ -222,7 +223,7 @@ def forward(self, hidden_states, cu_seqlens=None, inference_params=None):
222
223
delta_bias = self .dt_proj .bias .float (),
223
224
delta_softplus = True ,
224
225
return_last_state = ssm_state is not None ,
225
- cu_seqlens = cu_seqlens [ 0 ] if cu_seqlens is not None else None ,
226
+ cu_seqlens = cu_seqlens ,
226
227
)
227
228
if ssm_state is not None :
228
229
y , last_state = y
0 commit comments