From eef1b169dd1077be68a2c9a465054c9a4e495eef Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Thu, 16 Oct 2025 00:56:49 +0000 Subject: [PATCH 1/6] gemm_a16w16 upstreaming Signed-off-by: Aleksandr Malyshev --- vllm/envs.py | 6 ++++++ vllm/model_executor/layers/utils.py | 25 ++++++++++++++++++++++--- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 6f40209dd000..fbb1a5c491af 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -109,6 +109,7 @@ VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_MHA: bool = True + VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = False VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False VLLM_ROCM_USE_TRITON_ROPE: bool = False VLLM_ROCM_USE_AITER_FP8BMM: bool = True @@ -894,6 +895,11 @@ def get_vllm_port() -> int | None: "VLLM_ROCM_USE_AITER_MHA": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in ("true", "1") ), + # Whether to use aiter fp16 triton gemm. + # By default is disabled. + "VLLM_ROCM_USE_AITER_TRITON_GEMM": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_TRITON_GEMM", "False").lower() in ("true", "1") + ), # Whether to use aiter fp4 gemm asm. # By default is disabled. "VLLM_ROCM_USE_AITER_FP4_ASM_GEMM": lambda: ( diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 87ffcb48c8c0..c386ef0bec0e 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -100,12 +100,34 @@ def default_unquantized_gemm( return torch.nn.functional.linear(x, weight, bias) +def use_aiter_triton_gemm(m, k): + if envs.VLLM_ROCM_USE_AITER == 0 or envs.VLLM_ROCM_USE_AITER_TRITON_GEMM == 0: + return False + + return ( + (m == 5120 and k == 2880) + or (m == 2880 and k == 4096) + or (m == 128 and k == 2880) + or (m == 640 and k == 2880) + or (m == 2880 and k == 512) + ) + + def rocm_unquantized_gemm_impl( x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None ) -> torch.Tensor: from vllm.platforms.rocm import on_gfx9 + x_view = x.view(-1, x.size(-1)) + n = x_view.shape[0] + m = weight.shape[0] k = weight.shape[1] + + if use_aiter_triton_gemm(m, k): + from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 + + return gemm_a16w16(x, weight, bias) + use_skinny = ( envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() @@ -116,9 +138,6 @@ def rocm_unquantized_gemm_impl( if use_skinny is not True: return torch.nn.functional.linear(x, weight, bias) - x_view = x.view(-1, x.size(-1)) - n = x_view.shape[0] - m = weight.shape[0] cu_count = current_platform.get_cu_count() if m > 8 and 0 < n <= 4: From 5538c0f4ff38f36987969962491440caf8580afa Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Fri, 17 Oct 2025 22:45:43 +0000 Subject: [PATCH 2/6] triton fp16 kernel Signed-off-by: Aleksandr Malyshev --- vllm/model_executor/layers/utils.py | 16 +++++++++++----- vllm/model_executor/models/gpt_oss.py | 8 +++++++- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index c386ef0bec0e..518c9fc7d296 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -100,10 +100,17 @@ def default_unquantized_gemm( return torch.nn.functional.linear(x, weight, bias) -def use_aiter_triton_gemm(m, k): - if envs.VLLM_ROCM_USE_AITER == 0 or envs.VLLM_ROCM_USE_AITER_TRITON_GEMM == 0: +def use_aiter_triton_gemm(n, m, k, dtype): + if ( + envs.VLLM_ROCM_USE_AITER == 0 + or envs.VLLM_ROCM_USE_AITER_TRITON_GEMM == 0 + or dtype not in [torch.float16, torch.bfloat16] + ): return False + # use hipblaslt for the larger GEMMs + if n > 2048 and m > 512: + return False return ( (m == 5120 and k == 2880) or (m == 2880 and k == 4096) @@ -123,7 +130,7 @@ def rocm_unquantized_gemm_impl( m = weight.shape[0] k = weight.shape[1] - if use_aiter_triton_gemm(m, k): + if use_aiter_triton_gemm(n, m, k, x.dtype): from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 return gemm_a16w16(x, weight, bias) @@ -138,9 +145,8 @@ def rocm_unquantized_gemm_impl( if use_skinny is not True: return torch.nn.functional.linear(x, weight, bias) - cu_count = current_platform.get_cu_count() - if m > 8 and 0 < n <= 4: + cu_count = current_platform.get_cu_count() out = ops.wvSplitK(weight, x_view, cu_count, bias) return out.view(*x.shape[:-1], weight.shape[0]) elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None: diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index fcba9b8e66c2..8c1504b7e580 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -29,6 +29,7 @@ ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import sequence_parallel_chunk +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import cdiv @@ -175,7 +176,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.is_sequence_parallel: x = sequence_parallel_chunk(x) - g = self.router(x) + if current_platform.is_rocm(): + g = torch.ops.vllm.rocm_unquantized_gemm_impl( + x[:, : self.hidden_size], self.router.weight, self.router.bias + ) + else: + g = self.router(x) x = self.experts(hidden_states=x, router_logits=g) if self.is_sequence_parallel: From 13503848ad99d0c4f935e365987ed4aca0be1503 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Sat, 18 Oct 2025 04:05:41 +0000 Subject: [PATCH 3/6] triton fp16 kernel Signed-off-by: Aleksandr Malyshev --- vllm/model_executor/layers/utils.py | 11 +++++------ vllm/model_executor/models/gpt_oss.py | 3 ++- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 518c9fc7d296..8261cefc94ab 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -133,7 +133,7 @@ def rocm_unquantized_gemm_impl( if use_aiter_triton_gemm(n, m, k, x.dtype): from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 - return gemm_a16w16(x, weight, bias) + return gemm_a16w16(x, weight) use_skinny = ( envs.VLLM_ROCM_USE_SKINNY_GEMM @@ -155,25 +155,24 @@ def rocm_unquantized_gemm_impl( return torch.nn.functional.linear(x, weight, bias) -def rocm_unquantized_gemm_impl_fake( +def rocm_unquantized_gemm_fake( x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None ) -> torch.Tensor: return x.new_empty((*x.shape[:-1], weight.shape[0])) def rocm_unquantized_gemm( - layer: torch.nn.Module, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias) + return torch.ops.vllm.rocm_unquantized_gemm(x, weight, bias) direct_register_custom_op( - op_name="rocm_unquantized_gemm_impl", + op_name="rocm_unquantized_gemm", op_func=rocm_unquantized_gemm_impl, - fake_impl=rocm_unquantized_gemm_impl_fake, + fake_impl=rocm_unquantized_gemm_fake, ) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 8c1504b7e580..7a581a99a80e 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -23,6 +23,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.utils import rocm_unquantized_gemm from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -177,7 +178,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = sequence_parallel_chunk(x) if current_platform.is_rocm(): - g = torch.ops.vllm.rocm_unquantized_gemm_impl( + g = rocm_unquantized_gemm( x[:, : self.hidden_size], self.router.weight, self.router.bias ) else: From feaae0f7c72c4f2618f6d16930e4456c80717a17 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Tue, 21 Oct 2025 02:59:59 +0000 Subject: [PATCH 4/6] Torch compile fix Signed-off-by: Aleksandr Malyshev --- vllm/model_executor/layers/utils.py | 1 + vllm/model_executor/models/gpt_oss.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 7cf1708ef06c..a7498f151e40 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -162,6 +162,7 @@ def rocm_unquantized_gemm_fake( def rocm_unquantized_gemm( + layer: torch.nn.Module, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None, diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index bac8814e0654..fae93322e935 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -153,6 +153,7 @@ def __init__( self.layer_idx = layer_idx self.num_experts = config.num_local_experts + self.hidden_size = config.hidden_size self.experts_per_token = config.num_experts_per_tok self.world_size = dist.get_world_size() if dist.is_initialized() else 1 self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts) @@ -179,7 +180,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if current_platform.is_rocm(): g = rocm_unquantized_gemm( - x[:, : self.hidden_size], self.router.weight, self.router.bias + self, x[:, : self.hidden_size], self.router.weight, self.router.bias ) else: g = self.router(x) From 70b2746499a345920afefca151e6ae459c3f3ea0 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Mon, 27 Oct 2025 22:25:38 +0000 Subject: [PATCH 5/6] removed flag, added missed bias Signed-off-by: Aleksandr Malyshev --- vllm/envs.py | 6 ------ vllm/model_executor/layers/utils.py | 5 +++-- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index f9621e5c6d68..0c45f93ec057 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -106,7 +106,6 @@ VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_MHA: bool = True - VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = False VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False VLLM_ROCM_USE_TRITON_ROPE: bool = False VLLM_ROCM_USE_AITER_FP8BMM: bool = True @@ -869,11 +868,6 @@ def get_vllm_port() -> int | None: "VLLM_ROCM_USE_AITER_MHA": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in ("true", "1") ), - # Whether to use aiter fp16 triton gemm. - # By default is disabled. - "VLLM_ROCM_USE_AITER_TRITON_GEMM": lambda: ( - os.getenv("VLLM_ROCM_USE_AITER_TRITON_GEMM", "False").lower() in ("true", "1") - ), # Whether to use aiter fp4 gemm asm. # By default is disabled. "VLLM_ROCM_USE_AITER_FP4_ASM_GEMM": lambda: ( diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 123ab1cd9a3f..07b3958382f8 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -103,7 +103,8 @@ def default_unquantized_gemm( def use_aiter_triton_gemm(n, m, k, dtype): if ( envs.VLLM_ROCM_USE_AITER == 0 - or envs.VLLM_ROCM_USE_AITER_TRITON_GEMM == 0 + # MI300's - fp8nuz=True + or current_platform.is_fp8_fnuz() or dtype not in [torch.float16, torch.bfloat16] ): return False @@ -133,7 +134,7 @@ def rocm_unquantized_gemm_impl( if use_aiter_triton_gemm(n, m, k, x.dtype): from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 - return gemm_a16w16(x, weight) + return gemm_a16w16(x, weight, bias) use_skinny = ( envs.VLLM_ROCM_USE_SKINNY_GEMM From f79a9ffa0ad2e2d390077188c22c8d0407be1bd2 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Tue, 28 Oct 2025 16:56:34 +0000 Subject: [PATCH 6/6] minor corrections Signed-off-by: Aleksandr Malyshev --- vllm/model_executor/layers/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 07b3958382f8..e7dc39dc502b 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -126,8 +126,7 @@ def rocm_unquantized_gemm_impl( ) -> torch.Tensor: from vllm.platforms.rocm import on_gfx9 - x_view = x.view(-1, x.size(-1)) - n = x_view.shape[0] + n = x.numel() / x.size(-1) m = weight.shape[0] k = weight.shape[1] @@ -146,6 +145,7 @@ def rocm_unquantized_gemm_impl( if use_skinny is not True: return torch.nn.functional.linear(x, weight, bias) + x_view = x.view(-1, x.size(-1)) if m > 8 and 0 < n <= 4: cu_count = current_platform.get_cu_count() out = ops.wvSplitK(weight, x_view, cu_count, bias)