Skip to content
Merged
71 changes: 71 additions & 0 deletions vllm/model_executor/layers/rotary_embedding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vllm.model_executor.custom_op import CustomOp

from .common import apply_rotary_emb_dispatch, apply_rotary_emb_torch
from .rocm_aiter_rope_ops import is_rocm_rotatry_embedding_enabled


@CustomOp.register("rotary_embedding")
Expand Down Expand Up @@ -35,6 +36,7 @@ def __init__(
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
self.is_rocm_aiter_enabled = is_rocm_rotatry_embedding_enabled()

def _compute_inv_freq(self, base: float) -> torch.Tensor:
"""Compute the inverse frequency."""
Expand Down Expand Up @@ -119,6 +121,75 @@ def forward_cuda(
self.cos_sin_cache, self.is_neox_style)
return query, key

def forward_hip(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
is_nope_first=False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
# currently only rotary embedding ops from AITER package are
# supported for HiP forward.
if self.is_rocm_aiter_enabled:
return self.forward_hip_rocm_aiter(positions, query, key, offsets,
is_nope_first)
return self.forward_native(positions, query, key, offsets)

def forward_hip_rocm_aiter(
self,
positions: torch.Tensor,
# if is_nope_first
# [[batch_size, seq_len, num_heads, nope_size+rope_size]
# if NOT is_nope_first
# [[batch_size, seq_len, num_heads, rope_size+nope_size],
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
is_nope_first: bool = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if self.cos_sin_cache.device != query.device or \
self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype)
cos, sin = self.cos_sin_cache.chunk(2, dim=-1)

cos = cos.unsqueeze(-2).unsqueeze(-2)
sin = sin.unsqueeze(-2).unsqueeze(-2)

rotate_style = 0 if self.is_neox_style else 1

num_tokens = positions.numel()

query_shape = query.shape
query = query.view(1, num_tokens, -1, self.head_size)
if key is not None:
key_shape = key.shape
key = key.view(1, num_tokens, -1, self.head_size)

positions = positions.view(*query.shape[:2])
if offsets is not None:
offsets = offsets.view(*query.shape[:2])

if not is_nope_first:
query_ = query[..., :self.rotary_dim]
key_ = key[..., :self.rotary_dim] if key is not None else None
else:
query_ = query[..., -self.rotary_dim:]
key_ = key[..., -self.rotary_dim:] if key is not None else None

if key_ is None:
torch.ops.vllm.rocm_aiter_rotary_emb_without_key_forward_hip(
positions, sin, cos, query_, offsets, rotate_style,
is_nope_first)
return query.view(query_shape), None

torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_hip(
positions, sin, cos, query_, key_, offsets, rotate_style,
is_nope_first)

return query.view(query_shape), key.view(key_shape)

def forward_xpu(
self,
positions: torch.Tensor,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/rotary_embedding/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def yarn_linear_ramp_mask(low: float, high: float, dim: int,
return ramp_func


def yarn_get_mscale(scale: float = 1) -> float:
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0
return 0.1 * mscale * math.log(scale) + 1.0
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import math
from typing import Optional

import torch
Expand All @@ -10,13 +9,7 @@

from .base import RotaryEmbedding
from .common import (rotate_gptj, rotate_neox, yarn_find_correction_range,
yarn_linear_ramp_mask)


def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
yarn_get_mscale, yarn_linear_ramp_mask)


class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
Expand Down Expand Up @@ -96,6 +89,12 @@ def forward(
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
if self.is_rocm_aiter_enabled:
query, key = super().forward(positions, query, key, offsets)
if positions.numel() == 1:
key = key.clone()
return query, key

assert key is not None
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
Expand Down
127 changes: 127 additions & 0 deletions vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Optional

import torch

import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op


def is_rocm_rotatry_embedding_enabled() -> bool:
return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER)


def rocm_aiter_rotary_emb_without_key_forward_hip_impl(
positions: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
query: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
rotate_style: int = 0,
is_nope_first: bool = False,
) -> None:
import aiter as ops
if offsets is None:
ops.rope_cached_positions_fwd_inplace(
query,
cos,
sin,
positions,
rotate_style,
reuse_freqs_front_part=True,
nope_first=is_nope_first,
)
else:
ops.rope_cached_positions_offsets_fwd_inplace(
query,
cos,
sin,
positions,
offsets,
rotate_style,
reuse_freqs_front_part=True,
nope_first=is_nope_first,
)


def rocm_aiter_rotary_emb_with_key_forward_hip_impl(
positions: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
rotate_style: int = 0,
is_nope_first: bool = False,
) -> None:
import aiter as ops
if offsets is None:
ops.rope_cached_positions_2c_fwd_inplace(
query,
key,
cos,
sin,
positions,
rotate_style,
reuse_freqs_front_part=True,
nope_first=is_nope_first,
)
else:
ops.rope_cached_positions_offsets_2c_fwd_inplace(
query,
key,
cos,
sin,
positions,
offsets,
rotate_style,
reuse_freqs_front_part=True,
nope_first=is_nope_first,
)


def rocm_aiter_rotary_emb_with_key_forward_hip_fake(
positions: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
rotate_style: int = 0,
is_nope_first: bool = False,
) -> None:
pass


def rocm_aiter_rotary_emb_without_key_forward_hip_fake(
positions: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
query: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
rotate_style: int = 0,
is_nope_first: bool = False,
) -> None:
pass


if is_rocm_rotatry_embedding_enabled():

direct_register_custom_op(
op_name="rocm_aiter_rotary_emb_with_key_forward_hip",
op_func=rocm_aiter_rotary_emb_with_key_forward_hip_impl,
mutates_args=["key", "query"],
fake_impl=rocm_aiter_rotary_emb_with_key_forward_hip_fake,
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_rotary_emb_without_key_forward_hip",
op_func=rocm_aiter_rotary_emb_without_key_forward_hip_impl,
mutates_args=["query"],
fake_impl=rocm_aiter_rotary_emb_without_key_forward_hip_fake,
dispatch_key=current_platform.dispatch_key,
)
Loading