12
12
from transformers import MambaConfig
13
13
14
14
from vllm .attention .backends .abstract import AttentionMetadata
15
- from vllm .config import CacheConfig , LoRAConfig
15
+ from vllm .config import CacheConfig , LoRAConfig , SchedulerConfig
16
16
from vllm .distributed import (get_tensor_model_parallel_rank ,
17
17
get_tensor_model_parallel_world_size )
18
18
from vllm .model_executor .layers .activation import SiluAndMul
27
27
from vllm .model_executor .layers .vocab_parallel_embedding import (
28
28
VocabParallelEmbedding )
29
29
from vllm .model_executor .model_loader .weight_utils import default_weight_loader
30
+ from vllm .model_executor .models .interfaces import HasInnerState
30
31
from vllm .model_executor .sampling_metadata import SamplingMetadata
31
32
from vllm .model_executor .utils import set_weight_attrs
32
33
from vllm .sequence import IntermediateTensors , SamplerOutput
33
- from vllm .worker .model_runner import _BATCH_SIZES_TO_CAPTURE
34
+ from vllm .worker .model_runner import (_BATCH_SIZES_TO_CAPTURE ,
35
+ _get_graph_batch_size )
34
36
35
37
KVCache = Tuple [torch .Tensor , torch .Tensor ]
36
38
@@ -376,7 +378,7 @@ def forward(
376
378
return hidden_states
377
379
378
380
379
- class MambaForCausalLM (nn .Module ):
381
+ class MambaForCausalLM (nn .Module , HasInnerState ):
380
382
packed_modules_mapping = {
381
383
"qkv_proj" : [
382
384
"q_proj" ,
@@ -404,9 +406,11 @@ def __init__(
404
406
cache_config : Optional [CacheConfig ] = None ,
405
407
quant_config : Optional [QuantizationConfig ] = None ,
406
408
lora_config : Optional [LoRAConfig ] = None ,
409
+ scheduler_config : Optional [SchedulerConfig ] = None ,
407
410
) -> None :
408
411
super ().__init__ ()
409
412
self .config = config
413
+ self .scheduler_config = scheduler_config
410
414
self .backbone = MambaModel (config ,
411
415
cache_config = cache_config ,
412
416
quant_config = quant_config ,
@@ -436,7 +440,6 @@ def forward(self,
436
440
attn_metadata : AttentionMetadata ,
437
441
intermediate_tensors : Optional [IntermediateTensors ] = None ,
438
442
** kwargs ):
439
-
440
443
if not self .mamba_cache :
441
444
self ._prepare_mamba_cache ()
442
445
@@ -447,14 +450,16 @@ def forward(self,
447
450
for key in ["request_ids_to_seq_ids" , "finished_requests_ids" ])
448
451
449
452
request_ids_to_seq_ids = kwargs ["request_ids_to_seq_ids" ]
453
+ finished_requests_ids = kwargs ["finished_requests_ids" ]
450
454
batch_size = input_ids .shape [0 ]
451
455
if attn_metadata .prefill_metadata :
452
456
batch_size = len (request_ids_to_seq_ids )
453
457
(
454
458
current_seqlen_agnostic_cache ,
455
459
indices ,
456
460
) = self ._prepare_current_run_mamba_cache (request_ids_to_seq_ids ,
457
- batch_size )
461
+ batch_size ,
462
+ finished_requests_ids )
458
463
finished_requests_ids = kwargs ["finished_requests_ids" ]
459
464
self ._release_mamba_cache (finished_requests_ids )
460
465
else :
@@ -518,10 +523,15 @@ def _assign_seq_id_to_mamba_cache(self, cur_rid: str,
518
523
return indices_for_current_run
519
524
520
525
def _prepare_current_run_mamba_cache (
521
- self , request_ids_to_seq_ids : Dict [str , list [int ]], batch_size : int
526
+ self , request_ids_to_seq_ids : Dict [str , list [int ]], batch_size : int ,
527
+ finished_requests_ids : List [str ]
522
528
) -> Tuple [Tuple [torch .Tensor , torch .Tensor ], List [int ]]:
523
529
indices_for_current_run = []
524
530
for request_id , seqs_id in request_ids_to_seq_ids .items ():
531
+ if request_id in finished_requests_ids :
532
+ # Do not allocate cache for requests that run
533
+ # and finish right after
534
+ continue
525
535
indices_for_current_run += self ._assign_seq_id_to_mamba_cache (
526
536
request_id , seqs_id )
527
537
## Pad the batch in case of running batch that was not captured via CG
@@ -545,13 +555,16 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
545
555
assert all (
546
556
key in kwargs
547
557
for key in ["request_ids_to_seq_ids" , "finished_requests_ids" ])
558
+ finished_requests_ids = kwargs ["finished_requests_ids" ]
559
+ self ._release_mamba_cache (finished_requests_ids )
548
560
request_ids_to_seq_ids = kwargs ["request_ids_to_seq_ids" ]
549
561
cg_batch_size = input_buffers ['input_ids' ].shape [0 ]
550
562
(
551
563
current_mamba_cache ,
552
564
indices ,
553
565
) = self ._prepare_current_run_mamba_cache (request_ids_to_seq_ids ,
554
- cg_batch_size )
566
+ cg_batch_size ,
567
+ finished_requests_ids )
555
568
self .current_indices = indices
556
569
finished_requests_ids = kwargs ["finished_requests_ids" ]
557
570
self ._release_mamba_cache (finished_requests_ids )
@@ -615,9 +628,12 @@ def _get_mamba_cache_shape(
615
628
def _prepare_mamba_cache (self ):
616
629
dtype = self .lm_head .weight .dtype
617
630
num_mamba_layers = self .config .num_hidden_layers
618
- max_batch_size = _BATCH_SIZES_TO_CAPTURE [- 1 ] + 10
631
+ max_batch_size = (_get_graph_batch_size (
632
+ self .scheduler_config .max_num_seqs ) if self .scheduler_config else
633
+ max (_BATCH_SIZES_TO_CAPTURE )) + 10
619
634
conv_state_shape , temporal_state_shape = self ._get_mamba_cache_shape ()
620
635
assert conv_state_shape is not None and temporal_state_shape is not None
636
+
621
637
for buffername in ["mamba_cache" , "mamba_gc_cache_buffer" ]:
622
638
buffer = (torch .empty (size = (num_mamba_layers , max_batch_size ) +
623
639
conv_state_shape ,
0 commit comments