16
16
from vllm_ascend .ascend_config import get_ascend_config
17
17
from vllm_ascend .attention .attention_v1 import AscendAttentionState
18
18
from vllm_ascend .attention .utils import (AscendCommonAttentionMetadata ,
19
- split_decodes_and_prefills )
19
+ maybe_save_kv_layer_to_connector ,
20
+ split_decodes_and_prefills ,
21
+ wait_for_kv_layer_from_connector )
20
22
from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
21
23
from vllm_ascend .multistream .context import get_multistream_comm_context
22
24
from vllm_ascend .multistream .ms_split import model_input_split_v1_mla_attn
@@ -853,8 +855,8 @@ def _forward_decode(
853
855
current_ms_metadata .before_comm_event .wait ()
854
856
return self ._v_up_proj (attn_output )
855
857
856
- def _mla_preprocess (self , hidden_states , kv_cache , attn_metadata ,
857
- need_gather_q_kv ):
858
+ def _mla_preprocess (self , layer_name , hidden_states , kv_cache ,
859
+ attn_metadata , need_gather_q_kv ):
858
860
# MLA Preprocess:
859
861
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
860
862
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
@@ -883,6 +885,8 @@ def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
883
885
kv_no_split = get_tp_group ().all_gather (kv_no_split , 0 )
884
886
decode_preprocess_res = None
885
887
prefill_preprocess_res = None
888
+ if has_prefill :
889
+ wait_for_kv_layer_from_connector (layer_name )
886
890
# Preprocess for decode tokens
887
891
if has_decode :
888
892
decode_q_c = q_c [:num_decode_tokens ]
@@ -929,6 +933,7 @@ def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
929
933
930
934
def forward (
931
935
self ,
936
+ layer_name ,
932
937
hidden_states : torch .Tensor , # query in unified attn
933
938
kv_cache : Tuple [torch .Tensor ],
934
939
attn_metadata : M ,
@@ -955,7 +960,8 @@ def forward(
955
960
956
961
# MLA Preprocess
957
962
decode_preprocess_res , prefill_preprocess_res = self ._mla_preprocess (
958
- hidden_states , kv_cache , attn_metadata , need_gather_q_kv )
963
+ layer_name , hidden_states , kv_cache , attn_metadata ,
964
+ need_gather_q_kv )
959
965
960
966
if decode_preprocess_res is not None :
961
967
# MLA Preprocess for decoding
@@ -1013,4 +1019,8 @@ def forward(
1013
1019
is_force_scatter = self .enable_shared_expert_dp )[0 ]
1014
1020
current_ms_metadata .after_comm_event .record ()
1015
1021
del o_proj_input
1022
+
1023
+ has_prefill = attn_metadata .num_prefills > 0
1024
+ if has_prefill :
1025
+ maybe_save_kv_layer_to_connector (layer_name , list (kv_cache ))
1016
1026
return output_padded
0 commit comments