Skip to content

Commit ab142d6

Browse files
committed
[Refactor] Create torchair module
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent e65fcea commit ab142d6

File tree

9 files changed

+984
-408
lines changed

9 files changed

+984
-408
lines changed

vllm_ascend/platform.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
181181

182182
if parallel_config and parallel_config.worker_cls == "auto":
183183
if envs.VLLM_USE_V1:
184-
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
184+
if ascend_config.torchair_graph_config.enabled:
185+
parallel_config.worker_cls = "vllm_ascend.torchair.worker_v1_torchair.NPUTorchairWorker"
186+
else:
187+
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
185188
elif vllm_config.speculative_config:
186189
# NOTE: We set this var to `1` in vllm-ascend to avoid segment
187190
# fault when using spec decode with V0 engine.
@@ -224,7 +227,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
224227
return "vllm_ascend.attention.mla_v1.AscendMLABackend"
225228
use_torchair = get_ascend_config().torchair_graph_config.enabled
226229
if use_v1 and use_torchair:
227-
return "vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
230+
return "vllm_ascend.torchair.attention_torchair.AscendAttentionTorchairBackend"
228231
if use_v1:
229232
return "vllm_ascend.attention.attention_v1.AscendAttentionBackend"
230233
if use_mla:

vllm_ascend/torchair/__init__.py

Whitespace-only changes.

vllm_ascend/torchair/model_runner_torchair.py

Lines changed: 735 additions & 0 deletions
Large diffs are not rendered by default.

vllm_ascend/torchair/utils.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import fcntl
2+
import os
3+
import shutil
4+
from contextlib import contextmanager
5+
6+
7+
KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
8+
KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes"
9+
TORCHAIR_CACHE_PATH_NAME = ".torchair_cache"
10+
TORCHAIR_CACHE_DIR = os.getenv(
11+
'TORCHAIR_CACHE_HOME', os.path.join(os.getcwd(), TORCHAIR_CACHE_PATH_NAME))
12+
13+
14+
def get_torchair_current_work_dir(file_name=None):
15+
if file_name is None:
16+
return TORCHAIR_CACHE_DIR
17+
return os.path.join(TORCHAIR_CACHE_DIR, file_name)
18+
19+
20+
def check_torchair_cache_exist():
21+
res = False
22+
torch_air_abs_path = get_torchair_current_work_dir()
23+
if os.path.exists(torch_air_abs_path):
24+
file_list = os.listdir(torch_air_abs_path)
25+
if len(file_list) != 0:
26+
res = True
27+
return res
28+
29+
30+
def check_kv_cache_bytes_cache_exist():
31+
res = False
32+
kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
33+
KV_CACHE_BYTES_CACHE_PATH_NAME)
34+
if os.path.exists(kv_cache_bytes_cache_abs_path):
35+
file_list = os.listdir(kv_cache_bytes_cache_abs_path)
36+
if len(file_list) != 0:
37+
res = True
38+
return res
39+
40+
41+
def read_kv_cache_bytes_from_file(rank) -> int:
42+
kv_cache_bytes = -1
43+
kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
44+
KV_CACHE_BYTES_CACHE_PATH_NAME)
45+
kv_cache_bytes_file = os.path.join(
46+
kv_cache_bytes_cache_abs_path,
47+
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
48+
with open(kv_cache_bytes_file, "r", encoding="utf-8") as f:
49+
with file_lock(f, fcntl.LOCK_SH):
50+
kv_cache_bytes = int(f.readline())
51+
return kv_cache_bytes
52+
53+
54+
@contextmanager
55+
def file_lock(file_descriptor, lock_type):
56+
fcntl.flock(file_descriptor, lock_type)
57+
try:
58+
yield
59+
finally:
60+
fcntl.flock(file_descriptor, fcntl.LOCK_UN)
61+
62+
63+
def write_kv_cache_bytes_to_file(rank, kv_cache_bytes):
64+
kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
65+
KV_CACHE_BYTES_CACHE_PATH_NAME)
66+
os.makedirs(kv_cache_bytes_cache_abs_path, exist_ok=True)
67+
kv_cache_bytes_file = os.path.join(
68+
kv_cache_bytes_cache_abs_path,
69+
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
70+
with open(kv_cache_bytes_file, "w", encoding="utf-8") as f:
71+
with file_lock(f, fcntl.LOCK_EX):
72+
f.write(f"{kv_cache_bytes}")
73+
74+
75+
def delete_torchair_cache_file():
76+
torch_air_abs_path = get_torchair_current_work_dir()
77+
if os.path.exists(torch_air_abs_path):
78+
shutil.rmtree(torch_air_abs_path)
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# This file is a part of the vllm-ascend project.
17+
# Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py
18+
#
19+
20+
import torch
21+
import torch_npu
22+
from vllm.logger import logger
23+
24+
import vllm_ascend.envs as envs_ascend
25+
from vllm_ascend.ascend_config import get_ascend_config
26+
from vllm_ascend.device_allocator.camem import CaMemAllocator
27+
from vllm_ascend.platform import NPUPlatform
28+
from vllm_ascend.torchair.utils import (check_kv_cache_bytes_cache_exist,
29+
check_torchair_cache_exist,
30+
delete_torchair_cache_file,
31+
read_kv_cache_bytes_from_file,
32+
)
33+
from vllm_ascend.torchair.model_runner_torchair import NPUTorchairModelRunner
34+
from vllm_ascend.worker.worker_v1 import NPUWorker
35+
36+
class NPUTorchairWorker(NPUWorker):
37+
def init_device(self):
38+
device = torch.device(f"npu:{self.local_rank}")
39+
NPUPlatform.set_device(device)
40+
NPUPlatform.empty_cache()
41+
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
42+
43+
# Initialize the distributed environment.
44+
self._init_worker_distributed_environment()
45+
# Set random seed.
46+
NPUPlatform.seed_everything(self.model_config.seed)
47+
48+
# Init ModelRunner here, so that we have access to self.device.
49+
self.model_runner = NPUTorchairModelRunner(self.vllm_config, device)
50+
51+
def determine_available_memory(self) -> int:
52+
# Profile the memory usage of the model and get the maximum number of
53+
# cache blocks that can be allocated with the remaining free memory.
54+
NPUPlatform.clear_npu_memory()
55+
56+
# Execute a forward pass with dummy inputs to profile the memory usage
57+
# of the model.
58+
_, total_npu_memory = NPUPlatform.mem_get_info()
59+
self.model_runner.profile_run()
60+
61+
# Calculate the number of blocks that can be allocated with the
62+
# profiled peak memory.
63+
free_npu_memory, _ = NPUPlatform.mem_get_info()
64+
# NOTE(woosuk): Here we assume that the other processes using the same
65+
# GPU did not change their memory usage during the profiling.
66+
assert self.init_npu_memory > free_npu_memory, (
67+
"Error in memory profiling. "
68+
f"Initial free memory {self.init_npu_memory}, current free memory"
69+
f" {free_npu_memory}. This happens when the NPU memory was "
70+
"not properly cleaned up before initializing the vLLM instance.")
71+
72+
# Get the peak memory allocation recorded by torch
73+
peak_memory = torch_npu.npu.memory_stats()["allocated_bytes.all.peak"]
74+
# TODO: don`t need impl this func after empty_cache in
75+
# Worker.determine_num_available_blocks() unified`
76+
NPUPlatform.empty_cache()
77+
torch_allocated_bytes = torch_npu.npu.memory_stats(
78+
)["allocated_bytes.all.current"]
79+
total_allocated_bytes = torch_npu.npu.mem_get_info(
80+
)[1] - torch_npu.npu.mem_get_info()[0]
81+
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
82+
if non_torch_allocations > 0:
83+
peak_memory += non_torch_allocations
84+
available_kv_cache_memory = int(
85+
total_npu_memory * self.cache_config.gpu_memory_utilization -
86+
peak_memory)
87+
available_kv_cache_memory = int(max(available_kv_cache_memory, 0))
88+
logger.info(
89+
f"Available memory: {available_kv_cache_memory}, total memory: {total_npu_memory}"
90+
)
91+
if get_ascend_config().torchair_graph_config.enabled:
92+
if check_torchair_cache_exist(
93+
) and check_kv_cache_bytes_cache_exist():
94+
old_kv_cache_bytes = read_kv_cache_bytes_from_file(
95+
torch.distributed.get_rank())
96+
if 0 < old_kv_cache_bytes <= available_kv_cache_memory:
97+
logger.info(
98+
f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}"
99+
)
100+
self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes
101+
return old_kv_cache_bytes
102+
else:
103+
logger.info(
104+
"Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache"
105+
)
106+
delete_torchair_cache_file()
107+
bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE
108+
available_kv_cache_memory -= bytes_floating_tolerance
109+
logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}")
110+
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory
111+
112+
return available_kv_cache_memory
113+
114+
def execute_dummy_batch(self) -> None:
115+
runner = self.model_runner
116+
max_num_tokens = 1
117+
with_prefill = False
118+
if runner.dp_size > 1:
119+
max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp(
120+
max_num_tokens, with_prefill)
121+
if runner.torchair_graph_enabled and not with_prefill:
122+
max_num_tokens = runner.select_torchair_padded_batch_size(
123+
max_num_tokens)
124+
runner._dummy_run(max_num_tokens,
125+
is_compile=False,
126+
with_prefill=with_prefill)

vllm_ascend/utils.py

Lines changed: 0 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@
1818
#
1919

2020
import atexit
21-
import fcntl
2221
import math
23-
import os
24-
import shutil
2522
from contextlib import contextmanager, nullcontext
2623
from enum import Enum
2724
from threading import Lock
@@ -443,77 +440,3 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool,
443440
return FusedMoEState.All2All
444441
else:
445442
return FusedMoEState.MC2
446-
447-
448-
KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
449-
KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes"
450-
TORCHAIR_CACHE_PATH_NAME = ".torchair_cache"
451-
TORCHAIR_CACHE_DIR = os.getenv(
452-
'TORCHAIR_CACHE_HOME', os.path.join(os.getcwd(), TORCHAIR_CACHE_PATH_NAME))
453-
454-
455-
def get_torchair_current_work_dir(file_name=None):
456-
if file_name is None:
457-
return TORCHAIR_CACHE_DIR
458-
return os.path.join(TORCHAIR_CACHE_DIR, file_name)
459-
460-
461-
def check_torchair_cache_exist():
462-
res = False
463-
torch_air_abs_path = get_torchair_current_work_dir()
464-
if os.path.exists(torch_air_abs_path):
465-
file_list = os.listdir(torch_air_abs_path)
466-
if len(file_list) != 0:
467-
res = True
468-
return res
469-
470-
471-
def check_kv_cache_bytes_cache_exist():
472-
res = False
473-
kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
474-
KV_CACHE_BYTES_CACHE_PATH_NAME)
475-
if os.path.exists(kv_cache_bytes_cache_abs_path):
476-
file_list = os.listdir(kv_cache_bytes_cache_abs_path)
477-
if len(file_list) != 0:
478-
res = True
479-
return res
480-
481-
482-
def read_kv_cache_bytes_from_file(rank) -> int:
483-
kv_cache_bytes = -1
484-
kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
485-
KV_CACHE_BYTES_CACHE_PATH_NAME)
486-
kv_cache_bytes_file = os.path.join(
487-
kv_cache_bytes_cache_abs_path,
488-
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
489-
with open(kv_cache_bytes_file, "r", encoding="utf-8") as f:
490-
with file_lock(f, fcntl.LOCK_SH):
491-
kv_cache_bytes = int(f.readline())
492-
return kv_cache_bytes
493-
494-
495-
@contextmanager
496-
def file_lock(file_descriptor, lock_type):
497-
fcntl.flock(file_descriptor, lock_type)
498-
try:
499-
yield
500-
finally:
501-
fcntl.flock(file_descriptor, fcntl.LOCK_UN)
502-
503-
504-
def write_kv_cache_bytes_to_file(rank, kv_cache_bytes):
505-
kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
506-
KV_CACHE_BYTES_CACHE_PATH_NAME)
507-
os.makedirs(kv_cache_bytes_cache_abs_path, exist_ok=True)
508-
kv_cache_bytes_file = os.path.join(
509-
kv_cache_bytes_cache_abs_path,
510-
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
511-
with open(kv_cache_bytes_file, "w", encoding="utf-8") as f:
512-
with file_lock(f, fcntl.LOCK_EX):
513-
f.write(f"{kv_cache_bytes}")
514-
515-
516-
def delete_torchair_cache_file():
517-
torch_air_abs_path = get_torchair_current_work_dir()
518-
if os.path.exists(torch_air_abs_path):
519-
shutil.rmtree(torch_air_abs_path)

0 commit comments

Comments
 (0)