24
24
from vllm_ascend .utils import npu_prefetch
25
25
from vllm_ascend .worker .npu_input_batch import InputBatch
26
26
27
+ import torch .distributed as dist
28
+ from vllm .distributed .parallel_state import get_dp_group
29
+ from vllm_ascend .distributed .parallel_state import get_mla_sp_world_group
30
+ from vllm_ascend .mla_sp_context import get_sp_context
31
+ from vllm_ascend .ops .shard import RowShardLinear
32
+
27
33
if TYPE_CHECKING :
28
34
from vllm .v1 .core .sched .output import SchedulerOutput
29
35
@@ -962,6 +968,126 @@ def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
962
968
prefill_q_nope , prefill_q_pe , prefill_k_nope , prefill_k_pe ,
963
969
prefill_value )
964
970
return decode_preprocess_res , prefill_preprocess_res
971
+
972
+ def _forward_prefill_sp (
973
+ self ,
974
+ hidden_states : torch .Tensor ,
975
+ kv_cache : Tuple [torch .Tensor ],
976
+ attn_metadata : M ,
977
+ ) -> torch .Tensor :
978
+ sp_context = get_sp_context ()
979
+ assert sp_context is not None
980
+ npu_prefetch (self .q_a_proj .weight ,
981
+ hidden_states ,
982
+ enabled = self .enable_prefetch )
983
+ # Split inputs from local DP to each device.
984
+ dp_sp_hidden_states = hidden_states
985
+ rank_sp_hidden_states = dp_sp_hidden_states [sp_context .my_rank_sp_start_token_within_dp :sp_context .my_rank_sp_end_token_within_dp ]
986
+ sp_tokens = rank_sp_hidden_states .shape [0 ]
987
+ if sp_tokens == 0 :
988
+ rank_sp_hidden_states = nn .functional .pad (rank_sp_hidden_states , (0 , 0 , 0 , 1 ))
989
+ sp_tokens = 1
990
+ # MLA prefill:
991
+ # 1. Perform q_a_proj and q_a_layernorm to obtain q_c
992
+ # 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
993
+ # 3. If need_gather_q_kv, perform all_gather.
994
+ sp_ckq = self .q_a_proj (rank_sp_hidden_states )[0 ]
995
+ sp_hidden_states_or_q_c = self .q_a_layernorm (sp_ckq )
996
+ sp_kv_no_split = self .kv_a_proj_with_mqa (rank_sp_hidden_states )[0 ]
997
+ # Rearrange down_proj outputs across DP.
998
+ sp_output = torch .cat ([sp_hidden_states_or_q_c , sp_kv_no_split ], dim = 1 )
999
+ if sp_tokens < sp_context .num_tokens_per_rank :
1000
+ sp_output = nn .functional .pad (sp_output , (0 , 0 , 0 , sp_context .num_tokens_per_rank - sp_tokens ))
1001
+ global_sp_output = get_mla_sp_world_group ().all_gather (sp_output , 0 )
1002
+ my_dp = sp_context .my_dp
1003
+ dp_output = global_sp_output [sp_context .start_token_of_dp [my_dp ]:sp_context .end_token_of_dp [my_dp ]]
1004
+ prefill_q_c , prefill_kv_no_split = dp_output .split ([self .q_lora_rank , self .kv_lora_rank + self .qk_rope_head_dim ], dim = - 1 )
1005
+
1006
+ if attn_metadata is None :
1007
+ # Dummy run, just construct the attention outputs.
1008
+ output_prefill = torch .empty (
1009
+ [prefill_q_c .shape [0 ], self .num_heads * self .v_head_dim ],
1010
+ dtype = hidden_states .dtype ,
1011
+ device = hidden_states .device
1012
+ )
1013
+ else :
1014
+ # Preprocess prefill tokens, write kv cache and get:
1015
+ # prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value
1016
+ prefill_q = self .q_proj (prefill_q_c )[0 ]\
1017
+ .view (- 1 , self .num_heads , self .qk_head_dim )
1018
+ prefill_q_pe = prefill_q [..., self .qk_nope_head_dim :]
1019
+ prefill_q_nope = prefill_q [..., :self .qk_nope_head_dim ]
1020
+ cos = attn_metadata .prefill .cos
1021
+ sin = attn_metadata .prefill .sin
1022
+ prefill_slots = attn_metadata .slot_mapping
1023
+ prefill_q_pe = self .rope_single (prefill_q_pe , cos , sin )
1024
+ prefill_k_pe , prefill_k_c_normed = self .exec_kv_prefill (
1025
+ prefill_kv_no_split , cos , sin , kv_cache , prefill_slots )
1026
+ prefill_k_pe = prefill_k_pe .view (prefill_q_c .shape [0 ],
1027
+ self .num_kv_heads , - 1 )
1028
+ prefill_k_nope , prefill_value = self .kv_b_proj (
1029
+ prefill_k_c_normed )[0 ].view (
1030
+ - 1 , self .num_heads ,
1031
+ self .qk_nope_head_dim + self .v_head_dim ).split (
1032
+ [self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
1033
+ prefill_k_pe = prefill_k_pe .expand (
1034
+ (* prefill_k_nope .shape [:- 1 ], - 1 ))
1035
+ # Attention outputs.
1036
+ output_prefill = self ._forward_prefill (
1037
+ prefill_q_nope , prefill_q_pe ,
1038
+ prefill_k_nope , prefill_k_pe ,
1039
+ prefill_value , kv_cache , attn_metadata )
1040
+
1041
+ # Rearrange attention outputs across DP to run SP.
1042
+ sp_world_group = get_mla_sp_world_group ()
1043
+ tp_size = get_tp_group ().world_size
1044
+ my_rank = sp_context .my_rank
1045
+ my_rank_start_token = sp_context .rank_sp_start_token [my_rank ]
1046
+ my_rank_end_token = sp_context .rank_sp_end_token [my_rank ]
1047
+ num_sp_tokens = max (my_rank_end_token - my_rank_start_token , 0 )
1048
+ sp_send = output_prefill
1049
+ if get_dp_group ().world_size == 1 :
1050
+ padded_len = sp_context .num_tokens_per_rank * sp_world_group .world_size
1051
+ if sp_context .num_global_tokens < padded_len :
1052
+ sp_send = nn .functional .pad (sp_send , (0 , 0 , 0 , padded_len - sp_context .num_global_tokens ))
1053
+ sp_output = torch .empty (
1054
+ [sp_context .num_tokens_per_rank * tp_size , self .num_heads * self .v_head_dim ],
1055
+ dtype = sp_send .dtype ,
1056
+ device = sp_send .device
1057
+ )
1058
+ dist .all_to_all_single (
1059
+ output = sp_output ,
1060
+ input = sp_send ,
1061
+ group = sp_world_group .device_group ,
1062
+ )
1063
+ sp_output = sp_output .reshape (sp_context .num_tokens_per_rank , tp_size * self .num_heads * self .v_head_dim )
1064
+ sp_output = sp_output [:num_sp_tokens ]
1065
+ else :
1066
+ sp_output = torch .empty (
1067
+ [num_sp_tokens * tp_size , self .num_heads * self .v_head_dim ],
1068
+ dtype = sp_send .dtype ,
1069
+ device = sp_send .device
1070
+ )
1071
+ dist .all_to_all_single (
1072
+ output = sp_output ,
1073
+ input = sp_send ,
1074
+ output_split_sizes = sp_context .output_split_sizes ,
1075
+ input_split_sizes = sp_context .input_split_sizes ,
1076
+ group = sp_world_group .device_group ,
1077
+ )
1078
+ sp_output = sp_output .reshape (num_sp_tokens , tp_size * self .num_heads * self .v_head_dim )
1079
+ sp_tokens = sp_output .shape [0 ]
1080
+ if sp_tokens == 0 :
1081
+ sp_output = nn .functional .pad (sp_output , (0 , 0 , 0 , 1 ))
1082
+ sp_tokens = 1
1083
+ # O proj
1084
+ o_output = self .o_proj (sp_output )[0 ]
1085
+ del sp_output
1086
+ if sp_tokens < sp_context .num_tokens_per_rank :
1087
+ o_output = nn .functional .pad (o_output , (0 , 0 , 0 , sp_context .num_tokens_per_rank - sp_tokens ))
1088
+ dp_output = get_tp_group ().all_gather (o_output , 0 )
1089
+ dp_output = dp_output [:sp_context .num_my_dp_sp_tokens ]
1090
+ return dp_output
965
1091
966
1092
def forward (
967
1093
self ,
@@ -972,6 +1098,10 @@ def forward(
972
1098
output : Optional [torch .Tensor ] = None ,
973
1099
) -> torch .Tensor :
974
1100
assert output is not None , "Output tensor must be provided."
1101
+ if get_sp_context () is not None :
1102
+ # SP across DP
1103
+ output [...] = self ._forward_prefill_sp (hidden_states , kv_cache , attn_metadata )
1104
+ return output
975
1105
if attn_metadata is None :
976
1106
# Profiling run.
977
1107
return output
@@ -1024,6 +1154,11 @@ def forward(
1024
1154
current_ms_metadata .after_comm_event .record ()
1025
1155
else :
1026
1156
o_proj_input [num_decode_tokens :] = output_prefill
1157
+
1158
+ # When o_proj sharding is enabled, make sure the second dimension is complete while decoding.
1159
+ if isinstance (self .o_proj , RowShardLinear ):
1160
+ o_proj_input = get_tp_group ().all_gather (o_proj_input , dim = - 1 )
1161
+
1027
1162
# O proj
1028
1163
current_ms_metadata = get_multistream_comm_context ()
1029
1164
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
0 commit comments