Skip to content

Commit addcaae

Browse files
committed
v1: Add Whisper model support (encoder-decoder)
This brings Whisper support to V1 to close one of the remaining feature gaps with V0. Most of the changes apply to encoder-decoder models generally, though Whisper is the only one explicitly tested and is the only encoder-decoder model updated to support V1. **Whisper Model Implementation:** - Remove SupportsV0Only interface constraint to enable V1 compatibility - Update get_multimodal_embeddings() to return list format required by V1 **Flash Attention Backend:** - Add encoder attention metadata fields (encoder_seq_start_loc, max_encoder_seq_len, cross_slot_mapping) - Implement encoder self-attention support without using KV cache - Add cross-attention support for encoder-decoder models with proper KV cache handling **KV Cache Manager:** - Introduce CrossAttentionManager for handling cross-attention KV cache in encoder-decoder models - Add CrossAttentionSpec for cross-attention cache specification with encoder-based sizing - Implement allocate_slots_for_cross_attn() for static encoder-length-based allocation - Add cross-attention block allocation logic separate from decoder token growth **Scheduler:** - Disable prefix caching for encoder-decoder models - Implement cross-attention block allocation during request scheduling - Add cross-attention block tracking in state management **GPU Model Runner:** - Add encoder input extraction for audio features processing - Implement encoder attention metadata building for both self-attention and cross-attention - Add cross-attention KV cache group handling with proper slot mapping - Modify input batch creation to accommodate encoder sequence lengths - Add encoder input processing in forward pass with proper device/dtype handling - Update profiling and memory management for encoder-decoder models The implementation maintains backward compatibility while adding comprehensive encoder-decoder support, with particular focus on Whisper's audio processing pipeline and cross-attention mechanisms between encoder and decoder. Related to: - V0 deprecation: #18571 - 2025 Q3 roadmap: #20336 Signed-off-by: Russell Bryant <rbryant@redhat.com>
1 parent f0caa0b commit addcaae

File tree

13 files changed

+541
-78
lines changed

13 files changed

+541
-78
lines changed

vllm/attention/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
"AttentionMetadata",
1515
"AttentionType",
1616
"AttentionMetadataBuilder",
17-
"Attention",
1817
"AttentionState",
1918
"get_attn_backend",
2019
]

vllm/inputs/preprocess.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -869,9 +869,6 @@ def preprocess(
869869
) -> ProcessorInputs:
870870
"""Preprocess the input prompt."""
871871
if self.model_config.is_encoder_decoder:
872-
assert not return_mm_hashes, (
873-
"Multimodal hashes for encoder-decoder models should not be ",
874-
"returned until they are supported on vLLM V1.")
875872
# Encoder-decoder model requires special mapping of
876873
# input prompts to encoder & decoder
877874
return self._process_encoder_decoder_prompt(
@@ -903,9 +900,6 @@ async def preprocess_async(
903900
[`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
904901
"""
905902
if self.model_config.is_encoder_decoder:
906-
assert not return_mm_hashes, (
907-
"Multimodal hashes for encoder-decoder models should not be ",
908-
"returned until they are supported on vLLM V1.")
909903
# Encoder-decoder model requires special mapping of
910904
# input prompts to encoder & decoder
911905
return await self._process_encoder_decoder_prompt_async(prompt)

vllm/model_executor/models/whisper.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from vllm.transformers_utils.processor import cached_get_processor
4343

4444
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
45-
SupportsTranscription, SupportsV0Only)
45+
SupportsTranscription)
4646
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
4747
make_layers)
4848

@@ -790,7 +790,7 @@ def _get_prompt_updates(
790790
info=WhisperProcessingInfo,
791791
dummy_inputs=WhisperDummyInputsBuilder)
792792
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
793-
SupportsMultiModal, SupportsV0Only):
793+
SupportsMultiModal):
794794
packed_modules_mapping = {
795795
"self_attn.qkv_proj": [
796796
"self_attn.q_proj",
@@ -916,10 +916,9 @@ def get_language_model(self) -> torch.nn.Module:
916916

917917
def get_multimodal_embeddings(self,
918918
**kwargs: object) -> MultiModalEmbeddings:
919-
# TODO: This method does not obey the interface for SupportsMultiModal.
920-
# Refactor this once encoder/decoder support is implemented in V1.
919+
# Required as part of SupportsMultiModal interface.
921920
audio_input = self._parse_and_validate_audio_input(**kwargs)
922-
return self.model.get_encoder_outputs(audio_input["input_features"])
921+
return [self.model.get_encoder_outputs(audio_input["input_features"])]
923922

924923
def get_input_embeddings(
925924
self,

vllm/v1/attention/backends/flash_attn.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,15 @@ class FlashAttentionMetadata:
131131
max_num_splits: int = 0
132132

133133
causal: bool = True
134+
# Begin encoder attn & enc/dec cross-attn fields...
135+
136+
# (batch_size + 1,). The cumulative sequence lengths of the encoder
137+
# sequences in the batch, used to index into sequence. E.g., if the sequence
138+
# length is [4, 6], it is [0, 4, 10].
139+
encoder_seq_start_loc: Optional[torch.Tensor] = None
140+
# Maximum sequence length among encoder sequences
141+
max_encoder_seq_len: Optional[int] = None
142+
cross_slot_mapping: Optional[torch.Tensor] = None
134143

135144

136145
def _get_sliding_window_configs(
@@ -209,7 +218,13 @@ def build(self,
209218
num_reqs = common_attn_metadata.num_reqs
210219
num_actual_tokens = common_attn_metadata.num_actual_tokens
211220
max_query_len = common_attn_metadata.max_query_len
212-
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
221+
222+
if (common_attn_metadata.cross_slot_mapping is not None
223+
and common_attn_metadata.max_encoder_seq_len is not None):
224+
# ENCODER_DECODER cross-attention
225+
max_seq_len = common_attn_metadata.max_encoder_seq_len
226+
else:
227+
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
213228
query_start_loc = common_attn_metadata.query_start_loc
214229
seq_lens = common_attn_metadata.seq_lens
215230
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
@@ -329,7 +344,12 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
329344
suffix_kv_lens=suffix_kv_lens,
330345
prefix_scheduler_metadata=prefix_scheduler_metadata,
331346
max_num_splits=max_num_splits,
332-
causal=causal)
347+
causal=causal,
348+
# Encoder/cross-attention fields
349+
encoder_seq_start_loc=common_attn_metadata.encoder_seq_start_loc,
350+
max_encoder_seq_len=common_attn_metadata.max_encoder_seq_len,
351+
cross_slot_mapping=common_attn_metadata.cross_slot_mapping,
352+
)
333353
return attn_metadata
334354

335355
def can_run_in_cudagraph(
@@ -378,13 +398,6 @@ def __init__(
378398

379399
FlashAttentionBackend.validate_head_size(head_size)
380400

381-
if attn_type not in [
382-
AttentionType.DECODER, AttentionType.ENCODER_ONLY
383-
]:
384-
raise NotImplementedError("Encoder/decoder cross-attention "
385-
"is not implemented for "
386-
"FlashAttentionImpl")
387-
388401
self.attn_type = attn_type
389402
self.vllm_flash_attn_version = get_flash_attn_version()
390403
if is_quantized_kv_cache(self.kv_cache_dtype) \
@@ -442,7 +455,7 @@ def forward(
442455
num_actual_tokens = attn_metadata.num_actual_tokens
443456

444457
# Handle encoder attention differently - no KV cache needed
445-
if attn_type in (AttentionType.ENCODER_ONLY, ):
458+
if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
446459
# For encoder attention,
447460
# we use direct Q, K, V tensors without caching
448461
return self._forward_encoder_attention(query[:num_actual_tokens],
@@ -454,20 +467,26 @@ def forward(
454467
# For decoder and cross-attention, use KV cache as before
455468
key_cache, value_cache = kv_cache.unbind(0)
456469

457-
if self.kv_sharing_target_layer_name is None:
470+
if (self.kv_sharing_target_layer_name is None and (key is not None)
471+
and (value is not None)):
458472
# Reshape the input keys and values and store them in the cache.
459473
# Skip this if sharing KV cache with an earlier attention layer.
460474
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
461475
# not padded. However, we don't need to do key[:num_actual_tokens]
462476
# and value[:num_actual_tokens] because the reshape_and_cache_flash
463477
# op uses the slot_mapping's shape to determine the number of
464478
# actual tokens.
479+
if attn_type == AttentionType.ENCODER_DECODER:
480+
updated_slot_mapping = attn_metadata.cross_slot_mapping
481+
else:
482+
updated_slot_mapping = attn_metadata.slot_mapping
483+
465484
reshape_and_cache_flash(
466485
key,
467486
value,
468487
key_cache,
469488
value_cache,
470-
attn_metadata.slot_mapping,
489+
updated_slot_mapping,
471490
self.kv_cache_dtype,
472491
layer._k_scale,
473492
layer._v_scale,
@@ -491,7 +510,7 @@ def forward(
491510
block_table = attn_metadata.block_table
492511
scheduler_metadata = attn_metadata.scheduler_metadata
493512

494-
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
513+
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
495514

496515
flash_attn_varlen_func(
497516
q=query[:num_actual_tokens],
@@ -510,9 +529,9 @@ def forward(
510529
softcap=self.logits_soft_cap,
511530
scheduler_metadata=scheduler_metadata,
512531
fa_version=self.vllm_flash_attn_version,
513-
q_descale=layer._q_scale.expand(descale_shape),
514-
k_descale=layer._k_scale.expand(descale_shape),
515-
v_descale=layer._v_scale.expand(descale_shape),
532+
q_descale=layer._q_scale,
533+
k_descale=layer._k_scale,
534+
v_descale=layer._v_scale,
516535
num_splits=attn_metadata.max_num_splits,
517536
)
518537
return output
@@ -538,9 +557,9 @@ def forward(
538557
fa_version=self.vllm_flash_attn_version,
539558
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
540559
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
541-
q_descale=layer._q_scale,
542-
k_descale=layer._k_scale,
543-
v_descale=layer._v_scale,
560+
q_descale=layer._q_scale.expand(descale_shape),
561+
k_descale=layer._k_scale.expand(descale_shape),
562+
v_descale=layer._v_scale.expand(descale_shape),
544563
)
545564
return output
546565

vllm/v1/attention/backends/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ class CommonAttentionMetadata:
6161

6262
causal: bool
6363

64+
# Encoder/cross-attention specific fields (optional)
65+
encoder_seq_start_loc: Optional[torch.Tensor] = None
66+
"""(batch_size + 1,), cumulative encoder sequence lengths"""
67+
max_encoder_seq_len: Optional[int] = None
68+
"""Maximum encoder sequence length in batch"""
69+
cross_slot_mapping: Optional[torch.Tensor] = None
70+
"""Slot mapping for cross-attention KV cache"""
71+
6472

6573
M = TypeVar("M")
6674

vllm/v1/core/kv_cache_coordinator.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from vllm.v1.core.block_pool import BlockPool
77
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
88
from vllm.v1.core.single_type_kv_cache_manager import (
9-
FullAttentionManager, get_manager_for_kv_cache_spec)
9+
CrossAttentionManager, FullAttentionManager, get_manager_for_kv_cache_spec)
1010
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
1111
from vllm.v1.request import Request
1212

@@ -43,9 +43,12 @@ def __init__(
4343
) for i, kv_cache_group in enumerate(
4444
self.kv_cache_config.kv_cache_groups))
4545

46-
def get_num_blocks_to_allocate(
47-
self, request_id: str, num_tokens: int,
48-
new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int:
46+
def get_num_blocks_to_allocate(self,
47+
request_id: str,
48+
num_tokens: int,
49+
new_computed_blocks: tuple[
50+
list[KVCacheBlock], ...],
51+
cross_attn: bool = False) -> int:
4952
"""
5053
Get the number of blocks needed to be allocated for the request.
5154
@@ -61,8 +64,14 @@ def get_num_blocks_to_allocate(
6164
"""
6265
num_blocks_to_allocate = 0
6366
for i, manager in enumerate(self.single_type_managers):
64-
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
65-
request_id, num_tokens, new_computed_blocks[i])
67+
if cross_attn and isinstance(manager, CrossAttentionManager):
68+
# For cross-attention, we issue a single static allocation
69+
# of blocks based on the number of encoder input tokens.
70+
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
71+
request_id, num_tokens, [])
72+
elif not cross_attn:
73+
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
74+
request_id, num_tokens, new_computed_blocks[i])
6675
return num_blocks_to_allocate
6776

6877
def save_new_computed_blocks(
@@ -80,8 +89,11 @@ def save_new_computed_blocks(
8089
manager.save_new_computed_blocks(request_id,
8190
new_computed_blocks[i])
8291

83-
def allocate_new_blocks(self, request_id: str,
84-
num_tokens: int) -> tuple[list[KVCacheBlock], ...]:
92+
def allocate_new_blocks(
93+
self,
94+
request_id: str,
95+
num_tokens: int,
96+
cross_attn: bool = False) -> tuple[list[KVCacheBlock], ...]:
8597
"""
8698
Allocate new blocks for the request to give it at least `num_tokens`
8799
token slots.
@@ -95,7 +107,8 @@ def allocate_new_blocks(self, request_id: str,
95107
The new allocated blocks.
96108
"""
97109
return tuple(
98-
manager.allocate_new_blocks(request_id, num_tokens)
110+
(manager.allocate_new_blocks(request_id, num_tokens) if isinstance(
111+
manager, CrossAttentionManager) == cross_attn else [])
99112
for manager in self.single_type_managers)
100113

101114
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],

vllm/v1/core/kv_cache_manager.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,45 @@ def allocate_slots(
307307

308308
return KVCacheBlocks(new_blocks)
309309

310+
def allocate_slots_for_cross_attn(
311+
self,
312+
request: Request,
313+
num_encoder_tokens: int,
314+
) -> Optional[KVCacheBlocks]:
315+
"""Add slots for cross-attention blocks.
316+
317+
This is separate from the main `allocate_slots` function because
318+
cross-attention blocks are allocated based on the max encoder length,
319+
which is a static value. The number of blocks to allocate is not
320+
affected by the number of decoder tokens.
321+
322+
Args:
323+
request: The request to allocate slots.
324+
num_encoder_tokens: The number of tokens sent to the encoder.
325+
326+
Returns:
327+
A list of new allocated blocks.
328+
"""
329+
if num_encoder_tokens == 0:
330+
raise ValueError("num_encoder_tokens must be greater than 0")
331+
332+
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
333+
request_id=request.request_id,
334+
num_tokens=num_encoder_tokens,
335+
new_computed_blocks=tuple(),
336+
cross_attn=True,
337+
)
338+
339+
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
340+
# Cannot allocate new blocks
341+
return None
342+
343+
new_blocks = self.coordinator.allocate_new_blocks(request.request_id,
344+
num_encoder_tokens,
345+
cross_attn=True)
346+
347+
return KVCacheBlocks(new_blocks)
348+
310349
def free(self, request: Request) -> None:
311350
"""Free the blocks allocated for the request.
312351
We free the blocks in reverse order so that he tail blocks are evicted

vllm/v1/core/sched/scheduler.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
2020
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
2121
compute_encoder_budget)
22-
from vllm.v1.core.kv_cache_manager import KVCacheManager
22+
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
2323
from vllm.v1.core.sched.interface import SchedulerInterface
2424
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
2525
SchedulerOutput)
@@ -58,6 +58,7 @@ def __init__(
5858
self.parallel_config = vllm_config.parallel_config
5959
self.log_stats = log_stats
6060
self.structured_output_manager = structured_output_manager
61+
self.is_encoder_decoder = vllm_config.model_config.is_encoder_decoder
6162

6263
# include_finished_set controls whether a separate set of finished
6364
# request ids should be included in the EngineCoreOutputs returned
@@ -150,11 +151,17 @@ def __init__(
150151
self.use_eagle = True
151152
self.num_lookahead_tokens = self.num_spec_tokens
152153

154+
enable_caching = self.cache_config.enable_prefix_caching or False
155+
if self.is_encoder_decoder:
156+
# prefix caching for encoder-decoder models is not currently
157+
# supported
158+
enable_caching = False
159+
153160
# Create the KV cache manager.
154161
self.kv_cache_manager = KVCacheManager(
155162
kv_cache_config=kv_cache_config,
156163
max_model_len=self.max_model_len,
157-
enable_caching=self.cache_config.enable_prefix_caching,
164+
enable_caching=enable_caching,
158165
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
159166
use_eagle=self.use_eagle,
160167
log_stats=self.log_stats,
@@ -399,6 +406,7 @@ def schedule(self) -> SchedulerOutput:
399406

400407
encoder_inputs_to_schedule = None
401408
new_encoder_budget = encoder_budget
409+
new_cross_blocks: Optional[KVCacheBlocks] = None
402410

403411
# KVTransfer: loading remote KV, do not allocate for new work.
404412
if load_kv_async:
@@ -436,6 +444,22 @@ def schedule(self) -> SchedulerOutput:
436444
if num_new_tokens == 0:
437445
# The request cannot be scheduled.
438446
break
447+
if self.is_encoder_decoder:
448+
# For encoder-decoder models, we allocate slots for
449+
# the cross-attention blocks based on the max
450+
# encoder length. This is a single static allocation
451+
# and does not grow with the number of decoder
452+
# tokens.
453+
max_encoder_len = (self.vllm_config.model_config.
454+
hf_config.max_source_positions)
455+
new_cross_blocks = (self.kv_cache_manager.
456+
allocate_slots_for_cross_attn(
457+
request,
458+
max_encoder_len,
459+
))
460+
if new_cross_blocks is None:
461+
# The request cannot be scheduled.
462+
break
439463

440464
new_blocks = self.kv_cache_manager.allocate_slots(
441465
request,
@@ -454,9 +478,12 @@ def schedule(self) -> SchedulerOutput:
454478
# This information is used to determine if a load is
455479
# needed for this request.
456480
if self.connector is not None:
481+
update_blocks = new_computed_blocks + new_blocks
482+
if new_cross_blocks is not None:
483+
update_blocks += new_cross_blocks
457484
self.connector.update_state_after_alloc(
458485
request,
459-
new_computed_blocks + new_blocks,
486+
update_blocks,
460487
num_external_computed_tokens,
461488
)
462489

0 commit comments

Comments
 (0)