Skip to content

Commit 6c0186f

Browse files
Lucaskabelazhaozuy
authored andcommitted
[Bugfix][Qwen][Multimodal] Move Qwen2_5_vl sdpa to custom op and reenable compile (vllm-project#27764)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
1 parent 29970f2 commit 6c0186f

File tree

2 files changed

+69
-28
lines changed

2 files changed

+69
-28
lines changed

vllm/attention/ops/vit_attn_wrappers.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import einops
1616
import torch
17+
import torch.nn.functional as F
1718

1819
from vllm.utils.torch_utils import direct_register_custom_op
1920

@@ -123,3 +124,55 @@ def vit_flash_attn_wrapper(
123124
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
124125
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
125126
)
127+
128+
129+
# TODO: Once we have a torch 2.10, we can use tensor slices
130+
# so we won't need to wrap this in custom ops
131+
def torch_sdpa_wrapper(
132+
q: torch.Tensor,
133+
k: torch.Tensor,
134+
v: torch.Tensor,
135+
cu_seqlens: torch.Tensor,
136+
) -> torch.Tensor:
137+
outputs = []
138+
for i in range(1, len(cu_seqlens)):
139+
start_idx = cu_seqlens[i - 1]
140+
end_idx = cu_seqlens[i]
141+
q_i = q[:, start_idx:end_idx]
142+
k_i = k[:, start_idx:end_idx]
143+
v_i = v[:, start_idx:end_idx]
144+
q_i, k_i, v_i = (
145+
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
146+
)
147+
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
148+
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
149+
outputs.append(output_i)
150+
context_layer = torch.cat(outputs, dim=1)
151+
context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
152+
return context_layer
153+
154+
155+
def torch_sdpa_wrapper_fake(
156+
q: torch.Tensor,
157+
k: torch.Tensor,
158+
v: torch.Tensor,
159+
cu_seqlens: torch.Tensor,
160+
) -> torch.Tensor:
161+
b, s, h, d = q.shape
162+
return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)
163+
164+
165+
direct_register_custom_op(
166+
op_name="torch_sdpa_wrapper",
167+
op_func=torch_sdpa_wrapper,
168+
fake_impl=torch_sdpa_wrapper_fake,
169+
)
170+
171+
172+
def vit_torch_sdpa_wrapper(
173+
q: torch.Tensor,
174+
k: torch.Tensor,
175+
v: torch.Tensor,
176+
cu_seqlens: torch.Tensor,
177+
) -> torch.Tensor:
178+
return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens)

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
4747
from vllm.attention.ops.vit_attn_wrappers import (
4848
vit_flash_attn_wrapper,
49+
vit_torch_sdpa_wrapper,
4950
vit_xformers_attn_wrapper,
5051
)
5152
from vllm.compilation.decorators import support_torch_compile
@@ -442,41 +443,28 @@ def forward(
442443
q = q.contiguous()
443444
k = k.contiguous()
444445
v = v.contiguous()
445-
outputs = []
446-
for i in range(1, len(cu_seqlens)):
447-
start_idx = cu_seqlens[i - 1]
448-
end_idx = cu_seqlens[i]
449-
q_i = q[:, start_idx:end_idx]
450-
k_i = k[:, start_idx:end_idx]
451-
v_i = v[:, start_idx:end_idx]
452-
q_i, k_i, v_i = (
453-
einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
454-
)
455-
output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
456-
output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
457-
outputs.append(output_i)
458-
context_layer = torch.cat(outputs, dim=1)
459-
context_layer = einops.rearrange(
460-
context_layer, "b s h d -> s b (h d)"
461-
).contiguous()
446+
context_layer = vit_torch_sdpa_wrapper(
447+
q,
448+
k,
449+
v,
450+
cu_seqlens,
451+
)
462452
elif self.attn_backend == _Backend.XFORMERS:
463453
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
464454

465455
output, _ = self.proj(context_layer)
466456
return output
467457

468458

469-
# (FIXME): Enable this after dynamic slicing is fixed
470-
# See https://github.com/vllm-project/vllm/pull/27760
471-
# @support_torch_compile(
472-
# dynamic_arg_dims={
473-
# "x": 0,
474-
# "cu_seqlens": 0,
475-
# "rotary_pos_emb": 0,
476-
# "seqlens": 0,
477-
# },
478-
# mark_unbacked_dims={"seqlens": 0},
479-
# )
459+
@support_torch_compile(
460+
dynamic_arg_dims={
461+
"x": 0,
462+
"cu_seqlens": 0,
463+
"rotary_pos_emb": 0,
464+
"seqlens": 0,
465+
},
466+
mark_unbacked_dims={"seqlens": 0},
467+
)
480468
class Qwen2_5_VisionBlock(nn.Module):
481469
def __init__(
482470
self,

0 commit comments

Comments
 (0)