@@ -663,30 +663,35 @@ def _copy_mamba_cache(self, to_index: int, from_index: int):
663
663
cache_t [:, to_index ].copy_ (cache_t [:, from_index ],
664
664
non_blocking = True )
665
665
666
+ def _move_out_if_already_occupied (self , index : int ,
667
+ all_occupied_indices : List [int ]):
668
+ if index in all_occupied_indices :
669
+ first_free_index = self ._first_free_index_in_mamba_cache ()
670
+ # In case occupied, move the occupied to a new empty block
671
+ self ._move_cache_index_and_mappings (from_index = index ,
672
+ to_index = first_free_index )
673
+
666
674
def _assign_seq_id_to_mamba_cache_in_specific_dest (self , cur_rid : str ,
667
675
seq_id : int ,
668
676
destination_index : int ):
669
677
all_occupied_indices = self ._get_all_occupied_indices ()
670
678
if cur_rid not in self .mamba_cache_indices_mapping :
671
- ## assign new free index
672
- if destination_index in all_occupied_indices :
673
- # In case occupied, move the occupied to a new empty block
674
- self ._move_cache_index_and_mappings (
675
- from_index = destination_index ,
676
- to_index = self ._first_free_index_in_mamba_cache ())
679
+ self ._move_out_if_already_occupied (
680
+ index = destination_index ,
681
+ all_occupied_indices = all_occupied_indices )
677
682
self .mamba_cache_indices_mapping [cur_rid ] = {
678
683
seq_id : destination_index
679
684
}
680
685
elif seq_id not in (seq_ids2indices :=
681
686
self .mamba_cache_indices_mapping [cur_rid ]):
682
- # N > 1
683
- first_free_index = self . _first_free_index_in_mamba_cache ()
684
- if destination_index in all_occupied_indices :
685
- # In case occupied, move the occupied to a new empty block
686
- self . _move_cache_index_and_mappings (
687
- from_index = destination_index , to_index = first_free_index )
687
+ # parallel sampling , where n > 1, assume prefill already happend
688
+ # now we only need to copy the already existing cache into the
689
+ # siblings seq_ids caches
690
+ self . _move_out_if_already_occupied (
691
+ index = destination_index ,
692
+ all_occupied_indices = all_occupied_indices )
688
693
index_exists = list (seq_ids2indices .values ())[0 ]
689
- ## case of decoding n>1, copy prefill cache to decoding indices
694
+ # case of decoding n>1, copy prefill cache to decoding indices
690
695
self ._copy_mamba_cache (from_index = index_exists ,
691
696
to_index = destination_index )
692
697
self .mamba_cache_indices_mapping [cur_rid ][
0 commit comments