Skip to content

Commit 60857a3

Browse files
committed
Factor out moving out the occupied index
1 parent df269e5 commit 60857a3

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

vllm/model_executor/models/jamba.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -663,30 +663,35 @@ def _copy_mamba_cache(self, to_index: int, from_index: int):
663663
cache_t[:, to_index].copy_(cache_t[:, from_index],
664664
non_blocking=True)
665665

666+
def _move_out_if_already_occupied(self, index: int,
667+
all_occupied_indices: List[int]):
668+
if index in all_occupied_indices:
669+
first_free_index = self._first_free_index_in_mamba_cache()
670+
# In case occupied, move the occupied to a new empty block
671+
self._move_cache_index_and_mappings(from_index=index,
672+
to_index=first_free_index)
673+
666674
def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
667675
seq_id: int,
668676
destination_index: int):
669677
all_occupied_indices = self._get_all_occupied_indices()
670678
if cur_rid not in self.mamba_cache_indices_mapping:
671-
## assign new free index
672-
if destination_index in all_occupied_indices:
673-
# In case occupied, move the occupied to a new empty block
674-
self._move_cache_index_and_mappings(
675-
from_index=destination_index,
676-
to_index=self._first_free_index_in_mamba_cache())
679+
self._move_out_if_already_occupied(
680+
index=destination_index,
681+
all_occupied_indices=all_occupied_indices)
677682
self.mamba_cache_indices_mapping[cur_rid] = {
678683
seq_id: destination_index
679684
}
680685
elif seq_id not in (seq_ids2indices :=
681686
self.mamba_cache_indices_mapping[cur_rid]):
682-
# N > 1
683-
first_free_index = self._first_free_index_in_mamba_cache()
684-
if destination_index in all_occupied_indices:
685-
# In case occupied, move the occupied to a new empty block
686-
self._move_cache_index_and_mappings(
687-
from_index=destination_index, to_index=first_free_index)
687+
# parallel sampling , where n > 1, assume prefill already happend
688+
# now we only need to copy the already existing cache into the
689+
# siblings seq_ids caches
690+
self._move_out_if_already_occupied(
691+
index=destination_index,
692+
all_occupied_indices=all_occupied_indices)
688693
index_exists = list(seq_ids2indices.values())[0]
689-
## case of decoding n>1, copy prefill cache to decoding indices
694+
# case of decoding n>1, copy prefill cache to decoding indices
690695
self._copy_mamba_cache(from_index=index_exists,
691696
to_index=destination_index)
692697
self.mamba_cache_indices_mapping[cur_rid][

0 commit comments

Comments
 (0)