29
29
FusedMoE , UnquantizedFusedMoEMethod , determine_expert_map )
30
30
31
31
from vllm_ascend .ascend_config import get_ascend_config
32
+ from vllm_ascend .ascend_forward_context import MoECommImpl
32
33
from vllm_ascend .distributed .parallel_state import get_mc2_group
33
34
from vllm_ascend .eplb .core .eplb_utils import (determine_default_expert_map ,
34
35
determine_default_log2phy_map )
35
36
from vllm_ascend .ops .expert_load_balancer import ExpertLoadBalancer
36
37
from vllm_ascend .ops .moe .experts_selector import select_experts
37
- from vllm_ascend .ops .moe .moe_comm_method import (AllGatherCommImpl ,
38
- AlltoAllCommImpl , MC2CommImpl ,
39
- NaiveMulticastCommImpl )
38
+ from vllm_ascend .ops .moe .moe_comm_method import setup_moe_comm_method
40
39
from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_NZ ,
41
40
get_all_reduce_merge_state ,
42
41
get_rm_router_logits_state , is_310p )
@@ -145,6 +144,8 @@ def apply(self,
145
144
146
145
147
146
class AscendFusedMoE (FusedMoE ):
147
+ # The moe_counter parameter is required during the initialization of EPLB
148
+ # to identify the current layer index within the MOE model.
148
149
moe_counter = - 1
149
150
150
151
def __init__ (self , * args , ** kwargs ):
@@ -172,14 +173,11 @@ def __init__(self, *args, **kwargs):
172
173
173
174
assert self .quant_method is not None
174
175
175
- AscendFusedMoE .moe_counter += 1
176
- self .moe_instance_id = AscendFusedMoE .moe_counter
177
176
self .moe_config .tp_group = get_tp_group ()
178
177
self .moe_config .dp_group = get_dp_group ()
179
178
self .moe_config .ep_group = get_ep_group ()
180
179
self .moe_config .mc2_group = get_mc2_group ()
181
180
ascend_config = get_ascend_config ()
182
- self .enable_shared_expert_dp = ascend_config .enable_shared_expert_dp
183
181
self .dynamic_eplb = ascend_config .dynamic_eplb
184
182
self .expert_map_path = ascend_config .expert_map_path
185
183
self .global_redundant_expert_num = ascend_config .init_redundancy_expert
@@ -215,13 +213,9 @@ def __init__(self, *args, **kwargs):
215
213
if self .dynamic_eplb :
216
214
self .moe_load = torch .zeros (local_num_experts , dtype = torch .int64 )
217
215
218
- for method in {
219
- AllGatherCommImpl , AlltoAllCommImpl , MC2CommImpl ,
220
- NaiveMulticastCommImpl
221
- }:
222
- setattr (
223
- self , method .__name__ .lower (),
224
- method (moe_config = self .moe_config )) # type: ignore[abstract]
216
+ self .enable_shared_expert_dp = ascend_config .enable_shared_expert_dp
217
+
218
+ setup_moe_comm_method (self .moe_config )
225
219
226
220
def update_expert_map (self , new_expert_map ):
227
221
self .expert_map = new_expert_map
@@ -245,8 +239,8 @@ def maybe_all_reduce_tensor_model_parallel(
245
239
outputs since each rank only has partial outputs.
246
240
"""
247
241
forward_context = get_forward_context ()
248
- moe_comm_method_name = forward_context .moe_comm_method_name
249
- if moe_comm_method_name in {"alltoallcommimpl" , "mc2commimpl" }:
242
+ moe_comm_method_type = forward_context .moe_comm_method_type
243
+ if moe_comm_method_type in {MoECommImpl . AllTOAll , MoECommImpl . MC2 }:
250
244
return final_hidden_states
251
245
else :
252
246
return tensor_model_parallel_all_reduce (final_hidden_states )
@@ -260,9 +254,6 @@ def forward_impl(self, hidden_states: torch.Tensor,
260
254
261
255
forward_context = get_forward_context ()
262
256
enable_force_load_balance = forward_context .in_profile_run
263
- moe_comm_method_name = forward_context .moe_comm_method_name
264
-
265
- forward_context .moe_comm_method = getattr (self , moe_comm_method_name )
266
257
267
258
hidden_states , router_logits = forward_context .moe_comm_method .prepare (
268
259
hidden_states = hidden_states ,
@@ -287,11 +278,13 @@ def forward_impl(self, hidden_states: torch.Tensor,
287
278
e_score_correction_bias = self .e_score_correction_bias ,
288
279
activation = self .activation ,
289
280
apply_router_weight_on_input = self .apply_router_weight_on_input ,
290
- enable_eplb = self .enable_eplb ,
291
- expert_load_view = self .expert_load_view ,
292
- logical_to_physical_map = self .logical_to_physical_map ,
293
- logical_replica_count = self .logical_replica_count ,
294
- )
281
+ quantized_x_for_share = quantized_x_for_share ,
282
+ dynamic_scale_for_share = dynamic_scale_for_share ,
283
+ shared_experts = None ,
284
+ enable_force_load_balance = enable_force_load_balance ,
285
+ log2phy = self .log2phy ,
286
+ global_redundant_expert_num = self .global_redundant_expert_num )
287
+
295
288
if isinstance (final_hidden_states , tuple ):
296
289
final_hidden_states , group_list_type , expert_tokens = final_hidden_states
297
290
@@ -410,8 +403,8 @@ def forward(
410
403
411
404
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
412
405
forward_context = get_forward_context ()
413
- moe_comm_method_name = forward_context .moe_comm_method_name
414
- if moe_comm_method_name in {"alltoallcommimpl" , "mc2commimpl" }:
406
+ moe_comm_method_type = forward_context .moe_comm_method_type
407
+ if moe_comm_method_type in {MoECommImpl . AllTOAll , MoECommImpl . MC2 }:
415
408
shared_out = tensor_model_parallel_all_reduce (shared_out )
416
409
417
410
fused_out = super ().forward (
0 commit comments