Skip to content

Conversation

@jiahanc
Copy link
Contributor

@jiahanc jiahanc commented Oct 24, 2025

Purpose

  • Integrate multiple routing methods for FP8 flashinfer trtllm MOE, currently only DS and Llama4
  • Add FP8 flashinfer trtllm MOE support on Qwen3 and Qwen3-next
  • Bump flashinfer version to 0.5.0rc1 (required for this PR)

Test Plan

Qwen3-Next-80B-A3B-Instruct-FP8 on 2xB200 TP2

VLLM_USE_FLASHINFER_MOE_FP8=1 VLLM_FLASHINFER_MOE_BACKEND=latency VLLM_USE_DEEP_GEMM=0 VLLM_USE_TRTLLM_ATTENTION=0 VLLM_ATTENTION_BACKEND=FLASH_ATTN vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct-FP8 \
    --max-num-batched-tokens 8192 \
    --max-model-len 16384 \
    --no-enable-prefix-caching \
    --async-scheduling \
    --compilation_config.pass_config.enable_fi_allreduce_fusion true \
    --compilation_config.pass_config.enable_noop true \
    --compilation_config.cudagraph_mode FULL_DECODE_ONLY \
    --compilation_config.splitting_ops [] \
    -tp 2 
lm_eval --model local-completions --tasks gsm8k --model_args model=Qwen/Qwen3-Next-80B-A3B-Instruct-FP8,base_url=http://0.0.0.0:8000/v1/completions,max_retries=3,tokenized_requests=False,timeout=1200,max_gen_toks=2048,max_length=8192 --batch_size 2048 --trust_remote_code --limit 0.5

Qwen3-30B-A3B-Instruct-2507-FP8 on 2xB200 TP2

VLLM_USE_FLASHINFER_MOE_FP8=1 VLLM_FLASHINFER_MOE_BACKEND=latency VLLM_USE_DEEP_GEMM=0 vllm serve Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 \
    --max-num-batched-tokens 8192 \
    --max-model-len 16384 \
    --no-enable-prefix-caching \
    --async-scheduling \
    --compilation_config.pass_config.enable_fi_allreduce_fusion true \
    --compilation_config.pass_config.enable_noop true \
    --compilation_config.cudagraph_mode FULL_DECODE_ONLY \
    --compilation_config.splitting_ops [] \
    -tp 2
lm_eval --model local-completions --tasks gsm8k --model_args model=Qwen/Qwen3-30B-A3B-Instruct-2507-FP8,base_url=http://0.0.0.0:8000/v1/completions,max_retries=3,tokenized_requests=False,timeout=1200,max_gen_toks=2048,max_length=8192 --batch_size 2048 --trust_remote_code --limit 0.5

Test Result

Qwen3-Next-80B-A3B-Instruct-FP8

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9515|±  |0.0084|
|     |       |strict-match    |     5|exact_match|↑  |0.9197|±  |0.0106|

Qwen/Qwen3-30B-A3B-Instruct-2507-FP8

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9379|±  |0.0094|
|     |       |strict-match    |     5|exact_match|↑  |0.9364|±  |0.0095|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@jiahanc jiahanc changed the title [Performance] Support flashinfer TRTLLM MOE on Qwen3 and Qwen3next [Performance] Support flashinfer FP8 TRTLLM MOE on Qwen3 and Qwen3next Oct 24, 2025
@mergify mergify bot added the qwen Related to Qwen models label Oct 24, 2025
@jiahanc jiahanc changed the title [Performance] Support flashinfer FP8 TRTLLM MOE on Qwen3 and Qwen3next [Performance] Support flashinfer FP8 TRTLLM MOE on Qwen3 and Qwen-3next Oct 24, 2025
@jiahanc jiahanc changed the title [Performance] Support flashinfer FP8 TRTLLM MOE on Qwen3 and Qwen-3next [Performance] Support FP8 flashinfer TRTLLM MOE on Qwen3 and Qwen-3next Oct 24, 2025
@jiahanc jiahanc force-pushed the qwen3next_trtllmgen_moe branch from 9aaf36c to aa947da Compare October 24, 2025 23:53
@jiahanc jiahanc force-pushed the qwen3next_trtllmgen_moe branch from c3863df to 15b457c Compare October 28, 2025 17:15
@mergify
Copy link

mergify bot commented Oct 29, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jiahanc.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 29, 2025
@jiahanc jiahanc force-pushed the qwen3next_trtllmgen_moe branch from 15b457c to fccb4d0 Compare October 29, 2025 21:10
@mergify mergify bot removed the needs-rebase label Oct 29, 2025
@jiahanc jiahanc marked this pull request as ready for review October 30, 2025 16:50
@mergify mergify bot added the ci/build label Oct 30, 2025
@jiahanc
Copy link
Contributor Author

jiahanc commented Oct 30, 2025

Qwen3-Next-80B-A3B-Instruct-FP8 on 1xB200 1k/1k benchmark
image

@jiahanc jiahanc force-pushed the qwen3next_trtllmgen_moe branch from 08dcd1b to 2b9022e Compare October 30, 2025 17:00
@jiahanc
Copy link
Contributor Author

jiahanc commented Oct 30, 2025

@mgoin @pavanimajety may you help review the PR?
pre-commit failure is unrelated to any changed file. Might be introduced in other PRs

@mxz297
Copy link
Contributor

mxz297 commented Oct 31, 2025

If this PR is merged, can vllm still run with older flashinfer? We are internally just upgrading to flashinfer nightly-v0.4.1-20251027. This seems to bump flashinfer version again. Is it possible to consider some backward compatibility with older flashinfer version?

cc @houseroad @yeqcharlotte

@jiahanc
Copy link
Contributor Author

jiahanc commented Oct 31, 2025

If this PR is merged, can vllm still run with older flashinfer? We are internally just upgrading to flashinfer nightly-v0.4.1-20251027. This seems to bump flashinfer version again. Is it possible to consider some backward compatibility with older flashinfer version?

cc @houseroad @yeqcharlotte

Hi @mxz297 ,
There is no api change compared to v0.4.1. :)

Comment on lines 1226 to 1229
routing_method_type = getattr(layer, "routing_method_type", 2)
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32),
routing_logits=router_logits.to(torch.float32)
if routing_method_type == 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use the enum rather than raw int

Comment on lines +95 to +112
# The type of method in top-K routing
# Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h
class RoutingMethodType(IntEnum):
# Default: Softmax -> TopK
Default = (0,)
# Renormalize: TopK -> Softmax
Renormalize = (1,)
# DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups
# -> Top8 experts from the Top4 groups
DeepSeekV3 = (2,)
# Llama4: Top1 -> Sigmoid
Llama4 = (3,)
# RenormalizeNaive: Softmax -> TopK -> Renormalize
RenormalizeNaive = (4,)
# TopK: TopK (no softmax)
TopK = (5,)
# Unspecified
Unspecified = 6.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of having a routing method type so we can reduce the need for hacks like checking the llama 4 custom routing function within the quant method.
However I think directly tying the values to the flashinfer trtllm fusedmoe is short-sighted if we are to leverage this across the codebase. I think if we do this right, we can actually remove other arguments we have in FusedMoE such as renormalize.
So I think this is the important design change in the PR. We could currently derive the routing type based on existing arguments, and of course allow for explicit override. I'm interested to hear @bnellnm thoughts too

I don't necessarily want to block this PR on getting the final design right, but I do want to get agreement with my other comments that this makes sense to be a more explicit control for routing types across all fused moe methods

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @mgoin that it would be nice to derive the routing type from existing arguments. Would it make more sense to have a collection of router objects/functions that could be passed in directly?

Copy link
Contributor Author

@jiahanc jiahanc Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add logic in FusedMOE top check the routing method given param.
For use of other backends and etc might need more discussion and design which might not be the scope of this PR :)
@mgoin please let me know if this make sense 😄

if self.use_grouped_topk:
self.routing_method_type = RoutingMethodType.DeepSeekV3
elif self.top_k == 1:
self.routing_method_type = RoutingMethodType.Llama4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not directly related to (or suggested for) this PR but we could also set apply_weights_on_input to True for this case and get rid of the runtime parameter. cc @mgoin

Copy link
Contributor Author

@jiahanc jiahanc Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this param used for? Searched code base, only found supports_apply_weight_on_input , no apply_weights_on_input

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this param used for? Searched code base, only found supports_apply_weight_on_input , no apply_weights_on_input

It's actually apply_router_weight_on_input. I just couldn't remember the exact name when I wrote the comment. Afaik, it is only used for llama when topk==1, so I was wondering if we could detect and store it here while deriving the routing method. We could also remove it as an extra argument to apply. You don't need to make this change for this PR. I just wanted to point it out.

@jiahanc jiahanc force-pushed the qwen3next_trtllmgen_moe branch from 608f7e8 to 2c2e9b3 Compare November 3, 2025 21:55
@jiahanc
Copy link
Contributor Author

jiahanc commented Nov 4, 2025

@mgoin may you help re-review?

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
@jiahanc jiahanc force-pushed the qwen3next_trtllmgen_moe branch from 2c2e9b3 to 648547a Compare November 4, 2025 17:38
@jiahanc
Copy link
Contributor Author

jiahanc commented Nov 4, 2025

blocked by flashinfer-ai/flashinfer#2032. dont merge before issue fixed.

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Copy link
Contributor

@bnellnm bnellnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@mergify
Copy link

mergify bot commented Nov 5, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jiahanc.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build qwen Related to Qwen models

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants