Skip to content

Commit c589d9b

Browse files
committed
[Refactor] Refactor torchair
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent 6af35f6 commit c589d9b

19 files changed

+1024
-494
lines changed

tests/ut/__init__.py

Whitespace-only changes.

tests/ut/test_platform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ def test_get_attn_backend_cls_use_v1_and_torchair(self,
523523
)
524524
self.assertEqual(
525525
result,
526-
"vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
526+
"vllm_ascend.torchair.attention_torchair.AscendAttentionTorchairBackend"
527527
)
528528

529529
@patch('vllm_ascend.platform.get_ascend_config')

tests/ut/test_utils.py

Lines changed: 4 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -123,16 +123,13 @@ def test_aligned_16(self):
123123
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
124124
new=mock.MagicMock)
125125
@mock.patch('vllm_ascend.utils.is_310p')
126-
@mock.patch('vllm_ascend.utils.get_ascend_config')
127-
def test_maybe_converting_weight_acl_format(self, mock_get_config,
128-
mock_310p, mock_npu_cast,
126+
def test_maybe_converting_weight_acl_format(self, mock_310p, mock_npu_cast,
129127
mock_get_format):
130128
ACL_FORMAT_FRACTAL_NZ = 29
131129
mock_310p.return_value = True
132130

133131
mock_config = mock.MagicMock()
134132
mock_config.torchair_graph_config.enabled = True
135-
mock_get_config.return_value = mock_config
136133
mock_get_format.return_value = 1
137134

138135
mock_npu_cast.return_value = 1
@@ -145,23 +142,21 @@ def test_maybe_converting_weight_acl_format(self, mock_get_config,
145142
model = mock.MagicMock()
146143
model.modules.return_value = [fused_moe]
147144

148-
utils.maybe_converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
145+
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
149146
self.assertEqual(fused_moe.w13_weight.data, 1)
150147

151148
@mock.patch('torch_npu.get_npu_format')
152149
@mock.patch('torch_npu.npu_format_cast')
153150
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
154151
new=mock.MagicMock)
155152
@mock.patch('vllm_ascend.utils.is_310p')
156-
@mock.patch('vllm_ascend.utils.get_ascend_config')
157153
def test_maybe_converting_weight_acl_format_format_true(
158-
self, mock_get_config, mock_310p, mock_npu_cast, mock_get_format):
154+
self, mock_310p, mock_npu_cast, mock_get_format):
159155
ACL_FORMAT_FRACTAL_NZ = 29
160156
mock_310p.return_value = True
161157

162158
mock_config = mock.MagicMock()
163159
mock_config.torchair_graph_config.enabled = True
164-
mock_get_config.return_value = mock_config
165160
mock_get_format.return_value = ACL_FORMAT_FRACTAL_NZ
166161

167162
mock_npu_cast.return_value = 1
@@ -176,20 +171,7 @@ def test_maybe_converting_weight_acl_format_format_true(
176171

177172
mock_get_format.return_value = ACL_FORMAT_FRACTAL_NZ
178173

179-
utils.maybe_converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
180-
181-
@mock.patch('vllm_ascend.utils.get_ascend_config')
182-
@mock.patch('vllm_ascend.utils.is_310p', return_value=False)
183-
def test_maybe_converting_weight_acl_format_not_310_not_graph(
184-
self, mock_310p, mock_get_config):
185-
mock_config = mock.MagicMock()
186-
mock_config.torchair_graph_config.enabled = False
187-
mock_get_config.return_value = mock_config
188-
189-
mock_constant = mock.MagicMock()
190-
191-
mock_model = mock.MagicMock()
192-
utils.maybe_converting_weight_acl_format(mock_model, mock_constant)
174+
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
193175

194176
@mock.patch('importlib.util.find_spec')
195177
@mock.patch('importlib.import_module')
@@ -280,27 +262,6 @@ def test_update_aclgraph_sizes(self):
280262
3,
281263
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
282264

283-
def test_get_torchair_current_work_dir(self):
284-
cache_dir = utils.TORCHAIR_CACHE_DIR
285-
work_dir = utils.get_torchair_current_work_dir()
286-
self.assertEqual(cache_dir, work_dir)
287-
work_dir = utils.get_torchair_current_work_dir("test")
288-
self.assertEqual(os.path.join(cache_dir, "test"), work_dir)
289-
290-
def test_torchair_cache_dir(self):
291-
utils.write_kv_cache_bytes_to_file(0, 100)
292-
self.assertTrue(utils.check_torchair_cache_exist(),
293-
"Create torchair cache dir failed")
294-
self.assertTrue(utils.check_kv_cache_bytes_cache_exist(),
295-
"Create kv cache bytes cache dir failed")
296-
kv_cache_bytes = utils.read_kv_cache_bytes_from_file(0)
297-
self.assertEqual(100, kv_cache_bytes)
298-
utils.delete_torchair_cache_file()
299-
self.assertFalse(utils.check_torchair_cache_exist(),
300-
"Delete torchair cache dir failed")
301-
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
302-
"Delete kv cache bytes cache dir failed")
303-
304265

305266
class TestProfileExecuteDuration(unittest.TestCase):
306267

tests/ut/torchair/__init__.py

Whitespace-only changes.

tests/ut/torchair/test_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import os
2+
3+
from tests.ut.base import TestBase
4+
from vllm_ascend.torchair import utils
5+
6+
7+
class TestTorchairUtils(TestBase):
8+
9+
def test_get_torchair_current_work_dir(self):
10+
cache_dir = utils.TORCHAIR_CACHE_DIR
11+
work_dir = utils.get_torchair_current_work_dir()
12+
self.assertEqual(cache_dir, work_dir)
13+
work_dir = utils.get_torchair_current_work_dir("test")
14+
self.assertEqual(os.path.join(cache_dir, "test"), work_dir)
15+
16+
def test_torchair_cache_dir(self):
17+
utils.write_kv_cache_bytes_to_file(0, 100)
18+
self.assertTrue(utils.check_torchair_cache_exist(),
19+
"Create torchair cache dir failed")
20+
self.assertTrue(utils.check_kv_cache_bytes_cache_exist(),
21+
"Create kv cache bytes cache dir failed")
22+
kv_cache_bytes = utils.read_kv_cache_bytes_from_file(0)
23+
self.assertEqual(100, kv_cache_bytes)
24+
utils.delete_torchair_cache_file()
25+
self.assertFalse(utils.check_torchair_cache_exist(),
26+
"Delete torchair cache dir failed")
27+
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
28+
"Delete kv cache bytes cache dir failed")

tests/ut/worker/test_pooling_model_runner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import unittest
23
from unittest.mock import MagicMock, patch
34

@@ -24,9 +25,10 @@ def _create_model_runner(self, model: str, *args,
2425
def setUp(self):
2526
"""Initialize test fixtures and common mocks"""
2627
self.attn_backend = "npu"
27-
28+
model_path = os.path.join(os.path.dirname(__file__), "..",
29+
"fake_weight")
2830
model_runner = self._create_model_runner(
29-
"tests/ut/fake_weight",
31+
model_path,
3032
trust_remote_code=True,
3133
enable_chunked_prefill=False,
3234
)

vllm_ascend/attention/mla_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from vllm_ascend.multistream.context import get_multistream_comm_context
2222
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
2323
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
24-
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
24+
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
2525
from vllm_ascend.worker.npu_input_batch import InputBatch
2626

2727
if TYPE_CHECKING:

vllm_ascend/models/deepseek_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@
7474
from vllm_ascend.ops.fused_moe import AscendFusedMoE
7575
from vllm_ascend.quantization.quant_config import AscendLinearMethod
7676
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
77-
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
78-
npu_wait_tensor)
77+
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
78+
from vllm_ascend.utils import dispose_tensor
7979

8080

8181
class CustomDeepseekV2SiluAndMul(SiluAndMul):

vllm_ascend/ops/fused_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@
4141
from vllm_ascend.ascend_config import get_ascend_config
4242
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
4343
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
44+
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
4445
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
45-
get_fused_moe_state, is_310p, npu_stream_switch,
46-
npu_wait_tensor)
46+
get_fused_moe_state, is_310p)
4747

4848
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
4949

vllm_ascend/platform.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
181181

182182
if parallel_config and parallel_config.worker_cls == "auto":
183183
if envs.VLLM_USE_V1:
184-
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
184+
if ascend_config.torchair_graph_config.enabled:
185+
parallel_config.worker_cls = "vllm_ascend.torchair.worker_torchair.NPUTorchairWorker"
186+
else:
187+
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
185188
elif vllm_config.speculative_config:
186189
# NOTE: We set this var to `1` in vllm-ascend to avoid segment
187190
# fault when using spec decode with V0 engine.
@@ -224,7 +227,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
224227
return "vllm_ascend.attention.mla_v1.AscendMLABackend"
225228
use_torchair = get_ascend_config().torchair_graph_config.enabled
226229
if use_v1 and use_torchair:
227-
return "vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
230+
return "vllm_ascend.torchair.attention_torchair.AscendAttentionTorchairBackend"
228231
if use_v1:
229232
return "vllm_ascend.attention.attention_v1.AscendAttentionBackend"
230233
if use_mla:

0 commit comments

Comments
 (0)