Skip to content

Commit 19b587b

Browse files
committed
fix accuracy problem in dp situation
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent 75e28d0 commit 19b587b

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

vllm_ascend/distributed/communicator.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import torch.distributed as dist
2121
from vllm.distributed.device_communicators.base_device_communicator import \
2222
DeviceCommunicatorBase
23+
from vllm.forward_context import get_forward_context
24+
from vllm.distributed.parallel_state import get_dp_group
2325

2426

2527
class NPUCommunicator(DeviceCommunicatorBase):
@@ -73,3 +75,47 @@ def all_to_all(self,
7375
dist.all_to_all(output_list, input_list, group=self.device_group)
7476
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
7577
return output_tensor
78+
79+
def naive_multicast(self, x: torch.Tensor,
80+
cu_tokens_across_dp_cpu: torch.Tensor):
81+
assert (len(x.shape) == 2)
82+
dp_group = get_dp_group()
83+
dp_rank = dp_group.rank_in_group
84+
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
85+
device=x.device,
86+
dtype=x.dtype)
87+
88+
start = 0 if dp_rank == 0 else cu_tokens_across_dp_cpu[
89+
dp_rank - 1]
90+
end = cu_tokens_across_dp_cpu[dp_rank]
91+
buffer[start:end, :].copy_(x)
92+
for idx in range(dp_group.world_size):
93+
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
94+
end = cu_tokens_across_dp_cpu[idx]
95+
dp_group.broadcast(buffer[start:end, :], idx)
96+
97+
return buffer
98+
99+
def dispatch(self, hidden_states: torch.Tensor,
100+
router_logits: torch.Tensor):
101+
cu_tokens_across_dp_cpu = get_forward_context(
102+
).dp_metadata.cu_tokens_across_dp_cpu
103+
104+
hidden_states = self.naive_multicast(hidden_states,
105+
cu_tokens_across_dp_cpu)
106+
router_logits = self.naive_multicast(router_logits,
107+
cu_tokens_across_dp_cpu)
108+
return hidden_states, router_logits
109+
110+
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
111+
cu_tokens_across_dp_cpu = get_forward_context(
112+
).dp_metadata.cu_tokens_across_dp_cpu
113+
dp_group = get_dp_group()
114+
dp_rank = dp_group.rank_in_group
115+
start = 0 if dp_rank == 0 else cu_tokens_across_dp_cpu[
116+
dp_rank - 1]
117+
end = cu_tokens_across_dp_cpu[dp_rank]
118+
119+
all_hidden_states = dp_group.all_reduce(hidden_states)
120+
hidden_states = all_hidden_states[start:end, :]
121+
return hidden_states

0 commit comments

Comments
 (0)