Skip to content

Conversation

@izhuhaoran
Copy link
Contributor

@izhuhaoran izhuhaoran commented Oct 16, 2025

Purpose

This PR fuses QNorm, KNorm, and RoPE into a single CUDA kernel for the Qwen3 model, improving inference performance. And user can enable this fusion by VLLM_FUSE_QKNORM_AND_ROPE=1

Test Result

GPU Trace

  • Main
image
  • This PR
image

Bench Serve

setting: NVIDIA H20, TP=2, model=qwen3-30b-a3b-fp8, num_prompts=32, max-concurrency=8, input_len=out_len=1024
result: TTFT from 388.56ms to 192.26ms && TPOT from 10.29ms to 9.93ms

  • Main
============ Serving Benchmark Result ============
Successful requests:                     32        
Maximum request concurrency:             8         
Benchmark duration (s):                  43.66     
Total input tokens:                      32620     
Total generated tokens:                  32768     
Request throughput (req/s):              0.73      
Output token throughput (tok/s):         750.48    
Peak output token throughput (tok/s):    792.00    
Peak concurrent requests:                16.00     
Total Token throughput (tok/s):          1497.57   
---------------Time to First Token----------------
Mean TTFT (ms):                          388.56    
Median TTFT (ms):                        223.63    
P99 TTFT (ms):                           1070.85   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          10.29     
Median TPOT (ms):                        10.26     
P99 TPOT (ms):                           10.99     
---------------Inter-token Latency----------------
Mean ITL (ms):                           10.29     
Median ITL (ms):                         10.23     
P99 ITL (ms):                            10.70     
==================================================
  • This PR
============ Serving Benchmark Result ============
Successful requests:                     32        
Maximum request concurrency:             8         
Benchmark duration (s):                  41.40     
Total input tokens:                      32768     
Total generated tokens:                  32768     
Request throughput (req/s):              0.77      
Output token throughput (tok/s):         791.47    
Peak output token throughput (tok/s):    824.00    
Peak concurrent requests:                16.00     
Total Token throughput (tok/s):          1582.95   
---------------Time to First Token----------------
Mean TTFT (ms):                          192.26    
Median TTFT (ms):                        207.59    
P99 TTFT (ms):                           223.33    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          9.93      
Median TPOT (ms):                        9.88      
P99 TPOT (ms):                           10.13     
---------------Inter-token Latency----------------
Mean ITL (ms):                           9.93      
Median ITL (ms):                         9.89      
P99 ITL (ms):                            10.45     
==================================================

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
@izhuhaoran izhuhaoran changed the title [Kernel][Perf] fuse QNorm KNorm and RoPE into a single cuda kernel for Qwen Model [Kernel][Perf] fuse QK Norm and RoPE into a single cuda kernel for Qwen Model Oct 16, 2025
@mergify mergify bot added ci/build qwen Related to Qwen models labels Oct 16, 2025
@izhuhaoran izhuhaoran changed the title [Kernel][Perf] fuse QK Norm and RoPE into a single cuda kernel for Qwen Model [Kernel][Perf] fuse QK Norm and RoPE into one cuda kernel for Qwen Model Oct 16, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fused CUDA kernel for Q-Norm, K-Norm, and RoPE operations for the Qwen3 model, which shows significant performance improvements. The implementation is well-structured and follows patterns from existing high-performance kernels.

I've identified a critical issue regarding the use of this fused kernel with dynamic RoPE scaling methods, which could lead to incorrect results. I've also pointed out a minor issue with const correctness in the CUDA code.

Overall, this is a great performance optimization. Addressing the identified issues will make it robust and ready for merging.

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Comment on lines +1314 to +1317
# If set, use the fuse QKNorm and RoPE kernel
"VLLM_FUSE_QKNORM_AND_ROPE": lambda: bool(
int(os.getenv("VLLM_FUSE_QKNORM_AND_ROPE", "0"))
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any downsides to using the fused implementation? Just asking because I'm wondering whether it would make sense to enable by default or even remove the config option.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation does not support RoPE with custom forward implementations that differ from the RotaryEmbedding, such as DualChunkRotaryEmbedding(CustomOp).

Although Qwen3 might not use this type of RoPE, I've added the VLLM_FUSE_QKNORM_AND_ROPE environment variable for safety. Additionally, in qwen3.py & qwen3_moe.py, I add isinstance(self.rotary_emb, RotaryEmbedding) check to filter out unsupported RoPE implementations.

Copy link
Member

@mgoin mgoin Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@izhuhaoran Do you think we could formulate this as a custom pass for torch.compile to substitute the CUDA kernel in for? This would generalize across arches and also take care of greedily enabling by default when valid for Qwen.
Might be tough for this particular op, but worth a try cc @ProExpertProg

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, I will try it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also please compare performance with 2.9 as we improved the Inductor-generated kernels

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also please compare performance with 2.9 as we improved the Inductor-generated kernels

ok, I will bench again with torch 2.9

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@izhuhaoran Do you think we could formulate this as a custom pass for torch.compile to substitute the CUDA kernel in for? This would generalize across arches and also take care of greedily enabling by default when valid for Qwen. Might be tough for this particular op, but worth a try cc @ProExpertProg

I tried converting this fusion into a custom torch.compile pass, referencing existing fusion passes. However, my initial pass pattern isn't triggering—no fusion replacement occurs (both e2e tests and unit tests fail).

The draft code is here: https://github.com/izhuhaoran/vllm/tree/fuse-qknorm-rope-compile

I don't have too much experience with torch.compile, so I might be missing something. @mgoin @ProExpertProg Could someone help me convert this into a custom pass?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, responded on Slack - can you open a draft PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, responded on Slack - can you open a draft PR?

The follow draft PR is #27165 . Thanks for your help

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
@izhuhaoran
Copy link
Contributor Author

izhuhaoran commented Oct 18, 2025

Here is new evaluation with torch 2.9 :

Bench Serve

setting: NVIDIA H20, TP=2, model=qwen3-30b-a3b-fp8, num_prompts=32, max-concurrency=32, input_len=out_len=1024
result: TTFT from 835.54ms to 826.41ms && TPOT from 15.33ms to 15.01ms

  • Main
============ Serving Benchmark Result ============
Successful requests:                     32        
Maximum request concurrency:             32        
Benchmark duration (s):                  16.54     
Total input tokens:                      32768     
Total generated tokens:                  32768     
Request throughput (req/s):              1.93      
Output token throughput (tok/s):         1980.78   
Peak output token throughput (tok/s):    2400.00   
Peak concurrent requests:                32.00     
Total Token throughput (tok/s):          3961.56   
---------------Time to First Token----------------
Mean TTFT (ms):                          835.54    
Median TTFT (ms):                        536.59    
P99 TTFT (ms):                           2029.73   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          15.33     
Median TPOT (ms):                        15.62     
P99 TPOT (ms):                           15.98     
---------------Inter-token Latency----------------
Mean ITL (ms):                           15.33     
Median ITL (ms):                         14.23     
P99 ITL (ms):                            18.98     
==================================================
  • This PR
============ Serving Benchmark Result ============
Successful requests:                     32        
Maximum request concurrency:             32        
Benchmark duration (s):                  16.20     
Total input tokens:                      32768     
Total generated tokens:                  32768     
Request throughput (req/s):              1.97      
Output token throughput (tok/s):         2022.19   
Peak output token throughput (tok/s):    2496.00   
Peak concurrent requests:                32.00     
Total Token throughput (tok/s):          4044.37   
---------------Time to First Token----------------
Mean TTFT (ms):                          826.41    
Median TTFT (ms):                        531.50    
P99 TTFT (ms):                           2006.07   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          15.01     
Median TPOT (ms):                        15.30     
P99 TPOT (ms):                           15.66     
---------------Inter-token Latency----------------
Mean ITL (ms):                           15.01     
Median ITL (ms):                         13.99     
P99 ITL (ms):                            15.57     
==================================================

GPU Timeline Profile

  • Main
image
  • This PR
image

CC @ProExpertProg , we can see the fusion qknorm_rope kernel also better than torch.compile (2.9)

@ProExpertProg
Copy link
Collaborator

Those are great results! Just one question: I see gaps between the kernels in the profile, are you using cudagraphs when profiling? If not could you show a profile with cudagraphs?

@izhuhaoran
Copy link
Contributor Author

Those are great results! Just one question: I see gaps between the kernels in the profile, are you using cudagraphs when profiling? If not could you show a profile with cudagraphs?

yes, I was already using cudagraph (full cudagraph mode) when profiling. The gaps you see are only about 0.5us, which is actually quite common in CUDA graph mode (and would be much larger without CUDA graphs). The gaps appear more prominent in the visualization mainly due to the scaling of the picture. Here's the original profile for reference:

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build qwen Related to Qwen models

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants