From ce630ea78004d33733229a9d783dda98c96ba227 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 8 Jul 2024 14:55:58 +0000 Subject: [PATCH 01/15] WiP adding support for Mamba --- examples/offline_inference.py | 4 +- vllm/attention/backends/no_attention.py | 161 ++++++ vllm/attention/selector.py | 2 + vllm/config.py | 12 + vllm/core/scheduler.py | 4 + vllm/engine/llm_engine.py | 6 +- vllm/model_executor/models/__init__.py | 3 +- vllm/model_executor/models/mamba.py | 704 ++++++++++++++++++++++++ vllm/sequence.py | 3 + vllm/worker/model_runner.py | 7 +- vllm/worker/worker.py | 16 +- 11 files changed, 911 insertions(+), 11 deletions(-) create mode 100644 vllm/attention/backends/no_attention.py create mode 100644 vllm/model_executor/models/mamba.py diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 9b758fa2479..f64082ac0fb 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -8,10 +8,10 @@ "The future of AI is", ] # Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +sampling_params = SamplingParams(temperature=0.0, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="state-spaces/mamba-370m-hf") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/attention/backends/no_attention.py b/vllm/attention/backends/no_attention.py new file mode 100644 index 00000000000..c42f39789eb --- /dev/null +++ b/vllm/attention/backends/no_attention.py @@ -0,0 +1,161 @@ +from dataclasses import dataclass, fields +from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type, + TypeVar) +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata) +import torch + + +class NoAttentionBackend(AttentionBackend): + """Placeholder backend for when no attention is needed.""" + + @staticmethod + def get_name() -> str: + return "No attention" + + @staticmethod + def get_impl_cls() -> Type["NoAttentionImpl"]: + return NoAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["NoAttentionMetadata"]: + return NoAttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (1, 1, 1, 1, 1) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + return + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + return + + +@dataclass +class NoAttentionMetadata(AttentionMetadata): + """Attention metadata for prefill and decode batched together.""" + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + _cached_prefill_metadata: Optional["NoAttentionMetadata"] = None + _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["NoAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = NoAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = FlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + +class NoAttentionImpl(AttentionImpl): + def __init__(self, *args, **kwargs) -> None: + return + + def forward(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index ae63eb1d48f..93c39e45635 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -35,6 +35,8 @@ def get_attn_backend( is_blocksparse: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" + import pdb + pdb.set_trace() if is_blocksparse: logger.info("Using BlocksparseFlashAttention backend.") diff --git a/vllm/config.py b/vllm/config.py index 1ea28887968..940dea2dbb0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -312,6 +312,12 @@ def get_head_size(self) -> int: # FlashAttention supports only head_size 32, 64, 128, 256, # we need to pad head_size 192 to 256 return 256 + + if hasattr(self.hf_text_config, "model_type" + ) and self.hf_text_config.model_type == 'mamba': + # Is this going to explode + return 0 + if hasattr(self.hf_text_config, "head_dim"): return self.hf_text_config.head_dim # FIXME(woosuk): This may not be true for all models. @@ -342,6 +348,8 @@ def get_total_num_kv_heads(self) -> int: if self.hf_config.model_type == "dbrx": return getattr(self.hf_config.attn_config, "kv_n_heads", self.hf_config.num_attention_heads) + if self.hf_config.model_type == "mamba": + return 0 attributes = [ # For Falcon: @@ -393,6 +401,10 @@ def contains_seqlen_agnostic_layers( def get_layers_block_type(self, parallel_config: "ParallelConfig") -> List[str]: num_layers = self.get_num_layers(parallel_config) + + if self.hf_config.model_type == "mamba": + return ["mamba"] * num_layers + # Transformers supports layers_block_type @property return getattr(self.hf_config, "layers_block_type", ["attention"] * num_layers) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 9e626b28839..3382f19d254 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -695,6 +695,7 @@ def _schedule_prefills( # If the sequence group cannot be allocated, stop. can_allocate = self.block_manager.can_allocate(seq_group) + can_allocate = True #TODO HACK TMS if can_allocate == AllocStatus.LATER: break elif can_allocate == AllocStatus.NEVER: @@ -757,6 +758,8 @@ def _schedule_default(self) -> SchedulerOutputs: decodes. If there's a pressure on GPU memory, decode requests can be swapped or preempted. """ + import pdb + pdb.set_trace() # Include running requests to the budget. budget = SchedulingBudget( token_budget=self.scheduler_config.max_num_batched_tokens, @@ -1054,6 +1057,7 @@ def free_finished_seq_groups(self) -> None: if not seq_group.is_finished()) def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: + return #TODO TMS HACK self.block_manager.allocate(seq_group) for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): seq.status = SequenceStatus.RUNNING diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index de7604ece7c..6a45524c0fe 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -252,9 +252,13 @@ def __init__( load_config=load_config, ) - if not self.model_config.embedding_mode: + if self.model_config.get_num_attention_layers(parallel_config) == 0: + self.cache_config.num_gpu_blocks = 0 + self.cache_config.num_cpu_blocks = 0 + elif not self.model_config.embedding_mode: self._initialize_kv_caches() + # If usage stat is enabled, collect relevant info. if is_usage_stats_enabled(): from vllm.model_executor.model_loader import ( diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index a4fe18d52d6..0cca07dc567 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -63,7 +63,8 @@ "XverseForCausalLM": ("xverse", "XverseForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), - "JambaForCausalLM": ("jamba", "JambaForCausalLM") + "JambaForCausalLM": ("jamba", "JambaForCausalLM"), + "MambaForCausalLM": ("mamba", "MambaForCausalLM") } _EMBEDDING_MODELS = { diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py new file mode 100644 index 00000000000..156d045d296 --- /dev/null +++ b/vllm/model_executor/models/mamba.py @@ -0,0 +1,704 @@ +# coding=utf-8 +"""PyTorch MAMBA model.""" +from dataclasses import dataclass +from typing import Dict, Iterable, List, Optional, Tuple + +import torch +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from torch import nn +from torch.nn.parameter import Parameter +from transformers import MambaConfig + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.worker.model_runner import _BATCH_SIZES_TO_CAPTURE + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +@dataclass +class MambaCacheParams: + is_prompt: bool = False + conv_state: torch.Tensor = torch.Tensor() + ssm_state: torch.Tensor = torch.Tensor() + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +class MambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) + """ + + def __init__(self, config: MambaConfig, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = config.intermediate_size + self.time_step_rank = int(config.time_step_rank) + self.use_conv_bias = config.use_conv_bias + + # TODO: ?? + #self.use_bias = config.mamba_proj_bias + self.use_bias = False + + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.intermediate_size, + bias=self.use_conv_bias, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear(self.hidden_size, + [self.intermediate_size] * 2, + bias=self.use_bias) + # selective projection used to make dt, B and C input dependent + self.x_proj = RowParallelLinear( + self.intermediate_size, + self.time_step_rank + self.ssm_state_size * 2, + bias=False, + ) + # time step projection (discretization) - + # In the forward we need to apply dt_proj without the bias, + # as the bias is added in the selective scan kernel. + self.dt_proj = ColumnParallelLinear(self.time_step_rank, + self.intermediate_size, + bias=True, + skip_bias_add=True) + + def weight_loader(param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + param.data.copy_( + loaded_weight.data.split(loaded_weight.shape[0] // tp_size, + dim=0)[tp_rank]) + + def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): + weight_loader(param, -torch.exp(loaded_weight.float())) + + tp_size = get_tensor_model_parallel_world_size() + self.A = nn.Parameter( + torch.empty( + self.intermediate_size // tp_size, + self.ssm_state_size, + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) + + set_weight_attrs(self.D, {"weight_loader": weight_loader}) + set_weight_attrs(self.A, {"weight_loader": A_weight_loader}) + + self.out_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=self.use_bias, + input_is_parallel=True, + ) + self.activation = config.hidden_act + + self.dt_layernorm = RMSNorm(self.time_step_rank, + eps=config.layer_norm_epsilon) + self.b_layernorm = RMSNorm(self.ssm_state_size, + eps=config.layer_norm_epsilon) + self.c_layernorm = RMSNorm(self.ssm_state_size, + eps=config.layer_norm_epsilon) + + def mamba_forward(self, + hidden_states: torch.Tensor, + cache_params: MambaCacheParams = None): + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states)[0].transpose(1, 2) + hidden_states, gate = projected_states.chunk(2, dim=1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + if cache_params is not None and not cache_params.is_prompt: + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), + cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.unsqueeze(-1) + else: + if cache_params is not None: + conv_states = nn.functional.pad( + hidden_states, + (self.conv_kernel_size - hidden_states.shape[-1], 0)) + cache_params.conv_state.copy_(conv_states) + + hidden_states = causal_conv1d_fn( + hidden_states, + conv_weights, + self.conv1d.bias, + activation=self.activation, + ) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0] + + time_step, B, C = torch.split( + ssm_parameters, + [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], + dim=-1, + ) + time_step = self.dt_layernorm(time_step.contiguous()) + B = self.b_layernorm(B.contiguous()) + C = self.c_layernorm(C.contiguous()) + + discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = (self.dt_proj.bias.float() if hasattr( + self.dt_proj, "bias") else None) + if cache_params is not None and not cache_params.is_prompt: + scan_outputs = selective_state_update( + cache_params.ssm_state, + hidden_states[..., 0], + discrete_time_step[..., 0], + self.A, + B[:, 0], + C[:, 0], + self.D, + gate[..., 0], + time_proj_bias, + dt_softplus=True, + ).unsqueeze(-1) + else: + scan_outputs, ssm_state = selective_scan_fn( + hidden_states, + discrete_time_step, + self.A, + B.transpose(1, 2), + C.transpose(1, 2), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + return_last_state=True, + ) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_state.copy_(ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0] + return contextualized_states + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ): + if attn_metadata.prefill_metadata is not None: + offset = 0 + for i, prompt_len in enumerate( + attn_metadata.prefill_metadata.seq_lens): + cache = MambaCacheParams(True, + conv_state=conv_state[i].unsqueeze(0), + ssm_state=ssm_state[i].unsqueeze(0)) + hidden_states[offset:offset + prompt_len].copy_( + self.mamba_forward(hidden_states[offset:offset + + prompt_len].unsqueeze(0), + cache_params=cache)[0]) + offset += prompt_len + else: + cache = MambaCacheParams(False, + conv_state=conv_state, + ssm_state=ssm_state) + hidden_states = self.mamba_forward(hidden_states.unsqueeze(1), + cache_params=cache) + hidden_states = hidden_states.squeeze(1) + + return hidden_states + +class MambaMLP(nn.Module): + + def __init__( + self, + config: MambaConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + hidden_act = config.hidden_act + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + +class MambaDecoderLayer(nn.Module): + + def __init__(self, + config: MambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.mixer = MambaMixer(config, layer_idx) + + self.feed_forward = MambaMLP(config, quant_config=quant_config) + self.norm = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm( + hidden_states, residual) + + hidden_states = self.mixer(hidden_states, attn_metadata, conv_state, + ssm_state) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + +class MambaModel(nn.Module): + + def __init__( + self, + config: MambaConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embeddings = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + decoder_layers = [] + for i in range(config.num_hidden_layers): + decoder_layers.append( + MambaDecoderLayer(config, + layer_idx=i, + cache_config=cache_config, + quant_config=quant_config)) + self.layers = nn.ModuleList(decoder_layers) + self.norm_f = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.embeddings(input_ids) + residual = None + + for i in range(len(self.layers)): + layer = self.layers[i] + current_ssm_state = ssm_state[i] + current_conv_state = conv_state[i] + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + residual=residual, + conv_state=current_conv_state, + ssm_state=current_ssm_state, + ) + hidden_states, _ = self.norm_f(hidden_states, residual) + return hidden_states + +class MambaForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embeddings": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: MambaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.backbone = MambaModel(config, + cache_config=cache_config, + quant_config=quant_config, + lora_config=lora_config) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + # Current step used indices + self.current_indices: List[int] = [] + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple() + # Used as an input_buffer for the CUDA graph runs. + self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple() + # Maps between the request id and a dict that maps between the seq_id + # and its index inside the self.mamba_cache + self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs): + if not self.mamba_cache: + self._prepare_mamba_cache() + + if "seqlen_agnostic_capture_inputs" not in kwargs: + # We get here only on Prefill/Eager mode runs + assert all( + key in kwargs + for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + batch_size = input_ids.shape[0] + if attn_metadata.prefill_metadata: + batch_size = len(request_ids_to_seq_ids) + ( + current_seqlen_agnostic_cache, + indices, + ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, + batch_size) + finished_requests_ids = kwargs["finished_requests_ids"] + self._release_mamba_cache(finished_requests_ids) + else: + # CUDA graph capturing runs + current_seqlen_agnostic_cache, indices = ( + kwargs["seqlen_agnostic_capture_inputs"], + [], + ) + self.current_indices = indices + + hidden_states = self.backbone(input_ids, positions, kv_caches, + attn_metadata, + current_seqlen_agnostic_cache[0], + current_seqlen_agnostic_cache[1]) + + if "seqlen_agnostic_capture_inputs" not in kwargs: + self._copy_mamba_cache_by_indices(self.current_indices, + current_seqlen_agnostic_cache) + + return hidden_states + + def _copy_mamba_cache_by_indices( + self, indices: List[int], + current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]): + for i, offset in enumerate(indices): + self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache) + + def _copy_mamba_cache(self, index_to: int, index_from: int, + from_buffer: Tuple[torch.Tensor, torch.Tensor]): + assert len(self.mamba_cache) > 0 + for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer): + cache_t[:, index_to].copy_(from_buffer_t[:, index_from], + non_blocking=True) + + def _assign_seq_id_to_mamba_cache(self, cur_rid: str, + seqs_id: List[int]) -> List[int]: + indices_for_current_run = [] + for seq_id in seqs_id: + if cur_rid not in self.mamba_cache_indices_mapping: + self.mamba_cache_indices_mapping[cur_rid] = {} + first_free_index = self._first_free_index_in_mamba_cache() + self.mamba_cache_indices_mapping[cur_rid][ + seq_id] = first_free_index + index_for_current_run = first_free_index + ## case of decoding n>1, copy prefill cache to decoding indices + elif seq_id not in (seq_ids2indices := + self.mamba_cache_indices_mapping[cur_rid]): + first_free_index = self._first_free_index_in_mamba_cache() + index_exist = list(seq_ids2indices.values())[0] + self._copy_mamba_cache(index_from=index_exist, + index_to=first_free_index, + from_buffer=self.mamba_cache) + self.mamba_cache_indices_mapping[cur_rid][ + seq_id] = first_free_index + index_for_current_run = first_free_index + else: + index_for_current_run = self.mamba_cache_indices_mapping[ + cur_rid][seq_id] + + indices_for_current_run.append(index_for_current_run) + return indices_for_current_run + + def _prepare_current_run_mamba_cache( + self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]: + indices_for_current_run = [] + for request_id, seqs_id in request_ids_to_seq_ids.items(): + indices_for_current_run += self._assign_seq_id_to_mamba_cache( + request_id, seqs_id) + ## Pad the batch in case of running batch that was not captured via CG + padded_indices = indices_for_current_run.copy() + pad_index = self._first_free_index_in_mamba_cache() + + for _ in range(batch_size - len(indices_for_current_run)): + padded_indices.append(pad_index) + + conv_state = self.mamba_cache[0][:, padded_indices] + temporal_state = self.mamba_cache[1][:, padded_indices] + + return (conv_state, temporal_state), indices_for_current_run + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + """ + Copy the relevant Mamba cache into the CUDA graph input buffer + that was provided during the capture runs + (MambaForCausalLM.mamba_gc_cache_buffer). + """ + assert all( + key in kwargs + for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + batch_size = len(request_ids_to_seq_ids) + ( + current_mamba_cache, + indices, + ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, + batch_size) + self.current_indices = indices + finished_requests_ids = kwargs["finished_requests_ids"] + self._release_mamba_cache(finished_requests_ids) + + for input_buffer, current_cache_buffer in zip( + input_buffers["seqlen_agnostic_capture_inputs"], + current_mamba_cache): + input_buffer.copy_(current_cache_buffer, non_blocking=True) + + def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs): + """ + Copy the relevant Mamba cache from the CUDA graph input_buffers + back to the MambaForCausalLM.mamba_cache after CUDA + graph replay run is done. + """ + self._copy_mamba_cache_by_indices( + self.current_indices, + input_buffers["seqlen_agnostic_capture_inputs"]) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + """ + Provide the CUDA graph capture runs with a buffer in adjusted size. + The buffer is used to maintain the Mamba Cache during the CUDA graph + replay runs. + """ + return tuple(buffer[:, :batch_size] + for buffer in self.mamba_gc_cache_buffer) + + def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): + for req_id in finished_seq_groups_req_ids: + if req_id in self.mamba_cache_indices_mapping: + self.mamba_cache_indices_mapping.pop(req_id) + + def _first_free_index_in_mamba_cache(self) -> int: + if self.mamba_cache: + max_possible_batch_size = self.mamba_cache[0].shape[1] + occupied = [ + id for seq_ids in self.mamba_cache_indices_mapping.values() + for id in seq_ids.values() + ] + first_free_index = [ + i not in occupied for i in range(max_possible_batch_size) + ].index(True) + return first_free_index + return 0 + + def _get_mamba_cache_shape( + self + ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + conv_state_shape = ( + self.config.intermediate_size // world_size, + self.config.conv_kernel, + ) + temporal_state_shape = ( + self.config.intermediate_size // world_size, + self.config.state_size, + ) + return conv_state_shape, temporal_state_shape + + def _prepare_mamba_cache(self): + dtype = self.lm_head.weight.dtype + num_mamba_layers = self.config.num_hidden_layers + max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + 10 + conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() + assert conv_state_shape is not None and temporal_state_shape is not None + for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]: + buffer = (torch.empty(size=(num_mamba_layers, max_batch_size) + + conv_state_shape, + dtype=dtype, + device="cuda"), + torch.empty(size=(num_mamba_layers, max_batch_size) + + temporal_state_shape, + dtype=dtype, + device="cuda")) + setattr(self, buffername, buffer) + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + for k, v in params_dict.items(): + print(k) + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "A_log" in name: + name = name.replace("A_log", "A") + + if ".self_attn." in name: + name = name.replace(".self_attn", "") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/sequence.py b/vllm/sequence.py index d200115aa09..5ea623f1cc3 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -643,6 +643,9 @@ def __init__( encoder_seq_data: Optional[SequenceData] = None, cross_block_table: Optional[List[int]] = None, ) -> None: + import pdb + pdb.set_trace() + self.request_id = request_id self.is_prompt = is_prompt self.seq_data = seq_data diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d0c82d6bbed..ac257d77bd5 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -23,6 +23,8 @@ FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.attention.backends.no_attention import NoAttentionBackend + from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, SchedulerConfig) @@ -222,7 +224,7 @@ def __init__( self.model_config.dtype, self.kv_cache_dtype, self.block_size, - ) if num_attn_heads else None + ) if num_attn_heads else NoAttentionBackend() # Multi-modal data support self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ @@ -395,6 +397,9 @@ def _prepare_model_input_tensors( block_aligned_sliding_window = \ sliding_window_blocks * self.block_size + import pdb + pdb.set_trace() + for seq_group_metadata in seq_group_metadata_list: seq_ids = list(seq_group_metadata.seq_data.keys()) is_prompt = seq_group_metadata.is_prompt diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 58707269bd6..d0be7811012 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -184,11 +184,15 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: "not properly cleaned up before initializing the vLLM instance.") cache_block_size = self.get_cache_block_size_bytes() - num_gpu_blocks = int( - (total_gpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) // cache_block_size) - num_cpu_blocks = int(self.cache_config.swap_space_bytes // - cache_block_size) + if cache_block_size == 0: + num_gpu_blocks = 0 + num_cpu_blocks = 0 + else: + num_gpu_blocks = int( + (total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) if self.model_runner.lora_manager: @@ -209,7 +213,7 @@ def initialize_cache(self, num_gpu_blocks: int, self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - + self._init_cache_engine() self._warm_up_model() From 6c59b06a569e3d2abcbb90d6507e118bc483d6c4 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 9 Jul 2024 21:50:18 +0000 Subject: [PATCH 02/15] wip --- vllm/config.py | 3 +++ vllm/core/scheduler.py | 5 ++--- vllm/sequence.py | 5 +---- vllm/worker/model_runner.py | 3 --- 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 940dea2dbb0..d31433552a8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -798,6 +798,9 @@ def __init__(self, if enable_chunked_prefill: logger.info("Chunked prefill is enabled (EXPERIMENTAL).") + #TODO: already perfect + self.its_mamba = True + self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len self.use_v2_block_manager = use_v2_block_manager diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 3382f19d254..f671da01e69 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -270,6 +270,8 @@ def __init__( version = "v2" if self.scheduler_config.embedding_mode: version = "embedding" + if self.scheduler_config.its_mamba: + version = "embedding" BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( version) @@ -758,8 +760,6 @@ def _schedule_default(self) -> SchedulerOutputs: decodes. If there's a pressure on GPU memory, decode requests can be swapped or preempted. """ - import pdb - pdb.set_trace() # Include running requests to the budget. budget = SchedulingBudget( token_budget=self.scheduler_config.max_num_batched_tokens, @@ -1057,7 +1057,6 @@ def free_finished_seq_groups(self) -> None: if not seq_group.is_finished()) def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: - return #TODO TMS HACK self.block_manager.allocate(seq_group) for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): seq.status = SequenceStatus.RUNNING diff --git a/vllm/sequence.py b/vllm/sequence.py index 5ea623f1cc3..753d7875cac 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -643,9 +643,6 @@ def __init__( encoder_seq_data: Optional[SequenceData] = None, cross_block_table: Optional[List[int]] = None, ) -> None: - import pdb - pdb.set_trace() - self.request_id = request_id self.is_prompt = is_prompt self.seq_data = seq_data @@ -660,7 +657,7 @@ def __init__( self.cross_block_table = cross_block_table self._token_chunk_size = token_chunk_size self.do_sample = do_sample - + # The number of speculative tokens adopted in this request. # None means specuative decoding is not used. # Zero means speculative decoding is disabled for some reasons. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ac257d77bd5..30f2b1d366c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -397,9 +397,6 @@ def _prepare_model_input_tensors( block_aligned_sliding_window = \ sliding_window_blocks * self.block_size - import pdb - pdb.set_trace() - for seq_group_metadata in seq_group_metadata_list: seq_ids = list(seq_group_metadata.seq_data.keys()) is_prompt = seq_group_metadata.is_prompt From eb9bf348032b51520c313a55d7813b43567e5763 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 10 Jul 2024 21:04:07 +0000 Subject: [PATCH 03/15] WIP -- runs through. Generates tokens. Bad tokens. --- vllm/attention/backends/no_attention.py | 6 +++--- vllm/attention/selector.py | 2 -- vllm/config.py | 2 ++ vllm/core/embedding_model_block_manager.py | 2 +- vllm/engine/arg_utils.py | 1 + vllm/engine/llm_engine.py | 9 ++++----- vllm/model_executor/models/mamba.py | 23 +++++++++++++--------- vllm/worker/model_runner.py | 7 ++++--- vllm/worker/worker.py | 7 ++++--- 9 files changed, 33 insertions(+), 26 deletions(-) diff --git a/vllm/attention/backends/no_attention.py b/vllm/attention/backends/no_attention.py index c42f39789eb..25e239b603b 100644 --- a/vllm/attention/backends/no_attention.py +++ b/vllm/attention/backends/no_attention.py @@ -89,7 +89,7 @@ class NoAttentionMetadata(AttentionMetadata): use_cuda_graph: bool _cached_prefill_metadata: Optional["NoAttentionMetadata"] = None - _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["NoAttentionMetadata"] = None @property def prefill_metadata(self) -> Optional["NoAttentionMetadata"]: @@ -125,7 +125,7 @@ def prefill_metadata(self) -> Optional["NoAttentionMetadata"]: return self._cached_prefill_metadata @property - def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: + def decode_metadata(self) -> Optional["NoAttentionMetadata"]: if self.num_decode_tokens == 0: return None @@ -134,7 +134,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: assert self.block_tables is not None assert self.seq_lens_tensor is not None - self._cached_decode_metadata = FlashAttentionMetadata( + self._cached_decode_metadata = NoAttentionMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 93c39e45635..ae63eb1d48f 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -35,8 +35,6 @@ def get_attn_backend( is_blocksparse: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" - import pdb - pdb.set_trace() if is_blocksparse: logger.info("Using BlocksparseFlashAttention backend.") diff --git a/vllm/config.py b/vllm/config.py index d31433552a8..0ea675cbe86 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -443,6 +443,7 @@ def __init__( gpu_memory_utilization: float, swap_space: int, cache_dtype: str, + cache_grows: bool, num_gpu_blocks_override: Optional[int] = None, sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, @@ -452,6 +453,7 @@ def __init__( self.swap_space_bytes = swap_space * _GB self.num_gpu_blocks_override = num_gpu_blocks_override self.cache_dtype = cache_dtype + self.cache_grows = cache_grows self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching self._verify_args() diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py index f2d67306d7c..43a9f9de676 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/embedding_model_block_manager.py @@ -37,7 +37,7 @@ def append_slots( seq: Sequence, num_lookahead_slots: int, ) -> List[Tuple[int, int]]: - return None # type: ignore + return [] def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: pass diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index afa6892d49e..20c010a09b0 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -650,6 +650,7 @@ def create_engine_config(self, ) -> EngineConfig: gpu_memory_utilization=self.gpu_memory_utilization, swap_space=self.swap_space, cache_dtype=self.kv_cache_dtype, + cache_grows=False, num_gpu_blocks_override=self.num_gpu_blocks_override, sliding_window=model_config.get_sliding_window(), enable_prefix_caching=self.enable_prefix_caching) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 6a45524c0fe..07b6912576a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -252,10 +252,9 @@ def __init__( load_config=load_config, ) - if self.model_config.get_num_attention_layers(parallel_config) == 0: - self.cache_config.num_gpu_blocks = 0 - self.cache_config.num_cpu_blocks = 0 - elif not self.model_config.embedding_mode: + if not self.model_config.embedding_mode: + # TODO: Even for mamba, we must initialize the KV caches, + # Because model warmup and CUDA graphs are created here. self._initialize_kv_caches() @@ -852,7 +851,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: 0].schedule() finished_requests_ids = self.scheduler[ 0].get_and_reset_finished_requests_ids() - + if not scheduler_outputs.is_empty(): execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 156d045d296..1e5632816b1 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -422,15 +422,19 @@ def __init__( self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, - ) + + #TODO: this ends up all 0s -- we don't put anything in here when loading weights. + #TODO: Does mamba share weights between the lm head and embeddings? +# self.lm_head = ParallelLMHead( +# self.unpadded_vocab_size, +# config.hidden_size, +# org_num_embeddings=config.vocab_size, +# padding_size=DEFAULT_VOCAB_PADDING_SIZE +# # We need bigger padding if using lora for kernel +# # compatibility +# if not lora_config else lora_config.lora_vocab_padding_size, +# ) + self.lm_head = self.backbone.embeddings # Current step used indices self.current_indices: List[int] = [] # Used to track and store by the Mamba cache between steps. @@ -451,6 +455,7 @@ def forward(self, attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs): + if not self.mamba_cache: self._prepare_mamba_cache() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 30f2b1d366c..006eb3c2d2d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -933,7 +933,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: "You can also reduce the `max_num_seqs` as needed " "to decrease memory usage.") start_time = time.perf_counter() - + # Prepare dummy inputs. These will be reused for all batch sizes. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() @@ -1410,8 +1410,9 @@ def forward( # Copy the input tensors to the input buffers. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True) - self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, - non_blocking=True) + if self.backend_name != "No attention": + self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, + non_blocking=True) if self.backend_name != "flashinfer": self.input_buffers["seq_lens_tensor"].copy_( attn_metadata.decode_metadata.seq_lens_tensor, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d0be7811012..52ac40a5dca 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -209,6 +209,7 @@ def initialize_cache(self, num_gpu_blocks: int, """ raise_if_cache_size_invalid(num_gpu_blocks, self.cache_config.block_size, + self.cache_config.cache_grows, self.model_config.max_model_len) self.cache_config.num_gpu_blocks = num_gpu_blocks @@ -346,14 +347,14 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): "`dtype` flag in CLI, for example: --dtype=half.") -def raise_if_cache_size_invalid(num_gpu_blocks, block_size, +def raise_if_cache_size_invalid(num_gpu_blocks, block_size, cache_grows, max_model_len) -> None: - if num_gpu_blocks <= 0: + if num_gpu_blocks <= 0 and cache_grows: raise ValueError("No available memory for the cache blocks. " "Try increasing `gpu_memory_utilization` when " "initializing the engine.") max_seq_len = block_size * num_gpu_blocks - if max_model_len > max_seq_len: + if max_model_len > max_seq_len and cache_grows: raise ValueError( f"The model's max seq len ({max_model_len}) " "is larger than the maximum number of tokens that can be " From 320f79b348694175e80745a2f0de10fe4874d7f7 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 15 Jul 2024 19:43:43 +0000 Subject: [PATCH 04/15] Good output for mamba-370m --- examples/offline_inference.py | 3 ++- vllm/model_executor/models/mamba.py | 26 +++++++++++++++++--------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index f64082ac0fb..6e6f91542bb 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,4 +1,5 @@ from vllm import LLM, SamplingParams +import torch # Sample prompts. prompts = [ @@ -11,7 +12,7 @@ sampling_params = SamplingParams(temperature=0.0, top_p=0.95) # Create an LLM. -llm = LLM(model="state-spaces/mamba-370m-hf") +llm = LLM(model="state-spaces/mamba-370m-hf", dtype=torch.float32) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 1e5632816b1..fa65f081bca 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -131,16 +131,19 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): ) self.activation = config.hidden_act - self.dt_layernorm = RMSNorm(self.time_step_rank, - eps=config.layer_norm_epsilon) - self.b_layernorm = RMSNorm(self.ssm_state_size, - eps=config.layer_norm_epsilon) - self.c_layernorm = RMSNorm(self.ssm_state_size, - eps=config.layer_norm_epsilon) + # Jamba has layer norms here. Mamba doesn't. + # TODO: Leaving these in for now, just as a placeholder in case mamba2 needs them. + #self.dt_layernorm = RMSNorm(self.time_step_rank, + # eps=config.layer_norm_epsilon) + #self.b_layernorm = RMSNorm(self.ssm_state_size, + # eps=config.layer_norm_epsilon) + #self.c_layernorm = RMSNorm(self.ssm_state_size, + # eps=config.layer_norm_epsilon) def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCacheParams = None): + # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(1, 2) hidden_states, gate = projected_states.chunk(2, dim=1) @@ -180,9 +183,12 @@ def mamba_forward(self, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1, ) - time_step = self.dt_layernorm(time_step.contiguous()) - B = self.b_layernorm(B.contiguous()) - C = self.c_layernorm(C.contiguous()) + + # Jamba has layer norms here. Mamba doesn't. + # TODO: Leaving these in for now, just as a placeholder in case mamba2 needs them. + # time_step = self.dt_layernorm(time_step.contiguous()) + # B = self.b_layernorm(B.contiguous()) + # C = self.c_layernorm(C.contiguous()) discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2) # 3.c perform the recurrence y ← SSM(A, B, C)(x) @@ -382,6 +388,8 @@ def forward( ssm_state=current_ssm_state, ) hidden_states, _ = self.norm_f(hidden_states, residual) + + return hidden_states class MambaForCausalLM(nn.Module): From 5ab6622f2d143e3242e44f1bd25c2faceaa93b1a Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 16 Jul 2024 15:37:34 +0000 Subject: [PATCH 05/15] wip --- examples/offline_inference.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 6e6f91542bb..cb74561f35e 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -12,7 +12,9 @@ sampling_params = SamplingParams(temperature=0.0, top_p=0.95) # Create an LLM. -llm = LLM(model="state-spaces/mamba-370m-hf", dtype=torch.float32) +#llm = LLM(model="state-spaces/mamba-370m-hf", dtype=torch.float32) +llm = LLM(model="state-spaces/mamba2-130m", dtype=torch.float32) + # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) From 25b54d95458670402908eb4b4f11f1da575b0f85 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 16 Jul 2024 17:09:55 +0000 Subject: [PATCH 06/15] cleanup --- .../{no_attention.py => placeholder_attn.py} | 38 +- vllm/config.py | 38 +- vllm/core/scheduler.py | 6 +- vllm/engine/arg_utils.py | 3 +- vllm/engine/llm_engine.py | 6 +- vllm/model_executor/models/2 | 728 ++++++++++++++++++ vllm/model_executor/models/__init__.py | 1 - vllm/model_executor/models/mamba.py | 81 +- vllm/sequence.py | 2 +- vllm/worker/model_runner.py | 10 +- vllm/worker/worker.py | 14 +- 11 files changed, 822 insertions(+), 105 deletions(-) rename vllm/attention/backends/{no_attention.py => placeholder_attn.py} (82%) create mode 100644 vllm/model_executor/models/2 diff --git a/vllm/attention/backends/no_attention.py b/vllm/attention/backends/placeholder_attn.py similarity index 82% rename from vllm/attention/backends/no_attention.py rename to vllm/attention/backends/placeholder_attn.py index 25e239b603b..6bc766ba4e3 100644 --- a/vllm/attention/backends/no_attention.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -1,12 +1,15 @@ -from dataclasses import dataclass, fields -from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type, - TypeVar) +from dataclasses import dataclass +from typing import (List, Optional, Tuple, Type) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata) import torch +# Placeholder attention backend for models like Mamba that don't have attention. +# Mainly exists to sidestep get_attn_backend. +# The attention metadata is still needed for Mamba. -class NoAttentionBackend(AttentionBackend): + +class PlaceholderAttentionBackend(AttentionBackend): """Placeholder backend for when no attention is needed.""" @staticmethod @@ -14,12 +17,12 @@ def get_name() -> str: return "No attention" @staticmethod - def get_impl_cls() -> Type["NoAttentionImpl"]: - return NoAttentionImpl + def get_impl_cls() -> Type["PlaceholderAttentionImpl"]: + return PlaceholderAttentionImpl @staticmethod - def get_metadata_cls() -> Type["NoAttentionMetadata"]: - return NoAttentionMetadata + def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]: + return PlaceholderAttentionMetadata @staticmethod def get_kv_cache_shape( @@ -47,7 +50,7 @@ def copy_blocks( @dataclass -class NoAttentionMetadata(AttentionMetadata): +class PlaceholderAttentionMetadata(AttentionMetadata): """Attention metadata for prefill and decode batched together.""" # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. @@ -88,11 +91,11 @@ class NoAttentionMetadata(AttentionMetadata): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool - _cached_prefill_metadata: Optional["NoAttentionMetadata"] = None - _cached_decode_metadata: Optional["NoAttentionMetadata"] = None + _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None + _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None @property - def prefill_metadata(self) -> Optional["NoAttentionMetadata"]: + def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: if self.num_prefills == 0: return None @@ -106,7 +109,7 @@ def prefill_metadata(self) -> Optional["NoAttentionMetadata"]: assert self.block_tables is not None assert self.seq_start_loc is not None - self._cached_prefill_metadata = NoAttentionMetadata( + self._cached_prefill_metadata = PlaceholderAttentionMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, @@ -125,7 +128,7 @@ def prefill_metadata(self) -> Optional["NoAttentionMetadata"]: return self._cached_prefill_metadata @property - def decode_metadata(self) -> Optional["NoAttentionMetadata"]: + def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: if self.num_decode_tokens == 0: return None @@ -134,7 +137,7 @@ def decode_metadata(self) -> Optional["NoAttentionMetadata"]: assert self.block_tables is not None assert self.seq_lens_tensor is not None - self._cached_decode_metadata = NoAttentionMetadata( + self._cached_decode_metadata = PlaceholderAttentionMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, @@ -153,9 +156,10 @@ def decode_metadata(self) -> Optional["NoAttentionMetadata"]: return self._cached_decode_metadata -class NoAttentionImpl(AttentionImpl): +class PlaceholderAttentionImpl(AttentionImpl): + def __init__(self, *args, **kwargs) -> None: return - + def forward(self, *args, **kwargs) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/config.py b/vllm/config.py index 0ae4acf4403..26ad889cceb 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -276,6 +276,19 @@ def verify_with_parallel_config( raise ValueError( "BitAndBytes quantization with TP or PP is not supported yet.") + def is_attention_free(self) -> bool: + """Returns True if the model has no attention, i.e. the model has no + state that grows with the size of the context. + """ + + # Return true if the model is mamba. + # This check should be augmented with more models in the future, + # and made more robust if possible. + if hasattr(self.hf_text_config, + "model_type") and self.hf_text_config.model_type == 'mamba': + return True + return False + def get_hf_config_sliding_window(self) -> Optional[int]: """Get the sliding window size, or None if disabled.""" @@ -310,10 +323,8 @@ def get_head_size(self) -> int: # we need to pad head_size 192 to 256 return 256 - if hasattr(self.hf_text_config, "model_type" - ) and self.hf_text_config.model_type == 'mamba': - # Is this going to explode - return 0 + if self.is_attention_free(): + return 0 if hasattr(self.hf_text_config, "head_dim"): return self.hf_text_config.head_dim @@ -345,7 +356,8 @@ def get_total_num_kv_heads(self) -> int: if self.hf_config.model_type == "dbrx": return getattr(self.hf_config.attn_config, "kv_n_heads", self.hf_config.num_attention_heads) - if self.hf_config.model_type == "mamba": + + if self.is_attention_free(): return 0 attributes = [ @@ -398,8 +410,9 @@ def contains_seqlen_agnostic_layers( def get_layers_block_type(self, parallel_config: "ParallelConfig") -> List[str]: num_layers = self.get_num_layers(parallel_config) - - if self.hf_config.model_type == "mamba": + + if self.is_attention_free(): + assert (self.hf_config.model_type == "mamba") return ["mamba"] * num_layers # Transformers supports layers_block_type @property @@ -440,7 +453,7 @@ def __init__( gpu_memory_utilization: float, swap_space: int, cache_dtype: str, - cache_grows: bool, + is_attention_free: bool, num_gpu_blocks_override: Optional[int] = None, sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, @@ -450,7 +463,7 @@ def __init__( self.swap_space_bytes = swap_space * _GB self.num_gpu_blocks_override = num_gpu_blocks_override self.cache_dtype = cache_dtype - self.cache_grows = cache_grows + self.is_attention_free = is_attention_free self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching self._verify_args() @@ -745,6 +758,8 @@ class SchedulerConfig: iteration. max_model_len: Maximum length of a sequence (including prompt and generated text). + is_attention_free: True if the running model does not have state that + grows as the context size increases. use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not. num_lookahead_slots: The number of slots to allocate per sequence per step, beyond the known token ids. This is used in speculative @@ -767,6 +782,7 @@ def __init__(self, max_num_batched_tokens: Optional[int], max_num_seqs: int, max_model_len: int, + is_attention_free: bool, use_v2_block_manager: bool = False, num_lookahead_slots: int = 0, delay_factor: float = 0.0, @@ -791,11 +807,9 @@ def __init__(self, if enable_chunked_prefill: logger.info("Chunked prefill is enabled (EXPERIMENTAL).") - #TODO: already perfect - self.its_mamba = True - self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len + self.is_attention_free = is_attention_free self.use_v2_block_manager = use_v2_block_manager self.num_lookahead_slots = num_lookahead_slots self.delay_factor = delay_factor diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index ae46f2f1a96..ef183cda9d6 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -279,9 +279,8 @@ def __init__( version = "v1" if self.scheduler_config.use_v2_block_manager: version = "v2" - if self.scheduler_config.embedding_mode: - version = "embedding" - if self.scheduler_config.its_mamba: + if (self.scheduler_config.embedding_mode + or self.scheduler_config.is_attention_free): version = "embedding" BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( @@ -708,7 +707,6 @@ def _schedule_prefills( # If the sequence group cannot be allocated, stop. can_allocate = self.block_manager.can_allocate(seq_group) - can_allocate = True #TODO HACK TMS if can_allocate == AllocStatus.LATER: break elif can_allocate == AllocStatus.NEVER: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index cd6912466a3..20bfd71221e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -664,7 +664,7 @@ def create_engine_config(self, ) -> EngineConfig: gpu_memory_utilization=self.gpu_memory_utilization, swap_space=self.swap_space, cache_dtype=self.kv_cache_dtype, - cache_grows=False, + is_attention_free=model_config.is_attention_free(), num_gpu_blocks_override=self.num_gpu_blocks_override, sliding_window=model_config.get_sliding_window(), enable_prefix_caching=self.enable_prefix_caching) @@ -709,6 +709,7 @@ def create_engine_config(self, ) -> EngineConfig: max_num_batched_tokens=self.max_num_batched_tokens, max_num_seqs=self.max_num_seqs, max_model_len=model_config.max_model_len, + is_attention_free=model_config.is_attention_free(), use_v2_block_manager=self.use_v2_block_manager, num_lookahead_slots=(self.num_lookahead_slots if speculative_config is None else diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index af6191b7b35..c43f7fcb854 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -260,11 +260,11 @@ def __init__( ) if not self.model_config.embedding_mode: - # TODO: Even for mamba, we must initialize the KV caches, - # Because model warmup and CUDA graphs are created here. + # For all decoders including attention-free models like mamba, + # this must call _initialize_kv_caches, as this is where model + # warmup and CUDA graphs creation happens. self._initialize_kv_caches() - # If usage stat is enabled, collect relevant info. if is_usage_stats_enabled(): from vllm.model_executor.model_loader import ( diff --git a/vllm/model_executor/models/2 b/vllm/model_executor/models/2 new file mode 100644 index 00000000000..0452e9b3381 --- /dev/null +++ b/vllm/model_executor/models/2 @@ -0,0 +1,728 @@ +# coding=utf-8 +"""PyTorch MAMBA model.""" +import pdb +import traceback +import inspect + +from dataclasses import dataclass +from typing import Dict, Iterable, List, Optional, Tuple + +import torch +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from torch import nn +from torch.nn.parameter import Parameter +from transformers import MambaConfig + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.worker.model_runner import _BATCH_SIZES_TO_CAPTURE + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +def function_in_stack(function_name): + stack = traceback.extract_stack() + for frame in stack: + if frame.name == function_name: + return True + return False + +@dataclass +class MambaCacheParams: + is_prompt: bool = False + conv_state: torch.Tensor = torch.Tensor() + ssm_state: torch.Tensor = torch.Tensor() + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +class MambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) + """ + + def __init__(self, config: MambaConfig, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = config.intermediate_size + self.time_step_rank = int(config.time_step_rank) + self.use_conv_bias = config.use_conv_bias + + # TODO: ?? + #self.use_bias = config.mamba_proj_bias + self.use_bias = False + + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.intermediate_size, + bias=self.use_conv_bias, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear(self.hidden_size, + [self.intermediate_size] * 2, + bias=self.use_bias) + # selective projection used to make dt, B and C input dependent + self.x_proj = RowParallelLinear( + self.intermediate_size, + self.time_step_rank + self.ssm_state_size * 2, + bias=False, + ) + # time step projection (discretization) - + # In the forward we need to apply dt_proj without the bias, + # as the bias is added in the selective scan kernel. + self.dt_proj = ColumnParallelLinear(self.time_step_rank, + self.intermediate_size, + bias=True, + skip_bias_add=True) + + def weight_loader(param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + param.data.copy_( + loaded_weight.data.split(loaded_weight.shape[0] // tp_size, + dim=0)[tp_rank]) + + def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): + weight_loader(param, -torch.exp(loaded_weight.float())) + + tp_size = get_tensor_model_parallel_world_size() + self.A = nn.Parameter( + torch.empty( + self.intermediate_size // tp_size, + self.ssm_state_size, + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) + + set_weight_attrs(self.D, {"weight_loader": weight_loader}) + set_weight_attrs(self.A, {"weight_loader": A_weight_loader}) + + self.out_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=self.use_bias, + input_is_parallel=True, + ) + self.activation = config.hidden_act + + # Jamba has layer norms here. Mamba doesn't. + # TODO: Leaving these in for now, just as a placeholder in case mamba2 needs them. + #self.dt_layernorm = RMSNorm(self.time_step_rank, + # eps=config.layer_norm_epsilon) + #self.b_layernorm = RMSNorm(self.ssm_state_size, + # eps=config.layer_norm_epsilon) + #self.c_layernorm = RMSNorm(self.ssm_state_size, + # eps=config.layer_norm_epsilon) + + def mamba_forward(self, + hidden_states: torch.Tensor, + cache_params: MambaCacheParams = None): + + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states)[0].transpose(1, 2) + hidden_states, gate = projected_states.chunk(2, dim=1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + if cache_params is not None and not cache_params.is_prompt: + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), + cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.unsqueeze(-1) + else: + if cache_params is not None: + conv_states = nn.functional.pad( + hidden_states, + (self.conv_kernel_size - hidden_states.shape[-1], 0)) + cache_params.conv_state.copy_(conv_states) + + hidden_states = causal_conv1d_fn( + hidden_states, + conv_weights, + self.conv1d.bias, + activation=self.activation, + ) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0] + + time_step, B, C = torch.split( + ssm_parameters, + [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], + dim=-1, + ) + + # Jamba has layer norms here. Mamba doesn't. + # TODO: Leaving these in for now, just as a placeholder in case mamba2 needs them. + # time_step = self.dt_layernorm(time_step.contiguous()) + # B = self.b_layernorm(B.contiguous()) + # C = self.c_layernorm(C.contiguous()) + + discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = (self.dt_proj.bias.float() if hasattr( + self.dt_proj, "bias") else None) + if cache_params is not None and not cache_params.is_prompt: + scan_outputs = selective_state_update( + cache_params.ssm_state, + hidden_states[..., 0], + discrete_time_step[..., 0], + self.A, + B[:, 0], + C[:, 0], + self.D, + gate[..., 0], + time_proj_bias, + dt_softplus=True, + ).unsqueeze(-1) + else: + scan_outputs, ssm_state = selective_scan_fn( + hidden_states, + discrete_time_step, + self.A, + B.transpose(1, 2), + C.transpose(1, 2), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + return_last_state=True, + ) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_state.copy_(ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0] + return contextualized_states + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ): + if attn_metadata.prefill_metadata is not None: + offset = 0 + for i, prompt_len in enumerate( + attn_metadata.prefill_metadata.seq_lens): + cache = MambaCacheParams(True, + conv_state=conv_state[i].unsqueeze(0), + ssm_state=ssm_state[i].unsqueeze(0)) + hidden_states[offset:offset + prompt_len].copy_( + self.mamba_forward(hidden_states[offset:offset + + prompt_len].unsqueeze(0), + cache_params=cache)[0]) + offset += prompt_len + else: + cache = MambaCacheParams(False, + conv_state=conv_state, + ssm_state=ssm_state) + hidden_states = self.mamba_forward(hidden_states.unsqueeze(1), + cache_params=cache) + hidden_states = hidden_states.squeeze(1) + + return hidden_states + +class MambaMLP(nn.Module): + + def __init__( + self, + config: MambaConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + hidden_act = config.hidden_act + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + +class MambaDecoderLayer(nn.Module): + + def __init__(self, + config: MambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.mixer = MambaMixer(config, layer_idx) + + self.feed_forward = MambaMLP(config, quant_config=quant_config) + self.norm = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm( + hidden_states, residual) + + hidden_states = self.mixer(hidden_states, attn_metadata, conv_state, + ssm_state) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + +class MambaModel(nn.Module): + + def __init__( + self, + config: MambaConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embeddings = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + decoder_layers = [] + for i in range(config.num_hidden_layers): + decoder_layers.append( + MambaDecoderLayer(config, + layer_idx=i, + cache_config=cache_config, + quant_config=quant_config)) + self.layers = nn.ModuleList(decoder_layers) + self.norm_f = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.embeddings(input_ids) + residual = None + + for i in range(len(self.layers)): + layer = self.layers[i] + current_ssm_state = ssm_state[i] + current_conv_state = conv_state[i] + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + residual=residual, + conv_state=current_conv_state, + ssm_state=current_ssm_state, + ) + hidden_states, _ = self.norm_f(hidden_states, residual) + + + return hidden_states + +class MambaForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embeddings": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: MambaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.backbone = MambaModel(config, + cache_config=cache_config, + quant_config=quant_config, + lora_config=lora_config) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + #TODO: this ends up all 0s -- we don't put anything in here when loading weights. + #TODO: Does mamba share weights between the lm head and embeddings? +# self.lm_head = ParallelLMHead( +# self.unpadded_vocab_size, +# config.hidden_size, +# org_num_embeddings=config.vocab_size, +# padding_size=DEFAULT_VOCAB_PADDING_SIZE +# # We need bigger padding if using lora for kernel +# # compatibility +# if not lora_config else lora_config.lora_vocab_padding_size, +# ) + self.lm_head = self.backbone.embeddings + # Current step used indices + self.current_indices: List[int] = [] + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple() + # Used as an input_buffer for the CUDA graph runs. + self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple() + # Maps between the request id and a dict that maps between the seq_id + # and its index inside the self.mamba_cache + self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs): + + if not self.mamba_cache: + self._prepare_mamba_cache() + + if "seqlen_agnostic_capture_inputs" not in kwargs: + # We get here only on Prefill/Eager mode runs + assert all( + key in kwargs + for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + batch_size = input_ids.shape[0] + if attn_metadata.prefill_metadata: + batch_size = len(request_ids_to_seq_ids) + ( + current_seqlen_agnostic_cache, + indices, + ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, + batch_size) + finished_requests_ids = kwargs["finished_requests_ids"] + self._release_mamba_cache(finished_requests_ids) + else: + # CUDA graph capturing runs + current_seqlen_agnostic_cache, indices = ( + kwargs["seqlen_agnostic_capture_inputs"], + [], + ) + self.current_indices = indices + + hidden_states = self.backbone(input_ids, positions, kv_caches, + attn_metadata, + current_seqlen_agnostic_cache[0], + current_seqlen_agnostic_cache[1]) + + if "seqlen_agnostic_capture_inputs" not in kwargs: + self._copy_mamba_cache_by_indices(self.current_indices, + current_seqlen_agnostic_cache) + + return hidden_states + + def _copy_mamba_cache_by_indices( + self, indices: List[int], + current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]): + for i, offset in enumerate(indices): + self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache) + + def _copy_mamba_cache(self, index_to: int, index_from: int, + from_buffer: Tuple[torch.Tensor, torch.Tensor]): + assert len(self.mamba_cache) > 0 + for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer): + cache_t[:, index_to].copy_(from_buffer_t[:, index_from], + non_blocking=True) + + def _assign_seq_id_to_mamba_cache(self, cur_rid: str, + seqs_id: List[int]) -> List[int]: + indices_for_current_run = [] + for seq_id in seqs_id: + if cur_rid not in self.mamba_cache_indices_mapping: + self.mamba_cache_indices_mapping[cur_rid] = {} + first_free_index = self._first_free_index_in_mamba_cache() + self.mamba_cache_indices_mapping[cur_rid][ + seq_id] = first_free_index + index_for_current_run = first_free_index + ## case of decoding n>1, copy prefill cache to decoding indices + elif seq_id not in (seq_ids2indices := + self.mamba_cache_indices_mapping[cur_rid]): + first_free_index = self._first_free_index_in_mamba_cache() + index_exist = list(seq_ids2indices.values())[0] + self._copy_mamba_cache(index_from=index_exist, + index_to=first_free_index, + from_buffer=self.mamba_cache) + self.mamba_cache_indices_mapping[cur_rid][ + seq_id] = first_free_index + index_for_current_run = first_free_index + else: + index_for_current_run = self.mamba_cache_indices_mapping[ + cur_rid][seq_id] + + indices_for_current_run.append(index_for_current_run) + return indices_for_current_run + + def _prepare_current_run_mamba_cache( + self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]: + indices_for_current_run = [] + for request_id, seqs_id in request_ids_to_seq_ids.items(): + indices_for_current_run += self._assign_seq_id_to_mamba_cache( + request_id, seqs_id) + ## Pad the batch in case of running batch that was not captured via CG + padded_indices = indices_for_current_run.copy() + pad_index = self._first_free_index_in_mamba_cache() + + for _ in range(batch_size - len(indices_for_current_run)): + padded_indices.append(pad_index) + + conv_state = self.mamba_cache[0][:, padded_indices] + temporal_state = self.mamba_cache[1][:, padded_indices] + + return (conv_state, temporal_state), indices_for_current_run + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + """ + Copy the relevant Mamba cache into the CUDA graph input buffer + that was provided during the capture runs + (MambaForCausalLM.mamba_gc_cache_buffer). + """ + assert all( + key in kwargs + for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + batch_size = len(request_ids_to_seq_ids) + ( + current_mamba_cache, + indices, + ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, + batch_size) + self.current_indices = indices + finished_requests_ids = kwargs["finished_requests_ids"] + self._release_mamba_cache(finished_requests_ids) + + for input_buffer, current_cache_buffer in zip( + input_buffers["seqlen_agnostic_capture_inputs"], + current_mamba_cache): + input_buffer.copy_(current_cache_buffer, non_blocking=True) + + def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs): + """ + Copy the relevant Mamba cache from the CUDA graph input_buffers + back to the MambaForCausalLM.mamba_cache after CUDA + graph replay run is done. + """ + self._copy_mamba_cache_by_indices( + self.current_indices, + input_buffers["seqlen_agnostic_capture_inputs"]) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + """ + Provide the CUDA graph capture runs with a buffer in adjusted size. + The buffer is used to maintain the Mamba Cache during the CUDA graph + replay runs. + """ + return tuple(buffer[:, :batch_size] + for buffer in self.mamba_gc_cache_buffer) + + def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): + for req_id in finished_seq_groups_req_ids: + if req_id in self.mamba_cache_indices_mapping: + self.mamba_cache_indices_mapping.pop(req_id) + + def _first_free_index_in_mamba_cache(self) -> int: + if self.mamba_cache: + max_possible_batch_size = self.mamba_cache[0].shape[1] + occupied = [ + id for seq_ids in self.mamba_cache_indices_mapping.values() + for id in seq_ids.values() + ] + first_free_index = [ + i not in occupied for i in range(max_possible_batch_size) + ].index(True) + return first_free_index + return 0 + + def _get_mamba_cache_shape( + self + ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + conv_state_shape = ( + self.config.intermediate_size // world_size, + self.config.conv_kernel, + ) + temporal_state_shape = ( + self.config.intermediate_size // world_size, + self.config.state_size, + ) + return conv_state_shape, temporal_state_shape + + def _prepare_mamba_cache(self): + dtype = self.lm_head.weight.dtype + num_mamba_layers = self.config.num_hidden_layers + max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + 10 + conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() + assert conv_state_shape is not None and temporal_state_shape is not None + for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]: + buffer = (torch.empty(size=(num_mamba_layers, max_batch_size) + + conv_state_shape, + dtype=dtype, + device="cuda"), + torch.empty(size=(num_mamba_layers, max_batch_size) + + temporal_state_shape, + dtype=dtype, + device="cuda")) + setattr(self, buffername, buffer) + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + for k, v in params_dict.items(): + print(k) + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "A_log" in name: + name = name.replace("A_log", "A") + + if ".self_attn." in name: + name = name.replace(".self_attn", "") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 130f00e8645..7ce2b6fa5c9 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -75,7 +75,6 @@ _EMBEDDING_MODELS = { "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), } - _MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS} # Architecture -> type. diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index fa65f081bca..ca8d58fd3d6 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -12,25 +12,20 @@ from transformers import MambaConfig from vllm.attention.backends.abstract import AttentionMetadata -from vllm.attention.layer import Attention from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs @@ -46,6 +41,7 @@ class MambaCacheParams: conv_state: torch.Tensor = torch.Tensor() ssm_state: torch.Tensor = torch.Tensor() + # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer class MambaMixer(nn.Module): """ @@ -131,15 +127,6 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): ) self.activation = config.hidden_act - # Jamba has layer norms here. Mamba doesn't. - # TODO: Leaving these in for now, just as a placeholder in case mamba2 needs them. - #self.dt_layernorm = RMSNorm(self.time_step_rank, - # eps=config.layer_norm_epsilon) - #self.b_layernorm = RMSNorm(self.ssm_state_size, - # eps=config.layer_norm_epsilon) - #self.c_layernorm = RMSNorm(self.ssm_state_size, - # eps=config.layer_norm_epsilon) - def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCacheParams = None): @@ -184,11 +171,7 @@ def mamba_forward(self, dim=-1, ) - # Jamba has layer norms here. Mamba doesn't. - # TODO: Leaving these in for now, just as a placeholder in case mamba2 needs them. - # time_step = self.dt_layernorm(time_step.contiguous()) - # B = self.b_layernorm(B.contiguous()) - # C = self.c_layernorm(C.contiguous()) + # Note that Jamba normalizes B, C, and time_step here but Mamba doesn't. discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2) # 3.c perform the recurrence y ← SSM(A, B, C)(x) @@ -256,6 +239,7 @@ def forward( return hidden_states + class MambaMLP(nn.Module): def __init__( @@ -286,6 +270,7 @@ def forward(self, x): x, _ = self.down_proj(x) return x + class MambaDecoderLayer(nn.Module): def __init__(self, @@ -299,8 +284,7 @@ def __init__(self, self.mixer = MambaMixer(config, layer_idx) self.feed_forward = MambaMLP(config, quant_config=quant_config) - self.norm = RMSNorm(config.hidden_size, - eps=config.layer_norm_epsilon) + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -317,8 +301,7 @@ def forward( residual = hidden_states hidden_states = self.norm(hidden_states) else: - hidden_states, residual = self.norm( - hidden_states, residual) + hidden_states, residual = self.norm(hidden_states, residual) hidden_states = self.mixer(hidden_states, attn_metadata, conv_state, ssm_state) @@ -328,6 +311,7 @@ def forward( hidden_states = self.feed_forward(hidden_states) return hidden_states, residual + class MambaModel(nn.Module): def __init__( @@ -355,12 +339,12 @@ def __init__( for i in range(config.num_hidden_layers): decoder_layers.append( MambaDecoderLayer(config, - layer_idx=i, - cache_config=cache_config, - quant_config=quant_config)) + layer_idx=i, + cache_config=cache_config, + quant_config=quant_config)) self.layers = nn.ModuleList(decoder_layers) self.norm_f = RMSNorm(config.hidden_size, - eps=config.layer_norm_epsilon) + eps=config.layer_norm_epsilon) def forward( self, @@ -389,9 +373,9 @@ def forward( ) hidden_states, _ = self.norm_f(hidden_states, residual) - return hidden_states + class MambaForCausalLM(nn.Module): packed_modules_mapping = { "qkv_proj": [ @@ -424,24 +408,13 @@ def __init__( super().__init__() self.config = config self.backbone = MambaModel(config, - cache_config=cache_config, - quant_config=quant_config, - lora_config=lora_config) + cache_config=cache_config, + quant_config=quant_config, + lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - #TODO: this ends up all 0s -- we don't put anything in here when loading weights. - #TODO: Does mamba share weights between the lm head and embeddings? -# self.lm_head = ParallelLMHead( -# self.unpadded_vocab_size, -# config.hidden_size, -# org_num_embeddings=config.vocab_size, -# padding_size=DEFAULT_VOCAB_PADDING_SIZE -# # We need bigger padding if using lora for kernel -# # compatibility -# if not lora_config else lora_config.lora_vocab_padding_size, -# ) self.lm_head = self.backbone.embeddings # Current step used indices self.current_indices: List[int] = [] @@ -493,9 +466,9 @@ def forward(self, self.current_indices = indices hidden_states = self.backbone(input_ids, positions, kv_caches, - attn_metadata, - current_seqlen_agnostic_cache[0], - current_seqlen_agnostic_cache[1]) + attn_metadata, + current_seqlen_agnostic_cache[0], + current_seqlen_agnostic_cache[1]) if "seqlen_agnostic_capture_inputs" not in kwargs: self._copy_mamba_cache_by_indices(self.current_indices, @@ -565,9 +538,9 @@ def _prepare_current_run_mamba_cache( def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): """ - Copy the relevant Mamba cache into the CUDA graph input buffer - that was provided during the capture runs - (MambaForCausalLM.mamba_gc_cache_buffer). + Copy the relevant Mamba cache into the CUDA graph input buffer + that was provided during the capture runs + (MambaForCausalLM.mamba_gc_cache_buffer). """ assert all( key in kwargs @@ -591,7 +564,7 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs): """ Copy the relevant Mamba cache from the CUDA graph input_buffers - back to the MambaForCausalLM.mamba_cache after CUDA + back to the MambaForCausalLM.mamba_cache after CUDA graph replay run is done. """ self._copy_mamba_cache_by_indices( @@ -601,7 +574,7 @@ def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs): def get_seqlen_agnostic_capture_inputs(self, batch_size: int): """ Provide the CUDA graph capture runs with a buffer in adjusted size. - The buffer is used to maintain the Mamba Cache during the CUDA graph + The buffer is used to maintain the Mamba Cache during the CUDA graph replay runs. """ return tuple(buffer[:, :batch_size] @@ -629,7 +602,6 @@ def _get_mamba_cache_shape( self ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: world_size = get_tensor_model_parallel_world_size() - hidden_size = self.config.hidden_size conv_state_shape = ( self.config.intermediate_size // world_size, self.config.conv_kernel, @@ -682,9 +654,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters()) - for k, v in params_dict.items(): - print(k) - for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/sequence.py b/vllm/sequence.py index 6753d7f86b6..1cebf68d463 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -684,7 +684,7 @@ def __init__( self.cross_block_table = cross_block_table self._token_chunk_size = token_chunk_size self.do_sample = do_sample - + # The number of speculative tokens adopted in this request. # None means specuative decoding is not used. # Zero means speculative decoding is disabled for some reasons. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 56c2693d661..459798e418c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -23,7 +23,7 @@ FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.attention.backends.no_attention import NoAttentionBackend +from vllm.attention.backends.placeholder_attn import PlaceholderAttentionBackend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, @@ -236,7 +236,7 @@ def __init__( self.model_config.dtype, self.kv_cache_dtype, self.block_size, - ) if num_attn_heads else NoAttentionBackend() + ) if num_attn_heads else PlaceholderAttentionBackend() # Multi-modal data support self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ @@ -1016,7 +1016,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: "You can also reduce the `max_num_seqs` as needed " "to decrease memory usage.") start_time = time.perf_counter() - + # Prepare dummy inputs. These will be reused for all batch sizes. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() @@ -1509,8 +1509,8 @@ def forward( self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True) if self.backend_name != "No attention": - self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, - non_blocking=True) + self.input_buffers["slot_mapping"].copy_( + attn_metadata.slot_mapping, non_blocking=True) if self.backend_name != "flashinfer": self.input_buffers["seq_lens_tensor"].copy_( attn_metadata.decode_metadata.seq_lens_tensor, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 60f80b135e3..f80b8be89a8 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -215,12 +215,12 @@ def initialize_cache(self, num_gpu_blocks: int, """ raise_if_cache_size_invalid(num_gpu_blocks, self.cache_config.block_size, - self.cache_config.cache_grows, + self.cache_config.is_attention_free, self.model_config.max_model_len) self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - + self._init_cache_engine() self._warm_up_model() @@ -366,14 +366,18 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): "`dtype` flag in CLI, for example: --dtype=half.") -def raise_if_cache_size_invalid(num_gpu_blocks, block_size, cache_grows, +def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free, max_model_len) -> None: - if num_gpu_blocks <= 0 and cache_grows: + if is_attention_free and num_gpu_blocks != 0: + raise ValueError("No memory should be allocated for the cache blocks " + f"for an attention-free model, but {num_gpu_blocks}" + "blocks are allocated.") + if not is_attention_free and num_gpu_blocks <= 0: raise ValueError("No available memory for the cache blocks. " "Try increasing `gpu_memory_utilization` when " "initializing the engine.") max_seq_len = block_size * num_gpu_blocks - if max_model_len > max_seq_len and cache_grows: + if not is_attention_free and max_model_len > max_seq_len: raise ValueError( f"The model's max seq len ({max_model_len}) " "is larger than the maximum number of tokens that can be " From ebc12f1ee699f8d0b83dd5f75e402447729180fa Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 16 Jul 2024 17:15:23 +0000 Subject: [PATCH 07/15] Rename embedding block space manager --- vllm/core/interfaces.py | 8 ++++---- ...lock_manager.py => placeholder_block_space_manager.py} | 7 ++++--- vllm/core/scheduler.py | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) rename vllm/core/{embedding_model_block_manager.py => placeholder_block_space_manager.py} (90%) diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 8759ee06795..d964898af19 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -35,10 +35,10 @@ def get_block_space_manager_class(version: str): from vllm.core.block_manager_v2 import BlockSpaceManagerV2 return BlockSpaceManagerV2 - if version == "embedding": - from vllm.core.embedding_model_block_manager import ( - EmbeddingModelBlockSpaceManager) - return EmbeddingModelBlockSpaceManager + if version == "placeholder": + from vllm.core.placeholder_block_space_manager import ( + PlaceholderBlockSpaceManager) + return PlaceholderBlockSpaceManager raise ValueError(f"Unknown version {version=}") diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/placeholder_block_space_manager.py similarity index 90% rename from vllm/core/embedding_model_block_manager.py rename to vllm/core/placeholder_block_space_manager.py index 43a9f9de676..a71e6f79b6d 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/placeholder_block_space_manager.py @@ -4,9 +4,10 @@ from vllm.sequence import Sequence, SequenceGroup -class EmbeddingModelBlockSpaceManager(BlockSpaceManager): - """An embedding version of BlockSpaceManager for use in environments - with embedding models where block management is not required. +class PlaceholderBlockSpaceManager(BlockSpaceManager): + """A version of BlockSpaceManager for use in environments + where block management is not required. + For example: embedding models or attention-free models like Mamba. This class provides the same interface as BlockSpaceManager, but its methods perform no actions or return simple values like True in specific diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index ef183cda9d6..f004df21169 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -281,7 +281,7 @@ def __init__( version = "v2" if (self.scheduler_config.embedding_mode or self.scheduler_config.is_attention_free): - version = "embedding" + version = "placeholder" BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( version) From ac60374b8637ea2ed00ebaa159c06979e38ada44 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 16 Jul 2024 17:30:29 +0000 Subject: [PATCH 08/15] cleanup --- examples/offline_inference.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index cb74561f35e..9b758fa2479 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,5 +1,4 @@ from vllm import LLM, SamplingParams -import torch # Sample prompts. prompts = [ @@ -9,12 +8,10 @@ "The future of AI is", ] # Create a sampling params object. -sampling_params = SamplingParams(temperature=0.0, top_p=0.95) +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -#llm = LLM(model="state-spaces/mamba-370m-hf", dtype=torch.float32) -llm = LLM(model="state-spaces/mamba2-130m", dtype=torch.float32) - +llm = LLM(model="facebook/opt-125m") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) From adb6713830e1f5c252f6de71d7173a55197dfe1d Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 16 Jul 2024 17:31:51 +0000 Subject: [PATCH 09/15] remove file --- vllm/model_executor/models/2 | 728 ----------------------------------- 1 file changed, 728 deletions(-) delete mode 100644 vllm/model_executor/models/2 diff --git a/vllm/model_executor/models/2 b/vllm/model_executor/models/2 deleted file mode 100644 index 0452e9b3381..00000000000 --- a/vllm/model_executor/models/2 +++ /dev/null @@ -1,728 +0,0 @@ -# coding=utf-8 -"""PyTorch MAMBA model.""" -import pdb -import traceback -import inspect - -from dataclasses import dataclass -from typing import Dict, Iterable, List, Optional, Tuple - -import torch -from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -from mamba_ssm.ops.selective_scan_interface import selective_scan_fn -from mamba_ssm.ops.triton.selective_state_update import selective_state_update -from torch import nn -from torch.nn.parameter import Parameter -from transformers import MambaConfig - -from vllm.attention.backends.abstract import AttentionMetadata -from vllm.attention.layer import Attention -from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.worker.model_runner import _BATCH_SIZES_TO_CAPTURE - -KVCache = Tuple[torch.Tensor, torch.Tensor] - - -def function_in_stack(function_name): - stack = traceback.extract_stack() - for frame in stack: - if frame.name == function_name: - return True - return False - -@dataclass -class MambaCacheParams: - is_prompt: bool = False - conv_state: torch.Tensor = torch.Tensor() - ssm_state: torch.Tensor = torch.Tensor() - -# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer -class MambaMixer(nn.Module): - """ - Compute ∆, A, B, C, and D the state space parameters and compute - the `contextualized_states`. A, D are input independent - (see Mamba paper [1] Section 3.5.2 "Interpretation of A" - for why A isn't selective) ∆, B, C are input-dependent - (this is a key difference between Mamba and the linear time - invariant S4, and is why Mamba is called - **selective** state spaces) - """ - - def __init__(self, config: MambaConfig, layer_idx): - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.hidden_size = config.hidden_size - self.ssm_state_size = config.state_size - self.conv_kernel_size = config.conv_kernel - self.intermediate_size = config.intermediate_size - self.time_step_rank = int(config.time_step_rank) - self.use_conv_bias = config.use_conv_bias - - # TODO: ?? - #self.use_bias = config.mamba_proj_bias - self.use_bias = False - - self.conv1d = ColumnParallelLinear( - input_size=self.conv_kernel_size, - output_size=self.intermediate_size, - bias=self.use_conv_bias, - ) - # unsqueeze to fit conv1d weights shape into the linear weights shape. - # Can't do this in `weight_loader` since it already exists in - # `ColumnParallelLinear` and `set_weight_attrs` - # doesn't allow to override it - self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - - self.in_proj = MergedColumnParallelLinear(self.hidden_size, - [self.intermediate_size] * 2, - bias=self.use_bias) - # selective projection used to make dt, B and C input dependent - self.x_proj = RowParallelLinear( - self.intermediate_size, - self.time_step_rank + self.ssm_state_size * 2, - bias=False, - ) - # time step projection (discretization) - - # In the forward we need to apply dt_proj without the bias, - # as the bias is added in the selective scan kernel. - self.dt_proj = ColumnParallelLinear(self.time_step_rank, - self.intermediate_size, - bias=True, - skip_bias_add=True) - - def weight_loader(param: Parameter, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - param.data.copy_( - loaded_weight.data.split(loaded_weight.shape[0] // tp_size, - dim=0)[tp_rank]) - - def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): - weight_loader(param, -torch.exp(loaded_weight.float())) - - tp_size = get_tensor_model_parallel_world_size() - self.A = nn.Parameter( - torch.empty( - self.intermediate_size // tp_size, - self.ssm_state_size, - dtype=torch.float32, - )) - self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) - - set_weight_attrs(self.D, {"weight_loader": weight_loader}) - set_weight_attrs(self.A, {"weight_loader": A_weight_loader}) - - self.out_proj = RowParallelLinear( - self.intermediate_size, - self.hidden_size, - bias=self.use_bias, - input_is_parallel=True, - ) - self.activation = config.hidden_act - - # Jamba has layer norms here. Mamba doesn't. - # TODO: Leaving these in for now, just as a placeholder in case mamba2 needs them. - #self.dt_layernorm = RMSNorm(self.time_step_rank, - # eps=config.layer_norm_epsilon) - #self.b_layernorm = RMSNorm(self.ssm_state_size, - # eps=config.layer_norm_epsilon) - #self.c_layernorm = RMSNorm(self.ssm_state_size, - # eps=config.layer_norm_epsilon) - - def mamba_forward(self, - hidden_states: torch.Tensor, - cache_params: MambaCacheParams = None): - - # 1. Gated MLP's linear projection - projected_states = self.in_proj(hidden_states)[0].transpose(1, 2) - hidden_states, gate = projected_states.chunk(2, dim=1) - - # 2. Convolution sequence transformation - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), - self.conv1d.weight.size(2)) - if cache_params is not None and not cache_params.is_prompt: - hidden_states = causal_conv1d_update( - hidden_states.squeeze(-1), - cache_params.conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - ) - hidden_states = hidden_states.unsqueeze(-1) - else: - if cache_params is not None: - conv_states = nn.functional.pad( - hidden_states, - (self.conv_kernel_size - hidden_states.shape[-1], 0)) - cache_params.conv_state.copy_(conv_states) - - hidden_states = causal_conv1d_fn( - hidden_states, - conv_weights, - self.conv1d.bias, - activation=self.activation, - ) - - # 3. State Space Model sequence transformation - # 3.a. input varying initialization of time_step, B and C - ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0] - - time_step, B, C = torch.split( - ssm_parameters, - [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], - dim=-1, - ) - - # Jamba has layer norms here. Mamba doesn't. - # TODO: Leaving these in for now, just as a placeholder in case mamba2 needs them. - # time_step = self.dt_layernorm(time_step.contiguous()) - # B = self.b_layernorm(B.contiguous()) - # C = self.c_layernorm(C.contiguous()) - - discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2) - # 3.c perform the recurrence y ← SSM(A, B, C)(x) - time_proj_bias = (self.dt_proj.bias.float() if hasattr( - self.dt_proj, "bias") else None) - if cache_params is not None and not cache_params.is_prompt: - scan_outputs = selective_state_update( - cache_params.ssm_state, - hidden_states[..., 0], - discrete_time_step[..., 0], - self.A, - B[:, 0], - C[:, 0], - self.D, - gate[..., 0], - time_proj_bias, - dt_softplus=True, - ).unsqueeze(-1) - else: - scan_outputs, ssm_state = selective_scan_fn( - hidden_states, - discrete_time_step, - self.A, - B.transpose(1, 2), - C.transpose(1, 2), - self.D.float(), - gate, - time_proj_bias, - delta_softplus=True, - return_last_state=True, - ) - if ssm_state is not None and cache_params is not None: - cache_params.ssm_state.copy_(ssm_state) - - # 4. Final linear projection - contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0] - return contextualized_states - - def forward( - self, - hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, - conv_state: torch.Tensor, - ssm_state: torch.Tensor, - ): - if attn_metadata.prefill_metadata is not None: - offset = 0 - for i, prompt_len in enumerate( - attn_metadata.prefill_metadata.seq_lens): - cache = MambaCacheParams(True, - conv_state=conv_state[i].unsqueeze(0), - ssm_state=ssm_state[i].unsqueeze(0)) - hidden_states[offset:offset + prompt_len].copy_( - self.mamba_forward(hidden_states[offset:offset + - prompt_len].unsqueeze(0), - cache_params=cache)[0]) - offset += prompt_len - else: - cache = MambaCacheParams(False, - conv_state=conv_state, - ssm_state=ssm_state) - hidden_states = self.mamba_forward(hidden_states.unsqueeze(1), - cache_params=cache) - hidden_states = hidden_states.squeeze(1) - - return hidden_states - -class MambaMLP(nn.Module): - - def __init__( - self, - config: MambaConfig, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - hidden_size = config.hidden_size - intermediate_size = config.intermediate_size - hidden_act = config.hidden_act - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x - -class MambaDecoderLayer(nn.Module): - - def __init__(self, - config: MambaConfig, - layer_idx: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: - super().__init__() - self.layer_idx = layer_idx - self.config = config - self.mixer = MambaMixer(config, layer_idx) - - self.feed_forward = MambaMLP(config, quant_config=quant_config) - self.norm = RMSNorm(config.hidden_size, - eps=config.layer_norm_epsilon) - self.pre_ff_layernorm = RMSNorm(config.hidden_size, - eps=config.layer_norm_epsilon) - - def forward( - self, - hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, - residual: Optional[torch.Tensor], - conv_state: torch.Tensor, - ssm_state: torch.Tensor, - **kwargs, - ): - if residual is None: - residual = hidden_states - hidden_states = self.norm(hidden_states) - else: - hidden_states, residual = self.norm( - hidden_states, residual) - - hidden_states = self.mixer(hidden_states, attn_metadata, conv_state, - ssm_state) - # Fully Connected - hidden_states, residual = self.pre_ff_layernorm( - hidden_states, residual) - hidden_states = self.feed_forward(hidden_states) - return hidden_states, residual - -class MambaModel(nn.Module): - - def __init__( - self, - config: MambaConfig, - quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, - lora_config: Optional[LoRAConfig] = None, - ) -> None: - super().__init__() - self.config = config - self.padding_idx = config.pad_token_id - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size - - self.embeddings = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - ) - - decoder_layers = [] - for i in range(config.num_hidden_layers): - decoder_layers.append( - MambaDecoderLayer(config, - layer_idx=i, - cache_config=cache_config, - quant_config=quant_config)) - self.layers = nn.ModuleList(decoder_layers) - self.norm_f = RMSNorm(config.hidden_size, - eps=config.layer_norm_epsilon) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - conv_state: torch.Tensor, - ssm_state: torch.Tensor, - ) -> torch.Tensor: - hidden_states = self.embeddings(input_ids) - residual = None - - for i in range(len(self.layers)): - layer = self.layers[i] - current_ssm_state = ssm_state[i] - current_conv_state = conv_state[i] - - hidden_states, residual = layer( - positions=positions, - hidden_states=hidden_states, - attn_metadata=attn_metadata, - residual=residual, - conv_state=current_conv_state, - ssm_state=current_ssm_state, - ) - hidden_states, _ = self.norm_f(hidden_states, residual) - - - return hidden_states - -class MambaForCausalLM(nn.Module): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - } - - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "embed_tokens", - "lm_head", - ] - embedding_modules = { - "embeddings": "input_embeddings", - "lm_head": "output_embeddings", - } - embedding_padding_modules = ["lm_head"] - - def __init__( - self, - config: MambaConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, - ) -> None: - super().__init__() - self.config = config - self.backbone = MambaModel(config, - cache_config=cache_config, - quant_config=quant_config, - lora_config=lora_config) - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - - #TODO: this ends up all 0s -- we don't put anything in here when loading weights. - #TODO: Does mamba share weights between the lm head and embeddings? -# self.lm_head = ParallelLMHead( -# self.unpadded_vocab_size, -# config.hidden_size, -# org_num_embeddings=config.vocab_size, -# padding_size=DEFAULT_VOCAB_PADDING_SIZE -# # We need bigger padding if using lora for kernel -# # compatibility -# if not lora_config else lora_config.lora_vocab_padding_size, -# ) - self.lm_head = self.backbone.embeddings - # Current step used indices - self.current_indices: List[int] = [] - # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple() - # Used as an input_buffer for the CUDA graph runs. - self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple() - # Maps between the request id and a dict that maps between the seq_id - # and its index inside the self.mamba_cache - self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) - self.sampler = Sampler() - - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[KVCache], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - **kwargs): - - if not self.mamba_cache: - self._prepare_mamba_cache() - - if "seqlen_agnostic_capture_inputs" not in kwargs: - # We get here only on Prefill/Eager mode runs - assert all( - key in kwargs - for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) - - request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] - batch_size = input_ids.shape[0] - if attn_metadata.prefill_metadata: - batch_size = len(request_ids_to_seq_ids) - ( - current_seqlen_agnostic_cache, - indices, - ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - batch_size) - finished_requests_ids = kwargs["finished_requests_ids"] - self._release_mamba_cache(finished_requests_ids) - else: - # CUDA graph capturing runs - current_seqlen_agnostic_cache, indices = ( - kwargs["seqlen_agnostic_capture_inputs"], - [], - ) - self.current_indices = indices - - hidden_states = self.backbone(input_ids, positions, kv_caches, - attn_metadata, - current_seqlen_agnostic_cache[0], - current_seqlen_agnostic_cache[1]) - - if "seqlen_agnostic_capture_inputs" not in kwargs: - self._copy_mamba_cache_by_indices(self.current_indices, - current_seqlen_agnostic_cache) - - return hidden_states - - def _copy_mamba_cache_by_indices( - self, indices: List[int], - current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]): - for i, offset in enumerate(indices): - self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache) - - def _copy_mamba_cache(self, index_to: int, index_from: int, - from_buffer: Tuple[torch.Tensor, torch.Tensor]): - assert len(self.mamba_cache) > 0 - for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer): - cache_t[:, index_to].copy_(from_buffer_t[:, index_from], - non_blocking=True) - - def _assign_seq_id_to_mamba_cache(self, cur_rid: str, - seqs_id: List[int]) -> List[int]: - indices_for_current_run = [] - for seq_id in seqs_id: - if cur_rid not in self.mamba_cache_indices_mapping: - self.mamba_cache_indices_mapping[cur_rid] = {} - first_free_index = self._first_free_index_in_mamba_cache() - self.mamba_cache_indices_mapping[cur_rid][ - seq_id] = first_free_index - index_for_current_run = first_free_index - ## case of decoding n>1, copy prefill cache to decoding indices - elif seq_id not in (seq_ids2indices := - self.mamba_cache_indices_mapping[cur_rid]): - first_free_index = self._first_free_index_in_mamba_cache() - index_exist = list(seq_ids2indices.values())[0] - self._copy_mamba_cache(index_from=index_exist, - index_to=first_free_index, - from_buffer=self.mamba_cache) - self.mamba_cache_indices_mapping[cur_rid][ - seq_id] = first_free_index - index_for_current_run = first_free_index - else: - index_for_current_run = self.mamba_cache_indices_mapping[ - cur_rid][seq_id] - - indices_for_current_run.append(index_for_current_run) - return indices_for_current_run - - def _prepare_current_run_mamba_cache( - self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]: - indices_for_current_run = [] - for request_id, seqs_id in request_ids_to_seq_ids.items(): - indices_for_current_run += self._assign_seq_id_to_mamba_cache( - request_id, seqs_id) - ## Pad the batch in case of running batch that was not captured via CG - padded_indices = indices_for_current_run.copy() - pad_index = self._first_free_index_in_mamba_cache() - - for _ in range(batch_size - len(indices_for_current_run)): - padded_indices.append(pad_index) - - conv_state = self.mamba_cache[0][:, padded_indices] - temporal_state = self.mamba_cache[1][:, padded_indices] - - return (conv_state, temporal_state), indices_for_current_run - - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - """ - Copy the relevant Mamba cache into the CUDA graph input buffer - that was provided during the capture runs - (MambaForCausalLM.mamba_gc_cache_buffer). - """ - assert all( - key in kwargs - for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) - request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] - batch_size = len(request_ids_to_seq_ids) - ( - current_mamba_cache, - indices, - ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - batch_size) - self.current_indices = indices - finished_requests_ids = kwargs["finished_requests_ids"] - self._release_mamba_cache(finished_requests_ids) - - for input_buffer, current_cache_buffer in zip( - input_buffers["seqlen_agnostic_capture_inputs"], - current_mamba_cache): - input_buffer.copy_(current_cache_buffer, non_blocking=True) - - def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs): - """ - Copy the relevant Mamba cache from the CUDA graph input_buffers - back to the MambaForCausalLM.mamba_cache after CUDA - graph replay run is done. - """ - self._copy_mamba_cache_by_indices( - self.current_indices, - input_buffers["seqlen_agnostic_capture_inputs"]) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - """ - Provide the CUDA graph capture runs with a buffer in adjusted size. - The buffer is used to maintain the Mamba Cache during the CUDA graph - replay runs. - """ - return tuple(buffer[:, :batch_size] - for buffer in self.mamba_gc_cache_buffer) - - def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): - for req_id in finished_seq_groups_req_ids: - if req_id in self.mamba_cache_indices_mapping: - self.mamba_cache_indices_mapping.pop(req_id) - - def _first_free_index_in_mamba_cache(self) -> int: - if self.mamba_cache: - max_possible_batch_size = self.mamba_cache[0].shape[1] - occupied = [ - id for seq_ids in self.mamba_cache_indices_mapping.values() - for id in seq_ids.values() - ] - first_free_index = [ - i not in occupied for i in range(max_possible_batch_size) - ].index(True) - return first_free_index - return 0 - - def _get_mamba_cache_shape( - self - ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = self.config.hidden_size - conv_state_shape = ( - self.config.intermediate_size // world_size, - self.config.conv_kernel, - ) - temporal_state_shape = ( - self.config.intermediate_size // world_size, - self.config.state_size, - ) - return conv_state_shape, temporal_state_shape - - def _prepare_mamba_cache(self): - dtype = self.lm_head.weight.dtype - num_mamba_layers = self.config.num_hidden_layers - max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + 10 - conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() - assert conv_state_shape is not None and temporal_state_shape is not None - for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]: - buffer = (torch.empty(size=(num_mamba_layers, max_batch_size) + - conv_state_shape, - dtype=dtype, - device="cuda"), - torch.empty(size=(num_mamba_layers, max_batch_size) + - temporal_state_shape, - dtype=dtype, - device="cuda")) - setattr(self, buffername, buffer) - - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def sample( - self, - logits: Optional[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - params_dict = dict(self.named_parameters()) - for k, v in params_dict.items(): - print(k) - - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - - if "A_log" in name: - name = name.replace("A_log", "A") - - if ".self_attn." in name: - name = name.replace(".self_attn", "") - - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) From b733a840010c054f3bb069e49335e9c7926d5a35 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 16 Jul 2024 18:26:46 +0000 Subject: [PATCH 10/15] format --- vllm/attention/backends/placeholder_attn.py | 6 ++++-- vllm/engine/llm_engine.py | 2 +- vllm/worker/model_runner.py | 4 ++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 6bc766ba4e3..f5728756c6e 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -1,8 +1,10 @@ from dataclasses import dataclass -from typing import (List, Optional, Tuple, Type) +from typing import List, Optional, Tuple, Type + +import torch + from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata) -import torch # Placeholder attention backend for models like Mamba that don't have attention. # Mainly exists to sidestep get_attn_backend. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c43f7fcb854..f1ce03171eb 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -261,7 +261,7 @@ def __init__( if not self.model_config.embedding_mode: # For all decoders including attention-free models like mamba, - # this must call _initialize_kv_caches, as this is where model + # this must call _initialize_kv_caches, as this is where model # warmup and CUDA graphs creation happens. self._initialize_kv_caches() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 459798e418c..2f4a0657c3f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -23,8 +23,8 @@ FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.attention.backends.placeholder_attn import PlaceholderAttentionBackend - +from vllm.attention.backends.placeholder_attn import ( + PlaceholderAttentionBackend) from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) From fb846ce85cde68ce6b22fcab596ed0ac06fef601 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 16 Jul 2024 21:57:51 +0000 Subject: [PATCH 11/15] apply fix from #6214 --- vllm/model_executor/models/mamba.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index ca8d58fd3d6..a76c3757be7 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -538,20 +538,20 @@ def _prepare_current_run_mamba_cache( def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): """ - Copy the relevant Mamba cache into the CUDA graph input buffer - that was provided during the capture runs - (MambaForCausalLM.mamba_gc_cache_buffer). + Copy the relevant Mamba cache into the CUDA graph input buffer + that was provided during the capture runs + (MambaForCausalLM.mamba_gc_cache_buffer). """ assert all( key in kwargs for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] - batch_size = len(request_ids_to_seq_ids) + cg_batch_size = input_buffers['input_ids'].shape[0] ( current_mamba_cache, indices, ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - batch_size) + cg_batch_size) self.current_indices = indices finished_requests_ids = kwargs["finished_requests_ids"] self._release_mamba_cache(finished_requests_ids) From d8017cb5044eb7c3458c2976e7aed9bf17753ace Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 16 Jul 2024 22:27:19 +0000 Subject: [PATCH 12/15] fixes from 6425 --- vllm/model_executor/models/interfaces.py | 2 +- vllm/model_executor/models/mamba.py | 32 ++++++++++++++++++------ 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 6fdacd44697..b0b614d6b52 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -152,7 +152,7 @@ class HasInnerState(Protocol): """ A flag that indicates this model has inner state. Models that has inner state usually need access to the scheduler_config - for max_num_seqs ,etc... (Currently only used by Jamba) + for max_num_seqs ,etc... (Currently used by Jamba and Mamba) """ def __init__(self, diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index a76c3757be7..49cfd5c1868 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -12,7 +12,7 @@ from transformers import MambaConfig from vllm.attention.backends.abstract import AttentionMetadata -from vllm.config import CacheConfig, LoRAConfig +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -27,10 +27,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import HasInnerState from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.worker.model_runner import _BATCH_SIZES_TO_CAPTURE +from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, + _get_graph_batch_size) KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -376,7 +378,7 @@ def forward( return hidden_states -class MambaForCausalLM(nn.Module): +class MambaForCausalLM(nn.Module, HasInnerState): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -404,9 +406,11 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + scheduler_config: Optional[SchedulerConfig] = None, ) -> None: super().__init__() self.config = config + self.scheduler_config = scheduler_config self.backbone = MambaModel(config, cache_config=cache_config, quant_config=quant_config, @@ -436,7 +440,6 @@ def forward(self, attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs): - if not self.mamba_cache: self._prepare_mamba_cache() @@ -447,6 +450,7 @@ def forward(self, for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + finished_requests_ids = kwargs["finished_requests_ids"] batch_size = input_ids.shape[0] if attn_metadata.prefill_metadata: batch_size = len(request_ids_to_seq_ids) @@ -454,7 +458,8 @@ def forward(self, current_seqlen_agnostic_cache, indices, ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - batch_size) + batch_size, + finished_requests_ids) finished_requests_ids = kwargs["finished_requests_ids"] self._release_mamba_cache(finished_requests_ids) else: @@ -518,10 +523,15 @@ def _assign_seq_id_to_mamba_cache(self, cur_rid: str, return indices_for_current_run def _prepare_current_run_mamba_cache( - self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int + self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int, + finished_requests_ids: List[str] ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]: indices_for_current_run = [] for request_id, seqs_id in request_ids_to_seq_ids.items(): + if request_id in finished_requests_ids: + # Do not allocate cache for requests that run + # and finish right after + continue indices_for_current_run += self._assign_seq_id_to_mamba_cache( request_id, seqs_id) ## Pad the batch in case of running batch that was not captured via CG @@ -545,13 +555,16 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): assert all( key in kwargs for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + finished_requests_ids = kwargs["finished_requests_ids"] + self._release_mamba_cache(finished_requests_ids) request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] cg_batch_size = input_buffers['input_ids'].shape[0] ( current_mamba_cache, indices, ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - cg_batch_size) + cg_batch_size, + finished_requests_ids) self.current_indices = indices finished_requests_ids = kwargs["finished_requests_ids"] self._release_mamba_cache(finished_requests_ids) @@ -615,9 +628,12 @@ def _get_mamba_cache_shape( def _prepare_mamba_cache(self): dtype = self.lm_head.weight.dtype num_mamba_layers = self.config.num_hidden_layers - max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + 10 + max_batch_size = (_get_graph_batch_size( + self.scheduler_config.max_num_seqs) if self.scheduler_config else + max(_BATCH_SIZES_TO_CAPTURE)) + 10 conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() assert conv_state_shape is not None and temporal_state_shape is not None + for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]: buffer = (torch.empty(size=(num_mamba_layers, max_batch_size) + conv_state_shape, From 7ab2b9e7d3a2ce8648e35c9ab34bb1c627dbac2a Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 23 Jul 2024 19:59:54 +0000 Subject: [PATCH 13/15] add an integration test --- tests/models/test_mamba.py | 79 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 tests/models/test_mamba.py diff --git a/tests/models/test_mamba.py b/tests/models/test_mamba.py new file mode 100644 index 00000000000..6a09d5f98f0 --- /dev/null +++ b/tests/models/test_mamba.py @@ -0,0 +1,79 @@ +"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba. + +Run `pytest tests/models/test_mamba.py`. +""" +import pytest +from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline +import torch + +from .utils import check_outputs_equal + +MODELS = [ + "state-spaces/mamba-370m-hf", +] + +# Use lower-level interfaces to create this greedy generator, as mamba will +# choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used. +def generate_greedy(model_name, example_prompts, max_tokens): + # Create a text generation pipeline + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained(model_name) + + generator = TextGenerationPipeline(model=model, tokenizer=tokenizer, + device=torch.cuda.current_device() + if torch.cuda.is_available() else -1) + + # Generate texts from the prompts + outputs = [] + for prompt in example_prompts: + # Tokenize the input prompt with truncation + inputs = tokenizer(prompt, return_tensors="pt", truncation=True) + input_ids = inputs["input_ids"].to(model.device) + + # Generate text using the model's generate method directly + generated_ids = model.generate(input_ids, max_new_tokens=max_tokens) + generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + + outputs.append((generated_ids[0].tolist(), generated_text)) + + return outputs + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # To pass the small model tests, we need full precision. + assert dtype == "float" + + hf_outputs = generate_greedy(model, example_prompts, max_tokens) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_model_print( + vllm_runner, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, dtype=dtype) as vllm_model: + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) From c319a21c9d203de56addc662691f176003ca5d91 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 23 Jul 2024 20:06:44 +0000 Subject: [PATCH 14/15] lint --- tests/models/test_mamba.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/models/test_mamba.py b/tests/models/test_mamba.py index 6a09d5f98f0..509027681f4 100644 --- a/tests/models/test_mamba.py +++ b/tests/models/test_mamba.py @@ -3,8 +3,7 @@ Run `pytest tests/models/test_mamba.py`. """ import pytest -from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline -import torch +from transformers import AutoModelForCausalLM, AutoTokenizer from .utils import check_outputs_equal @@ -12,6 +11,7 @@ "state-spaces/mamba-370m-hf", ] + # Use lower-level interfaces to create this greedy generator, as mamba will # choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used. def generate_greedy(model_name, example_prompts, max_tokens): @@ -19,25 +19,23 @@ def generate_greedy(model_name, example_prompts, max_tokens): tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) - generator = TextGenerationPipeline(model=model, tokenizer=tokenizer, - device=torch.cuda.current_device() - if torch.cuda.is_available() else -1) - # Generate texts from the prompts outputs = [] for prompt in example_prompts: # Tokenize the input prompt with truncation inputs = tokenizer(prompt, return_tensors="pt", truncation=True) input_ids = inputs["input_ids"].to(model.device) - + # Generate text using the model's generate method directly generated_ids = model.generate(input_ids, max_new_tokens=max_tokens) - generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + generated_text = tokenizer.decode(generated_ids[0], + skip_special_tokens=True) outputs.append((generated_ids[0].tolist(), generated_text)) return outputs + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [96]) From 1b7b5963d606f5067f812335c178efff66c0057f Mon Sep 17 00:00:00 2001 From: Zhiyuan Li Date: Thu, 25 Jul 2024 00:41:40 +1000 Subject: [PATCH 15/15] add model for rwkv6( No adaptation has been implemented --- vllm/model_executor/models/rwkv_6.py | 377 +++++++++++++++++++++++++++ 1 file changed, 377 insertions(+) create mode 100644 vllm/model_executor/models/rwkv_6.py diff --git a/vllm/model_executor/models/rwkv_6.py b/vllm/model_executor/models/rwkv_6.py new file mode 100644 index 00000000000..bf8c9f628cd --- /dev/null +++ b/vllm/model_executor/models/rwkv_6.py @@ -0,0 +1,377 @@ +# coding=utf-8 +"""PyTorch RWKV6 model.(native PyTorch version)""" +""" +author: @Zhiyuan Li +email: +date: 2024-07-22 +""" +from dataclasses import dataclass +from typing import Dict, Iterable, List, Optional, Tuple + + +import torch +import torch.nn as nn +from transformers import RwkvConfig +from vllm.config import LoRAConfig, CacheConfig, SchedulerConfig +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.models.interfaces import HasInnerState +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, + _get_graph_batch_size) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + + +MyModule = torch.jit.ScriptModule +MyFunction = torch.jit.script_method +KVCache = Tuple[torch.Tensor, torch.Tensor] + +@dataclass +class RwkvCacheParams: + is_prompt: bool = False + ssm_state: torch.Tensor = torch.Tensor() + + +class Rwkv_Block(MyModule): + def __init__(self, block_w: dict, hidden_size: int, n_head: int): + super().__init__() + self.hidden_size = hidden_size + self.n_head = n_head + self.head_size = hidden_size // n_head + + self.ln1 = nn.LayerNorm(hidden_size) + self.ln1.weight = nn.Parameter(block_w['ln1.weight']) + self.ln1.bias = nn.Parameter(block_w['ln1.bias']) + self.ln2 = nn.LayerNorm(hidden_size) + self.ln2.weight = nn.Parameter(block_w['ln2.weight']) + self.ln2.bias = nn.Parameter(block_w['ln2.bias']) + + + self.silu = nn.SiLU(inplace=False) + + self.att_time_maa_x = nn.Parameter(block_w['att.time_maa_x']) + self.att_time_maa_w = nn.Parameter(block_w['att.time_maa_w']) + self.att_time_maa_k = nn.Parameter(block_w['att.time_maa_k']) + self.att_time_maa_v = nn.Parameter(block_w['att.time_maa_v']) + self.att_time_maa_r = nn.Parameter(block_w['att.time_maa_r']) + self.att_time_maa_g = nn.Parameter(block_w['att.time_maa_g']) + self.att_time_maa_w1 = nn.Parameter(block_w['att.time_maa_w1']) + self.att_time_maa_w2 = nn.Parameter(block_w['att.time_maa_w2']) + self.att_time_decay = nn.Parameter(block_w['att.time_decay']) + self.att_time_decay_w1 = nn.Parameter(block_w['att.time_decay_w1']) + self.att_time_decay_w2 = nn.Parameter(block_w['att.time_decay_w2']) + self.att_time_faaaa = nn.Parameter(block_w['att.time_faaaa']) + self.att_receptance = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.att_receptance.weight = nn.Parameter(block_w['att.receptance.weight']) + self.att_key = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.att_key.weight = nn.Parameter(block_w['att.key.weight']) + self.att_value = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.att_value.weight = nn.Parameter(block_w['att.value.weight']) + self.att_output = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.att_output.weight = nn.Parameter(block_w['att.output.weight']) + self.att_gate = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.att_gate.weight = nn.Parameter(block_w['att.gate.weight']) + + + self.att_group_norm = nn.GroupNorm(num_groups=n_head, num_channels=hidden_size, eps=1e-5, affine=True) + self.att_group_norm.weight = nn.Parameter(block_w['att.ln_x.weight']) + self.att_group_norm.bias = nn.Parameter(block_w['att.ln_x.bias']) + + self.ffn_time_maa_k = nn.Parameter(block_w['ffn.time_maa_k']) + self.ffn_time_maa_r = nn.Parameter(block_w['ffn.time_maa_r']) + self.ffn_key = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.ffn_key.weight = nn.Parameter(block_w['ffn.key.weight']) + self.ffn_receptance = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.ffn_receptance.weight = nn.Parameter(block_w['ffn.receptance.weight']) + self.ffn_value = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.ffn_value.weight = nn.Parameter(block_w['ffn.value.weight']) + + @MyFunction + def channel_mixing(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor: + i0 = (2 + self.head_size) * i + 0 + sx = state[:, i0] - x + state[:, i0] = x + xk = x + sx * self.ffn_time_maa_k + xr = x + sx * self.ffn_time_maa_r + r = torch.sigmoid(self.ffn_receptance(xr)) + k = torch.relu(self.ffn_key(xk)).pow(2) + output = r * self.ffn_value(k) + return output, state + + @MyFunction + def channel_mixing_parallel(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor: + i0 = (2 + self.head_size) * i + 0 + + sx_lerp = torch.empty_like(x) + sx_lerp[:, 0] = state[:, i0] - x[:, 0] + sx_lerp[:, 1:] = x[:, :-1] - x[:, 1:] + + state[:, i0] = x[:, -1] + + xk = x + sx_lerp * self.ffn_time_maa_k + xr = x + sx_lerp * self.ffn_time_maa_r + + r = torch.sigmoid(self.ffn_receptance(xr)) # [Batch, L, hiddle_size] + k = torch.relu(self.ffn_key(xk)).pow(2) + + output = r * self.ffn_value(k) + return output, state + + def time_mixing(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor: + batch_size, H, S = x.size(0), self.n_head, self.head_size + x, state, g = self.time_mixing_jit(x, state, i, batch_size, H, S) + + x = self.time_mixing_jit2(x, g) + + return x, state + + @MyFunction + def time_mixing_jit(self, x: torch.Tensor, state: torch.Tensor, i: int, + batch_size: int, H: int, S: int): + i1 = (2 + S) * i + 1 # i is the block number + + sx = state[:, i1] - x + state[:, i1] = x # Information is compressed to position 1 on each layer + + xxx = x + sx * self.att_time_maa_x + xxx = torch.tanh(xxx @ self.att_time_maa_w1).view(batch_size, 5, 1, -1) + xxx = torch.matmul(xxx, self.att_time_maa_w2).view(batch_size, 5, -1) + mw, mk, mv, mr, mg = xxx.unbind(dim=1) + + xw = x + sx * (self.att_time_maa_w + mw) + xk = x + sx * (self.att_time_maa_k + mk) + xv = x + sx * (self.att_time_maa_v + mv) + xr = x + sx * (self.att_time_maa_r + mr) + xg = x + sx * (self.att_time_maa_g + mg) + + # calculate w, r, k, v, g + w = (self.att_time_decay + (torch.tanh(xw @ self.att_time_decay_w1) @ self.att_time_decay_w2)) + w = -torch.exp(w.view(batch_size, H, S, 1)) + + r = self.att_receptance(xr).view(batch_size, H, 1, S) + k = self.att_key(xk).view(batch_size, H, S, 1) + v = self.att_value(xv).view(batch_size, H, 1, S) + g = self.silu(self.att_gate(xg)) + + # Update state using attention mechanism + s = state[:, (2+S)*i+2:(2+S)*(i+1), :].view(batch_size, H, S, S) + a = k @ v + x = r @ (self.att_time_faaaa * a + s) + s = a + torch.exp(w) * s + # Update the attention parameters of the i-th layer STATE + state[:, (2+S)*i+2:(2+S)*(i+1), :] = s.view(batch_size, S, -1) + return x, state, g + + @MyFunction + def time_mixing_jit2(self, x:torch.Tensor, g): + return self.att_output(self.att_group_norm(x.flatten(start_dim=1)) * g) + + def time_mixing_parallel(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor: + batch_size, L, H, S = x.size(0), x.size(1), self.n_head, self.head_size + x, state, g = self.time_mixing_parallel_jit1(x, state, i, batch_size, L, H, S) + + x = self.time_mixing_parallel_jit2(x, g, batch_size, L) + + return x, state + + @MyFunction + def time_mixing_parallel_jit1(self, x: torch.Tensor, state: torch.Tensor, i: int, + batch_size: int, L: int, H: int, S: int): + i1 = (2 + S) * i + 1 + sx_lerp = torch.empty_like(x) + sx_lerp[:, 0] = state[:, i1] - x[:, 0] + + sx_lerp[:, 1:] = x[:, :-1] - x[:, 1:] + + state[:, i1] = x[:, -1] + + xxx = x + sx_lerp * self.att_time_maa_x # torch.Size([B, L, hiddle_size]) + xxx = torch.tanh(xxx @ self.att_time_maa_w1).view(batch_size, L, 5, 1, -1) # att_time_maa_w1: [hiddle_size, 160] + xxx = torch.matmul(xxx, self.att_time_maa_w2).view(batch_size, L, 5, -1) # [Batch, L, 5, hiddle_size] + + mw, mk, mv, mr, mg = xxx.unbind(dim=2) # [10, 100, hiddle_size] + + xw = x + sx_lerp * (self.att_time_maa_w + mw) # torch.Size([B, L, hiddle_size]) + xk = x + sx_lerp * (self.att_time_maa_k + mk) + xv = x + sx_lerp * (self.att_time_maa_v + mv) + xr = x + sx_lerp * (self.att_time_maa_r + mr) + xg = x + sx_lerp * (self.att_time_maa_g + mg) + + w = (self.att_time_decay + (torch.tanh(xw @ self.att_time_decay_w1) @ self.att_time_decay_w2)) + w = -torch.exp(w.view(batch_size, L, H, S, 1)) + + r = self.att_receptance(xr).view(batch_size, L, H, 1, S) + k = self.att_key(xk).view(batch_size, L, H, S, 1) + v = self.att_value(xv).view(batch_size, L, H, 1, S) + g = self.silu(self.att_gate(xg)) # [10, 100, hiddle_size] + # TODO, apply kernel here, cuda or fla(triton) + + + w = torch.exp(w) + s = state[:, (2+S)*i+2:(2+S)*(i+1)].view(batch_size, H, S, S) + a = k @ v # a: [batch_size, L, H, S, S] + + state_s = torch.zeros(batch_size, L, H, S, S, dtype=x.dtype, device=x.device) + state_s[:, 0] = s + + for l in range(L-1): + s = a[:, l] + w[:, l] * s + state_s[:, l+1] = s + s = a[:, -1] + w[:, -1] * s + + state[:, (2+S)*i+2:(2+S)*(i+1)] = s.view(batch_size, S, -1) + + x = r @ (self.att_time_faaaa * a + state_s) + return x, state, g + + @MyFunction + def time_mixing_parallel_jit2(self, x: torch.Tensor, g: torch.Tensor, batch_size: int, L:int): + return self.att_output(self.att_group_norm(x.flatten(start_dim=2).view(batch_size * L, -1)).view(batch_size, L, -1) * g) + + @torch.no_grad() + def forward(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor: + x_time, state = self.time_mixing(self.ln1(x), state, i) + x = x + x_time + x_channel, state = self.channel_mixing(self.ln2(x), state, i) + x = x + x_channel + + return x, state + + @torch.no_grad() + def forward_parallel(self, x: torch.Tensor, state: torch.Tensor, i: int) -> torch.Tensor: + # Time mixing + x_time, state = self.time_mixing_parallel(self.ln1(x), state, i) + x = x + x_time + + # Channel mixing + x_channel, state = self.channel_mixing_parallel(self.ln2(x), state, i) + x = x + x_channel + + return x, state + +class RwkvModel(MyModule): + def __init__(self, args: dict): + super().__init__() + self.args = args + self.load_params() + self.eval() + + + + def load_params(self, load_from_file: bool = True, w: dict = None): + # TODO: vllm + if load_from_file: + if not self.args['MODEL_NAME'].endswith('.pth'): + self.args['MODEL_NAME'] += '.pth' + w = torch.load(self.args['MODEL_NAME'], map_location="cpu") + else: + assert w is not None + + self.num_layer = 0 + for k in w.keys(): + if '.time_' in k: w[k] = w[k].squeeze() + if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1) + if "blocks" in k: self.num_layer = max(self.num_layer, int(k.split(".")[1])) + self.num_layer += 1 + + self.n_head = w['blocks.0.att.time_faaaa'].shape[0] + self.n_embd = w['blocks.0.ln1.weight'].shape[0] + self.head_size = self.n_embd // self.n_head + self.state_size = [self.num_layer * (2 + self.head_size), self.n_embd] + + self.emb = nn.Embedding.from_pretrained(w['emb.weight'], freeze=True) + + + self.ln0 = nn.LayerNorm(self.n_embd) + self.ln0.weight = nn.Parameter(w['blocks.0.ln0.weight']) + self.ln0.bias = nn.Parameter(w['blocks.0.ln0.bias']) + + + self.blocks = nn.ModuleList() + + for i in range(self.num_layer): + block_w = {k[len(f'blocks.{i}.'):]: v for k, v in w.items() if f'blocks.{i}.' in k} + self.blocks.append(Rwkv_Block(block_w, self.n_embd, self.n_head, self.args)) + + + self.ln_out = nn.LayerNorm(self.n_embd) + self.ln_out.weight = nn.Parameter(w['ln_out.weight']) + self.ln_out.bias = nn.Parameter(w['ln_out.bias']) + + + self.head = nn.Linear(self.n_embd, self.args['vocab_size'], bias=False) + self.head.weight = nn.Parameter(w['head.weight']) + + + @torch.no_grad() + def forward(self, + input_ids: torch.Tensor, + state: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + x = self.forward_jit1(input_ids) + + if attn_metadata.prefill_metadata is not None: + # Prefill phase + for i, block in enumerate(self.blocks): + x, state = block.forward_parallel(x, state, i) + else: + # Decoding phase + for i, block in enumerate(self.blocks): + x, state = block(x, state, i) + + x = self.forward_jit2(x) + + return x, state + + + @MyFunction + def forward_jit1(self, token: torch.Tensor) -> torch.Tensor: + return self.ln0(self.emb(token)) + + @MyFunction + def forward_jit2(self, x: torch.Tensor) -> torch.Tensor: + return self.head(self.ln_out(x)) + + @torch.no_grad() + def forward_parallel_slices(self, + input_ids: torch.Tensor, + state: torch.Tensor, + attn_metadata: AttentionMetadata, + slice_len: int = 64) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Prefill forward with chunks of the RWKV6 model. + Args: + x (torch.Tensor): Input tensor, shape [Batch, N_embd]. + state (torch.Tensor): Hidden state tensor, shape [Batch, State Size, N_embd]. + i (int): Time index. + Returns: + torch.Tensor: Forward pass result tensor, shape same as input x. + """ + # FIXME! + data_len = input_ids.shape[1] + for i in range((data_len-2)//slice_len+1): + start = i*slice_len + end = min((i+1)*slice_len, data_len) + input_ids_ith_chunk = input_ids[:, start:end] + token_out, state = self.forward(input_ids_ith_chunk, state, attn_metadata) + + return token_out, state + + def init_state(self, batch_size: int) -> torch.Tensor: + state = torch.zeros(batch_size, self.state_size[0], self.state_size[1]) + return state +