Skip to content

Commit 3ac6ae4

Browse files
committed
Wider coverage of batch invariance torch.compile, cuda graphs, and B200
1 parent c9791f1 commit 3ac6ae4

File tree

4 files changed

+30
-15
lines changed

4 files changed

+30
-15
lines changed

tests/v1/generation/test_batch_invariance.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
456456
model=model,
457457
max_num_seqs=1,
458458
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
459-
enforce_eager=True,
460459
gpu_memory_utilization=0.9,
461460
max_model_len=2048,
462461
dtype="bfloat16",
@@ -998,7 +997,6 @@ def LLM_with_max_seqs(
998997
dtype="bfloat16",
999998
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
1000999
enable_prefix_caching=False,
1001-
enforce_eager=True,
10021000
# Enable for MOE models
10031001
# enable_expert_parallel=True,
10041002
)

vllm/config/compilation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
1616
from vllm.config.utils import config
1717
from vllm.logger import init_logger
18+
from vllm.model_executor.layers.batch_invariant import (
19+
vllm_is_batch_invariant,
20+
)
1821
from vllm.platforms import current_platform
1922
from vllm.utils.import_utils import resolve_obj_by_qualname
2023
from vllm.utils.torch_utils import is_torch_equal_or_newer

vllm/model_executor/layers/batch_invariant.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -736,11 +736,26 @@ def enable_batch_invariant_mode():
736736

737737
_batch_invariant_MODE = True
738738
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
739-
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
740-
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
741-
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
742-
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
743-
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
739+
740+
# Batch invariant matmuls are no longer needed after cublas overrides
741+
if not is_torch_equal_or_newer("2.10.0.dev"):
742+
if current_platform.is_device_capability(100):
743+
# For PyTorch 2.9, B200 uses GEMV for bs=1
744+
# Requires https://github.com/pytorch/pytorch/pull/166735, so need to use mm overrides
745+
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
746+
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
747+
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
748+
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
749+
else:
750+
# Only source of batch invariance for Hopper is split-k, can disable through
751+
# cuBLAS workspace config
752+
_original_cublas_workspace_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None)
753+
_original_cublaslt_workspace_size = os.environ.get(
754+
"CUBLASLT_WORKSPACE_SIZE", None
755+
)
756+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
757+
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
758+
744759
_batch_invariant_LIB.impl(
745760
"aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
746761
)
@@ -749,6 +764,7 @@ def enable_batch_invariant_mode():
749764
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
750765

751766
# Also monkeypatch torch.bmm directly as a fallback
767+
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
752768
_original_torch_bmm = torch.bmm
753769
torch.bmm = bmm_batch_invariant
754770

@@ -770,14 +786,6 @@ def enable_batch_invariant_mode():
770786
)
771787
torch.backends.cuda.preferred_blas_library(backend="cublaslt")
772788

773-
if not is_torch_equal_or_newer("2.10.0.dev"):
774-
_original_cublas_workspace_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None)
775-
_original_cublaslt_workspace_size = os.environ.get(
776-
"CUBLASLT_WORKSPACE_SIZE", None
777-
)
778-
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
779-
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
780-
781789

782790
def disable_batch_invariant_mode():
783791
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm

vllm/utils/flashinfer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
import vllm.envs as envs
2121
from vllm.logger import init_logger
2222
from vllm.platforms import current_platform
23+
from vllm.model_executor.layers.batch_invariant import (
24+
vllm_is_batch_invariant,
25+
)
2326

2427
logger = init_logger(__name__)
2528

@@ -213,6 +216,9 @@ def _force_use_trtllm_attention(env_value: bool | None) -> bool | None:
213216
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
214217
if env_value is not None:
215218
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
219+
if vllm_is_batch_invariant():
220+
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is disabled for batch-invariant")
221+
return False
216222
return env_value
217223

218224

0 commit comments

Comments
 (0)