|
20 | 20 | import torch.distributed as dist
|
21 | 21 | from vllm.distributed.device_communicators.base_device_communicator import \
|
22 | 22 | DeviceCommunicatorBase
|
| 23 | +from vllm.forward_context import get_forward_context |
| 24 | +from vllm.distributed.parallel_state import get_dp_group |
23 | 25 |
|
24 | 26 |
|
25 | 27 | class NPUCommunicator(DeviceCommunicatorBase):
|
@@ -73,3 +75,47 @@ def all_to_all(self,
|
73 | 75 | dist.all_to_all(output_list, input_list, group=self.device_group)
|
74 | 76 | output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
|
75 | 77 | return output_tensor
|
| 78 | + |
| 79 | + def naive_multicast(self, x: torch.Tensor, |
| 80 | + cu_tokens_across_dp_cpu: torch.Tensor): |
| 81 | + assert (len(x.shape) == 2) |
| 82 | + dp_group = get_dp_group() |
| 83 | + dp_rank = dp_group.rank_in_group |
| 84 | + buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), |
| 85 | + device=x.device, |
| 86 | + dtype=x.dtype) |
| 87 | + |
| 88 | + start = 0 if dp_rank == 0 else cu_tokens_across_dp_cpu[ |
| 89 | + dp_rank - 1] |
| 90 | + end = cu_tokens_across_dp_cpu[dp_rank] |
| 91 | + buffer[start:end, :].copy_(x) |
| 92 | + for idx in range(dp_group.world_size): |
| 93 | + start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] |
| 94 | + end = cu_tokens_across_dp_cpu[idx] |
| 95 | + dp_group.broadcast(buffer[start:end, :], idx) |
| 96 | + |
| 97 | + return buffer |
| 98 | + |
| 99 | + def dispatch(self, hidden_states: torch.Tensor, |
| 100 | + router_logits: torch.Tensor): |
| 101 | + cu_tokens_across_dp_cpu = get_forward_context( |
| 102 | + ).dp_metadata.cu_tokens_across_dp_cpu |
| 103 | + |
| 104 | + hidden_states = self.naive_multicast(hidden_states, |
| 105 | + cu_tokens_across_dp_cpu) |
| 106 | + router_logits = self.naive_multicast(router_logits, |
| 107 | + cu_tokens_across_dp_cpu) |
| 108 | + return hidden_states, router_logits |
| 109 | + |
| 110 | + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 111 | + cu_tokens_across_dp_cpu = get_forward_context( |
| 112 | + ).dp_metadata.cu_tokens_across_dp_cpu |
| 113 | + dp_group = get_dp_group() |
| 114 | + dp_rank = dp_group.rank_in_group |
| 115 | + start = 0 if dp_rank == 0 else cu_tokens_across_dp_cpu[ |
| 116 | + dp_rank - 1] |
| 117 | + end = cu_tokens_across_dp_cpu[dp_rank] |
| 118 | + |
| 119 | + all_hidden_states = dp_group.all_reduce(hidden_states) |
| 120 | + hidden_states = all_hidden_states[start:end, :] |
| 121 | + return hidden_states |
0 commit comments