Skip to content

Commit 3a6068a

Browse files
clrs97CalvinXKY
andcommitted
fix format
Co-authored-by: CalvinXKY <kyxiezju@163.com> Signed-off-by: clrs97 <524936896@qq.com>
1 parent 77e0af3 commit 3a6068a

File tree

6 files changed

+100
-77
lines changed

6 files changed

+100
-77
lines changed

vllm_ascend/ascend_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ def __init__(self, vllm_config):
105105
if self.pd_tp_ratio == 0:
106106
raise AssertionError(
107107
"Only support P node tp size lagger then D node tp size")
108-
self.enable_mla_prefill_dp_rebalancing = additional_config.get("enable_mla_prefill_dp_rebalancing", False)
108+
self.enable_mla_prefill_dp_rebalancing = additional_config.get(
109+
"enable_mla_prefill_dp_rebalancing", False)
109110

110111

111112
class TorchairGraphConfig:

vllm_ascend/attention/mla_v1.py

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
TypeVar)
44

55
import torch
6+
import torch.distributed as dist
67
import torch_npu
78
from torch import nn
89
from vllm.attention.backends.abstract import (AttentionBackend,
910
AttentionMetadata,
1011
MLAAttentionImpl)
1112
from vllm.config import VllmConfig, get_current_vllm_config
12-
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
13+
from vllm.distributed import (get_dp_group,
14+
get_tensor_model_parallel_world_size,
15+
get_tp_group)
1316
from vllm.model_executor.layers.linear import (LinearBase, ReplicatedLinear,
1417
UnquantizedLinearMethod)
1518
from vllm.utils import cdiv, round_down
@@ -21,21 +24,19 @@
2124
maybe_save_kv_layer_to_connector,
2225
split_decodes_and_prefills,
2326
wait_for_kv_layer_from_connector)
27+
from vllm_ascend.distributed.parallel_state import (
28+
get_mla_dp_rebalancing_o_shared_group, get_mla_dp_rebalancing_world_group)
29+
from vllm_ascend.mla_dp_rebalancing import get_mla_dp_rebalancing_context
2430
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
2531
from vllm_ascend.multistream.context import get_multistream_comm_context
2632
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
27-
from vllm_ascend.utils import npu_prefetch
33+
from vllm_ascend.torchair.ops.shared_weight_layer import (
34+
post_process_after_loading_for_shared_weight_series,
35+
reach_layer_for_shared_weight_series,
36+
register_layer_to_shared_weight_series)
37+
from vllm_ascend.utils import dispose_tensor, npu_prefetch
2838
from vllm_ascend.worker.npu_input_batch import InputBatch
2939

30-
from vllm_ascend.distributed.parallel_state import get_mla_dp_rebalancing_o_shared_group, get_mla_dp_rebalancing_world_group
31-
from vllm_ascend.torchair.ops.shared_weight_layer import (post_process_after_loading_for_shared_weight_series,
32-
reach_layer_for_shared_weight_series,
33-
register_layer_to_shared_weight_series)
34-
from vllm_ascend.utils import dispose_tensor
35-
from vllm_ascend.mla_dp_rebalancing import get_mla_dp_rebalancing_context
36-
import torch.distributed as dist
37-
from vllm.distributed import get_dp_group
38-
3940
if TYPE_CHECKING:
4041
from vllm.v1.core.sched.output import SchedulerOutput
4142

@@ -510,7 +511,7 @@ def __init__(
510511
self.prefill_mask = None
511512

512513
self.speculative_config = vllm_config.speculative_config
513-
514+
514515
if ascend_config.enable_mla_prefill_dp_rebalancing:
515516
# Dispose tensor from the original o_proj
516517
for attr_name in dir(self.o_proj):
@@ -519,19 +520,21 @@ def __init__(
519520
dispose_tensor(attr_value)
520521
# Construct the new o_proj using ReplicatedLinear
521522
config = vllm_config.model_config.hf_config
522-
new_o_proj = ReplicatedLinear(config.num_attention_heads * config.v_head_dim,
523-
config.hidden_size,
524-
bias=False,
525-
quant_config=vllm_config.quant_config,
526-
prefix=self.o_proj.prefix)
523+
new_o_proj = ReplicatedLinear(
524+
config.num_attention_heads * config.v_head_dim,
525+
config.hidden_size,
526+
bias=False,
527+
quant_config=vllm_config.quant_config,
528+
prefix=self.o_proj.prefix)
527529
# Replace the o_proj with the new one
528530
self.o_proj.__class__ = new_o_proj.__class__
529531
self.o_proj.__dict__ = new_o_proj.__dict__
530532
# Register the o_proj into shared weight series to cut down memory usage
531-
register_layer_to_shared_weight_series(series_name="o_proj",
532-
group=get_mla_dp_rebalancing_o_shared_group(),
533-
layer=self.o_proj,
534-
prefetch_step=1)
533+
register_layer_to_shared_weight_series(
534+
series_name="o_proj",
535+
group=get_mla_dp_rebalancing_o_shared_group(),
536+
layer=self.o_proj,
537+
prefetch_step=1)
535538

536539
def _v_up_proj(self, x):
537540
# Convert from (B, N, L) to (N, B, L)
@@ -991,18 +994,21 @@ def _forward_prefill_with_dp_rebalancing(
991994
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
992995
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
993996
npu_prefetch(self.q_a_proj.weight,
994-
hidden_states,
995-
enabled=self.enable_prefetch)
997+
hidden_states,
998+
enabled=self.enable_prefetch)
996999
sp_ckq = self.q_a_proj(device_sp_hidden_states)[0]
9971000
sp_hidden_states_or_q_c = self.q_a_layernorm(sp_ckq)
9981001
sp_kv_no_split = self.kv_a_proj_with_mqa(device_sp_hidden_states)[0]
9991002
# Rearrange down_proj outputs across DP.
1000-
sp_down_proj_output = torch.cat([sp_hidden_states_or_q_c, sp_kv_no_split], dim=1)
1003+
sp_down_proj_output = torch.cat(
1004+
[sp_hidden_states_or_q_c, sp_kv_no_split], dim=1)
10011005
sp_world_group = get_mla_dp_rebalancing_world_group()
1002-
global_sp_down_proj_output = sp_world_group.all_gather(sp_down_proj_output, 0)
1006+
global_sp_down_proj_output = sp_world_group.all_gather(
1007+
sp_down_proj_output, 0)
10031008
local_dp = context.local_dp
1004-
dp_ori_down_proj_output = global_sp_down_proj_output[context.start_token_of_dp[local_dp]:
1005-
context.end_token_of_dp[local_dp]]
1009+
dp_ori_down_proj_output = global_sp_down_proj_output[
1010+
context.start_token_of_dp[local_dp]:context.
1011+
end_token_of_dp[local_dp]]
10061012
prefill_q_c, prefill_kv_no_split = dp_ori_down_proj_output.split(
10071013
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
10081014
dim=-1)
@@ -1048,14 +1054,15 @@ def _forward_prefill_with_dp_rebalancing(
10481054
tp_size = get_tp_group().world_size
10491055
total_receive_len = context.local_device_total_receive_len
10501056
sp_o_proj_input = torch.empty(
1051-
[total_receive_len * tp_size, self.num_heads * self.v_head_dim],
1052-
dtype=output_prefill.dtype,
1053-
device=output_prefill.device)
1057+
[total_receive_len * tp_size, self.num_heads * self.v_head_dim],
1058+
dtype=output_prefill.dtype,
1059+
device=output_prefill.device)
10541060
if get_dp_group().world_size == 1:
10551061
if output_prefill.shape[0] < context.num_padded_global_tokens:
10561062
output_prefill = nn.functional.pad(
10571063
output_prefill,
1058-
(0, 0, 0, context.num_padded_global_tokens - output_prefill.shape[0]))
1064+
(0, 0, 0, context.num_padded_global_tokens -
1065+
output_prefill.shape[0]))
10591066
dist.all_to_all_single(
10601067
output=sp_o_proj_input,
10611068
input=output_prefill,
@@ -1070,8 +1077,7 @@ def _forward_prefill_with_dp_rebalancing(
10701077
group=sp_world_group.device_group,
10711078
)
10721079
sp_o_proj_input = sp_o_proj_input.reshape(
1073-
total_receive_len,
1074-
tp_size * self.num_heads * self.v_head_dim)
1080+
total_receive_len, tp_size * self.num_heads * self.v_head_dim)
10751081
if total_receive_len < context.num_tokens_per_device:
10761082
sp_o_proj_input = nn.functional.pad(
10771083
sp_o_proj_input,
@@ -1094,7 +1100,8 @@ def forward(
10941100
if get_ascend_config().enable_mla_prefill_dp_rebalancing:
10951101
reach_layer_for_shared_weight_series(self.o_proj)
10961102
if get_mla_dp_rebalancing_context() is not None:
1097-
output[...] = self._forward_prefill_with_dp_rebalancing(layer_name, hidden_states, kv_cache, attn_metadata)
1103+
output[...] = self._forward_prefill_with_dp_rebalancing(
1104+
layer_name, hidden_states, kv_cache, attn_metadata)
10981105
return output
10991106
if attn_metadata is None:
11001107
# Profiling run.

vllm_ascend/distributed/parallel_state.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def get_mla_dp_rebalancing_world_group() -> GroupCoordinator:
5353

5454

5555
def get_mla_dp_rebalancing_o_shared_group() -> GroupCoordinator:
56-
assert _MLA_DP_REBALANCING_O_SHARED is not None, ("o_proj shared weight group for MLA DP rebalancing is not initialized")
56+
assert _MLA_DP_REBALANCING_O_SHARED is not None, (
57+
"o_proj shared weight group for MLA DP rebalancing is not initialized")
5758
return _MLA_DP_REBALANCING_O_SHARED
5859

5960

@@ -152,14 +153,16 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
152153
global _MLA_DP_REBALANCING_WORLD
153154
global _MLA_DP_REBALANCING_O_SHARED
154155
group_ranks = [list(range(torch.distributed.get_world_size()))]
155-
_MLA_DP_REBALANCING_WORLD = init_model_parallel_group(group_ranks,
156-
get_world_group().local_rank,
157-
backend,
158-
group_name="mla_dp_rebalancing_world")
159-
_MLA_DP_REBALANCING_O_SHARED = init_model_parallel_group(group_ranks,
160-
get_world_group().local_rank,
161-
backend,
162-
group_name="mla_dp_rebalancing_o_shared")
156+
_MLA_DP_REBALANCING_WORLD = init_model_parallel_group(
157+
group_ranks,
158+
get_world_group().local_rank,
159+
backend,
160+
group_name="mla_dp_rebalancing_world")
161+
_MLA_DP_REBALANCING_O_SHARED = init_model_parallel_group(
162+
group_ranks,
163+
get_world_group().local_rank,
164+
backend,
165+
group_name="mla_dp_rebalancing_o_shared")
163166

164167

165168
def get_mlp_tensor_model_parallel_world_size():

vllm_ascend/mla_dp_rebalancing.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,26 @@
77
from vllm.distributed.parallel_state import get_dp_group, get_tp_group
88
from vllm.forward_context import get_forward_context
99

10-
from vllm_ascend.distributed.parallel_state import get_ascend_config, get_mla_dp_rebalancing_world_group
10+
from vllm_ascend.distributed.parallel_state import (
11+
get_ascend_config, get_mla_dp_rebalancing_world_group)
1112

1213

1314
@dataclass
1415
class RebalancingContext:
1516
num_padded_global_tokens: int
1617
num_tokens_per_dp: int
1718
num_tokens_per_device: int
18-
start_token_of_dp: list[int] # no pad, original
19-
end_token_of_dp: list[int] # no pad, original
19+
start_token_of_dp: list[int] # no pad, original
20+
end_token_of_dp: list[int] # no pad, original
2021
global_tokens: torch.Tensor
21-
dp_sp_start_token: list[int] # i * num_tokens_per_dp
22-
dp_sp_end_token: list[int] # (i + 1) * num_tokens_per_dp
23-
device_sp_start_token: list[int] # i * num_tokens_per_device
24-
device_sp_end_token: list[int] # (i + 1) * num_tokens_per_device
22+
dp_sp_start_token: list[int] # i * num_tokens_per_dp
23+
dp_sp_end_token: list[int] # (i + 1) * num_tokens_per_dp
24+
device_sp_start_token: list[int] # i * num_tokens_per_device
25+
device_sp_end_token: list[int] # (i + 1) * num_tokens_per_device
2526
local_dp: int
2627
local_device: int
27-
local_device_sp_start_token_within_dp: int # tp_group.rank_in_group * num_tokens_per_device
28-
local_device_sp_end_token_within_dp: int # (tp_group.rank_in_group + 1) * num_tokens_per_device
28+
local_device_sp_start_token_within_dp: int # tp_group.rank_in_group * num_tokens_per_device
29+
local_device_sp_end_token_within_dp: int # (tp_group.rank_in_group + 1) * num_tokens_per_device
2930
local_device_total_receive_len: int
3031
input_split_sizes: list[int]
3132
output_split_sizes: list[int]
@@ -66,7 +67,7 @@ def set_mla_dp_rebalancing_context(input_ids: torch.Tensor):
6667
num_input_tokens = attn_metadata.num_actual_tokens
6768
else:
6869
num_input_tokens = 1
69-
70+
7071
input_ids = input_ids[:num_input_tokens]
7172

7273
rebalancing_metadata = torch.cat([
@@ -76,7 +77,8 @@ def set_mla_dp_rebalancing_context(input_ids: torch.Tensor):
7677
nn.functional.pad(input_ids,
7778
(0, max_num_tokens_across_dp - num_input_tokens)),
7879
]).unsqueeze(0)
79-
rebalancing_metadata_across_dp = dp_group.all_gather(rebalancing_metadata, 0)
80+
rebalancing_metadata_across_dp = dp_group.all_gather(
81+
rebalancing_metadata, 0)
8082
for i in range(dp_group.world_size):
8183
row = rebalancing_metadata_across_dp[i]
8284
feature_enabled = bool(row[0] > 0)
@@ -120,7 +122,8 @@ def set_mla_dp_rebalancing_context(input_ids: torch.Tensor):
120122
local_dp = dp_group.rank_in_group
121123
local_device = sp_world_group.rank_in_group
122124
local_device_sp_start_token_within_dp = tp_group.rank_in_group * num_tokens_per_device
123-
local_device_sp_end_token_within_dp = (tp_group.rank_in_group + 1) * num_tokens_per_device
125+
local_device_sp_end_token_within_dp = (tp_group.rank_in_group +
126+
1) * num_tokens_per_device
124127

125128
tp_size = tp_group.world_size
126129
input_split_sizes = []
@@ -160,7 +163,8 @@ def set_mla_dp_rebalancing_context(input_ids: torch.Tensor):
160163
if dp_metadata is not None:
161164
dp_metadata.max_tokens_across_dp_cpu.fill_(num_tokens_per_dp)
162165
for i in range(dp_group.world_size):
163-
dp_metadata.cu_tokens_across_dp_cpu[i] = (i + 1) * num_tokens_per_dp
166+
dp_metadata.cu_tokens_across_dp_cpu[i] = (i +
167+
1) * num_tokens_per_dp
164168

165169
_mla_dp_rebalancing_context = RebalancingContext(
166170
num_padded_global_tokens=num_padded_global_tokens,
@@ -175,7 +179,8 @@ def set_mla_dp_rebalancing_context(input_ids: torch.Tensor):
175179
device_sp_end_token=device_sp_end_token,
176180
local_dp=local_dp,
177181
local_device=local_device,
178-
local_device_sp_start_token_within_dp=local_device_sp_start_token_within_dp,
182+
local_device_sp_start_token_within_dp=
183+
local_device_sp_start_token_within_dp,
179184
local_device_sp_end_token_within_dp=local_device_sp_end_token_within_dp,
180185
local_device_total_receive_len=local_device_total_receive_len,
181186
input_split_sizes=input_split_sizes,
@@ -187,15 +192,16 @@ def set_mla_dp_rebalancing_context(input_ids: torch.Tensor):
187192
def calc_div_ceil(up: int, down: int) -> int:
188193
return (up + down - 1) // down
189194

195+
190196
def pre_forward_for_dp_rebalancing(input_ids: torch.Tensor) -> torch.Tensor:
191197
set_mla_dp_rebalancing_context(input_ids)
192198
context = get_mla_dp_rebalancing_context()
193199
if context is None:
194200
return input_ids
195201
local_dp = context.local_dp
196-
return context.global_tokens[
197-
context.dp_sp_start_token[local_dp]:context.
198-
dp_sp_end_token[local_dp]]
202+
return context.global_tokens[context.dp_sp_start_token[local_dp]:context.
203+
dp_sp_end_token[local_dp]]
204+
199205

200206
def recover_output(hidden_states: torch.Tensor) -> torch.Tensor:
201207
context = get_mla_dp_rebalancing_context()
@@ -205,7 +211,10 @@ def recover_output(hidden_states: torch.Tensor) -> torch.Tensor:
205211
local_dp_end_token = context.end_token_of_dp[local_dp]
206212
local_dp_sp_start_token = context.dp_sp_start_token[local_dp]
207213
local_dp_sp_end_token = context.dp_sp_end_token[local_dp]
208-
send = hidden_states[:max(0, min(local_dp_sp_end_token, context.end_token_of_dp[-1]) - local_dp_sp_start_token)]
214+
send = hidden_states[:max(
215+
0,
216+
min(local_dp_sp_end_token, context.end_token_of_dp[-1]) -
217+
local_dp_sp_start_token)]
209218
dp_group = get_dp_group()
210219
if dp_group.world_size == 1:
211220
return send
@@ -238,12 +247,14 @@ def recover_output(hidden_states: torch.Tensor) -> torch.Tensor:
238247
)
239248
return output
240249

241-
def post_forward_for_dp_rebalancing(hidden_states: torch.Tensor) -> torch.Tensor:
250+
251+
def post_forward_for_dp_rebalancing(
252+
hidden_states: torch.Tensor) -> torch.Tensor:
242253
context = get_mla_dp_rebalancing_context()
243254
if context is None:
244255
return hidden_states
245256
output = recover_output(hidden_states)
246257
if output.shape[0] < context.num_output_tokens:
247-
output = nn.functional.pad(output,
248-
(0, 0, 0, context.num_output_tokens - output.shape[0]))
249-
return output
258+
output = nn.functional.pad(
259+
output, (0, 0, 0, context.num_output_tokens - output.shape[0]))
260+
return output

vllm_ascend/torchair/ops/shared_weight_layer.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def dispose_tensor(x: torch.Tensor):
1616
class LayerMetadata:
1717
"""Metadata for a layer.
1818
"""
19-
layer_idx: int # The index of the layer.
19+
layer_idx: int # The index of the layer.
2020
layer: LinearBase # The layer object.
2121
post_method: Callable[[
2222
torch.nn.Module
@@ -56,7 +56,7 @@ def post_process_after_loading(self):
5656
# This method only needs to be called once per series.
5757
if self.shared_windows:
5858
return
59-
59+
6060
self.layers.sort(key=lambda x: x.layer_idx)
6161
self.num_layers = len(self.layers)
6262
assert self.num_layers > 0, "No layers in the series"
@@ -212,13 +212,14 @@ def register_layer_to_shared_weight_series(
212212
series = _series_dict[series_name]
213213
assert layer.quant_method is not None
214214
layer_idx = extract_layer_index(layer.prefix)
215-
series.layers.append(LayerMetadata(
216-
layer_idx=layer_idx,
217-
layer=layer,
218-
post_method=layer.quant_method.process_weights_after_loading,
219-
weight=layer.weight,
220-
window_idx=-1,
221-
))
215+
series.layers.append(
216+
LayerMetadata(
217+
layer_idx=layer_idx,
218+
layer=layer,
219+
post_method=layer.quant_method.process_weights_after_loading,
220+
weight=layer.weight,
221+
window_idx=-1,
222+
))
222223
# Discard the original `process_weights_after_loading` method such that it won't be called by others.
223224
layer.quant_method.process_weights_after_loading = lambda layer: None
224225
# When the layer not intended to be stored in this device, dispose the tensor and skip weight loading.

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@
111111
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
112112
from vllm_ascend.eplb.eplb_updator import EplbUpdator
113113
from vllm_ascend.eplb.utils import model_register
114+
from vllm_ascend.mla_dp_rebalancing import (post_forward_for_dp_rebalancing,
115+
pre_forward_for_dp_rebalancing)
114116
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
115117
from vllm_ascend.multistream.ms_split import compute_split_seq_index
116118
from vllm_ascend.platform import NPUPlatform
@@ -126,8 +128,6 @@
126128
lmhead_tp_enable)
127129
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
128130

129-
from vllm_ascend.mla_dp_rebalancing import pre_forward_for_dp_rebalancing, post_forward_for_dp_rebalancing
130-
131131
if TYPE_CHECKING:
132132
import xgrammar as xgr # type: ignore[import-untyped]
133133
from vllm.v1.core.sched.output import SchedulerOutput

0 commit comments

Comments
 (0)