-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[Bugfix][ROCm] Fix ViT rotary embeddings for torch.compile compatibility on ROCm #27748
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly addresses a torch.compile compatibility issue on ROCm for ViT rotary embeddings by wrapping the flash_attn Triton kernel in a vLLM custom op. The refactoring of apply_rotary_pos_emb_vision into a common utility is a good improvement for code reuse.
However, I've identified a critical issue where the refactoring exposes a latent bug that will cause a TypeError on CUDA platforms due to a function signature mismatch. My review includes a suggested fix for this issue.
…ginal code Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
|
@zhewenl can you try this PR while re-enabling torch.compile and see if it works for you? |
@DarkLight1337, @vllmellm I have verified with bench serve(https://gist.github.com/zhewenl/41fc928c427ff5da7a2df25376d6d136), and also verified |
| if current_platform.is_cuda(): | ||
| return apply_rotary_emb | ||
|
|
||
| if current_platform.is_rocm(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vllmellm @DarkLight1337 What do you think? Should we keep the triton rotary embedding for backwards compatibility at this moment, by adding a condition torch.compiler.is_compiling() instead as there are other model definition files that are not using torch.compile yet.
E.g. model definition files that are not using torch.compile
vllm/model_executor/models/dots_ocr.pyvllm/model_executor/models/ernie45_vl.pyvllm/model_executor/models/glm4_1v.pyvllm/model_executor/models/qwen2_vl.pyvllm/model_executor/models/siglip2navit.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another thing to note at this point is that the torch.compile has been disabled as there are issues with dynamic slicing.
So if we add this condition torch.compiler.is_compiling(), then we don't need to postpone this PR while ensuring no performance regression for both cases torch.compile is enabled/disabled.
Right now, they are exploring new fixes e.g. #27764
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My two cents: I see the benefit of gating with torch.compiler.is_compiling(), especially for models we don't have covered with torch.compile yet, so I am in favor of that approach as opposed to completely removing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 on adding a condition on checking if it's torch compiled.
On a side note - is this PR really considered a bugfix? I thought after #27760 things should already work on AMD. Did I miss something here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ywang96 PR #27760 is a bugfix by reverting the torch.compile PR. If the torch.compile is reenable, AMD needs this bugfix PR to address the incompatibility of the flash_attn.rotary_embedding triton kernel with torch.compile. The discussion of current comment thread is to address the case where other models still haven't had torch.compile support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And in this PR we found that torch.compile-ed rotary embedding is faster than triton implementation on ROCm. I think it will be the same case on CUDA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea I was just confirming whether the current main branch is broken on AMD (because it'll affect our release), and it seems that it's not because we currently disabled it, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ywang96 On current main branch, the unit tests are healthy and the Qwen3 VL accuracy are normal. 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tjtanaa @Lucaskabela those model path are not using this dispatching function. it is only used in qwen2_vl.py. While qwen2.5_vl.py and glm4_1v.py are importing the apply_rotary_pos_emb_vision function from qwen2_vl.py and inside this function the dispatch is used.
vllm/vllm/model_executor/models/qwen2_vl.py
Lines 314 to 320 in 3696050
| def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: | |
| rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch) | |
| t_ = t.float() | |
| cos = freqs.cos() | |
| sin = freqs.sin() | |
| output = rotary_emb_function(t_, cos, sin).type_as(t) | |
| return output |
@DarkLight1337 @Lucaskabela @tjtanaa we may decide to keep the triton embedding function here only for one case that is glm4_1v.py or not ?
…mpatible now Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this approach LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that #27764 is merged - we should also merge this in
…ity on ROCm (vllm-project#27748) Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Purpose
This pull request updates the selection of rotary embedding function previously selected for the ROCm platform from
flash_attn.ops.triton.rotarybecause it is not compatible withtorch.compile. Since this PR enabledtorch.compilesupport for Qwen-VL models, it became necessary to address this incompatibility. After registering the rotary embedding operation usingdirect_register_custom_opswithin vLLM, benchmarking revealed that the naive PyTorch implementation of the rotary embedding is actually faster on ROCm. Therefore, this PR updates the ROCm platform to use the native PyTorch implementation instead whentorch.compileis enabled and for those models that don't enabletorch.compileuse the rotary embedding fromflash_attn.ops.triton.rotary.benchmark setting:
vllm serve Qwen/Qwen2.5-VL-3B-Instruct \ -tp 1 \ --port 9099 \ --trust-remote-code --swap-space 16 --distributed-executor-backend mpvllm bench serve \ --model Qwen/Qwen2.5-VL-3B-Instruct \ --backend openai-chat \ --endpoint /v1/chat/completions \ --dataset-name hf \ --port 9099 \ --dataset-path lmarena-ai/VisionArena-Chat \ --hf-split train \ --num-prompts 1000benchmark results
This pull request also fixes the bug below when running
glm visionmodels such aszai-org/GLM-4.5V-FP8withVLLM_ROCM_USE_AITER=1the
dropout_pvalue is changed from0to0.0to match the requiredfloatdtype byflash_attn_varlen_funcaiter kernel.Test Plan
lm_eval test on Qwen/Qwen2.5-VL-3B-Instruct using
mmmudataset in mistral-eval repoVLLM_USE_V1=1 VLLM_ROCM_USE_AITER=1 vllm serve Qwen/Qwen2.5-VL-3B-Instruct -tp 1 --port 9099 --trust-remote-code --swap-space 16 --distributed-executor-backend mppython -m eval.run eval_vllm \ --model_name Qwen/Qwen2.5-VL-3B-Instruct \ --url http://0.0.0.0:9099 \ --output_dir ./logs \ --eval_name "mmmu"Test Result
Metrics:
{
"explicit_prompt_relaxed_correctness": 0.47,
"anywhere_in_answer_relaxed_correctness": 0.4722222222222222
}
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.