Skip to content

Commit 3b372e4

Browse files
rootCalvinXKY
authored andcommitted
Fix isort and yapf
Co-authored-by: CalvinXKY <kyxiezju@163.com> Signed-off-by: clrs97 <524936896@qq.com>
1 parent d526479 commit 3b372e4

File tree

6 files changed

+139
-75
lines changed

6 files changed

+139
-75
lines changed

vllm_ascend/ascend_config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,12 @@ def __init__(self, vllm_config):
6363
)
6464

6565
self.enable_mla_sp = additional_config.get("enable_mla_sp", False)
66-
self.o_shard_parallel_size = int(additional_config.get("o_shard_parallel_size", -1))
66+
self.o_shard_parallel_size = int(
67+
additional_config.get("o_shard_parallel_size", -1))
6768
self.enable_o_shard = self.o_shard_parallel_size > 0
68-
self.o_shard_full_layers = int(additional_config.get("o_shard_full_layers", 0))
69+
self.o_shard_full_layers = int(
70+
additional_config.get("o_shard_full_layers", 0))
71+
6972

7073
class TorchairGraphConfig:
7174
"""

vllm_ascend/attention/mla_v1.py

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar
33

44
import torch
5+
import torch.distributed as dist
56
import torch_npu
67
from torch import nn
78
from vllm.attention.backends.abstract import (AttentionBackend,
89
AttentionMetadata,
910
MLAAttentionImpl)
1011
from vllm.config import VllmConfig, get_current_vllm_config
1112
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
13+
from vllm.distributed.parallel_state import get_dp_group
1214
from vllm.model_executor.layers.linear import (LinearBase,
1315
UnquantizedLinearMethod)
1416
from vllm.utils import cdiv, round_down
@@ -17,19 +19,16 @@
1719
from vllm_ascend.attention.attention_v1 import AscendAttentionState
1820
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
1921
split_decodes_and_prefills)
22+
from vllm_ascend.distributed.parallel_state import get_mla_sp_world_group
23+
from vllm_ascend.mla_sp_context import get_sp_context
2024
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
2125
from vllm_ascend.multistream.context import get_multistream_comm_context
2226
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
2327
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
28+
from vllm_ascend.ops.shard import RowShardLinear
2429
from vllm_ascend.utils import npu_prefetch
2530
from vllm_ascend.worker.npu_input_batch import InputBatch
2631

27-
import torch.distributed as dist
28-
from vllm.distributed.parallel_state import get_dp_group
29-
from vllm_ascend.distributed.parallel_state import get_mla_sp_world_group
30-
from vllm_ascend.mla_sp_context import get_sp_context
31-
from vllm_ascend.ops.shard import RowShardLinear
32-
3332
if TYPE_CHECKING:
3433
from vllm.v1.core.sched.output import SchedulerOutput
3534

@@ -968,7 +967,7 @@ def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
968967
prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe,
969968
prefill_value)
970969
return decode_preprocess_res, prefill_preprocess_res
971-
970+
972971
def _forward_prefill_sp(
973972
self,
974973
hidden_states: torch.Tensor,
@@ -982,10 +981,13 @@ def _forward_prefill_sp(
982981
enabled=self.enable_prefetch)
983982
# Split inputs from local DP to each device.
984983
dp_sp_hidden_states = hidden_states
985-
rank_sp_hidden_states = dp_sp_hidden_states[sp_context.my_rank_sp_start_token_within_dp:sp_context.my_rank_sp_end_token_within_dp]
984+
rank_sp_hidden_states = dp_sp_hidden_states[
985+
sp_context.my_rank_sp_start_token_within_dp:sp_context.
986+
my_rank_sp_end_token_within_dp]
986987
sp_tokens = rank_sp_hidden_states.shape[0]
987988
if sp_tokens == 0:
988-
rank_sp_hidden_states = nn.functional.pad(rank_sp_hidden_states, (0, 0, 0, 1))
989+
rank_sp_hidden_states = nn.functional.pad(rank_sp_hidden_states,
990+
(0, 0, 0, 1))
989991
sp_tokens = 1
990992
# MLA prefill:
991993
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
@@ -997,19 +999,23 @@ def _forward_prefill_sp(
997999
# Rearrange down_proj outputs across DP.
9981000
sp_output = torch.cat([sp_hidden_states_or_q_c, sp_kv_no_split], dim=1)
9991001
if sp_tokens < sp_context.num_tokens_per_rank:
1000-
sp_output = nn.functional.pad(sp_output, (0, 0, 0, sp_context.num_tokens_per_rank - sp_tokens))
1002+
sp_output = nn.functional.pad(
1003+
sp_output,
1004+
(0, 0, 0, sp_context.num_tokens_per_rank - sp_tokens))
10011005
global_sp_output = get_mla_sp_world_group().all_gather(sp_output, 0)
10021006
my_dp = sp_context.my_dp
1003-
dp_output = global_sp_output[sp_context.start_token_of_dp[my_dp]:sp_context.end_token_of_dp[my_dp]]
1004-
prefill_q_c, prefill_kv_no_split = dp_output.split([self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1)
1007+
dp_output = global_sp_output[sp_context.start_token_of_dp[my_dp]:
1008+
sp_context.end_token_of_dp[my_dp]]
1009+
prefill_q_c, prefill_kv_no_split = dp_output.split(
1010+
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
1011+
dim=-1)
10051012

10061013
if attn_metadata is None:
10071014
# Dummy run, just construct the attention outputs.
10081015
output_prefill = torch.empty(
10091016
[prefill_q_c.shape[0], self.num_heads * self.v_head_dim],
10101017
dtype=hidden_states.dtype,
1011-
device=hidden_states.device
1012-
)
1018+
device=hidden_states.device)
10131019
else:
10141020
# Preprocess prefill tokens, write kv cache and get:
10151021
# prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value
@@ -1025,7 +1031,7 @@ def _forward_prefill_sp(
10251031
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(
10261032
prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
10271033
prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0],
1028-
self.num_kv_heads, -1)
1034+
self.num_kv_heads, -1)
10291035
prefill_k_nope, prefill_value = self.kv_b_proj(
10301036
prefill_k_c_normed)[0].view(
10311037
-1, self.num_heads,
@@ -1034,10 +1040,11 @@ def _forward_prefill_sp(
10341040
prefill_k_pe = prefill_k_pe.expand(
10351041
(*prefill_k_nope.shape[:-1], -1))
10361042
# Attention outputs.
1037-
output_prefill = self._forward_prefill(
1038-
prefill_q_nope, prefill_q_pe,
1039-
prefill_k_nope, prefill_k_pe,
1040-
prefill_value, kv_cache, attn_metadata)
1043+
output_prefill = self._forward_prefill(prefill_q_nope,
1044+
prefill_q_pe,
1045+
prefill_k_nope,
1046+
prefill_k_pe, prefill_value,
1047+
kv_cache, attn_metadata)
10411048

10421049
# Rearrange attention outputs across DP to run SP.
10431050
sp_world_group = get_mla_sp_world_group()
@@ -1050,33 +1057,38 @@ def _forward_prefill_sp(
10501057
if get_dp_group().world_size == 1:
10511058
padded_len = sp_context.num_tokens_per_rank * sp_world_group.world_size
10521059
if sp_context.num_global_tokens < padded_len:
1053-
sp_send = nn.functional.pad(sp_send, (0, 0, 0, padded_len - sp_context.num_global_tokens))
1054-
sp_output = torch.empty(
1055-
[sp_context.num_tokens_per_rank * tp_size, self.num_heads * self.v_head_dim],
1056-
dtype=sp_send.dtype,
1057-
device=sp_send.device
1058-
)
1060+
sp_send = nn.functional.pad(
1061+
sp_send,
1062+
(0, 0, 0, padded_len - sp_context.num_global_tokens))
1063+
sp_output = torch.empty([
1064+
sp_context.num_tokens_per_rank * tp_size,
1065+
self.num_heads * self.v_head_dim
1066+
],
1067+
dtype=sp_send.dtype,
1068+
device=sp_send.device)
10591069
dist.all_to_all_single(
10601070
output=sp_output,
10611071
input=sp_send,
10621072
group=sp_world_group.device_group,
10631073
)
1064-
sp_output = sp_output.reshape(sp_context.num_tokens_per_rank, tp_size * self.num_heads * self.v_head_dim)
1074+
sp_output = sp_output.reshape(
1075+
sp_context.num_tokens_per_rank,
1076+
tp_size * self.num_heads * self.v_head_dim)
10651077
sp_output = sp_output[:num_sp_tokens]
10661078
else:
10671079
sp_output = torch.empty(
10681080
[num_sp_tokens * tp_size, self.num_heads * self.v_head_dim],
10691081
dtype=sp_send.dtype,
1070-
device=sp_send.device
1071-
)
1082+
device=sp_send.device)
10721083
dist.all_to_all_single(
10731084
output=sp_output,
10741085
input=sp_send,
10751086
output_split_sizes=sp_context.output_split_sizes,
10761087
input_split_sizes=sp_context.input_split_sizes,
10771088
group=sp_world_group.device_group,
10781089
)
1079-
sp_output = sp_output.reshape(num_sp_tokens, tp_size * self.num_heads * self.v_head_dim)
1090+
sp_output = sp_output.reshape(
1091+
num_sp_tokens, tp_size * self.num_heads * self.v_head_dim)
10801092
sp_tokens = sp_output.shape[0]
10811093
if sp_tokens == 0:
10821094
sp_output = nn.functional.pad(sp_output, (0, 0, 0, 1))
@@ -1085,7 +1097,9 @@ def _forward_prefill_sp(
10851097
o_output = self.o_proj(sp_output)[0]
10861098
del sp_output
10871099
if sp_tokens < sp_context.num_tokens_per_rank:
1088-
o_output = nn.functional.pad(o_output, (0, 0, 0, sp_context.num_tokens_per_rank - sp_tokens))
1100+
o_output = nn.functional.pad(
1101+
o_output,
1102+
(0, 0, 0, sp_context.num_tokens_per_rank - sp_tokens))
10891103
dp_output = get_tp_group().all_gather(o_output, 0)
10901104
dp_output = dp_output[:sp_context.num_my_dp_sp_tokens]
10911105
return dp_output
@@ -1101,7 +1115,8 @@ def forward(
11011115
assert output is not None, "Output tensor must be provided."
11021116
if get_sp_context() is not None:
11031117
# SP across DP
1104-
output[...] = self._forward_prefill_sp(hidden_states, kv_cache, attn_metadata)
1118+
output[...] = self._forward_prefill_sp(hidden_states, kv_cache,
1119+
attn_metadata)
11051120
return output
11061121
if attn_metadata is None:
11071122
# Profiling run.

vllm_ascend/distributed/parallel_state.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
9393

9494
init_ascend_mla_sp_model_parallel()
9595

96+
9697
def get_mlp_tensor_model_parallel_world_size():
9798
"""Return world size for the tensor model parallel group."""
9899
return get_mlp_tp_group().world_size
@@ -102,19 +103,24 @@ def get_mlp_tensor_model_parallel_rank():
102103
"""Return world size for the tensor model parallel group."""
103104
return get_mlp_tp_group().rank_in_group
104105

106+
105107
# vllm-ascend will maintain its own MLA SP world GroupCoordinator and o_proj sharding GroupCoordinator for
106108
# customize parallel solution
107109
_MLA_SP_WORLD: Optional[GroupCoordinator] = None
108110
_O_SHARD: Optional[GroupCoordinator] = None
109111

112+
110113
def get_mla_sp_world_group() -> GroupCoordinator:
111-
assert _MLA_SP_WORLD is not None, ("MLA sequence parallel world group is not initialized")
114+
assert _MLA_SP_WORLD is not None, (
115+
"MLA sequence parallel world group is not initialized")
112116
return _MLA_SP_WORLD
113117

118+
114119
def get_o_shard_group() -> GroupCoordinator:
115120
assert _O_SHARD is not None, ("o_proj sharding group is not initialized")
116121
return _O_SHARD
117122

123+
118124
def init_ascend_mla_sp_model_parallel():
119125
from vllm_ascend.ascend_config import get_ascend_config
120126
ascend_config = get_ascend_config()
@@ -138,13 +144,16 @@ def init_ascend_mla_sp_model_parallel():
138144
num_o_shard_parallel_groups = world_size // o_shard_parallel_size
139145
group_ranks = []
140146
for i in range(num_o_shard_parallel_groups):
141-
ranks = list(range(i * o_shard_parallel_size, (i + 1) * o_shard_parallel_size))
147+
ranks = list(
148+
range(i * o_shard_parallel_size,
149+
(i + 1) * o_shard_parallel_size))
142150
group_ranks.append(ranks)
143151
_O_SHARD = init_model_parallel_group(group_ranks,
144152
get_world_group().local_rank,
145153
backend,
146154
group_name="o_shard")
147155

156+
148157
def destroy_ascend_model_parallel():
149158
global _MC2
150159
if _MC2:

vllm_ascend/mla_sp_context.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1+
from dataclasses import dataclass
12
from typing import Optional, Union
23

34
import torch
45
from torch import nn
5-
6-
from dataclasses import dataclass
76
from vllm.attention import AttentionMetadata
8-
from vllm.distributed.parallel_state import (get_dp_group, get_tp_group)
7+
from vllm.distributed.parallel_state import get_dp_group, get_tp_group
98
from vllm.forward_context import get_forward_context
9+
1010
from vllm_ascend.distributed.parallel_state import get_mla_sp_world_group
1111

12+
1213
@dataclass
1314
class SPContext:
1415
num_global_tokens: int
@@ -29,14 +30,18 @@ class SPContext:
2930
input_split_sizes: list[int]
3031
output_split_sizes: list[int]
3132

33+
3234
_sp_context: Optional[SPContext] = None
3335

36+
3437
def get_sp_context() -> Optional[SPContext]:
3538
return _sp_context
3639

40+
3741
def set_sp_context(
3842
input_ids: torch.Tensor,
39-
attn_metadata: Optional[Union["AttentionMetadata", dict[str, "AttentionMetadata"]]] = None,
43+
attn_metadata: Optional[Union["AttentionMetadata",
44+
dict[str, "AttentionMetadata"]]] = None,
4045
):
4146
global _sp_context
4247
_sp_context = None
@@ -66,8 +71,11 @@ def set_sp_context(
6671
assert num_input_tokens == 1, "Length of dummy run must be 1."
6772

6873
sp_metadata = torch.cat([
69-
torch.tensor([sp_enabled, num_input_tokens], device=input_ids.device, dtype=torch.int32),
70-
nn.functional.pad(input_ids, (0, max_num_tokens_across_dp - num_input_tokens)),
74+
torch.tensor([sp_enabled, num_input_tokens],
75+
device=input_ids.device,
76+
dtype=torch.int32),
77+
nn.functional.pad(input_ids,
78+
(0, max_num_tokens_across_dp - num_input_tokens)),
7179
]).unsqueeze(0)
7280
sp_metadata_across_dp = dp_group.all_gather(sp_metadata, 0)
7381
for i in range(dp_group.world_size):
@@ -86,13 +94,17 @@ def set_sp_context(
8694
num_global_tokens += num_tokens
8795
end_token_of_dp.append(num_global_tokens)
8896

89-
num_tokens_per_rank = calc_div_ceil(num_global_tokens, sp_world_group.world_size)
97+
num_tokens_per_rank = calc_div_ceil(num_global_tokens,
98+
sp_world_group.world_size)
9099
num_tokens_per_dp = num_tokens_per_rank * tp_group.world_size
91-
global_tokens = torch.empty(num_global_tokens, dtype=input_ids.dtype, device=input_ids.device)
100+
global_tokens = torch.empty(num_global_tokens,
101+
dtype=input_ids.dtype,
102+
device=input_ids.device)
92103
for i in range(dp_group.world_size):
93104
row = sp_metadata_across_dp[i]
94105
num_tokens = row[1]
95-
global_tokens[start_token_of_dp[i]:end_token_of_dp[i]] = row[2:num_tokens+2]
106+
global_tokens[start_token_of_dp[i]:end_token_of_dp[i]] = row[
107+
2:num_tokens + 2]
96108

97109
dp_sp_start_token = []
98110
dp_sp_end_token = []
@@ -101,17 +113,22 @@ def set_sp_context(
101113
for i in range(dp_group.world_size):
102114
start_token = i * num_tokens_per_dp
103115
dp_sp_start_token.append(start_token)
104-
dp_sp_end_token.append(min(start_token + num_tokens_per_dp, num_global_tokens))
116+
dp_sp_end_token.append(
117+
min(start_token + num_tokens_per_dp, num_global_tokens))
105118
for i in range(sp_world_group.world_size):
106119
start_token = i * num_tokens_per_rank
107120
rank_sp_start_token.append(start_token)
108-
rank_sp_end_token.append(min(start_token + num_tokens_per_rank, num_global_tokens))
121+
rank_sp_end_token.append(
122+
min(start_token + num_tokens_per_rank, num_global_tokens))
109123

110124
my_dp = dp_group.rank_in_group
111125
my_rank = sp_world_group.rank_in_group
112126
my_rank_sp_start_token_within_dp = tp_group.rank_in_group * num_tokens_per_rank
113-
my_rank_sp_end_token_within_dp = min(my_rank_sp_start_token_within_dp + num_tokens_per_rank, max(0, dp_sp_end_token[my_dp] - dp_sp_start_token[my_dp]))
114-
num_my_dp_sp_tokens = max(0, dp_sp_end_token[my_dp] - dp_sp_start_token[my_dp])
127+
my_rank_sp_end_token_within_dp = min(
128+
my_rank_sp_start_token_within_dp + num_tokens_per_rank,
129+
max(0, dp_sp_end_token[my_dp] - dp_sp_start_token[my_dp]))
130+
num_my_dp_sp_tokens = max(
131+
0, dp_sp_end_token[my_dp] - dp_sp_start_token[my_dp])
115132

116133
tp_size = tp_group.world_size
117134
input_split_sizes = []
@@ -138,7 +155,8 @@ def set_sp_context(
138155

139156
forward_context.with_prefill = True
140157
forward_context.max_tokens_across_dp = num_tokens_per_dp
141-
forward_context.padded_num_tokens = calc_div_ceil(num_tokens_per_dp, tp_size) * tp_size
158+
forward_context.padded_num_tokens = calc_div_ceil(num_tokens_per_dp,
159+
tp_size) * tp_size
142160
from vllm_ascend.ascend_forward_context import FusedMoEState
143161
if forward_context.fused_moe_state == FusedMoEState.NaiveMulticast:
144162
forward_context.fused_moe_state = FusedMoEState.AllGather
@@ -170,5 +188,6 @@ def set_sp_context(
170188
output_split_sizes=output_split_sizes,
171189
)
172190

191+
173192
def calc_div_ceil(up: int, down: int) -> int:
174193
return (up + down - 1) // down

0 commit comments

Comments
 (0)