@@ -435,8 +435,9 @@ def torchair_fused_experts_with_all2all(
435
435
436
436
gather_sizes = global_expert_tokens .new_empty (
437
437
global_expert_tokens .shape [0 ])
438
- dist .all_to_all_single (gather_sizes , global_expert_tokens )
439
-
438
+ dist .all_to_all_single (gather_sizes ,
439
+ global_expert_tokens ,
440
+ group = ep_group .device_group )
440
441
token_counts_combined = torch .stack (
441
442
[gather_sizes , global_expert_tokens ], dim = 0 )
442
443
token_counts_combined = token_counts_combined .view (
@@ -451,10 +452,16 @@ def torchair_fused_experts_with_all2all(
451
452
gather_size_list = token_counts_combined_cpu [1 ]
452
453
scatter_size_list = token_counts_combined_cpu [0 ]
453
454
454
- dist .all_to_all_single (gathered_tokens , quantized_tokens ,
455
- scatter_size_list , gather_size_list )
456
- dist .all_to_all_single (dynamic_scale , token_scales , scatter_size_list ,
457
- gather_size_list )
455
+ dist .all_to_all_single (gathered_tokens ,
456
+ quantized_tokens ,
457
+ scatter_size_list ,
458
+ gather_size_list ,
459
+ group = ep_group .device_group )
460
+ dist .all_to_all_single (dynamic_scale ,
461
+ token_scales ,
462
+ scatter_size_list ,
463
+ gather_size_list ,
464
+ group = ep_group .device_group )
458
465
459
466
hidden_states , dynamic_scale , inverse_indices , expert_tokens = torch_npu .npu_moe_re_routing (
460
467
gathered_tokens ,
@@ -502,9 +509,11 @@ def torchair_fused_experts_with_all2all(
502
509
index = inverse_indices .to (torch .float32 ).argsort ().to (torch .int32 ))
503
510
504
511
hidden_states = reordered_outputs .new_empty (* quantized_tokens .shape )
505
- dist .all_to_all_single (hidden_states , reordered_outputs ,
506
- gather_size_list , scatter_size_list )
507
-
512
+ dist .all_to_all_single (hidden_states ,
513
+ reordered_outputs ,
514
+ gather_size_list ,
515
+ scatter_size_list ,
516
+ group = ep_group .device_group )
508
517
final_hidden_states = torch_npu .npu_moe_finalize_routing (
509
518
hidden_states ,
510
519
skip1 = None ,
0 commit comments