Skip to content

[Model][Jamba] Mamba cache single buffer #6739

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 148 additions & 121 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,12 +609,8 @@ def __init__(
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
# Current step used indices
self.current_indices: List[int] = []
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple()
# Used as an input_buffer for the CUDA graph runs.
self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple()
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
Expand Down Expand Up @@ -644,95 +640,148 @@ def forward(self,
batch_size = input_ids.shape[0]
if attn_metadata.prefill_metadata:
batch_size = len(request_ids_to_seq_ids)
(
current_seqlen_agnostic_cache,
indices,
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
batch_size,
finished_requests_ids)
mamba_cache = self._prepare_current_run_mamba_cache(
request_ids_to_seq_ids, batch_size, finished_requests_ids)
else:
# CUDA graph capturing runs
current_seqlen_agnostic_cache, indices = (
kwargs["seqlen_agnostic_capture_inputs"],
[],
)
self.current_indices = indices
mamba_cache = kwargs["seqlen_agnostic_capture_inputs"]

hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata,
current_seqlen_agnostic_cache[0],
current_seqlen_agnostic_cache[1])

if "seqlen_agnostic_capture_inputs" not in kwargs:
self._copy_mamba_cache_by_indices(self.current_indices,
current_seqlen_agnostic_cache)

attn_metadata, mamba_cache[0],
mamba_cache[1])
return hidden_states

def _copy_mamba_cache_by_indices(
self, indices: List[int],
current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]):
for i, offset in enumerate(indices):
self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache)
def _swap_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
for cache_t in self.mamba_cache:
cache_t[:, [to_index,from_index]] = \
cache_t[:, [from_index,to_index]]

def _copy_mamba_cache(self, index_to: int, index_from: int,
from_buffer: Tuple[torch.Tensor, torch.Tensor]):
def _copy_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer):
cache_t[:, index_to].copy_(from_buffer_t[:, index_from],
for cache_t in self.mamba_cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)

def _assign_seq_id_to_mamba_cache(self, cur_rid: str,
seqs_id: List[int]) -> List[int]:
indices_for_current_run = []
for seq_id in seqs_id:
if cur_rid not in self.mamba_cache_indices_mapping:
self.mamba_cache_indices_mapping[cur_rid] = {}
first_free_index = self._first_free_index_in_mamba_cache()
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = first_free_index
index_for_current_run = first_free_index
## case of decoding n>1, copy prefill cache to decoding indices
elif seq_id not in (seq_ids2indices :=
self.mamba_cache_indices_mapping[cur_rid]):
first_free_index = self._first_free_index_in_mamba_cache()
index_exist = list(seq_ids2indices.values())[0]
self._copy_mamba_cache(index_from=index_exist,
index_to=first_free_index,
from_buffer=self.mamba_cache)
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = first_free_index
index_for_current_run = first_free_index
else:
index_for_current_run = self.mamba_cache_indices_mapping[
cur_rid][seq_id]

indices_for_current_run.append(index_for_current_run)
return indices_for_current_run
def _move_out_if_already_occupied(self, index: int,
all_occupied_indices: List[int]):
if index in all_occupied_indices:
first_free_index = self._first_free_index_in_mamba_cache()
# In case occupied, move the occupied to a new empty block
self._move_cache_index_and_mappings(from_index=index,
to_index=first_free_index)

def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
seq_id: int,
destination_index: int):
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
all_occupied_indices = self._get_all_occupied_indices()
if cur_rid not in self.mamba_cache_indices_mapping:
self._move_out_if_already_occupied(
index=destination_index,
all_occupied_indices=all_occupied_indices)
self.mamba_cache_indices_mapping[cur_rid] = {
seq_id: destination_index
}
elif seq_id not in (seq_ids2indices :=
self.mamba_cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened now we only need to copy the already
# existing cache into the siblings seq_ids caches
self._move_out_if_already_occupied(
index=destination_index,
all_occupied_indices=all_occupied_indices)
index_exists = list(seq_ids2indices.values())[0]
# case of decoding n>1, copy prefill cache to decoding indices
self._copy_mamba_cache(from_index=index_exists,
to_index=destination_index)
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = destination_index
else:
# already exists
cache_index_already_exists = self.mamba_cache_indices_mapping[
cur_rid][seq_id]
if cache_index_already_exists != destination_index:
# In case the seq id already exists but not in
# the right destination, swap it with what's occupying it
self._swap_pair_indices_and_mappings(
from_index=cache_index_already_exists,
to_index=destination_index)

def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int,
finished_requests_ids: List[str]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]:
indices_for_current_run = []
for request_id, seqs_id in request_ids_to_seq_ids.items():
self, request_ids_to_seq_ids: Dict[str, list[int]],
batch_size: int, finished_requests_ids: List[str]):
running_indices = []
request_ids_to_seq_ids_flatten = [
(req_id, seq_id)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
for dest_index, (request_id,
seq_id) in enumerate(request_ids_to_seq_ids_flatten):
if request_id in finished_requests_ids:
# Do not allocate cache for requests that run
# Do not allocate cache index for requests that run
# and finish right after
continue
indices_for_current_run += self._assign_seq_id_to_mamba_cache(
request_id, seqs_id)
## Pad the batch in case of running batch that was not captured via CG
padded_indices = indices_for_current_run.copy()
pad_index = self._first_free_index_in_mamba_cache()
self._assign_seq_id_to_mamba_cache_in_specific_dest(
request_id, seq_id, dest_index)
running_indices.append(dest_index)

for _ in range(batch_size - len(indices_for_current_run)):
padded_indices.append(pad_index)
self._clean_up_first_bs_blocks(batch_size, running_indices)
conv_state = self.mamba_cache[0][:, :batch_size]
temporal_state = self.mamba_cache[1][:, :batch_size]

conv_state = self.mamba_cache[0][:, padded_indices]
temporal_state = self.mamba_cache[1][:, padded_indices]
return (conv_state, temporal_state)

return (conv_state, temporal_state), indices_for_current_run
def _get_all_occupied_indices(self):
return [
cache_idx
for seq_ids2indices in self.mamba_cache_indices_mapping.values()
for cache_idx in seq_ids2indices.values()
]

def _clean_up_first_bs_blocks(self, batch_size: int,
indices_for_current_run: List[int]):
# move out all of the occupied but currently not running blocks
# outside of the first n blocks
destination_indices = set([range(batch_size)])
max_possible_batch_size = self.mamba_cache[0].shape[1]
for destination_index in destination_indices:
if destination_index in self._get_all_occupied_indices() and \
destination_index not in indices_for_current_run:
# move not running indices outside of the batch
all_other_indices = list(
range(batch_size, max_possible_batch_size))
first_avail_index = self._first_free_index_in_mamba_cache(
all_other_indices)
self._swap_indices(from_index=destination_index,
to_index=first_avail_index)

def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
self._copy_mamba_cache(from_index=from_index, to_index=to_index)
self._update_mapping_index(from_index=from_index, to_index=to_index)

def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
self._swap_mamba_cache(from_index=from_index, to_index=to_index)
self._swap_mapping_index(from_index=from_index, to_index=to_index)

def _swap_mapping_index(self, from_index: int, to_index: int):
for seq_ids2index in self.mamba_cache_indices_mapping.values():
for seq_id, index in seq_ids2index.items():
if from_index == index:
seq_ids2index.update({seq_id: to_index})
elif to_index == index:
seq_ids2index.update({seq_id: from_index})

def _update_mapping_index(self, from_index: int, to_index: int):
for seq_ids2index in self.mamba_cache_indices_mapping.values():
for seq_id, index in seq_ids2index.items():
if from_index == index:
seq_ids2index.update({seq_id: to_index})
return

def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Expand All @@ -747,55 +796,35 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
self._release_mamba_cache(finished_requests_ids)
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
cg_batch_size = input_buffers['input_ids'].shape[0]
(
current_mamba_cache,
indices,
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
cg_batch_size,
finished_requests_ids)
self.current_indices = indices

for input_buffer, current_cache_buffer in zip(
input_buffers["seqlen_agnostic_capture_inputs"],
current_mamba_cache):
input_buffer.copy_(current_cache_buffer, non_blocking=True)

def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant Mamba cache from the CUDA graph input_buffers
back to the JambaForCausalLM.mamba_cache after CUDA
graph replay run is done.
"""
self._copy_mamba_cache_by_indices(
self.current_indices,
input_buffers["seqlen_agnostic_capture_inputs"])
self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
cg_batch_size,
finished_requests_ids)

def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
return tuple(buffer[:, :batch_size]
for buffer in self.mamba_gc_cache_buffer)
return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache)

def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.mamba_cache_indices_mapping:
self.mamba_cache_indices_mapping.pop(req_id)

def _first_free_index_in_mamba_cache(self) -> int:
if self.mamba_cache:
def _first_free_index_in_mamba_cache(
self, indices_range: Optional[List[int]] = None) -> int:
assert self.mamba_cache is not None
if indices_range is None:
max_possible_batch_size = self.mamba_cache[0].shape[1]
occupied = [
id for seq_ids in self.mamba_cache_indices_mapping.values()
for id in seq_ids.values()
]
first_free_index = [
i not in occupied for i in range(max_possible_batch_size)
].index(True)
return first_free_index
return 0
indices_range = list(range(max_possible_batch_size))
all_occupied_indices = self._get_all_occupied_indices()
for i in indices_range:
if i not in all_occupied_indices:
return i
raise Exception("Couldn't find a free spot in the mamba cache! This"
"should never happen")

def _get_mamba_cache_shape(
self
Expand All @@ -819,20 +848,18 @@ def _prepare_mamba_cache(self):
[layer_type == "mamba" for layer_type in layers_type])
max_batch_size = (_get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config else
max(_BATCH_SIZES_TO_CAPTURE)) + 10
max(_BATCH_SIZES_TO_CAPTURE) + 2)
conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
assert conv_state_shape is not None and temporal_state_shape is not None

for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]:
buffer = (torch.empty(size=(mamba_layers, max_batch_size) +
conv_state_shape,
dtype=dtype,
device="cuda"),
torch.empty(size=(mamba_layers, max_batch_size) +
temporal_state_shape,
dtype=dtype,
device="cuda"))
setattr(self, buffername, buffer)
self.mamba_cache = (torch.empty(size=(mamba_layers, max_batch_size) +
conv_state_shape,
dtype=dtype,
device="cuda"),
torch.empty(size=(mamba_layers, max_batch_size) +
temporal_state_shape,
dtype=dtype,
device="cuda"))

def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
Expand Down
3 changes: 0 additions & 3 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1711,9 +1711,6 @@ def forward(
non_blocking=True)
# Run the graph.
self.graph.replay()
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
self.model.copy_outputs_after_cuda_graphs(self.input_buffers,
**kwargs)
# Return the output tensor.
if get_pp_group().is_last_rank:
return self.output_buffers["hidden_states"]
Expand Down
Loading