Skip to content

Commit be4dc1a

Browse files
committed
add Qwen3 and fix lint
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
1 parent bcebcc6 commit be4dc1a

File tree

6 files changed

+21
-18
lines changed

6 files changed

+21
-18
lines changed

vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
45
import torch
5-
from typing import Optional
6+
67
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
78
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
89
calculate_tile_tokens_dim,
@@ -23,16 +24,17 @@ def flashinfer_fused_moe_blockscale_fp8(
2324
w2_weight_scale_inv: torch.Tensor,
2425
global_num_experts: int,
2526
top_k: int,
26-
num_expert_group: Optional[int],
27-
topk_group: Optional[int],
27+
num_expert_group: int | None,
28+
topk_group: int | None,
2829
intermediate_size: int,
2930
expert_offset: int,
3031
local_num_experts: int,
3132
block_shape: list[int],
32-
routed_scaling: Optional[float] = 1.0,
33+
routed_scaling: float | None = 1.0,
3334
routing_method_type: int = 2,
3435
) -> torch.Tensor:
3536
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
37+
3638
topk_group = topk_group if topk_group is not None else 0
3739
assert top_k <= global_num_experts
3840
assert top_k <= 10
@@ -49,7 +51,7 @@ def flashinfer_fused_moe_blockscale_fp8(
4951
# NOTE: scales of hidden states have to be transposed!
5052
a_sf_t = a_sf.t().contiguous()
5153
return flashinfer_trtllm_fp8_block_scale_moe(
52-
routing_logits=routing_logits,
54+
routing_logits=routing_logits,
5355
routing_bias=routing_bias,
5456
hidden_states=a_q,
5557
hidden_states_scale=a_sf_t,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,14 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
160160
logger.info_once("Using DeepGEMM backend for FP8 MoE")
161161
return Fp8MoeBackend.DEEPGEMM
162162

163-
# # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
164-
# if (
165-
# current_platform.is_cuda()
166-
# and current_platform.is_device_capability(100)
167-
# and block_quant
168-
# ):
169-
# logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
170-
# return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
163+
# CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
164+
if (
165+
current_platform.is_cuda()
166+
and current_platform.is_device_capability(100)
167+
and block_quant
168+
):
169+
logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
170+
return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
171171

172172
# default to Triton
173173
logger.info_once("Using Triton backend for FP8 MoE")
@@ -1230,7 +1230,9 @@ def apply(
12301230
)
12311231
routing_method_type = getattr(layer, "routing_method_type", 2)
12321232
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1233-
routing_logits=router_logits.to(torch.float32) if routing_method_type == 2 else router_logits,
1233+
routing_logits=router_logits.to(torch.float32)
1234+
if routing_method_type == 2
1235+
else router_logits,
12341236
routing_bias=e_score_correction_bias,
12351237
x=x,
12361238
w13_weight=layer.w13_weight,
@@ -1244,7 +1246,7 @@ def apply(
12441246
intermediate_size=layer.intermediate_size_per_partition,
12451247
expert_offset=layer.ep_rank * layer.local_num_experts,
12461248
local_num_experts=layer.local_num_experts,
1247-
block_shape=self.weight_block_size,
1249+
block_shape=self.weight_block_size,
12481250
routed_scaling=routed_scaling_factor,
12491251
routing_method_type=routing_method_type,
12501252
)

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1349,7 +1349,6 @@ def prepare_static_weights_for_trtllm_fp4_moe(
13491349
):
13501350
from flashinfer import nvfp4_block_scale_interleave
13511351
from flashinfer.fused_moe.core import (
1352-
get_w2_permute_indices_with_cache,
13531352
_maybe_get_cached_w3_w1_permute_indices,
13541353
get_w2_permute_indices_with_cache,
13551354
)

vllm/model_executor/layers/quantization/utils/flashinfer_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class FlashinferMoeBackend(Enum):
2828

2929
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
3030
from flashinfer import next_positive_power_of_2
31+
3132
# FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now.
3233
# TODO: Revert this to dynamic calculation once a new version of FlashInfer
3334
# with the necessary kernels is released.

vllm/model_executor/models/qwen3_moe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
from vllm.model_executor.layers.activation import SiluAndMul
4545
from vllm.model_executor.layers.fused_moe import FusedMoE
4646
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
47-
4847
from vllm.model_executor.layers.layernorm import RMSNorm
4948
from vllm.model_executor.layers.linear import (
5049
MergedColumnParallelLinear,

vllm/model_executor/models/qwen3_next.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,7 @@ def forward(
885885
residual: torch.Tensor | None,
886886
positions: torch.Tensor = None,
887887
**kwargs: object,
888-
):
888+
):
889889
if residual is None:
890890
residual = hidden_states
891891
hidden_states = self.input_layernorm(hidden_states)

0 commit comments

Comments
 (0)