|
9 | 9 | import pytest
|
10 | 10 | import torch
|
11 | 11 | import torch.nn as nn
|
| 12 | +import torch_npu |
12 | 13 |
|
13 | 14 | from vllm_ascend.utils import enable_custom_op
|
14 | 15 |
|
@@ -198,3 +199,69 @@ def test_rotary_embedding_quant_with_leading_dim(
|
198 | 199 | ref_key,
|
199 | 200 | atol=DEFAULT_ATOL,
|
200 | 201 | rtol=DEFAULT_RTOL)
|
| 202 | + |
| 203 | + |
| 204 | +# test npu_apply_rotary_pos_emb with head_size=128 and rotary_dim=128 and is_neox_style=True |
| 205 | +@pytest.mark.parametrize("is_neox_style", [True]) |
| 206 | +@pytest.mark.parametrize("batch_size", BATCH_SIZES) |
| 207 | +@pytest.mark.parametrize("seq_len", SEQ_LENS) |
| 208 | +@pytest.mark.parametrize("num_heads", NUM_HEADS) |
| 209 | +@pytest.mark.parametrize("head_size", [128]) |
| 210 | +@pytest.mark.parametrize("rotary_dim", [128]) |
| 211 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 212 | +@pytest.mark.parametrize("seed", SEEDS) |
| 213 | +@pytest.mark.parametrize("device", DEVICES) |
| 214 | +@torch.inference_mode() |
| 215 | +def test_npu_apply_rotary_pos_emb_with_head_size_equals_rotary_dim( |
| 216 | + is_neox_style: bool, |
| 217 | + batch_size: int, |
| 218 | + seq_len: int, |
| 219 | + num_heads: int, |
| 220 | + head_size: int, |
| 221 | + rotary_dim: Optional[int], |
| 222 | + dtype: torch.dtype, |
| 223 | + seed: int, |
| 224 | + device: str, |
| 225 | + max_position: int = 8192, |
| 226 | + base: int = 10000, |
| 227 | +) -> None: |
| 228 | + if rotary_dim is None: |
| 229 | + rotary_dim = head_size |
| 230 | + |
| 231 | + torch.set_default_device(device) |
| 232 | + if rotary_dim is None: |
| 233 | + rotary_dim = head_size |
| 234 | + rope = RotaryEmbedding(head_size, rotary_dim, max_position, base, |
| 235 | + is_neox_style, dtype) |
| 236 | + rope = rope.to(dtype=dtype) |
| 237 | + num_tokens = batch_size * seq_len |
| 238 | + positions = torch.randint(0, max_position, (batch_size * seq_len, )) |
| 239 | + qkv_tensor = torch.randn(1, |
| 240 | + num_tokens, |
| 241 | + num_heads, |
| 242 | + head_size * 3, |
| 243 | + dtype=dtype) |
| 244 | + query, key, _ = qkv_tensor.split( |
| 245 | + [head_size, head_size, head_size], |
| 246 | + dim=-1, |
| 247 | + ) |
| 248 | + |
| 249 | + ref_query, ref_key = rope.forward_native(positions, query, key) |
| 250 | + cos_sin = rope.cos_sin_cache.index_select(0, positions) |
| 251 | + last_dim = cos_sin.size()[-1] |
| 252 | + cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat(1, 1, |
| 253 | + 2).chunk(2, dim=-2) |
| 254 | + # BSNH |
| 255 | + cos, sin = cos.view(1, -1, 1, last_dim).contiguous(), sin.view( |
| 256 | + 1, -1, 1, last_dim).contiguous() |
| 257 | + torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin) |
| 258 | + |
| 259 | + # Compare the results. |
| 260 | + torch.testing.assert_close(query.view(ref_query.size()), |
| 261 | + ref_query, |
| 262 | + atol=DEFAULT_ATOL, |
| 263 | + rtol=DEFAULT_RTOL) |
| 264 | + torch.testing.assert_close(key.view(ref_key.size()), |
| 265 | + ref_key, |
| 266 | + atol=DEFAULT_ATOL, |
| 267 | + rtol=DEFAULT_RTOL) |
0 commit comments