Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions tests/v1/generation/test_batch_invariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
model=model,
max_num_seqs=1,
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enforce_eager=True,
gpu_memory_utilization=0.9,
max_model_len=2048,
dtype="bfloat16",
Expand Down Expand Up @@ -998,7 +997,6 @@ def LLM_with_max_seqs(
dtype="bfloat16",
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enable_prefix_caching=False,
enforce_eager=True,
# Enable for MOE models
# enable_expert_parallel=True,
)
39 changes: 24 additions & 15 deletions vllm/model_executor/layers/batch_invariant.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import functools
import os
from collections import namedtuple
from collections.abc import Callable
Expand All @@ -11,6 +10,7 @@

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import is_torch_equal_or_newer

Expand Down Expand Up @@ -737,11 +737,28 @@ def enable_batch_invariant_mode():

_batch_invariant_MODE = True
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")

# Batch invariant matmuls are no longer needed after cublas overrides
if not is_torch_equal_or_newer("2.10.0.dev"):
if current_platform.is_device_capability(100):
# For PyTorch 2.9, B200 uses GEMV for bs=1
# Requires https://github.com/pytorch/pytorch/pull/166735
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
else:
# Only source of batch invariance for Hopper is split-k, can disable through
# cuBLAS workspace config
_original_cublas_workspace_cfg = os.environ.get(
"CUBLAS_WORKSPACE_CONFIG", None
)
_original_cublaslt_workspace_size = os.environ.get(
"CUBLASLT_WORKSPACE_SIZE", None
)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"

_batch_invariant_LIB.impl(
"aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
)
Expand All @@ -750,6 +767,7 @@ def enable_batch_invariant_mode():
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")

# Also monkeypatch torch.bmm directly as a fallback
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
_original_torch_bmm = torch.bmm
torch.bmm = bmm_batch_invariant

Expand All @@ -771,14 +789,6 @@ def enable_batch_invariant_mode():
)
torch.backends.cuda.preferred_blas_library(backend="cublaslt")

if not is_torch_equal_or_newer("2.10.0.dev"):
_original_cublas_workspace_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None)
_original_cublaslt_workspace_size = os.environ.get(
"CUBLASLT_WORKSPACE_SIZE", None
)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"


def disable_batch_invariant_mode():
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
Expand Down Expand Up @@ -847,7 +857,6 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
return AttentionBlockSize(block_m=16, block_n=16)


@functools.cache
def vllm_is_batch_invariant():
env_key = "VLLM_BATCH_INVARIANT"
is_overridden = False
Expand Down
6 changes: 6 additions & 0 deletions vllm/utils/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform

logger = init_logger(__name__)
Expand Down Expand Up @@ -222,6 +225,9 @@ def force_use_trtllm_attention() -> bool | None:
return `True` if TRTLLM attention is forced to be used,
return `False` if TRTLLM attention is forced to be not used.
"""
if vllm_is_batch_invariant():
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is disabled for batch-invariant")
return False
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)


Expand Down