Skip to content

Commit fb891e0

Browse files
committed
add torchair ops
Signed-off-by: hust17yixuan <303660421@qq.com>
1 parent 516e14a commit fb891e0

File tree

4 files changed

+721
-0
lines changed

4 files changed

+721
-0
lines changed
Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
import math
2+
from unittest.mock import MagicMock, patch
3+
4+
import torch
5+
6+
from tests.ut.base import TestBase
7+
from vllm_ascend.torchair.ops.torchair_rotary_embedding import (
8+
custom_rotary_embedding_enabled, native_rope_deepseek_forward,
9+
rope_forward_oot, rotate_half, yarn_find_correction_dim, yarn_get_mscale)
10+
11+
12+
class TestCustomRotaryEmbeddingEnabled(TestBase):
13+
14+
def setUp(self):
15+
# Common setup for tests
16+
self.positions = torch.tensor([1, 2, 3])
17+
self.query = torch.randn(3, 4, dtype=torch.float16)
18+
self.key = torch.randn(3, 4, dtype=torch.float16)
19+
self.head_size = 32
20+
self.cos_sin_cache = torch.randn(3, 4)
21+
22+
# Mock self object for rope_forward_oot
23+
self.mock_self = MagicMock()
24+
self.mock_self.head_size = self.head_size
25+
self.mock_self.cos_sin_cache = self.cos_sin_cache
26+
self.mock_self.is_neox_style = True
27+
self.mock_self.forward_native.return_value = (self.query, self.key)
28+
29+
def test_custom_rotary_embedding_enabled(self):
30+
# Test when all conditions are True
31+
with patch(
32+
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
33+
return_value=True):
34+
result = custom_rotary_embedding_enabled(self.query, True,
35+
self.head_size)
36+
self.assertTrue(result)
37+
38+
# Test when dtype is not float16
39+
with patch(
40+
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
41+
return_value=True):
42+
query = self.query.to(torch.float32)
43+
result = custom_rotary_embedding_enabled(query, True,
44+
self.head_size)
45+
self.assertFalse(result)
46+
47+
# Test when neox_style is False
48+
with patch(
49+
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
50+
return_value=True):
51+
result = custom_rotary_embedding_enabled(self.query, False,
52+
self.head_size)
53+
self.assertFalse(result)
54+
55+
# Test when head_size is not divisible by 32
56+
with patch(
57+
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
58+
return_value=True):
59+
result = custom_rotary_embedding_enabled(self.query, True,
60+
self.head_size + 1)
61+
self.assertFalse(result)
62+
63+
# Test when custom op is disabled
64+
with patch(
65+
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
66+
return_value=False):
67+
result = custom_rotary_embedding_enabled(self.query, True,
68+
self.head_size)
69+
self.assertFalse(result)
70+
71+
72+
class TestRopeForwardOot(TestBase):
73+
74+
def setUp(self):
75+
# Common setup for tests
76+
self.positions = torch.tensor([1, 2, 3])
77+
self.query = torch.randn(3, 4, dtype=torch.float16)
78+
self.key = torch.randn(3, 4, dtype=torch.float16)
79+
self.head_size = 32
80+
self.cos_sin_cache = torch.randn(3, 4)
81+
82+
# Mock self object for rope_forward_oot
83+
self.mock_self = MagicMock()
84+
self.mock_self.head_size = self.head_size
85+
self.mock_self.cos_sin_cache = self.cos_sin_cache
86+
self.mock_self.is_neox_style = True
87+
self.mock_self.forward_native.return_value = (self.query, self.key)
88+
89+
@patch(
90+
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
91+
def test_rope_forward_oot_torchair_enabled_base(self,
92+
mock_get_ascend_config):
93+
# Setup mock for torchair enabled
94+
mock_config = MagicMock()
95+
mock_config.torchair_graph_config.enabled = True
96+
mock_get_ascend_config.return_value = mock_config
97+
98+
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
99+
self.query, self.key)
100+
101+
self.mock_self.forward_native.assert_called_once_with(
102+
self.positions, self.query, self.key, None)
103+
self.assertTrue(torch.equal(result_q, self.query))
104+
self.assertTrue(torch.equal(result_k, self.key))
105+
106+
@patch('torch.ops._C')
107+
@patch(
108+
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
109+
@patch('vllm_ascend.torchair.ops.torchair_rotary_embedding.is_310p',
110+
return_value=False)
111+
@patch(
112+
'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled',
113+
return_value=True)
114+
@patch('torch.ops._npu_rotary_embedding')
115+
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
116+
mock_custom_enabled, mock_is_310p,
117+
mock_get_ascend_config, mock__c):
118+
mock_config = MagicMock()
119+
mock_config.torchair_graph_config.enabled = False
120+
mock_get_ascend_config.return_value = mock_config
121+
122+
# Setup mock for custom kernel path
123+
124+
mock__c.rotary_embedding.return_value = self.query, self.key
125+
126+
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
127+
self.query, self.key)
128+
129+
self.assertEqual(result_q.shape, self.query.shape)
130+
self.assertEqual(result_k.shape, self.key.shape)
131+
132+
@patch(
133+
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
134+
@patch(
135+
'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled',
136+
return_value=False)
137+
@patch('torch_npu._npu_rotary_embedding')
138+
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
139+
mock_custom_enabled,
140+
mock_get_ascend_config):
141+
mock_config = MagicMock()
142+
mock_config.torchair_graph_config.enabled = False
143+
mock_get_ascend_config.return_value = mock_config
144+
145+
# Test contiguous path when custom is disabled
146+
non_contig_query = self.query.transpose(0, 1)
147+
non_contig_key = self.key.transpose(0, 1)
148+
149+
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
150+
non_contig_query, non_contig_key)
151+
152+
mock_npu_rotary.assert_called_once()
153+
self.assertEqual(result_q.shape, non_contig_query.shape)
154+
self.assertEqual(result_k.shape, non_contig_key.shape)
155+
156+
@patch(
157+
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
158+
def test_rope_forward_oot_with_offsets(self, mock_get_ascend_config):
159+
mock_config = MagicMock()
160+
mock_config.torchair_graph_config.enabled = False
161+
mock_get_ascend_config.return_value = mock_config
162+
163+
# Test that NotImplementedError is raised when offsets is provided
164+
offsets = torch.tensor([1, 2, 3])
165+
with self.assertRaises(NotImplementedError):
166+
rope_forward_oot(self.mock_self, self.positions, self.query,
167+
self.key, offsets)
168+
169+
@patch(
170+
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
171+
@patch(
172+
'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled',
173+
return_value=False)
174+
@patch('torch_npu._npu_rotary_embedding')
175+
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
176+
mock_custom_enabled,
177+
mock_get_ascend_config):
178+
mock_config = MagicMock()
179+
mock_config.torchair_graph_config.enabled = False
180+
mock_get_ascend_config.return_value = mock_config
181+
182+
# Test neox_style override
183+
result_q, result_k = rope_forward_oot(self.mock_self,
184+
self.positions,
185+
self.query,
186+
self.key,
187+
is_neox_style_override=False)
188+
189+
# Check that neox_style=False was passed to the NPU function
190+
args, kwargs = mock_npu_rotary.call_args
191+
self.assertFalse(args[-1])
192+
193+
194+
class MockRopeModule:
195+
196+
def __init__(self, max_seq_len=2048, is_neox_style=True):
197+
self.max_seq_len = max_seq_len
198+
self.is_neox_style = is_neox_style
199+
self.cos_cached = None
200+
self.sin_cached = None
201+
self.rotary_dim = 1
202+
self.base = 1
203+
204+
205+
class TestNativeRopeDeepseekForward(TestBase):
206+
207+
@patch(
208+
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
209+
def test_native_rope_deepseek_forward_base(self, mock_rope_forward_oot):
210+
module = MockRopeModule()
211+
positions = torch.tensor([1, 2, 3])
212+
query = torch.randn(1, 8, 128)
213+
key = torch.randn(1, 8, 128)
214+
215+
mock_rope_forward_oot.return_value = (query, key)
216+
217+
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
218+
key)
219+
220+
assert q_pe.shape == query.shape
221+
assert k_pe.shape == key.shape
222+
223+
@patch(
224+
'vllm_ascend.torchair.ops.torchair_rotary_embedding._set_cos_sin_cache'
225+
)
226+
@patch(
227+
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
228+
def test_native_rope_deepseek_forward_cache_handling(
229+
self, mock_rope_forward_oot, mock_set_cache):
230+
# Test cache situation is true
231+
module = MockRopeModule(max_seq_len=1024)
232+
positions = torch.tensor([1, 2, 3])
233+
query = torch.randn(1, 8, 128)
234+
key = torch.randn(1, 8, 128)
235+
236+
mock_rope_forward_oot.return_value = (query, key)
237+
238+
q_pe, k_pe = native_rope_deepseek_forward(module,
239+
positions,
240+
query,
241+
key,
242+
max_seq_len=2048)
243+
244+
assert q_pe.shape == query.shape
245+
assert k_pe.shape == key.shape
246+
247+
@patch(
248+
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
249+
def test_native_rope_deepseek_forward_key_reshaping(
250+
self, mock_rope_forward_oot):
251+
module = MockRopeModule()
252+
positions = torch.tensor([1, 2, 3])
253+
query = torch.randn(1, 8, 128)
254+
key = torch.randn(1, 128)
255+
256+
mock_rope_forward_oot.return_value = (query, key)
257+
258+
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
259+
key)
260+
261+
assert q_pe.shape == query.shape
262+
assert k_pe.shape == (1, 128)
263+
264+
@patch(
265+
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
266+
def test_native_rope_deepseek_forward_non_neox_style(
267+
self, mock_rope_forward_oot):
268+
module = MockRopeModule(is_neox_style=False)
269+
positions = torch.tensor([1, 2, 3])
270+
query = torch.randn(1, 8, 128)
271+
key = torch.randn(1, 8, 128)
272+
273+
mock_rope_forward_oot.return_value = (query, key)
274+
275+
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
276+
key)
277+
278+
assert q_pe.shape == query.shape
279+
assert k_pe.shape == key.shape
280+
281+
282+
class TestRotateHalf(TestBase):
283+
284+
def test_rotate_half_even_dim(self):
285+
# Test with even dimension
286+
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
287+
expected = torch.tensor([-3.0, -4.0, 1.0, 2.0])
288+
result = rotate_half(x)
289+
self.assertTrue(torch.allclose(result, expected))
290+
291+
292+
class TestYarnFindCorrectionDim(TestBase):
293+
294+
def test_basic_case(self):
295+
# Test with standard values
296+
num_rotations = 100
297+
dim = 512
298+
base = 10000
299+
max_position_embeddings = 2048
300+
301+
result = yarn_find_correction_dim(num_rotations, dim, base,
302+
max_position_embeddings)
303+
304+
# Calculate expected value manually
305+
expected = (dim * torch.log(
306+
torch.tensor(max_position_embeddings) /
307+
(num_rotations * 2 * torch.pi))) / (2 *
308+
torch.log(torch.tensor(base)))
309+
310+
self.assertTrue(torch.allclose(result, expected))
311+
312+
313+
class TestYarnGetMscale(TestBase):
314+
315+
def test_scale_less_than_or_equal_1(self):
316+
self.assertEqual(yarn_get_mscale(scale=0.5), 1.0)
317+
self.assertEqual(yarn_get_mscale(scale=1.0), 1.0)
318+
self.assertEqual(yarn_get_mscale(scale=0.999), 1.0)
319+
320+
def test_scale_greater_than_1(self):
321+
test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)),
322+
(10.0, 1.0, 1.0 + 0.1 * math.log(10.0)),
323+
(5.0, 2.0, 1.0 + 0.2 * math.log(5.0)),
324+
(math.e, 1.0, 1.0 + 0.1)]
325+
326+
for scale, mscale, expected in test_cases:
327+
result = yarn_get_mscale(scale, mscale)
328+
self.assertAlmostEqual(
329+
result,
330+
expected,
331+
places=6,
332+
msg=f"Failed for scale={scale}, mscale={mscale}")

0 commit comments

Comments
 (0)