Skip to content

Conversation

izhuhaoran
Copy link
Contributor

@izhuhaoran izhuhaoran commented Oct 19, 2025

Purpose

This PR is a follow PR about #27018 , and fuses QNorm, KNorm, and RoPE into a single CUDA kernel for the Qwen3 model, improving inference performance. We convert this fusion into a custom torch.compile pass, users can enable it by:

 --compilation-config='{"use_inductor": 1,  "pass_config": {"enable_qk_norm_rope_fusion": 1}}'

More details see #27018


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fused CUDA kernel for QK Normalization and RoPE for the Qwen model, aiming to improve inference performance. The fusion is implemented as a torch.compile pass. The changes include the CUDA kernel, its PyTorch bindings, the fusion pass logic, and integration into the model and build system. A new test is also added to verify the fusion.

The overall approach is solid and follows existing patterns in the codebase for custom ops and fusions. However, I've found a critical issue in the fusion pass implementation that causes the fusion to produce incorrect results. The output of the fused operation is not correctly propagated in the graph, making the fusion effectively a no-op. Please see the detailed comment for the fix.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +148 to +162
def apply_qk_norm_rope(self, qkv, positions):
if self.use_fused_qk_norm_rope:
ops.fused_qk_norm_rope(
qkv,
self.num_heads,
self.num_kv_heads,
self.num_kv_heads,
self.head_dim,
self.q_norm.variance_epsilon,
self.q_norm.weight,
self.k_norm.weight,
self.rotary_emb.cos_sin_cache,
self.rotary_emb.is_neox_style,
positions.view(-1),
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Cast RoPE cache to tensor dtype before fused op

When VLLM_FUSE_QKNORM_AND_ROPE is enabled, the new fast path calls ops.fused_qk_norm_rope using self.rotary_emb.cos_sin_cache directly. RotaryEmbedding initializes this buffer in float32 and only casts it to the query’s dtype inside RotaryEmbedding.forward() via _match_cos_sin_cache_dtype. Because the fused path bypasses that method, the buffer typically remains float32 while the kernel enforces bfloat16 (CHECK_INPUT(cos_sin_cache, torch::kBFloat16)), causing an immediate TORCH_CHECK failure the first time the fused kernel is used. Before invoking the fused op, the RoPE cache should be converted to the same dtype/device as qkv just as the unfused path does.

Useful? React with 👍 / 👎.

Comment on lines +306 to +319
def apply_qk_norm_rope(self, qkv, positions):
if self.use_fused_qk_norm_rope:
ops.fused_qk_norm_rope(
qkv,
self.num_heads,
self.num_kv_heads,
self.num_kv_heads,
self.head_dim,
self.q_norm.variance_epsilon,
self.q_norm.weight,
self.k_norm.weight,
self.rotary_emb.cos_sin_cache,
self.rotary_emb.is_neox_style,
positions.view(-1),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Ensure MoE RoPE cache matches fused kernel dtype

The MoE variant has the same issue: the fused path invokes ops.fused_qk_norm_rope without first aligning self.rotary_emb.cos_sin_cache to the query tensor’s dtype/device. The buffer starts as float32, while the CUDA kernel checks for bfloat16, so enabling the fused kernel leads to a runtime TORCH_CHECK error before any computation occurs. Mirror the unfused path by calling _match_cos_sin_cache_dtype (or otherwise casting) before the fused call.

Useful? React with 👍 / 👎.

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Comment on lines +1317 to +1320
# If set, use the fuse QKNorm and RoPE kernel
"VLLM_FUSE_QKNORM_AND_ROPE": lambda: bool(
int(os.getenv("VLLM_FUSE_QKNORM_AND_ROPE", "0"))
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use pass config instead of env vars.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use pass config instead of env vars.

The PR is still WIP. This environment variable comes from #27018 . This PR is currently attempting to convert it into a custom pass. Users will need to use (as you mentioned):
--compilation_config='{"use_inductor": 1, "pass_config": {"enable_qk_norm_rope_fusion": 1}}'
Once this PR is completed, these old contents will be cleaned up.

@ZJY0516
Copy link
Contributor

ZJY0516 commented Oct 19, 2025

call_function  split_with_sizes        aten.split_with_sizes.default       (mm_3, [4096, 1024, 1024], -1)                                   {}
call_function  getitem_6               <built-in function getitem>         (split_with_sizes, 0)                                            {}
call_function  getitem_7               <built-in function getitem>         (split_with_sizes, 1)                                            {}
call_function  getitem_8               <built-in function getitem>         (split_with_sizes, 2)                                            {}
call_function  empty                   aten.empty.memory_format            ([arg1_1, 32, 128],)                                             {'dtype': torch.bfloat16, 'layout': torch.strided, 'device': device(type='cuda', index=0), 'pin_memory': False}
call_function  permute_4               aten.permute.default                (empty, [0, 1, 2])                                               {}
call_function  view_1                  aten.reshape.default                (getitem_6, [arg1_1, 32, 128])                                   {}
call_function  clone                   aten.clone.default                  (view_1,)                                                        {'memory_format': torch.contiguous_format}
call_function  auto_functionalized_2   auto_functionalized                 (<OpOverload(op='_C.rms_norm', overload='default')>,)            {'result': permute_4, 'input': clone, 'weight': arg9_1, 'epsilon': 1e-06}
call_function  getitem_10              <built-in function getitem>         (auto_functionalized_2, 1)                                       {}
call_function  empty_1                 aten.empty.memory_format            ([arg1_1, 8, 128],)                                              {'dtype': torch.bfloat16, 'layout': torch.strided, 'device': device(type='cuda', index=0), 'pin_memory': False}
call_function  permute_5               aten.permute.default                (empty_1, [0, 1, 2])                                             {}
call_function  view_3                  aten.reshape.default                (getitem_7, [arg1_1, 8, 128])                                    {}
call_function  clone_1                 aten.clone.default                  (view_3,)                                                        {'memory_format': torch.contiguous_format}
call_function  auto_functionalized_3   auto_functionalized                 (<OpOverload(op='_C.rms_norm', overload='default')>,)            {'result': permute_5, 'input': clone_1, 'weight': arg10_1, 'epsilon': 1e-06}
call_function  getitem_12              <built-in function getitem>         (auto_functionalized_3, 1)                                       {}
call_function  view_5                  aten.reshape.default                (getitem_10, [arg1_1, 4096])                                     {}
call_function  view_6                  aten.reshape.default                (getitem_12, [arg1_1, 1024])                                     {}
call_function  auto_functionalized_4   auto_functionalized                 (<OpOverload(op='_C.rotary_embedding', overload='default')>,)    {'positions': arg11_1, 'query': view_5, 'key': view_6, 'head_size': 128, 'cos_sin_cache': arg13_1, 'is_neox': True}

The target graph for replacement is quite large. Using pattern matching here, as we do in other passes, may not scale effectively and could become a maintenance burden.
Do you have any suggestions? @ProExpertProg

@izhuhaoran
Copy link
Contributor Author

The target graph for replacement is quite large. Using pattern matching here, as we do in other passes, may not scale effectively and could become a maintenance burden. Do you have any suggestions? @ProExpertProg

I'm also aware of the same issue, which makes the pattern extremely hacky. So in my initial implementation, when enabling enable_qk_norm_rope_fusion, I set rms_norm and rope as custom ops, but this isn't the optimal solution either. I'm wondering if we should consider abandoning the conversion of this fusion into a custom pass and directly use the implementation from #27018 (which is also the current state in TRT-LLM)? I'd like to hear your thoughts on this, @ProExpertProg .

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like the right approach! Once you're done, please clean up the code and add E2E performance and lm-eval numbers.

"input_global_scale",
),
)
# # Defunctionalize fused_qk_norm_rope to remove higher-order wrapper.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to be removed or uncommented?

Comment on lines +86 to +93
# split qkv -> q,k,v
# q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
split_tuple = SPLIT_SIZES_OP(
qkv, [self.q_size, self.kv_size, self.kv_size], -1
)
q = operator.getitem(split_tuple, 0)
k = operator.getitem(split_tuple, 1)
v = operator.getitem(split_tuple, 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this should work, pattern tracing is very close to forward code tracing:

Suggested change
# split qkv -> q,k,v
# q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
split_tuple = SPLIT_SIZES_OP(
qkv, [self.q_size, self.kv_size, self.kv_size], -1
)
q = operator.getitem(split_tuple, 0)
k = operator.getitem(split_tuple, 1)
v = operator.getitem(split_tuple, 2)
# split qkv -> q,k,v
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

q_out = EMPTY_LIKE_OP(q_by_head)
q_by_head_contiguous = CONTIGUOUS_OP(q_by_head)

qn = auto_functionalized(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use MatcherRMSNorm so that we can match even with rms_norm disabled (using torch impl in forward_native)

]

# # Register variants across rope ops and with/without contiguous()
# # Ensure view ops are canonicalized to reshape in the traced pattern
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this not needed?

Comment on lines +272 to +273
if not current_platform.is_cuda_alike():
return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Silent disablement is not good. Should not be enabled at all on non-cuda-alike platforms

"QK Norm+RoPE fusion enabled, but no Attention layers were discovered."
)
return
layer_name, layer = next(iter(attn_layers.items()))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will only register the pattern using one layer, is that intended? Are we sure this will always pick the same layer also?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see now you don't care which layer because you only need the shapes, please add a comment for that

rope_op: torch._ops.OpOverload,
is_neox: bool,
) -> None:
self.layer = layer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you just using layer for these sizes? If yes, don't save the layer object on the pattern object, just extract the size properties

Comment on lines +551 to +553
if self.pass_config.enable_qk_norm_rope_fusion:
self.custom_ops.append("+rms_norm")
self.custom_ops.append("+rotary_embedding")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to remove this requirement; definitely for rms_norm and hopefully for RoPE as well although RoPE is less important - I assume all custom RoPE ops would be fused away anyway right?

rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
)
# Determine if we can use fused QK norm + RoPE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove model definition changes now that we have the fusion pass

rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
)
# Determine if we can use fused QK norm + RoPE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build qwen Related to Qwen models

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants