Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from enum import IntEnum
from typing import Optional, Union

import torch
Expand Down Expand Up @@ -91,6 +92,26 @@ def _quant_flags_to_group_shape(
return a_shape, w_shape


# The type of method in top-K routing
# Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h
class RoutingMethodType(IntEnum):
# Default: Softmax -> TopK
Default = (0,)
# Renormalize: TopK -> Softmax
Renormalize = (1,)
# DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups
# -> Top8 experts from the Top4 groups
DeepSeekV3 = (2,)
# Llama4: Top1 -> Sigmoid
Llama4 = (3,)
# RenormalizeNaive: Softmax -> TopK -> Renormalize
RenormalizeNaive = (4,)
# TopK: TopK (no softmax)
TopK = (5,)
# Unspecified
Unspecified = 6.0


@dataclass
class FusedMoEQuantDesc:
"""
Expand Down
26 changes: 12 additions & 14 deletions vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
calculate_tile_tokens_dim,
Expand All @@ -23,26 +24,24 @@ def flashinfer_fused_moe_blockscale_fp8(
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
num_expert_group: int | None,
topk_group: int | None,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
routed_scaling: float = 1.0,
routing_method_type: int = RoutingMethodType.DeepSeekV3,
routed_scaling: float | None = 1.0,
) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe

topk_group = topk_group if topk_group is not None else 0
assert top_k <= global_num_experts
assert top_k <= 8
assert topk_group <= 4
assert global_num_experts > num_expert_group
assert global_num_experts % num_expert_group == 0
assert top_k <= 10
assert global_num_experts % 4 == 0
assert top_k < (topk_group * global_num_experts / num_expert_group)
assert block_shape == [128, 128]
# Routing kernel expects #experts <= #threads 256
assert global_num_experts <= 256
# Routing kernel expects #experts <= #threads 512
assert global_num_experts <= 512

a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
# NOTE: scales of hidden states have to be transposed!
Expand All @@ -64,10 +63,8 @@ def flashinfer_fused_moe_blockscale_fp8(
local_expert_offset=expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling,
tile_tokens_dim=calculate_tile_tokens_dim(
x.shape[0], top_k, global_num_experts
),
routing_method_type=2, # DeepSeek-styled routing method
tile_tokens_dim=None,
routing_method_type=routing_method_type,
use_shuffled_weight=False,
)

Expand All @@ -88,6 +85,7 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
routing_method_type: int,
routed_scaling: float = 1.0,
) -> torch.Tensor:
return torch.empty_like(x)
Expand Down
20 changes: 20 additions & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
biased_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton
Expand Down Expand Up @@ -1218,6 +1219,7 @@ def __init__(
zero_expert_type: str | None = None,
expert_mapping: list[tuple[str, str, int, str]] | None = None,
n_shared_experts: int | None = None,
routing_method_type: int | None = None,
):
super().__init__()

Expand Down Expand Up @@ -1397,6 +1399,24 @@ def __init__(
"Only softmax scoring function is supported for non-grouped topk."
)

# ToDo: Better logic to determine the routing method type
if routing_method_type is not None:
self.routing_method_type = routing_method_type
else:
if scoring_func == "sigmoid":
if self.use_grouped_topk:
self.routing_method_type = RoutingMethodType.DeepSeekV3
elif self.top_k == 1:
self.routing_method_type = RoutingMethodType.Llama4
elif self.scoring_func == "softmax":
self.routing_method_type = (
RoutingMethodType.Renormalize
if not self.renormalize
else RoutingMethodType.RenormalizeNaive
)
else:
self.routing_method_type = RoutingMethodType.TopK

self.moe_config: FusedMoEConfig = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
Expand Down
14 changes: 7 additions & 7 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
RoutingMethodType,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
Expand Down Expand Up @@ -1226,22 +1227,20 @@ def apply(
assert activation == "silu", (
f"Expected 'silu' activation but got {activation}"
)
assert scoring_func == "sigmoid", (
f"Expected 'sigmoid' scoring func but got {scoring_func}"
)

if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401

assert (
renormalize and use_grouped_topk and custom_routing_function is None
)
e_score_correction_bias = (
e_score_correction_bias.to(x.dtype)
if e_score_correction_bias is not None
else None
)
routing_method_type = layer.routing_method_type
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32),
routing_logits=router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits,
routing_bias=e_score_correction_bias,
x=x,
w13_weight=layer.w13_weight,
Expand All @@ -1256,6 +1255,7 @@ def apply(
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size,
routing_method_type=routing_method_type,
routed_scaling=routed_scaling_factor,
)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,25 @@ class FlashinferMoeBackend(Enum):


def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
from flashinfer import next_positive_power_of_2

# FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now.
# TODO: Revert this to dynamic calculation once a new version of FlashInfer
# with the necessary kernels is released.
tile_tokens_dim = 8

# from flashinfer import next_positive_power_of_2

# # Guess tokens per expert assuming perfect expert distribution first.
# num_tokens_per_expert = (num_tokens * top_k) // num_experts
# # And pad the number to the next power of 2.
# tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# # Cap to 8-64 tokens per CTA tile as it's the range supported by the
# # kernel.
# tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
# A factor considering tokens are not perfectly balanced among experts.
imbalance_factor = 1.3
# Calculate the number of tokens per expert
# assuming perfect distribution.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# Apply the imbalance factor.
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
# And pad the number to the next power of 2.
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# Cap to 8-max_tile_tokens_dim tokens per CTA tile
# as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)

return tile_tokens_dim

Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
Expand Down Expand Up @@ -171,6 +172,7 @@ def __init__(
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
routing_method_type=RoutingMethodType.Renormalize,
)

self.gate = ReplicatedLinear(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
fused_recurrent_gated_delta_rule,
)
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3NextRMSNorm,
)
Expand Down Expand Up @@ -173,6 +174,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
routing_method_type=RoutingMethodType.Renormalize,
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
Expand Down