Skip to content

Commit 139aa82

Browse files
committed
add ut for npu_apply_rotary_pos_emb
Signed-off-by: David9857 <985700846@qq.com>
1 parent 71db15b commit 139aa82

File tree

1 file changed

+65
-2
lines changed

1 file changed

+65
-2
lines changed

tests/singlecard/ops/test_rotary_embedding.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010
import torch
1111
import torch.nn as nn
12+
import torch_npu
1213

1314
from vllm_ascend.utils import enable_custom_op
1415

@@ -141,8 +142,8 @@ def forward_native(
141142
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
142143
@pytest.mark.parametrize("seq_len", SEQ_LENS)
143144
@pytest.mark.parametrize("num_heads", NUM_HEADS)
144-
@pytest.mark.parametrize("head_size", HEAD_SIZES)
145-
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
145+
@pytest.mark.parametrize("head_size", [128])
146+
@pytest.mark.parametrize("rotary_dim", [128])
146147
@pytest.mark.parametrize("dtype", DTYPES)
147148
@pytest.mark.parametrize("seed", SEEDS)
148149
@pytest.mark.parametrize("device", DEVICES)
@@ -198,3 +199,65 @@ def test_rotary_embedding_quant_with_leading_dim(
198199
ref_key,
199200
atol=DEFAULT_ATOL,
200201
rtol=DEFAULT_RTOL)
202+
203+
# test npu_apply_rotary_pos_emb with head_size=128 and rotary_dim=128 and is_neox_style=True
204+
@pytest.mark.parametrize("is_neox_style", [True])
205+
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
206+
@pytest.mark.parametrize("seq_len", SEQ_LENS)
207+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
208+
@pytest.mark.parametrize("head_size", [128])
209+
@pytest.mark.parametrize("rotary_dim", [128])
210+
@pytest.mark.parametrize("dtype", DTYPES)
211+
@pytest.mark.parametrize("seed", SEEDS)
212+
@pytest.mark.parametrize("device", DEVICES)
213+
@torch.inference_mode()
214+
def test_npu_apply_rotary_pos_emb_with_head_size_equals_rotary_dim(
215+
is_neox_style: bool,
216+
batch_size: int,
217+
seq_len: int,
218+
num_heads: int,
219+
head_size: int,
220+
rotary_dim: Optional[int],
221+
dtype: torch.dtype,
222+
seed: int,
223+
device: str,
224+
max_position: int = 8192,
225+
base: int = 10000,
226+
) -> None:
227+
if rotary_dim is None:
228+
rotary_dim = head_size
229+
230+
torch.set_default_device(device)
231+
if rotary_dim is None:
232+
rotary_dim = head_size
233+
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base,
234+
is_neox_style, dtype)
235+
rope = rope.to(dtype=dtype)
236+
num_tokens = batch_size * seq_len
237+
positions = torch.randint(0, max_position, (batch_size * seq_len, ))
238+
qkv_tensor = torch.randn(1, num_tokens,
239+
num_heads, head_size * 3,
240+
dtype=dtype)
241+
query, key, _ = qkv_tensor.split(
242+
[head_size, head_size, head_size],
243+
dim=-1,
244+
)
245+
246+
ref_query, ref_key = rope.forward_native(positions, query, key)
247+
cos_sin = rope.cos_sin_cache.index_select(0, positions)
248+
last_dim = cos_sin.size()[-1]
249+
cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat(1, 1, 2).chunk(2, dim=-2)
250+
# BSNH
251+
cos, sin = cos.view(1, -1, 1, last_dim).contiguous(), sin.view(
252+
1, -1, 1, last_dim).contiguous()
253+
torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin)
254+
255+
# Compare the results.
256+
torch.testing.assert_close(query.view(ref_query.size()),
257+
ref_query,
258+
atol=DEFAULT_ATOL,
259+
rtol=DEFAULT_RTOL)
260+
torch.testing.assert_close(key.view(ref_key.size()),
261+
ref_key,
262+
atol=DEFAULT_ATOL,
263+
rtol=DEFAULT_RTOL)

0 commit comments

Comments
 (0)