File tree Expand file tree Collapse file tree 1 file changed +9
-7
lines changed Expand file tree Collapse file tree 1 file changed +9
-7
lines changed Original file line number Diff line number Diff line change @@ -484,13 +484,8 @@ def __init__(
484
484
self .chunked_prefill_for_mla = ascend_config .chunked_prefill_for_mla
485
485
486
486
vllm_config = get_current_vllm_config ()
487
- RING_MLA_MASK_SIZE = 512
488
- self .prefill_mask = torch .triu (
489
- torch .ones (RING_MLA_MASK_SIZE ,
490
- RING_MLA_MASK_SIZE ,
491
- device = "npu" ,
492
- dtype = vllm_config .model_config .dtype ),
493
- 1 )
487
+ self .ring_mla_mask_size = 512
488
+ self .prefill_mask = None
494
489
495
490
# Adapt torch air graph mode with spec decoding.
496
491
speculative_config = vllm_config .speculative_config
@@ -686,6 +681,13 @@ def _forward_prefill(
686
681
num_tokens ,
687
682
dtype = torch .float32 ,
688
683
device = q_nope .device )
684
+ if self .prefill_mask is None :
685
+ self .prefill_mask = torch .triu (
686
+ torch .ones (self .ring_mla_mask_size ,
687
+ self .ring_mla_mask_size ,
688
+ device = q_nope .device ,
689
+ dtype = q_nope .dtype ),
690
+ 1 )
689
691
torch_npu .atb .npu_ring_mla (
690
692
q_nope = q_nope ,
691
693
q_rope = q_pe ,
You can’t perform that action at this time.
0 commit comments