2
2
from typing import TYPE_CHECKING , NamedTuple , Optional , Tuple , Type , TypeVar
3
3
4
4
import torch
5
+ import torch .distributed as dist
5
6
import torch_npu
6
7
from torch import nn
7
8
from vllm .attention .backends .abstract import (AttentionBackend ,
8
9
AttentionMetadata ,
9
10
MLAAttentionImpl )
10
11
from vllm .config import VllmConfig , get_current_vllm_config
11
12
from vllm .distributed import get_tensor_model_parallel_world_size , get_tp_group
13
+ from vllm .distributed .parallel_state import get_dp_group
12
14
from vllm .model_executor .layers .linear import (LinearBase ,
13
15
UnquantizedLinearMethod )
14
16
from vllm .utils import cdiv , round_down
17
19
from vllm_ascend .attention .attention_v1 import AscendAttentionState
18
20
from vllm_ascend .attention .utils import (AscendCommonAttentionMetadata ,
19
21
split_decodes_and_prefills )
22
+ from vllm_ascend .distributed .parallel_state import get_mla_sp_world_group
23
+ from vllm_ascend .mla_sp_context import get_sp_context
20
24
from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
21
25
from vllm_ascend .multistream .context import get_multistream_comm_context
22
26
from vllm_ascend .multistream .ms_split import model_input_split_v1_mla_attn
23
27
from vllm_ascend .ops .attention import vanilla_chunked_prefill_mla
28
+ from vllm_ascend .ops .shard import RowShardLinear
24
29
from vllm_ascend .utils import npu_prefetch
25
30
from vllm_ascend .worker .npu_input_batch import InputBatch
26
31
27
- import torch .distributed as dist
28
- from vllm .distributed .parallel_state import get_dp_group
29
- from vllm_ascend .distributed .parallel_state import get_mla_sp_world_group
30
- from vllm_ascend .mla_sp_context import get_sp_context
31
- from vllm_ascend .ops .shard import RowShardLinear
32
-
33
32
if TYPE_CHECKING :
34
33
from vllm .v1 .core .sched .output import SchedulerOutput
35
34
@@ -968,7 +967,7 @@ def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
968
967
prefill_q_nope , prefill_q_pe , prefill_k_nope , prefill_k_pe ,
969
968
prefill_value )
970
969
return decode_preprocess_res , prefill_preprocess_res
971
-
970
+
972
971
def _forward_prefill_sp (
973
972
self ,
974
973
hidden_states : torch .Tensor ,
@@ -982,10 +981,13 @@ def _forward_prefill_sp(
982
981
enabled = self .enable_prefetch )
983
982
# Split inputs from local DP to each device.
984
983
dp_sp_hidden_states = hidden_states
985
- rank_sp_hidden_states = dp_sp_hidden_states [sp_context .my_rank_sp_start_token_within_dp :sp_context .my_rank_sp_end_token_within_dp ]
984
+ rank_sp_hidden_states = dp_sp_hidden_states [
985
+ sp_context .my_rank_sp_start_token_within_dp :sp_context .
986
+ my_rank_sp_end_token_within_dp ]
986
987
sp_tokens = rank_sp_hidden_states .shape [0 ]
987
988
if sp_tokens == 0 :
988
- rank_sp_hidden_states = nn .functional .pad (rank_sp_hidden_states , (0 , 0 , 0 , 1 ))
989
+ rank_sp_hidden_states = nn .functional .pad (rank_sp_hidden_states ,
990
+ (0 , 0 , 0 , 1 ))
989
991
sp_tokens = 1
990
992
# MLA prefill:
991
993
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
@@ -997,19 +999,23 @@ def _forward_prefill_sp(
997
999
# Rearrange down_proj outputs across DP.
998
1000
sp_output = torch .cat ([sp_hidden_states_or_q_c , sp_kv_no_split ], dim = 1 )
999
1001
if sp_tokens < sp_context .num_tokens_per_rank :
1000
- sp_output = nn .functional .pad (sp_output , (0 , 0 , 0 , sp_context .num_tokens_per_rank - sp_tokens ))
1002
+ sp_output = nn .functional .pad (
1003
+ sp_output ,
1004
+ (0 , 0 , 0 , sp_context .num_tokens_per_rank - sp_tokens ))
1001
1005
global_sp_output = get_mla_sp_world_group ().all_gather (sp_output , 0 )
1002
1006
my_dp = sp_context .my_dp
1003
- dp_output = global_sp_output [sp_context .start_token_of_dp [my_dp ]:sp_context .end_token_of_dp [my_dp ]]
1004
- prefill_q_c , prefill_kv_no_split = dp_output .split ([self .q_lora_rank , self .kv_lora_rank + self .qk_rope_head_dim ], dim = - 1 )
1007
+ dp_output = global_sp_output [sp_context .start_token_of_dp [my_dp ]:
1008
+ sp_context .end_token_of_dp [my_dp ]]
1009
+ prefill_q_c , prefill_kv_no_split = dp_output .split (
1010
+ [self .q_lora_rank , self .kv_lora_rank + self .qk_rope_head_dim ],
1011
+ dim = - 1 )
1005
1012
1006
1013
if attn_metadata is None :
1007
1014
# Dummy run, just construct the attention outputs.
1008
1015
output_prefill = torch .empty (
1009
1016
[prefill_q_c .shape [0 ], self .num_heads * self .v_head_dim ],
1010
1017
dtype = hidden_states .dtype ,
1011
- device = hidden_states .device
1012
- )
1018
+ device = hidden_states .device )
1013
1019
else :
1014
1020
# Preprocess prefill tokens, write kv cache and get:
1015
1021
# prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value
@@ -1025,7 +1031,7 @@ def _forward_prefill_sp(
1025
1031
prefill_k_pe , prefill_k_c_normed = self .exec_kv_prefill (
1026
1032
prefill_kv_no_split , cos , sin , kv_cache , prefill_slots )
1027
1033
prefill_k_pe = prefill_k_pe .view (prefill_q_c .shape [0 ],
1028
- self .num_kv_heads , - 1 )
1034
+ self .num_kv_heads , - 1 )
1029
1035
prefill_k_nope , prefill_value = self .kv_b_proj (
1030
1036
prefill_k_c_normed )[0 ].view (
1031
1037
- 1 , self .num_heads ,
@@ -1034,10 +1040,11 @@ def _forward_prefill_sp(
1034
1040
prefill_k_pe = prefill_k_pe .expand (
1035
1041
(* prefill_k_nope .shape [:- 1 ], - 1 ))
1036
1042
# Attention outputs.
1037
- output_prefill = self ._forward_prefill (
1038
- prefill_q_nope , prefill_q_pe ,
1039
- prefill_k_nope , prefill_k_pe ,
1040
- prefill_value , kv_cache , attn_metadata )
1043
+ output_prefill = self ._forward_prefill (prefill_q_nope ,
1044
+ prefill_q_pe ,
1045
+ prefill_k_nope ,
1046
+ prefill_k_pe , prefill_value ,
1047
+ kv_cache , attn_metadata )
1041
1048
1042
1049
# Rearrange attention outputs across DP to run SP.
1043
1050
sp_world_group = get_mla_sp_world_group ()
@@ -1050,33 +1057,38 @@ def _forward_prefill_sp(
1050
1057
if get_dp_group ().world_size == 1 :
1051
1058
padded_len = sp_context .num_tokens_per_rank * sp_world_group .world_size
1052
1059
if sp_context .num_global_tokens < padded_len :
1053
- sp_send = nn .functional .pad (sp_send , (0 , 0 , 0 , padded_len - sp_context .num_global_tokens ))
1054
- sp_output = torch .empty (
1055
- [sp_context .num_tokens_per_rank * tp_size , self .num_heads * self .v_head_dim ],
1056
- dtype = sp_send .dtype ,
1057
- device = sp_send .device
1058
- )
1060
+ sp_send = nn .functional .pad (
1061
+ sp_send ,
1062
+ (0 , 0 , 0 , padded_len - sp_context .num_global_tokens ))
1063
+ sp_output = torch .empty ([
1064
+ sp_context .num_tokens_per_rank * tp_size ,
1065
+ self .num_heads * self .v_head_dim
1066
+ ],
1067
+ dtype = sp_send .dtype ,
1068
+ device = sp_send .device )
1059
1069
dist .all_to_all_single (
1060
1070
output = sp_output ,
1061
1071
input = sp_send ,
1062
1072
group = sp_world_group .device_group ,
1063
1073
)
1064
- sp_output = sp_output .reshape (sp_context .num_tokens_per_rank , tp_size * self .num_heads * self .v_head_dim )
1074
+ sp_output = sp_output .reshape (
1075
+ sp_context .num_tokens_per_rank ,
1076
+ tp_size * self .num_heads * self .v_head_dim )
1065
1077
sp_output = sp_output [:num_sp_tokens ]
1066
1078
else :
1067
1079
sp_output = torch .empty (
1068
1080
[num_sp_tokens * tp_size , self .num_heads * self .v_head_dim ],
1069
1081
dtype = sp_send .dtype ,
1070
- device = sp_send .device
1071
- )
1082
+ device = sp_send .device )
1072
1083
dist .all_to_all_single (
1073
1084
output = sp_output ,
1074
1085
input = sp_send ,
1075
1086
output_split_sizes = sp_context .output_split_sizes ,
1076
1087
input_split_sizes = sp_context .input_split_sizes ,
1077
1088
group = sp_world_group .device_group ,
1078
1089
)
1079
- sp_output = sp_output .reshape (num_sp_tokens , tp_size * self .num_heads * self .v_head_dim )
1090
+ sp_output = sp_output .reshape (
1091
+ num_sp_tokens , tp_size * self .num_heads * self .v_head_dim )
1080
1092
sp_tokens = sp_output .shape [0 ]
1081
1093
if sp_tokens == 0 :
1082
1094
sp_output = nn .functional .pad (sp_output , (0 , 0 , 0 , 1 ))
@@ -1085,7 +1097,9 @@ def _forward_prefill_sp(
1085
1097
o_output = self .o_proj (sp_output )[0 ]
1086
1098
del sp_output
1087
1099
if sp_tokens < sp_context .num_tokens_per_rank :
1088
- o_output = nn .functional .pad (o_output , (0 , 0 , 0 , sp_context .num_tokens_per_rank - sp_tokens ))
1100
+ o_output = nn .functional .pad (
1101
+ o_output ,
1102
+ (0 , 0 , 0 , sp_context .num_tokens_per_rank - sp_tokens ))
1089
1103
dp_output = get_tp_group ().all_gather (o_output , 0 )
1090
1104
dp_output = dp_output [:sp_context .num_my_dp_sp_tokens ]
1091
1105
return dp_output
@@ -1101,7 +1115,8 @@ def forward(
1101
1115
assert output is not None , "Output tensor must be provided."
1102
1116
if get_sp_context () is not None :
1103
1117
# SP across DP
1104
- output [...] = self ._forward_prefill_sp (hidden_states , kv_cache , attn_metadata )
1118
+ output [...] = self ._forward_prefill_sp (hidden_states , kv_cache ,
1119
+ attn_metadata )
1105
1120
return output
1106
1121
if attn_metadata is None :
1107
1122
# Profiling run.
0 commit comments