Skip to content

Commit 9dfd075

Browse files
author
unknown
committed
fc1 for glm
1 parent 4c380f3 commit 9dfd075

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

vllm_ascend/ops/common_fused_moe.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
AlltoAllCommImpl, MC2CommImpl)
3636
from vllm_ascend.ops.moe.token_dispatcher import setup_token_dispatchers
3737
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
38+
import torch.nn.functional as F
39+
from vllm.distributed import (get_tensor_model_parallel_rank,
40+
get_tensor_model_parallel_world_size)
3841

3942
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
4043

@@ -305,17 +308,30 @@ def maybe_all_reduce_tensor_model_parallel(
305308
"""
306309
forward_context = get_forward_context()
307310
moe_comm_method_name = forward_context.moe_comm_method_name
311+
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
308312
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
313+
if flashcomm_v1_enabled:
314+
pad_size = forward_context.pad_size
315+
if pad_size > 0:
316+
final_hidden_states = F.pad(final_hidden_states, (0, 0, 0, pad_size))
317+
tp_size = get_tensor_model_parallel_world_size()
318+
tp_rank = get_tensor_model_parallel_rank()
319+
final_hidden_states = torch.chunk(final_hidden_states, tp_size, dim=0)[tp_rank]
309320
return final_hidden_states
310321
else:
311-
return tensor_model_parallel_all_reduce(final_hidden_states)
322+
return torch.ops.vllm.maybe_pad_and_reduce(final_hidden_states)
312323

313324
def forward_impl(self, hidden_states: torch.Tensor,
314325
router_logits: torch.Tensor):
315326
assert self.quant_method is not None
316327

317328
forward_context = get_forward_context()
318329
moe_comm_method_name = forward_context.moe_comm_method_name
330+
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
331+
if flashcomm_v1_enabled:
332+
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(hidden_states, True)
333+
if router_logits.shape[0] != hidden_states.shape[0]:
334+
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(router_logits, True)
319335

320336
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
321337

vllm_ascend/ops/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def __init__(
151151
comm_group = get_tp_group()
152152
self.forward_type = "matmul_allreduce"
153153
self.hcomm_info = self.get_hcomm_info(comm_group.device_group)
154-
elif dense_optim_enable():
154+
elif prefix.find("shared_experts") == -1 and dense_optim_enable():
155155
comm_group = get_tp_group()
156156
self.forward_type = "dense_optim"
157157
else:

0 commit comments

Comments
 (0)