Skip to content

Commit 7c82ee5

Browse files
author
angazenn
committed
add all2all
Signed-off-by: angazenn <zengyanjia@huawei.com>
1 parent cdece86 commit 7c82ee5

File tree

5 files changed

+238
-15
lines changed

5 files changed

+238
-15
lines changed

vllm_ascend/distributed/communicator.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
# limitations under the License.
1515
# This file is a part of the vllm-ascend project.
1616
#
17-
from typing import Optional
17+
from typing import Optional, List
1818

1919
import torch
20+
import torch.distributed as dist
2021
from torch.distributed import ProcessGroup
2122
from vllm.distributed.device_communicators.base_device_communicator import \
2223
DeviceCommunicatorBase
@@ -33,3 +34,58 @@ def __init__(self,
3334
# TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator
3435
# init device according to rank
3536
self.device = torch.npu.current_device()
37+
38+
def all_to_all(self,
39+
input_: torch.Tensor,
40+
scatter_dim: int = 0,
41+
gather_dim: int = -1,
42+
scatter_sizes: List[int] = None,
43+
gather_sizes: List[int] = None) -> torch.Tensor:
44+
45+
if scatter_dim < 0:
46+
scatter_dim += input_.dim()
47+
if gather_dim < 0:
48+
gather_dim += input_.dim()
49+
50+
if scatter_sizes is not None and gather_sizes is not None:
51+
input_list = [t.contiguous() for t in torch.split(input_, scatter_sizes, scatter_dim)]
52+
output_list = []
53+
tensor_shape_base = input_list[self.rank].size()
54+
for i in range(self.world_size):
55+
tensor_shape = list(tensor_shape_base)
56+
tensor_shape[gather_dim] = gather_sizes[i]
57+
output_list.append(torch.empty(tensor_shape, dtype=input_.dtype, device=input_.device))
58+
59+
else:
60+
input_list = [
61+
t.contiguous()
62+
for t in torch.tensor_split(input_, self.world_size, scatter_dim)
63+
]
64+
output_list = [torch.empty_like(input_list[i]) for i in range(self.world_size)]
65+
66+
dist.all_to_all(output_list, input_list, group=self.device_group)
67+
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
68+
return output_tensor
69+
70+
def reduce_scatter(self,
71+
input_: torch.Tensor,
72+
scatter_dim: int = 0) -> torch.Tensor:
73+
74+
if scatter_dim < 0:
75+
scatter_dim += input_.dim()
76+
if scatter_dim != 0:
77+
input_ = torch.transpose(input_, 0, scatter_dim)
78+
dim_size = list(input_.size())
79+
dim_size[0] = dim_size[0] // self.world_size
80+
output_tensor = torch.empty(dim_size,
81+
dtype=input_.dtype,
82+
device=input_.device)
83+
dist.reduce_scatter_tensor(output_tensor,
84+
input_.contiguous(),
85+
group=self.device_group)
86+
if scatter_dim != 0:
87+
output_tensor = torch.transpose(output_tensor, 0,
88+
scatter_dim).contiguous()
89+
return output_tensor
90+
91+

vllm_ascend/ops/fused_moe.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,7 @@ def forward(self,
627627
if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore
628628
) == 1 and not is_prefill:
629629
...
630+
else:
630631
elif int(os.environ.get("USING_LCCL_COM",
631632
'0')) == 1: # type: ignore
632633
hidden_states = get_dp_group().all_gather(
@@ -652,18 +653,19 @@ def forward(self,
652653
custom_routing_function=self.custom_routing_function,
653654
scoring_func=self.scoring_func,
654655
e_score_correction_bias=self.e_score_correction_bias,
655-
is_prefill=is_prefill)
656+
is_prefill=is_prefill,
657+
dp_size=self.dp_size)
656658

657659
if self.dp_size > 1:
658660
if int(os.environ.get("VLLM_ENABLE_MC2", '0') # type: ignore
659661
) == 1 and not is_prefill:
660662
...
661-
else:
662-
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
663-
final_hidden_states,
664-
"sum",
665-
scatter_dim=0,
666-
group=get_dp_group().device_group)
663+
else:
664+
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
665+
final_hidden_states,
666+
"sum",
667+
scatter_dim=0,
668+
group=get_dp_group().device_group)
667669

668670
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
669671
final_hidden_states = tensor_model_parallel_all_reduce(

vllm_ascend/patch/platform/patch_common/patch_distributed.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,18 @@
1616
# limitations under the License.
1717
# Adapted from vllm/model_executor/models/qwen2_vl.py
1818
# This file is a part of the vllm-ascend project.
19+
from typing import Optional, List
1920

2021
import torch
21-
import vllm
22-
import vllm.distributed
2322
from torch.distributed import ProcessGroup
2423
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
2524
_get_default_timeout,
2625
is_nccl_available)
2726
from torch.distributed.rendezvous import rendezvous
27+
import vllm
2828
from vllm.config import ParallelConfig
29+
import vllm.distributed
30+
from vllm.distributed.parallel_state import GroupCoordinator
2931

3032

3133
def ascend_destroy_model_parallel():
@@ -176,6 +178,40 @@ def ascend_stateless_init_dp_group(self) -> "ProcessGroup":
176178
return dp_group
177179

178180

181+
class GroupCoordinatorPatch(GroupCoordinator):
182+
183+
def __init__(self, *args, **kwargs):
184+
super().__init__(*args, **kwargs)
185+
186+
def all_to_all(self,
187+
input_: torch.Tensor,
188+
scatter_dim: int = 0,
189+
gather_dim: int = -1,
190+
scatter_sizes: List[int] = None,
191+
gather_sizes: List[int] = None) -> torch.Tensor:
192+
if self.world_size == 1:
193+
return input_
194+
assert -input_.dim() <= scatter_dim < input_.dim(), (
195+
f"Invalid scatter dim ({scatter_dim}) for input tensor with shape {input_.size()}"
196+
)
197+
assert -input_.dim() <= gather_dim < input_.dim(), (
198+
f"Invalid gather dim ({gather_dim}) for input tensor with shape {input_.size()}"
199+
)
200+
return self.device_communicator.all_to_all(input_, scatter_dim, gather_dim,
201+
scatter_sizes, gather_sizes)
202+
203+
def reduce_scatter(self,
204+
input_: torch.Tensor,
205+
scatter_dim: int = 0) -> torch.Tensor:
206+
if self.world_size == 1:
207+
return input_
208+
assert -input_.dim() <= scatter_dim < input_.dim(), (
209+
f"Invalid scatter dim ({scatter_dim}) for input tensor with shape {input_.size()}"
210+
)
211+
return self.device_communicator.reduce_scatter(input_, scatter_dim)
212+
213+
214+
vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch # Note: check the GroupCoordinator with online serving
179215
vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel
180216
vllm.distributed.stateless_init_torch_distributed_process_group = ascend_stateless_init_torch_distributed_process_group
181217
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port

vllm_ascend/quantization/quant_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,14 +317,16 @@ def apply(
317317
scoring_func: str = "softmax",
318318
e_score_correction_bias: Optional[torch.Tensor] = None,
319319
is_prefill: bool = True,
320+
dp_size: int = 1
320321
**kwargs,
321322
) -> torch.Tensor:
322323
return self.quant_method.apply(layer, x, router_logits, top_k,
323324
renormalize, use_grouped_topk,
324325
global_num_experts, expert_map,
325326
topk_group, num_expert_group,
326327
custom_routing_function, scoring_func,
327-
e_score_correction_bias, is_prefill)
328+
e_score_correction_bias, is_prefill,
329+
dp_size)
328330

329331
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
330332
if hasattr(self.quant_method, "process_weights_after_loading"):

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 131 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020

2121
import torch
2222
import torch_npu
23-
23+
import torch.distributed as dist
24+
from vllm.distributed import GroupCoordinator
2425
from vllm_ascend.distributed.parallel_state import get_ep_group
2526
from vllm_ascend.ops.fused_moe import select_experts
2627

@@ -201,6 +202,120 @@ def fused_experts_with_mc2(
201202
return hidden_states
202203

203204

205+
def fused_experts_with_all2all(hidden_states: torch.Tensor,
206+
w1: torch.Tensor,
207+
w1_scale: torch.Tensor,
208+
w2: torch.Tensor,
209+
w2_scale: torch.Tensor,
210+
topk_weights: torch.Tensor,
211+
topk_ids: torch.Tensor,
212+
top_k: int,
213+
expert_map: torch.Tensor = None,
214+
ep_group: GroupCoordinator = None,
215+
):
216+
original_shape = hidden_states.shape
217+
if len(original_shape) == 3:
218+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
219+
220+
num_tokens, _ = hidden_states.shape
221+
num_experts = w1.shape[0]
222+
dtype = hidden_states.dtype
223+
device = hidden_states.device
224+
225+
if expert_map is not None:
226+
global_num_experts = len(expert_map)
227+
local_num_experts = global_num_experts // ep_group.world_size
228+
row_idx_len = num_tokens * top_k
229+
row_idx = (torch.arange(0,
230+
row_idx_len,
231+
dtype=torch.int32,
232+
device=device).view(top_k, -1).permute(
233+
1, 0).contiguous())
234+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
235+
hidden_states,
236+
row_idx=row_idx,
237+
expert_idx=topk_ids,
238+
active_num=num_tokens)
239+
240+
global_expert_tokens = torch.bincount(expanded_expert_idx, minlength=global_num_experts)
241+
scatter_sizes = global_expert_tokens.view(ep_group.world_size, -1).sum(-1)
242+
243+
gather_sizes = torch.empty_like(scatter_sizes)
244+
dist.all_to_all_single(gather_sizes, scatter_sizes, group=ep_group.device_group)
245+
scatter_size_list = scatter_sizes.cpu().tolist()
246+
gather_size_list = gather_sizes.cpu().tolist()
247+
248+
expanded_expert_idx = expanded_expert_idx % local_num_experts
249+
hidden_states = ep_group.all_to_all(hidden_states, 0, 0, scatter_size_list, gather_size_list)
250+
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0, scatter_size_list, gather_size_list)
251+
252+
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
253+
254+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(sorted_local_expert_idx, local_num_experts).to(
255+
torch.int64)
256+
257+
hidden_states = hidden_states[sorted_idx]
258+
group_list_type = 0
259+
else:
260+
row_idx_len = num_tokens * top_k
261+
row_idx = torch.arange(0,
262+
row_idx_len,
263+
dtype=torch.int32,
264+
device=topk_weights.device).view(
265+
top_k, -1).permute(1, 0).contiguous()
266+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
267+
hidden_states,
268+
row_idx=row_idx,
269+
expert_idx=topk_ids,
270+
active_num=num_tokens)
271+
272+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
273+
expanded_expert_idx, num_experts)
274+
expert_tokens = expert_tokens.to(torch.int64)
275+
group_list_type = 0
276+
277+
hidden_states_wrapper = [hidden_states]
278+
del hidden_states
279+
280+
hidden_states = apply_mlp(hidden_states_wrapper,
281+
w1,
282+
w1_scale,
283+
w2,
284+
w2_scale,
285+
expert_tokens,
286+
group_list_type=group_list_type)
287+
288+
if expert_map is not None:
289+
resorted_idx = torch.argsort(sorted_idx)
290+
hidden_states = hidden_states[resorted_idx]
291+
hidden_states = ep_group.all_to_all(hidden_states, 0, 0, gather_size_list, scatter_size_list)
292+
293+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
294+
hidden_states,
295+
skip1=None,
296+
skip2=None,
297+
bias=None,
298+
scales=topk_weights,
299+
expanded_src_to_dst_row=expanded_row_idx,
300+
export_for_source_row=topk_ids,
301+
)
302+
else:
303+
# TODO: Reorder device memory 2 times here, replace the current
304+
# implementation here when suitable operators become available.
305+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
306+
hidden_states,
307+
skip1=None,
308+
skip2=None,
309+
bias=None,
310+
scales=topk_weights,
311+
expanded_src_to_dst_row=expanded_row_idx,
312+
export_for_source_row=topk_ids,
313+
)
314+
if len(original_shape) == 3:
315+
final_hidden_states = final_hidden_states.view(original_shape)
316+
return final_hidden_states
317+
318+
204319
def fused_experts(hidden_states: torch.Tensor,
205320
w1: torch.Tensor,
206321
w1_scale: torch.Tensor,
@@ -387,10 +502,10 @@ class AscendW8A8DynamicFusedMoEMethod:
387502
def __init__(self):
388503
self.transpose_weight = True
389504

390-
ep_group = get_ep_group()
505+
self.ep_group = get_ep_group()
391506

392507
try:
393-
device_group = ep_group.device_group
508+
device_group = self.ep_group.device_group
394509
# TODO: Try local_rank = ep_group.rank_in_group
395510
local_rank = torch.distributed.get_rank(group=device_group)
396511
backend = device_group._get_backend(torch.device("npu"))
@@ -457,6 +572,7 @@ def apply(
457572
scoring_func: str = "softmax",
458573
e_score_correction_bias: Optional[torch.Tensor] = None,
459574
is_prefill: bool = True,
575+
dp_size: int = 1,
460576
**kwargs,
461577
) -> torch.Tensor:
462578
assert router_logits.shape[
@@ -503,7 +619,7 @@ def apply(
503619
top_k=top_k,
504620
expert_map=expert_map,
505621
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
506-
else:
622+
elif dp_size == 1:
507623
return fused_experts(hidden_states=x,
508624
w1=layer.w13_weight,
509625
w1_scale=layer.w13_weight_scale,
@@ -513,6 +629,17 @@ def apply(
513629
topk_ids=topk_ids,
514630
top_k=top_k,
515631
expert_map=expert_map)
632+
else:
633+
return fused_experts_with_all2all(hidden_states=x,
634+
w1=layer.w13_weight,
635+
w1_scale=layer.w13_weight_scale,
636+
w2=layer.w2_weight,
637+
w2_scale=layer.w2_weight_scale,
638+
topk_weights=topk_weights,
639+
topk_ids=topk_ids,
640+
top_k=top_k,
641+
expert_map=expert_map,
642+
ep_group=self.ep_group)
516643

517644
def process_weights_after_loading(self, layer):
518645
if self.transpose_weight:

0 commit comments

Comments
 (0)