@@ -402,18 +402,20 @@ def fused_experts_with_all2all_v2(x, topk_ids, topk_weight, w1, w2, w1_scale, w2
402
402
# ep_size = get_expert_parallel_world_size()
403
403
ep_size = 16
404
404
combine_tokens = combine_tokens .view (2 , ep_size , - 1 ).sum (2 )
405
+ combine_tokens_cpu = combine_tokens .to (torch .device ("cpu" ), non_blocking = True ).numpy ()
405
406
all_tokens = combine_tokens [0 ].sum ()
406
- combine_tokens_cpu = combine_tokens .cpu ().tolist ()
407
- # alltoall input splits, the total number of tokens routed from the current rank to other ranks
408
- input_splits = combine_tokens_cpu [1 ]
409
- # alltoall output splits, the number of tokens each rank receives from other cards
410
- output_splits = combine_tokens_cpu [0 ]
411
407
# alltoall output, unfolded into one dimension, the size is the sum of the number of tokens routed from other cards to the current rank.
412
408
gathered_tokens = expanded_x .new_empty (
413
409
all_tokens .item (), expanded_x .shape [1 ]
414
410
)
415
- dist .all_to_all_single (gathered_tokens , expanded_x , output_splits , input_splits )
416
411
gathered_pertoken_scale = pertoken_scale .new_empty (gathered_tokens .shape [0 ])
412
+
413
+ # alltoall input splits, the total number of tokens routed from the current rank to other ranks
414
+ input_splits = combine_tokens_cpu [1 ]
415
+ # alltoall output splits, the number of tokens each rank receives from other cards
416
+ output_splits = combine_tokens_cpu [0 ]
417
+
418
+ dist .all_to_all_single (gathered_tokens , expanded_x , output_splits , input_splits )
417
419
dist .all_to_all_single (gathered_pertoken_scale , pertoken_scale , output_splits , input_splits )
418
420
# reroute
419
421
# Tokens merged by experts, scales merged by experts, indices for FinalizeRouting, number of tokens processed by each expert
0 commit comments