Skip to content

Commit b1cf5a7

Browse files
mzusmanMor Zusman
authored andcommitted
[Model][Jamba] Mamba cache single buffer (vllm-project#6739)
Co-authored-by: Mor Zusman <morz@ai21.com> Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent cd30e06 commit b1cf5a7

File tree

2 files changed

+148
-124
lines changed

2 files changed

+148
-124
lines changed

vllm/model_executor/models/jamba.py

Lines changed: 148 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -609,12 +609,8 @@ def __init__(
609609
# compatibility
610610
if not lora_config else lora_config.lora_vocab_padding_size,
611611
)
612-
# Current step used indices
613-
self.current_indices: List[int] = []
614612
# Used to track and store by the Mamba cache between steps.
615613
self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple()
616-
# Used as an input_buffer for the CUDA graph runs.
617-
self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple()
618614
# Maps between the request id and a dict that maps between the seq_id
619615
# and its index inside the self.mamba_cache
620616
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
@@ -644,95 +640,148 @@ def forward(self,
644640
batch_size = input_ids.shape[0]
645641
if attn_metadata.prefill_metadata:
646642
batch_size = len(request_ids_to_seq_ids)
647-
(
648-
current_seqlen_agnostic_cache,
649-
indices,
650-
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
651-
batch_size,
652-
finished_requests_ids)
643+
mamba_cache = self._prepare_current_run_mamba_cache(
644+
request_ids_to_seq_ids, batch_size, finished_requests_ids)
653645
else:
654646
# CUDA graph capturing runs
655-
current_seqlen_agnostic_cache, indices = (
656-
kwargs["seqlen_agnostic_capture_inputs"],
657-
[],
658-
)
659-
self.current_indices = indices
647+
mamba_cache = kwargs["seqlen_agnostic_capture_inputs"]
660648

661649
hidden_states = self.model(input_ids, positions, kv_caches,
662-
attn_metadata,
663-
current_seqlen_agnostic_cache[0],
664-
current_seqlen_agnostic_cache[1])
665-
666-
if "seqlen_agnostic_capture_inputs" not in kwargs:
667-
self._copy_mamba_cache_by_indices(self.current_indices,
668-
current_seqlen_agnostic_cache)
669-
650+
attn_metadata, mamba_cache[0],
651+
mamba_cache[1])
670652
return hidden_states
671653

672-
def _copy_mamba_cache_by_indices(
673-
self, indices: List[int],
674-
current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]):
675-
for i, offset in enumerate(indices):
676-
self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache)
654+
def _swap_mamba_cache(self, from_index: int, to_index: int):
655+
assert len(self.mamba_cache) > 0
656+
for cache_t in self.mamba_cache:
657+
cache_t[:, [to_index,from_index]] = \
658+
cache_t[:, [from_index,to_index]]
677659

678-
def _copy_mamba_cache(self, index_to: int, index_from: int,
679-
from_buffer: Tuple[torch.Tensor, torch.Tensor]):
660+
def _copy_mamba_cache(self, from_index: int, to_index: int):
680661
assert len(self.mamba_cache) > 0
681-
for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer):
682-
cache_t[:, index_to].copy_(from_buffer_t[:, index_from],
662+
for cache_t in self.mamba_cache:
663+
cache_t[:, to_index].copy_(cache_t[:, from_index],
683664
non_blocking=True)
684665

685-
def _assign_seq_id_to_mamba_cache(self, cur_rid: str,
686-
seqs_id: List[int]) -> List[int]:
687-
indices_for_current_run = []
688-
for seq_id in seqs_id:
689-
if cur_rid not in self.mamba_cache_indices_mapping:
690-
self.mamba_cache_indices_mapping[cur_rid] = {}
691-
first_free_index = self._first_free_index_in_mamba_cache()
692-
self.mamba_cache_indices_mapping[cur_rid][
693-
seq_id] = first_free_index
694-
index_for_current_run = first_free_index
695-
## case of decoding n>1, copy prefill cache to decoding indices
696-
elif seq_id not in (seq_ids2indices :=
697-
self.mamba_cache_indices_mapping[cur_rid]):
698-
first_free_index = self._first_free_index_in_mamba_cache()
699-
index_exist = list(seq_ids2indices.values())[0]
700-
self._copy_mamba_cache(index_from=index_exist,
701-
index_to=first_free_index,
702-
from_buffer=self.mamba_cache)
703-
self.mamba_cache_indices_mapping[cur_rid][
704-
seq_id] = first_free_index
705-
index_for_current_run = first_free_index
706-
else:
707-
index_for_current_run = self.mamba_cache_indices_mapping[
708-
cur_rid][seq_id]
709-
710-
indices_for_current_run.append(index_for_current_run)
711-
return indices_for_current_run
666+
def _move_out_if_already_occupied(self, index: int,
667+
all_occupied_indices: List[int]):
668+
if index in all_occupied_indices:
669+
first_free_index = self._first_free_index_in_mamba_cache()
670+
# In case occupied, move the occupied to a new empty block
671+
self._move_cache_index_and_mappings(from_index=index,
672+
to_index=first_free_index)
673+
674+
def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
675+
seq_id: int,
676+
destination_index: int):
677+
"""
678+
Assign (req_id,seq_id) pair to a `destination_index` index, if
679+
already occupied, move the occupying index to a free index.
680+
"""
681+
all_occupied_indices = self._get_all_occupied_indices()
682+
if cur_rid not in self.mamba_cache_indices_mapping:
683+
self._move_out_if_already_occupied(
684+
index=destination_index,
685+
all_occupied_indices=all_occupied_indices)
686+
self.mamba_cache_indices_mapping[cur_rid] = {
687+
seq_id: destination_index
688+
}
689+
elif seq_id not in (seq_ids2indices :=
690+
self.mamba_cache_indices_mapping[cur_rid]):
691+
# parallel sampling , where n > 1, assume prefill have
692+
# already happened now we only need to copy the already
693+
# existing cache into the siblings seq_ids caches
694+
self._move_out_if_already_occupied(
695+
index=destination_index,
696+
all_occupied_indices=all_occupied_indices)
697+
index_exists = list(seq_ids2indices.values())[0]
698+
# case of decoding n>1, copy prefill cache to decoding indices
699+
self._copy_mamba_cache(from_index=index_exists,
700+
to_index=destination_index)
701+
self.mamba_cache_indices_mapping[cur_rid][
702+
seq_id] = destination_index
703+
else:
704+
# already exists
705+
cache_index_already_exists = self.mamba_cache_indices_mapping[
706+
cur_rid][seq_id]
707+
if cache_index_already_exists != destination_index:
708+
# In case the seq id already exists but not in
709+
# the right destination, swap it with what's occupying it
710+
self._swap_pair_indices_and_mappings(
711+
from_index=cache_index_already_exists,
712+
to_index=destination_index)
712713

713714
def _prepare_current_run_mamba_cache(
714-
self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int,
715-
finished_requests_ids: List[str]
716-
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]:
717-
indices_for_current_run = []
718-
for request_id, seqs_id in request_ids_to_seq_ids.items():
715+
self, request_ids_to_seq_ids: Dict[str, list[int]],
716+
batch_size: int, finished_requests_ids: List[str]):
717+
running_indices = []
718+
request_ids_to_seq_ids_flatten = [
719+
(req_id, seq_id)
720+
for req_id, seq_ids in request_ids_to_seq_ids.items()
721+
for seq_id in seq_ids
722+
]
723+
for dest_index, (request_id,
724+
seq_id) in enumerate(request_ids_to_seq_ids_flatten):
719725
if request_id in finished_requests_ids:
720-
# Do not allocate cache for requests that run
726+
# Do not allocate cache index for requests that run
721727
# and finish right after
722728
continue
723-
indices_for_current_run += self._assign_seq_id_to_mamba_cache(
724-
request_id, seqs_id)
725-
## Pad the batch in case of running batch that was not captured via CG
726-
padded_indices = indices_for_current_run.copy()
727-
pad_index = self._first_free_index_in_mamba_cache()
729+
self._assign_seq_id_to_mamba_cache_in_specific_dest(
730+
request_id, seq_id, dest_index)
731+
running_indices.append(dest_index)
728732

729-
for _ in range(batch_size - len(indices_for_current_run)):
730-
padded_indices.append(pad_index)
733+
self._clean_up_first_bs_blocks(batch_size, running_indices)
734+
conv_state = self.mamba_cache[0][:, :batch_size]
735+
temporal_state = self.mamba_cache[1][:, :batch_size]
731736

732-
conv_state = self.mamba_cache[0][:, padded_indices]
733-
temporal_state = self.mamba_cache[1][:, padded_indices]
737+
return (conv_state, temporal_state)
734738

735-
return (conv_state, temporal_state), indices_for_current_run
739+
def _get_all_occupied_indices(self):
740+
return [
741+
cache_idx
742+
for seq_ids2indices in self.mamba_cache_indices_mapping.values()
743+
for cache_idx in seq_ids2indices.values()
744+
]
745+
746+
def _clean_up_first_bs_blocks(self, batch_size: int,
747+
indices_for_current_run: List[int]):
748+
# move out all of the occupied but currently not running blocks
749+
# outside of the first n blocks
750+
destination_indices = set([range(batch_size)])
751+
max_possible_batch_size = self.mamba_cache[0].shape[1]
752+
for destination_index in destination_indices:
753+
if destination_index in self._get_all_occupied_indices() and \
754+
destination_index not in indices_for_current_run:
755+
# move not running indices outside of the batch
756+
all_other_indices = list(
757+
range(batch_size, max_possible_batch_size))
758+
first_avail_index = self._first_free_index_in_mamba_cache(
759+
all_other_indices)
760+
self._swap_indices(from_index=destination_index,
761+
to_index=first_avail_index)
762+
763+
def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
764+
self._copy_mamba_cache(from_index=from_index, to_index=to_index)
765+
self._update_mapping_index(from_index=from_index, to_index=to_index)
766+
767+
def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
768+
self._swap_mamba_cache(from_index=from_index, to_index=to_index)
769+
self._swap_mapping_index(from_index=from_index, to_index=to_index)
770+
771+
def _swap_mapping_index(self, from_index: int, to_index: int):
772+
for seq_ids2index in self.mamba_cache_indices_mapping.values():
773+
for seq_id, index in seq_ids2index.items():
774+
if from_index == index:
775+
seq_ids2index.update({seq_id: to_index})
776+
elif to_index == index:
777+
seq_ids2index.update({seq_id: from_index})
778+
779+
def _update_mapping_index(self, from_index: int, to_index: int):
780+
for seq_ids2index in self.mamba_cache_indices_mapping.values():
781+
for seq_id, index in seq_ids2index.items():
782+
if from_index == index:
783+
seq_ids2index.update({seq_id: to_index})
784+
return
736785

737786
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
738787
"""
@@ -747,55 +796,35 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
747796
self._release_mamba_cache(finished_requests_ids)
748797
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
749798
cg_batch_size = input_buffers['input_ids'].shape[0]
750-
(
751-
current_mamba_cache,
752-
indices,
753-
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
754-
cg_batch_size,
755-
finished_requests_ids)
756-
self.current_indices = indices
757-
758-
for input_buffer, current_cache_buffer in zip(
759-
input_buffers["seqlen_agnostic_capture_inputs"],
760-
current_mamba_cache):
761-
input_buffer.copy_(current_cache_buffer, non_blocking=True)
762-
763-
def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs):
764-
"""
765-
Copy the relevant Mamba cache from the CUDA graph input_buffers
766-
back to the JambaForCausalLM.mamba_cache after CUDA
767-
graph replay run is done.
768-
"""
769-
self._copy_mamba_cache_by_indices(
770-
self.current_indices,
771-
input_buffers["seqlen_agnostic_capture_inputs"])
799+
self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
800+
cg_batch_size,
801+
finished_requests_ids)
772802

773803
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
774804
"""
775805
Provide the CUDA graph capture runs with a buffer in adjusted size.
776806
The buffer is used to maintain the Mamba Cache during the CUDA graph
777807
replay runs.
778808
"""
779-
return tuple(buffer[:, :batch_size]
780-
for buffer in self.mamba_gc_cache_buffer)
809+
return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache)
781810

782811
def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]):
783812
for req_id in finished_seq_groups_req_ids:
784813
if req_id in self.mamba_cache_indices_mapping:
785814
self.mamba_cache_indices_mapping.pop(req_id)
786815

787-
def _first_free_index_in_mamba_cache(self) -> int:
788-
if self.mamba_cache:
816+
def _first_free_index_in_mamba_cache(
817+
self, indices_range: Optional[List[int]] = None) -> int:
818+
assert self.mamba_cache is not None
819+
if indices_range is None:
789820
max_possible_batch_size = self.mamba_cache[0].shape[1]
790-
occupied = [
791-
id for seq_ids in self.mamba_cache_indices_mapping.values()
792-
for id in seq_ids.values()
793-
]
794-
first_free_index = [
795-
i not in occupied for i in range(max_possible_batch_size)
796-
].index(True)
797-
return first_free_index
798-
return 0
821+
indices_range = list(range(max_possible_batch_size))
822+
all_occupied_indices = self._get_all_occupied_indices()
823+
for i in indices_range:
824+
if i not in all_occupied_indices:
825+
return i
826+
raise Exception("Couldn't find a free spot in the mamba cache! This"
827+
"should never happen")
799828

800829
def _get_mamba_cache_shape(
801830
self
@@ -819,20 +848,18 @@ def _prepare_mamba_cache(self):
819848
[layer_type == "mamba" for layer_type in layers_type])
820849
max_batch_size = (_get_graph_batch_size(
821850
self.scheduler_config.max_num_seqs) if self.scheduler_config else
822-
max(_BATCH_SIZES_TO_CAPTURE)) + 10
851+
max(_BATCH_SIZES_TO_CAPTURE) + 2)
823852
conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
824853
assert conv_state_shape is not None and temporal_state_shape is not None
825854

826-
for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]:
827-
buffer = (torch.empty(size=(mamba_layers, max_batch_size) +
828-
conv_state_shape,
829-
dtype=dtype,
830-
device="cuda"),
831-
torch.empty(size=(mamba_layers, max_batch_size) +
832-
temporal_state_shape,
833-
dtype=dtype,
834-
device="cuda"))
835-
setattr(self, buffername, buffer)
855+
self.mamba_cache = (torch.empty(size=(mamba_layers, max_batch_size) +
856+
conv_state_shape,
857+
dtype=dtype,
858+
device="cuda"),
859+
torch.empty(size=(mamba_layers, max_batch_size) +
860+
temporal_state_shape,
861+
dtype=dtype,
862+
device="cuda"))
836863

837864
def compute_logits(self, hidden_states: torch.Tensor,
838865
sampling_metadata: SamplingMetadata) -> torch.Tensor:

vllm/worker/model_runner.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1711,9 +1711,6 @@ def forward(
17111711
non_blocking=True)
17121712
# Run the graph.
17131713
self.graph.replay()
1714-
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
1715-
self.model.copy_outputs_after_cuda_graphs(self.input_buffers,
1716-
**kwargs)
17171714
# Return the output tensor.
17181715
if get_pp_group().is_last_rank:
17191716
return self.output_buffers["hidden_states"]

0 commit comments

Comments
 (0)