Skip to content

Commit 3878d3c

Browse files
committed
Override inductor default mm with batch invariant one for B200
1 parent c9791f1 commit 3878d3c

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

tests/v1/generation/test_batch_invariance.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
456456
model=model,
457457
max_num_seqs=1,
458458
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
459-
enforce_eager=True,
460459
gpu_memory_utilization=0.9,
461460
max_model_len=2048,
462461
dtype="bfloat16",
@@ -998,7 +997,6 @@ def LLM_with_max_seqs(
998997
dtype="bfloat16",
999998
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
1000999
enable_prefix_caching=False,
1001-
enforce_eager=True,
10021000
# Enable for MOE models
10031001
# enable_expert_parallel=True,
10041002
)

vllm/config/compilation.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
1616
from vllm.config.utils import config
1717
from vllm.logger import init_logger
18+
from vllm.model_executor.layers.batch_invariant import (
19+
vllm_is_batch_invariant,
20+
)
1821
from vllm.platforms import current_platform
1922
from vllm.utils.import_utils import resolve_obj_by_qualname
2023
from vllm.utils.torch_utils import is_torch_equal_or_newer
@@ -579,6 +582,13 @@ def __post_init__(self) -> None:
579582
self.inductor_compile_config["combo_kernels"] = True
580583
self.inductor_compile_config["benchmark_combo_kernel"] = True
581584

585+
# Batch invariance on Blackwell doesn't work with cuda graphs
586+
if vllm_is_batch_invariant() and current_platform.is_device_capability(100):
587+
logger.warning(
588+
"Disabling Cudagraphs: Batch invariance on Blackwell doesn't work with cuda graphs"
589+
)
590+
self.cudagraph_mode = CUDAGraphMode.NONE
591+
582592
# migrate the deprecated flags
583593
if not self.use_cudagraph:
584594
logger.warning(

0 commit comments

Comments
 (0)