Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json
import re
import uuid
from collections.abc import Sequence
from typing import Any

import regex as re

from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
Expand Down
44 changes: 35 additions & 9 deletions vllm/model_executor/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,42 @@ def default_unquantized_gemm(
return torch.nn.functional.linear(x, weight, bias)


def use_aiter_triton_gemm(n, m, k, dtype):
if (
envs.VLLM_ROCM_USE_AITER == 0
# MI300's - fp8nuz=True
or current_platform.is_fp8_fnuz()
or dtype not in [torch.float16, torch.bfloat16]
):
return False

# use hipblaslt for the larger GEMMs
if n > 2048 and m > 512:
return False
return (
(m == 5120 and k == 2880)
or (m == 2880 and k == 4096)
or (m == 128 and k == 2880)
or (m == 640 and k == 2880)
or (m == 2880 and k == 512)
)


def rocm_unquantized_gemm_impl(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
from vllm.platforms.rocm import on_gfx9

x_view = x.view(-1, x.size(-1))
n = x_view.shape[0]
m = weight.shape[0]
k = weight.shape[1]

if use_aiter_triton_gemm(n, m, k, x.dtype):
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16

return gemm_a16w16(x, weight, bias)

use_skinny = (
envs.VLLM_ROCM_USE_SKINNY_GEMM
and on_gfx9()
Expand All @@ -116,12 +146,8 @@ def rocm_unquantized_gemm_impl(
if use_skinny is not True:
return torch.nn.functional.linear(x, weight, bias)

x_view = x.view(-1, x.size(-1))
n = x_view.shape[0]
m = weight.shape[0]
cu_count = current_platform.get_cu_count()

if m > 8 and 0 < n <= 4:
cu_count = current_platform.get_cu_count()
out = ops.wvSplitK(weight, x_view, cu_count, bias)
return out.view(*x.shape[:-1], weight.shape[0])
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
Expand All @@ -130,7 +156,7 @@ def rocm_unquantized_gemm_impl(
return torch.nn.functional.linear(x, weight, bias)


def rocm_unquantized_gemm_impl_fake(
def rocm_unquantized_gemm_fake(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return x.new_empty((*x.shape[:-1], weight.shape[0]))
Expand All @@ -142,13 +168,13 @@ def rocm_unquantized_gemm(
weight: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias)
return torch.ops.vllm.rocm_unquantized_gemm(x, weight, bias)


direct_register_custom_op(
op_name="rocm_unquantized_gemm_impl",
op_name="rocm_unquantized_gemm",
op_func=rocm_unquantized_gemm_impl,
fake_impl=rocm_unquantized_gemm_impl_fake,
fake_impl=rocm_unquantized_gemm_fake,
)


Expand Down
10 changes: 9 additions & 1 deletion vllm/model_executor/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.utils import rocm_unquantized_gemm
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.math_utils import cdiv

Expand Down Expand Up @@ -153,6 +155,7 @@ def __init__(

self.layer_idx = layer_idx
self.num_experts = config.num_local_experts
self.hidden_size = config.hidden_size
self.experts_per_token = config.num_experts_per_tok
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts)
Expand All @@ -177,7 +180,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.is_sequence_parallel:
x = sequence_parallel_chunk(x)

g = self.router(x)
if current_platform.is_rocm():
g = rocm_unquantized_gemm(
self, x[:, : self.hidden_size], self.router.weight, self.router.bias
)
else:
g = self.router(x)
x = self.experts(hidden_states=x, router_logits=g)

if self.is_sequence_parallel:
Expand Down