Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True
VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = False
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
VLLM_ROCM_USE_TRITON_ROPE: bool = False
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
Expand Down Expand Up @@ -894,6 +895,11 @@ def get_vllm_port() -> int | None:
"VLLM_ROCM_USE_AITER_MHA": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in ("true", "1")
),
# Whether to use aiter fp16 triton gemm.
# By default is disabled.
"VLLM_ROCM_USE_AITER_TRITON_GEMM": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_TRITON_GEMM", "False").lower() in ("true", "1")
),
# Whether to use aiter fp4 gemm asm.
# By default is disabled.
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM": lambda: (
Expand Down
44 changes: 34 additions & 10 deletions vllm/model_executor/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,41 @@ def default_unquantized_gemm(
return torch.nn.functional.linear(x, weight, bias)


def use_aiter_triton_gemm(n, m, k, dtype):
if (
envs.VLLM_ROCM_USE_AITER == 0
or envs.VLLM_ROCM_USE_AITER_TRITON_GEMM == 0
or dtype not in [torch.float16, torch.bfloat16]
):
return False

# use hipblaslt for the larger GEMMs
if n > 2048 and m > 512:
return False
return (
(m == 5120 and k == 2880)
or (m == 2880 and k == 4096)
or (m == 128 and k == 2880)
or (m == 640 and k == 2880)
or (m == 2880 and k == 512)
)


def rocm_unquantized_gemm_impl(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
from vllm.platforms.rocm import on_gfx9

x_view = x.view(-1, x.size(-1))
n = x_view.shape[0]
m = weight.shape[0]
k = weight.shape[1]

if use_aiter_triton_gemm(n, m, k, x.dtype):
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16

return gemm_a16w16(x, weight)

use_skinny = (
envs.VLLM_ROCM_USE_SKINNY_GEMM
and on_gfx9()
Expand All @@ -116,12 +145,8 @@ def rocm_unquantized_gemm_impl(
if use_skinny is not True:
return torch.nn.functional.linear(x, weight, bias)

x_view = x.view(-1, x.size(-1))
n = x_view.shape[0]
m = weight.shape[0]
cu_count = current_platform.get_cu_count()

if m > 8 and 0 < n <= 4:
cu_count = current_platform.get_cu_count()
out = ops.wvSplitK(weight, x_view, cu_count, bias)
return out.view(*x.shape[:-1], weight.shape[0])
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
Expand All @@ -130,25 +155,24 @@ def rocm_unquantized_gemm_impl(
return torch.nn.functional.linear(x, weight, bias)


def rocm_unquantized_gemm_impl_fake(
def rocm_unquantized_gemm_fake(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return x.new_empty((*x.shape[:-1], weight.shape[0]))


def rocm_unquantized_gemm(
layer: torch.nn.Module,
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias)
return torch.ops.vllm.rocm_unquantized_gemm(x, weight, bias)


direct_register_custom_op(
op_name="rocm_unquantized_gemm_impl",
op_name="rocm_unquantized_gemm",
op_func=rocm_unquantized_gemm_impl,
fake_impl=rocm_unquantized_gemm_impl_fake,
fake_impl=rocm_unquantized_gemm_fake,
)


Expand Down
9 changes: 8 additions & 1 deletion vllm/model_executor/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.utils import rocm_unquantized_gemm
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv

Expand Down Expand Up @@ -175,7 +177,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.is_sequence_parallel:
x = sequence_parallel_chunk(x)

g = self.router(x)
if current_platform.is_rocm():
g = rocm_unquantized_gemm(
x[:, : self.hidden_size], self.router.weight, self.router.bias
)
else:
g = self.router(x)
x = self.experts(hidden_states=x, router_logits=g)

if self.is_sequence_parallel:
Expand Down