Skip to content

Commit f827c17

Browse files
committed
Override inductor default mm with batch invariant one for B200
1 parent c9791f1 commit f827c17

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

vllm/config/compilation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
1616
from vllm.config.utils import config
1717
from vllm.logger import init_logger
18+
from vllm.model_executor.layers.batch_invariant import (
19+
vllm_is_batch_invariant,
20+
)
1821
from vllm.platforms import current_platform
1922
from vllm.utils.import_utils import resolve_obj_by_qualname
2023
from vllm.utils.torch_utils import is_torch_equal_or_newer
@@ -579,6 +582,16 @@ def __post_init__(self) -> None:
579582
self.inductor_compile_config["combo_kernels"] = True
580583
self.inductor_compile_config["benchmark_combo_kernel"] = True
581584

585+
# Batch invariance on Blackwell doesn't work with cuda graphs
586+
if vllm_is_batch_invariant() and current_platform.is_device_capability(100) >= (
587+
10,
588+
0,
589+
):
590+
logger.warning(
591+
"Disabling Cudagraphs: Batch invariance on Blackwell doesn't work with cuda graphs"
592+
)
593+
self.cudagraph_mode = CUDAGraphMode.NONE
594+
582595
# migrate the deprecated flags
583596
if not self.use_cudagraph:
584597
logger.warning(

0 commit comments

Comments
 (0)