Skip to content

Commit a1ae906

Browse files
committed
v1: Add Whisper model support (encoder-decoder)
This brings Whisper support to V1 to close one of the remaining feature gaps with V0. Most of the changes apply to encoder-decoder models generally, though Whisper is the only one explicitly tested and is the only encoder-decoder model updated to support V1. **Whisper Model Implementation:** - Remove SupportsV0Only interface constraint to enable V1 compatibility - Update get_multimodal_embeddings() to return list format required by V1 **Flash Attention Backend:** - Add encoder attention metadata fields (encoder_seq_start_loc, max_encoder_seq_len, cross_slot_mapping) - Implement encoder self-attention support without using KV cache - Add cross-attention support for encoder-decoder models with proper KV cache handling **KV Cache Manager:** - Introduce CrossAttentionManager for handling cross-attention KV cache in encoder-decoder models - Add CrossAttentionSpec for cross-attention cache specification with encoder-based sizing - Implement allocate_slots_for_cross_attn() for static encoder-length-based allocation - Add cross-attention block allocation logic separate from decoder token growth **Scheduler:** - Disable prefix caching for encoder-decoder models - Implement cross-attention block allocation during request scheduling - Add cross-attention block tracking in state management **GPU Model Runner:** - Add encoder input extraction for audio features processing - Implement encoder attention metadata building for both self-attention and cross-attention - Add cross-attention KV cache group handling with proper slot mapping - Modify input batch creation to accommodate encoder sequence lengths - Add encoder input processing in forward pass with proper device/dtype handling - Update profiling and memory management for encoder-decoder models The implementation maintains backward compatibility while adding comprehensive encoder-decoder support, with particular focus on Whisper's audio processing pipeline and cross-attention mechanisms between encoder and decoder. Related to: - V0 deprecation: #18571 - 2025 Q3 roadmap: #20336 Signed-off-by: Russell Bryant <rbryant@redhat.com>
1 parent 14bf19e commit a1ae906

File tree

14 files changed

+660
-70
lines changed

14 files changed

+660
-70
lines changed

vllm/attention/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
"AttentionMetadata",
1515
"AttentionType",
1616
"AttentionMetadataBuilder",
17-
"Attention",
1817
"AttentionState",
1918
"get_attn_backend",
2019
]

vllm/attention/backends/flash_attn.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,4 @@ def _get_causal_option(attn_type: str) -> bool:
10001000
attention (i.e., not encoder, encoder-only, or encoder-decoder),
10011001
otherwise returns `False`.
10021002
"""
1003-
return not (attn_type == AttentionType.ENCODER
1004-
or attn_type == AttentionType.ENCODER_ONLY
1005-
or attn_type == AttentionType.ENCODER_DECODER)
1003+
return attn_type == AttentionType.DECODER

vllm/inputs/preprocess.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -869,9 +869,6 @@ def preprocess(
869869
) -> ProcessorInputs:
870870
"""Preprocess the input prompt."""
871871
if self.model_config.is_encoder_decoder:
872-
assert not return_mm_hashes, (
873-
"Multimodal hashes for encoder-decoder models should not be ",
874-
"returned until they are supported on vLLM V1.")
875872
# Encoder-decoder model requires special mapping of
876873
# input prompts to encoder & decoder
877874
return self._process_encoder_decoder_prompt(
@@ -903,9 +900,6 @@ async def preprocess_async(
903900
[`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
904901
"""
905902
if self.model_config.is_encoder_decoder:
906-
assert not return_mm_hashes, (
907-
"Multimodal hashes for encoder-decoder models should not be ",
908-
"returned until they are supported on vLLM V1.")
909903
# Encoder-decoder model requires special mapping of
910904
# input prompts to encoder & decoder
911905
return await self._process_encoder_decoder_prompt_async(prompt)

vllm/model_executor/models/whisper.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from vllm.transformers_utils.processor import cached_get_processor
4343

4444
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
45-
SupportsTranscription, SupportsV0Only)
45+
SupportsTranscription)
4646
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
4747
make_layers)
4848

@@ -790,7 +790,7 @@ def _get_prompt_updates(
790790
info=WhisperProcessingInfo,
791791
dummy_inputs=WhisperDummyInputsBuilder)
792792
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
793-
SupportsMultiModal, SupportsV0Only):
793+
SupportsMultiModal):
794794
packed_modules_mapping = {
795795
"self_attn.qkv_proj": [
796796
"self_attn.q_proj",
@@ -916,10 +916,9 @@ def get_language_model(self) -> torch.nn.Module:
916916

917917
def get_multimodal_embeddings(self,
918918
**kwargs: object) -> MultiModalEmbeddings:
919-
# TODO: This method does not obey the interface for SupportsMultiModal.
920-
# Refactor this once encoder/decoder support is implemented in V1.
919+
# Required as part of SupportsMultiModal interface.
921920
audio_input = self._parse_and_validate_audio_input(**kwargs)
922-
return self.model.get_encoder_outputs(audio_input["input_features"])
921+
return [self.model.get_encoder_outputs(audio_input["input_features"])]
923922

924923
def get_input_embeddings(
925924
self,

vllm/v1/attention/backends/flash_attn.py

Lines changed: 155 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,24 @@ class FlashAttentionMetadata:
130130
prefix_scheduler_metadata: Optional[torch.Tensor] = None
131131
max_num_splits: int = 0
132132

133+
# Begin encoder attn & enc/dec cross-attn fields...
134+
135+
# (batch_size + 1,). The cumulative sequence lengths of the encoder
136+
# sequences in the batch, used to index into sequence. E.g., if the sequence
137+
# length is [4, 6], it is [0, 4, 10].
138+
encoder_seq_start_loc: Optional[torch.Tensor] = None
139+
# Maximum sequence length among encoder sequences
140+
max_encoder_seq_len: Optional[int] = None
141+
cross_slot_mapping: Optional[torch.Tensor] = None
142+
143+
@property
144+
def is_all_encoder_attn_metadata_set(self) -> bool:
145+
"""
146+
All attention metadata required for encoder attention is set.
147+
"""
148+
return (self.encoder_seq_start_loc is not None
149+
and self.max_encoder_seq_len is not None)
150+
133151

134152
def _get_sliding_window_configs(
135153
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
@@ -207,7 +225,13 @@ def build(self,
207225
num_reqs = common_attn_metadata.num_reqs
208226
num_actual_tokens = common_attn_metadata.num_actual_tokens
209227
max_query_len = common_attn_metadata.max_query_len
210-
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
228+
229+
if (common_attn_metadata.cross_slot_mapping is not None
230+
and common_attn_metadata.max_encoder_seq_len is not None):
231+
# ENCODER_DECODER cross-attention
232+
max_seq_len = common_attn_metadata.max_encoder_seq_len
233+
else:
234+
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
211235
query_start_loc = common_attn_metadata.query_start_loc
212236
seq_lens = common_attn_metadata.seq_lens
213237
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
@@ -326,6 +350,10 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
326350
suffix_kv_lens=suffix_kv_lens,
327351
prefix_scheduler_metadata=prefix_scheduler_metadata,
328352
max_num_splits=max_num_splits,
353+
# Encoder/cross-attention fields
354+
encoder_seq_start_loc=common_attn_metadata.encoder_seq_start_loc,
355+
max_encoder_seq_len=common_attn_metadata.max_encoder_seq_len,
356+
cross_slot_mapping=common_attn_metadata.cross_slot_mapping,
329357
)
330358
return attn_metadata
331359

@@ -375,17 +403,31 @@ def __init__(
375403

376404
FlashAttentionBackend.validate_head_size(head_size)
377405

378-
if attn_type != AttentionType.DECODER:
379-
raise NotImplementedError("Encoder self-attention and "
380-
"encoder/decoder cross-attention "
381-
"are not implemented for "
382-
"FlashAttentionImpl")
406+
self.attn_type = attn_type
383407
self.vllm_flash_attn_version = get_flash_attn_version()
384408
if is_quantized_kv_cache(self.kv_cache_dtype) \
385409
and not flash_attn_supports_fp8():
386410
raise NotImplementedError(
387411
"FlashAttention does not support fp8 kv-cache on this device.")
388412

413+
@staticmethod
414+
def _get_causal_option(attn_type: str) -> bool:
415+
"""
416+
Determine whether the given attention type is suitable for causal
417+
attention mechanisms.
418+
419+
Args:
420+
attn_type (AttentionType): The type of attention being evaluated
421+
422+
Returns:
423+
bool: Returns `True` if the attention type is suitable for causal
424+
attention (i.e., not encoder, encoder-only, or encoder-decoder),
425+
otherwise returns `False`.
426+
"""
427+
return not (attn_type == AttentionType.ENCODER
428+
or attn_type == AttentionType.ENCODER_ONLY
429+
or attn_type == AttentionType.ENCODER_DECODER)
430+
389431
def forward(
390432
self,
391433
layer: torch.nn.Module,
@@ -422,6 +464,14 @@ def forward(
422464
# Profiling run.
423465
return output
424466

467+
# Validate attention metadata based on attention type
468+
attn_type = self.attn_type
469+
if (attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_DECODER,
470+
AttentionType.ENCODER_ONLY)
471+
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
472+
raise AttributeError("Encoder attention requires setting "
473+
"encoder metadata attributes.")
474+
425475
# IMPORTANT!
426476
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
427477
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
@@ -432,22 +482,40 @@ def forward(
432482
# performance to make sure it does not introduce any overhead.
433483

434484
num_actual_tokens = attn_metadata.num_actual_tokens
485+
486+
# Handle encoder attention differently - no KV cache needed
487+
if attn_type == AttentionType.ENCODER:
488+
# For encoder attention,
489+
# we use direct Q, K, V tensors without caching
490+
return self._forward_encoder_attention(query[:num_actual_tokens],
491+
key[:num_actual_tokens],
492+
value[:num_actual_tokens],
493+
output[:num_actual_tokens],
494+
attn_metadata, layer)
495+
496+
# For decoder and cross-attention, use KV cache as before
435497
key_cache, value_cache = kv_cache.unbind(0)
436498

437-
if self.kv_sharing_target_layer_name is None:
499+
if (self.kv_sharing_target_layer_name is None and (key is not None)
500+
and (value is not None)):
438501
# Reshape the input keys and values and store them in the cache.
439502
# Skip this if sharing KV cache with an earlier attention layer.
440503
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
441504
# not padded. However, we don't need to do key[:num_actual_tokens]
442505
# and value[:num_actual_tokens] because the reshape_and_cache_flash
443506
# op uses the slot_mapping's shape to determine the number of
444507
# actual tokens.
508+
if attn_type == AttentionType.ENCODER_DECODER:
509+
updated_slot_mapping = attn_metadata.cross_slot_mapping
510+
else:
511+
updated_slot_mapping = attn_metadata.slot_mapping
512+
445513
reshape_and_cache_flash(
446514
key,
447515
value,
448516
key_cache,
449517
value_cache,
450-
attn_metadata.slot_mapping,
518+
updated_slot_mapping,
451519
self.kv_cache_dtype,
452520
layer._k_scale,
453521
layer._v_scale,
@@ -471,7 +539,7 @@ def forward(
471539
block_table = attn_metadata.block_table
472540
scheduler_metadata = attn_metadata.scheduler_metadata
473541

474-
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
542+
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
475543

476544
flash_attn_varlen_func(
477545
q=query[:num_actual_tokens],
@@ -483,7 +551,7 @@ def forward(
483551
seqused_k=seqused_k,
484552
max_seqlen_k=max_seqlen_k,
485553
softmax_scale=self.scale,
486-
causal=True,
554+
causal=FlashAttentionImpl._get_causal_option(attn_type),
487555
alibi_slopes=self.alibi_slopes,
488556
window_size=self.sliding_window,
489557
block_table=block_table,
@@ -518,12 +586,86 @@ def forward(
518586
fa_version=self.vllm_flash_attn_version,
519587
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
520588
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
521-
q_descale=layer._q_scale,
522-
k_descale=layer._k_scale,
523-
v_descale=layer._v_scale,
589+
q_descale=layer._q_scale.expand(descale_shape),
590+
k_descale=layer._k_scale.expand(descale_shape),
591+
v_descale=layer._v_scale.expand(descale_shape),
524592
)
525593
return output
526594

595+
def _forward_encoder_attention(
596+
self,
597+
query: torch.Tensor,
598+
key: torch.Tensor,
599+
value: torch.Tensor,
600+
output: torch.Tensor,
601+
attn_metadata: FlashAttentionMetadata,
602+
layer: torch.nn.Module,
603+
) -> torch.Tensor:
604+
"""Forward pass for encoder attention without KV cache.
605+
606+
Args:
607+
query: shape = [num_encoder_tokens, num_heads, head_size]
608+
key: shape = [num_encoder_tokens, num_kv_heads, head_size]
609+
value: shape = [num_encoder_tokens, num_kv_heads, head_size]
610+
output: shape = [num_encoder_tokens, num_heads, head_size]
611+
attn_metadata: Encoder attention metadata
612+
layer: The attention layer
613+
"""
614+
# For encoder attention, process FP8 quantization if needed
615+
if self.kv_cache_dtype.startswith("fp8"):
616+
num_tokens, num_heads, head_size = query.shape
617+
query, _ = ops.scaled_fp8_quant(
618+
query.reshape(
619+
(num_tokens, num_heads * head_size)).contiguous(),
620+
layer._q_scale)
621+
query = query.reshape((num_tokens, num_heads, head_size))
622+
623+
num_kv_tokens, num_kv_heads, head_size = key.shape
624+
key, _ = ops.scaled_fp8_quant(
625+
key.reshape(
626+
(num_kv_tokens, num_kv_heads * head_size)).contiguous(),
627+
layer._k_scale)
628+
key = key.reshape((num_kv_tokens, num_kv_heads, head_size))
629+
630+
value, _ = ops.scaled_fp8_quant(
631+
value.reshape(
632+
(num_kv_tokens, num_kv_heads * head_size)).contiguous(),
633+
layer._v_scale)
634+
value = value.reshape((num_kv_tokens, num_kv_heads, head_size))
635+
636+
# Use encoder-specific metadata for sequence information
637+
cu_seqlens_q = attn_metadata.encoder_seq_start_loc
638+
cu_seqlens_k = attn_metadata.encoder_seq_start_loc
639+
max_seqlen_q = attn_metadata.max_encoder_seq_len
640+
max_seqlen_k = attn_metadata.max_encoder_seq_len
641+
642+
descale_shape = (
643+
cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr]
644+
self.num_kv_heads)
645+
646+
# Call flash attention directly on Q, K, V tensors
647+
flash_attn_varlen_func(
648+
q=query,
649+
k=key,
650+
v=value,
651+
out=output,
652+
cu_seqlens_q=cu_seqlens_q,
653+
cu_seqlens_k=cu_seqlens_k,
654+
max_seqlen_q=max_seqlen_q,
655+
max_seqlen_k=max_seqlen_k,
656+
softmax_scale=self.scale,
657+
causal=False, # Encoder attention is bidirectional
658+
alibi_slopes=self.alibi_slopes,
659+
window_size=self.sliding_window,
660+
softcap=self.logits_soft_cap,
661+
fa_version=self.vllm_flash_attn_version,
662+
q_descale=layer._q_scale.expand(descale_shape),
663+
k_descale=layer._k_scale.expand(descale_shape),
664+
v_descale=layer._v_scale.expand(descale_shape),
665+
)
666+
667+
return output
668+
527669

528670
def use_cascade_attention(
529671
common_prefix_len: int,

vllm/v1/attention/backends/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ class CommonAttentionMetadata:
5959
block_table_tensor: torch.Tensor
6060
slot_mapping: torch.Tensor
6161

62+
# Encoder/cross-attention specific fields (optional)
63+
encoder_seq_start_loc: Optional[torch.Tensor] = None
64+
"""(batch_size + 1,), cumulative encoder sequence lengths"""
65+
max_encoder_seq_len: Optional[int] = None
66+
"""Maximum encoder sequence length in batch"""
67+
cross_slot_mapping: Optional[torch.Tensor] = None
68+
"""Slot mapping for cross-attention KV cache"""
69+
6270

6371
M = TypeVar("M")
6472

0 commit comments

Comments
 (0)