From f53dc96bdba99c24f181c90f33dc327b725097e5 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 2 Jul 2025 00:41:57 +0000 Subject: [PATCH 01/18] v1: Add Whisper model support (encoder-decoder) This brings Whisper support to V1 to close one of the remaining feature gaps with V0. Most of the changes apply to encoder-decoder models generally, though Whisper is the only one explicitly tested and is the only encoder-decoder model updated to support V1. **Whisper Model Implementation:** - Remove SupportsV0Only interface constraint to enable V1 compatibility - Update get_multimodal_embeddings() to return list format required by V1 **Flash Attention Backend:** - Add encoder attention metadata fields (encoder_seq_start_loc, max_encoder_seq_len, cross_slot_mapping) - Implement encoder self-attention support without using KV cache - Add cross-attention support for encoder-decoder models with proper KV cache handling **KV Cache Manager:** - Introduce CrossAttentionManager for handling cross-attention KV cache in encoder-decoder models - Add CrossAttentionSpec for cross-attention cache specification with encoder-based sizing - Implement allocate_slots_for_cross_attn() for static encoder-length-based allocation - Add cross-attention block allocation logic separate from decoder token growth **Scheduler:** - Disable prefix caching for encoder-decoder models - Implement cross-attention block allocation during request scheduling - Add cross-attention block tracking in state management **GPU Model Runner:** - Add encoder input extraction for audio features processing - Implement encoder attention metadata building for both self-attention and cross-attention - Add cross-attention KV cache group handling with proper slot mapping - Modify input batch creation to accommodate encoder sequence lengths - Add encoder input processing in forward pass with proper device/dtype handling - Update profiling and memory management for encoder-decoder models The implementation maintains backward compatibility while adding comprehensive encoder-decoder support, with particular focus on Whisper's audio processing pipeline and cross-attention mechanisms between encoder and decoder. Related to: - V0 deprecation: #18571 - 2025 Q3 roadmap: #20336 Signed-off-by: Russell Bryant --- vllm/attention/__init__.py | 1 - vllm/inputs/preprocess.py | 6 - vllm/model_executor/models/whisper.py | 9 +- vllm/v1/attention/backends/flash_attn.py | 45 ++- vllm/v1/attention/backends/utils.py | 8 + vllm/v1/core/kv_cache_coordinator.py | 31 +- vllm/v1/core/kv_cache_manager.py | 39 +++ vllm/v1/core/sched/scheduler.py | 33 +- vllm/v1/core/single_type_kv_cache_manager.py | 56 ++- vllm/v1/engine/processor.py | 5 - vllm/v1/kv_cache_interface.py | 18 + vllm/v1/worker/gpu_model_runner.py | 350 +++++++++++++++++-- vllm/v1/worker/utils.py | 14 +- 13 files changed, 542 insertions(+), 73 deletions(-) diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 344040586a53..dcb2aa68fbee 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -14,7 +14,6 @@ "AttentionMetadata", "AttentionType", "AttentionMetadataBuilder", - "Attention", "AttentionState", "get_attn_backend", ] diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index de5dc0876651..672174991525 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -841,9 +841,6 @@ def preprocess( ) -> ProcessorInputs: """Preprocess the input prompt.""" if self.model_config.is_encoder_decoder: - assert not return_mm_hashes, ( - "Multimodal hashes for encoder-decoder models should not be ", - "returned until they are supported on vLLM V1.") # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return self._process_encoder_decoder_prompt( @@ -873,9 +870,6 @@ async def preprocess_async( [`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess]. """ if self.model_config.is_encoder_decoder: - assert not return_mm_hashes, ( - "Multimodal hashes for encoder-decoder models should not be ", - "returned until they are supported on vLLM V1.") # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return await self._process_encoder_decoder_prompt_async(prompt) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index ca02ecd828ba..36695229a8ee 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -42,7 +42,7 @@ from vllm.transformers_utils.processor import cached_get_processor from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, - SupportsTranscription, SupportsV0Only) + SupportsTranscription) from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, make_layers) @@ -744,7 +744,7 @@ def _get_prompt_updates( info=WhisperProcessingInfo, dummy_inputs=WhisperDummyInputsBuilder) class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, - SupportsMultiModal, SupportsV0Only): + SupportsMultiModal): packed_modules_mapping = { "self_attn.qkv_proj": [ "self_attn.q_proj", @@ -871,10 +871,9 @@ def get_language_model(self) -> torch.nn.Module: def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: - # TODO: This method does not obey the interface for SupportsMultiModal. - # Refactor this once encoder/decoder support is implemented in V1. + # Required as part of SupportsMultiModal interface. audio_input = self._parse_and_validate_audio_input(**kwargs) - return self.model.get_encoder_outputs(audio_input["input_features"]) + return [self.model.get_encoder_outputs(audio_input["input_features"])] def get_input_embeddings( self, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 95ba56b35937..5185bb0110cf 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -139,6 +139,15 @@ class FlashAttentionMetadata: max_num_splits: int = 0 causal: bool = True + # Begin encoder attn & enc/dec cross-attn fields... + + # (batch_size + 1,). The cumulative sequence lengths of the encoder + # sequences in the batch, used to index into sequence. E.g., if the sequence + # length is [4, 6], it is [0, 4, 10]. + encoder_seq_start_loc: Optional[torch.Tensor] = None + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + cross_slot_mapping: Optional[torch.Tensor] = None def _get_sliding_window_configs( @@ -220,7 +229,13 @@ def build(self, num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + + if (common_attn_metadata.cross_slot_mapping is not None + and common_attn_metadata.max_encoder_seq_len is not None): + # ENCODER_DECODER cross-attention + max_seq_len = common_attn_metadata.max_encoder_seq_len + else: + max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu @@ -347,7 +362,12 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, max_num_splits=max_num_splits, - causal=causal) + causal=causal, + # Encoder/cross-attention fields + encoder_seq_start_loc=common_attn_metadata.encoder_seq_start_loc, + max_encoder_seq_len=common_attn_metadata.max_encoder_seq_len, + cross_slot_mapping=common_attn_metadata.cross_slot_mapping, + ) return attn_metadata def can_run_in_cudagraph( @@ -397,13 +417,6 @@ def __init__( FlashAttentionBackend.validate_head_size(head_size) - if attn_type not in [ - AttentionType.DECODER, AttentionType.ENCODER_ONLY - ]: - raise NotImplementedError("Encoder/decoder cross-attention " - "is not implemented for " - "FlashAttentionImpl") - self.attn_type = attn_type self.vllm_flash_attn_version = get_flash_attn_version() if is_quantized_kv_cache(self.kv_cache_dtype) \ @@ -469,7 +482,7 @@ def forward( num_actual_tokens = attn_metadata.num_actual_tokens # Handle encoder attention differently - no KV cache needed - if attn_type in (AttentionType.ENCODER_ONLY, ): + if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): # For encoder attention, # we use direct Q, K, V tensors without caching return self._forward_encoder_attention(query[:num_actual_tokens], @@ -481,7 +494,8 @@ def forward( # For decoder and cross-attention, use KV cache as before key_cache, value_cache = kv_cache.unbind(0) - if self.kv_sharing_target_layer_name is None: + if (self.kv_sharing_target_layer_name is None and (key is not None) + and (value is not None)): # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. # NOTE(woosuk): Here, key and value are padded while slot_mapping is @@ -489,12 +503,17 @@ def forward( # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. + if attn_type == AttentionType.ENCODER_DECODER: + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + updated_slot_mapping = attn_metadata.slot_mapping + reshape_and_cache_flash( key, value, key_cache, value_cache, - attn_metadata.slot_mapping, + updated_slot_mapping, self.kv_cache_dtype, layer._k_scale, layer._v_scale, @@ -520,7 +539,7 @@ def forward( block_table = attn_metadata.block_table scheduler_metadata = attn_metadata.scheduler_metadata - descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) flash_attn_varlen_func( q=query[:num_actual_tokens], diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index e23dd8bc5bbb..41344c1fbfd9 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -64,6 +64,14 @@ class CommonAttentionMetadata: causal: bool = True + # Encoder/cross-attention specific fields (optional) + encoder_seq_start_loc: Optional[torch.Tensor] = None + """(batch_size + 1,), cumulative encoder sequence lengths""" + max_encoder_seq_len: Optional[int] = None + """Maximum encoder sequence length in batch""" + cross_slot_mapping: Optional[torch.Tensor] = None + """Slot mapping for cross-attention KV cache""" + @dataclass class UbatchSlice: diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index f3a16d64e19f..a98f364a4dc7 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -6,7 +6,7 @@ from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.core.single_type_kv_cache_manager import ( - FullAttentionManager, get_manager_for_kv_cache_spec) + CrossAttentionManager, FullAttentionManager, get_manager_for_kv_cache_spec) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.request import Request @@ -44,9 +44,12 @@ def __init__( ) for i, kv_cache_group in enumerate( self.kv_cache_config.kv_cache_groups)) - def get_num_blocks_to_allocate( - self, request_id: str, num_tokens: int, - new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int: + def get_num_blocks_to_allocate(self, + request_id: str, + num_tokens: int, + new_computed_blocks: tuple[ + list[KVCacheBlock], ...], + cross_attn: bool = False) -> int: """ Get the number of blocks needed to be allocated for the request. @@ -62,8 +65,14 @@ def get_num_blocks_to_allocate( """ num_blocks_to_allocate = 0 for i, manager in enumerate(self.single_type_managers): - num_blocks_to_allocate += manager.get_num_blocks_to_allocate( - request_id, num_tokens, new_computed_blocks[i]) + if cross_attn and isinstance(manager, CrossAttentionManager): + # For cross-attention, we issue a single static allocation + # of blocks based on the number of encoder input tokens. + num_blocks_to_allocate += manager.get_num_blocks_to_allocate( + request_id, num_tokens, []) + elif not cross_attn: + num_blocks_to_allocate += manager.get_num_blocks_to_allocate( + request_id, num_tokens, new_computed_blocks[i]) return num_blocks_to_allocate def save_new_computed_blocks( @@ -81,8 +90,11 @@ def save_new_computed_blocks( manager.save_new_computed_blocks(request_id, new_computed_blocks[i]) - def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> tuple[list[KVCacheBlock], ...]: + def allocate_new_blocks( + self, + request_id: str, + num_tokens: int, + cross_attn: bool = False) -> tuple[list[KVCacheBlock], ...]: """ Allocate new blocks for the request to give it at least `num_tokens` token slots. @@ -96,7 +108,8 @@ def allocate_new_blocks(self, request_id: str, The new allocated blocks. """ return tuple( - manager.allocate_new_blocks(request_id, num_tokens) + (manager.allocate_new_blocks(request_id, num_tokens) if isinstance( + manager, CrossAttentionManager) == cross_attn else []) for manager in self.single_type_managers) def cache_blocks(self, request: Request, block_hashes: list[BlockHash], diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index ce333dbe61a1..7677311dda75 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -304,6 +304,45 @@ def allocate_slots( return KVCacheBlocks(new_blocks) + def allocate_slots_for_cross_attn( + self, + request: Request, + num_encoder_tokens: int, + ) -> Optional[KVCacheBlocks]: + """Add slots for cross-attention blocks. + + This is separate from the main `allocate_slots` function because + cross-attention blocks are allocated based on the max encoder length, + which is a static value. The number of blocks to allocate is not + affected by the number of decoder tokens. + + Args: + request: The request to allocate slots. + num_encoder_tokens: The number of tokens sent to the encoder. + + Returns: + A list of new allocated blocks. + """ + if num_encoder_tokens == 0: + raise ValueError("num_encoder_tokens must be greater than 0") + + num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate( + request_id=request.request_id, + num_tokens=num_encoder_tokens, + new_computed_blocks=tuple(), + cross_attn=True, + ) + + if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): + # Cannot allocate new blocks + return None + + new_blocks = self.coordinator.allocate_new_blocks(request.request_id, + num_encoder_tokens, + cross_attn=True) + + return KVCacheBlocks(new_blocks) + def free(self, request: Request) -> None: """Free the blocks allocated for the request. We free the blocks in reverse order so that he tail blocks are evicted diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d39aea1f2d11..94c689928ecb 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -19,7 +19,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) -from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) @@ -58,6 +58,7 @@ def __init__( self.parallel_config = vllm_config.parallel_config self.log_stats = log_stats self.structured_output_manager = structured_output_manager + self.is_encoder_decoder = vllm_config.model_config.is_encoder_decoder # include_finished_set controls whether a separate set of finished # request ids should be included in the EngineCoreOutputs returned @@ -150,11 +151,17 @@ def __init__( self.use_eagle = True self.num_lookahead_tokens = self.num_spec_tokens + enable_caching = self.cache_config.enable_prefix_caching or False + if self.is_encoder_decoder: + # prefix caching for encoder-decoder models is not currently + # supported + enable_caching = False + # Create the KV cache manager. self.kv_cache_manager = KVCacheManager( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, - enable_caching=self.cache_config.enable_prefix_caching, + enable_caching=enable_caching, caching_hash_algo=self.cache_config.prefix_caching_hash_algo, use_eagle=self.use_eagle, log_stats=self.log_stats, @@ -399,6 +406,7 @@ def schedule(self) -> SchedulerOutput: encoder_inputs_to_schedule = None new_encoder_budget = encoder_budget + new_cross_blocks: Optional[KVCacheBlocks] = None # KVTransfer: loading remote KV, do not allocate for new work. if load_kv_async: @@ -436,6 +444,22 @@ def schedule(self) -> SchedulerOutput: if num_new_tokens == 0: # The request cannot be scheduled. break + if self.is_encoder_decoder: + # For encoder-decoder models, we allocate slots for + # the cross-attention blocks based on the max + # encoder length. This is a single static allocation + # and does not grow with the number of decoder + # tokens. + max_encoder_len = (self.vllm_config.model_config. + hf_config.max_source_positions) + new_cross_blocks = (self.kv_cache_manager. + allocate_slots_for_cross_attn( + request, + max_encoder_len, + )) + if new_cross_blocks is None: + # The request cannot be scheduled. + break new_blocks = self.kv_cache_manager.allocate_slots( request, @@ -454,9 +478,12 @@ def schedule(self) -> SchedulerOutput: # This information is used to determine if a load is # needed for this request. if self.connector is not None: + update_blocks = new_computed_blocks + new_blocks + if new_cross_blocks is not None: + update_blocks += new_cross_blocks self.connector.update_state_after_alloc( request, - new_computed_blocks + new_blocks, + update_blocks, num_external_computed_tokens, ) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 8f310023a8cd..f9da0a4696cb 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -9,8 +9,9 @@ from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, - FullAttentionSpec, KVCacheSpec, - MambaSpec, SlidingWindowSpec) + CrossAttentionSpec, FullAttentionSpec, + KVCacheSpec, MambaSpec, + SlidingWindowSpec) from vllm.v1.request import Request @@ -560,11 +561,62 @@ def allocate_new_blocks(self, request_id: str, return new_blocks +class CrossAttentionManager(SingleTypeKVCacheManager): + """Manager for cross-attention KV cache in encoder-decoder models.""" + + def save_new_computed_blocks( + self, request_id: str, + new_computed_blocks: list[KVCacheBlock]) -> None: + # We do not allocate blocks as decoder tokens are generated, so this + # method is not relevant. + pass + + def cache_blocks(self, request: Request, block_hashes: list[BlockHash], + num_tokens: int) -> None: + # We do not cache blocks for cross-attention to be shared between + # requests, so this method is not relevant. + pass + + def get_num_common_prefix_blocks(self, request_id: str, + num_running_requests: int) -> int: + # Cross-attention blocks contain request-specific encoder states + # and are not shared between different requests + return 0 + + @classmethod + def find_longest_cache_hit( + cls, + block_hashes: list[BlockHash], + max_length: int, + kv_cache_group_ids: list[int], + block_pool: BlockPool, + kv_cache_spec: KVCacheSpec, + use_eagle: bool, + ) -> tuple[list[KVCacheBlock], ...]: + assert isinstance(kv_cache_spec, CrossAttentionSpec), ( + "CrossAttentionManager can only be used for cross-attention groups" + ) + # Cross-attention does not benefit from prefix caching since: + # 1. Encoder states are unique per request (different audio/image + # inputs) + # 2. Encoder states are computed once per request, not incrementally + # 3. No reusable prefix exists between different multimodal inputs + # Return empty blocks to indicate no cache hits + return tuple([] for _ in range(len(kv_cache_group_ids))) + + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: + # Cross-attention blocks represent encoder states which are needed + # for the entire decoding process, so no blocks should be skipped + pass + + spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, MambaSpec: MambaManager, + CrossAttentionSpec: CrossAttentionManager, } diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 6e37ebeb8778..48c9cf393c94 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -234,7 +234,6 @@ def process_inputs( ) -> tuple[Optional[str], EngineCoreRequest]: # TODO(woosuk): Support pooling models. - # TODO(woosuk): Support encoder-decoder models. self._validate_lora(lora_request) self._validate_params(params, lora_request) if trace_headers is not None: @@ -273,10 +272,6 @@ def process_inputs( encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) - # TODO: Impl encoder-decoder - if encoder_inputs is not None: - raise NotImplementedError - sampling_params = None pooling_params = None if isinstance(params, SamplingParams): diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 4ff96f9786b8..c7c3cf8f7727 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -202,6 +202,24 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: return self.page_size_bytes +@dataclass(frozen=True) +class CrossAttentionSpec(AttentionSpec): + """ + KV cache spec for cross-attention layers in encoder-decoder models. + """ + + @property + def type_id(self) -> str: + return f"cross_attention_{self.block_size}_{self.page_size_bytes}" + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + # For cross-attention, we need to cache encoder states + # Use max_source_positions for encoder length (e.g., 1500 for Whisper) + max_encoder_len = ( + vllm_config.model_config.hf_config.max_source_positions) + return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes + + @dataclass class KVCacheTensor: """ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 08b253dcdb35..99f2c162dc68 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -19,6 +19,7 @@ import vllm.envs as envs from vllm.attention import Attention, AttentionType from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.counter import compilation_counter from vllm.config import (CompilationLevel, VllmConfig, @@ -56,8 +57,8 @@ reorder_batch_to_split_decodes_and_prefills) from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, - FullAttentionSpec, KVCacheConfig, - KVCacheSpec, MambaSpec, + CrossAttentionSpec, FullAttentionSpec, + KVCacheConfig, KVCacheSpec, MambaSpec, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) @@ -150,6 +151,21 @@ def __init__( self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope + encoder_compute_budget, encoder_cache_size = compute_encoder_budget( + model_config=model_config, + scheduler_config=scheduler_config, + mm_registry=self.mm_registry, + ) + self.max_num_encoder_input_tokens = encoder_compute_budget + self.encoder_cache_size = encoder_cache_size + if self.model_config.is_encoder_decoder: + # If specified in the model config, this attribute defines the + # maximum length of the encoder input. + self.max_encoder_len = getattr(self.model_config.hf_config, + 'max_source_positions', 0) + else: + self.max_encoder_len = 0 + # Sampler self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) @@ -207,7 +223,9 @@ def __init__( # the block_sizes in the kv cache config. self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, + # We need to use the encoder length for encoder-decoer + # because of KV cache for cross-attention. + max_model_len=max(self.max_model_len, self.max_encoder_len), max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, @@ -788,16 +806,24 @@ def _prepare_inputs( # Prepare encoder attention metadata separately # (encoder layers are not in KV cache groups) - if self.is_encoder_only_model: - common_attn_metadata, encoder_attn_metadata = \ - self._build_encoder_only_attn_metadata( - scheduler_output) + if self.is_encoder_only_model or ( + self.model_config.is_encoder_decoder + and scheduler_output.scheduled_encoder_inputs): + if self.is_encoder_only_model: + common_attn_metadata, encoder_attn_metadata = \ + self._build_encoder_only_attn_metadata( + scheduler_output) + else: + common_attn_metadata, encoder_attn_metadata = \ + self._build_enc_dec_attn_metadata( + scheduler_output) # Add encoder attention metadata for all encoder layers attention_layers = get_layers_from_vllm_config( self.vllm_config, Attention) for layer_name, attn_module in attention_layers.items(): - if attn_module.attn_type == AttentionType.ENCODER_ONLY: + if attn_module.attn_type in (AttentionType.ENCODER_ONLY, + AttentionType.ENCODER): attn_metadata[layer_name] = encoder_attn_metadata # Prepare the attention metadata for each KV cache group and make layers @@ -828,6 +854,12 @@ def _prepare_inputs( causal=True, ) + is_enc_dec = isinstance(kv_cache_group_spec.kv_cache_spec, + CrossAttentionSpec) + if is_enc_dec: + _, encoder_attn_metadata = self._build_enc_dec_attn_metadata( + scheduler_output, common_attn_metadata) + if self.speculative_config and \ spec_decode_common_attn_metadata is None: spec_decode_common_attn_metadata = common_attn_metadata @@ -845,10 +877,11 @@ def _prepare_inputs( builder, ) - attn_metadata_i = (builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - )) + attn_metadata_i = (encoder_attn_metadata + if is_enc_dec else builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + )) fast_prefill_metadata = attn_metadata_i if (self.cache_config.kv_sharing_fast_prefill @@ -1212,6 +1245,99 @@ def _gather_mm_embeddings( mm_embeds.append(mm_embeds_item) return mm_embeds + def _extract_encoder_inputs( + self, + scheduler_output: "SchedulerOutput", + ) -> dict[str, torch.Tensor]: + """Extract encoder inputs for encoder-decoder models like Whisper. + + This method extracts audio input features and creates encoder positions + from scheduled encoder inputs. These are only needed when the encoder + needs to process new MM inputs (typically on the first processing step). + """ + input_features_list = [] + total_encoder_tokens = 0 + + for req_id, encoder_input_ids in ( + scheduler_output.scheduled_encoder_inputs.items()): + req_state = self.requests[req_id] + + for mm_input_id in encoder_input_ids: + if mm_input_id < len(req_state.mm_inputs): + mm_input = req_state.mm_inputs[mm_input_id] + # Extract input_features from MM input kwargs + if "input_features" in mm_input: + features = mm_input["input_features"] + input_features_list.append(features) + # Calculate encoder sequence length for this input + total_encoder_tokens += ( + self._get_encoder_sequence_length(features)) + + if not input_features_list: + return {} + + # Process and concatenate input features + input_features = self._process_input_features(input_features_list) + + # Move input_features to the correct device and dtype + input_features = input_features.to(device=self.device, + dtype=self.model_config.dtype) + + # Create encoder positions (similar to how V0 does it) + encoder_positions = torch.arange(total_encoder_tokens, + dtype=torch.long, + device=self.device) + + # Create encoder input_ids (dummy tokens for encoder) + encoder_input_ids = torch.zeros(total_encoder_tokens, + dtype=torch.long, + device=self.device) + + return { + "input_features": input_features, + "encoder_input_ids": encoder_input_ids, + "encoder_positions": encoder_positions, + } + + def _get_encoder_sequence_length( + self, features: Union[torch.Tensor, list]) -> int: + """Get the encoder sequence length for the given features.""" + # For Whisper: use max_source_positions from config + # which represents the encoder sequence length + encoder_seq_len = getattr(self.model_config.hf_config, + 'max_source_positions', 1500) + + if isinstance(features, list): + return len(features) * encoder_seq_len + else: + return encoder_seq_len + + def _process_input_features(self, + input_features_list: list) -> torch.Tensor: + """Process and concatenate input features into a single tensor.""" + if len(input_features_list) == 1 and isinstance( + input_features_list[0], torch.Tensor): + input_features = input_features_list[0] + # Ensure we have the correct 4D shape + # [batch, channels, mel_bins, time] + if input_features.dim() == 3: + # Add batch dim: [ch, mel, time] -> [1, ch, mel, time] + input_features = input_features.unsqueeze(0) + else: + # Handle list of tensors + processed_features = [] + for feat in input_features_list: + if isinstance(feat, torch.Tensor): + # Ensure 4D shape + if feat.dim() == 3: + feat = feat.unsqueeze(0) + processed_features.append(feat) + else: + processed_features.append(torch.stack(feat)) + input_features = torch.cat(processed_features) + + return input_features + def get_model(self) -> nn.Module: return self.model @@ -1479,14 +1605,16 @@ def execute_model( # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order - if self.is_multimodal_model: + if (self.is_multimodal_model + and not self.model_config.is_encoder_decoder): # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output) else: mm_embeds = [] - if self.is_multimodal_model and get_pp_group().is_first_rank: + if self.is_multimodal_model and get_pp_group().is_first_rank and ( + not self.model_config.is_encoder_decoder): # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. @@ -1537,6 +1665,12 @@ def execute_model( ), self.maybe_get_kv_connector_output( scheduler_output) as kv_connector_output: + extra_kwargs: dict = {} + if (self.model_config.is_encoder_decoder + and scheduler_output.scheduled_encoder_inputs): + encoder_inputs = self._extract_encoder_inputs(scheduler_output) + extra_kwargs.update(encoder_inputs) + model_output = self.model( input_ids=input_ids, positions=positions, @@ -1546,6 +1680,7 @@ def execute_model( model_mm_kwargs, device=self.device, ), + **extra_kwargs, ) if self.use_aux_hidden_state_outputs: @@ -2209,7 +2344,8 @@ def _dummy_run( with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): - if self.is_multimodal_model: + if (self.is_multimodal_model + and not self.model_config.is_encoder_decoder): input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] model_mm_kwargs = self._dummy_mm_kwargs(num_reqs) @@ -2417,7 +2553,7 @@ def _dummy_pooler_run( def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache. - if self.is_multimodal_model: + if self.is_multimodal_model and not self.model_config.is_encoder_decoder: mm_budget = self.mm_budget assert mm_budget is not None @@ -2701,7 +2837,7 @@ def may_reinitialize_input_batch(self, "for more details.") self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, + max_model_len=max(self.max_model_len, self.max_encoder_len), max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, @@ -2953,9 +3089,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: self.shared_kv_cache_layers[layer_name] = kv_tgt_layer continue - # TODO: Support other attention modules, e.g., cross-attention - # TODO(lucas): move the attention specs into the model layers like - # the attention backends if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( @@ -2981,12 +3114,17 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: head_size=attn_module.head_size, dtype=self.kv_cache_dtype, use_mla=use_mla) + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + kv_cache_spec[layer_name] = CrossAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. continue - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - raise NotImplementedError else: raise ValueError( f"Unknown attention type: {attn_module.attn_type}") @@ -3065,3 +3203,171 @@ def _build_encoder_only_attn_metadata( common_prefix_len=0, # No cascade for encoder common_attn_metadata=common_metadata, ) + + def _build_enc_dec_attn_metadata( + self, + scheduler_output: "SchedulerOutput", + common_attn_metadata: Optional[CommonAttentionMetadata] = None + ) -> tuple[CommonAttentionMetadata, dict[str, Any]]: + """Prepare encoder attention metadata for encoder-decoder models. + + Args: + scheduler_output: Scheduler output + common_attn_metadata: Optional common attention metadata for + cross-attention + + Returns: + tuple: (CommonAttentionMetadata, encoder attention metadata) + """ + # Get encoder input information from scheduled encoder inputs + scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs + + # Calculate encoder sequence lengths and cross slot mappings + encoder_seq_lens = [] + cross_slot_mapping = [] + num_encoder_tokens = 0 + + for req_id in scheduled_encoder_inputs: + encoder_seq_len = self.max_encoder_len + encoder_seq_lens.append(encoder_seq_len) + num_encoder_tokens += encoder_seq_len + + # Build cross slot mapping for cross-attention + if common_attn_metadata is not None: + cross_slot_mapping.extend( + self._get_cross_slot_mapping(req_id, encoder_seq_len)) + + # Create encoder sequence start locations (cumulative sum) + encoder_seq_start_loc = [0] + for seq_len in encoder_seq_lens: + encoder_seq_start_loc.append(encoder_seq_start_loc[-1] + seq_len) + + # Convert to tensors + encoder_seq_lens_tensor = torch.tensor(encoder_seq_lens, + dtype=torch.int32, + device=self.device) + encoder_seq_start_loc_tensor = torch.tensor(encoder_seq_start_loc, + dtype=torch.int32, + device=self.device) + + # Build common metadata based on attention type + is_cross_attention = common_attn_metadata is not None + common_metadata = self._build_encoder_common_metadata( + encoder_seq_lens, encoder_seq_lens_tensor, + encoder_seq_start_loc_tensor, num_encoder_tokens, + common_attn_metadata, is_cross_attention) + + # Set encoder fields + common_metadata.encoder_seq_start_loc = encoder_seq_start_loc_tensor + common_metadata.max_encoder_seq_len = self.max_encoder_len + + # Add cross slot mapping for cross-attention + if is_cross_attention: + common_metadata.cross_slot_mapping = torch.tensor( + cross_slot_mapping, dtype=torch.int64, device=self.device) + + # Use the first attention metadata builder + builder = self.attn_metadata_builders[0] + return common_metadata, builder.build( + common_prefix_len=0, # No cascade for encoder + common_attn_metadata=common_metadata, + ) + + def _get_cross_slot_mapping(self, req_id: str, + encoder_seq_len: int) -> list[int]: + """Get cross-attention slot mapping for a request.""" + req_state = self.requests.get(req_id) + if req_state is None: + # During memory profiling or if request not found + return [PAD_SLOT_ID] * encoder_seq_len + + # Find the KV cache group that uses CrossAttentionSpec + cross_attn_group_idx = None + for i, kv_cache_group in enumerate( + self.kv_cache_config.kv_cache_groups): + if isinstance(kv_cache_group.kv_cache_spec, CrossAttentionSpec): + cross_attn_group_idx = i + break + + if (cross_attn_group_idx is None + or cross_attn_group_idx >= len(req_state.block_ids)): + return [PAD_SLOT_ID] * encoder_seq_len + + # Get cross attention block IDs and calculate slot mapping + cross_block_ids = req_state.block_ids[cross_attn_group_idx] + block_size = self.kv_cache_config.kv_cache_groups[ + cross_attn_group_idx].kv_cache_spec.block_size + + slot_mapping = [] + for i in range(encoder_seq_len): + block_number = cross_block_ids[i // block_size] + block_offset = i % block_size + slot = block_number * block_size + block_offset + slot_mapping.append(slot) + + return slot_mapping + + def _build_encoder_common_metadata( + self, encoder_seq_lens: list[int], + encoder_seq_lens_tensor: torch.Tensor, + encoder_seq_start_loc_tensor: torch.Tensor, + num_encoder_tokens: int, + common_attn_metadata: Optional[CommonAttentionMetadata], + is_cross_attention: bool) -> CommonAttentionMetadata: + """Build common attention metadata for encoder attention.""" + if is_cross_attention: + # ENCODER_DECODER cross-attention - use decoder metadata as base + assert common_attn_metadata is not None, ( + "common_attn_metadata must be provided for cross-attention") + + seq_lens_tensor = torch.full( + (common_attn_metadata.num_reqs, ), + self.max_encoder_len, + dtype=torch.int32, + device=self.device, + ) + seq_lens_cpu = torch.full( + (common_attn_metadata.num_reqs, ), + self.max_encoder_len, + dtype=torch.int32, + device="cpu", + ) + return CommonAttentionMetadata( + query_start_loc=common_attn_metadata.query_start_loc, + query_start_loc_cpu=common_attn_metadata.query_start_loc_cpu, + seq_lens=seq_lens_tensor, + seq_lens_cpu=seq_lens_cpu, + num_computed_tokens_cpu=common_attn_metadata. + num_computed_tokens_cpu, + num_reqs=common_attn_metadata.num_reqs, + num_actual_tokens=common_attn_metadata.num_actual_tokens, + max_query_len=common_attn_metadata.max_query_len, + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + causal=False, + ) + else: + # ENCODER self-attention - create new metadata + dummy_block_table = torch.zeros((len(encoder_seq_lens), 1), + dtype=torch.int32, + device=self.device) + dummy_slot_mapping = torch.zeros((num_encoder_tokens, ), + dtype=torch.int32, + device=self.device) + dummy_computed_tokens = torch.zeros((len(encoder_seq_lens), ), + dtype=torch.int32, + device="cpu") + + return CommonAttentionMetadata( + query_start_loc=encoder_seq_start_loc_tensor, + query_start_loc_cpu=encoder_seq_start_loc_tensor.cpu(), + seq_lens=encoder_seq_lens_tensor, + seq_lens_cpu=encoder_seq_lens_tensor.cpu(), + num_computed_tokens_cpu=dummy_computed_tokens, + num_reqs=len(encoder_seq_lens), + num_actual_tokens=num_encoder_tokens, + max_query_len=self.max_encoder_len, + block_table_tensor=dummy_block_table, + slot_mapping=dummy_slot_mapping, + causal=False, + ) diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index e7079235d651..5a9453b46fc5 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -277,14 +277,14 @@ def bind_kv_cache( index2name[extract_layer_index(layer_name)].append(layer_name) for layer_index in sorted(index2name.keys()): + # Some models (like encoder-decoder models) may have multiple + # layers with the same index, so we need to append all of them. + # For an encoder-decoder model, each decoder layer has + # self-attention (AttentionType.DECODER) + # and cross-attention (AttentionType.ENCODER_DECODER). layer_names = index2name[layer_index] - if len(layer_names) > 1: - # One typical case is encoder-decoder model, e.g., bart. - # The cross attention and self attention in the same decoder layer - # has different layer_name but the same layer_index. - raise NotImplementedError - layer_name = layer_names[0] - runner_kv_caches.append(kv_caches[layer_name]) + for layer_name in layer_names: + runner_kv_caches.append(kv_caches[layer_name]) # Bind kv_caches to forward context for layer_name, kv_cache in kv_caches.items(): From 76cb36affb7f67a6d8c215279a159b2619c1193b Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 30 Jul 2025 13:25:20 +0000 Subject: [PATCH 02/18] Improvements from NickLucche - fix encoder budget for encoder-decoder model to return actual encoder max tokens - disable encoder chunking due to bidirectional attention (disable_chunked_mm_input flag) - disabling features solely in config.py and not in scheduler.py - Abstracting a few parts that are too whisper specific: - max_source_positions is a whisper thing. I've replaced it with a method provided by the MM interface, more below. (the way num_features tokens are counted assuming fixed encoder input size could also use reviewing) - Clarify in code the role of SupportsMultiModal: I believe to have a consistent way of using the MM path in code we need enc-dec models to implement this interface so that it can function in all places where an encoder input length is expected like above (eg scheduling path). Ideally one would have a separate interface for encoder-decoder models only, but I am not sure it's worth the effort rn. - minor corrections of comments referring to eventual v1 enc-dec support These updates come from #22018 Signed-off-by: NickLucche --- tests/v1/test_oracle.py | 1 - vllm/config.py | 33 +++++++++++++++++++++------ vllm/model_executor/models/whisper.py | 5 ++-- vllm/multimodal/registry.py | 14 ++++++++++++ vllm/multimodal/utils.py | 2 +- vllm/v1/core/encoder_cache_manager.py | 1 - vllm/v1/core/sched/scheduler.py | 17 +++++--------- vllm/v1/kv_cache_interface.py | 7 +++--- vllm/v1/worker/gpu_model_runner.py | 32 ++++++++------------------ 9 files changed, 62 insertions(+), 50 deletions(-) diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index a756c89b520f..e774e3979dd0 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -10,7 +10,6 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine UNSUPPORTED_MODELS_V1 = [ - "openai/whisper-large-v3", # transcription "facebook/bart-large-cnn", # encoder decoder ] diff --git a/vllm/config.py b/vllm/config.py index 7147702eddde..511d1615c99b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -34,6 +34,7 @@ from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.platforms import current_platform from vllm.transformers_utils.config import ( ConfigFormat, get_config, get_hf_image_processor_config, @@ -4872,22 +4873,40 @@ def __post_init__(self): disable_chunked_prefill_reasons: list[str] = [] - if self.model_config and self.model_config.pooler_config: - pooling_type = self.model_config.pooler_config.pooling_type - if pooling_type is None or pooling_type.lower() != "last": + if self.model_config: + if self.model_config.pooler_config: + pooling_type = self.model_config.pooler_config.pooling_type + if pooling_type is None or pooling_type.lower() != "last": + disable_chunked_prefill_reasons.append( + "Only \"last\" pooling supports chunked " + "prefill and prefix caching; disabling both.") + elif self.model_config.is_encoder_decoder: + self.scheduler_config.max_num_encoder_input_tokens = \ + MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config) + logger.debug( + "Encoder-decoder model detected: setting " + "`max_num_encoder_input_tokens` to encoder length (%s)", + self.scheduler_config.max_num_encoder_input_tokens) + self.scheduler_config.disable_chunked_mm_input = True disable_chunked_prefill_reasons.append( - "Only \"last\" pooling supports chunked " - "prefill and prefix caching; disabling both.") + "Encoder-decoder models do not support chunked prefill nor" + " prefix caching; disabling both.") if disable_chunked_prefill_reasons: for reason in disable_chunked_prefill_reasons: logger.info(reason) self.scheduler_config.chunked_prefill_enabled = False self.scheduler_config.long_prefill_token_threshold = 0 - self.scheduler_config.max_num_batched_tokens = max( + new_max_num_batched_tokens = max( self.scheduler_config.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS) - + if (new_max_num_batched_tokens + != self.scheduler_config.max_num_batched_tokens): + logger.info("Updating max_num_batched_tokens from %d to %d", + self.scheduler_config.max_num_batched_tokens, + new_max_num_batched_tokens) + self.scheduler_config.max_num_batched_tokens = \ + new_max_num_batched_tokens if self.cache_config is not None: self.cache_config.enable_prefix_caching = False diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 36695229a8ee..23573cb99b8f 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -880,9 +880,8 @@ def get_input_embeddings( input_ids: torch.Tensor, multimodal_embeddings: Optional[NestedTensors] = None, ) -> torch.Tensor: - # TODO: This method just returns the decoder sequence embeddings since - # Whisper does not have encoder text tokens. Refactor this once - # encoder/decoder support is implemented in V1. + # This method just returns the decoder sequence embeddings since + # Whisper does not have encoder text tokens. return self.model.decoder.get_input_embeddings(input_ids) def _parse_and_validate_audio_input( diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 565d54e1a264..ffb691f2cb9f 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -315,3 +315,17 @@ def get_encoder_dummy_data( ) return dummy_data + + def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int: + """ + Get the maximum length of the encoder input for encoder-decoder models. + """ + if not model_config.is_encoder_decoder: + return 0 + max_tokens = self.\ + get_max_tokens_per_item_by_nonzero_modality(model_config) + assert len(max_tokens) == 1, "Encoder-decoder models are expected \ + to implement the multimodal interface with at most one modality." + + first_modality = next(iter(max_tokens)) + return max_tokens[first_modality] \ No newline at end of file diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 8dfbc6503520..6a4a998c7e59 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -489,4 +489,4 @@ def fetch_video( "video": video_io_kwargs } media_connector = MediaConnector(media_io_kwargs=media_io_kwargs) - return media_connector.fetch_video(video_url) \ No newline at end of file + return media_connector.fetch_video(video_url) diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 67ea3b007ece..6666dcd5d09d 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -192,7 +192,6 @@ def compute_encoder_budget( if not model_config.is_multimodal_model: return 0, 0 - # TODO: handle encoder-decoder models once we support them. ( encoder_compute_budget, encoder_cache_size, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 94c689928ecb..595ee20af465 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -132,8 +132,8 @@ def __init__( ) # NOTE(woosuk): Here, "encoder" includes the vision encoder (and - # projector if needed). Currently, we assume that the encoder also - # has the Transformer architecture (e.g., ViT). + # projector if needed) for MM models as well as encoder-decoder + # transformers. self.max_num_encoder_input_tokens = encoder_compute_budget # NOTE: For the models without encoder (e.g., text-only models), # the encoder cache will not be initialized because cache size is 0 @@ -151,17 +151,11 @@ def __init__( self.use_eagle = True self.num_lookahead_tokens = self.num_spec_tokens - enable_caching = self.cache_config.enable_prefix_caching or False - if self.is_encoder_decoder: - # prefix caching for encoder-decoder models is not currently - # supported - enable_caching = False - # Create the KV cache manager. self.kv_cache_manager = KVCacheManager( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, - enable_caching=enable_caching, + enable_caching=self.cache_config.enable_prefix_caching, caching_hash_algo=self.cache_config.prefix_caching_hash_algo, use_eagle=self.use_eagle, log_stats=self.log_stats, @@ -450,8 +444,9 @@ def schedule(self) -> SchedulerOutput: # encoder length. This is a single static allocation # and does not grow with the number of decoder # tokens. - max_encoder_len = (self.vllm_config.model_config. - hf_config.max_source_positions) + max_encoder_len = MULTIMODAL_REGISTRY.\ + get_encdec_max_encoder_len( + self.vllm_config.model_config) new_cross_blocks = (self.kv_cache_manager. allocate_slots_for_cross_attn( request, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index c7c3cf8f7727..60342c714d92 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -11,6 +11,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.utils import cdiv, get_dtype_size logger = init_logger(__name__) @@ -214,9 +215,9 @@ def type_id(self) -> str: def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: # For cross-attention, we need to cache encoder states - # Use max_source_positions for encoder length (e.g., 1500 for Whisper) - max_encoder_len = ( - vllm_config.model_config.hf_config.max_source_positions) + # Get encoder length (e.g., 1500 for Whisper). + max_encoder_len = MULTIMODAL_REGISTRY.\ + get_encdec_max_encoder_len(vllm_config.model_config) return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 99f2c162dc68..b6a9857ffc6d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -158,13 +158,9 @@ def __init__( ) self.max_num_encoder_input_tokens = encoder_compute_budget self.encoder_cache_size = encoder_cache_size - if self.model_config.is_encoder_decoder: - # If specified in the model config, this attribute defines the - # maximum length of the encoder input. - self.max_encoder_len = getattr(self.model_config.hf_config, - 'max_source_positions', 0) - else: - self.max_encoder_len = 0 + # Maximum length of the encoder input, only for encoder-decoder models. + self.max_encoder_len = self.mm_registry.\ + get_encdec_max_encoder_len(model_config) # Sampler self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) @@ -1263,6 +1259,7 @@ def _extract_encoder_inputs( req_state = self.requests[req_id] for mm_input_id in encoder_input_ids: + # TODO (NickLucche) this is very whisper specific atm, refactor if mm_input_id < len(req_state.mm_inputs): mm_input = req_state.mm_inputs[mm_input_id] # Extract input_features from MM input kwargs @@ -1270,8 +1267,10 @@ def _extract_encoder_inputs( features = mm_input["input_features"] input_features_list.append(features) # Calculate encoder sequence length for this input - total_encoder_tokens += ( - self._get_encoder_sequence_length(features)) + num_features = len(features) if isinstance( + features, list) else 1 + total_encoder_tokens += num_features * \ + self.max_encoder_len if not input_features_list: return {} @@ -1283,7 +1282,7 @@ def _extract_encoder_inputs( input_features = input_features.to(device=self.device, dtype=self.model_config.dtype) - # Create encoder positions (similar to how V0 does it) + # Create encoder positions encoder_positions = torch.arange(total_encoder_tokens, dtype=torch.long, device=self.device) @@ -1299,19 +1298,6 @@ def _extract_encoder_inputs( "encoder_positions": encoder_positions, } - def _get_encoder_sequence_length( - self, features: Union[torch.Tensor, list]) -> int: - """Get the encoder sequence length for the given features.""" - # For Whisper: use max_source_positions from config - # which represents the encoder sequence length - encoder_seq_len = getattr(self.model_config.hf_config, - 'max_source_positions', 1500) - - if isinstance(features, list): - return len(features) * encoder_seq_len - else: - return encoder_seq_len - def _process_input_features(self, input_features_list: list) -> torch.Tensor: """Process and concatenate input features into a single tensor.""" From 222f6787eff7e14f0f45179c4193683af43afcb8 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 5 Aug 2025 00:37:09 +0000 Subject: [PATCH 03/18] scheduler: Disable encoder cache manager for encoder-decoder We do not cache encoder outputs for encoder-decoder models. The scheduler code was tracking encoder data as if we were. Make it explicit that there is no encoder cache for encoder-decoder. Signed-off-by: Russell Bryant --- vllm/v1/core/sched/scheduler.py | 39 +++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 595ee20af465..05beb6472a61 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -138,8 +138,13 @@ def __init__( # NOTE: For the models without encoder (e.g., text-only models), # the encoder cache will not be initialized because cache size is 0 # for these models. - self.encoder_cache_manager = EncoderCacheManager( - cache_size=encoder_cache_size) + self.encoder_cache_manager: Optional[EncoderCacheManager] = None + if not self.is_encoder_decoder: + # An encoder-decoder model does not use the encoder cache. + # It uses bidirectional attention and inputs are only + # processed once per request. + self.encoder_cache_manager = EncoderCacheManager( + cache_size=encoder_cache_size) speculative_config = vllm_config.speculative_config @@ -313,8 +318,9 @@ def schedule(self) -> SchedulerOutput: scheduled_encoder_inputs[request.request_id] = ( encoder_inputs_to_schedule) # Allocate the encoder cache. - for i in encoder_inputs_to_schedule: - self.encoder_cache_manager.allocate(request, i) + if self.encoder_cache_manager: + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) encoder_budget = new_encoder_budget # Record the LoRAs in scheduled_running_reqs @@ -524,8 +530,9 @@ def schedule(self) -> SchedulerOutput: scheduled_encoder_inputs[request.request_id] = ( encoder_inputs_to_schedule) # Allocate the encoder cache. - for i in encoder_inputs_to_schedule: - self.encoder_cache_manager.allocate(request, i) + if self.encoder_cache_manager: + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) encoder_budget = new_encoder_budget # Put back any skipped requests at the head of the waiting queue @@ -571,6 +578,10 @@ def schedule(self) -> SchedulerOutput: scheduled_spec_decode_tokens, req_to_new_block_ids, ) + if self.encoder_cache_manager: + free_encoder_input_ids = self.encoder_cache_manager.get_freed_ids() + else: + free_encoder_input_ids = [] scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -584,7 +595,7 @@ def schedule(self) -> SchedulerOutput: # It contains the request IDs that are finished in between # the previous and the current steps. finished_req_ids=self.finished_req_ids, - free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), + free_encoder_input_ids=free_encoder_input_ids, structured_output_request_ids=structured_output_request_ids, grammar_bitmask=grammar_bitmask, ) @@ -731,7 +742,8 @@ def _try_schedule_encoder_inputs( # in the decoder's KV cache. continue - if self.encoder_cache_manager.has_cache(request, i): + if (self.encoder_cache_manager + and self.encoder_cache_manager.has_cache(request, i)): # The encoder input is already computed and cached. continue @@ -745,7 +757,8 @@ def _try_schedule_encoder_inputs( num_new_tokens = start_pos - num_computed_tokens break - if (not self.encoder_cache_manager.can_allocate(request, i) + if ((self.encoder_cache_manager + and not self.encoder_cache_manager.can_allocate(request, i)) or num_encoder_tokens > encoder_budget): # The encoder cache is full or the encoder budget is exhausted. # NOTE(woosuk): We assume that the encoder input tokens should @@ -958,6 +971,8 @@ def _update_request_with_output( return new_token_ids, stopped def _free_encoder_inputs(self, request: Request) -> None: + if not self.encoder_cache_manager: + return cached_encoder_input_ids = ( self.encoder_cache_manager.get_cached_input_ids(request)) # OPTIMIZATION: Avoid list(set) if the set is empty. @@ -970,7 +985,8 @@ def _free_encoder_inputs(self, request: Request) -> None: mm_positions = request.mm_positions[input_id] start_pos = mm_positions.offset num_tokens = mm_positions.length - if start_pos + num_tokens <= request.num_computed_tokens: + if (self.encoder_cache_manager + and start_pos + num_tokens <= request.num_computed_tokens): # The encoder output is already processed and stored # in the decoder's KV cache. self.encoder_cache_manager.free_encoder_input( @@ -1034,7 +1050,8 @@ def _free_request(self, request: Request) -> Optional[dict[str, Any]]: assert request.is_finished() delay_free_blocks, kv_xfer_params = self._connector_finished(request) - self.encoder_cache_manager.free(request) + if self.encoder_cache_manager: + self.encoder_cache_manager.free(request) request_id = request.request_id self.finished_req_ids.add(request_id) if self.finished_req_ids_dict is not None: From ed2b2068b2a430862c1ae44dfcc5c78bfe679b0d Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 5 Aug 2025 14:05:11 +0000 Subject: [PATCH 04/18] fix changes from a bad rebase Signed-off-by: Russell Bryant --- vllm/v1/worker/gpu_model_runner.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b6a9857ffc6d..b74f4a267997 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -151,13 +151,6 @@ def __init__( self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope - encoder_compute_budget, encoder_cache_size = compute_encoder_budget( - model_config=model_config, - scheduler_config=scheduler_config, - mm_registry=self.mm_registry, - ) - self.max_num_encoder_input_tokens = encoder_compute_budget - self.encoder_cache_size = encoder_cache_size # Maximum length of the encoder input, only for encoder-decoder models. self.max_encoder_len = self.mm_registry.\ get_encdec_max_encoder_len(model_config) @@ -2539,7 +2532,8 @@ def _dummy_pooler_run( def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache. - if self.is_multimodal_model and not self.model_config.is_encoder_decoder: + if (self.is_multimodal_model + and not self.model_config.is_encoder_decoder): mm_budget = self.mm_budget assert mm_budget is not None From 96b4b510f91fa90cf30a1c982e93b0b51829cb4a Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 5 Aug 2025 14:14:10 +0000 Subject: [PATCH 05/18] disable bart tests and examples, only whisper is working Signed-off-by: Russell Bryant --- .buildkite/test-pipeline.yaml | 3 +-- examples/offline_inference/encoder_decoder.py | 2 ++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e139c6b30586..aa4c77100022 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -297,7 +297,6 @@ steps: - python3 offline_inference/vision_language_pooling.py --seed 0 - python3 offline_inference/vision_language_multi_image.py --seed 0 - VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - - python3 offline_inference/encoder_decoder.py - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 - python3 offline_inference/basic/classify.py - python3 offline_inference/basic/embed.py @@ -500,7 +499,7 @@ steps: - vllm/ - tests/encoder_decoder commands: - - pytest -v -s encoder_decoder + - echo TODO # TODO: bart is not yet supported in V1 - label: OpenAI-Compatible Tool Use # 20 min mirror_hardwares: [amdexperimental] diff --git a/examples/offline_inference/encoder_decoder.py b/examples/offline_inference/encoder_decoder.py index 0da6fa5c4af5..f54a7ccac323 100644 --- a/examples/offline_inference/encoder_decoder.py +++ b/examples/offline_inference/encoder_decoder.py @@ -3,6 +3,8 @@ """ Demonstrate prompting of text-to-text encoder/decoder models, specifically BART + +NOTE: This example is not yet supported in V1. """ from vllm import LLM, SamplingParams From a53e04494f1221bdfdbc18ef7073cff756d03557 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 6 Aug 2025 00:14:39 +0000 Subject: [PATCH 06/18] Remove V0 specific encoder-decoder worker test Signed-off-by: Russell Bryant --- .../test_encoder_decoder_model_runner.py | 648 ------------------ 1 file changed, 648 deletions(-) delete mode 100644 tests/worker/test_encoder_decoder_model_runner.py diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py deleted file mode 100644 index 35ac90b38e84..000000000000 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ /dev/null @@ -1,648 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import itertools - -import pytest -import torch - -from vllm.engine.arg_utils import EngineArgs -from vllm.platforms import current_platform -from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import make_tensor_with_pad -from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner - -BATCH_SIZES = [1, 4, 16, 64, 256] - - -def _create_model_runner(model: str, *args, - **kwargs) -> EncoderDecoderModelRunner: - engine_args = EngineArgs(model, *args, **kwargs) - engine_config = engine_args.create_engine_config() - model_runner = EncoderDecoderModelRunner( - vllm_config=engine_config, - is_driver_worker=True, - ) - return model_runner - - -@pytest.mark.skipif(condition=current_platform.is_cpu(), - reason="CPU backend is currently " - "unsupported for encoder/ " - "decoder models") -def test_empty_seq_group(): - """Verify prepare prompt and decode returns empty output - for empty seq group list""" - - model_runner = _create_model_runner( - "facebook/bart-base", - seed=0, - dtype="float16", - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=False, - enforce_eager=True, - ) - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - model_input = model_runner._prepare_model_input_tensors( - seq_group_metadata_list) - ( - input_tokens, - input_positions, - encoder_input_tokens, - encoder_input_positions, - attn_metadata, - return_seq_lens, - ) = ( - model_input.input_tokens, - model_input.input_positions, - model_input.encoder_input_tokens, - model_input.encoder_input_positions, - model_input.attn_metadata, - model_input.seq_lens, - ) - assert input_tokens is None - assert input_positions is None - assert encoder_input_tokens is None - assert encoder_input_positions is None - assert attn_metadata is None - assert return_seq_lens is None - - -@pytest.mark.skipif(condition=current_platform.is_cpu(), - reason="CPU backend is currently " - "unsupported for encoder/ " - "decoder models") -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -def test_prepare_prompt(batch_size): - ''' - Test the ability of the encoder/decoder model runner subclass to - produce prefill-phase model inputs & attention metadata. - - Test behavior: - - * Instantiate BART base model & enc/dec model runner - * Construct sequence-group metadata for dummy prompts - * Test that encoder attention, decoder self-attention, - and encoder/decoder cross-attention inputs are correct - - Arguments: - - * batch_size - * backend_name: The attention backend under test - * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph) - ''' - - model_runner = _create_model_runner( - "facebook/bart-base", - seed=0, - dtype="float16", - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=False, - enforce_eager=True, - ) - - seq_lens: list[int] = [] - encoder_seq_lens: list[int] = [] - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - block_tables = {0: [1]} - cross_block_table = [2] - for i in range(batch_size): - # make sure all tokens fit into one block - seq_len = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seq_len) - seq_data = SequenceData.from_seqs(range(seq_len)) - encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_lens.append(encoder_seq_len) - encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len)) - seq_group_metadata = SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: seq_data}, - sampling_params=SamplingParams(temperature=0), - block_tables=block_tables, - encoder_seq_data=encoder_seq_data, - cross_block_table=cross_block_table, - ) - assert seq_group_metadata.token_chunk_size == seq_data.get_len() - seq_group_metadata_list.append(seq_group_metadata) - - # Build - # * Decoder model inputs - # * Decoder self-attention KV caching data structures - # * Encoder model inputs - # * Encoder/decoder cross-attention KV caching data structures - model_input = model_runner.prepare_model_input(seq_group_metadata_list) - - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - attn_metadata = model_input.attn_metadata - return_seq_lens = model_input.seq_lens - slot_mapping = attn_metadata.slot_mapping - encoder_input_tokens = model_input.encoder_input_tokens - encoder_input_positions = model_input.encoder_input_positions - cross_slot_mapping = attn_metadata.cross_slot_mapping - assert return_seq_lens == seq_lens - assert len(slot_mapping) == len(input_tokens) - assert len(cross_slot_mapping) == len(encoder_input_tokens) - - # Verify input metadata is correct for prompts. - # - Decoder attention metadata - device = model_runner.device - assert attn_metadata.num_prefills > 0 - assert attn_metadata.num_decode_tokens == 0 - assert torch.equal(attn_metadata.seq_lens_tensor, - torch.tensor(seq_lens, device=device, dtype=torch.int)) - assert attn_metadata.seq_lens == seq_lens - assert attn_metadata.max_prefill_seq_len == max(seq_lens) - assert attn_metadata.max_decode_seq_len == 0 - # - Encoder attention metadata - assert attn_metadata.encoder_seq_lens == encoder_seq_lens - assert torch.equal( - attn_metadata.encoder_seq_lens_tensor, - torch.tensor(encoder_seq_lens, device=device, dtype=torch.int)) - assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens) - assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens) - - # Test decoder subquery start locs. - start_idx = 0 - start_loc = [start_idx] - for seq_len in seq_lens: - start_idx += seq_len - start_loc.append(start_idx) - assert torch.equal( - attn_metadata.query_start_loc, - torch.tensor(start_loc, dtype=torch.int32, device=device), - ) - - # Test decoder seq start locs & context lengths - - assert torch.equal( - attn_metadata.seq_start_loc, - torch.tensor(start_loc, dtype=torch.int32, device=device), - ) - assert torch.equal( - attn_metadata.context_lens_tensor, - torch.zeros(attn_metadata.context_lens_tensor.shape[0], - dtype=torch.int, - device=device), - ) - - # Verify block tables are correct for prompts - # - Decoder self-attention - expected = torch.tensor( - [[] for _ in range(len(seq_group_metadata_list))], - dtype=torch.int32, - device=model_runner.device, - ) - assert torch.equal( - attn_metadata.block_tables, - expected, - ) - # - Encoder/decoder cross-attention - assert torch.equal( - attn_metadata.cross_block_tables, - expected, - ) - - # Cuda graph should not be used for prefill. - assert attn_metadata.use_cuda_graph is False - - # Verify the lengths of input tokens & positions - # - Decoder - assert len(input_tokens) == sum(seq_lens) - assert len(input_positions) == sum(seq_lens) - # -- An indirect check that model_input.input_tokens - # and model_input.input_positions are correct - - # by design of the test, the input tokens are - # equal to the input position values, so if - # the model_input data structure has the correct - # values then these two should be equal - assert torch.equal( - input_tokens, - input_positions, - ) - # - Encoder - assert len(encoder_input_tokens) == sum(encoder_seq_lens) - # -- An indirect check that model_input.encoder_input_tokens - # and model_input.encoder_input_positions are correct - - # by design of the test, the input tokens are - # equal to the input position values, so if - # the model_input data structure has the correct - # values then these two should be equal - assert torch.equal( - encoder_input_tokens, - encoder_input_positions, - ) - - # Test that vLLM sampling infrastructure chooses the correct - # sequence positions at which to sample (i.e. the end of - # each sequence) in the prefill phase - - expected_selected_token_indices = [] - selected_token_start_idx = 0 - for seq_len in seq_lens: - # Compute the index offset of the final token in each - # prompt (recall that the prompts are concatenated) - expected_selected_token_indices.append(selected_token_start_idx + - seq_len - 1) - selected_token_start_idx += seq_len - - sampling_metadata = model_input.sampling_metadata - actual = sampling_metadata.selected_token_indices - expected = torch.tensor( - expected_selected_token_indices, - device=actual.device, - dtype=actual.dtype, - ) - assert torch.equal(actual, expected) - - -@pytest.mark.skipif(condition=current_platform.is_cpu(), - reason="CPU backend is currently " - "unsupported for encoder/ " - "decoder models") -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False]) -def test_prepare_decode(batch_size, multiple_seqs_per_seq_group): - ''' - Test the ability of the encoder/decoder model runner subclass to - produce decode-phase model inputs & attention metadata. - - Test behavior: - - * Instantiate BART base model & enc/dec model runner - * Construct sequence-group metadata for dummy prompts - * Test that encoder attention, decoder self-attention, - and encoder/decoder cross-attention inputs are correct - - Arguments: - - * batch_size - * multiple_seqs_per_seq_group - * backend_name: The attention backend under test - * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph) - ''' - - model_runner = _create_model_runner( - "facebook/bart-base", - seed=0, - dtype="float16", - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=False, - enforce_eager=True, - ) - - seq_lens: list[int] = [] - encoder_seq_lens: list[int] = [] - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - block_tables = { - 0: [1], - 1: [3] - } if multiple_seqs_per_seq_group else { - 0: [1] - } - cross_block_table = [2] - for i in range(batch_size): - # make sure all tokens fit into one block - seq_len = i % (model_runner.block_size - 1) + 1 - seq_data = SequenceData.from_seqs(range(seq_len)) - encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len)) - - seq_group_metadata = SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=False, - seq_data={ - 0: seq_data, - 1: seq_data - } if multiple_seqs_per_seq_group else {0: seq_data}, - sampling_params=SamplingParams(temperature=0), - block_tables=block_tables, - encoder_seq_data=encoder_seq_data, - cross_block_table=cross_block_table, - ) - assert seq_group_metadata.token_chunk_size == 1 - seq_group_metadata_list.append(seq_group_metadata) - seq_lens.extend( - [seq_len for _ in range(len(seq_group_metadata.seq_data))]) - encoder_seq_lens.extend( - [encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))]) - - # Build - # * Decoder model inputs - # * Decoder self-attention KV caching data structures - # * Encoder model inputs - # * Encoder/decoder cross-attention KV caching data structures - model_input = model_runner.prepare_model_input(seq_group_metadata_list) - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - attn_metadata = model_input.attn_metadata - return_seq_lens = model_input.seq_lens - slot_mapping = attn_metadata.slot_mapping - encoder_input_tokens = model_input.encoder_input_tokens - encoder_input_positions = model_input.encoder_input_positions - cross_slot_mapping = attn_metadata.cross_slot_mapping - assert return_seq_lens == seq_lens - assert len(slot_mapping) == len(input_tokens) - assert len(cross_slot_mapping) == len(encoder_input_tokens) - - # Verify input metadata is correct for decode phase. - # - Decoder attention metadata - device = model_runner.device - assert attn_metadata.num_prefills == 0 - assert attn_metadata.num_decode_tokens > 0 - assert torch.equal(attn_metadata.seq_lens_tensor, - torch.tensor(seq_lens, device=device, dtype=torch.int)) - assert attn_metadata.seq_lens == seq_lens - assert attn_metadata.max_prefill_seq_len == 0 - assert attn_metadata.max_decode_seq_len == max(seq_lens) - # - Encoder attention metadata - assert attn_metadata.encoder_seq_lens == encoder_seq_lens - assert torch.equal( - attn_metadata.encoder_seq_lens_tensor, - torch.tensor(encoder_seq_lens, device=device, dtype=torch.int)) - assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens) - assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens) - - # Test decoder subquery start locs. - start_idx = 0 - start_loc = [start_idx] - for seq_len in seq_lens: - start_idx += 1 - start_loc.append(start_idx) - assert torch.equal( - attn_metadata.query_start_loc, - torch.tensor(start_loc, dtype=torch.int32, device=device), - ) - - # Test decoder seq start locs. Note that for normal prefill it is - # equivalent to query_start_loc. - start_idx = 0 - seq_start_loc = [start_idx] - for seq_len in seq_lens: - start_idx += seq_len - seq_start_loc.append(start_idx) - - # Test seq_start_loc and context lengths - - assert torch.equal( - attn_metadata.seq_start_loc, - torch.tensor(seq_start_loc, dtype=torch.int32, device=device), - ) - assert torch.equal( - attn_metadata.context_lens_tensor, - torch.tensor([seq_len - 1 for seq_len in seq_lens], - dtype=torch.int, - device=device)) - - # Verify block tables are correct for prompts - # - Decoder self-attention - flattened_block_tables = [ - block_table for block_table in block_tables.values() - ] - expected = torch.tensor(flattened_block_tables * - len(seq_group_metadata_list), - dtype=torch.int32, - device=model_runner.device) - assert torch.equal( - attn_metadata.block_tables, - expected, - ) - # - Encoder/decoder cross-attention - expected = torch.tensor([ - cross_block_table for seq_group_metadata in seq_group_metadata_list - for _ in range(len(seq_group_metadata.seq_data)) - ], - dtype=torch.int32, - device=model_runner.device) - assert torch.equal( - attn_metadata.cross_block_tables, - expected, - ) - - # Model runner's CUDAGraph setting should be propagated to attention - # metadata. - assert attn_metadata.use_cuda_graph is False - - # Verify the lengths of input tokens & positions - # - Decoder - assert len(input_tokens) == len(seq_lens) - assert len(input_positions) == len(seq_lens) - # -- An indirect check that model_input.input_tokens - # and model_input.input_positions are correct - - # by design of the test, the input tokens are - # equal to the input position values, so if - # the model_input data structure has the correct - # values then these two should be equal - assert torch.equal( - input_tokens, - input_positions, - ) - # - Encoder - assert len(encoder_input_tokens) == 0 - assert len(encoder_input_tokens) == 0 - # -- An indirect check that model_input.encoder_input_tokens - # and model_input.encoder_input_positions are correct - - # by design of the test, the input tokens are - # equal to the input position values, so if - # the model_input data structure has the correct - # values then these two should be equal - assert torch.equal( - encoder_input_tokens, - encoder_input_positions, - ) - - # Test that vLLM sampling infrastructure chooses the correct - # sequence positions at which to sample (i.e. the end of - # each sequence) in the decode phase - - expected_selected_token_indices = [] - for selected_token_start_idx, seq_len in enumerate(seq_lens): - # Compute the index offset of the final token in each - # sequence's decoded outputs; since a single token is - # decoded per iteration per sequence, then the length - # of the decoded tokens for a given sequence is 1 and - # the final index offset into a given sequence's - # generated tokens is 0 (i.e. the expected sampling index - # for a given sequence is just `selected_token_start_idx`) - expected_selected_token_indices.append(selected_token_start_idx) - - sampling_metadata = model_input.sampling_metadata - actual = sampling_metadata.selected_token_indices - expected = torch.tensor( - expected_selected_token_indices, - device=actual.device, - dtype=actual.dtype, - ) - assert torch.equal(actual, expected) - - -@pytest.mark.parametrize("batch_size", list(range(1, 257))) -@pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False]) -def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): - """ - Tests that for encoder-decoder models with CUDA Graph capture and replay - enabled, the tensors used during the decode phase are correctly padded - for varying input batch sizes. - """ - model_runner = _create_model_runner( - "facebook/bart-base", - seed=0, - dtype="float16", - max_num_batched_tokens=100000, - max_num_seqs=100000, - enable_chunked_prefill=False, - enforce_eager=False, - ) - block_tables = { - 0: [1], - 1: [3] - } if multiple_seqs_per_seq_group else { - 0: [1] - } - seq_lens: list[int] = [] - encoder_seq_lens: list[int] = [] - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - - cross_block_table = [2] - expanded_batch_size = 0 - for i in range(batch_size): - # make sure all tokens fit into one block - seq_len = i % (model_runner.block_size - 1) + 1 - seq_data = SequenceData.from_seqs(range(seq_len)) - encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len)) - seq_group_metadata = SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=False, - seq_data={ - 0: seq_data, - 1: seq_data - } if multiple_seqs_per_seq_group else {0: seq_data}, - sampling_params=SamplingParams(temperature=0), - block_tables=block_tables, - encoder_seq_data=encoder_seq_data, - cross_block_table=cross_block_table, - ) - assert seq_group_metadata.token_chunk_size == 1 - seq_lens.extend( - [seq_len for _ in range(len(seq_group_metadata.seq_data))]) - encoder_seq_lens.extend( - [encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))]) - expanded_batch_size = expanded_batch_size + len( - seq_group_metadata.seq_data) - seq_group_metadata_list.append(seq_group_metadata) - - model_input = model_runner.prepare_model_input(seq_group_metadata_list) - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - attn_metadata = model_input.attn_metadata - return_seq_lens = model_input.seq_lens - slot_mapping = attn_metadata.slot_mapping - encoder_input_tokens = model_input.encoder_input_tokens - encoder_input_positions = model_input.encoder_input_positions - cross_slot_mapping = attn_metadata.cross_slot_mapping - - # With CUDA Graph capture and replay enabled, the decoder and encoder - # input sequences will be padded. Create the expected padded tensors - # accordingly. - graph_batch_size = model_runner.vllm_config.pad_for_cudagraph( - expanded_batch_size) - cuda_graph_pad_size = graph_batch_size - expanded_batch_size - padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size)) - padded_encoder_seq_lens = encoder_seq_lens + list( - itertools.repeat(1, cuda_graph_pad_size)) - - assert return_seq_lens == padded_seq_lens - assert len(slot_mapping) == len(input_tokens) - assert len(cross_slot_mapping) == len(encoder_input_tokens) - - # Verify attention metadata - device = model_runner.device - assert attn_metadata.num_prefills == 0 - assert attn_metadata.num_decode_tokens > 0 - assert torch.equal( - attn_metadata.seq_lens_tensor, - torch.tensor(padded_seq_lens, device=device, dtype=torch.int)) - assert attn_metadata.seq_lens == padded_seq_lens - assert attn_metadata.max_prefill_seq_len == 0 - assert attn_metadata.max_decode_seq_len == max(seq_lens) - # - Encoder attention metadata - assert attn_metadata.encoder_seq_lens == padded_encoder_seq_lens - assert torch.equal( - attn_metadata.encoder_seq_lens_tensor, - torch.tensor(padded_encoder_seq_lens, device=device, dtype=torch.int)) - assert attn_metadata.max_encoder_seq_len == max(padded_encoder_seq_lens) - assert attn_metadata.num_encoder_tokens == sum(padded_encoder_seq_lens) - - # Verify block tables are correct for prompts - # - Decoder self-attention. Pad the block tables as expected. - flattened_block_tables = [ - block_table for _ in range(len(seq_group_metadata_list)) - for block_table in block_tables.values() - ] - flattened_block_tables.extend([[] for _ in range(cuda_graph_pad_size)]) - expected = make_tensor_with_pad( - flattened_block_tables, - max_len=64, - pad=0, - dtype=torch.int32, - device=model_runner.device, - ) - assert torch.equal( - attn_metadata.block_tables, - expected, - ) - # - Encoder/decoder cross-attention. Pad the cross-attention block tables - # as expected. - expected = [ - cross_block_table for seq_group_metadata in seq_group_metadata_list - for _ in range(len(seq_group_metadata.seq_data)) - ] - expected.extend([[] for _ in range(cuda_graph_pad_size)]) - expected = make_tensor_with_pad( - expected, - max_len=64, - pad=0, - dtype=torch.int32, - device=model_runner.device, - ) - assert torch.equal( - attn_metadata.cross_block_tables, - expected, - ) - - # Model runner's CUDAGraph setting should be propagated to attention - # metadata. - assert attn_metadata.use_cuda_graph is True - - # Verify the lengths of input tokens & positions - # - Decoder - assert len(input_tokens) == len(padded_seq_lens) - assert len(input_positions) == len(padded_seq_lens) - # -- An indirect check that model_input.input_tokens - # and model_input.input_positions are correct - - # by design of the test, the input tokens are - # equal to the input position values, so if - # the model_input data structure has the correct - # values then these two should be equal - assert torch.equal( - input_tokens, - input_positions, - ) - # - Encoder - assert len(encoder_input_tokens) == 0 - assert len(encoder_input_tokens) == 0 - # -- An indirect check that model_input.encoder_input_tokens - # and model_input.encoder_input_positions are correct - - # by design of the test, the input tokens are - # equal to the input position values, so if - # the model_input data structure has the correct - # values then these two should be equal - assert torch.equal( - encoder_input_tokens, - encoder_input_positions, - ) From 70f5f8a2125ef2435960ba6a544394565e98eff8 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 6 Aug 2025 00:17:25 +0000 Subject: [PATCH 07/18] skip bart test since bart is not in v1 yet Signed-off-by: Russell Bryant --- tests/entrypoints/openai/test_encoder_decoder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/entrypoints/openai/test_encoder_decoder.py b/tests/entrypoints/openai/test_encoder_decoder.py index 9c2aef23e877..75612962c95f 100644 --- a/tests/entrypoints/openai/test_encoder_decoder.py +++ b/tests/entrypoints/openai/test_encoder_decoder.py @@ -30,6 +30,7 @@ async def client(server): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.skip(reason="bart is not yet supported in V1") async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): completion = await client.completions.create(model=model_name, prompt="Hello, my name is", From e3bb05bb6abcc0792e2efb8e35857116e98c1bc3 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 6 Aug 2025 00:22:42 +0000 Subject: [PATCH 08/18] skip another bart test not supported in V1 Signed-off-by: Russell Bryant --- .buildkite/test-pipeline.yaml | 2 +- tests/encoder_decoder/test_e2e_correctness.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index aa4c77100022..6a1b12fc6e07 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -499,7 +499,7 @@ steps: - vllm/ - tests/encoder_decoder commands: - - echo TODO # TODO: bart is not yet supported in V1 + - pytest -v -s encoder_decoder - label: OpenAI-Compatible Tool Use # 20 min mirror_hardwares: [amdexperimental] diff --git a/tests/encoder_decoder/test_e2e_correctness.py b/tests/encoder_decoder/test_e2e_correctness.py index 8b99d9d6e21f..3cf4c377fb58 100644 --- a/tests/encoder_decoder/test_e2e_correctness.py +++ b/tests/encoder_decoder/test_e2e_correctness.py @@ -63,6 +63,7 @@ def clear_cache(): current_platform.is_cpu(), reason="CPU backend is not currently supported with encoder/decoder models" ) +@pytest.mark.skip(reason="bart not supported in V1") def test_encoder_decoder_e2e( hf_runner, vllm_runner, From 1d6a71a4a8aeec291624854a28fdd004bc3ffbd8 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 6 Aug 2025 00:24:13 +0000 Subject: [PATCH 09/18] skip another bart test not supported in V1 Signed-off-by: Russell Bryant --- tests/models/language/generation/test_bart.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/language/generation/test_bart.py b/tests/models/language/generation/test_bart.py index b4c771840196..22ceb27869ac 100644 --- a/tests/models/language/generation/test_bart.py +++ b/tests/models/language/generation/test_bart.py @@ -178,6 +178,7 @@ def run_test( @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) +@pytest.mark.skip(reason="bart not supported in V1") def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None: @@ -201,6 +202,7 @@ def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model, @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.CUSTOM]) +@pytest.mark.skip(reason="bart not supported in V1") def test_models_distributed(hf_runner, vllm_runner, example_encoder_decoder_prompts, distributed_executor_backend, model, dtype, From 0d26c402e1caea82dad1c98afdad2ab92a1f7f7c Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 6 Aug 2025 10:23:40 +0000 Subject: [PATCH 10/18] Speed up test_transcription_api_correctness.py Load the tokenizer once instead of blocking async processing and loading it for every request. This vastly speeds up this test as it runs with 500+ samples. Signed-off-by: Russell Bryant --- .../correctness/test_transcription_api_correctness.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py index 58195f98bd35..0d0ce0be8c5f 100644 --- a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py +++ b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py @@ -49,8 +49,7 @@ async def transcribe_audio(client, tokenizer, y, sr): return latency, num_output_tokens, transcription.text -async def bound_transcribe(model_name, sem, client, audio, reference): - tokenizer = AutoTokenizer.from_pretrained(model_name) +async def bound_transcribe(sem, client, tokenizer, audio, reference): # Use semaphore to limit concurrent requests. async with sem: result = await transcribe_audio(client, tokenizer, *audio) @@ -63,15 +62,19 @@ async def bound_transcribe(model_name, sem, client, audio, reference): async def process_dataset(model, client, data, concurrent_request): sem = asyncio.Semaphore(concurrent_request) + # Load tokenizer once outside the loop + tokenizer = AutoTokenizer.from_pretrained(model) + # Warmup call as the first `librosa.load` server-side is quite slow. audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"] - _ = await bound_transcribe(model, sem, client, (audio, sr), "") + _ = await bound_transcribe(sem, client, tokenizer, (audio, sr), "") tasks: list[asyncio.Task] = [] for sample in data: audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"] task = asyncio.create_task( - bound_transcribe(model, sem, client, (audio, sr), sample["text"])) + bound_transcribe(sem, client, tokenizer, (audio, sr), + sample["text"])) tasks.append(task) return await asyncio.gather(*tasks) From a0595798ace9af1414a93c0ebc3751eccbb8132c Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 6 Aug 2025 10:57:37 -0400 Subject: [PATCH 11/18] Update vllm/v1/core/sched/scheduler.py Signed-off-by: Russell Bryant --- vllm/v1/core/sched/scheduler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 05beb6472a61..d6da7e9d9b5c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -985,8 +985,7 @@ def _free_encoder_inputs(self, request: Request) -> None: mm_positions = request.mm_positions[input_id] start_pos = mm_positions.offset num_tokens = mm_positions.length - if (self.encoder_cache_manager - and start_pos + num_tokens <= request.num_computed_tokens): + if start_pos + num_tokens <= request.num_computed_tokens: # The encoder output is already processed and stored # in the decoder's KV cache. self.encoder_cache_manager.free_encoder_input( From 073ec3364dba88a44ce741e2637db25d3f152c1d Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Wed, 6 Aug 2025 14:54:08 +0000 Subject: [PATCH 12/18] test_transcription_api_correctness: set a more generous read timeout Signed-off-by: Russell Bryant --- .../openai/correctness/test_transcription_api_correctness.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py index 0d0ce0be8c5f..2793263c554a 100644 --- a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py +++ b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py @@ -40,6 +40,10 @@ async def transcribe_audio(client, tokenizer, y, sr): model=tokenizer.name_or_path, language="en", temperature=0.0, + # 15 minutes + # The default is too aggressive in some cases, + # depending on the test environment. + timeout=900, ) end_time = time.perf_counter() # NOTE there's no streaming in transcriptions, can't measure ttft From f8826720572fe1f6637ebf67255f6b40afdc3f29 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Thu, 7 Aug 2025 23:29:39 +0000 Subject: [PATCH 13/18] Ensure encoder inputs are only processed once Signed-off-by: Russell Bryant --- vllm/v1/core/sched/scheduler.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d6da7e9d9b5c..2bbcbe19661d 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -737,7 +737,18 @@ def _try_schedule_encoder_inputs( if start_pos >= num_computed_tokens + num_new_tokens: # The encoder input is not needed in this step. break - if start_pos + num_encoder_tokens <= num_computed_tokens: + if self.is_encoder_decoder and start_pos < num_computed_tokens: + # Encoder input has already been computed + # The calculation here is a bit different. We don't turn encoder + # output into tokens that get processed by the decoder and + # reflected in num_computed_tokens. Instead, start_pos reflects + # the position where we need to ensure we calculate encoder + # inputs. This should always be 0 to ensure we calculate encoder + # inputs before running the decoder. Once we've calculated some + # decoder tokens (num_computed_tokens > 0), then we know we + # already calculated encoder inputs and can skip here. + continue + elif start_pos + num_encoder_tokens <= num_computed_tokens: # The encoder input is already computed and stored # in the decoder's KV cache. continue From 94e41ce44951ee6357e26dcb531ff28e12116fa3 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Fri, 8 Aug 2025 00:21:49 +0000 Subject: [PATCH 14/18] fix a problem from the last rebase Signed-off-by: Russell Bryant --- vllm/v1/worker/gpu_model_runner.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b74f4a267997..9e4636e00e77 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -866,11 +866,11 @@ def _prepare_inputs( builder, ) - attn_metadata_i = (encoder_attn_metadata - if is_enc_dec else builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - )) + attn_metadata_i = ( + encoder_attn_metadata if is_enc_dec else builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + )) fast_prefill_metadata = attn_metadata_i if (self.cache_config.kv_sharing_fast_prefill @@ -3247,7 +3247,7 @@ def _build_enc_dec_attn_metadata( cross_slot_mapping, dtype=torch.int64, device=self.device) # Use the first attention metadata builder - builder = self.attn_metadata_builders[0] + builder = self.attn_groups[0][0].metadata_builder return common_metadata, builder.build( common_prefix_len=0, # No cascade for encoder common_attn_metadata=common_metadata, From 957da64ec2dc64ea9478ea5e8fdee5301bc25a0d Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Fri, 8 Aug 2025 18:12:00 +0000 Subject: [PATCH 15/18] Simplify how to disable encoder cache for encoder-decoder Signed-off-by: Russell Bryant --- vllm/v1/core/encoder_cache_manager.py | 12 ++++++++++- vllm/v1/core/sched/scheduler.py | 31 +++++++++------------------ 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 6666dcd5d09d..471b276c7381 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -38,6 +38,7 @@ class EncoderCacheManager: Args: cache_size: Limit the size of the cache, measured by the number of tokens from the input sequence. + is_encoder_decoder: Whether the model is an encoder-decoder model. Attributes: cache_size: Total cache capacity in encoder tokens @@ -48,7 +49,7 @@ class EncoderCacheManager: This is cleared after every call to get_freed_ids(). """ - def __init__(self, cache_size: int): + def __init__(self, cache_size: int, is_encoder_decoder: bool): self.cache_size = cache_size self.num_free_slots = cache_size # req_id -> cached input ids @@ -56,6 +57,12 @@ def __init__(self, cache_size: int): # list of [req_id, input_id] self.freed: list[tuple[str, int]] = [] + # Whether the model is an encoder-decoder model. + # If so, we don't need to cache encoder inputs. + # We handle it here instead of in the scheduler to keep + # the scheduler logic simpler. + self.is_encoder_decoder = is_encoder_decoder + def has_cache(self, request: Request, input_id: int) -> bool: """Check if encoder output for a specific multimodal input is cached. @@ -99,6 +106,9 @@ def allocate(self, request: Request, input_id: int) -> None: This method assumes can_allocate() returned True for the same request and input_id. It will reduce available cache space. """ + if self.is_encoder_decoder: + return + req_id = request.request_id if req_id not in self.cached: self.cached[req_id] = set() diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 2bbcbe19661d..993c0fc8ff68 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -138,13 +138,9 @@ def __init__( # NOTE: For the models without encoder (e.g., text-only models), # the encoder cache will not be initialized because cache size is 0 # for these models. - self.encoder_cache_manager: Optional[EncoderCacheManager] = None - if not self.is_encoder_decoder: - # An encoder-decoder model does not use the encoder cache. - # It uses bidirectional attention and inputs are only - # processed once per request. - self.encoder_cache_manager = EncoderCacheManager( - cache_size=encoder_cache_size) + self.encoder_cache_manager = EncoderCacheManager( + cache_size=encoder_cache_size, + is_encoder_decoder=self.is_encoder_decoder) speculative_config = vllm_config.speculative_config @@ -318,9 +314,8 @@ def schedule(self) -> SchedulerOutput: scheduled_encoder_inputs[request.request_id] = ( encoder_inputs_to_schedule) # Allocate the encoder cache. - if self.encoder_cache_manager: - for i in encoder_inputs_to_schedule: - self.encoder_cache_manager.allocate(request, i) + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) encoder_budget = new_encoder_budget # Record the LoRAs in scheduled_running_reqs @@ -530,9 +525,8 @@ def schedule(self) -> SchedulerOutput: scheduled_encoder_inputs[request.request_id] = ( encoder_inputs_to_schedule) # Allocate the encoder cache. - if self.encoder_cache_manager: - for i in encoder_inputs_to_schedule: - self.encoder_cache_manager.allocate(request, i) + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) encoder_budget = new_encoder_budget # Put back any skipped requests at the head of the waiting queue @@ -578,10 +572,7 @@ def schedule(self) -> SchedulerOutput: scheduled_spec_decode_tokens, req_to_new_block_ids, ) - if self.encoder_cache_manager: - free_encoder_input_ids = self.encoder_cache_manager.get_freed_ids() - else: - free_encoder_input_ids = [] + free_encoder_input_ids = self.encoder_cache_manager.get_freed_ids() scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -753,8 +744,7 @@ def _try_schedule_encoder_inputs( # in the decoder's KV cache. continue - if (self.encoder_cache_manager - and self.encoder_cache_manager.has_cache(request, i)): + if self.encoder_cache_manager.has_cache(request, i): # The encoder input is already computed and cached. continue @@ -768,8 +758,7 @@ def _try_schedule_encoder_inputs( num_new_tokens = start_pos - num_computed_tokens break - if ((self.encoder_cache_manager - and not self.encoder_cache_manager.can_allocate(request, i)) + if (not self.encoder_cache_manager.can_allocate(request, i) or num_encoder_tokens > encoder_budget): # The encoder cache is full or the encoder budget is exhausted. # NOTE(woosuk): We assume that the encoder input tokens should From 16f3d88891acb32916f3d1dab45974b93a409897 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Fri, 8 Aug 2025 13:40:07 -0400 Subject: [PATCH 16/18] Update vllm/v1/worker/gpu_model_runner.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Nicolò Lucchesi Signed-off-by: Russell Bryant --- vllm/v1/worker/gpu_model_runner.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9e4636e00e77..f886cc9917e5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -795,15 +795,12 @@ def _prepare_inputs( # Prepare encoder attention metadata separately # (encoder layers are not in KV cache groups) - if self.is_encoder_only_model or ( - self.model_config.is_encoder_decoder - and scheduler_output.scheduled_encoder_inputs): - if self.is_encoder_only_model: - common_attn_metadata, encoder_attn_metadata = \ + if self.is_encoder_only_model: + common_attn_metadata, encoder_attn_metadata = \ self._build_encoder_only_attn_metadata( scheduler_output) - else: - common_attn_metadata, encoder_attn_metadata = \ + elif self.model_config.is_encoder_decoder and scheduler_output.scheduled_encoder_inputs: + common_attn_metadata, encoder_attn_metadata = \ self._build_enc_dec_attn_metadata( scheduler_output) From f20828cd052435c49b3f37406ff9e7eb01632ea4 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Fri, 8 Aug 2025 18:15:35 +0000 Subject: [PATCH 17/18] Fix formatting error from web based commit Signed-off-by: Russell Bryant --- vllm/v1/worker/gpu_model_runner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f886cc9917e5..f6aea58b6437 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -799,8 +799,9 @@ def _prepare_inputs( common_attn_metadata, encoder_attn_metadata = \ self._build_encoder_only_attn_metadata( scheduler_output) - elif self.model_config.is_encoder_decoder and scheduler_output.scheduled_encoder_inputs: - common_attn_metadata, encoder_attn_metadata = \ + elif (self.model_config.is_encoder_decoder + and scheduler_output.scheduled_encoder_inputs): + common_attn_metadata, encoder_attn_metadata = \ self._build_enc_dec_attn_metadata( scheduler_output) From 1728e068a0c6385c441fe558f2356d452ded9cd4 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Fri, 8 Aug 2025 18:22:59 +0000 Subject: [PATCH 18/18] drop unused encoder inputs for whisper Signed-off-by: Russell Bryant --- vllm/v1/worker/gpu_model_runner.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f6aea58b6437..a21934685535 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1273,20 +1273,8 @@ def _extract_encoder_inputs( input_features = input_features.to(device=self.device, dtype=self.model_config.dtype) - # Create encoder positions - encoder_positions = torch.arange(total_encoder_tokens, - dtype=torch.long, - device=self.device) - - # Create encoder input_ids (dummy tokens for encoder) - encoder_input_ids = torch.zeros(total_encoder_tokens, - dtype=torch.long, - device=self.device) - return { "input_features": input_features, - "encoder_input_ids": encoder_input_ids, - "encoder_positions": encoder_positions, } def _process_input_features(self,