Skip to content

Commit e223353

Browse files
committed
fix typos
1 parent a78a9eb commit e223353

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

csrc/selective_scan/selective_scan_bwd_kernel.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
136136
: reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate;
137137
float dD_val = 0;
138138
float ddelta_bias_val = 0;
139-
long *cu_seqlens = reinterpret_cast<long *>(params.cu_seqlens_ptr) + batch_id * params.u_batch_stride
139+
long *cu_seqlens = reinterpret_cast<long *>(params.cu_seqlens_ptr) + batch_id * params.u_batch_stride;
140140

141141
constexpr int kChunkSize = kNThreads * kNItems;
142142
u += (params.n_chunks - 1) * kChunkSize;

csrc/selective_scan/selective_scan_fwd_kernel.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
107107
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
108108
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
109109
scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;
110-
long *cu_seqlens = reinterpret_cast<long *>(params.cu_seqlens_ptr) + batch_id * params.u_batch_stride
110+
long *cu_seqlens = reinterpret_cast<long *>(params.cu_seqlens_ptr) + batch_id * params.u_batch_stride;
111111

112112
float D_val[kNRows] = {0};
113113
if (params.D_ptr != nullptr) {

0 commit comments

Comments
 (0)