@@ -113,15 +113,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
113
113
input_t *Cvar = reinterpret_cast <input_t *>(params.C_ptr ) + batch_id * params.C_batch_stride + group_id * params.C_group_stride ;
114
114
scan_t *x = reinterpret_cast <scan_t *>(params.x_ptr ) + (batch_id * params.dim + dim_id * kNRows ) * params.n_chunks * params.dstate ;
115
115
116
- // Load cu_seqlens into shared memory
117
116
const int cu_seqlens_size = params.cu_seqlens_size ;
118
- long *cu_seqlens = reinterpret_cast <long *>(params.cu_seqlens_ptr );
119
- __shared__ long smem_cu_seqlens[1024 ]; // Adjust size as needed
120
- for (int i = threadIdx .x ; i < cu_seqlens_size; i += blockDim .x ) {
121
- smem_cu_seqlens[i] = cu_seqlens[i];
122
- }
123
- __syncthreads ();
124
-
117
+ const long *cu_seqlens = reinterpret_cast <long *>(params.cu_seqlens_ptr );
125
118
126
119
float D_val[kNRows ] = {0 };
127
120
if (params.D_ptr != nullptr ) {
@@ -237,10 +230,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
237
230
int idx = threadIdx .x * kNItems + i + chunk * kChunkSize ;
238
231
while (left <= right) {
239
232
int mid = (left + right) >> 1 ;
240
- if (smem_cu_seqlens [mid] == idx) {
233
+ if (cu_seqlens [mid] == idx) {
241
234
thread_data[i].x = 0 .f ;
242
235
break ;
243
- } else if (smem_cu_seqlens [mid] < idx) {
236
+ } else if (cu_seqlens [mid] < idx) {
244
237
left = mid + 1 ;
245
238
} else {
246
239
right = mid - 1 ;
@@ -264,11 +257,11 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
264
257
int idx = threadIdx .x * kNItems + i + chunk * kChunkSize ;
265
258
while (left <= right) {
266
259
int mid = (left + right) >> 1 ;
267
- if (smem_cu_seqlens [mid] == idx) {
260
+ if (cu_seqlens [mid] == idx) {
268
261
thread_data[i].x = 0 .f ;
269
262
thread_data[i].y = 0 .f ;
270
263
break ;
271
- } else if (smem_cu_seqlens [mid] < idx) {
264
+ } else if (cu_seqlens [mid] < idx) {
272
265
left = mid + 1 ;
273
266
} else {
274
267
right = mid - 1 ;
0 commit comments