3
3
TypeVar )
4
4
5
5
import torch
6
+ import torch .distributed as dist
6
7
import torch_npu
7
8
from torch import nn
8
9
from vllm .attention .backends .abstract import (AttentionBackend ,
9
10
AttentionMetadata ,
10
11
MLAAttentionImpl )
11
12
from vllm .config import VllmConfig , get_current_vllm_config
12
- from vllm .distributed import get_tensor_model_parallel_world_size , get_tp_group
13
+ from vllm .distributed import (get_dp_group ,
14
+ get_tensor_model_parallel_world_size ,
15
+ get_tp_group )
13
16
from vllm .model_executor .layers .linear import (LinearBase , ReplicatedLinear ,
14
17
UnquantizedLinearMethod )
15
18
from vllm .utils import cdiv , round_down
21
24
maybe_save_kv_layer_to_connector ,
22
25
split_decodes_and_prefills ,
23
26
wait_for_kv_layer_from_connector )
27
+ from vllm_ascend .distributed .parallel_state import (
28
+ get_mla_dp_rebalancing_o_shared_group , get_mla_dp_rebalancing_world_group )
29
+ from vllm_ascend .mla_dp_rebalancing import get_mla_dp_rebalancing_context
24
30
from vllm_ascend .multistream .base import MSAttentionMetadataSplitConfig
25
31
from vllm_ascend .multistream .context import get_multistream_comm_context
26
32
from vllm_ascend .multistream .ms_split import model_input_split_v1_mla_attn
27
- from vllm_ascend .utils import npu_prefetch
33
+ from vllm_ascend .torchair .ops .shared_weight_layer import (
34
+ post_process_after_loading_for_shared_weight_series ,
35
+ reach_layer_for_shared_weight_series ,
36
+ register_layer_to_shared_weight_series )
37
+ from vllm_ascend .utils import dispose_tensor , npu_prefetch
28
38
from vllm_ascend .worker .npu_input_batch import InputBatch
29
39
30
- from vllm_ascend .distributed .parallel_state import get_mla_dp_rebalancing_o_shared_group , get_mla_dp_rebalancing_world_group
31
- from vllm_ascend .torchair .ops .shared_weight_layer import (post_process_after_loading_for_shared_weight_series ,
32
- reach_layer_for_shared_weight_series ,
33
- register_layer_to_shared_weight_series )
34
- from vllm_ascend .utils import dispose_tensor
35
- from vllm_ascend .mla_dp_rebalancing import get_mla_dp_rebalancing_context
36
- import torch .distributed as dist
37
- from vllm .distributed import get_dp_group
38
-
39
40
if TYPE_CHECKING :
40
41
from vllm .v1 .core .sched .output import SchedulerOutput
41
42
@@ -510,7 +511,7 @@ def __init__(
510
511
self .prefill_mask = None
511
512
512
513
self .speculative_config = vllm_config .speculative_config
513
-
514
+
514
515
if ascend_config .enable_mla_prefill_dp_rebalancing :
515
516
# Dispose tensor from the original o_proj
516
517
for attr_name in dir (self .o_proj ):
@@ -519,19 +520,21 @@ def __init__(
519
520
dispose_tensor (attr_value )
520
521
# Construct the new o_proj using ReplicatedLinear
521
522
config = vllm_config .model_config .hf_config
522
- new_o_proj = ReplicatedLinear (config .num_attention_heads * config .v_head_dim ,
523
- config .hidden_size ,
524
- bias = False ,
525
- quant_config = vllm_config .quant_config ,
526
- prefix = self .o_proj .prefix )
523
+ new_o_proj = ReplicatedLinear (
524
+ config .num_attention_heads * config .v_head_dim ,
525
+ config .hidden_size ,
526
+ bias = False ,
527
+ quant_config = vllm_config .quant_config ,
528
+ prefix = self .o_proj .prefix )
527
529
# Replace the o_proj with the new one
528
530
self .o_proj .__class__ = new_o_proj .__class__
529
531
self .o_proj .__dict__ = new_o_proj .__dict__
530
532
# Register the o_proj into shared weight series to cut down memory usage
531
- register_layer_to_shared_weight_series (series_name = "o_proj" ,
532
- group = get_mla_dp_rebalancing_o_shared_group (),
533
- layer = self .o_proj ,
534
- prefetch_step = 1 )
533
+ register_layer_to_shared_weight_series (
534
+ series_name = "o_proj" ,
535
+ group = get_mla_dp_rebalancing_o_shared_group (),
536
+ layer = self .o_proj ,
537
+ prefetch_step = 1 )
535
538
536
539
def _v_up_proj (self , x ):
537
540
# Convert from (B, N, L) to (N, B, L)
@@ -991,18 +994,21 @@ def _forward_prefill_with_dp_rebalancing(
991
994
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
992
995
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
993
996
npu_prefetch (self .q_a_proj .weight ,
994
- hidden_states ,
995
- enabled = self .enable_prefetch )
997
+ hidden_states ,
998
+ enabled = self .enable_prefetch )
996
999
sp_ckq = self .q_a_proj (device_sp_hidden_states )[0 ]
997
1000
sp_hidden_states_or_q_c = self .q_a_layernorm (sp_ckq )
998
1001
sp_kv_no_split = self .kv_a_proj_with_mqa (device_sp_hidden_states )[0 ]
999
1002
# Rearrange down_proj outputs across DP.
1000
- sp_down_proj_output = torch .cat ([sp_hidden_states_or_q_c , sp_kv_no_split ], dim = 1 )
1003
+ sp_down_proj_output = torch .cat (
1004
+ [sp_hidden_states_or_q_c , sp_kv_no_split ], dim = 1 )
1001
1005
sp_world_group = get_mla_dp_rebalancing_world_group ()
1002
- global_sp_down_proj_output = sp_world_group .all_gather (sp_down_proj_output , 0 )
1006
+ global_sp_down_proj_output = sp_world_group .all_gather (
1007
+ sp_down_proj_output , 0 )
1003
1008
local_dp = context .local_dp
1004
- dp_ori_down_proj_output = global_sp_down_proj_output [context .start_token_of_dp [local_dp ]:
1005
- context .end_token_of_dp [local_dp ]]
1009
+ dp_ori_down_proj_output = global_sp_down_proj_output [
1010
+ context .start_token_of_dp [local_dp ]:context .
1011
+ end_token_of_dp [local_dp ]]
1006
1012
prefill_q_c , prefill_kv_no_split = dp_ori_down_proj_output .split (
1007
1013
[self .q_lora_rank , self .kv_lora_rank + self .qk_rope_head_dim ],
1008
1014
dim = - 1 )
@@ -1048,14 +1054,15 @@ def _forward_prefill_with_dp_rebalancing(
1048
1054
tp_size = get_tp_group ().world_size
1049
1055
total_receive_len = context .local_device_total_receive_len
1050
1056
sp_o_proj_input = torch .empty (
1051
- [total_receive_len * tp_size , self .num_heads * self .v_head_dim ],
1052
- dtype = output_prefill .dtype ,
1053
- device = output_prefill .device )
1057
+ [total_receive_len * tp_size , self .num_heads * self .v_head_dim ],
1058
+ dtype = output_prefill .dtype ,
1059
+ device = output_prefill .device )
1054
1060
if get_dp_group ().world_size == 1 :
1055
1061
if output_prefill .shape [0 ] < context .num_padded_global_tokens :
1056
1062
output_prefill = nn .functional .pad (
1057
1063
output_prefill ,
1058
- (0 , 0 , 0 , context .num_padded_global_tokens - output_prefill .shape [0 ]))
1064
+ (0 , 0 , 0 , context .num_padded_global_tokens -
1065
+ output_prefill .shape [0 ]))
1059
1066
dist .all_to_all_single (
1060
1067
output = sp_o_proj_input ,
1061
1068
input = output_prefill ,
@@ -1070,8 +1077,7 @@ def _forward_prefill_with_dp_rebalancing(
1070
1077
group = sp_world_group .device_group ,
1071
1078
)
1072
1079
sp_o_proj_input = sp_o_proj_input .reshape (
1073
- total_receive_len ,
1074
- tp_size * self .num_heads * self .v_head_dim )
1080
+ total_receive_len , tp_size * self .num_heads * self .v_head_dim )
1075
1081
if total_receive_len < context .num_tokens_per_device :
1076
1082
sp_o_proj_input = nn .functional .pad (
1077
1083
sp_o_proj_input ,
@@ -1094,7 +1100,8 @@ def forward(
1094
1100
if get_ascend_config ().enable_mla_prefill_dp_rebalancing :
1095
1101
reach_layer_for_shared_weight_series (self .o_proj )
1096
1102
if get_mla_dp_rebalancing_context () is not None :
1097
- output [...] = self ._forward_prefill_with_dp_rebalancing (layer_name , hidden_states , kv_cache , attn_metadata )
1103
+ output [...] = self ._forward_prefill_with_dp_rebalancing (
1104
+ layer_name , hidden_states , kv_cache , attn_metadata )
1098
1105
return output
1099
1106
if attn_metadata is None :
1100
1107
# Profiling run.
0 commit comments