Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions vllm/model_executor/layers/rotary_embedding/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ def dispatch_rotary_emb_function(
if current_platform.is_cuda():
return apply_rotary_emb

if current_platform.is_rocm():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vllmellm @DarkLight1337 What do you think? Should we keep the triton rotary embedding for backwards compatibility at this moment, by adding a condition torch.compiler.is_compiling() instead as there are other model definition files that are not using torch.compile yet.

E.g. model definition files that are not using torch.compile

  • vllm/model_executor/models/dots_ocr.py
  • vllm/model_executor/models/ernie45_vl.py
  • vllm/model_executor/models/glm4_1v.py
  • vllm/model_executor/models/qwen2_vl.py
  • vllm/model_executor/models/siglip2navit.py

Copy link
Contributor

@tjtanaa tjtanaa Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thing to note at this point is that the torch.compile has been disabled as there are issues with dynamic slicing.

So if we add this condition torch.compiler.is_compiling(), then we don't need to postpone this PR while ensuring no performance regression for both cases torch.compile is enabled/disabled.

Right now, they are exploring new fixes e.g. #27764

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My two cents: I see the benefit of gating with torch.compiler.is_compiling(), especially for models we don't have covered with torch.compile yet, so I am in favor of that approach as opposed to completely removing

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on adding a condition on checking if it's torch compiled.

On a side note - is this PR really considered a bugfix? I thought after #27760 things should already work on AMD. Did I miss something here?

Copy link
Contributor

@tjtanaa tjtanaa Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ywang96 PR #27760 is a bugfix by reverting the torch.compile PR. If the torch.compile is reenable, AMD needs this bugfix PR to address the incompatibility of the flash_attn.rotary_embedding triton kernel with torch.compile. The discussion of current comment thread is to address the case where other models still haven't had torch.compile support.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And in this PR we found that torch.compile-ed rotary embedding is faster than triton implementation on ROCm. I think it will be the same case on CUDA.

Copy link
Member

@ywang96 ywang96 Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I was just confirming whether the current main branch is broken on AMD (because it'll affect our release), and it seems that it's not because we currently disabled it, correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ywang96 On current main branch, the unit tests are healthy and the Qwen3 VL accuracy are normal. 👍

Copy link
Contributor Author

@vllmellm vllmellm Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa @Lucaskabela those model path are not using this dispatching function. it is only used in qwen2_vl.py. While qwen2.5_vl.py and glm4_1v.py are importing the apply_rotary_pos_emb_vision function from qwen2_vl.py and inside this function the dispatch is used.

def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch)
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
output = rotary_emb_function(t_, cos, sin).type_as(t)
return output

@DarkLight1337 @Lucaskabela @tjtanaa we may decide to keep the triton embedding function here only for one case that is glm4_1v.py or not ?

# if torch compile is not enabled
# use rotary embedding function from flash_attn package
# otherwise use the naive pytorch embedding implementation
# is faster when torch compile is enabled.
if current_platform.is_rocm() and not torch.compiler.is_compiling():
if find_spec("flash_attn") is not None:
from flash_attn.ops.triton.rotary import apply_rotary

Expand All @@ -87,11 +91,10 @@ def dispatch_rotary_emb_function(
"flash_attn is not installed. Falling back to PyTorch "
"implementation for rotary embeddings."
)

if default is not None:
return default
else:
return apply_rotary_emb_torch

return apply_rotary_emb_torch


# yarn functions
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def forward(
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0,
dropout_p=0.0,
causal=False,
)

Expand Down