25
25
from vllm .v1 .request import Request
26
26
27
27
import llm_datadist # type: ignore
28
- from llm_datadist import LLMException , LLMStatusCode
28
+ from llm_datadist import LLMConfig , LLMException , LLMStatusCode
29
29
30
- import vllm_ascend .envs as envs
30
+ import vllm_ascend .envs as envs_ascend
31
31
from vllm_ascend .attention .mla_v1 import AscendMLAMetadata
32
32
33
33
TORCH_DTYPE_TO_NPU_DTYPE = {
41
41
torch .int32 : llm_datadist .DataType .DT_INT32 ,
42
42
}
43
43
44
- GLOBAL_RANKTABLE = envs . GLOBAL_RANKTABLE
44
+ GLOBAL_RANKTABLE = envs_ascend . LLMDATADIST_GLOBAL_RANKTABLE
45
45
46
46
47
47
class ServerRole (enum .Enum ):
@@ -289,16 +289,19 @@ def __init__(self, role: llm_datadist.LLMRole, local_rank: int,
289
289
self .role , self .cluster_id )
290
290
291
291
def prepare_data_dist (self ):
292
- # TODO: The maximum size of the mbuf for the llm datadist. We need to
293
- # find an appropriate value to minimize memory waste.
294
- options = {
295
- "llm.SyncKvCacheWaitTime" : envs .LLMDATADIST_SYNC_CACHE_WAIT_TIME ,
296
- "ge.flowGraphMemMaxSize" : f"{ int (2.25 * 1024 * 1024 * 1024 ):d} " ,
292
+ buff_size = envs_ascend .LLMDATADIST_BUFFSIZE_MB * 1024 * 1024
293
+ llm_config = LLMConfig ()
294
+ llm_config .ge_options = {
295
+ "llm.SyncKvCacheWaitTime" :
296
+ envs_ascend .LLMDATADIST_SYNC_CACHE_WAIT_TIME ,
297
+ "ge.flowGraphMemMaxSize" : f"{ buff_size :d} " ,
297
298
"ge.exec.deviceId" : str (self .local_rank ),
298
299
}
300
+ llm_config .buf_pool_cfg = '{"buf_cfg": [{"total_size":2097152,"blk_size":256,"max_buf_size":256}]}'
299
301
if self .role == llm_datadist .LLMRole .PROMPT :
300
- options ["llm.listenIpInfo" ] = f"{ self .local_device_ip } :26000"
301
- self .datadist_engine .init (options )
302
+ llm_config .listen_ip_info = f"{ self .local_device_ip } :26000"
303
+ engine_options = llm_config .generate_options ()
304
+ self .datadist_engine .init (engine_options )
302
305
logger .info ("llm_datadist init done" )
303
306
self .kv_transfer = self .datadist_engine .kv_cache_manager
304
307
@@ -869,7 +872,7 @@ def get_num_new_matched_tokens(
869
872
self ,
870
873
request : "Request" ,
871
874
num_computed_tokens : int ,
872
- ) -> int :
875
+ ) -> tuple [ int , bool ] :
873
876
"""
874
877
Get number of new tokens that can be loaded from the external KV cache
875
878
beyond the num_computed_tokens.
@@ -887,15 +890,15 @@ def get_num_new_matched_tokens(
887
890
# the block granularity. And it expects the returned blocks and
888
891
# num_computed_tokens to also be aligned with the block granularity.
889
892
890
- # NOTE: only request in waiting queue will come here. we use datadist
893
+ # NOTE: only requests in waiting queue will come here. we use datadist
891
894
# pull cache to do transfer, so we don't align to block_size in prefill,
892
895
# we won't have extra new matched tokens; in decode, new request kv
893
896
# cache will be transferred from prefill, so num_computed_tokens = 0,
894
897
# and extra new matched tokens should be len(request.prompt_token_ids) -
895
898
# 1
896
899
if self .kv_role == llm_datadist .LLMRole .PROMPT :
897
- return 0
898
- return len (request .prompt_token_ids ) - 1
900
+ return 0 , False
901
+ return len (request .prompt_token_ids ) - 1 , False
899
902
900
903
def update_state_after_alloc (self , request : "Request" ,
901
904
num_external_tokens : int ):
0 commit comments