|
9 | 9 | import pytest
|
10 | 10 | import torch
|
11 | 11 | import torch.nn as nn
|
| 12 | +import torch_npu |
12 | 13 |
|
13 | 14 | from vllm_ascend.utils import enable_custom_op
|
14 | 15 |
|
@@ -141,8 +142,8 @@ def forward_native(
|
141 | 142 | @pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
142 | 143 | @pytest.mark.parametrize("seq_len", SEQ_LENS)
|
143 | 144 | @pytest.mark.parametrize("num_heads", NUM_HEADS)
|
144 |
| -@pytest.mark.parametrize("head_size", HEAD_SIZES) |
145 |
| -@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) |
| 145 | +@pytest.mark.parametrize("head_size", [128]) |
| 146 | +@pytest.mark.parametrize("rotary_dim", [128]) |
146 | 147 | @pytest.mark.parametrize("dtype", DTYPES)
|
147 | 148 | @pytest.mark.parametrize("seed", SEEDS)
|
148 | 149 | @pytest.mark.parametrize("device", DEVICES)
|
@@ -198,3 +199,65 @@ def test_rotary_embedding_quant_with_leading_dim(
|
198 | 199 | ref_key,
|
199 | 200 | atol=DEFAULT_ATOL,
|
200 | 201 | rtol=DEFAULT_RTOL)
|
| 202 | + |
| 203 | +# test npu_apply_rotary_pos_emb with head_size=128 and rotary_dim=128 and is_neox_style=True |
| 204 | +@pytest.mark.parametrize("is_neox_style", [True]) |
| 205 | +@pytest.mark.parametrize("batch_size", BATCH_SIZES) |
| 206 | +@pytest.mark.parametrize("seq_len", SEQ_LENS) |
| 207 | +@pytest.mark.parametrize("num_heads", NUM_HEADS) |
| 208 | +@pytest.mark.parametrize("head_size", [128]) |
| 209 | +@pytest.mark.parametrize("rotary_dim", [128]) |
| 210 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 211 | +@pytest.mark.parametrize("seed", SEEDS) |
| 212 | +@pytest.mark.parametrize("device", DEVICES) |
| 213 | +@torch.inference_mode() |
| 214 | +def test_npu_apply_rotary_pos_emb_with_head_size_equals_rotary_dim( |
| 215 | + is_neox_style: bool, |
| 216 | + batch_size: int, |
| 217 | + seq_len: int, |
| 218 | + num_heads: int, |
| 219 | + head_size: int, |
| 220 | + rotary_dim: Optional[int], |
| 221 | + dtype: torch.dtype, |
| 222 | + seed: int, |
| 223 | + device: str, |
| 224 | + max_position: int = 8192, |
| 225 | + base: int = 10000, |
| 226 | +) -> None: |
| 227 | + if rotary_dim is None: |
| 228 | + rotary_dim = head_size |
| 229 | + |
| 230 | + torch.set_default_device(device) |
| 231 | + if rotary_dim is None: |
| 232 | + rotary_dim = head_size |
| 233 | + rope = RotaryEmbedding(head_size, rotary_dim, max_position, base, |
| 234 | + is_neox_style, dtype) |
| 235 | + rope = rope.to(dtype=dtype) |
| 236 | + num_tokens = batch_size * seq_len |
| 237 | + positions = torch.randint(0, max_position, (batch_size * seq_len, )) |
| 238 | + qkv_tensor = torch.randn(1, num_tokens, |
| 239 | + num_heads, head_size * 3, |
| 240 | + dtype=dtype) |
| 241 | + query, key, _ = qkv_tensor.split( |
| 242 | + [head_size, head_size, head_size], |
| 243 | + dim=-1, |
| 244 | + ) |
| 245 | + |
| 246 | + ref_query, ref_key = rope.forward_native(positions, query, key) |
| 247 | + cos_sin = rope.cos_sin_cache.index_select(0, positions) |
| 248 | + last_dim = cos_sin.size()[-1] |
| 249 | + cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat(1, 1, 2).chunk(2, dim=-2) |
| 250 | + # BSNH |
| 251 | + cos, sin = cos.view(1, -1, 1, last_dim).contiguous(), sin.view( |
| 252 | + 1, -1, 1, last_dim).contiguous() |
| 253 | + torch_npu.npu_apply_rotary_pos_emb(query, key, cos, sin) |
| 254 | + |
| 255 | + # Compare the results. |
| 256 | + torch.testing.assert_close(query.view(ref_query.size()), |
| 257 | + ref_query, |
| 258 | + atol=DEFAULT_ATOL, |
| 259 | + rtol=DEFAULT_RTOL) |
| 260 | + torch.testing.assert_close(key.view(ref_key.size()), |
| 261 | + ref_key, |
| 262 | + atol=DEFAULT_ATOL, |
| 263 | + rtol=DEFAULT_RTOL) |
0 commit comments