diff --git a/csrc/selective_scan/selective_scan.cpp b/csrc/selective_scan/selective_scan.cpp index a97588e6..b7c47721 100644 --- a/csrc/selective_scan/selective_scan.cpp +++ b/csrc/selective_scan/selective_scan.cpp @@ -79,7 +79,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, void* delta_bias_ptr, void* x_ptr, bool has_z, - bool delta_softplus) { + bool delta_softplus, + void* pos_ids_ptr) { // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -109,6 +110,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.x_ptr = x_ptr; params.z_ptr = has_z ? z.data_ptr() : nullptr; params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; + params.pos_ids_ptr = pos_ids_ptr; + // All stride are in elements, not bytes. params.A_d_stride = A.stride(0); params.A_dstate_stride = A.stride(1); @@ -173,7 +176,8 @@ void set_ssm_params_bwd(SSMParamsBwd ¶ms, void* ddelta_bias_ptr, bool has_z, bool delta_softplus, - bool recompute_out_z) { + bool recompute_out_z, + void* pos_ids_ptr) { // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, u, delta, A, B, C, has_z ? out : dout, @@ -181,7 +185,7 @@ void set_ssm_params_bwd(SSMParamsBwd ¶ms, // If not recompute_out_z, pass dout instead of out_z. // This won't be used by the bwd kernel recompute_out_z ? out_z : dout, - D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus); + D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus, pos_ids_ptr); if (!recompute_out_z) { params.out_z_ptr = nullptr; } // Set the pointers and strides. @@ -229,7 +233,8 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, const c10::optional &D_, const c10::optional &z_, const c10::optional &delta_bias_, - bool delta_softplus) { + bool delta_softplus, + const c10::optional &pos_ids_) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -293,6 +298,14 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, CHECK_SHAPE(delta_bias, dim); } + if (pos_ids_.has_value()) { + auto pos_ids = pos_ids_.value(); + TORCH_CHECK(pos_ids.scalar_type() == at::ScalarType::Int); + TORCH_CHECK(pos_ids.is_cuda()); + CHECK_SHAPE(pos_ids, batch_size, seqlen); + TORCH_CHECK(batch_size == 1) + } + at::Tensor z, out_z; const bool has_z = z_.has_value(); if (has_z) { @@ -319,7 +332,8 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, x.data_ptr(), has_z, - delta_softplus); + delta_softplus, + pos_ids_.has_value() ? pos_ids_.value().data_ptr() : nullptr); // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing @@ -346,7 +360,8 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, const c10::optional &out_, c10::optional &dz_, bool delta_softplus, - bool recompute_out_z) { + bool recompute_out_z, + const c10::optional &pos_ids_) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -414,6 +429,14 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, CHECK_SHAPE(delta_bias, dim); } + if (pos_ids_.has_value()) { + auto pos_ids = pos_ids_.value(); + TORCH_CHECK(pos_ids.scalar_type() == at::ScalarType::Int); + TORCH_CHECK(pos_ids.is_cuda()); + CHECK_SHAPE(pos_ids, batch_size, seqlen); + TORCH_CHECK(batch_size == 1) + } + at::Tensor z, out, dz, out_z; const bool has_z = z_.has_value(); if (has_z) { @@ -474,7 +497,8 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, dout, du, ddelta, dA, dB, dC, dz, D_.has_value() ? dD.data_ptr() : nullptr, delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr, - has_z, delta_softplus, recompute_out_z); + has_z, delta_softplus, recompute_out_z, + pos_ids_.has_value() ? pos_ids_.value().data_ptr() : nullptr); // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing diff --git a/csrc/selective_scan/selective_scan.h b/csrc/selective_scan/selective_scan.h index e2c7bcdb..f8c9eaf7 100644 --- a/csrc/selective_scan/selective_scan.h +++ b/csrc/selective_scan/selective_scan.h @@ -66,6 +66,7 @@ struct SSMParamsBase { void *__restrict__ x_ptr; void *__restrict__ z_ptr; void *__restrict__ out_z_ptr; + void *__restrict__ pos_ids_ptr; }; struct SSMParamsBwd: public SSMParamsBase { diff --git a/csrc/selective_scan/selective_scan_bwd_kernel.cuh b/csrc/selective_scan/selective_scan_bwd_kernel.cuh index c720ba28..9572cb64 100755 --- a/csrc/selective_scan/selective_scan_bwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_bwd_kernel.cuh @@ -29,31 +29,39 @@ template<> __device__ __forceinline__ float conj(float x) { return x; } template<> __device__ __forceinline__ complex_t conj(complex_t x) { return std::conj(x); } template + bool kDeltaSoftplus_, bool kHasZ_, bool kIsVarLen_, typename input_t_, typename weight_t_> struct Selective_Scan_bwd_kernel_traits { static_assert(kNItems_ % 4 == 0); using input_t = input_t_; using weight_t = weight_t_; + using pos_t = uint32_t; static constexpr int kNThreads = kNThreads_; static constexpr int kNItems = kNItems_; static constexpr int kNBytes = sizeof(input_t); + static constexpr int kNBytesPos = sizeof(pos_t); static_assert(kNBytes == 2 || kNBytes == 4); static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems); + static constexpr int kNEltsPos = kNBytesPos == 4 ? 4 : constexpr_min(8, kNItems); static_assert(kNItems % kNElts == 0); static constexpr int kNLoads = kNItems / kNElts; + static constexpr int kNLoadsPos = kNItems / kNEltsPos; static constexpr bool kIsComplex = std::is_same_v; static constexpr bool kIsEvenLen = kIsEvenLen_; static constexpr bool kIsVariableB = kIsVariableB_; static constexpr bool kIsVariableC = kIsVariableC_; static constexpr bool kDeltaSoftplus = kDeltaSoftplus_; static constexpr bool kHasZ = kHasZ_; + static constexpr bool kIsVarLen = kIsVarLen_; // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy. // For complex this would lead to massive register spilling, so we keep it at 2. static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2; using vec_t = typename BytesToType::Type; + using pos_vec_t = typename BytesToType::Type; using scan_t = std::conditional_t; using BlockLoadT = cub::BlockLoad; using BlockLoadVecT = cub::BlockLoad; + using BlockLoadPosIdsT = cub::BlockLoad; + using BlockLoadPosIdsVecT = cub::BlockLoad; using BlockLoadWeightT = cub::BlockLoad; using BlockLoadWeightVecT = cub::BlockLoad; using BlockStoreT = cub::BlockStore; @@ -86,10 +94,12 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { constexpr bool kIsVariableC = Ktraits::kIsVariableC; constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus; constexpr bool kHasZ = Ktraits::kHasZ; + constexpr bool kIsVarLen = Ktraits::kIsVarLen; constexpr int kNThreads = Ktraits::kNThreads; constexpr int kNItems = Ktraits::kNItems; using input_t = typename Ktraits::input_t; using weight_t = typename Ktraits::weight_t; + using pos_t = typename Ktraits::pos_t; using scan_t = typename Ktraits::scan_t; // Shared memory. @@ -100,6 +110,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { // auto& smem_load = reinterpret_cast(smem_loadstorescan); auto& smem_load = reinterpret_cast(smem_); auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_pos_ids = reinterpret_cast(smem_); auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); auto& smem_store = reinterpret_cast(smem_); auto& smem_exchange = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); @@ -142,9 +153,11 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate; float dD_val = 0; float ddelta_bias_val = 0; + pos_t *pos_ids = !kIsVarLen ? nullptr :reinterpret_cast(params.pos_ids_ptr) + batch_id * params.seqlen; constexpr int kChunkSize = kNThreads * kNItems; u += (params.n_chunks - 1) * kChunkSize; + pos_ids += (params.n_chunks - 1) * kChunkSize; delta += (params.n_chunks - 1) * kChunkSize; dout += (params.n_chunks - 1) * kChunkSize; Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); @@ -153,9 +166,15 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { input_t u_vals[kNItems]; input_t delta_vals_load[kNItems]; input_t dout_vals_load[kNItems]; + pos_t pos_ids_vals[kNItems]; __syncthreads(); load_input(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize); u -= kChunkSize; + if constexpr (kIsVarLen) { + __syncthreads(); + load_pos_ids(pos_ids, pos_ids_vals, smem_load_pos_ids, params.seqlen - chunk * kChunkSize); + pos_ids -= kChunkSize; + } __syncthreads(); load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); // Will reload delta at the same location if kDeltaSoftplus @@ -250,7 +269,13 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { if constexpr (!kIsComplex) { #pragma unroll for (int i = 0; i < kNItems; ++i) { - const float delta_a_exp = exp2f(delta_vals[i] * A_scaled); + float delta_a_exp = exp2f(delta_vals[i] * A_scaled); + // Reset A bar for cumulative sequences (Real) + if constexpr (kIsVarLen) { + if (pos_ids_vals[i] == 0) { + delta_a_exp = 0.f; + } + } thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); if (i == 0) { smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp; @@ -338,6 +363,13 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) { for (int i = 0; i < kNItems; ++i) { // Pytorch's implementation of complex exp (which calls thrust) is very slow complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled); + // Reset A bar for cumulative sequences (Complex) + if constexpr (kIsVarLen) { + if (pos_ids_vals[i] == 0) { + delta_a_exp.real_ = 0.f; + delta_a_exp.imag_ = 0.f; + } + } weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]); thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); if (i == 0) { @@ -501,30 +533,32 @@ void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] { BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { - using Ktraits = Selective_Scan_bwd_kernel_traits; - // using Ktraits = Selective_Scan_bwd_kernel_traits; - // TODO: check this - constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t); + BOOL_SWITCH(params.pos_ids_ptr != nullptr , kIsVarLen, [&] { + using Ktraits = Selective_Scan_bwd_kernel_traits; + // using Ktraits = Selective_Scan_bwd_kernel_traits; + // TODO: check this + constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t); - dim3 grid(params.batch, params.dim); - - auto kernel = &selective_scan_bwd_kernel; + dim3 grid(params.batch, params.dim); + + auto kernel = &selective_scan_bwd_kernel; - if (kSmemSize >= 48 * 1024) { + if (kSmemSize >= 48 * 1024) { - #ifndef USE_ROCM - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - #else - C10_CUDA_CHECK(cudaFuncSetAttribute( - (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - std::cerr << "Warning (selective_scan_bwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; - #endif + #ifndef USE_ROCM + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + #else + C10_CUDA_CHECK(cudaFuncSetAttribute( + (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + std::cerr << "Warning (selective_scan_bwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; + #endif - } + } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); diff --git a/csrc/selective_scan/selective_scan_common.h b/csrc/selective_scan/selective_scan_common.h index 91328e91..2b44566d 100644 --- a/csrc/selective_scan/selective_scan_common.h +++ b/csrc/selective_scan/selective_scan_common.h @@ -196,6 +196,23 @@ inline __device__ void load_input(typename Ktraits::input_t *u, } } +template +inline __device__ void load_pos_ids(typename Ktraits::pos_t *u, + typename Ktraits::pos_t (&u_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadPosIdsT::TempStorage &smem_load_pos_ids, + int seqlen) { + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_pos_ids_vec = reinterpret_cast(smem_load_pos_ids); + using pos_vec_t = typename Ktraits::pos_vec_t; + Ktraits::BlockLoadPosIdsVecT(smem_load_pos_ids_vec).Load( + reinterpret_cast(u), + reinterpret_cast(u_vals) + ); + } else { + Ktraits::BlockLoadPosIdsT(smem_load_pos_ids).Load(u, u_vals, seqlen, 0.f); + } +} + template inline __device__ void load_weight(typename Ktraits::input_t *Bvar, typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index 80e9e37e..32f759e4 100755 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -23,34 +23,44 @@ template + bool kHasZ_, bool kIsVarLen_, typename input_t_, typename weight_t_> struct Selective_Scan_fwd_kernel_traits { static_assert(kNItems_ % 4 == 0); using input_t = input_t_; using weight_t = weight_t_; + using pos_t = uint32_t; static constexpr int kNThreads = kNThreads_; // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; static constexpr int kNItems = kNItems_; static constexpr int kNRows = kNRows_; static constexpr int kNBytes = sizeof(input_t); + static constexpr int kNBytesPos = sizeof(pos_t); static_assert(kNBytes == 2 || kNBytes == 4); static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems); + static constexpr int kNEltsPos = kNBytesPos == 4 ? 4 : constexpr_min(8, kNItems); static_assert(kNItems % kNElts == 0); static constexpr int kNLoads = kNItems / kNElts; + static constexpr int kNLoadsPos = kNItems / kNEltsPos; static constexpr bool kIsComplex = std::is_same_v; static constexpr bool kIsEvenLen = kIsEvenLen_; static constexpr bool kIsVariableB = kIsVariableB_; static constexpr bool kIsVariableC = kIsVariableC_; static constexpr bool kHasZ = kHasZ_; + static constexpr bool kIsVarLen = kIsVarLen_; static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; - + static constexpr bool kDirectIOPos = kIsEvenLen && kNLoadsPos == 1; + using vec_t = typename BytesToType::Type; + using pos_vec_t = typename BytesToType::Type; using scan_t = std::conditional_t; using BlockLoadT = cub::BlockLoad; using BlockLoadVecT = cub::BlockLoad; + using BlockLoadPosIdsT = cub::BlockLoad; + using BlockLoadPosIdsVecT = cub::BlockLoad; using BlockLoadWeightT = cub::BlockLoad; using BlockLoadWeightVecT = cub::BlockLoad; @@ -76,12 +86,15 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { constexpr bool kIsVariableB = Ktraits::kIsVariableB; constexpr bool kIsVariableC = Ktraits::kIsVariableC; constexpr bool kHasZ = Ktraits::kHasZ; + constexpr bool kIsVarLen = Ktraits::kIsVarLen; constexpr int kNThreads = Ktraits::kNThreads; constexpr int kNItems = Ktraits::kNItems; constexpr int kNRows = Ktraits::kNRows; constexpr bool kDirectIO = Ktraits::kDirectIO; + constexpr bool kDirectIOPos = Ktraits::kDirectIOPos; using input_t = typename Ktraits::input_t; using weight_t = typename Ktraits::weight_t; + using pos_t = typename Ktraits::pos_t; using scan_t = typename Ktraits::scan_t; // Shared memory. @@ -92,6 +105,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // auto& smem_load = reinterpret_cast(smem_loadstorescan); auto& smem_load = reinterpret_cast(smem_); auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_pos_ids = reinterpret_cast(smem_); auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); auto& smem_store = reinterpret_cast(smem_); auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); @@ -112,6 +126,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; + pos_t *pos_ids = !kIsVarLen ? nullptr :reinterpret_cast(params.pos_ids_ptr) + batch_id * params.seqlen; float D_val[kNRows] = {0}; if (params.D_ptr != nullptr) { @@ -136,6 +151,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { constexpr int kChunkSize = kNThreads * kNItems; for (int chunk = 0; chunk < params.n_chunks; ++chunk) { input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; + pos_t pos_ids_vals[kNRows][kNItems]; __syncthreads(); #pragma unroll for (int r = 0; r < kNRows; ++r) { @@ -145,9 +161,14 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); if constexpr (!kDirectIO) { __syncthreads(); } load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + if constexpr (kIsVarLen) { + if constexpr (!kDirectIOPos) { __syncthreads(); } + load_pos_ids(pos_ids + r * params.delta_d_stride, pos_ids_vals[r], smem_load_pos_ids, params.seqlen - chunk * kChunkSize); + } } u += kChunkSize; delta += kChunkSize; + if constexpr (kIsVarLen) { pos_ids += kChunkSize; } float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; #pragma unroll @@ -220,6 +241,12 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if constexpr (!kIsComplex) { thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); + // Reset A bar for cumulative sequences (Real) + if constexpr (kIsVarLen) { + if (pos_ids_vals[r][i] == 0) { + thread_data[i].x = 0.f; + } + } if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { thread_data[i] = make_float2(1.f, 0.f); @@ -230,6 +257,12 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]); weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]; thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); + if constexpr (kIsVarLen) { + if (pos_ids_vals[r][i] == 0) { + thread_data[i].x = 0.f; + thread_data[i].y = 0.f; + } + } if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f); @@ -316,31 +349,33 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { - using Ktraits = Selective_Scan_fwd_kernel_traits; - - constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); - dim3 grid(params.batch, params.dim / kNRows); + BOOL_SWITCH(params.pos_ids_ptr != nullptr , kIsVarLen, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; + + constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); + dim3 grid(params.batch, params.dim / kNRows); - // Had to change this substantially since potentially the hip - // interface for setting kernel launch attributes is slightly different from - // cuda's. In particualar, it seems to expect a plain const void * pointer. + // Had to change this substantially since potentially the hip + // interface for setting kernel launch attributes is slightly different from + // cuda's. In particualar, it seems to expect a plain const void * pointer. - auto kernel = &selective_scan_fwd_kernel; + auto kernel = &selective_scan_fwd_kernel; - - if (kSmemSize >= 48 * 1024) { - #ifndef USE_ROCM - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - #else - C10_CUDA_CHECK(cudaFuncSetAttribute( - (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - std::cerr << "Warning (selective_scan_fwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; - #endif - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + if (kSmemSize >= 48 * 1024) { + #ifndef USE_ROCM + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + #else + C10_CUDA_CHECK(cudaFuncSetAttribute( + (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + std::cerr << "Warning (selective_scan_fwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; + #endif + } + + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index 4c8a3882..2bca7707 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -116,13 +116,30 @@ def __init__( self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - def forward(self, hidden_states, inference_params=None): + def forward(self, hidden_states, cu_seqlens=None, seq_idx=None, position_ids=None, inference_params=None): """ hidden_states: (B, L, D) + cu_seqlens: (Optional) cumulative sum of the sequence lengths, starting from 0 and end with L, and must already be sorted. Returns: same shape as hidden_states """ batch, seqlen, dim = hidden_states.shape + if cu_seqlens is not None: + # Sanity Check + assert batch == 1 and cu_seqlens.ndimension() == 1, "varlen mamba1 is only supported with B=1" + # Warning: + # 1) For testing, seq_idx and position_ids can be computed on-the-fly but would harm performance + # 2) For production, please betther prepare them in dataloader + if seq_idx is None: + seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) + for i, s in enumerate(cu_seqlens[1:]-cu_seqlens[:-1])], dim=0).unsqueeze(0) + if position_ids is None: + position_ids = (torch.arange((cu_seqlens[1:] - cu_seqlens[:-1]).sum(), device=cu_seqlens.device) + - torch.repeat_interleave(cu_seqlens[:-1], (cu_seqlens[1:] - cu_seqlens[:-1]))).to(torch.int32).unsqueeze(0) + else: + seq_idx = None + position_ids = None + conv_state, ssm_state = None, None if inference_params is not None: conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) @@ -157,6 +174,9 @@ def forward(self, hidden_states, inference_params=None): self.D.float(), delta_bias=self.dt_proj.bias.float(), delta_softplus=True, + cu_seqlens=cu_seqlens, + seq_idx=seq_idx, + position_ids=position_ids, ) else: x, z = xz.chunk(2, dim=1) @@ -166,13 +186,23 @@ def forward(self, hidden_states, inference_params=None): # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W) if causal_conv1d_fn is None: - x = self.act(self.conv1d(x)[..., :seqlen]) + if cu_seqlens is not None: + # naive pure python implementation of varlen causal_conv1d + for i, s in enumerate(cu_seqlens[1:-1]): + x = torch.cat((x[..., :s + i*(self.d_conv - 1)], torch.zeros_like(x[..., :(self.d_conv - 1)]), x[..., s + i*(self.d_conv - 1):]), dim=2) + mask = torch.cat([torch.cat((torch.full((s,), True, dtype=torch.bool, device=x.device), + torch.full((self.d_conv - 1,), False, dtype=torch.bool, device=x.device)), dim=0) + for s in (cu_seqlens[1:] - cu_seqlens[:-1])], dim=0) + x = self.act(self.conv1d(x)[:, :, mask]) + else: + x = self.act(self.conv1d(x)[..., :seqlen]) else: assert self.activation in ["silu", "swish"] x = causal_conv1d_fn( - x=x, + x=x.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else x, weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), bias=self.conv1d.bias, + seq_idx=seq_idx, activation=self.activation, ) @@ -197,6 +227,7 @@ def forward(self, hidden_states, inference_params=None): delta_bias=self.dt_proj.bias.float(), delta_softplus=True, return_last_state=ssm_state is not None, + cu_seqlens=cu_seqlens, ) if ssm_state is not None: y, last_state = y diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index a41f1359..98968121 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -24,7 +24,7 @@ class SelectiveScanFn(torch.autograd.Function): @staticmethod def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + return_last_state=False, cu_seqlens=None, position_ids = None): if u.stride(-1) != 1: u = u.contiguous() if delta.stride(-1) != 1: @@ -43,26 +43,26 @@ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softp if C.dim() == 3: C = rearrange(C, "b dstate l -> b 1 dstate l") ctx.squeeze_C = True - out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, position_ids) ctx.delta_softplus = delta_softplus ctx.has_z = z is not None last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) if not ctx.has_z: - ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x, cu_seqlens, position_ids) return out if not return_last_state else (out, last_state) else: - ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out, cu_seqlens, position_ids) out_z = rest[0] return out_z if not return_last_state else (out_z, last_state) @staticmethod def backward(ctx, dout, *args): if not ctx.has_z: - u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + u, delta, A, B, C, D, delta_bias, x, cu_seqlens, position_ids = ctx.saved_tensors z = None out = None else: - u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors + u, delta, A, B, C, D, z, delta_bias, x, out, cu_seqlens, position_ids = ctx.saved_tensors if dout.stride(-1) != 1: dout = dout.contiguous() # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the @@ -70,7 +70,8 @@ def backward(ctx, dout, *args): # Here we just pass in None and dz will be allocated in the C++ code. du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, - False # option to recompute out_z, not used here + False, # option to recompute out_z, not used here + position_ids ) dz = rest[0] if ctx.has_z else None dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB @@ -80,6 +81,8 @@ def backward(ctx, dout, *args): dz, ddelta_bias if delta_bias is not None else None, None, + None, + None, None) @@ -104,16 +107,16 @@ def rms_norm_forward( def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + return_last_state=False, cu_seqlens=None, position_ids=None): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). Note that the gradient of the last state is not considered in the backward pass. """ - return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, cu_seqlens, position_ids) def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False): + return_last_state=False, cu_seqlens=None, position_ids=None): """ u: r(B D L) delta: r(B D L) @@ -160,7 +163,10 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) last_state = None for i in range(u.shape[2]): - x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if cu_seqlens is not None and i in cu_seqlens[1:-1].tolist(): + x = deltaB_u[:, :, i] + else: + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] if not is_variable_C: y = torch.einsum('bdn,dn->bd', x, C) else: @@ -188,12 +194,13 @@ class MambaInnerFn(torch.autograd.Function): def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight=None, c_rms_weight= None, dt_rms_weight= None, b_c_dt_rms_eps=1e-6): + C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight=None, c_rms_weight= None, dt_rms_weight= None, b_c_dt_rms_eps=1e-6, cu_seqlens=None, seq_idx=None, position_ids=None): """ xz: (batch, dim, seqlen) """ assert causal_conv1d_fwd_function is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." assert checkpoint_lvl in [0, 1] + L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) @@ -207,15 +214,25 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh xz = xz.contiguous() conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") x, z = xz.chunk(2, dim=1) - conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None + conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None conv1d_out = causal_conv1d_fwd_function( - x, conv1d_weight, conv1d_bias, None, None, None, True + x.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else x, + conv1d_weight, + conv1d_bias, + seq_idx, + None, + None, + True ) + if conv1d_out.stride(-1) != 1: + conv1d_out = conv1d_out.contiguous() + # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) + ctx.is_variable_B = B is None ctx.is_variable_C = C is None ctx.B_proj_bias_is_None = B_proj_bias is None @@ -261,7 +278,16 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous() out, scan_intermediates, out_z = selective_scan_cuda.fwd( - conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus + conv1d_out, + delta, + A, + B, + C, + D, + z, + delta_bias, + delta_softplus, + position_ids, ) ctx.delta_softplus = delta_softplus ctx.out_proj_bias_is_None = out_proj_bias is None @@ -274,7 +300,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh conv1d_out, delta = None, None ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, - A, B, C, D, delta_bias, scan_intermediates, b_rms_weight, c_rms_weight, dt_rms_weight, out) + A, B, C, D, delta_bias, scan_intermediates, b_rms_weight, c_rms_weight, dt_rms_weight, out, cu_seqlens, seq_idx, position_ids) return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) @staticmethod @@ -283,7 +309,7 @@ def backward(ctx, dout): # dout: (batch, seqlen, dim) assert causal_conv1d_fwd_function is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, - conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, b_rms_weight, c_rms_weight, dt_rms_weight, out) = ctx.saved_tensors + conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, b_rms_weight, c_rms_weight, dt_rms_weight, out, cu_seqlens, seq_idx, position_ids) = ctx.saved_tensors L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) @@ -292,8 +318,16 @@ def backward(ctx, dout): dout = dout.contiguous() if ctx.checkpoint_lvl == 1: conv1d_out = causal_conv1d_fwd_function( - x, conv1d_weight, conv1d_bias, None, None, None, True + x.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else x, + conv1d_weight, + conv1d_bias, + seq_idx, + None, + None, + True ) + if conv1d_out.stride(-1) != 1: + conv1d_out = conv1d_out.contiguous() delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) if dt_rms_weight is not None: @@ -324,7 +358,8 @@ def backward(ctx, dout): dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz, ctx.delta_softplus, - True # option to recompute out_z + True, # option to recompute out_z + position_ids ) dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None @@ -358,41 +393,58 @@ def backward(ctx, dout): # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the # backward of conv1d with the backward of chunk). dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_bwd_function( - x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True + x.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else x, + conv1d_weight, + conv1d_bias, + dconv1d_out, + seq_idx, + None, + None, + dx.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else dx, + False, + True ) dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") - return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, + return (torch.cat((dx, dz), dim=1) if cu_seqlens is not None else dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, dout_proj_weight, dout_proj_bias, dA, dB, dC, dD, ddelta_bias if delta_bias is not None else None, # 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps - dB_proj_bias, dC_proj_bias, None, None, None, None, None, None) + dB_proj_bias, dC_proj_bias, None, None, None, None, None, None, None, None, None) def mamba_inner_fn( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight= None, c_rms_weight= None, dt_rms_weight= None, b_c_dt_rms_eps=1e-6 + C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight= None, c_rms_weight= None, dt_rms_weight= None, b_c_dt_rms_eps=1e-6, cu_seqlens=None, seq_idx=None, position_ids=None ): return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, - A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps) + A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps, cu_seqlens, seq_idx, position_ids) def mamba_inner_ref( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True + C_proj_bias=None, delta_softplus=True, cu_seqlens=None, seq_idx=None, position_ids=None ): assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d." L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) - x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu") + + x = causal_conv1d_fn( + x.transpose(1,2).contiguous().transpose(1,2) if cu_seqlens is not None else x, + rearrange(conv1d_weight, "d 1 w -> d w"), + conv1d_bias, + seq_idx=seq_idx, + activation="silu" + ) + # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. @@ -415,5 +467,5 @@ def mamba_inner_ref( C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() else: C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() - y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) - return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) + y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True, position_ids=position_ids) + return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) \ No newline at end of file diff --git a/tests/ops/test_mamba_varlen.py b/tests/ops/test_mamba_varlen.py new file mode 100644 index 00000000..33c6bd52 --- /dev/null +++ b/tests/ops/test_mamba_varlen.py @@ -0,0 +1,285 @@ +import random +import pytest +import torch + +from torch import nn +from mamba_ssm.modules.mamba_simple import Mamba +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref + + +''' +unpack function: convert packed_hidden_states (batch_size=1) to hidden_states +''' +def unpack(packed_hidden_states, cu_seqlens): + batch_size = cu_seqlens.shape[0] - 1 + seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + hidden_dim = packed_hidden_states.shape[2] + hidden_states = torch.zeros(batch_size, seq_len, hidden_dim, dtype=packed_hidden_states.dtype, device=packed_hidden_states.device) + for i in range(batch_size): + hidden_states[i, : cu_seqlens[i + 1] - cu_seqlens[i], :] = packed_hidden_states[:, cu_seqlens[i] : cu_seqlens[i + 1], :] + return hidden_states + + +''' +pack function: convert hidden_states to packed_hidden_states (batch_size=1) +''' +def pack(hidden_states, cu_seqlens): + batch_size, seq_len, hidden_dim = hidden_states.shape + seq_len_list = cu_seqlens[1:] - cu_seqlens[:-1] + seq_len_list_3d = seq_len_list.unsqueeze(1).unsqueeze(2) + indices_3d = ( + torch.arange(seq_len, device=hidden_states.device) + .unsqueeze(0) + .unsqueeze(2) + .repeat(batch_size, 1, hidden_dim) + ) + mask_3d = indices_3d < seq_len_list_3d + packed_hidden_states = hidden_states[mask_3d].view(-1, hidden_dim) + return packed_hidden_states + + +class NLayerMambaModel(nn.Module): + def __init__(self, layer_num, hidden_dim, device): + super().__init__() + self.layers = nn.ModuleList( + [ + Mamba( + # This module uses roughly 3 * expand * d_model^2 parameters + d_model=hidden_dim, # Model dimension d_model + d_state=16, # SSM state expansion factor + d_conv=4, # Local convolution width + expand=2, # Block expansion factor + layer_idx=layer_idx, + ).to(device) for layer_idx in range(layer_num) + ] + ) + + def forward(self, x, cu_seqlens=None, seq_idx=None, position_ids=None): + residual = x + for layer in self.layers: + x = layer(x, cu_seqlens, seq_idx=seq_idx, position_ids=position_ids) + return x + residual + + +''' +Generate random cu_seqlens for testing +''' +def generate_random_cu_seqlens(seq_len, batch_size=None): + if batch_size is None: + batch_size = random.randint(1, seq_len) + if batch_size > 1: + ret = sorted(random.sample(range(1, seq_len), batch_size - 1)) + else: + ret = [] + cu_seqlens = [0] + ret + [seq_len] + assert batch_size == len(cu_seqlens) - 1 + return cu_seqlens + + +@pytest.mark.parametrize('itype', [torch.float32]) +@pytest.mark.parametrize('layer_num', [1, 2, 4, 8]) +@pytest.mark.parametrize("hidden_dim", [2048]) +@pytest.mark.parametrize('seq_len', [1024, 2048, 4096, 8192]) +def test_mamba_varlen(itype, layer_num, hidden_dim, seq_len): + device='cuda' + if itype == torch.float32: + rtol, atol = (6e-4, 2e-3) + elif itype == torch.bfloat16: + rtol, atol = (3e-2, 5e-2) + else: + rtol, atol = (3e-3, 5e-3) + + # Generate random cu_seqlens for testing + cu_seqlens = generate_random_cu_seqlens(seq_len) + cu_seqlens = torch.tensor(cu_seqlens, device=device) + print(f'Generate random cu_seqlens = {cu_seqlens.tolist()}') + + # Generate packed_hidden_states with random values for testing + # packed_hidden_states (packed_batch_size=1) should be forwarded with cu_seqlens + hidden_states_list = [torch.randn(l, hidden_dim, device=device) for l in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()] + packed_hidden_states = torch.cat(hidden_states_list, dim=0).unsqueeze(0) + # hidden_states should be forwarded without cu_seqlens + hidden_states = unpack(packed_hidden_states, cu_seqlens) + + # Check: sum of seq_len of item in hidden_states_list should be equal to seq_len of packed_hidden_states + assert sum([hs.shape[0] for hs in hidden_states_list]) == packed_hidden_states.shape[1] + # Check: max of seq_len of item in hidden_states_list should be equal to seq_len of hidden_states + assert max([hs.shape[0] for hs in hidden_states_list]) == hidden_states.shape[1] + + # creat one simple mamba block + mamba_ref = NLayerMambaModel(layer_num, hidden_dim, device) + mamba = NLayerMambaModel(layer_num, hidden_dim, device) + mamba.load_state_dict(mamba_ref.state_dict()) + print(f"show reference model for testing: {mamba_ref}", flush=True) + + # reference output for forwardding hidden_states + out_ref_original = mamba_ref(hidden_states) + out_ref = pack(out_ref_original, cu_seqlens).unsqueeze(0) + + # In production, cu_seqlens/seq_idx/position_ids should be prepared in the dataloader + seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) + for i, s in enumerate(cu_seqlens[1:]-cu_seqlens[:-1])], dim=0).unsqueeze(0) + position_ids = (torch.arange((cu_seqlens[1:] - cu_seqlens[:-1]).sum(), device=cu_seqlens.device) + - torch.repeat_interleave(cu_seqlens[:-1], (cu_seqlens[1:] - cu_seqlens[:-1]))).to(torch.int32).unsqueeze(0) + # output for forwardding packed_hidden_states + out = mamba(packed_hidden_states, cu_seqlens=cu_seqlens, seq_idx=seq_idx, position_ids=position_ids) + + # Testing the max/mean diff + print(f"max diff for output in varlen_mamba fwd pass: {(out - out_ref).abs().max().item()}", flush=True) + print(f"mean diff for output in varlen_mamba fwd pass: {(out - out_ref).abs().mean().item()}", flush=True) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + # Generate random loss for backward testing + loss_fn = nn.CrossEntropyLoss() + g = torch.randn_like(out) + g_ref = unpack(g, cu_seqlens) + loss = loss_fn(out, g) + loss_ref = loss_fn(out_ref_original, g_ref) + loss.backward() + loss_ref.backward() + + # Check weight grad + all_grads_match = True + for (name_ref, param_ref), (name_packed, param_packed) in zip( + mamba_ref.named_parameters(), mamba.named_parameters() + ): + grad_match = torch.allclose(param_ref.grad, param_packed.grad, rtol=rtol, atol=atol) + if not grad_match: + print(f"Gradient mismatch in {name_ref} and {name_packed}! Max diff: {(param_ref.grad - param_packed.grad).abs().max().item()}", flush=True) + all_grads_match = False + print(f"All gradients match: {all_grads_match}", flush=True) + + +@pytest.mark.parametrize('wtype', [torch.float32]) +@pytest.mark.parametrize('itype', [torch.float32]) +@pytest.mark.parametrize('seq_len', [1024, 2048, 4096, 8192]) +@pytest.mark.parametrize("return_last_state", [True]) +@pytest.mark.parametrize('has_delta_bias', [True]) +@pytest.mark.parametrize('delta_softplus', [True]) +@pytest.mark.parametrize('has_z', [True]) +@pytest.mark.parametrize('has_D', [True]) +@pytest.mark.parametrize("varBC_groups", [1, 2]) +@pytest.mark.parametrize("is_variable_C", [True]) +@pytest.mark.parametrize("is_variable_B", [True]) +def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, + delta_softplus, return_last_state, seq_len, itype, wtype): + if varBC_groups > 1 and (not is_variable_B or not is_variable_C): + pytest.skip() # This config is not applicable + device = 'cuda' + if itype == torch.float32: + rtol, atol = (6e-4, 2e-3) + elif itype == torch.bfloat16: + rtol, atol = (3e-2, 5e-2) + else: + rtol, atol = (3e-3, 5e-3) + rtolw, atolw = (1e-3, 1e-3) + if has_z: # If we have z, the errors on the weights seem higher + rtolw = max(rtolw, rtol) + atolw = max(atolw, atol) + # set seed + torch.random.manual_seed(0) + packed_batch_size = 1 + dim = 768 + dstate = 8 + is_complex = wtype == torch.complex64 + A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_() + if not is_variable_B: + B_shape = (dim, dstate) + elif varBC_groups == 1: + B_shape = (packed_batch_size, dstate, seq_len if not is_complex else seq_len * 2) + else: + B_shape = (packed_batch_size, varBC_groups, dstate, seq_len if not is_complex else seq_len * 2) + B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype, + requires_grad=True) + if not is_variable_C: + C_shape = (dim, dstate) + elif varBC_groups == 1: + C_shape = (packed_batch_size, dstate, seq_len if not is_complex else seq_len * 2) + else: + C_shape = (packed_batch_size, varBC_groups, dstate, seq_len if not is_complex else seq_len * 2) + C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype, + requires_grad=True) + if has_D: + D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) + else: + D = None + if has_z: + z = torch.randn(packed_batch_size, dim, seq_len, device=device, dtype=itype, requires_grad=True) + else: + z = None + if has_delta_bias: + delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_() + else: + delta_bias = None + u = torch.randn(packed_batch_size, dim, seq_len, device=device, dtype=itype, requires_grad=True) + delta = (0.5 * torch.rand(packed_batch_size, dim, seq_len, device=device, dtype=itype)).requires_grad_() + A_ref = A.detach().clone().requires_grad_() + B_ref = B.detach().clone().requires_grad_() + C_ref = C.detach().clone().requires_grad_() + D_ref = D.detach().clone().requires_grad_() if D is not None else None + z_ref = z.detach().clone().requires_grad_() if z is not None else None + u_ref = u.detach().clone().requires_grad_() + delta_ref = delta.detach().clone().requires_grad_() + delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None + + # In production, cu_seqlens/seq_idx/position_ids should be prepared in the dataloader + cu_seqlens = torch.tensor(generate_random_cu_seqlens(seq_len), device=device) + # seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) + # for i, s in enumerate(cu_seqlens[1:]-cu_seqlens[:-1])], dim=0).unsqueeze(0) + position_ids = (torch.arange((cu_seqlens[1:] - cu_seqlens[:-1]).sum(), device=cu_seqlens.device) + - torch.repeat_interleave(cu_seqlens[:-1], (cu_seqlens[1:] - cu_seqlens[:-1]))).to(torch.int32).unsqueeze(0) + + out, *rest = selective_scan_fn( + u, delta, A, B, C, D, z=z, + delta_bias=delta_bias, delta_softplus=delta_softplus, + return_last_state=return_last_state, cu_seqlens=cu_seqlens, position_ids=position_ids + ) + if return_last_state: + state = rest[0] + out_ref, *rest = selective_scan_ref( + u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref, + delta_bias=delta_bias_ref, delta_softplus=delta_softplus, + return_last_state=return_last_state, cu_seqlens=cu_seqlens, position_ids=position_ids + ) + if return_last_state: + state_ref = rest[0] + # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + # dt_u = delta * u + + print(f'Output max diff: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + if return_last_state: + print(f'State max diff: {(state - state_ref).abs().max().item()}') + assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + + g = torch.randn_like(out) + out_ref.backward(g) + out.backward(g) + + print(f'du max diff: {(u.grad - u_ref.grad).abs().max().item()}') + print(f'ddelta max diff: {(delta.grad - delta_ref.grad).abs().max().item()}') + print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}') + print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}') + print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}') + if has_D: + print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}') + if has_z: + print(f'dz max diff: {(z.grad - z_ref.grad).abs().max().item()}') + if has_delta_bias: + print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}') + + assert torch.allclose(u.grad, u_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2) + assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10) + assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5) + assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol, + atol=atolw if not is_variable_B else atol) + assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol, + atol=atolw if not is_variable_C else atol) + if has_D: + assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw) + if has_z: + assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw) + if has_delta_bias: + assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw) + \ No newline at end of file