diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index 4fdc7a3cf709..b6007b9f4630 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -3,12 +3,12 @@ import argparse import datetime import os -import re from typing import Union import albumentations import numpy as np import rasterio +import regex as re import torch from einops import rearrange from terratorch.datamodules import Sen1Floods11NonGeoDataModule diff --git a/tests/conftest.py b/tests/conftest.py index a18dbf58c803..fd4956bdb24c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1062,8 +1062,17 @@ def score( return [req_output.outputs.score for req_output in req_outputs] def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: - executor = self.llm.llm_engine.model_executor - return executor.apply_model(func) + if hasattr(self.llm.llm_engine, "model_executor"): + # This works either in V0 or in V1 with + # VLLM_ENABLE_V1_MULTIPROCESSING=0 + executor = self.llm.llm_engine.model_executor + return executor.apply_model(func) + + # This works in V1 with VLLM_ALLOW_INSECURE_SERIALIZATION=1 + def _apply_model(self): + return func(self.get_model()) + + return self.llm.llm_engine.collective_rpc(_apply_model) def __enter__(self): return self diff --git a/tests/model_executor/test_model_load_with_params.py b/tests/model_executor/test_model_load_with_params.py index aae9a4d1ef11..667a63e76932 100644 --- a/tests/model_executor/test_model_load_with_params.py +++ b/tests/model_executor/test_model_load_with_params.py @@ -22,10 +22,12 @@ @pytest.mark.skipif(current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm.") -def test_model_loading_with_params(vllm_runner): +def test_model_loading_with_params(vllm_runner, monkeypatch): """ Test parameter weight loading with tp>1. """ + # to use apply_model + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") with vllm_runner(model_name=MODEL_NAME, revision=REVISION, dtype="float16", @@ -61,10 +63,12 @@ def check_model(model): @pytest.mark.skipif(current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm.") -def test_roberta_model_loading_with_params(vllm_runner): +def test_roberta_model_loading_with_params(vllm_runner, monkeypatch): """ Test parameter weight loading with tp>1. """ + # to use apply_model + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") with vllm_runner(model_name=MODEL_NAME_ROBERTA, revision=REVISION_ROBERTA, dtype="float16", @@ -101,10 +105,12 @@ def check_model(model): @pytest.mark.skipif(current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm.") -def test_facebook_roberta_model_loading_with_params(vllm_runner): +def test_facebook_roberta_model_loading_with_params(vllm_runner, monkeypatch): """ Test loading roberta-base model with no lm_head. """ + # to use apply_model + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") model_name = "FacebookAI/roberta-base" with vllm_runner(model_name=model_name, dtype="float16", diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index cc9e4102d5b7..ba42e389fc15 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -39,17 +39,9 @@ def v1(run_with_both_engines): pytest.param("ssmits/Qwen2-7B-Instruct-embed-base", marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]), # [Encoder-only] - pytest.param( - "BAAI/bge-base-en-v1.5", - marks=[ - # CPU only supports V1 - pytest.mark.core_model, - pytest.mark.skip_v1 - ]), - pytest.param("sentence-transformers/all-MiniLM-L12-v2", - marks=[pytest.mark.skip_v1]), - pytest.param("intfloat/multilingual-e5-small", - marks=[pytest.mark.skip_v1]), + pytest.param("BAAI/bge-base-en-v1.5", marks=[pytest.mark.core_model]), + pytest.param("sentence-transformers/all-MiniLM-L12-v2"), + pytest.param("intfloat/multilingual-e5-small"), pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct", marks=[pytest.mark.skip_v1]), # [Cross-Encoder] diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling/test_jina.py index 16c711407aea..a4681baa51ef 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling/test_jina.py @@ -23,6 +23,14 @@ ] +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.parametrize("model_info", EMBEDDING_MODELS) def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 69bd4a2060ae..ae2ab6e6413c 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -93,6 +93,7 @@ def create_common_attn_metadata( max_query_len=max_query_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, + causal=True, ) diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index b4d4348c7fd9..cc59287a9fbe 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -13,7 +13,6 @@ "openai/whisper-large-v3", # transcription "facebook/bart-large-cnn", # encoder decoder "state-spaces/mamba-130m-hf", # mamba1 - "BAAI/bge-m3", # embedding ] MODEL = "meta-llama/Llama-3.2-1B-Instruct" diff --git a/tests/v1/test_utils.py b/tests/v1/test_utils.py index 0b892bd9dffd..00d98a873a31 100644 --- a/tests/v1/test_utils.py +++ b/tests/v1/test_utils.py @@ -1,9 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import re - import pytest +import regex as re import requests import torch diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 709968004718..b7dbff397d2f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1649,7 +1649,8 @@ def _set_default_args_v1(self, usage_context: UsageContext, if (self.max_num_seqs is None and usage_context in default_max_num_seqs): - self.max_num_seqs = default_max_num_seqs[usage_context] + self.max_num_seqs = min(default_max_num_seqs[usage_context], + self.max_num_batched_tokens or sys.maxsize) logger.debug("Setting max_num_seqs to %d for %s usage context.", self.max_num_seqs, use_context_value) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index c3066aaa2b87..504621c8abd8 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -12,7 +12,6 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, @@ -60,7 +59,6 @@ def __init__(self, config: BertConfig): def forward( self, input_ids: torch.Tensor, - seq_lens: torch.Tensor, position_ids: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -119,7 +117,6 @@ def forward( return pooled_output -@support_torch_compile class BertEncoder(nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = ""): @@ -337,6 +334,7 @@ def forward(self, hidden_states: torch.Tensor, return hidden_states +@support_torch_compile class BertModel(nn.Module, SupportsQuant): is_pooling_model = True @@ -368,13 +366,9 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - attn_metadata = get_forward_context().attn_metadata - assert hasattr(attn_metadata, "seq_lens_tensor") - hidden_states = self.embeddings( - input_ids=input_ids, - seq_lens=attn_metadata.seq_lens_tensor, - position_ids=position_ids, - token_type_ids=token_type_ids) + hidden_states = self.embeddings(input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids) return self.encoder(hidden_states) def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): @@ -447,7 +441,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): +class BertEmbeddingModel(nn.Module, SupportsQuant): """A model that uses Bert to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for @@ -474,11 +468,13 @@ def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.model(input_ids=input_ids, position_ids=positions, + token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index c6b411644034..77e072c79275 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -9,6 +9,7 @@ from transformers import RobertaConfig from vllm.config import VllmConfig +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool, DispatchPooler, Pooler) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -51,33 +52,12 @@ def __init__(self, config: RobertaConfig): def forward( self, input_ids: torch.Tensor, - seq_lens: torch.Tensor, position_ids: torch.Tensor, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: input_shape = input_ids.size() inputs_embeds = self.word_embeddings(input_ids) - # Replace position ids because in RoBERTa models - # they have to start at padding_idx + 1 and ignore - # existing padding tokens - # References: - # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 - # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 - seq_lens_list = seq_lens.tolist() - new_pos_list = [] - for positions, tokens in zip(position_ids.split(seq_lens_list), - input_ids.split(seq_lens_list)): - # Verify assumption that incoming position are - # always a sequence from 0 to N. - expected_pos = torch.arange(positions.size()[0], - dtype=torch.long, - device=inputs_embeds.device) - assert torch.equal(positions, expected_pos) - new_pos_list.append( - create_position_ids_from_input_ids(tokens, self.padding_idx)) - position_ids = torch.cat(new_pos_list) - # Position embeddings. position_embeddings = self.position_embeddings(position_ids) if token_type_ids is None: @@ -119,6 +99,32 @@ class RobertaEmbeddingModel(BertEmbeddingModel): _pooler: An instance of Pooler used for pooling operations. """ + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + self.padding_idx = vllm_config.model_config.hf_config.pad_token_id + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # Fix Roberta positions here outside of the CUDA graph. + # Because we need the to extract the sequences from + # input_ids the control flow is data dependent. + replace_roberta_positions(input_ids=input_ids, + position_ids=positions, + padding_idx=self.padding_idx) + + return self.model(input_ids=input_ids, + position_ids=positions, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) + def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> Union[BertModel, BertWithRope]: @@ -175,6 +181,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config + self.padding_idx = vllm_config.model_config.hf_config.pad_token_id self.num_labels = config.num_labels self.roberta = BertModel(vllm_config=vllm_config, @@ -216,6 +223,9 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: + replace_roberta_positions(input_ids=input_ids, + position_ids=positions, + padding_idx=self.padding_idx) return self.roberta(input_ids=input_ids, position_ids=positions, inputs_embeds=inputs_embeds, @@ -245,3 +255,36 @@ def create_position_ids_from_input_ids(input_ids, past_key_values_length) * mask return incremental_indices.long() + padding_idx + + +def replace_roberta_positions(input_ids: torch.Tensor, + position_ids: torch.Tensor, + padding_idx: int) -> None: + + seq_lens: Optional[torch.Tensor] = None + attn_metadata = get_forward_context().attn_metadata + if attn_metadata is not None: # can be None during warmup + if isinstance(attn_metadata, dict): + attn_metadata = next(iter(attn_metadata.values())) + # TODO: remove "seq_lens_tensor" after V0 is removed + seq_lens = getattr(attn_metadata, "seq_lens_tensor", + getattr(attn_metadata, "seq_lens", None)) + + if seq_lens is not None: + assert isinstance(seq_lens, torch.Tensor) + + # Replace position ids because in RoBERTa models + # they have to start at padding_idx + 1 and ignore + # existing padding tokens + # References: + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 + # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 + token_list = torch.split(input_ids[:torch.sum(seq_lens)], + seq_lens.tolist()) + + offset = 0 + for tokens in token_list: + length = tokens.shape[0] + position_ids[offset:offset+length] = \ + create_position_ids_from_input_ids(tokens, padding_idx) + offset = offset + length diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 5fe274f2c65b..7c8a5e056fea 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -130,6 +130,8 @@ class FlashAttentionMetadata: prefix_scheduler_metadata: Optional[torch.Tensor] = None max_num_splits: int = 0 + causal: bool = True + def _get_sliding_window_configs( vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]: @@ -213,6 +215,7 @@ def build(self, seq_lens_cpu = common_attn_metadata.seq_lens_cpu block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping + causal = common_attn_metadata.causal # the overhead of the aot schedule is not worth it for spec-decode aot_schedule = self.aot_schedule and not fast_build @@ -288,7 +291,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, max_query_len=max_query_len, seqlens=seq_lens, max_seq_len=max_seq_len, - causal=True) + causal=causal) if self.use_full_cuda_graph: assert scheduler_metadata is not None @@ -326,7 +329,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, max_num_splits=max_num_splits, - ) + causal=causal) return attn_metadata def can_run_in_cudagraph( @@ -375,11 +378,14 @@ def __init__( FlashAttentionBackend.validate_head_size(head_size) - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " + if attn_type not in [ + AttentionType.DECODER, AttentionType.ENCODER_ONLY + ]: + raise NotImplementedError("Encoder/decoder cross-attention " + "is not implemented for " "FlashAttentionImpl") + + self.attn_type = attn_type self.vllm_flash_attn_version = get_flash_attn_version() if is_quantized_kv_cache(self.kv_cache_dtype) \ and not flash_attn_supports_fp8(): @@ -422,6 +428,8 @@ def forward( # Profiling run. return output + attn_type = self.attn_type + # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead @@ -432,6 +440,18 @@ def forward( # performance to make sure it does not introduce any overhead. num_actual_tokens = attn_metadata.num_actual_tokens + + # Handle encoder attention differently - no KV cache needed + if attn_type in (AttentionType.ENCODER_ONLY, ): + # For encoder attention, + # we use direct Q, K, V tensors without caching + return self._forward_encoder_attention(query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + output[:num_actual_tokens], + attn_metadata, layer) + + # For decoder and cross-attention, use KV cache as before key_cache, value_cache = kv_cache.unbind(0) if self.kv_sharing_target_layer_name is None: @@ -483,7 +503,7 @@ def forward( seqused_k=seqused_k, max_seqlen_k=max_seqlen_k, softmax_scale=self.scale, - causal=True, + causal=attn_metadata.causal, alibi_slopes=self.alibi_slopes, window_size=self.sliding_window, block_table=block_table, @@ -524,6 +544,63 @@ def forward( ) return output + def _forward_encoder_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + layer: torch.nn.Module, + ) -> torch.Tensor: + """Forward pass for encoder attention without KV cache. + + Args: + query: shape = [num_encoder_tokens, num_heads, head_size] + key: shape = [num_encoder_tokens, num_kv_heads, head_size] + value: shape = [num_encoder_tokens, num_kv_heads, head_size] + output: shape = [num_encoder_tokens, num_heads, head_size] + attn_metadata: Encoder attention metadata + layer: The attention layer + """ + # For encoder attention, process FP8 quantization if needed + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError( + "quantization is not supported for encoder attention") + + # Use encoder-specific metadata for sequence information + cu_seqlens_q = attn_metadata.query_start_loc + cu_seqlens_k = attn_metadata.query_start_loc + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_query_len + + descale_shape = ( + cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr] + self.num_kv_heads) + + # Call flash attention directly on Q, K, V tensors + flash_attn_varlen_func( + q=query, + k=key, + v=value, + out=output, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=False, # Encoder attention is bidirectional + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + softcap=self.logits_soft_cap, + fa_version=self.vllm_flash_attn_version, + q_descale=layer._q_scale.expand(descale_shape), + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + + return output + def use_cascade_attention( common_prefix_len: int, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index fc8649d587ee..b13362f8a8d8 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -59,6 +59,8 @@ class CommonAttentionMetadata: block_table_tensor: torch.Tensor slot_mapping: torch.Tensor + causal: bool = True + M = TypeVar("M") @@ -395,6 +397,7 @@ def make_local_attention_virtual_batches( max_query_len=seqlens_q_local.max(), block_table_tensor=block_table_local, slot_mapping=common_attn_metadata.slot_mapping, + causal=True, ) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 4124ee05326c..57f60c4b289b 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -111,6 +111,12 @@ def __init__(self, "compatibility may not be maintained.", vllm_config.scheduler_config.scheduler_cls) + if len(kv_cache_config.kv_cache_groups) == 0: + # Encoder models without KV cache don't support + # chunked prefill. But do SSM models? + logger.info("Disabling chunked prefill for model without KVCache") + vllm_config.scheduler_config.chunked_prefill_enabled = False + self.scheduler: SchedulerInterface = Scheduler( vllm_config=vllm_config, kv_cache_config=kv_cache_config, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 967847c02ff2..63f6fc276189 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -330,6 +330,7 @@ def prepare_inputs( max_query_len=new_query_len_per_req.max().item(), block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], + causal=True, ) return spec_common_attn_metadata, token_indices diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index ca94ac8c6054..6b2b50a57e1f 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -4,6 +4,7 @@ from typing import Any import torch +import torch.nn as nn from vllm.config import VllmConfig from vllm.logger import init_logger @@ -59,6 +60,9 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.scheduler_config, self.lora_config, self.device) + def get_model(self) -> nn.Module: + return self.model + def warming_up_model(self) -> None: logger.info("Warming up model for the compilation...") # Only generate graph for the generic shape diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5fe594db667a..8fe074b9bfd2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -126,6 +126,7 @@ def __init__( self.is_multimodal_model = model_config.is_multimodal_model self.is_pooling_model = model_config.pooler_config is not None + self.is_encoder_only_model = False self.model_supports_multimodal_raw_input = ( model_config.model_supports_multimodal_raw_input) self.max_model_len = model_config.max_model_len @@ -735,6 +736,21 @@ def _prepare_inputs( spec_decode_common_attn_metadata = None attn_metadata: dict[str, Any] = {} + + # Prepare encoder attention metadata separately + # (encoder layers are not in KV cache groups) + if self.is_encoder_only_model: + common_attn_metadata, encoder_attn_metadata = \ + self._build_encoder_only_attn_metadata( + scheduler_output) + + # Add encoder attention metadata for all encoder layers + attention_layers = get_layers_from_vllm_config( + self.vllm_config, Attention) + for layer_name, attn_module in attention_layers.items(): + if attn_module.attn_type == AttentionType.ENCODER_ONLY: + attn_metadata[layer_name] = encoder_attn_metadata + # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -760,6 +776,7 @@ def _prepare_inputs( max_query_len=max_num_scheduled_tokens, block_table_tensor=blk_table_tensor, slot_mapping=slot_mapping, + causal=True, ) if self.speculative_config and \ @@ -2102,7 +2119,8 @@ def _dummy_run( block_table_tensor=self.input_batch.block_table[ kv_cache_group_id].get_device_tensor()[:num_reqs], slot_mapping=self.input_batch. - block_table[kv_cache_group_id].slot_mapping[:num_tokens]) + block_table[kv_cache_group_id].slot_mapping[:num_tokens], + causal=True) attn_metadata_i = self.attn_metadata_builders[ kv_cache_group_id].build_for_cudagraph_capture( @@ -2466,6 +2484,49 @@ def freeze_gc(): logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) + def _initialize_single_attn_backend( + self, kv_cache_spec: KVCacheSpec + ) -> tuple[AttentionBackend, AttentionMetadataBuilder]: + if isinstance(kv_cache_spec, AttentionSpec): + attn_backend_i = get_attn_backend( + kv_cache_spec.head_size, + self.dtype, + kv_cache_spec.dtype, + kv_cache_spec.block_size, + self.model_config.is_attention_free, + use_mla=kv_cache_spec.use_mla, + ) + if attn_backend_i is None: + error_msg = (f"Error with get_attn_backend: " + f"{kv_cache_spec.head_size=}, " + f"{self.dtype=}, {kv_cache_spec.dtype=}, " + f"{kv_cache_spec.block_size=}, " + f"{self.model_config.is_attention_free=}, " + f"{kv_cache_spec.use_mla=}") + logger.error(error_msg) + raise NotImplementedError( + "Non-Attention backend is not supported by V1 " + "GPUModelRunner.") + elif isinstance(kv_cache_spec, MambaSpec): + attn_backend_i = Mamba2AttentionBackend + else: + raise ValueError( + f"Unknown KV cache spec type: {type(kv_cache_spec)}") + + attn_metadata_builder_i = attn_backend_i.get_builder_cls()( + kv_cache_spec, + self.vllm_config, + self.device, + ) + + if (self.full_cuda_graph + and not attn_metadata_builder_i.full_cudagraph_supported): + raise ValueError( + f"Full CUDAGraph not supported for " + f"{attn_backend_i.__name__}. Turn off CompilationConfig." + f"full_cuda_graph or use a different attention backend.") + return attn_backend_i, attn_metadata_builder_i + def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the attention backends and attention metadata builders. @@ -2476,48 +2537,45 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: for i, kv_cache_group_spec in enumerate( kv_cache_config.kv_cache_groups): kv_cache_spec = kv_cache_group_spec.kv_cache_spec - if isinstance(kv_cache_spec, AttentionSpec): - attn_backend_i = get_attn_backend( - kv_cache_spec.head_size, - self.dtype, - kv_cache_spec.dtype, - kv_cache_spec.block_size, - self.model_config.is_attention_free, - use_mla=kv_cache_spec.use_mla, - ) - if attn_backend_i is None: - error_msg = (f"Error with get_attn_backend: " - f"{kv_cache_spec.head_size=}, " - f"{self.dtype=}, {kv_cache_spec.dtype=}, " - f"{kv_cache_spec.block_size=}, " - f"{self.model_config.is_attention_free=}, " - f"{kv_cache_spec.use_mla=}") - logger.error(error_msg) - raise NotImplementedError( - "Non-Attention backend is not supported by V1 " - "GPUModelRunner.") - elif isinstance(kv_cache_spec, MambaSpec): - attn_backend_i = Mamba2AttentionBackend - else: - raise ValueError( - f"Unknown KV cache spec type: {type(kv_cache_spec)}") - - attn_metadata_builder_i = attn_backend_i.get_builder_cls()( - kv_cache_spec, - self.vllm_config, - self.device, - ) - - if (self.full_cuda_graph - and not attn_metadata_builder_i.full_cudagraph_supported): - raise ValueError( - f"Full CUDAGraph not supported for " - f"{attn_backend_i.__name__}. Turn off CompilationConfig." - f"full_cuda_graph or use a different attention backend.") + attn_backend_i, attn_metadata_builder_i = \ + self._initialize_single_attn_backend(kv_cache_spec) self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) + if len(self.attn_backends) > 0: + return + + # Check if model is encoder-only + block_size = self.vllm_config.cache_config.block_size + use_mla = self.vllm_config.model_config.use_mla + attn_specs = list[AttentionSpec]() + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + for attn_module in attn_layers.values(): + + if attn_module.attn_type == AttentionType.ENCODER_ONLY: + assert attn_module.sliding_window is None, "Sliding " + "window attention is not supported for encoder-only models" + + attn_specs.append( + FullAttentionSpec(block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla)) + else: + raise ValueError("Expected only encoder-only layers") + + if len(attn_specs) > 0: + assert len(attn_specs) == len(attn_layers), \ + "All or none of the layers are expected to be encoder-only" + + attn_backend, attn_metadata_builder = \ + self._initialize_single_attn_backend(attn_specs[0]) + self.attn_backends.append(attn_backend) + self.attn_metadata_builders.append(attn_metadata_builder) + self.is_encoder_only_model = True + def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: """ @@ -2833,3 +2891,53 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: page_size_padded=page_size_padded) return kv_cache_spec + + def _build_encoder_only_attn_metadata( + self, scheduler_output: "SchedulerOutput") -> \ + tuple[CommonAttentionMetadata, Any]: + """Prepare encoder attention metadata for encoder-only models. + + Args: + scheduler_output: Scheduler output + + Returns: + dict[str, Any]: Encoder attention metadata + """ + num_reqs = self.input_batch.num_reqs + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + + # Get the number of scheduled tokens for each request. + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + max_num_scheduled_tokens = max(tokens) + + # Use the first attention metadata builder + # to create encoder attention metadata + builder = self.attn_metadata_builders[0] + + dummy_block_table = torch.zeros((num_reqs, 1), + dtype=torch.int32, + device=self.device) + dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ), + dtype=torch.int32, + device=self.device) + + common_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], + seq_lens=self.seq_lens[:num_reqs], + seq_lens_cpu=self.seq_lens_cpu[:num_reqs], + num_computed_tokens_cpu=self.input_batch. + num_computed_tokens_cpu_tensor[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + block_table_tensor=dummy_block_table, + slot_mapping=dummy_slot_mapping, + causal=False, + ) + + return common_metadata, builder.build( + common_prefix_len=0, # No cascade for encoder + common_attn_metadata=common_metadata, + )