Skip to content

Commit d8017cb

Browse files
committed
fixes from 6425
1 parent 09b1495 commit d8017cb

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

vllm/model_executor/models/interfaces.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ class HasInnerState(Protocol):
152152
"""
153153
A flag that indicates this model has inner state.
154154
Models that has inner state usually need access to the scheduler_config
155-
for max_num_seqs ,etc... (Currently only used by Jamba)
155+
for max_num_seqs ,etc... (Currently used by Jamba and Mamba)
156156
"""
157157

158158
def __init__(self,

vllm/model_executor/models/mamba.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from transformers import MambaConfig
1313

1414
from vllm.attention.backends.abstract import AttentionMetadata
15-
from vllm.config import CacheConfig, LoRAConfig
15+
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
1616
from vllm.distributed import (get_tensor_model_parallel_rank,
1717
get_tensor_model_parallel_world_size)
1818
from vllm.model_executor.layers.activation import SiluAndMul
@@ -27,10 +27,12 @@
2727
from vllm.model_executor.layers.vocab_parallel_embedding import (
2828
VocabParallelEmbedding)
2929
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30+
from vllm.model_executor.models.interfaces import HasInnerState
3031
from vllm.model_executor.sampling_metadata import SamplingMetadata
3132
from vllm.model_executor.utils import set_weight_attrs
3233
from vllm.sequence import IntermediateTensors, SamplerOutput
33-
from vllm.worker.model_runner import _BATCH_SIZES_TO_CAPTURE
34+
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
35+
_get_graph_batch_size)
3436

3537
KVCache = Tuple[torch.Tensor, torch.Tensor]
3638

@@ -376,7 +378,7 @@ def forward(
376378
return hidden_states
377379

378380

379-
class MambaForCausalLM(nn.Module):
381+
class MambaForCausalLM(nn.Module, HasInnerState):
380382
packed_modules_mapping = {
381383
"qkv_proj": [
382384
"q_proj",
@@ -404,9 +406,11 @@ def __init__(
404406
cache_config: Optional[CacheConfig] = None,
405407
quant_config: Optional[QuantizationConfig] = None,
406408
lora_config: Optional[LoRAConfig] = None,
409+
scheduler_config: Optional[SchedulerConfig] = None,
407410
) -> None:
408411
super().__init__()
409412
self.config = config
413+
self.scheduler_config = scheduler_config
410414
self.backbone = MambaModel(config,
411415
cache_config=cache_config,
412416
quant_config=quant_config,
@@ -436,7 +440,6 @@ def forward(self,
436440
attn_metadata: AttentionMetadata,
437441
intermediate_tensors: Optional[IntermediateTensors] = None,
438442
**kwargs):
439-
440443
if not self.mamba_cache:
441444
self._prepare_mamba_cache()
442445

@@ -447,14 +450,16 @@ def forward(self,
447450
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
448451

449452
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
453+
finished_requests_ids = kwargs["finished_requests_ids"]
450454
batch_size = input_ids.shape[0]
451455
if attn_metadata.prefill_metadata:
452456
batch_size = len(request_ids_to_seq_ids)
453457
(
454458
current_seqlen_agnostic_cache,
455459
indices,
456460
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
457-
batch_size)
461+
batch_size,
462+
finished_requests_ids)
458463
finished_requests_ids = kwargs["finished_requests_ids"]
459464
self._release_mamba_cache(finished_requests_ids)
460465
else:
@@ -518,10 +523,15 @@ def _assign_seq_id_to_mamba_cache(self, cur_rid: str,
518523
return indices_for_current_run
519524

520525
def _prepare_current_run_mamba_cache(
521-
self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int
526+
self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int,
527+
finished_requests_ids: List[str]
522528
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]:
523529
indices_for_current_run = []
524530
for request_id, seqs_id in request_ids_to_seq_ids.items():
531+
if request_id in finished_requests_ids:
532+
# Do not allocate cache for requests that run
533+
# and finish right after
534+
continue
525535
indices_for_current_run += self._assign_seq_id_to_mamba_cache(
526536
request_id, seqs_id)
527537
## Pad the batch in case of running batch that was not captured via CG
@@ -545,13 +555,16 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
545555
assert all(
546556
key in kwargs
547557
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
558+
finished_requests_ids = kwargs["finished_requests_ids"]
559+
self._release_mamba_cache(finished_requests_ids)
548560
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
549561
cg_batch_size = input_buffers['input_ids'].shape[0]
550562
(
551563
current_mamba_cache,
552564
indices,
553565
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
554-
cg_batch_size)
566+
cg_batch_size,
567+
finished_requests_ids)
555568
self.current_indices = indices
556569
finished_requests_ids = kwargs["finished_requests_ids"]
557570
self._release_mamba_cache(finished_requests_ids)
@@ -615,9 +628,12 @@ def _get_mamba_cache_shape(
615628
def _prepare_mamba_cache(self):
616629
dtype = self.lm_head.weight.dtype
617630
num_mamba_layers = self.config.num_hidden_layers
618-
max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + 10
631+
max_batch_size = (_get_graph_batch_size(
632+
self.scheduler_config.max_num_seqs) if self.scheduler_config else
633+
max(_BATCH_SIZES_TO_CAPTURE)) + 10
619634
conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
620635
assert conv_state_shape is not None and temporal_state_shape is not None
636+
621637
for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]:
622638
buffer = (torch.empty(size=(num_mamba_layers, max_batch_size) +
623639
conv_state_shape,

0 commit comments

Comments
 (0)