Skip to content

Commit 210b6f6

Browse files
committed
mv cu_seqlens in ssm kernel to smem
1 parent 3bc4a51 commit 210b6f6

File tree

2 files changed

+48
-22
lines changed

2 files changed

+48
-22
lines changed

csrc/selective_scan/selective_scan_bwd_kernel.cuh

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,15 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
142142
: reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate;
143143
float dD_val = 0;
144144
float ddelta_bias_val = 0;
145-
long *cu_seqlens = reinterpret_cast<long *>(params.cu_seqlens_ptr) + batch_id * params.u_batch_stride;
145+
146+
// Load cu_seqlens into shared memory
147+
const int cu_seqlens_size = params.cu_seqlens_size;
148+
long *cu_seqlens = reinterpret_cast<long *>(params.cu_seqlens_ptr);
149+
__shared__ long smem_cu_seqlens[1024]; // Adjust size as needed
150+
for (int i = threadIdx.x; i < cu_seqlens_size; i += blockDim.x) {
151+
smem_cu_seqlens[i] = cu_seqlens[i];
152+
}
153+
__syncthreads();
146154

147155
constexpr int kChunkSize = kNThreads * kNItems;
148156
u += (params.n_chunks - 1) * kChunkSize;
@@ -255,15 +263,17 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
255263

256264
// Reset A bar for cumulative sequences (Real)
257265
int left = 1;
258-
int right = params.cu_seqlens_size - 2;
266+
int right = cu_seqlens_size - 2;
267+
int idx = threadIdx.x * kNItems + i + chunk * kChunkSize;
259268
while (left <= right) {
260-
if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) {
269+
int mid = (left + right) >> 1;
270+
if (smem_cu_seqlens[mid] == idx) {
261271
delta_a_exp = 0.f;
262272
break;
263-
} else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) {
264-
left = ((left + right) >> 1) + 1;
273+
} else if (smem_cu_seqlens[mid] < idx) {
274+
left = mid + 1;
265275
} else {
266-
right = ((left + right) >> 1) - 1;
276+
right = mid - 1;
267277
}
268278
}
269279

@@ -358,16 +368,18 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
358368

359369
// Reset A bar for cumulative sequences (Complex)
360370
int left = 1;
361-
int right = params.cu_seqlens_size - 2;
371+
int right = cu_seqlens_size - 2;
372+
int idx = threadIdx.x * kNItems + i + chunk * kChunkSize;
362373
while (left <= right) {
363-
if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) {
374+
int mid = (left + right) >> 1;
375+
if (smem_cu_seqlens[mid] == idx) {
364376
delta_a_exp.real_ = 0.f;
365377
delta_a_exp.imag_ = 0.f;
366378
break;
367-
} else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) {
368-
left = ((left + right) >> 1) + 1;
379+
} else if (smem_cu_seqlens[mid] < idx) {
380+
left = mid + 1;
369381
} else {
370-
right = ((left + right) >> 1) - 1;
382+
right = mid - 1;
371383
}
372384
}
373385

csrc/selective_scan/selective_scan_fwd_kernel.cuh

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,16 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
112112
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
113113
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
114114
scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;
115-
long *cu_seqlens = reinterpret_cast<long *>(params.cu_seqlens_ptr) + batch_id * params.u_batch_stride;
115+
116+
// Load cu_seqlens into shared memory
117+
const int cu_seqlens_size = params.cu_seqlens_size;
118+
long *cu_seqlens = reinterpret_cast<long *>(params.cu_seqlens_ptr);
119+
__shared__ long smem_cu_seqlens[1024]; // Adjust size as needed
120+
for (int i = threadIdx.x; i < cu_seqlens_size; i += blockDim.x) {
121+
smem_cu_seqlens[i] = cu_seqlens[i];
122+
}
123+
__syncthreads();
124+
116125

117126
float D_val[kNRows] = {0};
118127
if (params.D_ptr != nullptr) {
@@ -224,15 +233,17 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
224233

225234
// Reset A bar for cumulative sequences (Real)
226235
int left = 1;
227-
int right = params.cu_seqlens_size - 2;
236+
int right = cu_seqlens_size - 2;
237+
int idx = threadIdx.x * kNItems + i + chunk * kChunkSize;
228238
while (left <= right) {
229-
if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) {
239+
int mid = (left + right) >> 1;
240+
if (smem_cu_seqlens[mid] == idx) {
230241
thread_data[i].x = 0.f;
231242
break;
232-
} else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) {
233-
left = ((left + right) >> 1) + 1;
243+
} else if (smem_cu_seqlens[mid] < idx) {
244+
left = mid + 1;
234245
} else {
235-
right = ((left + right) >> 1) - 1;
246+
right = mid - 1;
236247
}
237248
}
238249

@@ -249,19 +260,22 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
249260

250261
// Reset A bar for cumulative sequences (Complex)
251262
int left = 1;
252-
int right = params.cu_seqlens_size - 2;
263+
int right = cu_seqlens_size - 2;
264+
int idx = threadIdx.x * kNItems + i + chunk * kChunkSize;
253265
while (left <= right) {
254-
if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) {
266+
int mid = (left + right) >> 1;
267+
if (smem_cu_seqlens[mid] == idx) {
255268
thread_data[i].x = 0.f;
256269
thread_data[i].y = 0.f;
257270
break;
258-
} else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) {
259-
left = ((left + right) >> 1) + 1;
271+
} else if (smem_cu_seqlens[mid] < idx) {
272+
left = mid + 1;
260273
} else {
261-
right = ((left + right) >> 1) - 1;
274+
right = mid - 1;
262275
}
263276
}
264277

278+
265279
if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
266280
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
267281
thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f);

0 commit comments

Comments
 (0)