Skip to content

Commit e568126

Browse files
committed
add ut for npu_apply_rotary_pos_emb
Signed-off-by: David9857 <985700846@qq.com>
1 parent 71db15b commit e568126

File tree

2 files changed

+72
-5
lines changed

2 files changed

+72
-5
lines changed

tests/singlecard/ops/test_rotary_embedding.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010
import torch
1111
import torch.nn as nn
12+
import torch_npu
1213

1314
from vllm_ascend.utils import enable_custom_op
1415

@@ -198,3 +199,69 @@ def test_rotary_embedding_quant_with_leading_dim(
198199
ref_key,
199200
atol=DEFAULT_ATOL,
200201
rtol=DEFAULT_RTOL)
202+
203+
204+
# test npu_apply_rotary_pos_emb with head_size=128 and rotary_dim=128 and is_neox_style=True
205+
@pytest.mark.parametrize("is_neox_style", [True])
206+
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
207+
@pytest.mark.parametrize("seq_len", SEQ_LENS)
208+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
209+
@pytest.mark.parametrize("head_size", [128])
210+
@pytest.mark.parametrize("rotary_dim", [128])
211+
@pytest.mark.parametrize("dtype", DTYPES)
212+
@pytest.mark.parametrize("seed", SEEDS)
213+
@pytest.mark.parametrize("device", DEVICES)
214+
@torch.inference_mode()
215+
def test_npu_apply_rotary_pos_emb_with_head_size_equals_rotary_dim(
216+
is_neox_style: bool,
217+
batch_size: int,
218+
seq_len: int,
219+
num_heads: int,
220+
head_size: int,
221+
rotary_dim: Optional[int],
222+
dtype: torch.dtype,
223+
seed: int,
224+
device: str,
225+
max_position: int = 8192,
226+
base: int = 10000,
227+
) -> None:
228+
if rotary_dim is None:
229+
rotary_dim = head_size
230+
231+
torch.set_default_device(device)
232+
if rotary_dim is None:
233+
rotary_dim = head_size
234+
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base,
235+
is_neox_style, dtype)
236+
rope = rope.to(dtype=dtype)
237+
num_tokens = batch_size * seq_len
238+
positions = torch.randint(0, max_position, (batch_size * seq_len, ))
239+
qkv_tensor = torch.randn(1,
240+
num_tokens,
241+
num_heads,
242+
head_size * 3,
243+
dtype=dtype)
244+
query, key, _ = qkv_tensor.split(
245+
[head_size, head_size, head_size],
246+
dim=-1,
247+
)
248+
249+
ref_query, ref_key = rope.forward_native(positions, query, key)
250+
cos_sin = rope.cos_sin_cache.index_select(0, positions)
251+
last_dim = cos_sin.size()[-1]
252+
cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat(1, 1,
253+
2).chunk(2, dim=-2)
254+
# BSNH
255+
cos, sin = cos.view(1, -1, 1, last_dim).contiguous(), sin.view(
256+
1, -1, 1, last_dim).contiguous()
257+
torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin)
258+
259+
# Compare the results.
260+
torch.testing.assert_close(query.view(ref_query.size()),
261+
ref_query,
262+
atol=DEFAULT_ATOL,
263+
rtol=DEFAULT_RTOL)
264+
torch.testing.assert_close(key.view(ref_key.size()),
265+
ref_key,
266+
atol=DEFAULT_ATOL,
267+
rtol=DEFAULT_RTOL)

vllm_ascend/models/qwen3.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,23 +141,23 @@ def __init__(
141141
"Expected quant_config to be an instance of AscendQuantConfig"
142142

143143
if isinstance(self.self_attn.qkv_proj.quant_method.quant_method,
144-
AscendW8A8LinearMethod):
144+
AscendW8A8LinearMethod):
145145
self.input_layernorm = AddRMSNormW8A8Quant(
146146
config.hidden_size,
147147
layer=self.self_attn.qkv_proj,
148148
eps=config.rms_norm_eps)
149149
else:
150150
self.input_layernorm = RMSNorm(config.hidden_size,
151-
eps=config.rms_norm_eps)
151+
eps=config.rms_norm_eps)
152152
if isinstance(self.mlp.gate_up_proj.quant_method.quant_method,
153-
AscendW8A8LinearMethod):
153+
AscendW8A8LinearMethod):
154154
self.post_attention_layernorm = AddRMSNormW8A8Quant(
155155
config.hidden_size,
156156
layer=self.mlp.gate_up_proj,
157157
eps=config.rms_norm_eps)
158158
else:
159-
self.post_attention_layernorm = RMSNorm(config.hidden_size,
160-
eps=config.rms_norm_eps)
159+
self.post_attention_layernorm = RMSNorm(
160+
config.hidden_size, eps=config.rms_norm_eps)
161161

162162
def forward(
163163
self,

0 commit comments

Comments
 (0)