Skip to content

Conversation

@vllmellm
Copy link
Contributor

@vllmellm vllmellm commented Oct 29, 2025

Purpose

This pull request updates the selection of rotary embedding function previously selected for the ROCm platform from flash_attn.ops.triton.rotary because it is not compatible with torch.compile. Since this PR enabled torch.compile support for Qwen-VL models, it became necessary to address this incompatibility. After registering the rotary embedding operation using direct_register_custom_ops within vLLM, benchmarking revealed that the naive PyTorch implementation of the rotary embedding is actually faster on ROCm. Therefore, this PR updates the ROCm platform to use the native PyTorch implementation instead when torch.compile is enabled and for those models that don't enable torch.compile use the rotary embedding from flash_attn.ops.triton.rotary.

benchmark setting:

  1. serve Qwen/Qwen2.5-VL-3B-Instruct model

vllm serve Qwen/Qwen2.5-VL-3B-Instruct \ -tp 1 \ --port 9099 \ --trust-remote-code --swap-space 16 --distributed-executor-backend mp

  1. Run benchmark

vllm bench serve \ --model Qwen/Qwen2.5-VL-3B-Instruct \ --backend openai-chat \ --endpoint /v1/chat/completions \ --dataset-name hf \ --port 9099 \ --dataset-path lmarena-ai/VisionArena-Chat \ --hf-split train \ --num-prompts 1000

benchmark results

Metric Triton Version Naive PyTorch Version
Successful requests 1000 1000
Failed requests 0 0
Benchmark duration (s) 120.18 118.13
Total input tokens 94,327 94,327
Total generated tokens 106,302 105,808
Request throughput (req/s) 8.32 8.46
Output token throughput (tok/s) 884.54 895.66
Peak output token throughput (tok/s) 3,358.00 7,338.00
Peak concurrent requests 1,000.00 1,000.00
Total token throughput (tok/s) 1,669.43 1,694.13
Time to First Token
Mean TTFT (ms) 70,512.10 69,237.83
Median TTFT (ms) 71,570.31 70,478.33
P99 TTFT (ms) 118,872.41 116,393.34
Time per Output Token (excl. 1st token)
Mean TPOT (ms) 1.31 1.91
Median TPOT (ms) 0.00 0.00
P99 TPOT (ms) 16.24 27.02
Inter-token Latency
Mean ITL (ms) 25.33 20.14
Median ITL (ms) 5.97 9.19
P99 ITL (ms) 427.85 381.43

This pull request also fixes the bug below when running glm vision models such as zai-org/GLM-4.5V-FP8 with VLLM_ROCM_USE_AITER=1

the dropout_p value is changed from 0 to 0.0 to match the required float dtype by flash_attn_varlen_func aiter kernel.

^[[1;36m(Worker pid=18972)^[[0;0m ERROR 10-31 09:32:27 [multiproc_executor.py:703]   File "/app/fix/vllm/vllm/model_executor/models/glm4_1v.py", line 363, in forward
^[[1;36m(Worker pid=18972)^[[0;0m ERROR 10-31 09:32:27 [multiproc_executor.py:703]     output = self.flash_attn_varlen_func(
^[[1;36m(Worker pid=18972)^[[0;0m ERROR 10-31 09:32:27 [multiproc_executor.py:703]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(Worker pid=18972)^[[0;0m ERROR 10-31 09:32:27 [multiproc_executor.py:703]   File "/usr/local/lib/python3.12/dist-packages/aiter/ops/mha.py", line 1445, in flash_attn_varlen_func
^[[1;36m(Worker pid=18972)^[[0;0m ERROR 10-31 09:32:27 [multiproc_executor.py:703]     return FlashAttnVarlenFunc.apply(
^[[1;36m(Worker pid=18972)^[[0;0m ERROR 10-31 09:32:27 [multiproc_executor.py:703]            ^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(Worker pid=18972)^[[0;0m ERROR 10-31 09:32:27 [multiproc_executor.py:703]   File "/usr/local/lib/python3.12/dist-packages/torch/autograd/function.py", line 576, in apply
^[[1;36m(Worker pid=18972)^[[0;0m ERROR 10-31 09:32:27 [multiproc_executor.py:703]     return super().apply(*args, **kwargs)  # type: ignore[misc]
^[[1;36m(Worker pid=18972)^[[0;0m ERROR 10-31 09:32:27 [multiproc_executor.py:703]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(Worker pid=18972)^[[0;0m ERROR 10-31 09:32:27 [multiproc_executor.py:703]   File "/usr/local/lib/python3.12/dist-packages/aiter/ops/mha.py", line 1244, in forward
^[[1;36m(Worker pid=18972)^[[0;0m ERROR 10-31 09:32:27 [multiproc_executor.py:703]     out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
^[[1;36m(Worker pid=18972)^[[0;0m ERROR 10-31 09:32:27 [multiproc_executor.py:703]                                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[1;36m(Worker pid=18972)^[[0;0m ERROR 10-31 09:32:27 [multiproc_executor.py:703]   File "/usr/local/lib/python3.12/dist-packages/aiter/ops/mha.py", line 960, in _flash_attn_varlen_forward
^[[1;36m(Worker pid=18972)^[[0;0m ERROR 10-31 09:32:27 [multiproc_executor.py:703]     out, softmax_lse, S_dmask, rng_state = mha_varlen_fwd(
^[[1;36m(Worker pid=18972)^[[0;0m ERROR 10-31 09:32:27 [multiproc_executor.py:703]                                            ^^^^^^^^^^^^^^^
^[[1;36m(Worker pid=18972)^[[0;0m ERROR 10-31 09:32:27 [multiproc_executor.py:703]   File "/usr/local/lib/python3.12/dist-packages/aiter/jit/core.py", line 628, in wrapper
^[[1;36m(Worker pid=18972)^[[0;0m ERROR 10-31 09:32:27 [multiproc_executor.py:703]     func.arg_checked = check_args()
^[[1;36m(Worker pid=18972)^[[0;0m ERROR 10-31 09:32:27 [multiproc_executor.py:703]                        ^^^^^^^^^^^^
^[[1;36m(Worker pid=18972)^[[0;0m ERROR 10-31 09:32:27 [multiproc_executor.py:703]   File "/usr/local/lib/python3.12/dist-packages/aiter/jit/core.py", line 599, in check_args
^[[1;36m(Worker pid=18972)^[[0;0m ERROR 10-31 09:32:27 [multiproc_executor.py:703]     raise TypeError(
^[[1;36m(Worker pid=18972)^[[0;0m ERROR 10-31 09:32:27 [multiproc_executor.py:703] TypeError: dropout_p needs to be <class 'float'> but got <class 'int'>

Test Plan

lm_eval test on Qwen/Qwen2.5-VL-3B-Instruct using mmmu dataset in mistral-eval repo

VLLM_USE_V1=1 VLLM_ROCM_USE_AITER=1 vllm serve Qwen/Qwen2.5-VL-3B-Instruct -tp 1 --port 9099 --trust-remote-code --swap-space 16 --distributed-executor-backend mp

python -m eval.run eval_vllm \ --model_name Qwen/Qwen2.5-VL-3B-Instruct \ --url http://0.0.0.0:9099 \ --output_dir ./logs \ --eval_name "mmmu"

Test Result

Metrics:
{
"explicit_prompt_relaxed_correctness": 0.47,
"anywhere_in_answer_relaxed_correctness": 0.4722222222222222
}


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@vllmellm vllmellm requested a review from sighingnow as a code owner October 29, 2025 14:53
@mergify mergify bot added qwen Related to Qwen models rocm Related to AMD ROCm labels Oct 29, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a torch.compile compatibility issue on ROCm for ViT rotary embeddings by wrapping the flash_attn Triton kernel in a vLLM custom op. The refactoring of apply_rotary_pos_emb_vision into a common utility is a good improvement for code reuse.

However, I've identified a critical issue where the refactoring exposes a latent bug that will cause a TypeError on CUDA platforms due to a function signature mismatch. My review includes a suggested fix for this issue.

@vllmellm vllmellm marked this pull request as draft October 29, 2025 15:18
…ginal code

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@vllmellm vllmellm marked this pull request as ready for review October 30, 2025 08:14
@DarkLight1337
Copy link
Member

cc @zhewenl @ywang96 @lgeiger @JartX

@DarkLight1337
Copy link
Member

@zhewenl can you try this PR while re-enabling torch.compile and see if it works for you?

@zhewenl
Copy link
Collaborator

zhewenl commented Oct 30, 2025

@zhewenl can you try this PR while re-enabling torch.compile and see if it works for you?

@DarkLight1337, @vllmellm I have verified with bench serve(https://gist.github.com/zhewenl/41fc928c427ff5da7a2df25376d6d136), and also verified tests/models/multimodal/generation/test_qwen2_5_vl.py is fixed with changes in this PR(https://gist.github.com/zhewenl/13ea86832d504c6045d4e7a2cf60badd)

if current_platform.is_cuda():
return apply_rotary_emb

if current_platform.is_rocm():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vllmellm @DarkLight1337 What do you think? Should we keep the triton rotary embedding for backwards compatibility at this moment, by adding a condition torch.compiler.is_compiling() instead as there are other model definition files that are not using torch.compile yet.

E.g. model definition files that are not using torch.compile

  • vllm/model_executor/models/dots_ocr.py
  • vllm/model_executor/models/ernie45_vl.py
  • vllm/model_executor/models/glm4_1v.py
  • vllm/model_executor/models/qwen2_vl.py
  • vllm/model_executor/models/siglip2navit.py

Copy link
Contributor

@tjtanaa tjtanaa Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thing to note at this point is that the torch.compile has been disabled as there are issues with dynamic slicing.

So if we add this condition torch.compiler.is_compiling(), then we don't need to postpone this PR while ensuring no performance regression for both cases torch.compile is enabled/disabled.

Right now, they are exploring new fixes e.g. #27764

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My two cents: I see the benefit of gating with torch.compiler.is_compiling(), especially for models we don't have covered with torch.compile yet, so I am in favor of that approach as opposed to completely removing

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on adding a condition on checking if it's torch compiled.

On a side note - is this PR really considered a bugfix? I thought after #27760 things should already work on AMD. Did I miss something here?

Copy link
Contributor

@tjtanaa tjtanaa Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ywang96 PR #27760 is a bugfix by reverting the torch.compile PR. If the torch.compile is reenable, AMD needs this bugfix PR to address the incompatibility of the flash_attn.rotary_embedding triton kernel with torch.compile. The discussion of current comment thread is to address the case where other models still haven't had torch.compile support.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And in this PR we found that torch.compile-ed rotary embedding is faster than triton implementation on ROCm. I think it will be the same case on CUDA.

Copy link
Member

@ywang96 ywang96 Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I was just confirming whether the current main branch is broken on AMD (because it'll affect our release), and it seems that it's not because we currently disabled it, correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ywang96 On current main branch, the unit tests are healthy and the Qwen3 VL accuracy are normal. 👍

Copy link
Contributor Author

@vllmellm vllmellm Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa @Lucaskabela those model path are not using this dispatching function. it is only used in qwen2_vl.py. While qwen2.5_vl.py and glm4_1v.py are importing the apply_rotary_pos_emb_vision function from qwen2_vl.py and inside this function the dispatch is used.

def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch)
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
output = rotary_emb_function(t_, cos, sin).type_as(t)
return output

@DarkLight1337 @Lucaskabela @tjtanaa we may decide to keep the triton embedding function here only for one case that is glm4_1v.py or not ?

…mpatible now

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Copy link
Contributor

@Lucaskabela Lucaskabela left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this approach LGTM

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that #27764 is merged - we should also merge this in

@ywang96 ywang96 added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 3, 2025
@ywang96 ywang96 merged commit b13a447 into vllm-project:main Nov 4, 2025
54 checks passed
omerpaz95 pushed a commit to omerpaz95/vllm that referenced this pull request Nov 4, 2025
…ity on ROCm (vllm-project#27748)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants