Skip to content

Commit fb846ce

Browse files
committed
apply fix from #6214
1 parent b733a84 commit fb846ce

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

vllm/model_executor/models/mamba.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -538,20 +538,20 @@ def _prepare_current_run_mamba_cache(
538538

539539
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
540540
"""
541-
Copy the relevant Mamba cache into the CUDA graph input buffer
542-
that was provided during the capture runs
543-
(MambaForCausalLM.mamba_gc_cache_buffer).
541+
Copy the relevant Mamba cache into the CUDA graph input buffer
542+
that was provided during the capture runs
543+
(MambaForCausalLM.mamba_gc_cache_buffer).
544544
"""
545545
assert all(
546546
key in kwargs
547547
for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
548548
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
549-
batch_size = len(request_ids_to_seq_ids)
549+
cg_batch_size = input_buffers['input_ids'].shape[0]
550550
(
551551
current_mamba_cache,
552552
indices,
553553
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
554-
batch_size)
554+
cg_batch_size)
555555
self.current_indices = indices
556556
finished_requests_ids = kwargs["finished_requests_ids"]
557557
self._release_mamba_cache(finished_requests_ids)

0 commit comments

Comments
 (0)