@@ -609,12 +609,8 @@ def __init__(
609
609
# compatibility
610
610
if not lora_config else lora_config .lora_vocab_padding_size ,
611
611
)
612
- # Current step used indices
613
- self .current_indices : List [int ] = []
614
612
# Used to track and store by the Mamba cache between steps.
615
613
self .mamba_cache : Tuple [torch .Tensor , torch .Tensor ] = tuple ()
616
- # Used as an input_buffer for the CUDA graph runs.
617
- self .mamba_gc_cache_buffer : Tuple [torch .Tensor , torch .Tensor ] = tuple ()
618
614
# Maps between the request id and a dict that maps between the seq_id
619
615
# and its index inside the self.mamba_cache
620
616
self .mamba_cache_indices_mapping : Dict [str , Dict [int , int ]] = {}
@@ -644,95 +640,148 @@ def forward(self,
644
640
batch_size = input_ids .shape [0 ]
645
641
if attn_metadata .prefill_metadata :
646
642
batch_size = len (request_ids_to_seq_ids )
647
- (
648
- current_seqlen_agnostic_cache ,
649
- indices ,
650
- ) = self ._prepare_current_run_mamba_cache (request_ids_to_seq_ids ,
651
- batch_size ,
652
- finished_requests_ids )
643
+ mamba_cache = self ._prepare_current_run_mamba_cache (
644
+ request_ids_to_seq_ids , batch_size , finished_requests_ids )
653
645
else :
654
646
# CUDA graph capturing runs
655
- current_seqlen_agnostic_cache , indices = (
656
- kwargs ["seqlen_agnostic_capture_inputs" ],
657
- [],
658
- )
659
- self .current_indices = indices
647
+ mamba_cache = kwargs ["seqlen_agnostic_capture_inputs" ]
660
648
661
649
hidden_states = self .model (input_ids , positions , kv_caches ,
662
- attn_metadata ,
663
- current_seqlen_agnostic_cache [0 ],
664
- current_seqlen_agnostic_cache [1 ])
665
-
666
- if "seqlen_agnostic_capture_inputs" not in kwargs :
667
- self ._copy_mamba_cache_by_indices (self .current_indices ,
668
- current_seqlen_agnostic_cache )
669
-
650
+ attn_metadata , mamba_cache [0 ],
651
+ mamba_cache [1 ])
670
652
return hidden_states
671
653
672
- def _copy_mamba_cache_by_indices (
673
- self , indices : List [ int ],
674
- current_seqlen_agnostic_cache : Tuple [ torch . Tensor , torch . Tensor ]) :
675
- for i , offset in enumerate ( indices ):
676
- self . _copy_mamba_cache ( offset , i , current_seqlen_agnostic_cache )
654
+ def _swap_mamba_cache ( self , from_index : int , to_index : int ):
655
+ assert len ( self . mamba_cache ) > 0
656
+ for cache_t in self . mamba_cache :
657
+ cache_t [:, [ to_index , from_index ]] = \
658
+ cache_t [:, [ from_index , to_index ]]
677
659
678
- def _copy_mamba_cache (self , index_to : int , index_from : int ,
679
- from_buffer : Tuple [torch .Tensor , torch .Tensor ]):
660
+ def _copy_mamba_cache (self , from_index : int , to_index : int ):
680
661
assert len (self .mamba_cache ) > 0
681
- for ( cache_t , from_buffer_t ) in zip ( self .mamba_cache , from_buffer ) :
682
- cache_t [:, index_to ].copy_ (from_buffer_t [:, index_from ],
662
+ for cache_t in self .mamba_cache :
663
+ cache_t [:, to_index ].copy_ (cache_t [:, from_index ],
683
664
non_blocking = True )
684
665
685
- def _assign_seq_id_to_mamba_cache (self , cur_rid : str ,
686
- seqs_id : List [int ]) -> List [int ]:
687
- indices_for_current_run = []
688
- for seq_id in seqs_id :
689
- if cur_rid not in self .mamba_cache_indices_mapping :
690
- self .mamba_cache_indices_mapping [cur_rid ] = {}
691
- first_free_index = self ._first_free_index_in_mamba_cache ()
692
- self .mamba_cache_indices_mapping [cur_rid ][
693
- seq_id ] = first_free_index
694
- index_for_current_run = first_free_index
695
- ## case of decoding n>1, copy prefill cache to decoding indices
696
- elif seq_id not in (seq_ids2indices :=
697
- self .mamba_cache_indices_mapping [cur_rid ]):
698
- first_free_index = self ._first_free_index_in_mamba_cache ()
699
- index_exist = list (seq_ids2indices .values ())[0 ]
700
- self ._copy_mamba_cache (index_from = index_exist ,
701
- index_to = first_free_index ,
702
- from_buffer = self .mamba_cache )
703
- self .mamba_cache_indices_mapping [cur_rid ][
704
- seq_id ] = first_free_index
705
- index_for_current_run = first_free_index
706
- else :
707
- index_for_current_run = self .mamba_cache_indices_mapping [
708
- cur_rid ][seq_id ]
709
-
710
- indices_for_current_run .append (index_for_current_run )
711
- return indices_for_current_run
666
+ def _move_out_if_already_occupied (self , index : int ,
667
+ all_occupied_indices : List [int ]):
668
+ if index in all_occupied_indices :
669
+ first_free_index = self ._first_free_index_in_mamba_cache ()
670
+ # In case occupied, move the occupied to a new empty block
671
+ self ._move_cache_index_and_mappings (from_index = index ,
672
+ to_index = first_free_index )
673
+
674
+ def _assign_seq_id_to_mamba_cache_in_specific_dest (self , cur_rid : str ,
675
+ seq_id : int ,
676
+ destination_index : int ):
677
+ """
678
+ Assign (req_id,seq_id) pair to a `destination_index` index, if
679
+ already occupied, move the occupying index to a free index.
680
+ """
681
+ all_occupied_indices = self ._get_all_occupied_indices ()
682
+ if cur_rid not in self .mamba_cache_indices_mapping :
683
+ self ._move_out_if_already_occupied (
684
+ index = destination_index ,
685
+ all_occupied_indices = all_occupied_indices )
686
+ self .mamba_cache_indices_mapping [cur_rid ] = {
687
+ seq_id : destination_index
688
+ }
689
+ elif seq_id not in (seq_ids2indices :=
690
+ self .mamba_cache_indices_mapping [cur_rid ]):
691
+ # parallel sampling , where n > 1, assume prefill have
692
+ # already happened now we only need to copy the already
693
+ # existing cache into the siblings seq_ids caches
694
+ self ._move_out_if_already_occupied (
695
+ index = destination_index ,
696
+ all_occupied_indices = all_occupied_indices )
697
+ index_exists = list (seq_ids2indices .values ())[0 ]
698
+ # case of decoding n>1, copy prefill cache to decoding indices
699
+ self ._copy_mamba_cache (from_index = index_exists ,
700
+ to_index = destination_index )
701
+ self .mamba_cache_indices_mapping [cur_rid ][
702
+ seq_id ] = destination_index
703
+ else :
704
+ # already exists
705
+ cache_index_already_exists = self .mamba_cache_indices_mapping [
706
+ cur_rid ][seq_id ]
707
+ if cache_index_already_exists != destination_index :
708
+ # In case the seq id already exists but not in
709
+ # the right destination, swap it with what's occupying it
710
+ self ._swap_pair_indices_and_mappings (
711
+ from_index = cache_index_already_exists ,
712
+ to_index = destination_index )
712
713
713
714
def _prepare_current_run_mamba_cache (
714
- self , request_ids_to_seq_ids : Dict [str , list [int ]], batch_size : int ,
715
- finished_requests_ids : List [str ]
716
- ) -> Tuple [Tuple [torch .Tensor , torch .Tensor ], List [int ]]:
717
- indices_for_current_run = []
718
- for request_id , seqs_id in request_ids_to_seq_ids .items ():
715
+ self , request_ids_to_seq_ids : Dict [str , list [int ]],
716
+ batch_size : int , finished_requests_ids : List [str ]):
717
+ running_indices = []
718
+ request_ids_to_seq_ids_flatten = [
719
+ (req_id , seq_id )
720
+ for req_id , seq_ids in request_ids_to_seq_ids .items ()
721
+ for seq_id in seq_ids
722
+ ]
723
+ for dest_index , (request_id ,
724
+ seq_id ) in enumerate (request_ids_to_seq_ids_flatten ):
719
725
if request_id in finished_requests_ids :
720
- # Do not allocate cache for requests that run
726
+ # Do not allocate cache index for requests that run
721
727
# and finish right after
722
728
continue
723
- indices_for_current_run += self ._assign_seq_id_to_mamba_cache (
724
- request_id , seqs_id )
725
- ## Pad the batch in case of running batch that was not captured via CG
726
- padded_indices = indices_for_current_run .copy ()
727
- pad_index = self ._first_free_index_in_mamba_cache ()
729
+ self ._assign_seq_id_to_mamba_cache_in_specific_dest (
730
+ request_id , seq_id , dest_index )
731
+ running_indices .append (dest_index )
728
732
729
- for _ in range (batch_size - len (indices_for_current_run )):
730
- padded_indices .append (pad_index )
733
+ self ._clean_up_first_bs_blocks (batch_size , running_indices )
734
+ conv_state = self .mamba_cache [0 ][:, :batch_size ]
735
+ temporal_state = self .mamba_cache [1 ][:, :batch_size ]
731
736
732
- conv_state = self .mamba_cache [0 ][:, padded_indices ]
733
- temporal_state = self .mamba_cache [1 ][:, padded_indices ]
737
+ return (conv_state , temporal_state )
734
738
735
- return (conv_state , temporal_state ), indices_for_current_run
739
+ def _get_all_occupied_indices (self ):
740
+ return [
741
+ cache_idx
742
+ for seq_ids2indices in self .mamba_cache_indices_mapping .values ()
743
+ for cache_idx in seq_ids2indices .values ()
744
+ ]
745
+
746
+ def _clean_up_first_bs_blocks (self , batch_size : int ,
747
+ indices_for_current_run : List [int ]):
748
+ # move out all of the occupied but currently not running blocks
749
+ # outside of the first n blocks
750
+ destination_indices = set ([range (batch_size )])
751
+ max_possible_batch_size = self .mamba_cache [0 ].shape [1 ]
752
+ for destination_index in destination_indices :
753
+ if destination_index in self ._get_all_occupied_indices () and \
754
+ destination_index not in indices_for_current_run :
755
+ # move not running indices outside of the batch
756
+ all_other_indices = list (
757
+ range (batch_size , max_possible_batch_size ))
758
+ first_avail_index = self ._first_free_index_in_mamba_cache (
759
+ all_other_indices )
760
+ self ._swap_indices (from_index = destination_index ,
761
+ to_index = first_avail_index )
762
+
763
+ def _move_cache_index_and_mappings (self , from_index : int , to_index : int ):
764
+ self ._copy_mamba_cache (from_index = from_index , to_index = to_index )
765
+ self ._update_mapping_index (from_index = from_index , to_index = to_index )
766
+
767
+ def _swap_pair_indices_and_mappings (self , from_index : int , to_index : int ):
768
+ self ._swap_mamba_cache (from_index = from_index , to_index = to_index )
769
+ self ._swap_mapping_index (from_index = from_index , to_index = to_index )
770
+
771
+ def _swap_mapping_index (self , from_index : int , to_index : int ):
772
+ for seq_ids2index in self .mamba_cache_indices_mapping .values ():
773
+ for seq_id , index in seq_ids2index .items ():
774
+ if from_index == index :
775
+ seq_ids2index .update ({seq_id : to_index })
776
+ elif to_index == index :
777
+ seq_ids2index .update ({seq_id : from_index })
778
+
779
+ def _update_mapping_index (self , from_index : int , to_index : int ):
780
+ for seq_ids2index in self .mamba_cache_indices_mapping .values ():
781
+ for seq_id , index in seq_ids2index .items ():
782
+ if from_index == index :
783
+ seq_ids2index .update ({seq_id : to_index })
784
+ return
736
785
737
786
def copy_inputs_before_cuda_graphs (self , input_buffers , ** kwargs ):
738
787
"""
@@ -747,55 +796,35 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
747
796
self ._release_mamba_cache (finished_requests_ids )
748
797
request_ids_to_seq_ids = kwargs ["request_ids_to_seq_ids" ]
749
798
cg_batch_size = input_buffers ['input_ids' ].shape [0 ]
750
- (
751
- current_mamba_cache ,
752
- indices ,
753
- ) = self ._prepare_current_run_mamba_cache (request_ids_to_seq_ids ,
754
- cg_batch_size ,
755
- finished_requests_ids )
756
- self .current_indices = indices
757
-
758
- for input_buffer , current_cache_buffer in zip (
759
- input_buffers ["seqlen_agnostic_capture_inputs" ],
760
- current_mamba_cache ):
761
- input_buffer .copy_ (current_cache_buffer , non_blocking = True )
762
-
763
- def copy_outputs_after_cuda_graphs (self , input_buffers , ** kwargs ):
764
- """
765
- Copy the relevant Mamba cache from the CUDA graph input_buffers
766
- back to the JambaForCausalLM.mamba_cache after CUDA
767
- graph replay run is done.
768
- """
769
- self ._copy_mamba_cache_by_indices (
770
- self .current_indices ,
771
- input_buffers ["seqlen_agnostic_capture_inputs" ])
799
+ self ._prepare_current_run_mamba_cache (request_ids_to_seq_ids ,
800
+ cg_batch_size ,
801
+ finished_requests_ids )
772
802
773
803
def get_seqlen_agnostic_capture_inputs (self , batch_size : int ):
774
804
"""
775
805
Provide the CUDA graph capture runs with a buffer in adjusted size.
776
806
The buffer is used to maintain the Mamba Cache during the CUDA graph
777
807
replay runs.
778
808
"""
779
- return tuple (buffer [:, :batch_size ]
780
- for buffer in self .mamba_gc_cache_buffer )
809
+ return tuple (buffer [:, :batch_size ] for buffer in self .mamba_cache )
781
810
782
811
def _release_mamba_cache (self , finished_seq_groups_req_ids : List [str ]):
783
812
for req_id in finished_seq_groups_req_ids :
784
813
if req_id in self .mamba_cache_indices_mapping :
785
814
self .mamba_cache_indices_mapping .pop (req_id )
786
815
787
- def _first_free_index_in_mamba_cache (self ) -> int :
788
- if self .mamba_cache :
816
+ def _first_free_index_in_mamba_cache (
817
+ self , indices_range : Optional [List [int ]] = None ) -> int :
818
+ assert self .mamba_cache is not None
819
+ if indices_range is None :
789
820
max_possible_batch_size = self .mamba_cache [0 ].shape [1 ]
790
- occupied = [
791
- id for seq_ids in self .mamba_cache_indices_mapping .values ()
792
- for id in seq_ids .values ()
793
- ]
794
- first_free_index = [
795
- i not in occupied for i in range (max_possible_batch_size )
796
- ].index (True )
797
- return first_free_index
798
- return 0
821
+ indices_range = list (range (max_possible_batch_size ))
822
+ all_occupied_indices = self ._get_all_occupied_indices ()
823
+ for i in indices_range :
824
+ if i not in all_occupied_indices :
825
+ return i
826
+ raise Exception ("Couldn't find a free spot in the mamba cache! This"
827
+ "should never happen" )
799
828
800
829
def _get_mamba_cache_shape (
801
830
self
@@ -819,20 +848,18 @@ def _prepare_mamba_cache(self):
819
848
[layer_type == "mamba" for layer_type in layers_type ])
820
849
max_batch_size = (_get_graph_batch_size (
821
850
self .scheduler_config .max_num_seqs ) if self .scheduler_config else
822
- max (_BATCH_SIZES_TO_CAPTURE )) + 10
851
+ max (_BATCH_SIZES_TO_CAPTURE ) + 2 )
823
852
conv_state_shape , temporal_state_shape = self ._get_mamba_cache_shape ()
824
853
assert conv_state_shape is not None and temporal_state_shape is not None
825
854
826
- for buffername in ["mamba_cache" , "mamba_gc_cache_buffer" ]:
827
- buffer = (torch .empty (size = (mamba_layers , max_batch_size ) +
828
- conv_state_shape ,
829
- dtype = dtype ,
830
- device = "cuda" ),
831
- torch .empty (size = (mamba_layers , max_batch_size ) +
832
- temporal_state_shape ,
833
- dtype = dtype ,
834
- device = "cuda" ))
835
- setattr (self , buffername , buffer )
855
+ self .mamba_cache = (torch .empty (size = (mamba_layers , max_batch_size ) +
856
+ conv_state_shape ,
857
+ dtype = dtype ,
858
+ device = "cuda" ),
859
+ torch .empty (size = (mamba_layers , max_batch_size ) +
860
+ temporal_state_shape ,
861
+ dtype = dtype ,
862
+ device = "cuda" ))
836
863
837
864
def compute_logits (self , hidden_states : torch .Tensor ,
838
865
sampling_metadata : SamplingMetadata ) -> torch .Tensor :
0 commit comments