diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index ededf9c533f..f70e3bb20da 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -44,9 +44,23 @@ @dataclass class MambaCacheParams: is_prompt: bool = False + state_indices: torch.Tensor = torch.Tensor() conv_state: torch.Tensor = torch.Tensor() ssm_state: torch.Tensor = torch.Tensor() + def with_batch_index(self, batch_idx: int) -> 'MambaCacheParams': + return MambaCacheParams( + is_prompt=True, #TODO: make less sketchy + state_indices=self.state_indices[batch_idx], + conv_state=self.conv_state, + ssm_state=self.ssm_state) + + def with_state_layer(self, layer_idx: int) -> 'MambaCacheParams': + return MambaCacheParams(is_prompt=self.is_prompt, + state_indices=self.state_indices, + conv_state=self.conv_state[layer_idx], + ssm_state=self.ssm_state[layer_idx]) + # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer class JambaMambaMixer(nn.Module): @@ -139,6 +153,7 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): def mamba_forward(self, hidden_states: torch.Tensor, cache_params: MambaCacheParams = None): + # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(1, 2) hidden_states, gate = projected_states.chunk(2, dim=1) @@ -153,6 +168,7 @@ def mamba_forward(self, conv_weights, self.conv1d.bias, self.activation, + conv_state_indices=cache_params.state_indices, ) hidden_states = hidden_states.unsqueeze(-1) else: @@ -160,7 +176,8 @@ def mamba_forward(self, conv_states = nn.functional.pad( hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) - cache_params.conv_state.copy_(conv_states) + cache_params.conv_state[cache_params.state_indices].copy_( + conv_states.squeeze(0)) hidden_states = causal_conv1d_fn( hidden_states, @@ -198,6 +215,7 @@ def mamba_forward(self, gate[..., 0], time_proj_bias, dt_softplus=True, + state_batch_indices=cache_params.state_indices, ).unsqueeze(-1) else: scan_outputs, ssm_state = selective_scan_fn( @@ -213,7 +231,8 @@ def mamba_forward(self, return_last_state=True, ) if ssm_state is not None and cache_params is not None: - cache_params.ssm_state.copy_(ssm_state) + cache_params.ssm_state[cache_params.state_indices].copy_( + ssm_state.squeeze(0)) # 4. Final linear projection contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0] @@ -223,27 +242,21 @@ def forward( self, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, - conv_state: torch.Tensor, - ssm_state: torch.Tensor, + cache_params: MambaCacheParams, ): if attn_metadata.prefill_metadata is not None: offset = 0 for i, prompt_len in enumerate( attn_metadata.prefill_metadata.seq_lens): - cache = MambaCacheParams(True, - conv_state=conv_state[i].unsqueeze(0), - ssm_state=ssm_state[i].unsqueeze(0)) + hidden_states[offset:offset + prompt_len].copy_( - self.mamba_forward(hidden_states[offset:offset + - prompt_len].unsqueeze(0), - cache_params=cache)[0]) + self.mamba_forward( + hidden_states[offset:offset + prompt_len].unsqueeze(0), + cache_params=cache_params.with_batch_index(i))[0]) offset += prompt_len else: - cache = MambaCacheParams(False, - conv_state=conv_state, - ssm_state=ssm_state) hidden_states = self.mamba_forward(hidden_states.unsqueeze(1), - cache_params=cache) + cache_params=cache_params) hidden_states = hidden_states.squeeze(1) return hidden_states @@ -314,6 +327,7 @@ def __init__(self, quant_config=quant_config) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (batch * sequence_length, n_experts) @@ -352,8 +366,7 @@ def forward( hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], - conv_state: torch.Tensor, - ssm_state: torch.Tensor, + cache_params: MambaCacheParams, **kwargs, ): if residual is None: @@ -363,8 +376,8 @@ def forward( hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.mamba(hidden_states, attn_metadata, conv_state, - ssm_state) + hidden_states = self.mamba(hidden_states, attn_metadata, cache_params) + # Fully Connected hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual) @@ -521,8 +534,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - conv_state: torch.Tensor, - ssm_state: torch.Tensor, + cache_params: MambaCacheParams, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None @@ -530,8 +542,7 @@ def forward( for i in range(len(self.layers)): layer = self.layers[i] kv_cache = None - current_ssm_state = None - current_conv_state = None + current_cache_params = None if isinstance(layer, JambaAttentionDecoderLayer): kv_cache = kv_caches[(i - self.config.attn_layer_offset) // self.config.attn_layer_period] @@ -539,8 +550,8 @@ def forward( current_state_layer = i - (1 + (i - self.config.attn_layer_offset) // self.config.attn_layer_period) - current_ssm_state = ssm_state[current_state_layer] - current_conv_state = conv_state[current_state_layer] + current_cache_params = cache_params.with_state_layer( + current_state_layer) hidden_states, residual = layer( positions=positions, @@ -548,8 +559,7 @@ def forward( kv_cache=kv_cache, attn_metadata=attn_metadata, residual=residual, - conv_state=current_conv_state, - ssm_state=current_ssm_state, + cache_params=current_cache_params, ) hidden_states, _ = self.final_layernorm(hidden_states, residual) return hidden_states @@ -610,7 +620,8 @@ def __init__( if not lora_config else lora_config.lora_vocab_padding_size, ) # Used to track and store by the Mamba cache between steps. - self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple() + self.mamba_cache: Tuple[torch.Tensor, torch.Tensor, + torch.Tensor] = tuple() # Maps between the request id and a dict that maps between the seq_id # and its index inside the self.mamba_cache self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} @@ -636,6 +647,7 @@ def forward(self, request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] finished_requests_ids = kwargs["finished_requests_ids"] + self._release_mamba_cache(finished_requests_ids) batch_size = input_ids.shape[0] if attn_metadata.prefill_metadata: @@ -644,97 +656,70 @@ def forward(self, request_ids_to_seq_ids, batch_size, finished_requests_ids) else: # CUDA graph capturing runs - mamba_cache = kwargs["seqlen_agnostic_capture_inputs"] + (state_indices, conv_state, + ssm_state) = kwargs["seqlen_agnostic_capture_inputs"] + mamba_cache = MambaCacheParams(False, + state_indices=state_indices, + conv_state=conv_state, + ssm_state=ssm_state) hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, mamba_cache[0], - mamba_cache[1]) + attn_metadata, mamba_cache) return hidden_states - def _swap_mamba_cache(self, from_index: int, to_index: int): - assert len(self.mamba_cache) > 0 - for cache_t in self.mamba_cache: - cache_t[:, [to_index,from_index]] = \ - cache_t[:, [from_index,to_index]] - - def _copy_mamba_cache(self, from_index: int, to_index: int): - assert len(self.mamba_cache) > 0 - for cache_t in self.mamba_cache: - cache_t[:, to_index].copy_(cache_t[:, from_index], - non_blocking=True) - - def _move_out_if_already_occupied(self, index: int, - all_occupied_indices: List[int]): - if index in all_occupied_indices: - first_free_index = self._first_free_index_in_mamba_cache() - # In case occupied, move the occupied to a new empty block - self._move_cache_index_and_mappings(from_index=index, - to_index=first_free_index) - - def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str, - seq_id: int, - destination_index: int): + def _assign_seq_id_to_mamba_cache(self, cur_rid: str, seq_id: int): """ - Assign (req_id,seq_id) pair to a `destination_index` index, if - already occupied, move the occupying index to a free index. + Assign (req_id,seq_id) pair to some index in the cache. """ - all_occupied_indices = self._get_all_occupied_indices() if cur_rid not in self.mamba_cache_indices_mapping: - self._move_out_if_already_occupied( - index=destination_index, - all_occupied_indices=all_occupied_indices) - self.mamba_cache_indices_mapping[cur_rid] = { - seq_id: destination_index - } + # Pick the first available entry + cache_idx = self._first_free_index_in_mamba_cache() + self.mamba_cache_indices_mapping[cur_rid] = {seq_id: cache_idx} + + return cache_idx elif seq_id not in (seq_ids2indices := self.mamba_cache_indices_mapping[cur_rid]): - # parallel sampling , where n > 1, assume prefill have + # Parallel sampling , where n > 1, assuming prefill has # already happened now we only need to copy the already - # existing cache into the siblings seq_ids caches - self._move_out_if_already_occupied( - index=destination_index, - all_occupied_indices=all_occupied_indices) + # existing cache into the sibling's cache + cache_idx = self._first_free_index_in_mamba_cache() + index_exists = list(seq_ids2indices.values())[0] # case of decoding n>1, copy prefill cache to decoding indices - self._copy_mamba_cache(from_index=index_exists, - to_index=destination_index) - self.mamba_cache_indices_mapping[cur_rid][ - seq_id] = destination_index + self._copy_mamba_cache(from_index=index_exists, to_index=cache_idx) + self.mamba_cache_indices_mapping[cur_rid][seq_id] = cache_idx + + return cache_idx else: - # already exists - cache_index_already_exists = self.mamba_cache_indices_mapping[ - cur_rid][seq_id] - if cache_index_already_exists != destination_index: - # In case the seq id already exists but not in - # the right destination, swap it with what's occupying it - self._swap_pair_indices_and_mappings( - from_index=cache_index_already_exists, - to_index=destination_index) + # Already exists + return self.mamba_cache_indices_mapping[cur_rid][seq_id] def _prepare_current_run_mamba_cache( self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int, finished_requests_ids: List[str]): - running_indices = [] + # Use the first batch_size entries of the mamba cache state. + state_indices = self.mamba_cache[0][:batch_size] + + running_indices = torch.empty((batch_size), + dtype=torch.int32, + device="cpu") request_ids_to_seq_ids_flatten = [ (req_id, seq_id) for req_id, seq_ids in request_ids_to_seq_ids.items() for seq_id in seq_ids ] - for dest_index, (request_id, - seq_id) in enumerate(request_ids_to_seq_ids_flatten): - if request_id in finished_requests_ids: - # Do not allocate cache index for requests that run - # and finish right after - continue - self._assign_seq_id_to_mamba_cache_in_specific_dest( - request_id, seq_id, dest_index) - running_indices.append(dest_index) - self._clean_up_first_bs_blocks(batch_size, running_indices) - conv_state = self.mamba_cache[0][:, :batch_size] - temporal_state = self.mamba_cache[1][:, :batch_size] + for index, (request_id, + seq_id) in enumerate(request_ids_to_seq_ids_flatten): + cache_index = self._assign_seq_id_to_mamba_cache( + request_id, seq_id) + running_indices[index] = cache_index - return (conv_state, temporal_state) + state_indices.copy_(running_indices) + + return MambaCacheParams(state_indices=state_indices, + conv_state=self.mamba_cache[1], + ssm_state=self.mamba_cache[2]) def _get_all_occupied_indices(self): return [ @@ -743,46 +728,6 @@ def _get_all_occupied_indices(self): for cache_idx in seq_ids2indices.values() ] - def _clean_up_first_bs_blocks(self, batch_size: int, - indices_for_current_run: List[int]): - # move out all of the occupied but currently not running blocks - # outside of the first n blocks - destination_indices = set([range(batch_size)]) - max_possible_batch_size = self.mamba_cache[0].shape[1] - for destination_index in destination_indices: - if destination_index in self._get_all_occupied_indices() and \ - destination_index not in indices_for_current_run: - # move not running indices outside of the batch - all_other_indices = list( - range(batch_size, max_possible_batch_size)) - first_avail_index = self._first_free_index_in_mamba_cache( - all_other_indices) - self._swap_indices(from_index=destination_index, - to_index=first_avail_index) - - def _move_cache_index_and_mappings(self, from_index: int, to_index: int): - self._copy_mamba_cache(from_index=from_index, to_index=to_index) - self._update_mapping_index(from_index=from_index, to_index=to_index) - - def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int): - self._swap_mamba_cache(from_index=from_index, to_index=to_index) - self._swap_mapping_index(from_index=from_index, to_index=to_index) - - def _swap_mapping_index(self, from_index: int, to_index: int): - for seq_ids2index in self.mamba_cache_indices_mapping.values(): - for seq_id, index in seq_ids2index.items(): - if from_index == index: - seq_ids2index.update({seq_id: to_index}) - elif to_index == index: - seq_ids2index.update({seq_id: from_index}) - - def _update_mapping_index(self, from_index: int, to_index: int): - for seq_ids2index in self.mamba_cache_indices_mapping.values(): - for seq_id, index in seq_ids2index.items(): - if from_index == index: - seq_ids2index.update({seq_id: to_index}) - return - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): """ Copy the relevant Mamba cache into the CUDA graph input buffer @@ -806,7 +751,8 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int): The buffer is used to maintain the Mamba Cache during the CUDA graph replay runs. """ - return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache) + state_indices = self.mamba_cache[0][:batch_size] + return (state_indices, ) + tuple(self.mamba_cache[1:]) def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): for req_id in finished_seq_groups_req_ids: @@ -817,7 +763,7 @@ def _first_free_index_in_mamba_cache( self, indices_range: Optional[List[int]] = None) -> int: assert self.mamba_cache is not None if indices_range is None: - max_possible_batch_size = self.mamba_cache[0].shape[1] + max_possible_batch_size = self.mamba_cache[0].shape[0] indices_range = list(range(max_possible_batch_size)) all_occupied_indices = self._get_all_occupied_indices() for i in indices_range: @@ -852,7 +798,10 @@ def _prepare_mamba_cache(self): conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() assert conv_state_shape is not None and temporal_state_shape is not None - self.mamba_cache = (torch.empty(size=(mamba_layers, max_batch_size) + + self.mamba_cache = (torch.empty((max_batch_size), + dtype=torch.int32, + device="cuda"), + torch.empty(size=(mamba_layers, max_batch_size) + conv_state_shape, dtype=dtype, device="cuda"),