@@ -418,6 +418,31 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
418
418
typename Ktraits::BlockStoreT (smem_store).Store (out, out_vals_store, seqlen - chunk * kChunkSize );
419
419
}
420
420
out += kChunkSize ;
421
+
422
+ int final_state_position = ((seqlen - (kWidth - 1 )) - (n_chunks - 1 ) * kChunkSize );
423
+ // in case the final state is separated between the last "smem_exchange" and
424
+ // and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
425
+ // (which occurs when `final_state_position` is a non-positivie index)
426
+ // we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
427
+ if (final_state_position < 0 && seqlen > kWidth ){
428
+ input_t vals_load[kNElts ] = {0 };
429
+ if ((chunk == n_chunks - 2 ) && (tidx == kNThreads - 1 )){
430
+ // chunk = n_chunks - 2, a segment of the final state sits in the last index
431
+ reinterpret_cast <vec_t *>(vals_load)[0 ] = smem_exchange[kNThreads - 1 ];
432
+ #pragma unroll
433
+ for (int w = 0 ; w < -final_state_position; ++w){
434
+ conv_states[w] = vals_load[kNElts + final_state_position + w];
435
+ }
436
+ }
437
+ if ((chunk == n_chunks - 1 ) && tidx == 0 ){
438
+ // chunk = n_chunks - 1, the second segment of the final state first positions
439
+ reinterpret_cast <vec_t *>(vals_load)[0 ] = smem_exchange[0 ];
440
+ for (int w = -final_state_position; w < kWidth - 1 ; ++w){
441
+ conv_states[w] = vals_load[w + final_state_position];
442
+ }
443
+ return ;
444
+ }
445
+ }
421
446
}
422
447
// Final state is stored in the smem_exchange last token slot,
423
448
// in case seqlen < kWidth, we would need to take the final state from the
@@ -446,9 +471,14 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
446
471
}
447
472
else {
448
473
// in case the final state is in between the threads data
449
- reinterpret_cast <vec_t *>(x_vals_load)[1 ] = smem_exchange[last_thread + 1 ];
450
- reinterpret_cast <vec_t *>(x_vals_load)[0 ] = smem_exchange[last_thread];
451
474
const int offset = ((seqlen - (kWidth - 1 )) % (kNElts ));
475
+ if ((offset + kWidth - 2 ) >= kNElts && (last_thread + 1 < kNThreads )){
476
+ // In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a
477
+ // illegal access error on H100.
478
+ // Therefore, we access last_thread + 1, only if the final state data sits there
479
+ reinterpret_cast <vec_t *>(x_vals_load)[1 ] = smem_exchange[last_thread + 1 ];
480
+ }
481
+ reinterpret_cast <vec_t *>(x_vals_load)[0 ] = smem_exchange[last_thread];
452
482
#pragma unroll
453
483
for (int w = 0 ; w < kWidth - 1 ; ++w){
454
484
conv_states[w] = x_vals_load[offset + w ];
0 commit comments