20
20
21
21
import torch
22
22
import torch_npu
23
-
23
+ import torch .distributed as dist
24
+ from vllm .distributed import GroupCoordinator
24
25
from vllm_ascend .distributed .parallel_state import get_ep_group
25
26
from vllm_ascend .ops .fused_moe import select_experts
26
27
@@ -201,6 +202,120 @@ def fused_experts_with_mc2(
201
202
return hidden_states
202
203
203
204
205
+ def fused_experts_with_all2all (hidden_states : torch .Tensor ,
206
+ w1 : torch .Tensor ,
207
+ w1_scale : torch .Tensor ,
208
+ w2 : torch .Tensor ,
209
+ w2_scale : torch .Tensor ,
210
+ topk_weights : torch .Tensor ,
211
+ topk_ids : torch .Tensor ,
212
+ top_k : int ,
213
+ expert_map : torch .Tensor = None ,
214
+ ep_group : GroupCoordinator = None ,
215
+ ):
216
+ original_shape = hidden_states .shape
217
+ if len (original_shape ) == 3 :
218
+ hidden_states = hidden_states .view (- 1 , hidden_states .shape [- 1 ])
219
+
220
+ num_tokens , _ = hidden_states .shape
221
+ num_experts = w1 .shape [0 ]
222
+ dtype = hidden_states .dtype
223
+ device = hidden_states .device
224
+
225
+ if expert_map is not None :
226
+ global_num_experts = len (expert_map )
227
+ local_num_experts = global_num_experts // ep_group .world_size
228
+ row_idx_len = num_tokens * top_k
229
+ row_idx = (torch .arange (0 ,
230
+ row_idx_len ,
231
+ dtype = torch .int32 ,
232
+ device = device ).view (top_k , - 1 ).permute (
233
+ 1 , 0 ).contiguous ())
234
+ hidden_states , expanded_row_idx , expanded_expert_idx = torch_npu .npu_moe_init_routing (
235
+ hidden_states ,
236
+ row_idx = row_idx ,
237
+ expert_idx = topk_ids ,
238
+ active_num = num_tokens )
239
+
240
+ global_expert_tokens = torch .bincount (expanded_expert_idx , minlength = global_num_experts )
241
+ scatter_sizes = global_expert_tokens .view (ep_group .world_size , - 1 ).sum (- 1 )
242
+
243
+ gather_sizes = torch .empty_like (scatter_sizes )
244
+ dist .all_to_all_single (gather_sizes , scatter_sizes , group = ep_group .device_group )
245
+ scatter_size_list = scatter_sizes .cpu ().tolist ()
246
+ gather_size_list = gather_sizes .cpu ().tolist ()
247
+
248
+ expanded_expert_idx = expanded_expert_idx % local_num_experts
249
+ hidden_states = ep_group .all_to_all (hidden_states , 0 , 0 , scatter_size_list , gather_size_list )
250
+ local_expert_idx = ep_group .all_to_all (expanded_expert_idx , 0 , 0 , scatter_size_list , gather_size_list )
251
+
252
+ sorted_local_expert_idx , sorted_idx = torch .sort (local_expert_idx )
253
+
254
+ expert_tokens = torch_npu .npu_moe_compute_expert_tokens (sorted_local_expert_idx , local_num_experts ).to (
255
+ torch .int64 )
256
+
257
+ hidden_states = hidden_states [sorted_idx ]
258
+ group_list_type = 0
259
+ else :
260
+ row_idx_len = num_tokens * top_k
261
+ row_idx = torch .arange (0 ,
262
+ row_idx_len ,
263
+ dtype = torch .int32 ,
264
+ device = topk_weights .device ).view (
265
+ top_k , - 1 ).permute (1 , 0 ).contiguous ()
266
+ hidden_states , expanded_row_idx , expanded_expert_idx = torch_npu .npu_moe_init_routing (
267
+ hidden_states ,
268
+ row_idx = row_idx ,
269
+ expert_idx = topk_ids ,
270
+ active_num = num_tokens )
271
+
272
+ expert_tokens = torch_npu .npu_moe_compute_expert_tokens (
273
+ expanded_expert_idx , num_experts )
274
+ expert_tokens = expert_tokens .to (torch .int64 )
275
+ group_list_type = 0
276
+
277
+ hidden_states_wrapper = [hidden_states ]
278
+ del hidden_states
279
+
280
+ hidden_states = apply_mlp (hidden_states_wrapper ,
281
+ w1 ,
282
+ w1_scale ,
283
+ w2 ,
284
+ w2_scale ,
285
+ expert_tokens ,
286
+ group_list_type = group_list_type )
287
+
288
+ if expert_map is not None :
289
+ resorted_idx = torch .argsort (sorted_idx )
290
+ hidden_states = hidden_states [resorted_idx ]
291
+ hidden_states = ep_group .all_to_all (hidden_states , 0 , 0 , gather_size_list , scatter_size_list )
292
+
293
+ final_hidden_states = torch_npu .npu_moe_finalize_routing (
294
+ hidden_states ,
295
+ skip1 = None ,
296
+ skip2 = None ,
297
+ bias = None ,
298
+ scales = topk_weights ,
299
+ expanded_src_to_dst_row = expanded_row_idx ,
300
+ export_for_source_row = topk_ids ,
301
+ )
302
+ else :
303
+ # TODO: Reorder device memory 2 times here, replace the current
304
+ # implementation here when suitable operators become available.
305
+ final_hidden_states = torch_npu .npu_moe_finalize_routing (
306
+ hidden_states ,
307
+ skip1 = None ,
308
+ skip2 = None ,
309
+ bias = None ,
310
+ scales = topk_weights ,
311
+ expanded_src_to_dst_row = expanded_row_idx ,
312
+ export_for_source_row = topk_ids ,
313
+ )
314
+ if len (original_shape ) == 3 :
315
+ final_hidden_states = final_hidden_states .view (original_shape )
316
+ return final_hidden_states
317
+
318
+
204
319
def fused_experts (hidden_states : torch .Tensor ,
205
320
w1 : torch .Tensor ,
206
321
w1_scale : torch .Tensor ,
@@ -387,10 +502,10 @@ class AscendW8A8DynamicFusedMoEMethod:
387
502
def __init__ (self ):
388
503
self .transpose_weight = True
389
504
390
- ep_group = get_ep_group ()
505
+ self . ep_group = get_ep_group ()
391
506
392
507
try :
393
- device_group = ep_group .device_group
508
+ device_group = self . ep_group .device_group
394
509
# TODO: Try local_rank = ep_group.rank_in_group
395
510
local_rank = torch .distributed .get_rank (group = device_group )
396
511
backend = device_group ._get_backend (torch .device ("npu" ))
@@ -457,6 +572,7 @@ def apply(
457
572
scoring_func : str = "softmax" ,
458
573
e_score_correction_bias : Optional [torch .Tensor ] = None ,
459
574
is_prefill : bool = True ,
575
+ dp_size : int = 1 ,
460
576
** kwargs ,
461
577
) -> torch .Tensor :
462
578
assert router_logits .shape [
@@ -503,7 +619,7 @@ def apply(
503
619
top_k = top_k ,
504
620
expert_map = expert_map ,
505
621
moe_all_to_all_group_name = self .moe_all_to_all_group_name )
506
- else :
622
+ elif dp_size == 1 :
507
623
return fused_experts (hidden_states = x ,
508
624
w1 = layer .w13_weight ,
509
625
w1_scale = layer .w13_weight_scale ,
@@ -513,6 +629,17 @@ def apply(
513
629
topk_ids = topk_ids ,
514
630
top_k = top_k ,
515
631
expert_map = expert_map )
632
+ else :
633
+ return fused_experts_with_all2all (hidden_states = x ,
634
+ w1 = layer .w13_weight ,
635
+ w1_scale = layer .w13_weight_scale ,
636
+ w2 = layer .w2_weight ,
637
+ w2_scale = layer .w2_weight_scale ,
638
+ topk_weights = topk_weights ,
639
+ topk_ids = topk_ids ,
640
+ top_k = top_k ,
641
+ expert_map = expert_map ,
642
+ ep_group = self .ep_group )
516
643
517
644
def process_weights_after_loading (self , layer ):
518
645
if self .transpose_weight :
0 commit comments