Skip to content

Commit 544f8ed

Browse files
committed
[bugfix] fix wasted NPU memory buffer allocation for all_to_all_single operation
Signed-off-by: linfeng-yuan <1102311262@qq.com>
1 parent 12bcbd0 commit 544f8ed

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -435,8 +435,9 @@ def torchair_fused_experts_with_all2all(
435435

436436
gather_sizes = global_expert_tokens.new_empty(
437437
global_expert_tokens.shape[0])
438-
dist.all_to_all_single(gather_sizes, global_expert_tokens)
439-
438+
dist.all_to_all_single(gather_sizes,
439+
global_expert_tokens,
440+
group=ep_group.device_group)
440441
token_counts_combined = torch.stack(
441442
[gather_sizes, global_expert_tokens], dim=0)
442443
token_counts_combined = token_counts_combined.view(
@@ -451,10 +452,16 @@ def torchair_fused_experts_with_all2all(
451452
gather_size_list = token_counts_combined_cpu[1]
452453
scatter_size_list = token_counts_combined_cpu[0]
453454

454-
dist.all_to_all_single(gathered_tokens, quantized_tokens,
455-
scatter_size_list, gather_size_list)
456-
dist.all_to_all_single(dynamic_scale, token_scales, scatter_size_list,
457-
gather_size_list)
455+
dist.all_to_all_single(gathered_tokens,
456+
quantized_tokens,
457+
scatter_size_list,
458+
gather_size_list,
459+
group=ep_group.device_group)
460+
dist.all_to_all_single(dynamic_scale,
461+
token_scales,
462+
scatter_size_list,
463+
gather_size_list,
464+
group=ep_group.device_group)
458465

459466
hidden_states, dynamic_scale, inverse_indices, expert_tokens = torch_npu.npu_moe_re_routing(
460467
gathered_tokens,
@@ -502,9 +509,11 @@ def torchair_fused_experts_with_all2all(
502509
index=inverse_indices.to(torch.float32).argsort().to(torch.int32))
503510

504511
hidden_states = reordered_outputs.new_empty(*quantized_tokens.shape)
505-
dist.all_to_all_single(hidden_states, reordered_outputs,
506-
gather_size_list, scatter_size_list)
507-
512+
dist.all_to_all_single(hidden_states,
513+
reordered_outputs,
514+
gather_size_list,
515+
scatter_size_list,
516+
group=ep_group.device_group)
508517
final_hidden_states = torch_npu.npu_moe_finalize_routing(
509518
hidden_states,
510519
skip1=None,

0 commit comments

Comments
 (0)