Skip to content

Commit 3c27ba5

Browse files
author
angazenn
committed
add more comments for force load balance
Signed-off-by: angazenn <zengyanjia@huawei.com>
1 parent fda8418 commit 3c27ba5

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
2626
# """Inference-only DeepseekV2/DeepseekV3 model."""
2727

28-
import os
2928
from typing import Any, Dict, List, Optional, Union
3029

3130
import torch
@@ -213,8 +212,8 @@ def __init__(
213212

214213
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
215214
attn_metadata = get_forward_context().attn_metadata
216-
# when profile runs, force experts load balance to avoid high memory
217-
# consumption from 1 rank.
215+
# when profile runs, force experts to load balanced tokens
216+
# to avoid high memory consumption on a single rank.
218217
# TODO: need a better flag to indicate whether in profile run or not.
219218
if attn_metadata is None or attn_metadata.slot_mapping[-1] < 0:
220219
# for profile run

vllm_ascend/ops/fused_moe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# This file is a part of the vllm-ascend project.
1616
# Adapted from vllm/tests/kernels/test_moe.py
1717

18-
import os
1918
from typing import Callable, Optional
2019

2120
import torch

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,15 @@
1515
# limitations under the License.
1616
#
1717

18-
import os
1918
from typing import Any, Callable, Dict, List, Optional
2019

2120
import torch
2221
import torch.distributed as dist
2322
import torch_npu
2423
from vllm.distributed import GroupCoordinator
2524

26-
from vllm_ascend.distributed.parallel_state import get_ep_group
2725
import vllm_ascend.envs as envs_ascend
26+
from vllm_ascend.distributed.parallel_state import get_ep_group
2827
from vllm_ascend.ops.fused_moe import select_experts
2928

3029
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
@@ -200,6 +199,8 @@ def fused_experts_with_mc2(
200199
return hidden_states
201200

202201

202+
# currently expert parallelism implemented with all2all
203+
# is under-optimized.
203204
def fused_experts_with_all2all(
204205
hidden_states: torch.Tensor,
205206
w1: torch.Tensor,
@@ -616,6 +617,9 @@ def apply(
616617
e_score_correction_bias=e_score_correction_bias,
617618
)
618619

620+
# this is a naive implementation for experts load balance so as
621+
# to avoid accumulating too much tokens on a single rank.
622+
# currently it is only activated when doing profile runs.
619623
if enable_force_load_balance:
620624
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
621625

0 commit comments

Comments
 (0)