Skip to content
48 changes: 44 additions & 4 deletions vllm_ascend/distributed/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,62 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import Optional
from typing import List, Optional

import torch
from torch.distributed import ProcessGroup
import torch.distributed as dist
from vllm.distributed.device_communicators.base_device_communicator import \
DeviceCommunicatorBase


class NPUCommunicator(DeviceCommunicatorBase):

def __init__(self,
cpu_group: ProcessGroup,
cpu_group: dist.ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
device_group: Optional[dist.ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
# TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator
# init device according to rank
self.device = torch.npu.current_device()

def all_to_all(self,
input_: torch.Tensor,
scatter_dim: int = 0,
gather_dim: int = -1,
scatter_sizes: Optional[List[int]] = None,
gather_sizes: Optional[List[int]] = None) -> torch.Tensor:

if scatter_dim < 0:
scatter_dim += input_.dim()
if gather_dim < 0:
gather_dim += input_.dim()

if scatter_sizes is not None and gather_sizes is not None:
input_list = [
t.contiguous()
for t in torch.split(input_, scatter_sizes, scatter_dim)
]
output_list = []
tensor_shape_base = input_list[self.rank].size()
for i in range(self.world_size):
tensor_shape = list(tensor_shape_base)
tensor_shape[gather_dim] = gather_sizes[i]
output_list.append(
torch.empty(tensor_shape,
dtype=input_.dtype,
device=input_.device))

else:
input_list = [
t.contiguous() for t in torch.tensor_split(
input_, self.world_size, scatter_dim)
]
output_list = [
torch.empty_like(input_list[i]) for i in range(self.world_size)
]

dist.all_to_all(output_list, input_list, group=self.device_group)
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
return output_tensor
64 changes: 40 additions & 24 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,50 +205,66 @@ def __init__(
)
CustomDeepseekV2MoE.top_k = config.num_experts_per_tok

vllm_config = get_current_vllm_config()
self.dp_size = get_dp_group().world_size
batch_size = vllm_config.scheduler_config.max_num_seqs

params_dtype = torch.get_default_dtype()
self.final_hidden_states = torch.zeros(
[batch_size, config.hidden_size], dtype=params_dtype, device="npu")
self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
attn_metadata = get_forward_context().attn_metadata
# when profile runs, force experts to load balanced tokens
# to avoid high memory consumption on a single rank.
# TODO: need a better flag to indicate whether in profile run or not.
if attn_metadata is None:
# for profile run
is_prefill = True
enable_force_load_balance = True
else:
is_prefill = attn_metadata.num_prefills > 0
enable_force_load_balance = False
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)

if (self.tp_size > 1 and VLLM_ENABLE_MC2 and not is_prefill):
chunks = torch.chunk(hidden_states,
get_tp_group().world_size,
dim=0)
hidden_states = chunks[get_tp_group().rank_in_group]
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)

if self.tp_size > 1:
# pass
num_tokens, hidden_size = hidden_states.shape
if num_tokens < self.tp_size:
target_size = self.tp_size
new_hidden_states = torch.empty([target_size, hidden_size],
dtype=hidden_states.dtype,
device=hidden_states.device)
new_hidden_states[:num_tokens] = hidden_states
hidden_states = new_hidden_states
chunk_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
local_hidden_states = chunk_hidden_states[self.tp_rank]
else:
local_hidden_states = hidden_states

# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
router_logits, _ = self.gate(local_hidden_states)

final_hidden_states = self.experts(
hidden_states=hidden_states,
router_hidden_states = self.experts(
hidden_states=local_hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=CustomDeepseekV2MoE.top_k) * self.routed_scaling_factor
top_k=CustomDeepseekV2MoE.top_k,
enable_force_load_balance=enable_force_load_balance,
) * self.routed_scaling_factor

if self.tp_size > 1:
if VLLM_ENABLE_MC2 and not is_prefill:
dist.all_gather_into_tensor(self.final_hidden_states,
final_hidden_states, self.tp_group)
final_hidden_states = self.final_hidden_states
else:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
dist.all_gather(list(chunk_hidden_states), router_hidden_states,
self.tp_group)
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
if num_tokens < self.tp_size:
final_hidden_states = final_hidden_states[:num_tokens]
else:
final_hidden_states = router_hidden_states

if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output

return final_hidden_states.view(num_tokens, hidden_dim)
Expand Down
32 changes: 9 additions & 23 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing import Callable, Optional

import torch
import torch.distributed as dist
import torch_npu
from vllm.config import get_current_vllm_config
from vllm.distributed import tensor_model_parallel_all_reduce
Expand Down Expand Up @@ -636,6 +635,7 @@ def forward(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_prefill: bool,
enable_force_load_balance: bool = False,
top_k=None):
assert self.quant_method is not None

Expand All @@ -644,17 +644,8 @@ def forward(self,
else:
real_top_k = self.top_k

if self.dp_size > 1:
if VLLM_ENABLE_MC2 and not is_prefill:
...
elif USING_LCCL_COM: # type: ignore
hidden_states = get_dp_group().all_gather(
hidden_states, 0, False)
router_logits = get_dp_group().all_gather(
router_logits, 0, False)
else:
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0)
if VLLM_ENABLE_MC2 and not is_prefill:
...

# Matrix multiply.
final_hidden_states = self.quant_method.apply(
Expand All @@ -671,17 +662,12 @@ def forward(self,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
is_prefill=is_prefill)

if self.dp_size > 1:
if VLLM_ENABLE_MC2 and not is_prefill:
...
else:
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
final_hidden_states,
"sum",
scatter_dim=0,
group=get_dp_group().device_group)
is_prefill=is_prefill,
enable_force_load_balance=enable_force_load_balance,
dp_size=self.dp_size)

if VLLM_ENABLE_MC2 and not is_prefill:
...

if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(
Expand Down
2 changes: 2 additions & 0 deletions vllm_ascend/patch/worker/patch_common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

# patch_utils should be the first import, because it will be used by other
# patch files.
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
Expand Down
49 changes: 49 additions & 0 deletions vllm_ascend/patch/worker/patch_common/patch_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import List, Optional

import torch
import vllm
from vllm.distributed.parallel_state import GroupCoordinator


class GroupCoordinatorPatch(GroupCoordinator):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def all_to_all(self,
input_: torch.Tensor,
scatter_dim: int = 0,
gather_dim: int = -1,
scatter_sizes: Optional[List[int]] = None,
gather_sizes: Optional[List[int]] = None) -> torch.Tensor:
if self.world_size == 1:
return input_
assert -input_.dim() <= scatter_dim < input_.dim(), (
f"Invalid scatter dim ({scatter_dim}) for input tensor with shape {input_.size()}"
)
assert -input_.dim() <= gather_dim < input_.dim(), (
f"Invalid gather dim ({gather_dim}) for input tensor with shape {input_.size()}"
)
return self.device_communicator.all_to_all(input_, scatter_dim,
gather_dim, scatter_sizes,
gather_sizes)


vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch # Note: check the GroupCoordinator with online serving
13 changes: 7 additions & 6 deletions vllm_ascend/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,14 +321,15 @@ def apply(
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
dp_size: int = 1,
**kwargs,
) -> torch.Tensor:
return self.quant_method.apply(layer, x, router_logits, top_k,
renormalize, use_grouped_topk,
global_num_experts, expert_map,
topk_group, num_expert_group,
custom_routing_function, scoring_func,
e_score_correction_bias, is_prefill)
return self.quant_method.apply(
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
global_num_experts, expert_map, topk_group, num_expert_group,
custom_routing_function, scoring_func, e_score_correction_bias,
is_prefill, enable_force_load_balance, dp_size)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
Expand Down
Loading