diff --git a/tests/models/test_mamba.py b/tests/models/test_mamba.py new file mode 100644 index 00000000000..509027681f4 --- /dev/null +++ b/tests/models/test_mamba.py @@ -0,0 +1,77 @@ +"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba. + +Run `pytest tests/models/test_mamba.py`. +""" +import pytest +from transformers import AutoModelForCausalLM, AutoTokenizer + +from .utils import check_outputs_equal + +MODELS = [ + "state-spaces/mamba-370m-hf", +] + + +# Use lower-level interfaces to create this greedy generator, as mamba will +# choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used. +def generate_greedy(model_name, example_prompts, max_tokens): + # Create a text generation pipeline + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained(model_name) + + # Generate texts from the prompts + outputs = [] + for prompt in example_prompts: + # Tokenize the input prompt with truncation + inputs = tokenizer(prompt, return_tensors="pt", truncation=True) + input_ids = inputs["input_ids"].to(model.device) + + # Generate text using the model's generate method directly + generated_ids = model.generate(input_ids, max_new_tokens=max_tokens) + generated_text = tokenizer.decode(generated_ids[0], + skip_special_tokens=True) + + outputs.append((generated_ids[0].tolist(), generated_text)) + + return outputs + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # To pass the small model tests, we need full precision. + assert dtype == "float" + + hf_outputs = generate_greedy(model, example_prompts, max_tokens) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_model_print( + vllm_runner, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, dtype=dtype) as vllm_model: + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py new file mode 100644 index 00000000000..f5728756c6e --- /dev/null +++ b/vllm/attention/backends/placeholder_attn.py @@ -0,0 +1,167 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata) + +# Placeholder attention backend for models like Mamba that don't have attention. +# Mainly exists to sidestep get_attn_backend. +# The attention metadata is still needed for Mamba. + + +class PlaceholderAttentionBackend(AttentionBackend): + """Placeholder backend for when no attention is needed.""" + + @staticmethod + def get_name() -> str: + return "No attention" + + @staticmethod + def get_impl_cls() -> Type["PlaceholderAttentionImpl"]: + return PlaceholderAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]: + return PlaceholderAttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (1, 1, 1, 1, 1) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + return + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + return + + +@dataclass +class PlaceholderAttentionMetadata(AttentionMetadata): + """Attention metadata for prefill and decode batched together.""" + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None + _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = PlaceholderAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = PlaceholderAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + +class PlaceholderAttentionImpl(AttentionImpl): + + def __init__(self, *args, **kwargs) -> None: + return + + def forward(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError diff --git a/vllm/config.py b/vllm/config.py index de7bb3943a4..26ad889cceb 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -276,6 +276,19 @@ def verify_with_parallel_config( raise ValueError( "BitAndBytes quantization with TP or PP is not supported yet.") + def is_attention_free(self) -> bool: + """Returns True if the model has no attention, i.e. the model has no + state that grows with the size of the context. + """ + + # Return true if the model is mamba. + # This check should be augmented with more models in the future, + # and made more robust if possible. + if hasattr(self.hf_text_config, + "model_type") and self.hf_text_config.model_type == 'mamba': + return True + return False + def get_hf_config_sliding_window(self) -> Optional[int]: """Get the sliding window size, or None if disabled.""" @@ -309,6 +322,10 @@ def get_head_size(self) -> int: # FlashAttention supports only head_size 32, 64, 128, 256, # we need to pad head_size 192 to 256 return 256 + + if self.is_attention_free(): + return 0 + if hasattr(self.hf_text_config, "head_dim"): return self.hf_text_config.head_dim # FIXME(woosuk): This may not be true for all models. @@ -340,6 +357,9 @@ def get_total_num_kv_heads(self) -> int: return getattr(self.hf_config.attn_config, "kv_n_heads", self.hf_config.num_attention_heads) + if self.is_attention_free(): + return 0 + attributes = [ # For Falcon: "n_head_kv", @@ -390,6 +410,11 @@ def contains_seqlen_agnostic_layers( def get_layers_block_type(self, parallel_config: "ParallelConfig") -> List[str]: num_layers = self.get_num_layers(parallel_config) + + if self.is_attention_free(): + assert (self.hf_config.model_type == "mamba") + return ["mamba"] * num_layers + # Transformers supports layers_block_type @property return getattr(self.hf_config, "layers_block_type", ["attention"] * num_layers) @@ -428,6 +453,7 @@ def __init__( gpu_memory_utilization: float, swap_space: int, cache_dtype: str, + is_attention_free: bool, num_gpu_blocks_override: Optional[int] = None, sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, @@ -437,6 +463,7 @@ def __init__( self.swap_space_bytes = swap_space * _GB self.num_gpu_blocks_override = num_gpu_blocks_override self.cache_dtype = cache_dtype + self.is_attention_free = is_attention_free self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching self._verify_args() @@ -731,6 +758,8 @@ class SchedulerConfig: iteration. max_model_len: Maximum length of a sequence (including prompt and generated text). + is_attention_free: True if the running model does not have state that + grows as the context size increases. use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not. num_lookahead_slots: The number of slots to allocate per sequence per step, beyond the known token ids. This is used in speculative @@ -753,6 +782,7 @@ def __init__(self, max_num_batched_tokens: Optional[int], max_num_seqs: int, max_model_len: int, + is_attention_free: bool, use_v2_block_manager: bool = False, num_lookahead_slots: int = 0, delay_factor: float = 0.0, @@ -779,6 +809,7 @@ def __init__(self, self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len + self.is_attention_free = is_attention_free self.use_v2_block_manager = use_v2_block_manager self.num_lookahead_slots = num_lookahead_slots self.delay_factor = delay_factor diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 8759ee06795..d964898af19 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -35,10 +35,10 @@ def get_block_space_manager_class(version: str): from vllm.core.block_manager_v2 import BlockSpaceManagerV2 return BlockSpaceManagerV2 - if version == "embedding": - from vllm.core.embedding_model_block_manager import ( - EmbeddingModelBlockSpaceManager) - return EmbeddingModelBlockSpaceManager + if version == "placeholder": + from vllm.core.placeholder_block_space_manager import ( + PlaceholderBlockSpaceManager) + return PlaceholderBlockSpaceManager raise ValueError(f"Unknown version {version=}") diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/placeholder_block_space_manager.py similarity index 90% rename from vllm/core/embedding_model_block_manager.py rename to vllm/core/placeholder_block_space_manager.py index f2d67306d7c..a71e6f79b6d 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/placeholder_block_space_manager.py @@ -4,9 +4,10 @@ from vllm.sequence import Sequence, SequenceGroup -class EmbeddingModelBlockSpaceManager(BlockSpaceManager): - """An embedding version of BlockSpaceManager for use in environments - with embedding models where block management is not required. +class PlaceholderBlockSpaceManager(BlockSpaceManager): + """A version of BlockSpaceManager for use in environments + where block management is not required. + For example: embedding models or attention-free models like Mamba. This class provides the same interface as BlockSpaceManager, but its methods perform no actions or return simple values like True in specific @@ -37,7 +38,7 @@ def append_slots( seq: Sequence, num_lookahead_slots: int, ) -> List[Tuple[int, int]]: - return None # type: ignore + return [] def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: pass diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 6e59c5e0f74..d92dce77b3c 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -279,8 +279,9 @@ def __init__( version = "v1" if self.scheduler_config.use_v2_block_manager: version = "v2" - if self.scheduler_config.embedding_mode: - version = "embedding" + if (self.scheduler_config.embedding_mode + or self.scheduler_config.is_attention_free): + version = "placeholder" BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( version) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b972573c025..20bfd71221e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -664,6 +664,7 @@ def create_engine_config(self, ) -> EngineConfig: gpu_memory_utilization=self.gpu_memory_utilization, swap_space=self.swap_space, cache_dtype=self.kv_cache_dtype, + is_attention_free=model_config.is_attention_free(), num_gpu_blocks_override=self.num_gpu_blocks_override, sliding_window=model_config.get_sliding_window(), enable_prefix_caching=self.enable_prefix_caching) @@ -708,6 +709,7 @@ def create_engine_config(self, ) -> EngineConfig: max_num_batched_tokens=self.max_num_batched_tokens, max_num_seqs=self.max_num_seqs, max_model_len=model_config.max_model_len, + is_attention_free=model_config.is_attention_free(), use_v2_block_manager=self.use_v2_block_manager, num_lookahead_slots=(self.num_lookahead_slots if speculative_config is None else diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 622221d2dd1..f1ce03171eb 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -260,6 +260,9 @@ def __init__( ) if not self.model_config.embedding_mode: + # For all decoders including attention-free models like mamba, + # this must call _initialize_kv_caches, as this is where model + # warmup and CUDA graphs creation happens. self._initialize_kv_caches() # If usage stat is enabled, collect relevant info. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 87508a1168e..7ce2b6fa5c9 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -68,13 +68,13 @@ "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "MedusaModel": ("medusa", "Medusa"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), - "JambaForCausalLM": ("jamba", "JambaForCausalLM") + "JambaForCausalLM": ("jamba", "JambaForCausalLM"), + "MambaForCausalLM": ("mamba", "MambaForCausalLM") } _EMBEDDING_MODELS = { "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), } - _MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS} # Architecture -> type. diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 6fdacd44697..b0b614d6b52 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -152,7 +152,7 @@ class HasInnerState(Protocol): """ A flag that indicates this model has inner state. Models that has inner state usually need access to the scheduler_config - for max_num_seqs ,etc... (Currently only used by Jamba) + for max_num_seqs ,etc... (Currently used by Jamba and Mamba) """ def __init__(self, diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py new file mode 100644 index 00000000000..49cfd5c1868 --- /dev/null +++ b/vllm/model_executor/models/mamba.py @@ -0,0 +1,702 @@ +# coding=utf-8 +"""PyTorch MAMBA model.""" +from dataclasses import dataclass +from typing import Dict, Iterable, List, Optional, Tuple + +import torch +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from torch import nn +from torch.nn.parameter import Parameter +from transformers import MambaConfig + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import HasInnerState +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, + _get_graph_batch_size) + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +@dataclass +class MambaCacheParams: + is_prompt: bool = False + conv_state: torch.Tensor = torch.Tensor() + ssm_state: torch.Tensor = torch.Tensor() + + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +class MambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) + """ + + def __init__(self, config: MambaConfig, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = config.intermediate_size + self.time_step_rank = int(config.time_step_rank) + self.use_conv_bias = config.use_conv_bias + + # TODO: ?? + #self.use_bias = config.mamba_proj_bias + self.use_bias = False + + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.intermediate_size, + bias=self.use_conv_bias, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear(self.hidden_size, + [self.intermediate_size] * 2, + bias=self.use_bias) + # selective projection used to make dt, B and C input dependent + self.x_proj = RowParallelLinear( + self.intermediate_size, + self.time_step_rank + self.ssm_state_size * 2, + bias=False, + ) + # time step projection (discretization) - + # In the forward we need to apply dt_proj without the bias, + # as the bias is added in the selective scan kernel. + self.dt_proj = ColumnParallelLinear(self.time_step_rank, + self.intermediate_size, + bias=True, + skip_bias_add=True) + + def weight_loader(param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + param.data.copy_( + loaded_weight.data.split(loaded_weight.shape[0] // tp_size, + dim=0)[tp_rank]) + + def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): + weight_loader(param, -torch.exp(loaded_weight.float())) + + tp_size = get_tensor_model_parallel_world_size() + self.A = nn.Parameter( + torch.empty( + self.intermediate_size // tp_size, + self.ssm_state_size, + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) + + set_weight_attrs(self.D, {"weight_loader": weight_loader}) + set_weight_attrs(self.A, {"weight_loader": A_weight_loader}) + + self.out_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=self.use_bias, + input_is_parallel=True, + ) + self.activation = config.hidden_act + + def mamba_forward(self, + hidden_states: torch.Tensor, + cache_params: MambaCacheParams = None): + + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states)[0].transpose(1, 2) + hidden_states, gate = projected_states.chunk(2, dim=1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + if cache_params is not None and not cache_params.is_prompt: + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), + cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.unsqueeze(-1) + else: + if cache_params is not None: + conv_states = nn.functional.pad( + hidden_states, + (self.conv_kernel_size - hidden_states.shape[-1], 0)) + cache_params.conv_state.copy_(conv_states) + + hidden_states = causal_conv1d_fn( + hidden_states, + conv_weights, + self.conv1d.bias, + activation=self.activation, + ) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0] + + time_step, B, C = torch.split( + ssm_parameters, + [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], + dim=-1, + ) + + # Note that Jamba normalizes B, C, and time_step here but Mamba doesn't. + + discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = (self.dt_proj.bias.float() if hasattr( + self.dt_proj, "bias") else None) + if cache_params is not None and not cache_params.is_prompt: + scan_outputs = selective_state_update( + cache_params.ssm_state, + hidden_states[..., 0], + discrete_time_step[..., 0], + self.A, + B[:, 0], + C[:, 0], + self.D, + gate[..., 0], + time_proj_bias, + dt_softplus=True, + ).unsqueeze(-1) + else: + scan_outputs, ssm_state = selective_scan_fn( + hidden_states, + discrete_time_step, + self.A, + B.transpose(1, 2), + C.transpose(1, 2), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + return_last_state=True, + ) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_state.copy_(ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0] + return contextualized_states + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ): + if attn_metadata.prefill_metadata is not None: + offset = 0 + for i, prompt_len in enumerate( + attn_metadata.prefill_metadata.seq_lens): + cache = MambaCacheParams(True, + conv_state=conv_state[i].unsqueeze(0), + ssm_state=ssm_state[i].unsqueeze(0)) + hidden_states[offset:offset + prompt_len].copy_( + self.mamba_forward(hidden_states[offset:offset + + prompt_len].unsqueeze(0), + cache_params=cache)[0]) + offset += prompt_len + else: + cache = MambaCacheParams(False, + conv_state=conv_state, + ssm_state=ssm_state) + hidden_states = self.mamba_forward(hidden_states.unsqueeze(1), + cache_params=cache) + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class MambaMLP(nn.Module): + + def __init__( + self, + config: MambaConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + hidden_act = config.hidden_act + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class MambaDecoderLayer(nn.Module): + + def __init__(self, + config: MambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.mixer = MambaMixer(config, layer_idx) + + self.feed_forward = MambaMLP(config, quant_config=quant_config) + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm(hidden_states, residual) + + hidden_states = self.mixer(hidden_states, attn_metadata, conv_state, + ssm_state) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class MambaModel(nn.Module): + + def __init__( + self, + config: MambaConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embeddings = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + decoder_layers = [] + for i in range(config.num_hidden_layers): + decoder_layers.append( + MambaDecoderLayer(config, + layer_idx=i, + cache_config=cache_config, + quant_config=quant_config)) + self.layers = nn.ModuleList(decoder_layers) + self.norm_f = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.embeddings(input_ids) + residual = None + + for i in range(len(self.layers)): + layer = self.layers[i] + current_ssm_state = ssm_state[i] + current_conv_state = conv_state[i] + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + residual=residual, + conv_state=current_conv_state, + ssm_state=current_ssm_state, + ) + hidden_states, _ = self.norm_f(hidden_states, residual) + + return hidden_states + + +class MambaForCausalLM(nn.Module, HasInnerState): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embeddings": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: MambaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + scheduler_config: Optional[SchedulerConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.backbone = MambaModel(config, + cache_config=cache_config, + quant_config=quant_config, + lora_config=lora_config) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.lm_head = self.backbone.embeddings + # Current step used indices + self.current_indices: List[int] = [] + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple() + # Used as an input_buffer for the CUDA graph runs. + self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple() + # Maps between the request id and a dict that maps between the seq_id + # and its index inside the self.mamba_cache + self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs): + if not self.mamba_cache: + self._prepare_mamba_cache() + + if "seqlen_agnostic_capture_inputs" not in kwargs: + # We get here only on Prefill/Eager mode runs + assert all( + key in kwargs + for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + finished_requests_ids = kwargs["finished_requests_ids"] + batch_size = input_ids.shape[0] + if attn_metadata.prefill_metadata: + batch_size = len(request_ids_to_seq_ids) + ( + current_seqlen_agnostic_cache, + indices, + ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, + batch_size, + finished_requests_ids) + finished_requests_ids = kwargs["finished_requests_ids"] + self._release_mamba_cache(finished_requests_ids) + else: + # CUDA graph capturing runs + current_seqlen_agnostic_cache, indices = ( + kwargs["seqlen_agnostic_capture_inputs"], + [], + ) + self.current_indices = indices + + hidden_states = self.backbone(input_ids, positions, kv_caches, + attn_metadata, + current_seqlen_agnostic_cache[0], + current_seqlen_agnostic_cache[1]) + + if "seqlen_agnostic_capture_inputs" not in kwargs: + self._copy_mamba_cache_by_indices(self.current_indices, + current_seqlen_agnostic_cache) + + return hidden_states + + def _copy_mamba_cache_by_indices( + self, indices: List[int], + current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]): + for i, offset in enumerate(indices): + self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache) + + def _copy_mamba_cache(self, index_to: int, index_from: int, + from_buffer: Tuple[torch.Tensor, torch.Tensor]): + assert len(self.mamba_cache) > 0 + for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer): + cache_t[:, index_to].copy_(from_buffer_t[:, index_from], + non_blocking=True) + + def _assign_seq_id_to_mamba_cache(self, cur_rid: str, + seqs_id: List[int]) -> List[int]: + indices_for_current_run = [] + for seq_id in seqs_id: + if cur_rid not in self.mamba_cache_indices_mapping: + self.mamba_cache_indices_mapping[cur_rid] = {} + first_free_index = self._first_free_index_in_mamba_cache() + self.mamba_cache_indices_mapping[cur_rid][ + seq_id] = first_free_index + index_for_current_run = first_free_index + ## case of decoding n>1, copy prefill cache to decoding indices + elif seq_id not in (seq_ids2indices := + self.mamba_cache_indices_mapping[cur_rid]): + first_free_index = self._first_free_index_in_mamba_cache() + index_exist = list(seq_ids2indices.values())[0] + self._copy_mamba_cache(index_from=index_exist, + index_to=first_free_index, + from_buffer=self.mamba_cache) + self.mamba_cache_indices_mapping[cur_rid][ + seq_id] = first_free_index + index_for_current_run = first_free_index + else: + index_for_current_run = self.mamba_cache_indices_mapping[ + cur_rid][seq_id] + + indices_for_current_run.append(index_for_current_run) + return indices_for_current_run + + def _prepare_current_run_mamba_cache( + self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int, + finished_requests_ids: List[str] + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]: + indices_for_current_run = [] + for request_id, seqs_id in request_ids_to_seq_ids.items(): + if request_id in finished_requests_ids: + # Do not allocate cache for requests that run + # and finish right after + continue + indices_for_current_run += self._assign_seq_id_to_mamba_cache( + request_id, seqs_id) + ## Pad the batch in case of running batch that was not captured via CG + padded_indices = indices_for_current_run.copy() + pad_index = self._first_free_index_in_mamba_cache() + + for _ in range(batch_size - len(indices_for_current_run)): + padded_indices.append(pad_index) + + conv_state = self.mamba_cache[0][:, padded_indices] + temporal_state = self.mamba_cache[1][:, padded_indices] + + return (conv_state, temporal_state), indices_for_current_run + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + """ + Copy the relevant Mamba cache into the CUDA graph input buffer + that was provided during the capture runs + (MambaForCausalLM.mamba_gc_cache_buffer). + """ + assert all( + key in kwargs + for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + finished_requests_ids = kwargs["finished_requests_ids"] + self._release_mamba_cache(finished_requests_ids) + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + cg_batch_size = input_buffers['input_ids'].shape[0] + ( + current_mamba_cache, + indices, + ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, + cg_batch_size, + finished_requests_ids) + self.current_indices = indices + finished_requests_ids = kwargs["finished_requests_ids"] + self._release_mamba_cache(finished_requests_ids) + + for input_buffer, current_cache_buffer in zip( + input_buffers["seqlen_agnostic_capture_inputs"], + current_mamba_cache): + input_buffer.copy_(current_cache_buffer, non_blocking=True) + + def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs): + """ + Copy the relevant Mamba cache from the CUDA graph input_buffers + back to the MambaForCausalLM.mamba_cache after CUDA + graph replay run is done. + """ + self._copy_mamba_cache_by_indices( + self.current_indices, + input_buffers["seqlen_agnostic_capture_inputs"]) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + """ + Provide the CUDA graph capture runs with a buffer in adjusted size. + The buffer is used to maintain the Mamba Cache during the CUDA graph + replay runs. + """ + return tuple(buffer[:, :batch_size] + for buffer in self.mamba_gc_cache_buffer) + + def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): + for req_id in finished_seq_groups_req_ids: + if req_id in self.mamba_cache_indices_mapping: + self.mamba_cache_indices_mapping.pop(req_id) + + def _first_free_index_in_mamba_cache(self) -> int: + if self.mamba_cache: + max_possible_batch_size = self.mamba_cache[0].shape[1] + occupied = [ + id for seq_ids in self.mamba_cache_indices_mapping.values() + for id in seq_ids.values() + ] + first_free_index = [ + i not in occupied for i in range(max_possible_batch_size) + ].index(True) + return first_free_index + return 0 + + def _get_mamba_cache_shape( + self + ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: + world_size = get_tensor_model_parallel_world_size() + conv_state_shape = ( + self.config.intermediate_size // world_size, + self.config.conv_kernel, + ) + temporal_state_shape = ( + self.config.intermediate_size // world_size, + self.config.state_size, + ) + return conv_state_shape, temporal_state_shape + + def _prepare_mamba_cache(self): + dtype = self.lm_head.weight.dtype + num_mamba_layers = self.config.num_hidden_layers + max_batch_size = (_get_graph_batch_size( + self.scheduler_config.max_num_seqs) if self.scheduler_config else + max(_BATCH_SIZES_TO_CAPTURE)) + 10 + conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() + assert conv_state_shape is not None and temporal_state_shape is not None + + for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]: + buffer = (torch.empty(size=(num_mamba_layers, max_batch_size) + + conv_state_shape, + dtype=dtype, + device="cuda"), + torch.empty(size=(num_mamba_layers, max_batch_size) + + temporal_state_shape, + dtype=dtype, + device="cuda")) + setattr(self, buffername, buffer) + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "A_log" in name: + name = name.replace("A_log", "A") + + if ".self_attn." in name: + name = name.replace(".self_attn", "") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/rwkv_6.py b/vllm/model_executor/models/rwkv_6.py new file mode 100644 index 00000000000..bf8c9f628cd --- /dev/null +++ b/vllm/model_executor/models/rwkv_6.py @@ -0,0 +1,377 @@ +# coding=utf-8 +"""PyTorch RWKV6 model.(native PyTorch version)""" +""" +author: @Zhiyuan Li +email: +date: 2024-07-22 +""" +from dataclasses import dataclass +from typing import Dict, Iterable, List, Optional, Tuple + + +import torch +import torch.nn as nn +from transformers import RwkvConfig +from vllm.config import LoRAConfig, CacheConfig, SchedulerConfig +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.models.interfaces import HasInnerState +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, + _get_graph_batch_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + + +MyModule = torch.jit.ScriptModule +MyFunction = torch.jit.script_method +KVCache = Tuple[torch.Tensor, torch.Tensor] + +@dataclass +class RwkvCacheParams: + is_prompt: bool = False + ssm_state: torch.Tensor = torch.Tensor() + + +class Rwkv_Block(MyModule): + def __init__(self, block_w: dict, hidden_size: int, n_head: int): + super().__init__() + self.hidden_size = hidden_size + self.n_head = n_head + self.head_size = hidden_size // n_head + + self.ln1 = nn.LayerNorm(hidden_size) + self.ln1.weight = nn.Parameter(block_w['ln1.weight']) + self.ln1.bias = nn.Parameter(block_w['ln1.bias']) + self.ln2 = nn.LayerNorm(hidden_size) + self.ln2.weight = nn.Parameter(block_w['ln2.weight']) + self.ln2.bias = nn.Parameter(block_w['ln2.bias']) + + + self.silu = nn.SiLU(inplace=False) + + self.att_time_maa_x = nn.Parameter(block_w['att.time_maa_x']) + self.att_time_maa_w = nn.Parameter(block_w['att.time_maa_w']) + self.att_time_maa_k = nn.Parameter(block_w['att.time_maa_k']) + self.att_time_maa_v = nn.Parameter(block_w['att.time_maa_v']) + self.att_time_maa_r = nn.Parameter(block_w['att.time_maa_r']) + self.att_time_maa_g = nn.Parameter(block_w['att.time_maa_g']) + self.att_time_maa_w1 = nn.Parameter(block_w['att.time_maa_w1']) + self.att_time_maa_w2 = nn.Parameter(block_w['att.time_maa_w2']) + self.att_time_decay = nn.Parameter(block_w['att.time_decay']) + self.att_time_decay_w1 = nn.Parameter(block_w['att.time_decay_w1']) + self.att_time_decay_w2 = nn.Parameter(block_w['att.time_decay_w2']) + self.att_time_faaaa = nn.Parameter(block_w['att.time_faaaa']) + self.att_receptance = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.att_receptance.weight = nn.Parameter(block_w['att.receptance.weight']) + self.att_key = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.att_key.weight = nn.Parameter(block_w['att.key.weight']) + self.att_value = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.att_value.weight = nn.Parameter(block_w['att.value.weight']) + self.att_output = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.att_output.weight = nn.Parameter(block_w['att.output.weight']) + self.att_gate = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.att_gate.weight = nn.Parameter(block_w['att.gate.weight']) + + + self.att_group_norm = nn.GroupNorm(num_groups=n_head, num_channels=hidden_size, eps=1e-5, affine=True) + self.att_group_norm.weight = nn.Parameter(block_w['att.ln_x.weight']) + self.att_group_norm.bias = nn.Parameter(block_w['att.ln_x.bias']) + + self.ffn_time_maa_k = nn.Parameter(block_w['ffn.time_maa_k']) + self.ffn_time_maa_r = nn.Parameter(block_w['ffn.time_maa_r']) + self.ffn_key = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.ffn_key.weight = nn.Parameter(block_w['ffn.key.weight']) + self.ffn_receptance = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.ffn_receptance.weight = nn.Parameter(block_w['ffn.receptance.weight']) + self.ffn_value = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.ffn_value.weight = nn.Parameter(block_w['ffn.value.weight']) + + @MyFunction + def channel_mixing(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor: + i0 = (2 + self.head_size) * i + 0 + sx = state[:, i0] - x + state[:, i0] = x + xk = x + sx * self.ffn_time_maa_k + xr = x + sx * self.ffn_time_maa_r + r = torch.sigmoid(self.ffn_receptance(xr)) + k = torch.relu(self.ffn_key(xk)).pow(2) + output = r * self.ffn_value(k) + return output, state + + @MyFunction + def channel_mixing_parallel(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor: + i0 = (2 + self.head_size) * i + 0 + + sx_lerp = torch.empty_like(x) + sx_lerp[:, 0] = state[:, i0] - x[:, 0] + sx_lerp[:, 1:] = x[:, :-1] - x[:, 1:] + + state[:, i0] = x[:, -1] + + xk = x + sx_lerp * self.ffn_time_maa_k + xr = x + sx_lerp * self.ffn_time_maa_r + + r = torch.sigmoid(self.ffn_receptance(xr)) # [Batch, L, hiddle_size] + k = torch.relu(self.ffn_key(xk)).pow(2) + + output = r * self.ffn_value(k) + return output, state + + def time_mixing(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor: + batch_size, H, S = x.size(0), self.n_head, self.head_size + x, state, g = self.time_mixing_jit(x, state, i, batch_size, H, S) + + x = self.time_mixing_jit2(x, g) + + return x, state + + @MyFunction + def time_mixing_jit(self, x: torch.Tensor, state: torch.Tensor, i: int, + batch_size: int, H: int, S: int): + i1 = (2 + S) * i + 1 # i is the block number + + sx = state[:, i1] - x + state[:, i1] = x # Information is compressed to position 1 on each layer + + xxx = x + sx * self.att_time_maa_x + xxx = torch.tanh(xxx @ self.att_time_maa_w1).view(batch_size, 5, 1, -1) + xxx = torch.matmul(xxx, self.att_time_maa_w2).view(batch_size, 5, -1) + mw, mk, mv, mr, mg = xxx.unbind(dim=1) + + xw = x + sx * (self.att_time_maa_w + mw) + xk = x + sx * (self.att_time_maa_k + mk) + xv = x + sx * (self.att_time_maa_v + mv) + xr = x + sx * (self.att_time_maa_r + mr) + xg = x + sx * (self.att_time_maa_g + mg) + + # calculate w, r, k, v, g + w = (self.att_time_decay + (torch.tanh(xw @ self.att_time_decay_w1) @ self.att_time_decay_w2)) + w = -torch.exp(w.view(batch_size, H, S, 1)) + + r = self.att_receptance(xr).view(batch_size, H, 1, S) + k = self.att_key(xk).view(batch_size, H, S, 1) + v = self.att_value(xv).view(batch_size, H, 1, S) + g = self.silu(self.att_gate(xg)) + + # Update state using attention mechanism + s = state[:, (2+S)*i+2:(2+S)*(i+1), :].view(batch_size, H, S, S) + a = k @ v + x = r @ (self.att_time_faaaa * a + s) + s = a + torch.exp(w) * s + # Update the attention parameters of the i-th layer STATE + state[:, (2+S)*i+2:(2+S)*(i+1), :] = s.view(batch_size, S, -1) + return x, state, g + + @MyFunction + def time_mixing_jit2(self, x:torch.Tensor, g): + return self.att_output(self.att_group_norm(x.flatten(start_dim=1)) * g) + + def time_mixing_parallel(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor: + batch_size, L, H, S = x.size(0), x.size(1), self.n_head, self.head_size + x, state, g = self.time_mixing_parallel_jit1(x, state, i, batch_size, L, H, S) + + x = self.time_mixing_parallel_jit2(x, g, batch_size, L) + + return x, state + + @MyFunction + def time_mixing_parallel_jit1(self, x: torch.Tensor, state: torch.Tensor, i: int, + batch_size: int, L: int, H: int, S: int): + i1 = (2 + S) * i + 1 + sx_lerp = torch.empty_like(x) + sx_lerp[:, 0] = state[:, i1] - x[:, 0] + + sx_lerp[:, 1:] = x[:, :-1] - x[:, 1:] + + state[:, i1] = x[:, -1] + + xxx = x + sx_lerp * self.att_time_maa_x # torch.Size([B, L, hiddle_size]) + xxx = torch.tanh(xxx @ self.att_time_maa_w1).view(batch_size, L, 5, 1, -1) # att_time_maa_w1: [hiddle_size, 160] + xxx = torch.matmul(xxx, self.att_time_maa_w2).view(batch_size, L, 5, -1) # [Batch, L, 5, hiddle_size] + + mw, mk, mv, mr, mg = xxx.unbind(dim=2) # [10, 100, hiddle_size] + + xw = x + sx_lerp * (self.att_time_maa_w + mw) # torch.Size([B, L, hiddle_size]) + xk = x + sx_lerp * (self.att_time_maa_k + mk) + xv = x + sx_lerp * (self.att_time_maa_v + mv) + xr = x + sx_lerp * (self.att_time_maa_r + mr) + xg = x + sx_lerp * (self.att_time_maa_g + mg) + + w = (self.att_time_decay + (torch.tanh(xw @ self.att_time_decay_w1) @ self.att_time_decay_w2)) + w = -torch.exp(w.view(batch_size, L, H, S, 1)) + + r = self.att_receptance(xr).view(batch_size, L, H, 1, S) + k = self.att_key(xk).view(batch_size, L, H, S, 1) + v = self.att_value(xv).view(batch_size, L, H, 1, S) + g = self.silu(self.att_gate(xg)) # [10, 100, hiddle_size] + # TODO, apply kernel here, cuda or fla(triton) + + + w = torch.exp(w) + s = state[:, (2+S)*i+2:(2+S)*(i+1)].view(batch_size, H, S, S) + a = k @ v # a: [batch_size, L, H, S, S] + + state_s = torch.zeros(batch_size, L, H, S, S, dtype=x.dtype, device=x.device) + state_s[:, 0] = s + + for l in range(L-1): + s = a[:, l] + w[:, l] * s + state_s[:, l+1] = s + s = a[:, -1] + w[:, -1] * s + + state[:, (2+S)*i+2:(2+S)*(i+1)] = s.view(batch_size, S, -1) + + x = r @ (self.att_time_faaaa * a + state_s) + return x, state, g + + @MyFunction + def time_mixing_parallel_jit2(self, x: torch.Tensor, g: torch.Tensor, batch_size: int, L:int): + return self.att_output(self.att_group_norm(x.flatten(start_dim=2).view(batch_size * L, -1)).view(batch_size, L, -1) * g) + + @torch.no_grad() + def forward(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor: + x_time, state = self.time_mixing(self.ln1(x), state, i) + x = x + x_time + x_channel, state = self.channel_mixing(self.ln2(x), state, i) + x = x + x_channel + + return x, state + + @torch.no_grad() + def forward_parallel(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor: + # Time mixing + x_time, state = self.time_mixing_parallel(self.ln1(x), state, i) + x = x + x_time + + # Channel mixing + x_channel, state = self.channel_mixing_parallel(self.ln2(x), state, i) + x = x + x_channel + + return x, state + +class RwkvModel(MyModule): + def __init__(self, args: dict): + super().__init__() + self.args = args + self.load_params() + self.eval() + + + + def load_params(self, load_from_file: bool = True, w: dict = None): + # TODO: vllm + if load_from_file: + if not self.args['MODEL_NAME'].endswith('.pth'): + self.args['MODEL_NAME'] += '.pth' + w = torch.load(self.args['MODEL_NAME'], map_location="cpu") + else: + assert w is not None + + self.num_layer = 0 + for k in w.keys(): + if '.time_' in k: w[k] = w[k].squeeze() + if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1) + if "blocks" in k: self.num_layer = max(self.num_layer, int(k.split(".")[1])) + self.num_layer += 1 + + self.n_head = w['blocks.0.att.time_faaaa'].shape[0] + self.n_embd = w['blocks.0.ln1.weight'].shape[0] + self.head_size = self.n_embd // self.n_head + self.state_size = [self.num_layer * (2 + self.head_size), self.n_embd] + + self.emb = nn.Embedding.from_pretrained(w['emb.weight'], freeze=True) + + + self.ln0 = nn.LayerNorm(self.n_embd) + self.ln0.weight = nn.Parameter(w['blocks.0.ln0.weight']) + self.ln0.bias = nn.Parameter(w['blocks.0.ln0.bias']) + + + self.blocks = nn.ModuleList() + + for i in range(self.num_layer): + block_w = {k[len(f'blocks.{i}.'):]: v for k, v in w.items() if f'blocks.{i}.' in k} + self.blocks.append(Rwkv_Block(block_w, self.n_embd, self.n_head, self.args)) + + + self.ln_out = nn.LayerNorm(self.n_embd) + self.ln_out.weight = nn.Parameter(w['ln_out.weight']) + self.ln_out.bias = nn.Parameter(w['ln_out.bias']) + + + self.head = nn.Linear(self.n_embd, self.args['vocab_size'], bias=False) + self.head.weight = nn.Parameter(w['head.weight']) + + + @torch.no_grad() + def forward(self, + input_ids: torch.Tensor, + state: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + x = self.forward_jit1(input_ids) + + if attn_metadata.prefill_metadata is not None: + # Prefill phase + for i, block in enumerate(self.blocks): + x, state = block.forward_parallel(x, state, i) + else: + # Decoding phase + for i, block in enumerate(self.blocks): + x, state = block(x, state, i) + + x = self.forward_jit2(x) + + return x, state + + + @MyFunction + def forward_jit1(self, token: torch.Tensor) -> torch.Tensor: + return self.ln0(self.emb(token)) + + @MyFunction + def forward_jit2(self, x: torch.Tensor) -> torch.Tensor: + return self.head(self.ln_out(x)) + + @torch.no_grad() + def forward_parallel_slices(self, + input_ids: torch.Tensor, + state: torch.Tensor, + attn_metadata: AttentionMetadata, + slice_len: int = 64) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Prefill forward with chunks of the RWKV6 model. + Args: + x (torch.Tensor): Input tensor, shape [Batch, N_embd]. + state (torch.Tensor): Hidden state tensor, shape [Batch, State Size, N_embd]. + i (int): Time index. + Returns: + torch.Tensor: Forward pass result tensor, shape same as input x. + """ + # FIXME! + data_len = input_ids.shape[1] + for i in range((data_len-2)//slice_len+1): + start = i*slice_len + end = min((i+1)*slice_len, data_len) + input_ids_ith_chunk = input_ids[:, start:end] + token_out, state = self.forward(input_ids_ith_chunk, state, attn_metadata) + + return token_out, state + + def init_state(self, batch_size: int) -> torch.Tensor: + state = torch.zeros(batch_size, self.state_size[0], self.state_size[1]) + return state + diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 205b4f58f7a..2f4a0657c3f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -23,6 +23,8 @@ FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.attention.backends.placeholder_attn import ( + PlaceholderAttentionBackend) from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -234,7 +236,7 @@ def __init__( self.model_config.dtype, self.kv_cache_dtype, self.block_size, - ) if num_attn_heads else None + ) if num_attn_heads else PlaceholderAttentionBackend() # Multi-modal data support self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ @@ -1506,8 +1508,9 @@ def forward( # Copy the input tensors to the input buffers. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True) - self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, - non_blocking=True) + if self.backend_name != "No attention": + self.input_buffers["slot_mapping"].copy_( + attn_metadata.slot_mapping, non_blocking=True) if self.backend_name != "flashinfer": self.input_buffers["seq_lens_tensor"].copy_( attn_metadata.decode_metadata.seq_lens_tensor, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 56d8587f8f0..f80b8be89a8 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -190,11 +190,15 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: "not properly cleaned up before initializing the vLLM instance.") cache_block_size = self.get_cache_block_size_bytes() - num_gpu_blocks = int( - (total_gpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) // cache_block_size) - num_cpu_blocks = int(self.cache_config.swap_space_bytes // - cache_block_size) + if cache_block_size == 0: + num_gpu_blocks = 0 + num_cpu_blocks = 0 + else: + num_gpu_blocks = int( + (total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) if self.model_runner.lora_manager: @@ -211,6 +215,7 @@ def initialize_cache(self, num_gpu_blocks: int, """ raise_if_cache_size_invalid(num_gpu_blocks, self.cache_config.block_size, + self.cache_config.is_attention_free, self.model_config.max_model_len) self.cache_config.num_gpu_blocks = num_gpu_blocks @@ -361,14 +366,18 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): "`dtype` flag in CLI, for example: --dtype=half.") -def raise_if_cache_size_invalid(num_gpu_blocks, block_size, +def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free, max_model_len) -> None: - if num_gpu_blocks <= 0: + if is_attention_free and num_gpu_blocks != 0: + raise ValueError("No memory should be allocated for the cache blocks " + f"for an attention-free model, but {num_gpu_blocks}" + "blocks are allocated.") + if not is_attention_free and num_gpu_blocks <= 0: raise ValueError("No available memory for the cache blocks. " "Try increasing `gpu_memory_utilization` when " "initializing the engine.") max_seq_len = block_size * num_gpu_blocks - if max_model_len > max_seq_len: + if not is_attention_free and max_model_len > max_seq_len: raise ValueError( f"The model's max seq len ({max_model_len}) " "is larger than the maximum number of tokens that can be "