Skip to content

Commit b13a447

Browse files
authored
[Bugfix][ROCm] Fix ViT rotary embeddings for torch.compile compatibility on ROCm (#27748)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
1 parent 7956b0c commit b13a447

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

vllm/model_executor/layers/rotary_embedding/common.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,11 @@ def dispatch_rotary_emb_function(
7777
if current_platform.is_cuda():
7878
return apply_rotary_emb
7979

80-
if current_platform.is_rocm():
80+
# if torch compile is not enabled
81+
# use rotary embedding function from flash_attn package
82+
# otherwise use the naive pytorch embedding implementation
83+
# is faster when torch compile is enabled.
84+
if current_platform.is_rocm() and not torch.compiler.is_compiling():
8185
if find_spec("flash_attn") is not None:
8286
from flash_attn.ops.triton.rotary import apply_rotary
8387

@@ -87,11 +91,10 @@ def dispatch_rotary_emb_function(
8791
"flash_attn is not installed. Falling back to PyTorch "
8892
"implementation for rotary embeddings."
8993
)
90-
9194
if default is not None:
9295
return default
93-
else:
94-
return apply_rotary_emb_torch
96+
97+
return apply_rotary_emb_torch
9598

9699

97100
# yarn functions

vllm/model_executor/models/glm4_1v.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def forward(
370370
cu_seqlens_k=cu_seqlens,
371371
max_seqlen_q=max_seqlen,
372372
max_seqlen_k=max_seqlen,
373-
dropout_p=0,
373+
dropout_p=0.0,
374374
causal=False,
375375
)
376376

0 commit comments

Comments
 (0)