-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
[Kernel][Perf] fuse QK Norm and RoPE into one cuda kernel for Qwen Model #27165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a fused CUDA kernel for QK Normalization and RoPE for the Qwen model, aiming to improve inference performance. The fusion is implemented as a torch.compile
pass. The changes include the CUDA kernel, its PyTorch bindings, the fusion pass logic, and integration into the model and build system. A new test is also added to verify the fusion.
The overall approach is solid and follows existing patterns in the codebase for custom ops and fusions. However, I've found a critical issue in the fusion pass implementation that causes the fusion to produce incorrect results. The output of the fused operation is not correctly propagated in the graph, making the fusion effectively a no-op. Please see the detailed comment for the fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
def apply_qk_norm_rope(self, qkv, positions): | ||
if self.use_fused_qk_norm_rope: | ||
ops.fused_qk_norm_rope( | ||
qkv, | ||
self.num_heads, | ||
self.num_kv_heads, | ||
self.num_kv_heads, | ||
self.head_dim, | ||
self.q_norm.variance_epsilon, | ||
self.q_norm.weight, | ||
self.k_norm.weight, | ||
self.rotary_emb.cos_sin_cache, | ||
self.rotary_emb.is_neox_style, | ||
positions.view(-1), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cast RoPE cache to tensor dtype before fused op
When VLLM_FUSE_QKNORM_AND_ROPE
is enabled, the new fast path calls ops.fused_qk_norm_rope
using self.rotary_emb.cos_sin_cache
directly. RotaryEmbedding
initializes this buffer in float32
and only casts it to the query’s dtype inside RotaryEmbedding.forward()
via _match_cos_sin_cache_dtype
. Because the fused path bypasses that method, the buffer typically remains float32
while the kernel enforces bfloat16
(CHECK_INPUT(cos_sin_cache, torch::kBFloat16)
), causing an immediate TORCH_CHECK
failure the first time the fused kernel is used. Before invoking the fused op, the RoPE cache should be converted to the same dtype/device as qkv
just as the unfused path does.
Useful? React with 👍 / 👎.
def apply_qk_norm_rope(self, qkv, positions): | ||
if self.use_fused_qk_norm_rope: | ||
ops.fused_qk_norm_rope( | ||
qkv, | ||
self.num_heads, | ||
self.num_kv_heads, | ||
self.num_kv_heads, | ||
self.head_dim, | ||
self.q_norm.variance_epsilon, | ||
self.q_norm.weight, | ||
self.k_norm.weight, | ||
self.rotary_emb.cos_sin_cache, | ||
self.rotary_emb.is_neox_style, | ||
positions.view(-1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure MoE RoPE cache matches fused kernel dtype
The MoE variant has the same issue: the fused path invokes ops.fused_qk_norm_rope
without first aligning self.rotary_emb.cos_sin_cache
to the query tensor’s dtype/device. The buffer starts as float32
, while the CUDA kernel checks for bfloat16
, so enabling the fused kernel leads to a runtime TORCH_CHECK
error before any computation occurs. Mirror the unfused path by calling _match_cos_sin_cache_dtype
(or otherwise casting) before the fused call.
Useful? React with 👍 / 👎.
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
# If set, use the fuse QKNorm and RoPE kernel | ||
"VLLM_FUSE_QKNORM_AND_ROPE": lambda: bool( | ||
int(os.getenv("VLLM_FUSE_QKNORM_AND_ROPE", "0")) | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use pass config instead of env vars.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use pass config instead of env vars.
The PR is still WIP. This environment variable comes from #27018 . This PR is currently attempting to convert it into a custom pass. Users will need to use (as you mentioned):
--compilation_config='{"use_inductor": 1, "pass_config": {"enable_qk_norm_rope_fusion": 1}}'
Once this PR is completed, these old contents will be cleaned up.
The target graph for replacement is quite large. Using pattern matching here, as we do in other passes, may not scale effectively and could become a maintenance burden. |
I'm also aware of the same issue, which makes the pattern extremely hacky. So in my initial implementation, when enabling enable_qk_norm_rope_fusion, I set rms_norm and rope as custom ops, but this isn't the optimal solution either. I'm wondering if we should consider abandoning the conversion of this fusion into a custom pass and directly use the implementation from #27018 (which is also the current state in TRT-LLM)? I'd like to hear your thoughts on this, @ProExpertProg . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like the right approach! Once you're done, please clean up the code and add E2E performance and lm-eval numbers.
"input_global_scale", | ||
), | ||
) | ||
# # Defunctionalize fused_qk_norm_rope to remove higher-order wrapper. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this supposed to be removed or uncommented?
# split qkv -> q,k,v | ||
# q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) | ||
split_tuple = SPLIT_SIZES_OP( | ||
qkv, [self.q_size, self.kv_size, self.kv_size], -1 | ||
) | ||
q = operator.getitem(split_tuple, 0) | ||
k = operator.getitem(split_tuple, 1) | ||
v = operator.getitem(split_tuple, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that this should work, pattern tracing is very close to forward code tracing:
# split qkv -> q,k,v | |
# q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) | |
split_tuple = SPLIT_SIZES_OP( | |
qkv, [self.q_size, self.kv_size, self.kv_size], -1 | |
) | |
q = operator.getitem(split_tuple, 0) | |
k = operator.getitem(split_tuple, 1) | |
v = operator.getitem(split_tuple, 2) | |
# split qkv -> q,k,v | |
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) |
q_out = EMPTY_LIKE_OP(q_by_head) | ||
q_by_head_contiguous = CONTIGUOUS_OP(q_by_head) | ||
|
||
qn = auto_functionalized( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use MatcherRMSNorm
so that we can match even with rms_norm
disabled (using torch impl in forward_native
)
] | ||
|
||
# # Register variants across rope ops and with/without contiguous() | ||
# # Ensure view ops are canonicalized to reshape in the traced pattern |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this not needed?
if not current_platform.is_cuda_alike(): | ||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Silent disablement is not good. Should not be enabled at all on non-cuda-alike platforms
"QK Norm+RoPE fusion enabled, but no Attention layers were discovered." | ||
) | ||
return | ||
layer_name, layer = next(iter(attn_layers.items())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will only register the pattern using one layer, is that intended? Are we sure this will always pick the same layer also?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see now you don't care which layer because you only need the shapes, please add a comment for that
rope_op: torch._ops.OpOverload, | ||
is_neox: bool, | ||
) -> None: | ||
self.layer = layer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you just using layer for these sizes? If yes, don't save the layer object on the pattern object, just extract the size properties
if self.pass_config.enable_qk_norm_rope_fusion: | ||
self.custom_ops.append("+rms_norm") | ||
self.custom_ops.append("+rotary_embedding") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's try to remove this requirement; definitely for rms_norm and hopefully for RoPE as well although RoPE is less important - I assume all custom RoPE ops would be fused away anyway right?
rope_scaling=rope_scaling, | ||
dual_chunk_attention_config=dual_chunk_attention_config, | ||
) | ||
# Determine if we can use fused QK norm + RoPE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove model definition changes now that we have the fusion pass
rope_scaling=rope_scaling, | ||
dual_chunk_attention_config=dual_chunk_attention_config, | ||
) | ||
# Determine if we can use fused QK norm + RoPE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here as well
Purpose
This PR is a follow PR about #27018 , and fuses QNorm, KNorm, and RoPE into a single CUDA kernel for the Qwen3 model, improving inference performance. We convert this fusion into a custom torch.compile pass, users can enable it by:
More details see #27018
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.