Skip to content

Commit cda4b5a

Browse files
committed
remove smem implementation because const vals and bi-search is enough
1 parent 210b6f6 commit cda4b5a

File tree

2 files changed

+10
-23
lines changed

2 files changed

+10
-23
lines changed

csrc/selective_scan/selective_scan_bwd_kernel.cuh

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,8 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
143143
float dD_val = 0;
144144
float ddelta_bias_val = 0;
145145

146-
// Load cu_seqlens into shared memory
147146
const int cu_seqlens_size = params.cu_seqlens_size;
148-
long *cu_seqlens = reinterpret_cast<long *>(params.cu_seqlens_ptr);
149-
__shared__ long smem_cu_seqlens[1024]; // Adjust size as needed
150-
for (int i = threadIdx.x; i < cu_seqlens_size; i += blockDim.x) {
151-
smem_cu_seqlens[i] = cu_seqlens[i];
152-
}
153-
__syncthreads();
147+
const long *cu_seqlens = reinterpret_cast<long *>(params.cu_seqlens_ptr);
154148

155149
constexpr int kChunkSize = kNThreads * kNItems;
156150
u += (params.n_chunks - 1) * kChunkSize;
@@ -267,10 +261,10 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
267261
int idx = threadIdx.x * kNItems + i + chunk * kChunkSize;
268262
while (left <= right) {
269263
int mid = (left + right) >> 1;
270-
if (smem_cu_seqlens[mid] == idx) {
264+
if (cu_seqlens[mid] == idx) {
271265
delta_a_exp = 0.f;
272266
break;
273-
} else if (smem_cu_seqlens[mid] < idx) {
267+
} else if (cu_seqlens[mid] < idx) {
274268
left = mid + 1;
275269
} else {
276270
right = mid - 1;
@@ -372,11 +366,11 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
372366
int idx = threadIdx.x * kNItems + i + chunk * kChunkSize;
373367
while (left <= right) {
374368
int mid = (left + right) >> 1;
375-
if (smem_cu_seqlens[mid] == idx) {
369+
if (cu_seqlens[mid] == idx) {
376370
delta_a_exp.real_ = 0.f;
377371
delta_a_exp.imag_ = 0.f;
378372
break;
379-
} else if (smem_cu_seqlens[mid] < idx) {
373+
} else if (cu_seqlens[mid] < idx) {
380374
left = mid + 1;
381375
} else {
382376
right = mid - 1;

csrc/selective_scan/selective_scan_fwd_kernel.cuh

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
113113
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
114114
scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;
115115

116-
// Load cu_seqlens into shared memory
117116
const int cu_seqlens_size = params.cu_seqlens_size;
118-
long *cu_seqlens = reinterpret_cast<long *>(params.cu_seqlens_ptr);
119-
__shared__ long smem_cu_seqlens[1024]; // Adjust size as needed
120-
for (int i = threadIdx.x; i < cu_seqlens_size; i += blockDim.x) {
121-
smem_cu_seqlens[i] = cu_seqlens[i];
122-
}
123-
__syncthreads();
124-
117+
const long *cu_seqlens = reinterpret_cast<long *>(params.cu_seqlens_ptr);
125118

126119
float D_val[kNRows] = {0};
127120
if (params.D_ptr != nullptr) {
@@ -237,10 +230,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
237230
int idx = threadIdx.x * kNItems + i + chunk * kChunkSize;
238231
while (left <= right) {
239232
int mid = (left + right) >> 1;
240-
if (smem_cu_seqlens[mid] == idx) {
233+
if (cu_seqlens[mid] == idx) {
241234
thread_data[i].x = 0.f;
242235
break;
243-
} else if (smem_cu_seqlens[mid] < idx) {
236+
} else if (cu_seqlens[mid] < idx) {
244237
left = mid + 1;
245238
} else {
246239
right = mid - 1;
@@ -264,11 +257,11 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
264257
int idx = threadIdx.x * kNItems + i + chunk * kChunkSize;
265258
while (left <= right) {
266259
int mid = (left + right) >> 1;
267-
if (smem_cu_seqlens[mid] == idx) {
260+
if (cu_seqlens[mid] == idx) {
268261
thread_data[i].x = 0.f;
269262
thread_data[i].y = 0.f;
270263
break;
271-
} else if (smem_cu_seqlens[mid] < idx) {
264+
} else if (cu_seqlens[mid] < idx) {
272265
left = mid + 1;
273266
} else {
274267
right = mid - 1;

0 commit comments

Comments
 (0)