@@ -495,11 +495,12 @@ def build(
495
495
graph_pad_size = common_attn_metadata .graph_pad_size
496
496
use_torchair_graph = graph_pad_size != - 1
497
497
if num_decodes > 0 :
498
+ # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
498
499
actual_seq_lengths_q = query_start_loc [1 :num_decodes + 1 ].tolist ()
499
500
max_seq_lens = seq_lens [:num_decodes ].max ().item ()
500
- seq_lens = seq_lens [:num_decode_tokens ]
501
+ seq_lens = seq_lens [:num_decodes ]
501
502
input_positions = input_positions [:num_decode_tokens ]
502
- block_table = block_table [:num_decode_tokens , ...]
503
+ block_table = block_table [:num_decodes , ...]
503
504
num_token_pad_size = 0
504
505
if use_torchair_graph and common_attn_metadata .attn_state in [
505
506
AscendAttentionState .DecodeOnly ,
@@ -538,10 +539,9 @@ def build(
538
539
device = input_positions .device )
539
540
input_positions = torch .cat (
540
541
[input_positions , position_padding ])
541
- actual_seq_lengths_q = (
542
- actual_seq_lengths_q + common_attn_metadata .
543
- actual_seq_lengths_q [num_reqs :num_reqs +
544
- num_reqs_pad_size ])
542
+ actual_seq_lengths_q = self .pad_actual_seq_len_q (
543
+ num_reqs_pad_size , num_reqs , actual_seq_lengths_q ,
544
+ common_attn_metadata )
545
545
else :
546
546
seq_lens_list = seq_lens .tolist ()
547
547
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens)
@@ -584,6 +584,48 @@ def build(
584
584
enable_dbo_across_dp = common_attn_metadata .enable_dbo_across_dp ,
585
585
)
586
586
587
+ def pad_actual_seq_len_q (self , num_reqs_pad_size , num_reqs ,
588
+ actual_seq_lengths_q , common_attn_metadata ):
589
+ """
590
+ Pads actual_seq_lengths_q evenly to not exceed 16 tokens per request
591
+ in order to meet the requirement of npu_fused_infer_attention_score.
592
+
593
+ In Torchair scenario, the lengths of the queries must be padded to the same length.
594
+ And npu_fused_infer_attention_score constraint requires the last element must equal to batch_size(num_tokens).
595
+
596
+ For example:
597
+ batch_size=36, num_reqs_pad_size=2, num_reqs=16
598
+ By default, each request should have inference 2 token, which means actual_seq_lengths_q should be
599
+ [2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36].
600
+
601
+ However, mtp torchair + PD scenario, the actual_seq_lengths_q may be
602
+ [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] before padding, since the first decode request only has 1 token.
603
+ In order to meet the requirement of npu_fused_infer_attention_score, we need to pad actual_seq_lengths_q evenly to not exceed 16 tokens per request.
604
+ after padding actual_seq_lengths_q should be similar to [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,36]
605
+ """
606
+ FIA_SEQ_LEN_LIMIT = 16
607
+ need_padding = num_reqs_pad_size != 0 and \
608
+ len (common_attn_metadata .actual_seq_lengths_q ) > num_reqs and \
609
+ common_attn_metadata .actual_seq_lengths_q [num_reqs ] - actual_seq_lengths_q [- 1 ] > FIA_SEQ_LEN_LIMIT
610
+ if need_padding :
611
+ padding_seq_len_q = common_attn_metadata .actual_seq_lengths_q [
612
+ num_reqs :num_reqs + num_reqs_pad_size ]
613
+ start_val = actual_seq_lengths_q [- 1 ]
614
+ end_val = padding_seq_len_q [- 1 ]
615
+
616
+ num_step = len (padding_seq_len_q )
617
+ interpolated = np .round (
618
+ np .linspace (start_val , end_val ,
619
+ num_step + 1 )[1 :]).astype (int ).tolist ()
620
+ assert interpolated [- 1 ] == end_val
621
+ assert len (interpolated ) == len (padding_seq_len_q )
622
+ actual_seq_lengths_q = actual_seq_lengths_q + interpolated
623
+ else :
624
+ actual_seq_lengths_q = actual_seq_lengths_q + common_attn_metadata .actual_seq_lengths_q [
625
+ num_reqs :num_reqs + num_reqs_pad_size ]
626
+
627
+ return actual_seq_lengths_q
628
+
587
629
588
630
class AscendMLATorchairImpl (MLAAttentionImpl ):
589
631
"""
0 commit comments