From a20ba7ab84598963d201c4291721bb8b0f5439d3 Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Mon, 13 Oct 2025 16:51:00 -0700 Subject: [PATCH 01/13] support multi routing method in flashinfer trtllm moe Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../model_executor/layers/fused_moe/config.py | 19 +++++++++++++++++++ .../layers/fused_moe/flashinfer_trtllm_moe.py | 10 ++++++---- vllm/model_executor/layers/fused_moe/layer.py | 3 +++ .../model_executor/layers/quantization/fp8.py | 7 ++++--- vllm/model_executor/models/qwen3_next.py | 2 ++ 5 files changed, 34 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index cbc3caafcf2f..868791eb074e 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass +from enum import IntEnum from typing import Optional, Union import torch @@ -91,6 +92,24 @@ def _quant_flags_to_group_shape( return a_shape, w_shape +# The type of method in top-K routing +# Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h +class RoutingMethodType(IntEnum): + # Default: Softmax -> TopK + Default = (0,) + # Renormalize: TopK -> Softmax + Renormalize = (1,) + # DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups + # -> Top8 experts from the Top4 groups + DeepSeekV3 = (2,) + # Llama4: Top1 -> Sigmoid + Llama4 = (3,) + # Qwen3: Softmax -> TopK -> Renormalize + RenormalizeNaive = (4,) + # Unspecified + Unspecified = 5.0 + + @dataclass class FusedMoEQuantDesc: """ diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index f21fe16c5108..2276138dc8c6 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -30,19 +30,20 @@ def flashinfer_fused_moe_blockscale_fp8( local_num_experts: int, block_shape: list[int], routed_scaling: float = 1.0, + routing_method_type: int = 2, ) -> torch.Tensor: from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe assert top_k <= global_num_experts - assert top_k <= 8 + assert top_k <= 10 assert topk_group <= 4 assert global_num_experts > num_expert_group assert global_num_experts % num_expert_group == 0 assert global_num_experts % 4 == 0 assert top_k < (topk_group * global_num_experts / num_expert_group) assert block_shape == [128, 128] - # Routing kernel expects #experts <= #threads 256 - assert global_num_experts <= 256 + # Routing kernel expects #experts <= #threads 512 + assert global_num_experts <= 512 a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1]) # NOTE: scales of hidden states have to be transposed! @@ -67,7 +68,7 @@ def flashinfer_fused_moe_blockscale_fp8( tile_tokens_dim=calculate_tile_tokens_dim( x.shape[0], top_k, global_num_experts ), - routing_method_type=2, # DeepSeek-styled routing method + routing_method_type=routing_method_type, use_shuffled_weight=False, ) @@ -89,6 +90,7 @@ def flashinfer_fused_moe_blockscale_fp8_fake( local_num_experts: int, block_shape: list[int], routed_scaling: float = 1.0, + routing_method_type: int = 2, ) -> torch.Tensor: return torch.empty_like(x) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index e69ead074c50..5331a2e5999e 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1218,6 +1218,7 @@ def __init__( zero_expert_type: str | None = None, expert_mapping: list[tuple[str, str, int, str]] | None = None, n_shared_experts: int | None = None, + routing_method_type: int = 2, ): super().__init__() @@ -1391,6 +1392,8 @@ def __init__( self.e_score_correction_bias = e_score_correction_bias self.apply_router_weight_on_input = apply_router_weight_on_input self.activation = activation + # Optional routing method id for backends that support multiple types + self.routing_method_type = routing_method_type if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index ce40645782e5..09afaf128c1f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1226,9 +1226,9 @@ def apply( assert activation == "silu", ( f"Expected 'silu' activation but got {activation}" ) - assert scoring_func == "sigmoid", ( - f"Expected 'sigmoid' scoring func but got {scoring_func}" - ) + # assert scoring_func == "sigmoid", ( + # f"Expected 'sigmoid' scoring func but got {scoring_func}" + # ) if self.block_quant: import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 @@ -1257,6 +1257,7 @@ def apply( local_num_experts=layer.local_num_experts, block_shape=self.weight_block_size, routed_scaling=routed_scaling_factor, + routing_method_type=getattr(layer, "routing_method_type", 2), ) else: assert not renormalize and custom_routing_function is not None diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index b095c79dc954..a685174956c1 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -38,6 +38,7 @@ GemmaRMSNorm as Qwen3NextRMSNorm, ) from vllm.model_executor.layers.layernorm import RMSNormGated +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, @@ -173,6 +174,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, + routing_method_type=RoutingMethodType.Renormalize, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: From a34e3df1ee146d66ed13bcaff504f32e617fcd6d Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Tue, 14 Oct 2025 13:36:15 -0700 Subject: [PATCH 02/13] fix Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../layers/fused_moe/flashinfer_trtllm_moe.py | 16 +++++++------- .../model_executor/layers/quantization/fp8.py | 13 ++++++------ .../layers/quantization/modelopt.py | 1 + .../quantization/utils/flashinfer_utils.py | 21 +++++++++++-------- 4 files changed, 28 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 2276138dc8c6..a20783bbfd1f 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch - +from typing import Optional from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( calculate_tile_tokens_dim, @@ -23,24 +23,24 @@ def flashinfer_fused_moe_blockscale_fp8( w2_weight_scale_inv: torch.Tensor, global_num_experts: int, top_k: int, - num_expert_group: int, - topk_group: int, + num_expert_group: Optional[int], + topk_group: Optional[int], intermediate_size: int, expert_offset: int, local_num_experts: int, block_shape: list[int], - routed_scaling: float = 1.0, + routed_scaling: Optional[float] = 1.0, routing_method_type: int = 2, ) -> torch.Tensor: from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe assert top_k <= global_num_experts assert top_k <= 10 - assert topk_group <= 4 - assert global_num_experts > num_expert_group - assert global_num_experts % num_expert_group == 0 + # assert topk_group <= 4 + # assert global_num_experts > num_expert_group + # assert global_num_experts % num_expert_group == 0 assert global_num_experts % 4 == 0 - assert top_k < (topk_group * global_num_experts / num_expert_group) + # assert top_k < (topk_group * global_num_experts / num_expert_group) assert block_shape == [128, 128] # Routing kernel expects #experts <= #threads 512 assert global_num_experts <= 512 diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 09afaf128c1f..f0374887787f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1232,16 +1232,17 @@ def apply( if self.block_quant: import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 - assert ( - renormalize and use_grouped_topk and custom_routing_function is None - ) + # assert ( + # renormalize and use_grouped_topk and custom_routing_function is None + # ) e_score_correction_bias = ( e_score_correction_bias.to(x.dtype) if e_score_correction_bias is not None else None ) + routing_method_type = getattr(layer, "routing_method_type", 2) return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( - routing_logits=router_logits.to(torch.float32), + routing_logits=router_logits.to(torch.float32) if routing_method_type == 2 else router_logits, routing_bias=e_score_correction_bias, x=x, w13_weight=layer.w13_weight, @@ -1255,9 +1256,9 @@ def apply( intermediate_size=layer.intermediate_size_per_partition, expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, - block_shape=self.weight_block_size, + block_shape=self.weight_block_size, routed_scaling=routed_scaling_factor, - routing_method_type=getattr(layer, "routing_method_type", 2), + routing_method_type=routing_method_type, ) else: assert not renormalize and custom_routing_function is not None diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index e14753c60c48..e0424b75580f 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1341,6 +1341,7 @@ def prepare_static_weights_for_trtllm_fp4_moe( ): from flashinfer import nvfp4_block_scale_interleave from flashinfer.fused_moe.core import ( + get_w2_permute_indices_with_cache, _maybe_get_cached_w3_w1_permute_indices, get_w2_permute_indices_with_cache, ) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 50ea049c3d5a..edfb8b561eb5 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -27,20 +27,23 @@ class FlashinferMoeBackend(Enum): def calculate_tile_tokens_dim(num_tokens, top_k, num_experts): + from flashinfer import next_positive_power_of_2 # FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now. # TODO: Revert this to dynamic calculation once a new version of FlashInfer # with the necessary kernels is released. tile_tokens_dim = 8 - # from flashinfer import next_positive_power_of_2 - - # # Guess tokens per expert assuming perfect expert distribution first. - # num_tokens_per_expert = (num_tokens * top_k) // num_experts - # # And pad the number to the next power of 2. - # tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) - # # Cap to 8-64 tokens per CTA tile as it's the range supported by the - # # kernel. - # tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + imbalance_factor = 1.3 + # Calculate the number of tokens per expert + # assuming perfect distribution. + num_tokens_per_expert = (num_tokens * top_k) // num_experts + # Apply the imbalance factor. + num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor) + # And pad the number to the next power of 2. + tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) + # Cap to 8-max_tile_tokens_dim tokens per CTA tile + # as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) return tile_tokens_dim From af97ceddd8d566970a2b9123f59be9a25a888c86 Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Mon, 20 Oct 2025 13:31:06 -0700 Subject: [PATCH 03/13] update code Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index a20783bbfd1f..893702f9c33b 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -33,7 +33,7 @@ def flashinfer_fused_moe_blockscale_fp8( routing_method_type: int = 2, ) -> torch.Tensor: from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe - + topk_group = topk_group if topk_group is not None else 0 assert top_k <= global_num_experts assert top_k <= 10 # assert topk_group <= 4 From 9a35fba4eb4fe44b15aaf58b8a0fb4ab533a9f2c Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Wed, 22 Oct 2025 11:45:44 -0700 Subject: [PATCH 04/13] update work Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../model_executor/layers/fused_moe/config.py | 6 ++++-- .../layers/fused_moe/flashinfer_trtllm_moe.py | 2 +- .../model_executor/layers/quantization/fp8.py | 21 +++++++++++-------- vllm/model_executor/models/qwen3_moe.py | 3 +++ vllm/model_executor/models/qwen3_next.py | 19 ++++++++++++++--- 5 files changed, 36 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 868791eb074e..a7bd64b1c65e 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -104,10 +104,12 @@ class RoutingMethodType(IntEnum): DeepSeekV3 = (2,) # Llama4: Top1 -> Sigmoid Llama4 = (3,) - # Qwen3: Softmax -> TopK -> Renormalize + # RenormalizeNaive: Softmax -> TopK -> Renormalize RenormalizeNaive = (4,) + # TopK: TopK (no softmax) + TopK = (5,) # Unspecified - Unspecified = 5.0 + Unspecified = 6.0 @dataclass diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 893702f9c33b..cb141cc9bfbb 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -49,7 +49,7 @@ def flashinfer_fused_moe_blockscale_fp8( # NOTE: scales of hidden states have to be transposed! a_sf_t = a_sf.t().contiguous() return flashinfer_trtllm_fp8_block_scale_moe( - routing_logits=routing_logits, + routing_logits=routing_logits, routing_bias=routing_bias, hidden_states=a_q, hidden_states_scale=a_sf_t, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f0374887787f..8f188c27da6e 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -167,14 +167,14 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend: logger.info_once("Using DeepGEMM backend for FP8 MoE") return Fp8MoeBackend.DEEPGEMM - # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights - if ( - current_platform.is_cuda() - and current_platform.is_device_capability(100) - and block_quant - ): - logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE") - return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM + # # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights + # if ( + # current_platform.is_cuda() + # and current_platform.is_device_capability(100) + # and block_quant + # ): + # logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE") + # return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM # default to Triton logger.info_once("Using Triton backend for FP8 MoE") @@ -1302,7 +1302,10 @@ def apply( ) topk_weights, topk_ids, zero_expert_result = select_result - + # if (topk_ids.shape[0] <100): + # print("=== MoE Routing Results ===") + # print(f"topk_ids: {topk_ids}") + # print(f"topk_weights: {topk_weights}") if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 rocm_aiter_fused_experts, diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index a7e6772bb708..cfe027c6317f 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -43,6 +43,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType + from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -171,6 +173,7 @@ def __init__( enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, + routing_method_type=RoutingMethodType.RenormalizeNaive, ) self.gate = ReplicatedLinear( diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index a685174956c1..a23b692b51d1 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -103,6 +103,7 @@ class Qwen3NextSparseMoeBlock(nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + self.prefix_print = prefix config = vllm_config.model_config.hf_config parallel_config = vllm_config.parallel_config quant_config = vllm_config.quant_config @@ -174,7 +175,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, - routing_method_type=RoutingMethodType.Renormalize, + routing_method_type=RoutingMethodType.RenormalizeNaive, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -182,10 +183,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_shape = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - if self.is_sequence_parallel: hidden_states = sequence_parallel_chunk(hidden_states) +<<<<<<< HEAD if self.experts.is_internal_router: # In this case, the gate/router runs inside the FusedMoE class final_hidden_states = self.experts( @@ -197,8 +198,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits ) +======= + # print(self.prefix_print) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) +>>>>>>> 9d88f1762 (update work) if self.shared_expert is not None: + # if ("model.layers.0." in self.prefix_print or "model.layers.1." in self.prefix_print or "model.layers.47." in self.prefix_print): + # print(f"shared_expert: {final_hidden_states[0]}") + # print(f"routed_expert: {final_hidden_states[1]}") final_hidden_states = final_hidden_states[0] + final_hidden_states[1] if self.is_sequence_parallel: @@ -904,7 +916,7 @@ def forward( residual: torch.Tensor | None, positions: torch.Tensor = None, **kwargs: object, - ): + ): if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -1035,6 +1047,7 @@ def forward( {"hidden_states": hidden_states, "residual": residual} ) hidden_states, _ = self.norm(hidden_states, residual) + # print("="*60) return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: From 28a9697d04cc49f73e609f6ae1cd4fc0877b81aa Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Fri, 24 Oct 2025 16:53:04 -0700 Subject: [PATCH 05/13] add Qwen3 and fix lint Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../layers/fused_moe/flashinfer_trtllm_moe.py | 12 +++++----- .../model_executor/layers/quantization/fp8.py | 22 ++++++++++--------- .../layers/quantization/modelopt.py | 1 - .../quantization/utils/flashinfer_utils.py | 1 + vllm/model_executor/models/qwen3_moe.py | 1 - vllm/model_executor/models/qwen3_next.py | 2 +- 6 files changed, 21 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index cb141cc9bfbb..bea3a9b6c76b 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import torch -from typing import Optional + from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( calculate_tile_tokens_dim, @@ -23,16 +24,17 @@ def flashinfer_fused_moe_blockscale_fp8( w2_weight_scale_inv: torch.Tensor, global_num_experts: int, top_k: int, - num_expert_group: Optional[int], - topk_group: Optional[int], + num_expert_group: int | None, + topk_group: int | None, intermediate_size: int, expert_offset: int, local_num_experts: int, block_shape: list[int], - routed_scaling: Optional[float] = 1.0, + routed_scaling: float | None = 1.0, routing_method_type: int = 2, ) -> torch.Tensor: from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe + topk_group = topk_group if topk_group is not None else 0 assert top_k <= global_num_experts assert top_k <= 10 @@ -49,7 +51,7 @@ def flashinfer_fused_moe_blockscale_fp8( # NOTE: scales of hidden states have to be transposed! a_sf_t = a_sf.t().contiguous() return flashinfer_trtllm_fp8_block_scale_moe( - routing_logits=routing_logits, + routing_logits=routing_logits, routing_bias=routing_bias, hidden_states=a_q, hidden_states_scale=a_sf_t, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 8f188c27da6e..7f6ffec7d4df 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -167,14 +167,14 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend: logger.info_once("Using DeepGEMM backend for FP8 MoE") return Fp8MoeBackend.DEEPGEMM - # # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights - # if ( - # current_platform.is_cuda() - # and current_platform.is_device_capability(100) - # and block_quant - # ): - # logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE") - # return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM + # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights + if ( + current_platform.is_cuda() + and current_platform.is_device_capability(100) + and block_quant + ): + logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE") + return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM # default to Triton logger.info_once("Using Triton backend for FP8 MoE") @@ -1242,7 +1242,9 @@ def apply( ) routing_method_type = getattr(layer, "routing_method_type", 2) return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( - routing_logits=router_logits.to(torch.float32) if routing_method_type == 2 else router_logits, + routing_logits=router_logits.to(torch.float32) + if routing_method_type == 2 + else router_logits, routing_bias=e_score_correction_bias, x=x, w13_weight=layer.w13_weight, @@ -1256,7 +1258,7 @@ def apply( intermediate_size=layer.intermediate_size_per_partition, expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, - block_shape=self.weight_block_size, + block_shape=self.weight_block_size, routed_scaling=routed_scaling_factor, routing_method_type=routing_method_type, ) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index e0424b75580f..e14753c60c48 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1341,7 +1341,6 @@ def prepare_static_weights_for_trtllm_fp4_moe( ): from flashinfer import nvfp4_block_scale_interleave from flashinfer.fused_moe.core import ( - get_w2_permute_indices_with_cache, _maybe_get_cached_w3_w1_permute_indices, get_w2_permute_indices_with_cache, ) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index edfb8b561eb5..53f6a2b24e13 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -28,6 +28,7 @@ class FlashinferMoeBackend(Enum): def calculate_tile_tokens_dim(num_tokens, top_k, num_experts): from flashinfer import next_positive_power_of_2 + # FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now. # TODO: Revert this to dynamic calculation once a new version of FlashInfer # with the necessary kernels is released. diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index cfe027c6317f..538a0ea40940 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -44,7 +44,6 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.config import RoutingMethodType - from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index a23b692b51d1..00384b9c716e 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -916,7 +916,7 @@ def forward( residual: torch.Tensor | None, positions: torch.Tensor = None, **kwargs: object, - ): + ): if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) From 77b6bef9cf8c1fee5d027b9f77f0ad4979018649 Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Fri, 24 Oct 2025 17:44:22 -0700 Subject: [PATCH 06/13] lint Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../layers/fused_moe/flashinfer_trtllm_moe.py | 4 ---- vllm/model_executor/layers/quantization/fp8.py | 11 +---------- vllm/model_executor/models/qwen3_next.py | 3 --- 3 files changed, 1 insertion(+), 17 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index bea3a9b6c76b..b5c0df461fc3 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -38,11 +38,7 @@ def flashinfer_fused_moe_blockscale_fp8( topk_group = topk_group if topk_group is not None else 0 assert top_k <= global_num_experts assert top_k <= 10 - # assert topk_group <= 4 - # assert global_num_experts > num_expert_group - # assert global_num_experts % num_expert_group == 0 assert global_num_experts % 4 == 0 - # assert top_k < (topk_group * global_num_experts / num_expert_group) assert block_shape == [128, 128] # Routing kernel expects #experts <= #threads 512 assert global_num_experts <= 512 diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 7f6ffec7d4df..0242dcb80a93 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1226,15 +1226,10 @@ def apply( assert activation == "silu", ( f"Expected 'silu' activation but got {activation}" ) - # assert scoring_func == "sigmoid", ( - # f"Expected 'sigmoid' scoring func but got {scoring_func}" - # ) + if self.block_quant: import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 - # assert ( - # renormalize and use_grouped_topk and custom_routing_function is None - # ) e_score_correction_bias = ( e_score_correction_bias.to(x.dtype) if e_score_correction_bias is not None @@ -1304,10 +1299,6 @@ def apply( ) topk_weights, topk_ids, zero_expert_result = select_result - # if (topk_ids.shape[0] <100): - # print("=== MoE Routing Results ===") - # print(f"topk_ids: {topk_ids}") - # print(f"topk_weights: {topk_weights}") if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 rocm_aiter_fused_experts, diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 00384b9c716e..529bcba77a19 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -208,9 +208,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: >>>>>>> 9d88f1762 (update work) if self.shared_expert is not None: - # if ("model.layers.0." in self.prefix_print or "model.layers.1." in self.prefix_print or "model.layers.47." in self.prefix_print): - # print(f"shared_expert: {final_hidden_states[0]}") - # print(f"routed_expert: {final_hidden_states[1]}") final_hidden_states = final_hidden_states[0] + final_hidden_states[1] if self.is_sequence_parallel: From 907e19bec1e60fe010f404a4d1f5a54c4bfa821b Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Wed, 29 Oct 2025 18:50:11 -0700 Subject: [PATCH 07/13] format Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../layers/fused_moe/flashinfer_trtllm_moe.py | 1 - vllm/model_executor/layers/quantization/fp8.py | 1 + vllm/model_executor/models/qwen3_next.py | 11 +---------- 3 files changed, 2 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index b5c0df461fc3..e9c02ec1364c 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import torch from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 0242dcb80a93..8af257d5724d 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1299,6 +1299,7 @@ def apply( ) topk_weights, topk_ids, zero_expert_result = select_result + if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 rocm_aiter_fused_experts, diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 529bcba77a19..08dad481fa3f 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -183,10 +183,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_shape = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) + if self.is_sequence_parallel: hidden_states = sequence_parallel_chunk(hidden_states) -<<<<<<< HEAD if self.experts.is_internal_router: # In this case, the gate/router runs inside the FusedMoE class final_hidden_states = self.experts( @@ -198,14 +198,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits ) -======= - # print(self.prefix_print) - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits - ) ->>>>>>> 9d88f1762 (update work) if self.shared_expert is not None: final_hidden_states = final_hidden_states[0] + final_hidden_states[1] @@ -1044,7 +1036,6 @@ def forward( {"hidden_states": hidden_states, "residual": residual} ) hidden_states, _ = self.norm(hidden_states, residual) - # print("="*60) return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: From 970a9198f466af1d59099fe7554133a0368a9020 Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Thu, 30 Oct 2025 10:53:13 -0700 Subject: [PATCH 08/13] per comment Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py | 4 +--- .../layers/quantization/utils/flashinfer_utils.py | 1 + vllm/model_executor/models/qwen3_next.py | 1 - 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index e9c02ec1364c..4dbbd02cd368 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -62,9 +62,7 @@ def flashinfer_fused_moe_blockscale_fp8( local_expert_offset=expert_offset, local_num_experts=local_num_experts, routed_scaling_factor=routed_scaling, - tile_tokens_dim=calculate_tile_tokens_dim( - x.shape[0], top_k, global_num_experts - ), + tile_tokens_dim=None, routing_method_type=routing_method_type, use_shuffled_weight=False, ) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 53f6a2b24e13..e49d374f154d 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -34,6 +34,7 @@ def calculate_tile_tokens_dim(num_tokens, top_k, num_experts): # with the necessary kernels is released. tile_tokens_dim = 8 + # A factor considering tokens are not perfectly balanced among experts. imbalance_factor = 1.3 # Calculate the number of tokens per expert # assuming perfect distribution. diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 08dad481fa3f..0fc964b7b710 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -103,7 +103,6 @@ class Qwen3NextSparseMoeBlock(nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - self.prefix_print = prefix config = vllm_config.model_config.hf_config parallel_config = vllm_config.parallel_config quant_config = vllm_config.quant_config From 711ac5dfbaa6c07235e45ad31c620976dc1872d2 Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Fri, 31 Oct 2025 11:41:19 -0700 Subject: [PATCH 09/13] per comment Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../model_executor/layers/fused_moe/flashinfer_trtllm_moe.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 4dbbd02cd368..409fc55b6937 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -3,6 +3,7 @@ import torch +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( calculate_tile_tokens_dim, @@ -30,7 +31,7 @@ def flashinfer_fused_moe_blockscale_fp8( local_num_experts: int, block_shape: list[int], routed_scaling: float | None = 1.0, - routing_method_type: int = 2, + routing_method_type: int = RoutingMethodType.DeepSeekV3, ) -> torch.Tensor: from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe @@ -85,7 +86,7 @@ def flashinfer_fused_moe_blockscale_fp8_fake( local_num_experts: int, block_shape: list[int], routed_scaling: float = 1.0, - routing_method_type: int = 2, + routing_method_type: int = RoutingMethodType.DeepSeekV3, ) -> torch.Tensor: return torch.empty_like(x) From b6cd21b61b72a674183931d08ded4fd35aa87610 Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Fri, 31 Oct 2025 22:55:39 -0700 Subject: [PATCH 10/13] update dtype Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../model_executor/layers/fused_moe/flashinfer_trtllm_moe.py | 4 ++-- vllm/model_executor/layers/fused_moe/layer.py | 2 +- vllm/model_executor/layers/quantization/fp8.py | 5 +++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 409fc55b6937..51e06ac54f49 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -30,8 +30,8 @@ def flashinfer_fused_moe_blockscale_fp8( expert_offset: int, local_num_experts: int, block_shape: list[int], - routed_scaling: float | None = 1.0, routing_method_type: int = RoutingMethodType.DeepSeekV3, + routed_scaling: float | None = 1.0, ) -> torch.Tensor: from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe @@ -85,8 +85,8 @@ def flashinfer_fused_moe_blockscale_fp8_fake( expert_offset: int, local_num_experts: int, block_shape: list[int], + routing_method_type: int, routed_scaling: float = 1.0, - routing_method_type: int = RoutingMethodType.DeepSeekV3, ) -> torch.Tensor: return torch.empty_like(x) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 5331a2e5999e..8d969d5e661b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1218,7 +1218,7 @@ def __init__( zero_expert_type: str | None = None, expert_mapping: list[tuple[str, str, int, str]] | None = None, n_shared_experts: int | None = None, - routing_method_type: int = 2, + routing_method_type: int | None = None, ): super().__init__() diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 8af257d5724d..a2db45246292 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -27,6 +27,7 @@ ) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, + RoutingMethodType, fp8_w8a8_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe @@ -1238,7 +1239,7 @@ def apply( routing_method_type = getattr(layer, "routing_method_type", 2) return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( routing_logits=router_logits.to(torch.float32) - if routing_method_type == 2 + if routing_method_type == RoutingMethodType.DeepSeekV3 else router_logits, routing_bias=e_score_correction_bias, x=x, @@ -1254,8 +1255,8 @@ def apply( expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, block_shape=self.weight_block_size, - routed_scaling=routed_scaling_factor, routing_method_type=routing_method_type, + routed_scaling=routed_scaling_factor, ) else: assert not renormalize and custom_routing_function is not None From 01b4a9e70e59a39bfcae033bb61253b757edda8b Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Tue, 4 Nov 2025 09:37:51 -0800 Subject: [PATCH 11/13] fix conflict Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- vllm/model_executor/layers/fused_moe/layer.py | 21 +++++++++++++++++-- vllm/model_executor/models/qwen3_next.py | 2 +- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 8d969d5e661b..1a50d50a1494 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -30,6 +30,7 @@ FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, + RoutingMethodType, biased_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton @@ -1392,14 +1393,30 @@ def __init__( self.e_score_correction_bias = e_score_correction_bias self.apply_router_weight_on_input = apply_router_weight_on_input self.activation = activation - # Optional routing method id for backends that support multiple types - self.routing_method_type = routing_method_type if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError( "Only softmax scoring function is supported for non-grouped topk." ) + # ToDo: Better logic to determine the routing method type + if routing_method_type is not None: + self.routing_method_type = routing_method_type + else: + if scoring_func == "sigmoid": + if self.use_grouped_topk: + self.routing_method_type = RoutingMethodType.DeepSeekV3 + elif self.top_k == 1: + self.routing_method_type = RoutingMethodType.Llama4 + elif self.scoring_func == "softmax": + self.routing_method_type = ( + RoutingMethodType.Renormalize + if not self.renormalize + else RoutingMethodType.RenormalizeNaive + ) + else: + self.routing_method_type = RoutingMethodType.TopK + self.moe_config: FusedMoEConfig = FusedMoEConfig( num_experts=self.global_num_experts, experts_per_token=top_k, diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 0fc964b7b710..f1dc0b745467 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -34,11 +34,11 @@ fused_recurrent_gated_delta_rule, ) from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.layernorm import ( GemmaRMSNorm as Qwen3NextRMSNorm, ) from vllm.model_executor.layers.layernorm import RMSNormGated -from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, From 9361c868c7a8bdbc33e051d5f56dcad63e57e58e Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Tue, 4 Nov 2025 13:02:54 -0800 Subject: [PATCH 12/13] remove getattr Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- vllm/model_executor/layers/quantization/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index a2db45246292..9aa319382e52 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1236,7 +1236,7 @@ def apply( if e_score_correction_bias is not None else None ) - routing_method_type = getattr(layer, "routing_method_type", 2) + routing_method_type = layer.routing_method_type return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( routing_logits=router_logits.to(torch.float32) if routing_method_type == RoutingMethodType.DeepSeekV3 From ec5ba87c944343a5417b175dd5700fcb7bf83bf0 Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Fri, 7 Nov 2025 18:18:06 -0800 Subject: [PATCH 13/13] update routing Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- vllm/model_executor/models/qwen3_moe.py | 2 +- vllm/model_executor/models/qwen3_next.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 538a0ea40940..d57b82cb0227 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -172,7 +172,7 @@ def __init__( enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, - routing_method_type=RoutingMethodType.RenormalizeNaive, + routing_method_type=RoutingMethodType.Renormalize, ) self.gate = ReplicatedLinear( diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index f1dc0b745467..555708331825 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -174,7 +174,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, - routing_method_type=RoutingMethodType.RenormalizeNaive, + routing_method_type=RoutingMethodType.Renormalize, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: