diff --git a/vllm/envs.py b/vllm/envs.py index 18870c1c6b5..cdc4a8ca81f 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -72,6 +72,7 @@ VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False + VLLM_MAMBA_NUM_OF_SLOTS_MULTIPLIER: float = 1.5 def get_default_cache_root(): @@ -466,6 +467,8 @@ def get_default_config_root(): lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")), "VLLM_DISABLE_COMPILE_CACHE": lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))), + "VLLM_MAMBA_NUM_OF_SLOTS_MULTIPLIER": + lambda: float(os.getenv("VLLM_MAMBA_NUM_OF_SLOTS_MULTIPLIER", "1.5")), } # end-env-vars-definition diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 91786db5ddc..92804d19cd7 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -5,6 +5,7 @@ from torch import nn from transformers import JambaConfig +from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig @@ -422,17 +423,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - if self.scheduler_config is not None and \ - not self.model_config.enforce_eager: - if self.scheduler_config.max_num_seqs > \ + + effective_max_batch_size = int( + self.vllm_config.scheduler_config.max_num_seqs * \ + envs.VLLM_MAMBA_NUM_OF_SLOTS_MULTIPLIER + ) + if not self.model_config.enforce_eager \ + and effective_max_batch_size <= \ vllm_config.compilation_config.max_capture_size: - self.max_batch_size = \ - vllm_config.compilation_config.max_capture_size - else: - self.max_batch_size = vllm_config.pad_for_cudagraph( - self.scheduler_config.max_num_seqs) + self.max_batch_size = vllm_config.pad_for_cudagraph( + effective_max_batch_size) else: - self.max_batch_size = 8192 + 2 + self.max_batch_size = effective_max_batch_size def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 06c8d9723cd..ee5c26cde30 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -5,6 +5,7 @@ from torch import nn from transformers import MambaConfig +from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -195,17 +196,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.backbone.make_empty_intermediate_tensors) - if self.scheduler_config is not None and \ - not self.model_config.enforce_eager: - if self.scheduler_config.max_num_seqs > \ + + effective_max_batch_size = int( + self.vllm_config.scheduler_config.max_num_seqs * \ + envs.VLLM_MAMBA_NUM_OF_SLOTS_MULTIPLIER + ) + if not self.model_config.enforce_eager \ + and effective_max_batch_size <= \ vllm_config.compilation_config.max_capture_size: - self.max_batch_size = \ - vllm_config.compilation_config.max_capture_size - else: - self.max_batch_size = vllm_config.pad_for_cudagraph( - self.scheduler_config.max_num_seqs) + self.max_batch_size = vllm_config.pad_for_cudagraph( + effective_max_batch_size) else: - self.max_batch_size = 8192 + 2 + self.max_batch_size = effective_max_batch_size def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.backbone.get_input_embeddings(input_ids)