Skip to content

Commit 07890f4

Browse files
committed
[Refactor] Create torchair module
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent e65fcea commit 07890f4

15 files changed

+995
-432
lines changed

tests/ut/__init__.py

Whitespace-only changes.

tests/ut/test_platform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ def test_get_attn_backend_cls_use_v1_and_torchair(self,
523523
)
524524
self.assertEqual(
525525
result,
526-
"vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
526+
"vllm_ascend.torchair.attention_torchair.AscendAttentionTorchairBackend"
527527
)
528528

529529
@patch('vllm_ascend.platform.get_ascend_config')

tests/ut/test_utils.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -280,27 +280,6 @@ def test_update_aclgraph_sizes(self):
280280
3,
281281
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
282282

283-
def test_get_torchair_current_work_dir(self):
284-
cache_dir = utils.TORCHAIR_CACHE_DIR
285-
work_dir = utils.get_torchair_current_work_dir()
286-
self.assertEqual(cache_dir, work_dir)
287-
work_dir = utils.get_torchair_current_work_dir("test")
288-
self.assertEqual(os.path.join(cache_dir, "test"), work_dir)
289-
290-
def test_torchair_cache_dir(self):
291-
utils.write_kv_cache_bytes_to_file(0, 100)
292-
self.assertTrue(utils.check_torchair_cache_exist(),
293-
"Create torchair cache dir failed")
294-
self.assertTrue(utils.check_kv_cache_bytes_cache_exist(),
295-
"Create kv cache bytes cache dir failed")
296-
kv_cache_bytes = utils.read_kv_cache_bytes_from_file(0)
297-
self.assertEqual(100, kv_cache_bytes)
298-
utils.delete_torchair_cache_file()
299-
self.assertFalse(utils.check_torchair_cache_exist(),
300-
"Delete torchair cache dir failed")
301-
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
302-
"Delete kv cache bytes cache dir failed")
303-
304283

305284
class TestProfileExecuteDuration(unittest.TestCase):
306285

tests/ut/torchair/__init__.py

Whitespace-only changes.

tests/ut/torchair/test_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import os
2+
3+
from tests.ut.base import TestBase
4+
from vllm_ascend.torchair import utils
5+
6+
7+
class TestTorchairUtils(TestBase):
8+
9+
def test_get_torchair_current_work_dir(self):
10+
cache_dir = utils.TORCHAIR_CACHE_DIR
11+
work_dir = utils.get_torchair_current_work_dir()
12+
self.assertEqual(cache_dir, work_dir)
13+
work_dir = utils.get_torchair_current_work_dir("test")
14+
self.assertEqual(os.path.join(cache_dir, "test"), work_dir)
15+
16+
def test_torchair_cache_dir(self):
17+
utils.write_kv_cache_bytes_to_file(0, 100)
18+
self.assertTrue(utils.check_torchair_cache_exist(),
19+
"Create torchair cache dir failed")
20+
self.assertTrue(utils.check_kv_cache_bytes_cache_exist(),
21+
"Create kv cache bytes cache dir failed")
22+
kv_cache_bytes = utils.read_kv_cache_bytes_from_file(0)
23+
self.assertEqual(100, kv_cache_bytes)
24+
utils.delete_torchair_cache_file()
25+
self.assertFalse(utils.check_torchair_cache_exist(),
26+
"Delete torchair cache dir failed")
27+
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
28+
"Delete kv cache bytes cache dir failed")

tests/ut/worker/test_pooling_model_runner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import unittest
23
from unittest.mock import MagicMock, patch
34

@@ -24,9 +25,10 @@ def _create_model_runner(self, model: str, *args,
2425
def setUp(self):
2526
"""Initialize test fixtures and common mocks"""
2627
self.attn_backend = "npu"
27-
28+
model_path = os.path.join(os.path.dirname(__file__), "..",
29+
"fake_weight")
2830
model_runner = self._create_model_runner(
29-
"tests/ut/fake_weight",
31+
model_path,
3032
trust_remote_code=True,
3133
enable_chunked_prefill=False,
3234
)

vllm_ascend/platform.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
181181

182182
if parallel_config and parallel_config.worker_cls == "auto":
183183
if envs.VLLM_USE_V1:
184-
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
184+
if ascend_config.torchair_graph_config.enabled:
185+
parallel_config.worker_cls = "vllm_ascend.torchair.worker_v1_torchair.NPUTorchairWorker"
186+
else:
187+
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
185188
elif vllm_config.speculative_config:
186189
# NOTE: We set this var to `1` in vllm-ascend to avoid segment
187190
# fault when using spec decode with V0 engine.
@@ -224,7 +227,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
224227
return "vllm_ascend.attention.mla_v1.AscendMLABackend"
225228
use_torchair = get_ascend_config().torchair_graph_config.enabled
226229
if use_v1 and use_torchair:
227-
return "vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
230+
return "vllm_ascend.torchair.attention_torchair.AscendAttentionTorchairBackend"
228231
if use_v1:
229232
return "vllm_ascend.attention.attention_v1.AscendAttentionBackend"
230233
if use_mla:

vllm_ascend/torchair/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)