Skip to content

Commit bc3e0dd

Browse files
Isotr0pysumitd2
authored andcommitted
[Model] Support SDPA attention for Molmo vision backbone (vllm-project#9410)
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
1 parent 1e97d33 commit bc3e0dd

File tree

3 files changed

+61
-78
lines changed

3 files changed

+61
-78
lines changed

vllm/model_executor/models/molmo.py

Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
import math
32
import re
43
from array import array
@@ -14,10 +13,8 @@
1413
from torch.nn import functional as F
1514
from transformers import PretrainedConfig
1615

17-
import vllm.envs as envs
1816
from vllm.attention import Attention, AttentionMetadata
19-
from vllm.attention.selector import (_Backend, backend_name_to_enum,
20-
get_global_forced_attn_backend)
17+
from vllm.attention.selector import _Backend
2118
from vllm.config import CacheConfig, MultiModalConfig
2219
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
2320
get_tensor_model_parallel_world_size,
@@ -43,12 +40,11 @@
4340
from vllm.model_executor.models.interfaces import SupportsMultiModal
4441
from vllm.model_executor.models.utils import make_layers
4542
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
46-
from vllm.platforms import current_platform
4743
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
4844
SequenceData)
4945
from vllm.transformers_utils.processor import get_processor
5046

51-
log = logging.getLogger(__name__)
47+
from .utils import get_vit_attn_backend
5248

5349
# TODO: hard-coded for now. Consider making it configurable.
5450
VIT_LAYERS = [-2, -9]
@@ -190,35 +186,12 @@ def __init__(
190186
)
191187

192188
# Detect attention implementation.
193-
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
194-
if selected_backend is None:
195-
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
196-
if backend_by_env_var is not None:
197-
selected_backend = backend_name_to_enum(backend_by_env_var)
198-
if selected_backend is None:
199-
# For Volta and Turing GPUs, use xformers instead.
200-
device_available = current_platform.get_device_capability()[0] >= 8
201-
if device_available:
202-
from transformers.utils import is_flash_attn_2_available
203-
if is_flash_attn_2_available():
204-
self._use_flash_attn = True
205-
else:
206-
log.warning(
207-
"Current Molmo implementation has a bug with "
208-
"`vllm-flash-attn` inside vision module, so we use "
209-
"xformers backend instead. You can run `pip install "
210-
"flash-attn to use flash-attention backend.")
211-
self._use_flash_attn = False
212-
else:
213-
self._use_flash_attn = False
214-
else:
215-
if selected_backend == _Backend.FLASH_ATTN:
216-
self._use_flash_attn = True
217-
elif selected_backend == _Backend.XFORMERS:
218-
self._use_flash_attn = False
219-
else:
220-
raise RuntimeError(
221-
f"Molmo does not support {selected_backend} backend now.")
189+
self.attn_backend: _Backend = get_vit_attn_backend()
190+
if self.attn_backend not in {
191+
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
192+
}:
193+
raise RuntimeError(
194+
f"Molmo does not support {self.attn_backend} backend now.")
222195

223196
def forward(self,
224197
inputs_q: torch.Tensor,
@@ -240,10 +213,15 @@ def forward(self,
240213
xk = xk.view(*kv_shape)
241214
xv = xv.view(*kv_shape)
242215

243-
if self._use_flash_attn:
216+
if self.attn_backend == _Backend.FLASH_ATTN:
244217
from flash_attn import flash_attn_func
245218
output = flash_attn_func(xq, xk, xv, dropout_p=0.0, causal=False)
246-
else:
219+
elif self.attn_backend == _Backend.TORCH_SDPA:
220+
xq, xk, xv = (rearrange(x, "b s h d -> b h s d")
221+
for x in (xq, xk, xv))
222+
output = F.scaled_dot_product_attention(xq, xk, xv)
223+
output = rearrange(output, "b h s d -> b s h d ")
224+
elif self.attn_backend == _Backend.XFORMERS:
247225
from xformers import ops as xops
248226
output = xops.memory_efficient_attention_forward(xq, xk, xv, p=0)
249227

vllm/model_executor/models/qwen2_vl.py

Lines changed: 12 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,8 @@
3939
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
4040
make_batched_images, make_batched_videos, smart_resize)
4141

42-
import vllm.envs as envs
4342
from vllm.attention import AttentionMetadata
44-
from vllm.attention.selector import (_Backend, backend_name_to_enum,
45-
get_global_forced_attn_backend)
43+
from vllm.attention.selector import _Backend
4644
from vllm.config import CacheConfig, MultiModalConfig
4745
from vllm.distributed import get_pp_group, parallel_state
4846
from vllm.distributed import utils as dist_utils
@@ -63,14 +61,13 @@
6361
MultiModalInputs)
6462
from vllm.multimodal.base import MultiModalData
6563
from vllm.multimodal.image import cached_get_image_processor
66-
from vllm.platforms import current_platform
6764
from vllm.sequence import IntermediateTensors, SequenceData
6865
from vllm.transformers_utils.config import uses_mrope
6966
from vllm.transformers_utils.processor import get_processor
70-
from vllm.utils import is_cpu
7167

7268
from .interfaces import SupportsMultiModal, SupportsPP
73-
from .utils import (PPMissingLayer, is_pp_missing_parameter,
69+
from .utils import (PPMissingLayer, get_vit_attn_backend,
70+
is_pp_missing_parameter,
7471
make_empty_intermediate_tensors_factory)
7572

7673
logger = init_logger(__name__)
@@ -215,37 +212,12 @@ def __init__(
215212
quant_config=quant_config)
216213

217214
# Detect attention implementation.
218-
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
219-
if selected_backend is None:
220-
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
221-
if backend_by_env_var is not None:
222-
selected_backend = backend_name_to_enum(backend_by_env_var)
223-
if selected_backend is None:
224-
# For Volta and Turing GPUs, use xformers instead.
225-
device_available = current_platform.has_device_capability(80)
226-
if device_available:
227-
from transformers.utils import is_flash_attn_2_available
228-
229-
if is_flash_attn_2_available():
230-
self._use_flash_attn = True
231-
else:
232-
logger.warning(
233-
"Current Qwen2-VL implementation has a bug with "
234-
"`vllm-flash-attn` inside vision module, so we use "
235-
"xformers backend instead. You can run `pip install "
236-
"flash-attn to use flash-attention backend.")
237-
self._use_flash_attn = False
238-
else:
239-
self._use_flash_attn = False
240-
else:
241-
if selected_backend == _Backend.FLASH_ATTN:
242-
self._use_flash_attn = True
243-
elif selected_backend == _Backend.XFORMERS:
244-
self._use_flash_attn = False
245-
else:
246-
raise RuntimeError(
247-
f"Qwen2-VL does not support {selected_backend} backend now."
248-
)
215+
self.attn_backend: _Backend = get_vit_attn_backend()
216+
if self.attn_backend not in {
217+
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
218+
}:
219+
raise RuntimeError(
220+
f"Qwen2-VL does not support {self.attn_backend} backend now.")
249221

250222
def forward(
251223
self,
@@ -274,7 +246,7 @@ def forward(
274246
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
275247
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
276248

277-
if self._use_flash_attn:
249+
if self.attn_backend == _Backend.FLASH_ATTN:
278250
# from vllm_flash_attn.flash_attn_interface import (
279251
# flash_attn_varlen_func)
280252
from flash_attn import flash_attn_varlen_func
@@ -295,7 +267,7 @@ def forward(
295267
context_layer = rearrange(output,
296268
"(b s) ... -> b s ...",
297269
b=batch_size)
298-
elif is_cpu():
270+
elif self.attn_backend == _Backend.TORCH_SDPA:
299271
seq_length = q.size(1)
300272
q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]]
301273
attention_mask = torch.zeros([1, seq_length, seq_length],
@@ -310,7 +282,7 @@ def forward(
310282
attention_mask,
311283
dropout_p=0.0)
312284
context_layer = rearrange(output, "b h s d -> b s h d ")
313-
else:
285+
elif self.attn_backend == _Backend.XFORMERS:
314286
from xformers import ops as xops
315287
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
316288

vllm/model_executor/models/utils.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,22 @@
88
from torch.func import functional_call
99
from transformers import PretrainedConfig
1010

11+
import vllm.envs as envs
12+
from vllm.attention.selector import (_Backend, backend_name_to_enum,
13+
get_global_forced_attn_backend)
1114
from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
1215
SchedulerConfig)
16+
from vllm.logger import init_logger
1317
from vllm.model_executor.layers.quantization import QuantizationConfig
1418
from vllm.model_executor.model_loader.loader import build_model
1519
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
1620
from vllm.model_executor.models import ModelRegistry
1721
from vllm.multimodal.base import NestedTensors
22+
from vllm.platforms import current_platform
1823
from vllm.sequence import IntermediateTensors
19-
from vllm.utils import is_pin_memory_available
24+
from vllm.utils import is_cpu, is_pin_memory_available
25+
26+
logger = init_logger(__name__)
2027

2128
WeightsMapping = Mapping[str, Optional[str]]
2229
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
@@ -487,3 +494,29 @@ def __getattr__(self, key: str):
487494
def __call__(self, *args: Any, **kwargs: Any) -> Any:
488495
llm = super().__getattr__(self.model_name)
489496
return llm(*args, **kwargs)
497+
498+
499+
def get_vit_attn_backend() -> _Backend:
500+
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
501+
if selected_backend is None:
502+
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
503+
if backend_by_env_var is not None:
504+
selected_backend = backend_name_to_enum(backend_by_env_var)
505+
if selected_backend is None:
506+
# For Volta and Turing GPUs, use xformers instead.
507+
device_available = current_platform.has_device_capability(80)
508+
if device_available:
509+
from transformers.utils import is_flash_attn_2_available
510+
if is_flash_attn_2_available():
511+
selected_backend = _Backend.FLASH_ATTN
512+
else:
513+
logger.warning(
514+
"Current `vllm-flash-attn` has a bug inside vision module, "
515+
"so we use xformers backend instead. You can run "
516+
"`pip install flash-attn` to use flash-attention backend.")
517+
selected_backend = _Backend.XFORMERS
518+
elif is_cpu():
519+
selected_backend = _Backend.TORCH_SDPA
520+
else:
521+
selected_backend = _Backend.XFORMERS
522+
return selected_backend

0 commit comments

Comments
 (0)