Skip to content

Commit bcebcc6

Browse files
committed
update work
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
1 parent 6edee14 commit bcebcc6

File tree

5 files changed

+36
-15
lines changed

5 files changed

+36
-15
lines changed

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,12 @@ class RoutingMethodType(IntEnum):
104104
DeepSeekV3 = (2,)
105105
# Llama4: Top1 -> Sigmoid
106106
Llama4 = (3,)
107-
# Qwen3: Softmax -> TopK -> Renormalize
107+
# RenormalizeNaive: Softmax -> TopK -> Renormalize
108108
RenormalizeNaive = (4,)
109+
# TopK: TopK (no softmax)
110+
TopK = (5,)
109111
# Unspecified
110-
Unspecified = 5.0
112+
Unspecified = 6.0
111113

112114

113115
@dataclass

vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def flashinfer_fused_moe_blockscale_fp8(
4949
# NOTE: scales of hidden states have to be transposed!
5050
a_sf_t = a_sf.t().contiguous()
5151
return flashinfer_trtllm_fp8_block_scale_moe(
52-
routing_logits=routing_logits,
52+
routing_logits=routing_logits,
5353
routing_bias=routing_bias,
5454
hidden_states=a_q,
5555
hidden_states_scale=a_sf_t,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,14 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
160160
logger.info_once("Using DeepGEMM backend for FP8 MoE")
161161
return Fp8MoeBackend.DEEPGEMM
162162

163-
# CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
164-
if (
165-
current_platform.is_cuda()
166-
and current_platform.is_device_capability(100)
167-
and block_quant
168-
):
169-
logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
170-
return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
163+
# # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
164+
# if (
165+
# current_platform.is_cuda()
166+
# and current_platform.is_device_capability(100)
167+
# and block_quant
168+
# ):
169+
# logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
170+
# return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
171171

172172
# default to Triton
173173
logger.info_once("Using Triton backend for FP8 MoE")
@@ -1294,7 +1294,10 @@ def apply(
12941294
# can override fused_experts or cutlass but not rocm or marlin.
12951295
#
12961296
topk_weights, topk_ids, zero_expert_result = select_result
1297-
1297+
# if (topk_ids.shape[0] <100):
1298+
# print("=== MoE Routing Results ===")
1299+
# print(f"topk_ids: {topk_ids}")
1300+
# print(f"topk_weights: {topk_weights}")
12981301
if self.rocm_aiter_moe_enabled:
12991302
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
13001303
rocm_aiter_fused_experts,

vllm/model_executor/models/qwen3_moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
from vllm.logger import init_logger
4444
from vllm.model_executor.layers.activation import SiluAndMul
4545
from vllm.model_executor.layers.fused_moe import FusedMoE
46+
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
47+
4648
from vllm.model_executor.layers.layernorm import RMSNorm
4749
from vllm.model_executor.layers.linear import (
4850
MergedColumnParallelLinear,
@@ -171,6 +173,7 @@ def __init__(
171173
enable_eplb=self.enable_eplb,
172174
num_redundant_experts=self.n_redundant_experts,
173175
is_sequence_parallel=self.is_sequence_parallel,
176+
routing_method_type=RoutingMethodType.RenormalizeNaive,
174177
)
175178

176179
self.gate = ReplicatedLinear(

vllm/model_executor/models/qwen3_next.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
101101
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
102102
super().__init__()
103103

104+
self.prefix_print = prefix
104105
config = vllm_config.model_config.hf_config
105106
parallel_config = vllm_config.parallel_config
106107
quant_config = vllm_config.quant_config
@@ -172,18 +173,18 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
172173
enable_eplb=self.enable_eplb,
173174
num_redundant_experts=self.n_redundant_experts,
174175
is_sequence_parallel=self.is_sequence_parallel,
175-
routing_method_type=RoutingMethodType.Renormalize,
176+
routing_method_type=RoutingMethodType.RenormalizeNaive,
176177
)
177178

178179
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
179180
# NOTE: hidden_states can have either 1D or 2D shape.
180181
orig_shape = hidden_states.shape
181182
num_tokens, hidden_dim = hidden_states.shape
182183
hidden_states = hidden_states.view(-1, hidden_dim)
183-
184184
if self.is_sequence_parallel:
185185
hidden_states = sequence_parallel_chunk(hidden_states)
186186

187+
<<<<<<< HEAD
187188
if self.experts.is_internal_router:
188189
# In this case, the gate/router runs inside the FusedMoE class
189190
final_hidden_states = self.experts(
@@ -195,8 +196,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
195196
final_hidden_states = self.experts(
196197
hidden_states=hidden_states, router_logits=router_logits
197198
)
199+
=======
200+
# print(self.prefix_print)
201+
# router_logits: (num_tokens, n_experts)
202+
router_logits, _ = self.gate(hidden_states)
203+
final_hidden_states = self.experts(
204+
hidden_states=hidden_states, router_logits=router_logits
205+
)
206+
>>>>>>> 9d88f1762 (update work)
198207

199208
if self.shared_expert is not None:
209+
# if ("model.layers.0." in self.prefix_print or "model.layers.1." in self.prefix_print or "model.layers.47." in self.prefix_print):
210+
# print(f"shared_expert: {final_hidden_states[0]}")
211+
# print(f"routed_expert: {final_hidden_states[1]}")
200212
final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
201213

202214
if self.is_sequence_parallel:
@@ -873,7 +885,7 @@ def forward(
873885
residual: torch.Tensor | None,
874886
positions: torch.Tensor = None,
875887
**kwargs: object,
876-
):
888+
):
877889
if residual is None:
878890
residual = hidden_states
879891
hidden_states = self.input_layernorm(hidden_states)
@@ -1004,6 +1016,7 @@ def forward(
10041016
{"hidden_states": hidden_states, "residual": residual}
10051017
)
10061018
hidden_states, _ = self.norm(hidden_states, residual)
1019+
# print("="*60)
10071020
return hidden_states
10081021

10091022
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:

0 commit comments

Comments
 (0)