Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
59e6abf
Migrate mamba_ssm and causal_conv1d kernels to vLLM
mzusman Aug 19, 2024
d2348ec
Casual conv1d compiles
mzusman Aug 20, 2024
66ee5af
Add casual_conv1d to _custom_ops
mzusman Aug 20, 2024
7a0d206
Add mamba ops and triton kernels
mzusman Aug 20, 2024
145b6b7
Add casual_conv1d update
mzusman Aug 20, 2024
2bdd7f5
setup selective scan fwd pass
mzusman Aug 20, 2024
e25dbfe
Format
mzusman Aug 20, 2024
64b6160
Do not have a mamba layer for now, push in a future PR
mzusman Aug 20, 2024
2ff36cb
Format
mzusman Aug 20, 2024
5f9c383
Take off mamba from image and requirements
mzusman Aug 20, 2024
ac8354e
Add tests
mzusman Aug 20, 2024
ea80282
Some small fixes, tests still do not pass
mzusman Aug 22, 2024
2f15495
Fix tests
mzusman Aug 22, 2024
b51fd28
Causal conv1d tests are passing
mzusman Aug 22, 2024
0cc2252
Import
mzusman Aug 22, 2024
d65dfb6
Tests
mzusman Aug 22, 2024
e7b2b32
Format
mzusman Aug 22, 2024
2c9fe00
Cleanup
mzusman Aug 22, 2024
c82cc30
Align with main
mzusman Aug 22, 2024
6c83e5f
Format
mzusman Aug 22, 2024
cd78cf6
Merge remote-tracking branch 'github/main' into mamba_kernels_migrate
mzusman Aug 22, 2024
b6a00cb
Add init py files
mzusman Aug 22, 2024
ef69b6c
Move kernels to cuda only
mzusman Aug 22, 2024
152f331
Revert "Move kernels to cuda only"
mzusman Aug 22, 2024
39f0fa0
move kernels to if cuda
mzusman Aug 22, 2024
42f94b7
Fix tests
mzusman Aug 22, 2024
f050781
Revert formating
mzusman Aug 25, 2024
c8ffba5
Format
mzusman Aug 25, 2024
04f947b
Add comments on adapted from mamba/casual conv1d repos
mzusman Aug 25, 2024
732db18
pare down number of w/i dtype combinations
mzusman Aug 25, 2024
fdca1ff
Clean up not used
mzusman Aug 25, 2024
fe70a39
Rename typo
mzusman Aug 25, 2024
9a0e538
Add comment on einops
mzusman Aug 25, 2024
619a40a
Remove requirement for einops
mzusman Aug 25, 2024
5d0d2db
Fix tests after paring down kernels
mzusman Aug 25, 2024
c622375
format
mzusman Aug 25, 2024
cdc9205
Fix typo
mzusman Aug 25, 2024
42d9c59
Merge remote-tracking branch 'github/main' into mamba_kernels_migrate
mzusman Aug 25, 2024
308c922
register meta functions to the kernels
mzusman Aug 25, 2024
d921a48
Revert "register meta functions to the kernels"
mzusman Aug 25, 2024
a8078e7
move to ifndef ROCm
mzusman Aug 26, 2024
2ca8db7
Format
mzusman Aug 26, 2024
abf02fa
Reduce combinations of bool switch to reduce wheel size
mzusman Aug 27, 2024
633225c
Fix, use float as weight dtype
mzusman Aug 27, 2024
ec0112b
Merge remote-tracking branch 'github/main' into mamba_kernels_migrate
mzusman Aug 28, 2024
1f35bbe
Take down seq_pos_idx, not used atm, will comeback in a following PR
mzusman Aug 28, 2024
bed44c4
Add comments and guard checks on disabled "features"
mzusman Aug 28, 2024
950701a
Fix header file
mzusman Aug 28, 2024
4e5d6b4
Merge remote-tracking branch 'github/main' into mamba_kernels_migrate
mzusman Aug 28, 2024
d23a429
Merge remote-tracking branch 'github/main' into mamba_kernels_migrate
mzusman Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions csrc/mamba/causal_conv1d/causal_conv1d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -404,19 +404,18 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
void causal_conv1d_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
BOOL_SWITCH(params.seq_pos_idx_ptr != nullptr, kHasSeqPosIdx, [&] {
BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
constexpr int kSmemSize = Ktraits::kSmemSize;
dim3 grid(params.batch, params.dim);
auto kernel = &causal_conv1d_fwd_kernel<Ktraits, kHasSeqPosIdx>;
if (kSmemSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
}
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
constexpr kHasSeqPosIdx = false;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a TORCH_CHECK(params.seq_pos_idx_ptr == nullptr)`

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment that some kernel cases have been disabled to reduce binary size would be good to add for documentation as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, btw, This variable is used for batch varlen enablement, which is out of scope IMO for this PR, I've taken it down completely and will have a seperate following up PR to for varlen batching

BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
constexpr int kSmemSize = Ktraits::kSmemSize;
dim3 grid(params.batch, params.dim);
auto kernel = &causal_conv1d_fwd_kernel<Ktraits, kHasSeqPosIdx>;
if (kSmemSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
}
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}

Expand Down
49 changes: 19 additions & 30 deletions csrc/mamba/mamba_ssm/selective_scan_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -311,26 +311,21 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
// processing 1 row.
constexpr int kNRows = 1;
constexpr bool kIsVariableB = true;
constexpr bool kIsVariableC = true;
constexpr bool kHasZ = true;
Comment on lines +315 to +317
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add torch checks to guard against cases where kIsVariableB, kIsVariableC, or kHasZ is false?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] {
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kUseIndex, input_t, weight_t>;
// constexpr int kSmemSize = Ktraits::kSmemSize;
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
// printf("smem_size = %d\n", kSmemSize);
dim3 grid(params.batch, params.dim / kNRows);
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
if (kSmemSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
}
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] {
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kUseIndex, input_t, weight_t>;
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
dim3 grid(params.batch, params.dim / kNRows);
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
if (kSmemSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
}
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
}
Expand Down Expand Up @@ -369,27 +364,23 @@ template void selective_scan_fwd_cuda<float, float>(SSMParamsBase &params, cudaS

#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")

#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
if (ITYPE == at::ScalarType::Half) { \
using input_t = at::Half; \
using weight_t = at::Half; \
__VA_ARGS__(); \
} else if (ITYPE == at::ScalarType::BFloat16) { \
using input_t = at::BFloat16; \
using weight_t = at::BFloat16; \
__VA_ARGS__(); \
} else if (ITYPE == at::ScalarType::Float) { \
using input_t = float; \
using weight_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
}

#define DISPATCH_WTYPE_FLOAT(WTYPE, NAME, ...) \
if (WTYPE == at::ScalarType::Float) { \
using weight_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
}

template<typename input_t, typename weight_t>
void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream);
Expand Down Expand Up @@ -598,10 +589,8 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
auto stream = at::cuda::getCurrentCUDAStream().stream();
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
DISPATCH_WTYPE_FLOAT(A.scalar_type(), "selective_scan_fwd", [&] {
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
});
});
std::vector<at::Tensor> result = {out, x.value()};
if (has_z) { result.push_back(out_z); }
Expand Down