Skip to content

[Feature] Support variable-length sequences for mamba block #244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 43 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
d28e1b0
add cu_seqlens support and ensure numerical equality
zigzagcai Mar 8, 2024
a78a9eb
add notes for variable length sequences
zigzagcai Mar 14, 2024
e223353
fix typos
zigzagcai Mar 15, 2024
5955450
fix typos
zigzagcai Mar 18, 2024
ca189f6
Merge branch 'main' into feat/add-cu_seqlens
zigzagcai Mar 18, 2024
c2d5b88
fix typos
Dmovic Mar 18, 2024
db0dd09
fix typos
zigzagcai Mar 18, 2024
842bef5
Merge branch 'main' into feat/add-cu_seqlens
zigzagcai Mar 18, 2024
e7774aa
refine cu_seqlens implementation
zigzagcai Mar 18, 2024
1ccc60f
Merge branch 'feat/add-cu_seqlens' into feat/add-cu_seqlens
Dmovic Mar 19, 2024
4bf2697
Merge pull request #1 from Dmovic/feat/add-cu_seqlens
zigzagcai Mar 19, 2024
f357c44
add unit test for variable length
Dmovic Mar 19, 2024
6b98161
update unit test
Dmovic Mar 19, 2024
e4af927
fix typos
zigzagcai Mar 19, 2024
4221d48
update selective scan
zigzagcai Mar 25, 2024
934c0e6
Add logic for variable-length sequences
wang-zerui Mar 25, 2024
63b646d
Merge branch 'main' into feat/add-cu_seqlens
zigzagcai Apr 18, 2024
f6bb7e2
add example test to prove the mathematical equivalence of cu_seqlens …
zigzagcai Apr 26, 2024
bffcd97
fix typos
zigzagcai Apr 26, 2024
e3cab98
add cu_seqlens support for MixerModel
zigzagcai Apr 26, 2024
2f01ede
code refine for tests
zigzagcai Apr 30, 2024
f0a6508
refine code for tests
zigzagcai Apr 30, 2024
623d246
update API notes
zigzagcai Apr 30, 2024
ef3f760
update test code
zigzagcai Apr 30, 2024
71c77b1
Merge remote-tracking branch 'origin/main' into feat/add-cu_seqlens
zigzagcai Jun 6, 2024
2d27ccc
fix conflicts with latest main branch
zigzagcai Jun 6, 2024
f802627
Merge remote-tracking branch 'origin/main' into feat/add-cu_seqlens
zigzagcai Jul 16, 2024
596943c
fix unittest for test_selective_state_update_with_heads
zigzagcai Jul 16, 2024
6961faa
Merge branch 'state-spaces:main' into feat/add-cu_seqlens
zigzagcai Jul 18, 2024
b69b957
migrate to tridao's native varlen causal_conv1d kernel for speedup
zigzagcai Jul 19, 2024
50bffae
Merge branch 'state-spaces:main' into feat/add-cu_seqlens
zigzagcai Jul 22, 2024
909f970
typo fix
zigzagcai Jul 23, 2024
8174c45
use seq_idx if provided, or compute it by cu_seqlens
zigzagcai Aug 5, 2024
59be631
use seq_idx if provided, or compute it by cu_seqlens
zigzagcai Aug 5, 2024
3bc4a51
Merge branch 'state-spaces:main' into feat/add-cu_seqlens
zigzagcai Aug 6, 2024
210b6f6
mv cu_seqlens in ssm kernel to smem
zigzagcai Aug 7, 2024
cda4b5a
remove smem implementation because const vals and bi-search is enough
zigzagcai Aug 7, 2024
f463a65
Add unittest for test_mamba_cu_seqlens_equivalence_nlayers.py
zigzagcai Apr 24, 2025
7cd6bb2
Merge remote-tracking branch 'origin/main' into feat/add-cu_seqlens
zigzagcai Apr 25, 2025
b02b243
Merge remote-tracking branch 'upstream/main' into feat/add-cu_seqlens
zigzagcai Apr 25, 2025
1887385
Merge remote-tracking branch 'upstream/main' into feat/add-cu_seqlens
zigzagcai Jun 11, 2025
fc30cf7
refactor implementation to use pos_ids as the sequence boundary ident…
zigzagcai Jun 18, 2025
93403bf
remove unnecessary syncthreads
zigzagcai Jun 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions csrc/selective_scan/selective_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ void set_ssm_params_fwd(SSMParamsBase &params,
void* delta_bias_ptr,
void* x_ptr,
bool has_z,
bool delta_softplus) {
bool delta_softplus,
void* pos_ids_ptr) {

// Reset the parameters
memset(&params, 0, sizeof(params));
Expand Down Expand Up @@ -109,6 +110,8 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.x_ptr = x_ptr;
params.z_ptr = has_z ? z.data_ptr() : nullptr;
params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
params.pos_ids_ptr = pos_ids_ptr;

// All stride are in elements, not bytes.
params.A_d_stride = A.stride(0);
params.A_dstate_stride = A.stride(1);
Expand Down Expand Up @@ -173,15 +176,16 @@ void set_ssm_params_bwd(SSMParamsBwd &params,
void* ddelta_bias_ptr,
bool has_z,
bool delta_softplus,
bool recompute_out_z) {
bool recompute_out_z,
void* pos_ids_ptr) {
// Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
u, delta, A, B, C, has_z ? out : dout,
has_z ? z : dout,
// If not recompute_out_z, pass dout instead of out_z.
// This won't be used by the bwd kernel
recompute_out_z ? out_z : dout,
D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus);
D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus, pos_ids_ptr);
if (!recompute_out_z) { params.out_z_ptr = nullptr; }

// Set the pointers and strides.
Expand Down Expand Up @@ -229,7 +233,8 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
const c10::optional<at::Tensor> &D_,
const c10::optional<at::Tensor> &z_,
const c10::optional<at::Tensor> &delta_bias_,
bool delta_softplus) {
bool delta_softplus,
const c10::optional<at::Tensor> &pos_ids_) {
auto input_type = u.scalar_type();
auto weight_type = A.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
Expand Down Expand Up @@ -293,6 +298,14 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
CHECK_SHAPE(delta_bias, dim);
}

if (pos_ids_.has_value()) {
auto pos_ids = pos_ids_.value();
TORCH_CHECK(pos_ids.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(pos_ids.is_cuda());
CHECK_SHAPE(pos_ids, batch_size, seqlen);
TORCH_CHECK(batch_size == 1)
}

at::Tensor z, out_z;
const bool has_z = z_.has_value();
if (has_z) {
Expand All @@ -319,7 +332,8 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
x.data_ptr(),
has_z,
delta_softplus);
delta_softplus,
pos_ids_.has_value() ? pos_ids_.value().data_ptr() : nullptr);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
Expand All @@ -346,7 +360,8 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
const c10::optional<at::Tensor> &out_,
c10::optional<at::Tensor> &dz_,
bool delta_softplus,
bool recompute_out_z) {
bool recompute_out_z,
const c10::optional<at::Tensor> &pos_ids_) {
auto input_type = u.scalar_type();
auto weight_type = A.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
Expand Down Expand Up @@ -414,6 +429,14 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
CHECK_SHAPE(delta_bias, dim);
}

if (pos_ids_.has_value()) {
auto pos_ids = pos_ids_.value();
TORCH_CHECK(pos_ids.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(pos_ids.is_cuda());
CHECK_SHAPE(pos_ids, batch_size, seqlen);
TORCH_CHECK(batch_size == 1)
}

at::Tensor z, out, dz, out_z;
const bool has_z = z_.has_value();
if (has_z) {
Expand Down Expand Up @@ -474,7 +497,8 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
dout, du, ddelta, dA, dB, dC, dz,
D_.has_value() ? dD.data_ptr() : nullptr,
delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,
has_z, delta_softplus, recompute_out_z);
has_z, delta_softplus, recompute_out_z,
pos_ids_.has_value() ? pos_ids_.value().data_ptr() : nullptr);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
Expand Down
1 change: 1 addition & 0 deletions csrc/selective_scan/selective_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ struct SSMParamsBase {
void *__restrict__ x_ptr;
void *__restrict__ z_ptr;
void *__restrict__ out_z_ptr;
void *__restrict__ pos_ids_ptr;
};

struct SSMParamsBwd: public SSMParamsBase {
Expand Down
76 changes: 55 additions & 21 deletions csrc/selective_scan/selective_scan_bwd_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,39 @@ template<> __device__ __forceinline__ float conj<float>(float x) { return x; }
template<> __device__ __forceinline__ complex_t conj<complex_t>(complex_t x) { return std::conj(x); }

template<int kNThreads_, int kNItems_, bool kIsEvenLen_, bool kIsVariableB_, bool kIsVariableC_,
bool kDeltaSoftplus_, bool kHasZ_, typename input_t_, typename weight_t_>
bool kDeltaSoftplus_, bool kHasZ_, bool kIsVarLen_, typename input_t_, typename weight_t_>
struct Selective_Scan_bwd_kernel_traits {
static_assert(kNItems_ % 4 == 0);
using input_t = input_t_;
using weight_t = weight_t_;
using pos_t = uint32_t;
static constexpr int kNThreads = kNThreads_;
static constexpr int kNItems = kNItems_;
static constexpr int kNBytes = sizeof(input_t);
static constexpr int kNBytesPos = sizeof(pos_t);
static_assert(kNBytes == 2 || kNBytes == 4);
static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);
static constexpr int kNEltsPos = kNBytesPos == 4 ? 4 : constexpr_min(8, kNItems);
static_assert(kNItems % kNElts == 0);
static constexpr int kNLoads = kNItems / kNElts;
static constexpr int kNLoadsPos = kNItems / kNEltsPos;
static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>;
static constexpr bool kIsEvenLen = kIsEvenLen_;
static constexpr bool kIsVariableB = kIsVariableB_;
static constexpr bool kIsVariableC = kIsVariableC_;
static constexpr bool kDeltaSoftplus = kDeltaSoftplus_;
static constexpr bool kHasZ = kHasZ_;
static constexpr bool kIsVarLen = kIsVarLen_;
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy.
// For complex this would lead to massive register spilling, so we keep it at 2.
static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2;
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
using pos_vec_t = typename BytesToType<kNBytesPos * kNEltsPos>::Type;
using scan_t = std::conditional_t<!kIsComplex, float2, float4>;
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockLoadPosIdsT = cub::BlockLoad<pos_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockLoadPosIdsVecT = cub::BlockLoad<pos_vec_t, kNThreads, kNLoadsPos, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
Expand Down Expand Up @@ -86,10 +94,12 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus;
constexpr bool kHasZ = Ktraits::kHasZ;
constexpr bool kIsVarLen = Ktraits::kIsVarLen;
constexpr int kNThreads = Ktraits::kNThreads;
constexpr int kNItems = Ktraits::kNItems;
using input_t = typename Ktraits::input_t;
using weight_t = typename Ktraits::weight_t;
using pos_t = typename Ktraits::pos_t;
using scan_t = typename Ktraits::scan_t;

// Shared memory.
Expand All @@ -100,6 +110,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
auto& smem_load_pos_ids = reinterpret_cast<typename Ktraits::BlockLoadPosIdsT::TempStorage&>(smem_);
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
auto& smem_exchange = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
Expand Down Expand Up @@ -142,9 +153,11 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
: reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate;
float dD_val = 0;
float ddelta_bias_val = 0;
pos_t *pos_ids = !kIsVarLen ? nullptr :reinterpret_cast<pos_t *>(params.pos_ids_ptr) + batch_id * params.seqlen;

constexpr int kChunkSize = kNThreads * kNItems;
u += (params.n_chunks - 1) * kChunkSize;
pos_ids += (params.n_chunks - 1) * kChunkSize;
delta += (params.n_chunks - 1) * kChunkSize;
dout += (params.n_chunks - 1) * kChunkSize;
Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2);
Expand All @@ -153,9 +166,15 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
input_t u_vals[kNItems];
input_t delta_vals_load[kNItems];
input_t dout_vals_load[kNItems];
pos_t pos_ids_vals[kNItems];
__syncthreads();
load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);
u -= kChunkSize;
if constexpr (kIsVarLen) {
__syncthreads();
load_pos_ids<Ktraits>(pos_ids, pos_ids_vals, smem_load_pos_ids, params.seqlen - chunk * kChunkSize);
pos_ids -= kChunkSize;
}
__syncthreads();
load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
// Will reload delta at the same location if kDeltaSoftplus
Expand Down Expand Up @@ -250,7 +269,13 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
if constexpr (!kIsComplex) {
#pragma unroll
for (int i = 0; i < kNItems; ++i) {
const float delta_a_exp = exp2f(delta_vals[i] * A_scaled);
float delta_a_exp = exp2f(delta_vals[i] * A_scaled);
// Reset A bar for cumulative sequences (Real)
if constexpr (kIsVarLen) {
if (pos_ids_vals[i] == 0) {
delta_a_exp = 0.f;
}
}
thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]);
if (i == 0) {
smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp;
Expand Down Expand Up @@ -338,6 +363,13 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
for (int i = 0; i < kNItems; ++i) {
// Pytorch's implementation of complex exp (which calls thrust) is very slow
complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled);
// Reset A bar for cumulative sequences (Complex)
if constexpr (kIsVarLen) {
if (pos_ids_vals[i] == 0) {
delta_a_exp.real_ = 0.f;
delta_a_exp.imag_ = 0.f;
}
}
weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]);
thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
if (i == 0) {
Expand Down Expand Up @@ -501,30 +533,32 @@ void selective_scan_bwd_launch(SSMParamsBwd &params, cudaStream_t stream) {
BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] {
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
// using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, true, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
// TODO: check this
constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t);
BOOL_SWITCH(params.pos_ids_ptr != nullptr , kIsVarLen, [&] {
using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, kIsVarLen, input_t, weight_t>;
// using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, true, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
// TODO: check this
constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t);

dim3 grid(params.batch, params.dim);

auto kernel = &selective_scan_bwd_kernel<Ktraits>;
dim3 grid(params.batch, params.dim);
auto kernel = &selective_scan_bwd_kernel<Ktraits>;

if (kSmemSize >= 48 * 1024) {
if (kSmemSize >= 48 * 1024) {

#ifndef USE_ROCM
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
#else
C10_CUDA_CHECK(cudaFuncSetAttribute(
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
std::cerr << "Warning (selective_scan_bwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
#endif
#ifndef USE_ROCM
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
#else
C10_CUDA_CHECK(cudaFuncSetAttribute(
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
std::cerr << "Warning (selective_scan_bwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
#endif

}
}

kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
Expand Down
17 changes: 17 additions & 0 deletions csrc/selective_scan/selective_scan_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,23 @@ inline __device__ void load_input(typename Ktraits::input_t *u,
}
}

template<typename Ktraits>
inline __device__ void load_pos_ids(typename Ktraits::pos_t *u,
typename Ktraits::pos_t (&u_vals)[Ktraits::kNItems],
typename Ktraits::BlockLoadPosIdsT::TempStorage &smem_load_pos_ids,
int seqlen) {
if constexpr (Ktraits::kIsEvenLen) {
auto& smem_load_pos_ids_vec = reinterpret_cast<typename Ktraits::BlockLoadPosIdsVecT::TempStorage&>(smem_load_pos_ids);
using pos_vec_t = typename Ktraits::pos_vec_t;
Ktraits::BlockLoadPosIdsVecT(smem_load_pos_ids_vec).Load(
reinterpret_cast<pos_vec_t*>(u),
reinterpret_cast<pos_vec_t(&)[Ktraits::kNLoadsPos]>(u_vals)
);
} else {
Ktraits::BlockLoadPosIdsT(smem_load_pos_ids).Load(u, u_vals, seqlen, 0.f);
}
}

template<typename Ktraits>
inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
Expand Down
Loading