Skip to content

Commit ae758dd

Browse files
authored
[Bugfix] Fix mtp torchair in pd Disaggregation scenario (#2951)
### What this PR does / why we need it? 1. In memory of #2509, Fix mtp torchair in pd Disaggregation scenario 2. fix mla bug in SpecDecoding Scenario, since num_decodes != num_decode_tokens ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.10.2 - vLLM main: vllm-project/vllm@5206ab2 Signed-off-by: xuyexiong <xuyexiong@huawei.com>
1 parent 6b7117d commit ae758dd

File tree

3 files changed

+58
-9
lines changed

3 files changed

+58
-9
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,11 +379,12 @@ def build(
379379

380380
decode_metadata = None
381381
if num_decodes > 0:
382+
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
382383
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
383384
max_seq_lens = seq_lens[:num_decodes].max().item()
384-
seq_lens = seq_lens[:num_decode_tokens]
385+
seq_lens = seq_lens[:num_decodes]
385386
input_positions = input_positions[:num_decode_tokens]
386-
block_table = block_table[:num_decode_tokens, ...]
387+
block_table = block_table[:num_decodes, ...]
387388
seq_lens_list = seq_lens.tolist()
388389

389390
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore

vllm_ascend/torchair/torchair_mla.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -495,11 +495,12 @@ def build(
495495
graph_pad_size = common_attn_metadata.graph_pad_size
496496
use_torchair_graph = graph_pad_size != -1
497497
if num_decodes > 0:
498+
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
498499
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
499500
max_seq_lens = seq_lens[:num_decodes].max().item()
500-
seq_lens = seq_lens[:num_decode_tokens]
501+
seq_lens = seq_lens[:num_decodes]
501502
input_positions = input_positions[:num_decode_tokens]
502-
block_table = block_table[:num_decode_tokens, ...]
503+
block_table = block_table[:num_decodes, ...]
503504
num_token_pad_size = 0
504505
if use_torchair_graph and common_attn_metadata.attn_state in [
505506
AscendAttentionState.DecodeOnly,
@@ -538,10 +539,9 @@ def build(
538539
device=input_positions.device)
539540
input_positions = torch.cat(
540541
[input_positions, position_padding])
541-
actual_seq_lengths_q = (
542-
actual_seq_lengths_q + common_attn_metadata.
543-
actual_seq_lengths_q[num_reqs:num_reqs +
544-
num_reqs_pad_size])
542+
actual_seq_lengths_q = self.pad_actual_seq_len_q(
543+
num_reqs_pad_size, num_reqs, actual_seq_lengths_q,
544+
common_attn_metadata)
545545
else:
546546
seq_lens_list = seq_lens.tolist()
547547
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens)
@@ -584,6 +584,48 @@ def build(
584584
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
585585
)
586586

587+
def pad_actual_seq_len_q(self, num_reqs_pad_size, num_reqs,
588+
actual_seq_lengths_q, common_attn_metadata):
589+
"""
590+
Pads actual_seq_lengths_q evenly to not exceed 16 tokens per request
591+
in order to meet the requirement of npu_fused_infer_attention_score.
592+
593+
In Torchair scenario, the lengths of the queries must be padded to the same length.
594+
And npu_fused_infer_attention_score constraint requires the last element must equal to batch_size(num_tokens).
595+
596+
For example:
597+
batch_size=36, num_reqs_pad_size=2, num_reqs=16
598+
By default, each request should have inference 2 token, which means actual_seq_lengths_q should be
599+
[2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36].
600+
601+
However, mtp torchair + PD scenario, the actual_seq_lengths_q may be
602+
[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] before padding, since the first decode request only has 1 token.
603+
In order to meet the requirement of npu_fused_infer_attention_score, we need to pad actual_seq_lengths_q evenly to not exceed 16 tokens per request.
604+
after padding actual_seq_lengths_q should be similar to [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,36]
605+
"""
606+
FIA_SEQ_LEN_LIMIT = 16
607+
need_padding = num_reqs_pad_size != 0 and \
608+
len(common_attn_metadata.actual_seq_lengths_q) > num_reqs and \
609+
common_attn_metadata.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[-1] > FIA_SEQ_LEN_LIMIT
610+
if need_padding:
611+
padding_seq_len_q = common_attn_metadata.actual_seq_lengths_q[
612+
num_reqs:num_reqs + num_reqs_pad_size]
613+
start_val = actual_seq_lengths_q[-1]
614+
end_val = padding_seq_len_q[-1]
615+
616+
num_step = len(padding_seq_len_q)
617+
interpolated = np.round(
618+
np.linspace(start_val, end_val,
619+
num_step + 1)[1:]).astype(int).tolist()
620+
assert interpolated[-1] == end_val
621+
assert len(interpolated) == len(padding_seq_len_q)
622+
actual_seq_lengths_q = actual_seq_lengths_q + interpolated
623+
else:
624+
actual_seq_lengths_q = actual_seq_lengths_q + common_attn_metadata.actual_seq_lengths_q[
625+
num_reqs:num_reqs + num_reqs_pad_size]
626+
627+
return actual_seq_lengths_q
628+
587629

588630
class AscendMLATorchairImpl(MLAAttentionImpl):
589631
"""

vllm_ascend/torchair/torchair_model_runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,13 @@ def select_torchair_padded_batch_size(self, batch_size: int):
424424
def update_torchair_graph_batch_sizes(self):
425425
# return graph_batch_sizes according to the max number of tokens
426426
# first pad according to the number of requests
427-
if len(self.torchair_graph_batch_sizes) == 0:
427+
if self.is_kv_consumer and self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
428+
# pd disaggregation scenario may incorrectly calculate the batch in mtp scenario, so we force set it to max_num_reqs
429+
self.torchair_graph_batch_sizes = [self.max_num_reqs]
430+
logger.warning(
431+
"is kv_consumer, torch_graph_batch_sizes sets to [max_num_seqs]"
432+
)
433+
elif len(self.torchair_graph_batch_sizes) == 0:
428434
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
429435
else:
430436
self.torchair_graph_batch_sizes = sorted(

0 commit comments

Comments
 (0)