Skip to content

Optimization of TP4 Parallelism in DeepSeek MLP Dense Layers #1738

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
518572d
update to vllm-ascend main branch
zhanghw0354 Jun 27, 2025
0fdd699
[Test]Add unit test for platform.py
zhanghw0354 Jun 27, 2025
3b3a41b
[Doc] Add guidance on how to implement and register new models (#1426)
shen-shanshan Jun 27, 2025
932fe28
delete dirs and files not exist in vllm-ascend main branch
zhanghw0354 Jun 27, 2025
3482fff
[CI] Pin transformers<4.53.0 and fix EPLB load_weights to make CI pas…
MengqingCao Jun 27, 2025
5119b6e
[Doc] Add Qwen2.5-VL eager mode doc (#1394)
shen-shanshan Jun 28, 2025
e5271a3
[PromptLogprobs][V1] Support prompt logprobs to fix ceval accuracy in…
MengqingCao Jun 28, 2025
ecf3a36
[CI/Build] Fix version conflict on transformers (#1490)
MengqingCao Jun 28, 2025
8955ac2
[BugFix]Fix bugs when initializing communication groups with dp on 30…
Angazenn Jun 28, 2025
0c0e3e0
[PERF]support H2P communication optimization for PanguProMoe (#1463)
Angazenn Jun 28, 2025
6e355ac
[PERF]support MERRouter (#1421)
Angazenn Jun 28, 2025
cfbee30
support pangumoe w8a8c8 and docs (#1477)
GDzhu01 Jun 28, 2025
9c84c1f
Merge branch 'vllm-project:main' into main
zhanghw0354 Jun 30, 2025
8961621
fix codespell check problem with assertIn function
zhanghw0354 Jun 30, 2025
4c1ac3a
Merge branch 'main' of https://github.com/zhanghw0354/vllm-ascend
zhanghw0354 Jun 30, 2025
abf62dc
fix problem in the github pipeline step analysing the code with ruff
zhanghw0354 Jun 30, 2025
4322edc
Merge branch 'main' of https://github.com/zhanghw0354/vllm-ascend
zhanghw0354 Jun 30, 2025
9dbeb8f
fix isort problems
zhanghw0354 Jun 30, 2025
fcf6c3d
fix yapf problems
zhanghw0354 Jun 30, 2025
1fdba8d
Merge branch 'vllm-project:main' into main
zhanghw0354 Jul 1, 2025
da00921
Update the parent class of TestNPUPlatform to TestBase
zhanghw0354 Jul 1, 2025
d248032
Merge branch 'main' of https://github.com/zhanghw0354/vllm-ascend
zhanghw0354 Jul 1, 2025
c234ec2
fix isort check problem
zhanghw0354 Jul 1, 2025
6bc0dc1
Merge branch 'vllm-project:main' into main
zhanghw0354 Jul 2, 2025
4483fce
fix mypy check problem
zhanghw0354 Jul 2, 2025
a0445ab
Merge branch 'main' of https://github.com/zhanghw0354/vllm-ascend
zhanghw0354 Jul 2, 2025
6277056
test deepseek v2 mlp layer tp4
zhanghw0354 Jul 7, 2025
0bb5803
fix import set_weight_attrs problem
zhanghw0354 Jul 7, 2025
9d9da60
fix mlp_tensor_parallel_size problem
zhanghw0354 Jul 7, 2025
2d8e881
add rank log
zhanghw0354 Jul 8, 2025
2fc2129
update rank log
zhanghw0354 Jul 8, 2025
1b54054
sync test_platform.py from vllm-ascend main branch
zhanghw0354 Jul 11, 2025
5712554
sync changes from vllm-ascend main branch and fix conflict codes
zhanghw0354 Jul 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions vllm_ascend/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import torch
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
init_model_parallel_group)
init_model_parallel_group,
logger, get_dp_group, get_pp_group, get_tp_group)

# vllm-ascend will maintain its own EP GroupCoordinator and ETP GroupCoordinator for
# customize parallel solution
_EP: Optional[GroupCoordinator] = None
_ETP: Optional[GroupCoordinator] = None
_MLPTP: Optional[GroupCoordinator] = None


def get_ep_group() -> GroupCoordinator:
Expand All @@ -20,16 +22,29 @@
"expert tensor parallel group is not initialized")
return _ETP

def get_mlptp_group() -> GroupCoordinator:
assert _MLPTP is not None, (
"mlp tensor parallel group is not initialized")
return _MLPTP

def get_mlp_tensor_model_parallel_world_size():
"""Return world size for the mlp tensor model parallel group."""
return get_mlptp_group().world_size

def get_mlp_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
return get_mlptp_group().rank_in_group

def model_parallel_initialized():
return (_ETP is not None and _EP is not None)
return (_ETP is not None and _EP is not None and _MLPTP is not None)


def init_ascend_model_parallel(
expert_parallel_size: int = 1,
expert_tensor_parallel_size: int = 1,
world_size: Optional[int] = None,
backend: Optional[str] = None,
mlp_tensor_parallel_size: Optional[int] = 4,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this optimization be used when TP=8?

):
if model_parallel_initialized():
return
Expand All @@ -39,10 +54,12 @@
get_world_group().device_group)
num_expert_parallel_groups = expert_tensor_parallel_size
num_expert_tensor_parallel_groups = expert_parallel_size
assert world_size % mlp_tensor_parallel_size == 0, "world_size must be divisible by mlp_tensor_parallel_size"
num_mlp_tensor_parallel_groups = world_size // mlp_tensor_parallel_size

global _EP
group_ranks = []

Check failure on line 61 in vllm_ascend/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Unsupported operand types for % ("int" and "None") [operator]

Check failure on line 61 in vllm_ascend/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Unsupported operand types for % ("int" and "None") [operator]
for i in range(num_expert_parallel_groups):

Check failure on line 62 in vllm_ascend/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Unsupported operand types for // ("int" and "None") [operator]

Check failure on line 62 in vllm_ascend/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Unsupported operand types for // ("int" and "None") [operator]
ranks = list(range(i, world_size, num_expert_parallel_groups))
group_ranks.append(ranks)

Expand All @@ -64,6 +81,25 @@
backend,
group_name="etp")

group_ranks = []
global _MLPTP
for i in range(num_mlp_tensor_parallel_groups):
ranks = list(
range(i * mlp_tensor_parallel_size,
(i + 1) * mlp_tensor_parallel_size))
group_ranks.append(ranks)
# Build the mlp tensor model-parallel groups.
_MLPTP = init_model_parallel_group(group_ranks,

Check failure on line 92 in vllm_ascend/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Unsupported operand types for * ("int" and "None") [operator]

Check failure on line 92 in vllm_ascend/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Unsupported operand types for * ("int" and "None") [operator]
get_world_group().local_rank,

Check failure on line 93 in vllm_ascend/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Unsupported operand types for * ("int" and "None") [operator]

Check failure on line 93 in vllm_ascend/distributed/parallel_state.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Unsupported operand types for * ("int" and "None") [operator]
backend,
group_name="mlptp")

logger.info(
"vllm-ascend: rank %s in world size %s is assigned as "
"DP rank %s, PP rank %s, TP rank %s, EP rank %s, MLP TP rank %s", torch.distributed.get_rank(), world_size,
get_dp_group().rank_in_group, get_pp_group().rank_in_group, get_tp_group().rank_in_group,
_EP.rank_in_group, _MLPTP.rank_in_group)


def destory_ascend_model_parallel():
global _EP
Expand All @@ -75,3 +111,8 @@
if _ETP:
_ETP.destroy()
_ETP = None

global _MLPTP
if _MLPTP:
_MLPTP.destroy()
_MLPTP = None
Loading
Loading