@@ -1121,14 +1121,14 @@ def forward(self,
1121
1121
mc2_mask = forward_context .mc2_mask
1122
1122
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
1123
1123
quantized_x_for_share , dynamic_scale_for_share = None , None
1124
- from vllm_ascend .quantization .w8a8_dynamic import \
1125
- AscendW8A8DynamicFusedMoEMethod
1124
+ from vllm_ascend .torchair . quantization .torchair_w8a8_dynamic import \
1125
+ TorchairAscendW8A8DynamicFusedMoEMethod
1126
1126
if self .enable_multistream_moe :
1127
1127
if not self .rm_router_logits :
1128
1128
router_logits , _ = gate (hidden_states )
1129
1129
if hasattr (self .quant_method , "quant_method" ) and \
1130
1130
isinstance (self .quant_method .quant_method ,
1131
- AscendW8A8DynamicFusedMoEMethod
1131
+ TorchairAscendW8A8DynamicFusedMoEMethod
1132
1132
) and fused_moe_state == FusedMoEState .MC2 :
1133
1133
with npu_stream_switch ("moe_secondary" , 0 ):
1134
1134
quantized_x_for_share , dynamic_scale_for_share = torch_npu .npu_dynamic_quant (
@@ -1154,30 +1154,32 @@ def forward(self,
1154
1154
FusedMoEState .AllGather , FusedMoEState .AllGatherEP ,
1155
1155
FusedMoEState .NaiveMulticast
1156
1156
] and not replace_allreduce ):
1157
- if fused_moe_state in {FusedMoEState .MC2 }:
1158
- padding_size = forward_context .padded_num_tokens
1159
- else :
1160
- # TODO: Determine if we can remove the padding
1161
- padding_size = tp_size
1162
- if num_tokens < padding_size and not self .enable_shared_expert_dp :
1163
- hidden_states = nn .functional .pad (
1164
- hidden_states , (0 , 0 , 0 , padding_size - num_tokens ))
1165
- router_logits = nn .functional .pad (
1166
- router_logits , (0 , 0 , 0 , padding_size - num_tokens ))
1167
1157
if tp_size > 1 :
1168
1158
tp_rank = get_tensor_model_parallel_rank ()
1169
- if not self .enable_shared_expert_dp :
1170
- chunk_hidden_states = torch .tensor_split (hidden_states ,
1171
- tp_size ,
1172
- dim = 0 )
1173
- chunk_router_logits = torch .tensor_split (router_logits ,
1174
- tp_size ,
1175
- dim = 0 )
1176
- hidden_states = chunk_hidden_states [tp_rank ]
1177
- router_logits = chunk_router_logits [tp_rank ]
1178
-
1179
1159
chunk_mc2_mask = torch .tensor_split (mc2_mask , tp_size , dim = 0 )
1180
1160
mc2_mask = chunk_mc2_mask [tp_rank ]
1161
+ if not replace_allreduce :
1162
+ if fused_moe_state in {FusedMoEState .MC2 }:
1163
+ padding_size = forward_context .padded_num_tokens
1164
+ else :
1165
+ # TODO: Determine if we can remove the padding
1166
+ padding_size = tp_size
1167
+ if num_tokens < padding_size and not self .enable_shared_expert_dp :
1168
+ hidden_states = nn .functional .pad (
1169
+ hidden_states , (0 , 0 , 0 , padding_size - num_tokens ))
1170
+ router_logits = nn .functional .pad (
1171
+ router_logits , (0 , 0 , 0 , padding_size - num_tokens ))
1172
+ if tp_size > 1 :
1173
+ tp_rank = get_tensor_model_parallel_rank ()
1174
+ if not self .enable_shared_expert_dp :
1175
+ chunk_hidden_states = torch .tensor_split (hidden_states ,
1176
+ tp_size ,
1177
+ dim = 0 )
1178
+ chunk_router_logits = torch .tensor_split (router_logits ,
1179
+ tp_size ,
1180
+ dim = 0 )
1181
+ hidden_states = chunk_hidden_states [tp_rank ]
1182
+ router_logits = chunk_router_logits [tp_rank ]
1181
1183
1182
1184
if self .dp_size > 1 :
1183
1185
if fused_moe_state == FusedMoEState .AllGather :
0 commit comments