File tree Expand file tree Collapse file tree 1 file changed +7
-11
lines changed Expand file tree Collapse file tree 1 file changed +7
-11
lines changed Original file line number Diff line number Diff line change @@ -483,7 +483,13 @@ def __init__(
483
483
self .enable_kv_nz = ascend_config .torchair_graph_config .enable_kv_nz
484
484
self .chunked_prefill_for_mla = ascend_config .chunked_prefill_for_mla
485
485
486
- self .prefill_mask = None
486
+ vllm_config = get_current_vllm_config ()
487
+ self .prefill_mask = torch .triu (
488
+ torch .ones (512 ,
489
+ 512 ,
490
+ device = "npu" ,
491
+ dtype = vllm_config .model_config .dtype ),
492
+ 1 ) # 512: mask only support 512
487
493
488
494
# Adapt torch air graph mode with spec decoding.
489
495
speculative_config = get_current_vllm_config ().speculative_config
@@ -679,16 +685,6 @@ def _forward_prefill(
679
685
num_tokens ,
680
686
dtype = torch .float32 ,
681
687
device = q_nope .device )
682
- if self .prefill_mask is None :
683
- self .prefill_mask = torch .triu (
684
- torch .ones (512 ,
685
- 512 ,
686
- device = q_nope .device ,
687
- dtype = q_nope .dtype ),
688
- 1 ) # 512: mask only support 512
689
- if attn_metadata .num_prefills > 1 :
690
- self .prefill_mask = self .prefill_mask .unsqueeze (0 ).repeat (
691
- attn_metadata .num_prefills , 1 , 1 )
692
688
torch_npu .atb .npu_ring_mla (
693
689
q_nope = q_nope ,
694
690
q_rope = q_pe ,
You can’t perform that action at this time.
0 commit comments