File tree Expand file tree Collapse file tree 2 files changed +2
-2
lines changed Expand file tree Collapse file tree 2 files changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -136,7 +136,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
136
136
: reinterpret_cast <scan_t *>(params.x_ptr ) + (batch_id * params.dim + dim_id) * (params.n_chunks ) * params.dstate ;
137
137
float dD_val = 0 ;
138
138
float ddelta_bias_val = 0 ;
139
- long *cu_seqlens = reinterpret_cast <long *>(params.cu_seqlens_ptr ) + batch_id * params.u_batch_stride
139
+ long *cu_seqlens = reinterpret_cast <long *>(params.cu_seqlens_ptr ) + batch_id * params.u_batch_stride ;
140
140
141
141
constexpr int kChunkSize = kNThreads * kNItems ;
142
142
u += (params.n_chunks - 1 ) * kChunkSize ;
Original file line number Diff line number Diff line change @@ -107,7 +107,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
107
107
weight_t *C = reinterpret_cast <weight_t *>(params.C_ptr ) + dim_id * kNRows * params.C_d_stride ;
108
108
input_t *Cvar = reinterpret_cast <input_t *>(params.C_ptr ) + batch_id * params.C_batch_stride + group_id * params.C_group_stride ;
109
109
scan_t *x = reinterpret_cast <scan_t *>(params.x_ptr ) + (batch_id * params.dim + dim_id * kNRows ) * params.n_chunks * params.dstate ;
110
- long *cu_seqlens = reinterpret_cast <long *>(params.cu_seqlens_ptr ) + batch_id * params.u_batch_stride
110
+ long *cu_seqlens = reinterpret_cast <long *>(params.cu_seqlens_ptr ) + batch_id * params.u_batch_stride ;
111
111
112
112
float D_val[kNRows ] = {0 };
113
113
if (params.D_ptr != nullptr ) {
You can’t perform that action at this time.
0 commit comments