Skip to content

Commit 9d6c8b0

Browse files
committed
[Refactor] Create torchair module
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent e65fcea commit 9d6c8b0

File tree

11 files changed

+989
-429
lines changed

11 files changed

+989
-429
lines changed

tests/ut/test_utils.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -280,27 +280,6 @@ def test_update_aclgraph_sizes(self):
280280
3,
281281
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
282282

283-
def test_get_torchair_current_work_dir(self):
284-
cache_dir = utils.TORCHAIR_CACHE_DIR
285-
work_dir = utils.get_torchair_current_work_dir()
286-
self.assertEqual(cache_dir, work_dir)
287-
work_dir = utils.get_torchair_current_work_dir("test")
288-
self.assertEqual(os.path.join(cache_dir, "test"), work_dir)
289-
290-
def test_torchair_cache_dir(self):
291-
utils.write_kv_cache_bytes_to_file(0, 100)
292-
self.assertTrue(utils.check_torchair_cache_exist(),
293-
"Create torchair cache dir failed")
294-
self.assertTrue(utils.check_kv_cache_bytes_cache_exist(),
295-
"Create kv cache bytes cache dir failed")
296-
kv_cache_bytes = utils.read_kv_cache_bytes_from_file(0)
297-
self.assertEqual(100, kv_cache_bytes)
298-
utils.delete_torchair_cache_file()
299-
self.assertFalse(utils.check_torchair_cache_exist(),
300-
"Delete torchair cache dir failed")
301-
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
302-
"Delete kv cache bytes cache dir failed")
303-
304283

305284
class TestProfileExecuteDuration(unittest.TestCase):
306285

tests/ut/torchair/test_utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import os
2+
3+
from tests.ut.base import TestBase
4+
from vllm_ascend.torchair import utils
5+
6+
7+
class TestTorchairUtils(TestBase):
8+
def test_get_torchair_current_work_dir(self):
9+
cache_dir = utils.TORCHAIR_CACHE_DIR
10+
work_dir = utils.get_torchair_current_work_dir()
11+
self.assertEqual(cache_dir, work_dir)
12+
work_dir = utils.get_torchair_current_work_dir("test")
13+
self.assertEqual(os.path.join(cache_dir, "test"), work_dir)
14+
15+
def test_torchair_cache_dir(self):
16+
utils.write_kv_cache_bytes_to_file(0, 100)
17+
self.assertTrue(utils.check_torchair_cache_exist(),
18+
"Create torchair cache dir failed")
19+
self.assertTrue(utils.check_kv_cache_bytes_cache_exist(),
20+
"Create kv cache bytes cache dir failed")
21+
kv_cache_bytes = utils.read_kv_cache_bytes_from_file(0)
22+
self.assertEqual(100, kv_cache_bytes)
23+
utils.delete_torchair_cache_file()
24+
self.assertFalse(utils.check_torchair_cache_exist(),
25+
"Delete torchair cache dir failed")
26+
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
27+
"Delete kv cache bytes cache dir failed")

vllm_ascend/platform.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
181181

182182
if parallel_config and parallel_config.worker_cls == "auto":
183183
if envs.VLLM_USE_V1:
184-
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
184+
if ascend_config.torchair_graph_config.enabled:
185+
parallel_config.worker_cls = "vllm_ascend.torchair.worker_v1_torchair.NPUTorchairWorker"
186+
else:
187+
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
185188
elif vllm_config.speculative_config:
186189
# NOTE: We set this var to `1` in vllm-ascend to avoid segment
187190
# fault when using spec decode with V0 engine.
@@ -224,7 +227,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
224227
return "vllm_ascend.attention.mla_v1.AscendMLABackend"
225228
use_torchair = get_ascend_config().torchair_graph_config.enabled
226229
if use_v1 and use_torchair:
227-
return "vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
230+
return "vllm_ascend.torchair.attention_torchair.AscendAttentionTorchairBackend"
228231
if use_v1:
229232
return "vllm_ascend.attention.attention_v1.AscendAttentionBackend"
230233
if use_mla:

vllm_ascend/torchair/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)