Skip to content

(Dont Merge) Add rwkv6 #6749

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions tests/models/test_mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba.

Run `pytest tests/models/test_mamba.py`.
"""
import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer

from .utils import check_outputs_equal

MODELS = [
"state-spaces/mamba-370m-hf",
]


# Use lower-level interfaces to create this greedy generator, as mamba will
# choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used.
def generate_greedy(model_name, example_prompts, max_tokens):
# Create a text generation pipeline
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Generate texts from the prompts
outputs = []
for prompt in example_prompts:
# Tokenize the input prompt with truncation
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
input_ids = inputs["input_ids"].to(model.device)

# Generate text using the model's generate method directly
generated_ids = model.generate(input_ids, max_new_tokens=max_tokens)
generated_text = tokenizer.decode(generated_ids[0],
skip_special_tokens=True)

outputs.append((generated_ids[0].tolist(), generated_text))

return outputs


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# To pass the small model tests, we need full precision.
assert dtype == "float"

hf_outputs = generate_greedy(model, example_prompts, max_tokens)

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_model_print(
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
167 changes: 167 additions & 0 deletions vllm/attention/backends/placeholder_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple, Type

import torch

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)

# Placeholder attention backend for models like Mamba that don't have attention.
# Mainly exists to sidestep get_attn_backend.
# The attention metadata is still needed for Mamba.


class PlaceholderAttentionBackend(AttentionBackend):
"""Placeholder backend for when no attention is needed."""

@staticmethod
def get_name() -> str:
return "No attention"

@staticmethod
def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
return PlaceholderAttentionImpl

@staticmethod
def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]:
return PlaceholderAttentionMetadata

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (1, 1, 1, 1, 1)

@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
return

@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
return


@dataclass
class PlaceholderAttentionMetadata(AttentionMetadata):
"""Attention metadata for prefill and decode batched together."""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]

# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]

# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]

# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool

_cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None
_cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None

@property
def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
if self.num_prefills == 0:
return None

if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata

assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
assert self.query_start_loc is not None
assert self.context_lens_tensor is not None
assert self.block_tables is not None
assert self.seq_start_loc is not None

self._cached_prefill_metadata = PlaceholderAttentionMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False,
)
return self._cached_prefill_metadata

@property
def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None

if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.block_tables is not None
assert self.seq_lens_tensor is not None

self._cached_decode_metadata = PlaceholderAttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
)
return self._cached_decode_metadata


class PlaceholderAttentionImpl(AttentionImpl):

def __init__(self, *args, **kwargs) -> None:
return

def forward(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError
31 changes: 31 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,19 @@ def verify_with_parallel_config(
raise ValueError(
"BitAndBytes quantization with TP or PP is not supported yet.")

def is_attention_free(self) -> bool:
"""Returns True if the model has no attention, i.e. the model has no
state that grows with the size of the context.
"""

# Return true if the model is mamba.
# This check should be augmented with more models in the future,
# and made more robust if possible.
if hasattr(self.hf_text_config,
"model_type") and self.hf_text_config.model_type == 'mamba':
return True
return False

def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled."""

Expand Down Expand Up @@ -309,6 +322,10 @@ def get_head_size(self) -> int:
# FlashAttention supports only head_size 32, 64, 128, 256,
# we need to pad head_size 192 to 256
return 256

if self.is_attention_free():
return 0

if hasattr(self.hf_text_config, "head_dim"):
return self.hf_text_config.head_dim
# FIXME(woosuk): This may not be true for all models.
Expand Down Expand Up @@ -340,6 +357,9 @@ def get_total_num_kv_heads(self) -> int:
return getattr(self.hf_config.attn_config, "kv_n_heads",
self.hf_config.num_attention_heads)

if self.is_attention_free():
return 0

attributes = [
# For Falcon:
"n_head_kv",
Expand Down Expand Up @@ -390,6 +410,11 @@ def contains_seqlen_agnostic_layers(
def get_layers_block_type(self,
parallel_config: "ParallelConfig") -> List[str]:
num_layers = self.get_num_layers(parallel_config)

if self.is_attention_free():
assert (self.hf_config.model_type == "mamba")
return ["mamba"] * num_layers

# Transformers supports layers_block_type @property
return getattr(self.hf_config, "layers_block_type",
["attention"] * num_layers)
Expand Down Expand Up @@ -428,6 +453,7 @@ def __init__(
gpu_memory_utilization: float,
swap_space: int,
cache_dtype: str,
is_attention_free: bool,
num_gpu_blocks_override: Optional[int] = None,
sliding_window: Optional[int] = None,
enable_prefix_caching: bool = False,
Expand All @@ -437,6 +463,7 @@ def __init__(
self.swap_space_bytes = swap_space * _GB
self.num_gpu_blocks_override = num_gpu_blocks_override
self.cache_dtype = cache_dtype
self.is_attention_free = is_attention_free
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
self._verify_args()
Expand Down Expand Up @@ -731,6 +758,8 @@ class SchedulerConfig:
iteration.
max_model_len: Maximum length of a sequence (including prompt
and generated text).
is_attention_free: True if the running model does not have state that
grows as the context size increases.
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
num_lookahead_slots: The number of slots to allocate per sequence per
step, beyond the known token ids. This is used in speculative
Expand All @@ -753,6 +782,7 @@ def __init__(self,
max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_model_len: int,
is_attention_free: bool,
use_v2_block_manager: bool = False,
num_lookahead_slots: int = 0,
delay_factor: float = 0.0,
Expand All @@ -779,6 +809,7 @@ def __init__(self,

self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
self.is_attention_free = is_attention_free
self.use_v2_block_manager = use_v2_block_manager
self.num_lookahead_slots = num_lookahead_slots
self.delay_factor = delay_factor
Expand Down
8 changes: 4 additions & 4 deletions vllm/core/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def get_block_space_manager_class(version: str):
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
return BlockSpaceManagerV2

if version == "embedding":
from vllm.core.embedding_model_block_manager import (
EmbeddingModelBlockSpaceManager)
return EmbeddingModelBlockSpaceManager
if version == "placeholder":
from vllm.core.placeholder_block_space_manager import (
PlaceholderBlockSpaceManager)
return PlaceholderBlockSpaceManager

raise ValueError(f"Unknown version {version=}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from vllm.sequence import Sequence, SequenceGroup


class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
"""An embedding version of BlockSpaceManager for use in environments
with embedding models where block management is not required.
class PlaceholderBlockSpaceManager(BlockSpaceManager):
"""A version of BlockSpaceManager for use in environments
where block management is not required.
For example: embedding models or attention-free models like Mamba.

This class provides the same interface as BlockSpaceManager, but its
methods perform no actions or return simple values like True in specific
Expand Down Expand Up @@ -37,7 +38,7 @@ def append_slots(
seq: Sequence,
num_lookahead_slots: int,
) -> List[Tuple[int, int]]:
return None # type: ignore
return []

def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
pass
Expand Down
5 changes: 3 additions & 2 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,9 @@ def __init__(
version = "v1"
if self.scheduler_config.use_v2_block_manager:
version = "v2"
if self.scheduler_config.embedding_mode:
version = "embedding"
if (self.scheduler_config.embedding_mode
or self.scheduler_config.is_attention_free):
version = "placeholder"

BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
version)
Expand Down
Loading