Skip to content

Commit 78d3141

Browse files
clrs97CalvinXKY
andcommitted
Run SP across DP for DeepSeekV2 MLA
Co-authored-by: CalvinXKY <kyxiezju@163.com> Signed-off-by: clrs97 <524936896@qq.com>
1 parent ea53f90 commit 78d3141

File tree

6 files changed

+633
-1
lines changed

6 files changed

+633
-1
lines changed

vllm_ascend/ascend_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ def __init__(self, vllm_config):
6262
"lmhead_tensor_parallel_size is only supported in the pure DP scenario"
6363
)
6464

65+
self.enable_mla_sp = additional_config.get("enable_mla_sp", False)
66+
self.o_shard_parallel_size = int(additional_config.get("o_shard_parallel_size", -1))
67+
self.enable_o_shard = self.o_shard_parallel_size > 0
68+
self.o_shard_full_layers = int(additional_config.get("o_shard_full_layers", 0))
6569

6670
class TorchairGraphConfig:
6771
"""

vllm_ascend/attention/mla_v1.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424
from vllm_ascend.utils import npu_prefetch
2525
from vllm_ascend.worker.npu_input_batch import InputBatch
2626

27+
import torch.distributed as dist
28+
from vllm.distributed.parallel_state import get_dp_group
29+
from vllm_ascend.distributed.parallel_state import get_mla_sp_world_group
30+
from vllm_ascend.mla_sp_context import get_sp_context
31+
from vllm_ascend.ops.shard import RowShardLinear
32+
2733
if TYPE_CHECKING:
2834
from vllm.v1.core.sched.output import SchedulerOutput
2935

@@ -962,6 +968,126 @@ def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
962968
prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe,
963969
prefill_value)
964970
return decode_preprocess_res, prefill_preprocess_res
971+
972+
def _forward_prefill_sp(
973+
self,
974+
hidden_states: torch.Tensor,
975+
kv_cache: Tuple[torch.Tensor],
976+
attn_metadata: M,
977+
) -> torch.Tensor:
978+
sp_context = get_sp_context()
979+
assert sp_context is not None
980+
npu_prefetch(self.q_a_proj.weight,
981+
hidden_states,
982+
enabled=self.enable_prefetch)
983+
# Split inputs from local DP to each device.
984+
dp_sp_hidden_states = hidden_states
985+
rank_sp_hidden_states = dp_sp_hidden_states[sp_context.my_rank_sp_start_token_within_dp:sp_context.my_rank_sp_end_token_within_dp]
986+
sp_tokens = rank_sp_hidden_states.shape[0]
987+
if sp_tokens == 0:
988+
rank_sp_hidden_states = nn.functional.pad(rank_sp_hidden_states, (0, 0, 0, 1))
989+
sp_tokens = 1
990+
# MLA prefill:
991+
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
992+
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
993+
# 3. If need_gather_q_kv, perform all_gather.
994+
sp_ckq = self.q_a_proj(rank_sp_hidden_states)[0]
995+
sp_hidden_states_or_q_c = self.q_a_layernorm(sp_ckq)
996+
sp_kv_no_split = self.kv_a_proj_with_mqa(rank_sp_hidden_states)[0]
997+
# Rearrange down_proj outputs across DP.
998+
sp_output = torch.cat([sp_hidden_states_or_q_c, sp_kv_no_split], dim=1)
999+
if sp_tokens < sp_context.num_tokens_per_rank:
1000+
sp_output = nn.functional.pad(sp_output, (0, 0, 0, sp_context.num_tokens_per_rank - sp_tokens))
1001+
global_sp_output = get_mla_sp_world_group().all_gather(sp_output, 0)
1002+
my_dp = sp_context.my_dp
1003+
dp_output = global_sp_output[sp_context.start_token_of_dp[my_dp]:sp_context.end_token_of_dp[my_dp]]
1004+
prefill_q_c, prefill_kv_no_split = dp_output.split([self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1)
1005+
1006+
if attn_metadata is None:
1007+
# Dummy run, just construct the attention outputs.
1008+
output_prefill = torch.empty(
1009+
[prefill_q_c.shape[0], self.num_heads * self.v_head_dim],
1010+
dtype=hidden_states.dtype,
1011+
device=hidden_states.device
1012+
)
1013+
else:
1014+
# Preprocess prefill tokens, write kv cache and get:
1015+
# prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value
1016+
prefill_q = self.q_proj(prefill_q_c)[0]\
1017+
.view(-1, self.num_heads, self.qk_head_dim)
1018+
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
1019+
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
1020+
cos = attn_metadata.prefill.cos
1021+
sin = attn_metadata.prefill.sin
1022+
prefill_slots = attn_metadata.slot_mapping
1023+
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
1024+
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(
1025+
prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
1026+
prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0],
1027+
self.num_kv_heads, -1)
1028+
prefill_k_nope, prefill_value = self.kv_b_proj(
1029+
prefill_k_c_normed)[0].view(
1030+
-1, self.num_heads,
1031+
self.qk_nope_head_dim + self.v_head_dim).split(
1032+
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
1033+
prefill_k_pe = prefill_k_pe.expand(
1034+
(*prefill_k_nope.shape[:-1], -1))
1035+
# Attention outputs.
1036+
output_prefill = self._forward_prefill(
1037+
prefill_q_nope, prefill_q_pe,
1038+
prefill_k_nope, prefill_k_pe,
1039+
prefill_value, kv_cache, attn_metadata)
1040+
1041+
# Rearrange attention outputs across DP to run SP.
1042+
sp_world_group = get_mla_sp_world_group()
1043+
tp_size = get_tp_group().world_size
1044+
my_rank = sp_context.my_rank
1045+
my_rank_start_token = sp_context.rank_sp_start_token[my_rank]
1046+
my_rank_end_token = sp_context.rank_sp_end_token[my_rank]
1047+
num_sp_tokens = max(my_rank_end_token - my_rank_start_token, 0)
1048+
sp_send = output_prefill
1049+
if get_dp_group().world_size == 1:
1050+
padded_len = sp_context.num_tokens_per_rank * sp_world_group.world_size
1051+
if sp_context.num_global_tokens < padded_len:
1052+
sp_send = nn.functional.pad(sp_send, (0, 0, 0, padded_len - sp_context.num_global_tokens))
1053+
sp_output = torch.empty(
1054+
[sp_context.num_tokens_per_rank * tp_size, self.num_heads * self.v_head_dim],
1055+
dtype=sp_send.dtype,
1056+
device=sp_send.device
1057+
)
1058+
dist.all_to_all_single(
1059+
output=sp_output,
1060+
input=sp_send,
1061+
group=sp_world_group.device_group,
1062+
)
1063+
sp_output = sp_output.reshape(sp_context.num_tokens_per_rank, tp_size * self.num_heads * self.v_head_dim)
1064+
sp_output = sp_output[:num_sp_tokens]
1065+
else:
1066+
sp_output = torch.empty(
1067+
[num_sp_tokens * tp_size, self.num_heads * self.v_head_dim],
1068+
dtype=sp_send.dtype,
1069+
device=sp_send.device
1070+
)
1071+
dist.all_to_all_single(
1072+
output=sp_output,
1073+
input=sp_send,
1074+
output_split_sizes=sp_context.output_split_sizes,
1075+
input_split_sizes=sp_context.input_split_sizes,
1076+
group=sp_world_group.device_group,
1077+
)
1078+
sp_output = sp_output.reshape(num_sp_tokens, tp_size * self.num_heads * self.v_head_dim)
1079+
sp_tokens = sp_output.shape[0]
1080+
if sp_tokens == 0:
1081+
sp_output = nn.functional.pad(sp_output, (0, 0, 0, 1))
1082+
sp_tokens = 1
1083+
# O proj
1084+
o_output = self.o_proj(sp_output)[0]
1085+
del sp_output
1086+
if sp_tokens < sp_context.num_tokens_per_rank:
1087+
o_output = nn.functional.pad(o_output, (0, 0, 0, sp_context.num_tokens_per_rank - sp_tokens))
1088+
dp_output = get_tp_group().all_gather(o_output, 0)
1089+
dp_output = dp_output[:sp_context.num_my_dp_sp_tokens]
1090+
return dp_output
9651091

9661092
def forward(
9671093
self,
@@ -972,6 +1098,10 @@ def forward(
9721098
output: Optional[torch.Tensor] = None,
9731099
) -> torch.Tensor:
9741100
assert output is not None, "Output tensor must be provided."
1101+
if get_sp_context() is not None:
1102+
# SP across DP
1103+
output[...] = self._forward_prefill_sp(hidden_states, kv_cache, attn_metadata)
1104+
return output
9751105
if attn_metadata is None:
9761106
# Profiling run.
9771107
return output
@@ -1024,6 +1154,11 @@ def forward(
10241154
current_ms_metadata.after_comm_event.record()
10251155
else:
10261156
o_proj_input[num_decode_tokens:] = output_prefill
1157+
1158+
# When o_proj sharding is enabled, make sure the second dimension is complete while decoding.
1159+
if isinstance(self.o_proj, RowShardLinear):
1160+
o_proj_input = get_tp_group().all_gather(o_proj_input, dim=-1)
1161+
10271162
# O proj
10281163
current_ms_metadata = get_multistream_comm_context()
10291164
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024

vllm_ascend/distributed/parallel_state.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
9191
backend,
9292
group_name="lmheadtp")
9393

94+
init_ascend_mla_sp_model_parallel()
9495

9596
def get_mlp_tensor_model_parallel_world_size():
9697
"""Return world size for the tensor model parallel group."""
@@ -101,6 +102,46 @@ def get_mlp_tensor_model_parallel_rank():
101102
"""Return world size for the tensor model parallel group."""
102103
return get_mlp_tp_group().rank_in_group
103104

105+
# vllm-ascend will maintain its own MLA SP world GroupCoordinator and o_proj sharding GroupCoordinator for
106+
# customize parallel solution
107+
_MLA_SP_WORLD: Optional[GroupCoordinator] = None
108+
_O_SHARD: Optional[GroupCoordinator] = None
109+
110+
def get_mla_sp_world_group() -> GroupCoordinator:
111+
assert _MLA_SP_WORLD is not None, ("MLA sequence parallel world group is not initialized")
112+
return _MLA_SP_WORLD
113+
114+
def get_o_shard_group() -> GroupCoordinator:
115+
assert _O_SHARD is not None, ("o_proj sharding group is not initialized")
116+
return _O_SHARD
117+
118+
def init_ascend_mla_sp_model_parallel():
119+
from vllm_ascend.ascend_config import get_ascend_config
120+
ascend_config = get_ascend_config()
121+
world_size = torch.distributed.get_world_size()
122+
backend = torch.distributed.get_backend(get_world_group().device_group)
123+
124+
if ascend_config.enable_mla_sp:
125+
assert ascend_config.enable_o_shard, "MLA SP must be enabled with o_proj sharding"
126+
global _MLA_SP_WORLD
127+
group_ranks = [list(range(torch.distributed.get_world_size()))]
128+
_MLA_SP_WORLD = init_model_parallel_group(group_ranks,
129+
get_world_group().local_rank,
130+
backend,
131+
group_name="mla_sp_world")
132+
133+
if ascend_config.enable_o_shard:
134+
o_shard_parallel_size = ascend_config.o_shard_parallel_size
135+
assert o_shard_parallel_size >= 2, "o_shard_parallel_size must be >= 2"
136+
assert world_size % o_shard_parallel_size == 0, "o_shard_parallel_size must be a divisor of world_size"
137+
global _O_SHARD
138+
all_ranks = torch.arange(world_size)
139+
group_ranks = all_ranks.view(-1, o_shard_parallel_size).unbind(0)
140+
group_ranks = [x.tolist() for x in group_ranks]
141+
_O_SHARD = init_model_parallel_group(group_ranks,
142+
get_world_group().local_rank,
143+
backend,
144+
group_name="o_shard")
104145

105146
def destroy_ascend_model_parallel():
106147
global _MC2
@@ -117,3 +158,13 @@ def destroy_ascend_model_parallel():
117158
if _LMTP:
118159
_LMTP.destroy()
119160
_LMTP = None
161+
162+
global _MLA_SP_WORLD
163+
if _MLA_SP_WORLD:
164+
_MLA_SP_WORLD.destroy()
165+
_MLA_SP_WORLD = None
166+
167+
global _O_SHARD
168+
if _O_SHARD:
169+
_O_SHARD.destroy()
170+
_O_SHARD = None

0 commit comments

Comments
 (0)