Skip to content

Commit 7a1448b

Browse files
mzusmanLeiWang1999
authored andcommitted
[Kernel][Model] Improve continuous batching for Jamba and Mamba (vllm-project#9189)
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent 4e25c7c commit 7a1448b

File tree

15 files changed

+511
-439
lines changed

15 files changed

+511
-439
lines changed

csrc/mamba/causal_conv1d/causal_conv1d.cu

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ void set_conv_params_fwd(ConvParamsBase &params,
5555
const at::Tensor out,
5656
const c10::optional<at::Tensor>& bias,
5757
bool silu_activation,
58+
int64_t pad_slot_id,
5859
const c10::optional<at::Tensor>& query_start_loc = std::nullopt,
5960
const c10::optional<at::Tensor>& cache_indices = std::nullopt,
6061
const c10::optional<at::Tensor>& has_initial_state = std::nullopt) {
@@ -66,6 +67,7 @@ void set_conv_params_fwd(ConvParamsBase &params,
6667
params.dim = dim;
6768
params.seqlen = seqlen;
6869
params.width = width;
70+
params.pad_slot_id = pad_slot_id;
6971

7072
params.silu_activation = silu_activation;
7173

@@ -90,14 +92,16 @@ void set_conv_params_fwd(ConvParamsBase &params,
9092
}
9193

9294

93-
at::Tensor
94-
causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
95+
void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
9596
const c10::optional<at::Tensor> &bias_,
9697
const c10::optional<at::Tensor> &conv_states,
9798
const c10::optional<at::Tensor> &query_start_loc,
9899
const c10::optional<at::Tensor> &cache_indices,
99100
const c10::optional<at::Tensor> &has_initial_state,
100-
bool silu_activation) {
101+
bool silu_activation,
102+
// used to identify padding entries if cache_indices provided
103+
// in case of padding, the kernel will return early
104+
int64_t pad_slot_id) {
101105
auto input_type = x.scalar_type();
102106
auto weight_type = weight.scalar_type();
103107
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@@ -153,12 +157,13 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
153157
CHECK_SHAPE(cache_indices_, batch_size);
154158
}
155159

156-
at::Tensor out = torch::empty_like(x);
160+
at::Tensor out = x;
157161

158162
ConvParamsBase params;
159163
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
160164
bias_,
161165
silu_activation,
166+
pad_slot_id,
162167
query_start_loc,
163168
cache_indices,
164169
has_initial_state
@@ -183,18 +188,19 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
183188
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
184189
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
185190
});
186-
return out;
187191
}
188192

189193

190-
at::Tensor
191-
causal_conv1d_update(const at::Tensor &x,
194+
void causal_conv1d_update(const at::Tensor &x,
192195
const at::Tensor &conv_state,
193196
const at::Tensor &weight,
194197
const c10::optional<at::Tensor> &bias_,
195198
bool silu_activation,
196199
const c10::optional<at::Tensor> &cache_seqlens_,
197-
const c10::optional<at::Tensor> &conv_state_indices_) {
200+
const c10::optional<at::Tensor> &conv_state_indices_,
201+
// used to identify padding entries if cache_indices provided
202+
// in case of padding, the kernel will return early
203+
int64_t pad_slot_id) {
198204
auto input_type = x.scalar_type();
199205
auto weight_type = weight.scalar_type();
200206
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@@ -227,12 +233,13 @@ causal_conv1d_update(const at::Tensor &x,
227233
CHECK_SHAPE(bias, dim);
228234
}
229235

230-
at::Tensor out = torch::empty_like(x);
236+
at::Tensor out = x;
231237

232238
ConvParamsBase params;
233239
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
234240
bias_,
235-
silu_activation);
241+
silu_activation,
242+
pad_slot_id);
236243
params.conv_state_ptr = conv_state.data_ptr();
237244
params.conv_state_len = conv_state_len;
238245
// All stride are in elements, not bytes.
@@ -274,7 +281,6 @@ causal_conv1d_update(const at::Tensor &x,
274281
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
275282
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
276283
});
277-
return out;
278284
}
279285

280286
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
@@ -340,7 +346,10 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
340346
int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
341347
: reinterpret_cast<int *>(params.cache_indices_ptr);
342348
int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
343-
349+
// cache_index == params.pad_slot_id is defined as padding, so we exit early
350+
if (cache_index == params.pad_slot_id){
351+
return;
352+
}
344353
input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr
345354
: reinterpret_cast<input_t *>(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride;
346355

@@ -528,6 +537,10 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
528537
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
529538
? batch_id
530539
: params.conv_state_indices_ptr[batch_id];
540+
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
541+
if (conv_state_batch_coord == params.pad_slot_id){
542+
return;
543+
}
531544
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
532545
+ conv_state_batch_coord * params.conv_state_batch_stride
533546
+ channel_id * params.conv_state_c_stride;

csrc/mamba/causal_conv1d/causal_conv1d.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ struct ConvParamsBase {
1313
using index_t = uint32_t;
1414

1515
int batch, dim, seqlen, width;
16+
int64_t pad_slot_id;
1617
bool silu_activation;
1718

1819
index_t x_batch_stride;

csrc/mamba/mamba_ssm/selective_scan.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ struct SSMParamsBase {
2121
int dim_ngroups_ratio;
2222
bool is_variable_B;
2323
bool is_variable_C;
24+
int64_t pad_slot_id;
2425

2526
bool delta_softplus;
2627

csrc/mamba/mamba_ssm/selective_scan_fwd.cu

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
115115
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
116116
: reinterpret_cast<int *>(params.cache_indices_ptr);
117117
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
118+
// cache_index == params.pad_slot_id is defined as padding, so we exit early
119+
if (cache_index == params.pad_slot_id){
120+
return;
121+
}
118122
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
119123
+ dim_id * kNRows * params.u_d_stride;
120124
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + sequence_start_index * params.delta_batch_stride
@@ -387,7 +391,6 @@ void set_ssm_params_fwd(SSMParamsBase &params,
387391
const size_t seqlen,
388392
const size_t dstate,
389393
const size_t n_groups,
390-
const size_t n_chunks,
391394
const bool is_variable_B,
392395
const bool is_variable_C,
393396
// device pointers
@@ -407,7 +410,8 @@ void set_ssm_params_fwd(SSMParamsBase &params,
407410
const c10::optional<at::Tensor>& query_start_loc,
408411
const c10::optional<at::Tensor>& cache_indices,
409412
const c10::optional<at::Tensor>& has_initial_state,
410-
bool varlen) {
413+
bool varlen,
414+
int64_t pad_slot_id) {
411415

412416
// Reset the parameters
413417
memset(&params, 0, sizeof(params));
@@ -417,8 +421,8 @@ void set_ssm_params_fwd(SSMParamsBase &params,
417421
params.seqlen = seqlen;
418422
params.dstate = dstate;
419423
params.n_groups = n_groups;
420-
params.n_chunks = n_chunks;
421424
params.dim_ngroups_ratio = dim / n_groups;
425+
params.pad_slot_id = pad_slot_id;
422426

423427
params.delta_softplus = delta_softplus;
424428

@@ -507,7 +511,10 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
507511
const c10::optional<torch::Tensor> &query_start_loc,
508512
const c10::optional<torch::Tensor> &cache_indices,
509513
const c10::optional<torch::Tensor> &has_initial_state,
510-
const torch::Tensor &ssm_states) {
514+
const torch::Tensor &ssm_states,
515+
// used to identify padding entries if cache_indices provided
516+
// in case of padding, the kernel will return early
517+
int64_t pad_slot_id) {
511518
auto input_type = u.scalar_type();
512519
auto weight_type = A.scalar_type();
513520
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@@ -618,18 +625,14 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
618625

619626
out_z = z;
620627

621-
const int n_chunks = (seqlen + 2048 - 1) / 2048;
622-
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
623-
// at::Tensor out = torch::empty_like(u);
624628
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
625629
at::Tensor out = delta;
626630
TORCH_CHECK(ssm_states.scalar_type() == input_type);
627631
TORCH_CHECK(ssm_states.is_cuda());
628632
TORCH_CHECK(ssm_states.stride(-1) == 1);
629-
CHECK_SHAPE(ssm_states, batch_size, dim, dstate);
630633

631634
SSMParamsBase params;
632-
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
635+
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, is_variable_B, is_variable_C,
633636
u, delta, A, B, C, out, z, out_z,
634637
D_,
635638
delta_bias_,
@@ -639,7 +642,8 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
639642
query_start_loc,
640643
cache_indices,
641644
has_initial_state,
642-
varlen
645+
varlen,
646+
pad_slot_id
643647
);
644648

645649

csrc/ops.h

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -157,21 +157,23 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
157157
const c10::optional<torch::Tensor>& query_start_loc,
158158
const c10::optional<torch::Tensor>& cache_indices,
159159
const c10::optional<torch::Tensor>& has_initial_state,
160-
const torch::Tensor& ssm_states);
161-
162-
at::Tensor causal_conv1d_update(
163-
const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight,
164-
const c10::optional<at::Tensor>& bias_, bool silu_activation,
165-
const c10::optional<at::Tensor>& cache_seqlens_,
166-
const c10::optional<at::Tensor>& conv_state_indices_);
167-
168-
at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
169-
const c10::optional<at::Tensor>& bias_,
170-
const c10::optional<at::Tensor>& conv_states,
171-
const c10::optional<at::Tensor>& query_start_loc,
172-
const c10::optional<at::Tensor>& cache_indices,
173-
const c10::optional<at::Tensor>& has_initial_state,
174-
bool silu_activation);
160+
const torch::Tensor& ssm_states, int64_t pad_slot_id);
161+
162+
void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
163+
const at::Tensor& weight,
164+
const c10::optional<at::Tensor>& bias_,
165+
bool silu_activation,
166+
const c10::optional<at::Tensor>& cache_seqlens_,
167+
const c10::optional<at::Tensor>& conv_state_indices_,
168+
int64_t pad_slot_id);
169+
170+
void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
171+
const c10::optional<at::Tensor>& bias_,
172+
const c10::optional<at::Tensor>& conv_states,
173+
const c10::optional<at::Tensor>& query_start_loc,
174+
const c10::optional<at::Tensor>& cache_indices,
175+
const c10::optional<at::Tensor>& has_initial_state,
176+
bool silu_activation, int64_t pad_slot_id);
175177

176178
#ifndef USE_ROCM
177179
using fptr_t = int64_t;

csrc/torch_bindings.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
278278
"Tensor? query_start_loc,"
279279
"Tensor? cache_indices,"
280280
"Tensor? has_initial_state,"
281-
"Tensor! ssm_states) -> ()");
281+
"Tensor! ssm_states,"
282+
"int pad_slot_id) -> ()");
282283
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
283284

284285
ops.def(
@@ -288,7 +289,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
288289
"Tensor? bias_,"
289290
"bool silu_activation,"
290291
"Tensor? cache_seqlens_,"
291-
"Tensor? conv_state_indices) -> Tensor");
292+
"Tensor? conv_state_indices,"
293+
"int pad_slot_id) -> ()");
292294
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
293295

294296
ops.def(
@@ -298,7 +300,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
298300
"Tensor? query_start_loc,"
299301
"Tensor? cache_indices,"
300302
"Tensor? has_initial_state,"
301-
"bool silu_activation) -> Tensor");
303+
"bool silu_activation,"
304+
"int pad_slot_id) -> ()");
302305
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
303306
#endif
304307

0 commit comments

Comments
 (0)