Skip to content

[ROCm][AITER] Support AITER Rope ops in RotaryEmbedding Module. #22521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 11, 2025
Merged
71 changes: 71 additions & 0 deletions vllm/model_executor/layers/rotary_embedding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vllm.model_executor.custom_op import CustomOp

from .common import apply_rotary_emb_dispatch, apply_rotary_emb_torch
from .rocm_aiter_rope_ops import is_rocm_rotary_embedding_enabled


@CustomOp.register("rotary_embedding")
Expand Down Expand Up @@ -35,6 +36,7 @@ def __init__(
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
self.is_rocm_aiter_enabled = is_rocm_rotary_embedding_enabled()

def _compute_inv_freq(self, base: float) -> torch.Tensor:
"""Compute the inverse frequency."""
Expand Down Expand Up @@ -119,6 +121,75 @@ def forward_cuda(
self.cos_sin_cache, self.is_neox_style)
return query, key

def forward_hip(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
is_nope_first=False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
# currently only rotary embedding ops from AITER package are
# supported for HiP forward.
if self.is_rocm_aiter_enabled:
return self.forward_hip_rocm_aiter(positions, query, key, offsets,
is_nope_first)
return self.forward_native(positions, query, key, offsets)

def forward_hip_rocm_aiter(
self,
positions: torch.Tensor,
# if is_nope_first
# [[batch_size, seq_len, num_heads, nope_size+rope_size]
# if NOT is_nope_first
# [[batch_size, seq_len, num_heads, rope_size+nope_size],
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
is_nope_first: bool = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if self.cos_sin_cache.device != query.device or \
self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype)
cos, sin = self.cos_sin_cache.chunk(2, dim=-1)

cos = cos.unsqueeze(-2).unsqueeze(-2)
sin = sin.unsqueeze(-2).unsqueeze(-2)

rotate_style = 0 if self.is_neox_style else 1

num_tokens = positions.numel()

query_shape = query.shape
query = query.view(1, num_tokens, -1, self.head_size)
if key is not None:
key_shape = key.shape
key = key.view(1, num_tokens, -1, self.head_size)

positions = positions.view(*query.shape[:2])
if offsets is not None:
offsets = offsets.view(*query.shape[:2])

if not is_nope_first:
query_ = query[..., :self.rotary_dim]
key_ = key[..., :self.rotary_dim] if key is not None else None
else:
query_ = query[..., -self.rotary_dim:]
key_ = key[..., -self.rotary_dim:] if key is not None else None

if key_ is None:
torch.ops.vllm.rocm_aiter_rotary_emb_without_key_forward_hip(
positions, sin, cos, query_, offsets, rotate_style,
is_nope_first)
return query.view(query_shape), None

torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_hip(
positions, sin, cos, query_, key_, offsets, rotate_style,
is_nope_first)

return query.view(query_shape), key.view(key_shape)

def forward_xpu(
self,
positions: torch.Tensor,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/rotary_embedding/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def yarn_linear_ramp_mask(low: float, high: float, dim: int,
return ramp_func


def yarn_get_mscale(scale: float = 1) -> float:
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0
return 0.1 * mscale * math.log(scale) + 1.0
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import math
from typing import Optional

import torch
Expand All @@ -10,13 +9,7 @@

from .base import RotaryEmbedding
from .common import (rotate_gptj, rotate_neox, yarn_find_correction_range,
yarn_linear_ramp_mask)


def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
yarn_get_mscale, yarn_linear_ramp_mask)


class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
Expand Down Expand Up @@ -96,6 +89,9 @@ def forward(
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
if self.is_rocm_aiter_enabled:
return self.forward_hip_rocm_aiter(positions, query, key, offsets)

assert key is not None
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
Expand Down
127 changes: 127 additions & 0 deletions vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional

import torch

import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op


def is_rocm_rotary_embedding_enabled() -> bool:
return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER)


def rocm_aiter_rotary_emb_without_key_forward_hip_impl(
positions: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
query: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
rotate_style: int = 0,
is_nope_first: bool = False,
) -> None:
import aiter as ops
if offsets is None:
ops.rope_cached_positions_fwd_inplace(
query,
cos,
sin,
positions,
rotate_style,
reuse_freqs_front_part=True,
nope_first=is_nope_first,
)
else:
ops.rope_cached_positions_offsets_fwd_inplace(
query,
cos,
sin,
positions,
offsets,
rotate_style,
reuse_freqs_front_part=True,
nope_first=is_nope_first,
)


def rocm_aiter_rotary_emb_with_key_forward_hip_impl(
positions: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
rotate_style: int = 0,
is_nope_first: bool = False,
) -> None:
import aiter as ops
if offsets is None:
ops.rope_cached_positions_2c_fwd_inplace(
query,
key,
cos,
sin,
positions,
rotate_style,
reuse_freqs_front_part=True,
nope_first=is_nope_first,
)
else:
ops.rope_cached_positions_offsets_2c_fwd_inplace(
query,
key,
cos,
sin,
positions,
offsets,
rotate_style,
reuse_freqs_front_part=True,
nope_first=is_nope_first,
)


def rocm_aiter_rotary_emb_with_key_forward_hip_fake(
positions: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
rotate_style: int = 0,
is_nope_first: bool = False,
) -> None:
pass


def rocm_aiter_rotary_emb_without_key_forward_hip_fake(
positions: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
query: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
rotate_style: int = 0,
is_nope_first: bool = False,
) -> None:
pass


if is_rocm_rotary_embedding_enabled():

direct_register_custom_op(
op_name="rocm_aiter_rotary_emb_with_key_forward_hip",
op_func=rocm_aiter_rotary_emb_with_key_forward_hip_impl,
mutates_args=["key", "query"],
fake_impl=rocm_aiter_rotary_emb_with_key_forward_hip_fake,
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_rotary_emb_without_key_forward_hip",
op_func=rocm_aiter_rotary_emb_without_key_forward_hip_impl,
mutates_args=["query"],
fake_impl=rocm_aiter_rotary_emb_without_key_forward_hip_fake,
dispatch_key=current_platform.dispatch_key,
)