|
35 | 35 | AlltoAllCommImpl, MC2CommImpl)
|
36 | 36 | from vllm_ascend.ops.moe.token_dispatcher import setup_token_dispatchers
|
37 | 37 | from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
|
| 38 | +import torch.nn.functional as F |
| 39 | +from vllm.distributed import (get_tensor_model_parallel_rank, |
| 40 | + get_tensor_model_parallel_world_size) |
38 | 41 |
|
39 | 42 | original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
40 | 43 |
|
@@ -305,17 +308,30 @@ def maybe_all_reduce_tensor_model_parallel(
|
305 | 308 | """
|
306 | 309 | forward_context = get_forward_context()
|
307 | 310 | moe_comm_method_name = forward_context.moe_comm_method_name
|
| 311 | + flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled |
308 | 312 | if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
|
| 313 | + if flashcomm_v1_enabled: |
| 314 | + pad_size = forward_context.pad_size |
| 315 | + if pad_size > 0: |
| 316 | + final_hidden_states = F.pad(final_hidden_states, (0, 0, 0, pad_size)) |
| 317 | + tp_size = get_tensor_model_parallel_world_size() |
| 318 | + tp_rank = get_tensor_model_parallel_rank() |
| 319 | + final_hidden_states = torch.chunk(final_hidden_states, tp_size, dim=0)[tp_rank] |
309 | 320 | return final_hidden_states
|
310 | 321 | else:
|
311 |
| - return tensor_model_parallel_all_reduce(final_hidden_states) |
| 322 | + return torch.ops.vllm.maybe_pad_and_reduce(final_hidden_states) |
312 | 323 |
|
313 | 324 | def forward_impl(self, hidden_states: torch.Tensor,
|
314 | 325 | router_logits: torch.Tensor):
|
315 | 326 | assert self.quant_method is not None
|
316 | 327 |
|
317 | 328 | forward_context = get_forward_context()
|
318 | 329 | moe_comm_method_name = forward_context.moe_comm_method_name
|
| 330 | + flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled |
| 331 | + if flashcomm_v1_enabled: |
| 332 | + hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(hidden_states, True) |
| 333 | + if router_logits.shape[0] != hidden_states.shape[0]: |
| 334 | + router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(router_logits, True) |
319 | 335 |
|
320 | 336 | forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
|
321 | 337 |
|
|
0 commit comments