Skip to content

Commit aea08ca

Browse files
committed
fix typos
1 parent ca189f6 commit aea08ca

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

csrc/selective_scan/selective_scan_bwd_kernel.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
253253
while (left <= right) {
254254
if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) {
255255
delta_a_exp = 0.f;
256+
break;
256257
} else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) {
257258
left = ((left + right) >> 1) + 1;
258259
} else {
@@ -356,6 +357,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
356357
if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) {
357358
delta_a_exp.real_ = 0.f;
358359
delta_a_exp.imag_ = 0.f;
360+
break;
359361
} else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) {
360362
left = ((left + right) >> 1) + 1;
361363
} else {

csrc/selective_scan/selective_scan_fwd_kernel.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
223223
while (left <= right) {
224224
if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) {
225225
thread_data[i].x = 0.f;
226+
break;
226227
} else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) {
227228
left = ((left + right) >> 1) + 1;
228229
} else {
@@ -248,6 +249,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
248249
if (cu_seqlens[(left + right) >> 1] == threadIdx.x * kNItems + i + chunk * kChunkSize) {
249250
thread_data[i].x = 0.f;
250251
thread_data[i].y = 0.f;
252+
break;
251253
} else if (cu_seqlens[(left + right) >> 1] < threadIdx.x * kNItems + i + chunk * kChunkSize) {
252254
left = ((left + right) >> 1) + 1;
253255
} else {

0 commit comments

Comments
 (0)