Skip to content

Commit 6a4c2e7

Browse files
committed
Fix CI Break: upstream adds routed_scaling_factor in forward_oot interface, vllm-ascend needs to adapt
Signed-off-by: leo-pony <nengjunma@outlook.com>
1 parent ad13964 commit 6a4c2e7

File tree

2 files changed

+72
-2
lines changed

2 files changed

+72
-2
lines changed

vllm_ascend/ops/common_fused_moe.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from vllm_ascend.ops.layers.experts_selector import select_experts
3636
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
3737
setup_token_dispatchers
38-
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
38+
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, vllm_version_is
3939

4040
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
4141

@@ -246,6 +246,67 @@ def unquantized_fused_moe_init_func(self, *args, **kwargs):
246246
and not vllm_config.model_config.enforce_eager)
247247

248248

249+
def forward_oot_v01011(
250+
self,
251+
layer: torch.nn.Module,
252+
x: torch.Tensor,
253+
use_grouped_topk: bool,
254+
top_k: int,
255+
router_logits: torch.Tensor,
256+
renormalize: bool,
257+
topk_group: Optional[int] = None,
258+
num_expert_group: Optional[int] = None,
259+
custom_routing_function: Optional[Callable] = None,
260+
scoring_func: str = "softmax",
261+
e_score_correction_bias: Optional[torch.Tensor] = None,
262+
global_num_experts: int = -1,
263+
expert_map: Optional[torch.Tensor] = None,
264+
apply_router_weight_on_input: bool = False,
265+
activation: str = "silu",
266+
enable_eplb: bool = False,
267+
expert_load_view: Optional[torch.Tensor] = None,
268+
logical_to_physical_map: Optional[torch.Tensor] = None,
269+
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
270+
271+
topk_weights, topk_ids, _ = select_experts(
272+
hidden_states=x,
273+
router_logits=router_logits,
274+
top_k=top_k,
275+
use_grouped_topk=use_grouped_topk,
276+
renormalize=renormalize,
277+
topk_group=topk_group,
278+
num_expert_group=num_expert_group,
279+
custom_routing_function=custom_routing_function,
280+
scoring_func=scoring_func,
281+
routed_scaling_factor=1.0,
282+
e_score_correction_bias=e_score_correction_bias,
283+
global_num_experts=global_num_experts)
284+
285+
if topk_ids.shape[1] < top_k or is_310p():
286+
assert global_num_experts is not None
287+
return fused_experts_moge(
288+
hidden_states=x,
289+
w1=layer.w13_weight,
290+
w2=layer.w2_weight,
291+
moe_parallel_config=self.moe.moe_parallel_config,
292+
topk_weights=topk_weights,
293+
topk_ids=topk_ids,
294+
top_k=top_k,
295+
global_num_experts=global_num_experts,
296+
expert_map=expert_map,
297+
apply_router_weight_on_input=apply_router_weight_on_input)
298+
299+
return fused_experts(
300+
hidden_states=x,
301+
w1=layer.w13_weight,
302+
w2=layer.w2_weight,
303+
topk_weights=topk_weights,
304+
topk_ids=topk_ids,
305+
global_num_experts=global_num_experts,
306+
expert_map=expert_map,
307+
)
308+
309+
249310
def forward_oot(
250311
self,
251312
layer: torch.nn.Module,
@@ -258,6 +319,7 @@ def forward_oot(
258319
num_expert_group: Optional[int] = None,
259320
custom_routing_function: Optional[Callable] = None,
260321
scoring_func: str = "softmax",
322+
routed_scaling_factor: float = 1.0,
261323
e_score_correction_bias: Optional[torch.Tensor] = None,
262324
global_num_experts: int = -1,
263325
expert_map: Optional[torch.Tensor] = None,
@@ -278,6 +340,7 @@ def forward_oot(
278340
num_expert_group=num_expert_group,
279341
custom_routing_function=custom_routing_function,
280342
scoring_func=scoring_func,
343+
routed_scaling_factor=routed_scaling_factor,
281344
e_score_correction_bias=e_score_correction_bias,
282345
global_num_experts=global_num_experts)
283346

@@ -441,4 +504,8 @@ def forward_impl(self, hidden_states: torch.Tensor,
441504

442505
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
443506
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
444-
UnquantizedFusedMoEMethod.forward_oot = forward_oot
507+
508+
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
509+
UnquantizedFusedMoEMethod.forward_oot = forward_oot_v01011
510+
else:
511+
UnquantizedFusedMoEMethod.forward_oot = forward_oot

vllm_ascend/ops/layers/experts_selector.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def select_experts(hidden_states: torch.Tensor,
4040
num_expert_group: Optional[int] = None,
4141
custom_routing_function: Optional[Callable] = None,
4242
scoring_func: str = "softmax",
43+
routed_scaling_factor=1.0,
4344
e_score_correction_bias: Optional[torch.Tensor] = None,
4445
indices_type: Optional[torch.dtype] = None,
4546
is_unquantized: bool = False,
@@ -78,6 +79,7 @@ def select_experts(hidden_states: torch.Tensor,
7879
num_expert_group=num_expert_group,
7980
custom_routing_function=custom_routing_function,
8081
scoring_func=scoring_func,
82+
routed_scaling_factor=routed_scaling_factor,
8183
global_num_experts=global_num_experts,
8284
is_unquantized=is_unquantized)
8385

@@ -180,6 +182,7 @@ def _select_experts_with_fusion_ops(
180182
num_expert_group: Optional[int],
181183
custom_routing_function: Optional[Callable] = None,
182184
scoring_func: str = "softmax",
185+
routed_scaling_factor=1.0,
183186
global_num_experts: int = -1,
184187
is_unquantized: bool = False):
185188

0 commit comments

Comments
 (0)