Skip to content

Commit 79a9e7b

Browse files
hust17yixuanoffline0806
authored andcommitted
[6/N][refactor]delete torchair in rotary ops (vllm-project#2581)
### What this PR does / why we need it? After moved torchair related rope ops into torchair_ops, split the torchair from the origin rope ops to make the code clean. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? vLLM version: main vLLM main: vllm-project/vllm@ab9f2cf - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@81eea3d Signed-off-by: hust17yixuan <303660421@qq.com> Signed-off-by: offline0806 <z00858301@china.huawei.com>
1 parent c2c97f3 commit 79a9e7b

File tree

2 files changed

+7
-83
lines changed

2 files changed

+7
-83
lines changed

tests/ut/ops/test_rotary_embedding.py

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -88,36 +88,16 @@ def setUp(self):
8888
self.mock_self.cos_sin_cache = self.cos_sin_cache
8989
self.mock_self.is_neox_style = self.is_neox_style
9090

91-
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
92-
def test_rope_forward_oot_torchair_enabled_base(self,
93-
mock_get_ascend_config):
94-
# Setup mock for torchair enabled
95-
mock_config = MagicMock()
96-
mock_config.torchair_graph_config.enabled = True
97-
mock_get_ascend_config.return_value = mock_config
98-
with patch.object(self.layer,
99-
"forward_native",
100-
return_value=(self.query,
101-
self.key)) as mock_forward_native:
102-
result_q, result_k = self.layer.forward(self.positions, self.query,
103-
self.key)
104-
105-
mock_forward_native.assert_called_once()
106-
self.assertTrue(torch.equal(result_q, self.query))
107-
self.assertTrue(torch.equal(result_k, self.key))
108-
10991
@patch('torch.ops._C')
110-
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
11192
@patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False)
11293
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
11394
return_value=True)
11495
@patch('torch.ops._npu_rotary_embedding')
11596
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
11697
mock_custom_enabled, mock_is_310p,
117-
mock_get_ascend_config, mock__c):
98+
mock__c):
11899
mock_config = MagicMock()
119100
mock_config.torchair_graph_config.enabled = False
120-
mock_get_ascend_config.return_value = mock_config
121101

122102
# Setup mock for custom kernel path
123103

@@ -130,16 +110,13 @@ def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
130110
self.assertEqual(result_q.shape, self.query.shape)
131111
self.assertEqual(result_k.shape, self.key.shape)
132112

133-
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
134113
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
135114
return_value=False)
136115
@patch('torch_npu._npu_rotary_embedding')
137116
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
138-
mock_custom_enabled,
139-
mock_get_ascend_config):
117+
mock_custom_enabled):
140118
mock_config = MagicMock()
141119
mock_config.torchair_graph_config.enabled = False
142-
mock_get_ascend_config.return_value = mock_config
143120

144121
# Test contiguous path when custom is disabled
145122
non_contig_query = self.query.transpose(0, 1)
@@ -153,27 +130,22 @@ def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
153130
self.assertEqual(result_q.shape, non_contig_query.shape)
154131
self.assertEqual(result_k.shape, non_contig_key.shape)
155132

156-
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
157-
def test_rope_forward_oot_with_offsets(self, mock_get_ascend_config):
133+
def test_rope_forward_oot_with_offsets(self):
158134
mock_config = MagicMock()
159135
mock_config.torchair_graph_config.enabled = False
160-
mock_get_ascend_config.return_value = mock_config
161136

162137
# Test that NotImplementedError is raised when offsets is provided
163138
offsets = torch.tensor([1, 2, 3])
164139
with self.assertRaises(NotImplementedError):
165140
self.layer.forward(self.positions, self.query, self.key, offsets)
166141

167-
@patch('vllm_ascend.ops.rotary_embedding.get_ascend_config')
168142
@patch('vllm_ascend.ops.rotary_embedding.custom_rotary_embedding_enabled',
169143
return_value=False)
170144
@patch('torch_npu._npu_rotary_embedding')
171145
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
172-
mock_custom_enabled,
173-
mock_get_ascend_config):
146+
mock_custom_enabled):
174147
mock_config = MagicMock()
175148
mock_config.torchair_graph_config.enabled = False
176-
mock_get_ascend_config.return_value = mock_config
177149

178150
# Test neox_style override
179151
result_q, result_k = self.layer.forward(self.positions,

vllm_ascend/ops/rotary_embedding.py

Lines changed: 3 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from vllm.model_executor.layers.rotary_embedding import (
2525
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
2626

27-
from vllm_ascend.ascend_config import get_ascend_config
2827
from vllm_ascend.platform import NPUPlatform
2928
from vllm_ascend.utils import enable_custom_op, is_310p
3029

@@ -43,15 +42,6 @@ def rope_forward_oot(
4342
is_neox_style_override: Optional[bool] = None,
4443
is_qwen_torchair: Optional[bool] = False,
4544
) -> Tuple[torch.Tensor, torch.Tensor]:
46-
if get_ascend_config(
47-
).torchair_graph_config.enabled and not is_qwen_torchair:
48-
return self.forward_native(
49-
positions,
50-
query,
51-
key,
52-
offsets,
53-
)
54-
5545
query_shape, key_shape = query.shape, key.shape
5646
if self.cos_sin_cache.device != query.device:
5747
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
@@ -120,11 +110,6 @@ def __init__(
120110
) -> None:
121111
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
122112
is_neox_style, dtype)
123-
if get_ascend_config().torchair_graph_config.enabled:
124-
set_cos_sin_cache(self,
125-
seq_len=max_position_embeddings,
126-
device="npu",
127-
dtype=dtype)
128113

129114
def forward_oot(
130115
self,
@@ -137,42 +122,9 @@ def forward_oot(
137122
is_prefill: Optional[bool] = True,
138123
is_qwen_torchair: Optional[bool] = False,
139124
):
140-
if get_ascend_config().torchair_graph_config.enabled \
141-
and is_qwen_torchair and not is_prefill:
142-
if max_seq_len is not None and torch.gt(
143-
max_seq_len, self.max_position_embeddings):
144-
set_cos_sin_cache(self,
145-
seq_len=max_seq_len,
146-
device=query.device,
147-
dtype=torch.float32)
148-
149-
# bsnd/bnsd
150-
if positions is not None:
151-
cos = self.embed(positions, self.cos)
152-
sin = self.embed(positions, self.sin)
153-
self.cos_embed = cos
154-
self.sin_embed = sin
155-
else:
156-
cos = self.cos_embed
157-
sin = self.sin_embed
158-
159-
query = query.view(*query.shape[:-1], -1,
160-
self.head_size).contiguous()
161-
key = key.view(*key.shape[:-1], -1, self.head_size).contiguous()
162-
163-
cos = cos.unsqueeze(-2).unsqueeze(-2)
164-
sin = sin.unsqueeze(-2).unsqueeze(-2)
165-
166-
query = query.unsqueeze(1)
167-
key = key.unsqueeze(1)
168-
169-
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(
170-
query, key, cos, sin)
171-
return q_embed.flatten(-2), k_embed.flatten(-2)
172-
else:
173-
return rope_forward_oot(self, positions, query, key, offsets,
174-
is_neox_style_override,
175-
is_qwen_torchair) # type: ignore
125+
return rope_forward_oot(self, positions, query, key, offsets,
126+
is_neox_style_override,
127+
is_qwen_torchair) # type: ignore
176128

177129

178130
class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):

0 commit comments

Comments
 (0)