Skip to content

Commit fade2e2

Browse files
committed
Override inductor default mm with batch invariant one for B200
1 parent c9791f1 commit fade2e2

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

vllm/config/compilation.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from vllm.platforms import current_platform
1919
from vllm.utils.import_utils import resolve_obj_by_qualname
2020
from vllm.utils.torch_utils import is_torch_equal_or_newer
21+
from vllm.model_executor.layers.batch_invariant import (
22+
vllm_is_batch_invariant,
23+
)
2124

2225
if TYPE_CHECKING:
2326
from vllm.config import VllmConfig
@@ -579,6 +582,13 @@ def __post_init__(self) -> None:
579582
self.inductor_compile_config["combo_kernels"] = True
580583
self.inductor_compile_config["benchmark_combo_kernel"] = True
581584

585+
# Batch invariance on Blackwell doesn't work with cuda graphs
586+
if vllm_is_batch_invariant() and current_platform.is_device_capability(100) >= (10, 0):
587+
logger.warning(
588+
"Disabling Cudagraphs: Batch invariance on Blackwell doesn't work with cuda graphs"
589+
)
590+
self.cudagraph_mode = CUDAGraphMode.NONE
591+
582592
# migrate the deprecated flags
583593
if not self.use_cudagraph:
584594
logger.warning(

vllm/model_executor/layers/batch_invariant.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def matmul_kernel_persistent(
140140

141141

142142
def matmul_persistent(
143-
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
143+
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None, out = None
144144
):
145145
# Check constraints.
146146
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
@@ -213,7 +213,11 @@ def grid(META):
213213
HAS_BIAS=bias is not None,
214214
**configs[dtype],
215215
)
216-
return c
216+
217+
if out is not None:
218+
out.copy_(c)
219+
else:
220+
return c
217221

218222

219223
@triton.jit
@@ -466,6 +470,9 @@ def mean_dim(
466470
def mm_batch_invariant(a, b):
467471
return matmul_persistent(a, b)
468472

473+
def mm_batch_invariant_out(a, b, out=None):
474+
return matmul_persistent(a, b, bias=None, out=out)
475+
469476

470477
def matmul_batch_invariant(a, b, *, out=None):
471478
# torch.matmul can handle various dimensions

0 commit comments

Comments
 (0)