@@ -112,7 +112,16 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
112
112
weight_t *C = reinterpret_cast <weight_t *>(params.C_ptr ) + dim_id * kNRows * params.C_d_stride ;
113
113
input_t *Cvar = reinterpret_cast <input_t *>(params.C_ptr ) + batch_id * params.C_batch_stride + group_id * params.C_group_stride ;
114
114
scan_t *x = reinterpret_cast <scan_t *>(params.x_ptr ) + (batch_id * params.dim + dim_id * kNRows ) * params.n_chunks * params.dstate ;
115
- long *cu_seqlens = reinterpret_cast <long *>(params.cu_seqlens_ptr ) + batch_id * params.u_batch_stride ;
115
+
116
+ // Load cu_seqlens into shared memory
117
+ const int cu_seqlens_size = params.cu_seqlens_size ;
118
+ long *cu_seqlens = reinterpret_cast <long *>(params.cu_seqlens_ptr );
119
+ __shared__ long smem_cu_seqlens[1024 ]; // Adjust size as needed
120
+ for (int i = threadIdx .x ; i < cu_seqlens_size; i += blockDim .x ) {
121
+ smem_cu_seqlens[i] = cu_seqlens[i];
122
+ }
123
+ __syncthreads ();
124
+
116
125
117
126
float D_val[kNRows ] = {0 };
118
127
if (params.D_ptr != nullptr ) {
@@ -224,15 +233,17 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
224
233
225
234
// Reset A bar for cumulative sequences (Real)
226
235
int left = 1 ;
227
- int right = params.cu_seqlens_size - 2 ;
236
+ int right = cu_seqlens_size - 2 ;
237
+ int idx = threadIdx .x * kNItems + i + chunk * kChunkSize ;
228
238
while (left <= right) {
229
- if (cu_seqlens[(left + right) >> 1 ] == threadIdx .x * kNItems + i + chunk * kChunkSize ) {
239
+ int mid = (left + right) >> 1 ;
240
+ if (smem_cu_seqlens[mid] == idx) {
230
241
thread_data[i].x = 0 .f ;
231
242
break ;
232
- } else if (cu_seqlens[(left + right) >> 1 ] < threadIdx . x * kNItems + i + chunk * kChunkSize ) {
233
- left = ((left + right) >> 1 ) + 1 ;
243
+ } else if (smem_cu_seqlens[mid ] < idx ) {
244
+ left = mid + 1 ;
234
245
} else {
235
- right = ((left + right) >> 1 ) - 1 ;
246
+ right = mid - 1 ;
236
247
}
237
248
}
238
249
@@ -249,19 +260,22 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
249
260
250
261
// Reset A bar for cumulative sequences (Complex)
251
262
int left = 1 ;
252
- int right = params.cu_seqlens_size - 2 ;
263
+ int right = cu_seqlens_size - 2 ;
264
+ int idx = threadIdx .x * kNItems + i + chunk * kChunkSize ;
253
265
while (left <= right) {
254
- if (cu_seqlens[(left + right) >> 1 ] == threadIdx .x * kNItems + i + chunk * kChunkSize ) {
266
+ int mid = (left + right) >> 1 ;
267
+ if (smem_cu_seqlens[mid] == idx) {
255
268
thread_data[i].x = 0 .f ;
256
269
thread_data[i].y = 0 .f ;
257
270
break ;
258
- } else if (cu_seqlens[(left + right) >> 1 ] < threadIdx . x * kNItems + i + chunk * kChunkSize ) {
259
- left = ((left + right) >> 1 ) + 1 ;
271
+ } else if (smem_cu_seqlens[mid ] < idx ) {
272
+ left = mid + 1 ;
260
273
} else {
261
- right = ((left + right) >> 1 ) - 1 ;
274
+ right = mid - 1 ;
262
275
}
263
276
}
264
277
278
+
265
279
if constexpr (!Ktraits::kIsEvenLen ) { // So that the last state is correct
266
280
if (threadIdx .x * kNItems + i >= params.seqlen - chunk * kChunkSize ) {
267
281
thread_data[i] = make_float4 (1 .f , 0 .f , 0 .f , 0 .f );
0 commit comments