39
39
from transformers .models .qwen2_vl .image_processing_qwen2_vl import (
40
40
make_batched_images , make_batched_videos , smart_resize )
41
41
42
- import vllm .envs as envs
43
42
from vllm .attention import AttentionMetadata
44
- from vllm .attention .selector import (_Backend , backend_name_to_enum ,
45
- get_global_forced_attn_backend )
43
+ from vllm .attention .selector import _Backend
46
44
from vllm .config import CacheConfig , MultiModalConfig
47
45
from vllm .distributed import get_pp_group , parallel_state
48
46
from vllm .distributed import utils as dist_utils
63
61
MultiModalInputs )
64
62
from vllm .multimodal .base import MultiModalData
65
63
from vllm .multimodal .image import cached_get_image_processor
66
- from vllm .platforms import current_platform
67
64
from vllm .sequence import IntermediateTensors , SequenceData
68
65
from vllm .transformers_utils .config import uses_mrope
69
66
from vllm .transformers_utils .processor import get_processor
70
- from vllm .utils import is_cpu
71
67
72
68
from .interfaces import SupportsMultiModal , SupportsPP
73
- from .utils import (PPMissingLayer , is_pp_missing_parameter ,
69
+ from .utils import (PPMissingLayer , get_vit_attn_backend ,
70
+ is_pp_missing_parameter ,
74
71
make_empty_intermediate_tensors_factory )
75
72
76
73
logger = init_logger (__name__ )
@@ -215,37 +212,12 @@ def __init__(
215
212
quant_config = quant_config )
216
213
217
214
# Detect attention implementation.
218
- selected_backend : Optional [_Backend ] = get_global_forced_attn_backend ()
219
- if selected_backend is None :
220
- backend_by_env_var : Optional [str ] = envs .VLLM_ATTENTION_BACKEND
221
- if backend_by_env_var is not None :
222
- selected_backend = backend_name_to_enum (backend_by_env_var )
223
- if selected_backend is None :
224
- # For Volta and Turing GPUs, use xformers instead.
225
- device_available = current_platform .has_device_capability (80 )
226
- if device_available :
227
- from transformers .utils import is_flash_attn_2_available
228
-
229
- if is_flash_attn_2_available ():
230
- self ._use_flash_attn = True
231
- else :
232
- logger .warning (
233
- "Current Qwen2-VL implementation has a bug with "
234
- "`vllm-flash-attn` inside vision module, so we use "
235
- "xformers backend instead. You can run `pip install "
236
- "flash-attn to use flash-attention backend." )
237
- self ._use_flash_attn = False
238
- else :
239
- self ._use_flash_attn = False
240
- else :
241
- if selected_backend == _Backend .FLASH_ATTN :
242
- self ._use_flash_attn = True
243
- elif selected_backend == _Backend .XFORMERS :
244
- self ._use_flash_attn = False
245
- else :
246
- raise RuntimeError (
247
- f"Qwen2-VL does not support { selected_backend } backend now."
248
- )
215
+ self .attn_backend : _Backend = get_vit_attn_backend ()
216
+ if self .attn_backend not in {
217
+ _Backend .FLASH_ATTN , _Backend .TORCH_SDPA , _Backend .XFORMERS
218
+ }:
219
+ raise RuntimeError (
220
+ f"Qwen2-VL does not support { self .attn_backend } backend now." )
249
221
250
222
def forward (
251
223
self ,
@@ -274,7 +246,7 @@ def forward(
274
246
q = apply_rotary_pos_emb_vision (q , rotary_pos_emb )
275
247
k = apply_rotary_pos_emb_vision (k , rotary_pos_emb )
276
248
277
- if self ._use_flash_attn :
249
+ if self .attn_backend == _Backend . FLASH_ATTN :
278
250
# from vllm_flash_attn.flash_attn_interface import (
279
251
# flash_attn_varlen_func)
280
252
from flash_attn import flash_attn_varlen_func
@@ -295,7 +267,7 @@ def forward(
295
267
context_layer = rearrange (output ,
296
268
"(b s) ... -> b s ..." ,
297
269
b = batch_size )
298
- elif is_cpu () :
270
+ elif self . attn_backend == _Backend . TORCH_SDPA :
299
271
seq_length = q .size (1 )
300
272
q , k , v = [rearrange (x , "b s h d -> b h s d" ) for x in [q , k , v ]]
301
273
attention_mask = torch .zeros ([1 , seq_length , seq_length ],
@@ -310,7 +282,7 @@ def forward(
310
282
attention_mask ,
311
283
dropout_p = 0.0 )
312
284
context_layer = rearrange (output , "b h s d -> b s h d " )
313
- else :
285
+ elif self . attn_backend == _Backend . XFORMERS :
314
286
from xformers import ops as xops
315
287
from xformers .ops .fmha .attn_bias import BlockDiagonalMask
316
288
0 commit comments