Skip to content

Commit 89a485b

Browse files
[Feature] Support using prefix-caching + cudagraph for inference (#2924)
* fix the bug in cudagraph+prefix-caching but still have some bug with profile Change-Id: Ibf2ba3f2e3b08641d03f4b1391d7c862c3efa397 * add the signal to make sure cache manager launched * fix judge condition * reomove useless control * update control stream * update * fix xpu * change the do_profile flag * update * add new threads to init cache_manager --------- Co-authored-by: RAM <gstian5555@outlook.com>
1 parent 48e6a0c commit 89a485b

11 files changed

+63
-65
lines changed

fastdeploy/engine/args_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -863,8 +863,6 @@ def create_engine_config(self) -> Config:
863863
graph_opt_cfg = self.create_graph_optimization_config()
864864
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
865865

866-
assert not (self.use_cudagraph and self.enable_prefix_caching), "Prefix caching cannot be used with CUDA graph"
867-
868866
assert not (
869867
self.tensor_parallel_size <= 1 and self.enable_custom_all_reduce
870868
), "enable_custom_all_reduce must be used with tensor_parallel_size>1"

fastdeploy/engine/engine.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def start(self, api_server_pid=None):
183183
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
184184
pid_suffix=self.ipc_signal_suffix,
185185
)
186+
self.launched_cache_manager_signal.value[0] = 1
186187

187188
self.worker_proc = self._start_worker_service()
188189
console_logger.info("Waitting worker processes ready...")
@@ -217,9 +218,6 @@ def start(self, api_server_pid=None):
217218
# Start TokenProcessor thread
218219
self.token_processor.run()
219220

220-
if self.do_profile:
221-
self._stop_profile()
222-
223221
if self.cfg.splitwise_role != "mixed":
224222
# 单机逻辑
225223
self.engine_worker_queue.available_prefill_instances.put(1)
@@ -849,6 +847,17 @@ def _init_worker_signals(self):
849847
create=True,
850848
)
851849

850+
# launched_cache_manager_signal 用于感知engine是否启动了cache_manager
851+
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
852+
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
853+
self.launched_cache_manager_signal = IPCSignal(
854+
name="launched_cache_manager_signal",
855+
array=launched_cache_manager_signal_data,
856+
dtype=np.int32,
857+
suffix=self.ipc_signal_suffix,
858+
create=True,
859+
)
860+
852861
# worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间
853862
worker_healthy_live_recorded_time_array = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32)
854863
self.worker_healthy_live_signal = IPCSignal(
@@ -1133,6 +1142,7 @@ def _stop_profile(self):
11331142
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
11341143
pid_suffix=self.ipc_signal_suffix,
11351144
)
1145+
self.launched_cache_manager_signal.value[0] = 1
11361146

11371147
def check_health(self, time_interval_threashold=30):
11381148
"""
@@ -1171,6 +1181,10 @@ def detect_thread():
11711181

11721182
self.checking_worker_status_thread = threading.Thread(target=detect_thread, daemon=True)
11731183
self.checking_worker_status_thread.start()
1184+
checking_worker_init_kv_cache_status_thread = None
1185+
if self.do_profile:
1186+
checking_worker_init_kv_cache_status_thread = threading.Thread(target=self._stop_profile, daemon=True)
1187+
checking_worker_init_kv_cache_status_thread.start()
11741188

11751189
# display weight loadding progress
11761190
with tqdm(total=100, desc="Loading Weights") as pbar:
@@ -1201,6 +1215,8 @@ def detect_thread():
12011215
self.worker_init_status["finished"] = True
12021216
try:
12031217
self.checking_worker_status_thread.join(timeout=1)
1218+
if checking_worker_init_kv_cache_status_thread is not None:
1219+
checking_worker_init_kv_cache_status_thread.join(timeout=1)
12041220
except Exception:
12051221
pass
12061222
return True

fastdeploy/worker/gcu_model_runner.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
151151
"""
152152
Process inputs for prefill tasks and insert it to share_inputs buffer
153153
"""
154-
if "caches" not in self.share_inputs:
155-
self.initialize_kv_cache()
156154

157155
if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill":
158156
os.environ["PREFILL_NODE_ONE_STEP_STOP"] = "1"
@@ -561,7 +559,7 @@ def update_parameters(self, pid):
561559
self.initialize_kv_cache()
562560
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
563561

564-
def initialize_kv_cache(self) -> None:
562+
def initialize_kv_cache(self, profile: bool = False) -> None:
565563
"""
566564
Initialize kv cache
567565
"""
@@ -582,7 +580,7 @@ def initialize_kv_cache(self) -> None:
582580
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num)
583581
# local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
584582

585-
if not self.parallel_config.do_profile and (
583+
if not profile and (
586584
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
587585
):
588586
raise NotImplementedError("prefix_caching is not support by GCUModelRunner.")
@@ -1012,7 +1010,7 @@ def profile_run(self) -> None:
10121010

10131011
# Initialize kv cache for profile run. After profile run kv cache will be reset.
10141012
self.num_gcu_blocks = self.parallel_config.total_block_num
1015-
self.initialize_kv_cache()
1013+
self.initialize_kv_cache(profile=True)
10161014

10171015
# 1. Profile with multimodal encoder & encoder cache
10181016

@@ -1038,8 +1036,7 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
10381036
self.num_gcu_blocks = num_gpu_blocks
10391037

10401038
# Reset block table and kv cache with global block num
1041-
if not (self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
1042-
self.initialize_kv_cache()
1039+
self.initialize_kv_cache()
10431040

10441041
# Reset free list
10451042
free_list = list(
@@ -1057,8 +1054,6 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
10571054
}
10581055
)
10591056

1060-
self.parallel_config.do_profile = False
1061-
10621057
if self.speculative_method in ["mtp"]:
10631058
self.proposer.update_block_num(num_gpu_blocks)
10641059

fastdeploy/worker/gcu_worker.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,9 @@ def get_model(self) -> nn.Layer:
9898
""" """
9999
return self.model_runner.get_model()
100100

101-
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
101+
def initialize_cache(self, num_gpu_blocks: int) -> None:
102102
""" """
103-
pass
103+
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
104104

105105
def execute_model(
106106
self,
@@ -134,7 +134,3 @@ def check_health(self) -> bool:
134134
def cal_theortical_kvcache(self) -> int:
135135
""" """
136136
return self.model_runner.cal_theortical_kvcache()
137-
138-
def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None:
139-
""" """
140-
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)

fastdeploy/worker/gpu_model_runner.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
193193
Process inputs for prefill tasks and insert it to share_inputs buffer
194194
TODO(gongshaotian): Refactor this func
195195
"""
196-
# NOTE(luotingdan): Lazy initialize kv cache
197-
if "caches" not in self.share_inputs:
198-
self.initialize_kv_cache()
199196

200197
# NOTE(luotingdan): Set environment variable of prefill node
201198
if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill":
@@ -700,7 +697,7 @@ def initialize_forward_meta(self):
700697
for attn_backend in self.attn_backends:
701698
attn_backend.init_attention_metadata(self.forward_meta)
702699

703-
def initialize_kv_cache(self) -> None:
700+
def initialize_kv_cache(self, profile: bool = False) -> None:
704701
"""
705702
Initialize kv cache
706703
"""
@@ -721,7 +718,7 @@ def initialize_kv_cache(self) -> None:
721718
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num)
722719
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
723720

724-
if not self.parallel_config.do_profile and (
721+
if not profile and (
725722
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
726723
):
727724
cache_kvs_list = []
@@ -739,7 +736,6 @@ def initialize_kv_cache(self) -> None:
739736

740737
else:
741738
for i in range(self.model_config.num_hidden_layers):
742-
743739
cache_kvs[f"key_caches_{i}"] = paddle.full(
744740
shape=kv_cache_shape,
745741
fill_value=0,
@@ -1218,7 +1214,7 @@ def profile_run(self) -> None:
12181214
# Initialize kv cache for profile run. After profile run kv cache will be reset.
12191215
# TODO(gongshaotian): Optimize the management logic of kvcache
12201216
self.num_gpu_blocks = self.parallel_config.total_block_num
1221-
self.initialize_kv_cache()
1217+
self.initialize_kv_cache(profile=True)
12221218

12231219
# 1. Profile with multimodal encoder & encoder cache
12241220

@@ -1243,8 +1239,7 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
12431239
self.num_gpu_blocks = num_gpu_blocks
12441240

12451241
# Reset block table and kv cache with global block num
1246-
if not (self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
1247-
self.initialize_kv_cache()
1242+
self.initialize_kv_cache()
12481243

12491244
# Reset free list
12501245
free_list = list(
@@ -1262,8 +1257,6 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
12621257
}
12631258
)
12641259

1265-
self.parallel_config.do_profile = False
1266-
12671260
if self.speculative_method in ["mtp"]:
12681261
self.proposer.update_block_num(num_gpu_blocks)
12691262

fastdeploy/worker/gpu_worker.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,10 @@ def get_model(self) -> nn.Layer:
165165
"""Get current model"""
166166
return self.model_runner.get_model()
167167

168-
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
169-
"""Initizlize the KV Cache"""
170-
pass
168+
def initialize_cache(self, num_gpu_blocks: int) -> None:
169+
"""Initizlize the KV Cache with accurate num_gpu_blocks"""
170+
# accurate cache size
171+
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
171172

172173
def execute_model(
173174
self,
@@ -198,7 +199,3 @@ def check_health(self) -> bool:
198199
def cal_theortical_kvcache(self) -> int:
199200
"""Calculate the block memory required"""
200201
return self.model_runner.cal_theortical_kvcache()
201-
202-
def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None:
203-
"""Reinitialize the kv cache using the parameters from the profile"""
204-
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)

fastdeploy/worker/iluvatar_model_runner.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
141141
Process inputs for prefill tasks and insert it to share_inputs buffer
142142
TODO(gongshaotian): Refactor this func
143143
"""
144-
# NOTE(luotingdan): Lazy initialize kv cache
145-
if "caches" not in self.share_inputs:
146-
self.initialize_kv_cache()
147144

148145
# NOTE(luotingdan): Set environment variable of prefill node
149146
if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill":
@@ -552,7 +549,7 @@ def clear_cache(self):
552549
if self.forward_meta is not None:
553550
self.forward_meta.clear_caches()
554551

555-
def initialize_kv_cache(self) -> None:
552+
def initialize_kv_cache(self, profile: bool = False) -> None:
556553
"""
557554
Initialize kv cache
558555
"""
@@ -992,7 +989,7 @@ def profile_run(self) -> None:
992989
# Initialize kv cache for profile run. After profile run kv cache will be reset.
993990
# TODO(gongshaotian): Optimize the management logic of kvcache
994991
self.num_gpu_blocks = self.parallel_config.total_block_num
995-
self.initialize_kv_cache()
992+
self.initialize_kv_cache(profile=True)
996993

997994
# 1. Profile with multimodal encoder & encoder cache
998995

@@ -1016,8 +1013,7 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
10161013
self.num_gpu_blocks = num_gpu_blocks
10171014

10181015
# Reset block table and kv cache with global block num
1019-
if not (self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
1020-
self.initialize_kv_cache()
1016+
self.initialize_kv_cache()
10211017

10221018
# Reset free list
10231019
free_list = list(
@@ -1035,8 +1031,6 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
10351031
}
10361032
)
10371033

1038-
self.parallel_config.do_profile = False
1039-
10401034
def cal_theortical_kvcache(self):
10411035
"""
10421036
Calculate the total block memory required at the model level

fastdeploy/worker/iluvatar_worker.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,9 @@ def get_model(self) -> nn.Layer:
9999
""" """
100100
return self.model_runner.get_model()
101101

102-
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
102+
def initialize_cache(self, num_gpu_blocks: int) -> None:
103103
""" """
104-
pass
104+
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
105105

106106
def execute_model(
107107
self,
@@ -135,7 +135,3 @@ def check_health(self) -> bool:
135135
def cal_theortical_kvcache(self) -> int:
136136
""" """
137137
return self.model_runner.cal_theortical_kvcache()
138-
139-
def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None:
140-
""" """
141-
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)

fastdeploy/worker/worker_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def init_device(self) -> None:
6464
raise NotImplementedError
6565

6666
@abstractmethod
67-
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
67+
def initialize_cache(self, num_gpu_blocks: int) -> None:
6868
"""Initizlize the KV Cache with the given size in blocks."""
6969
raise NotImplementedError
7070

fastdeploy/worker/worker_process.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def event_loop_normal(self) -> None:
347347

348348
self.exist_prefill_task_signal.value[0] = self.worker.prefill_finished()
349349

350-
def determine_num_available_blocks(self) -> None:
350+
def initialize_kv_cache(self) -> None:
351351
"""Profiles the peak memory usage of the model to determine how many
352352
KV blocks may be allocated without OOMs.
353353
@@ -400,8 +400,25 @@ def determine_num_available_blocks(self) -> None:
400400
self.get_profile_block_num_signal.value[0] = num_blocks_local
401401
else:
402402
num_blocks_local = self.fd_config.parallel_config.total_block_num
403-
# 4. Updata share inputs
404-
self.worker.reinitialize_kv_cache(num_gpu_blocks=num_blocks_local)
403+
404+
logger.info(f"------- num_blocks_global: {num_blocks_local} --------")
405+
# wait engine launch cache_manager
406+
if self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed":
407+
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
408+
self.launched_cache_manager_signal = IPCSignal(
409+
name="launched_cache_manager_signal",
410+
array=launched_cache_manager_signal_data,
411+
dtype=np.int32,
412+
suffix=self.parallel_config.engine_pid,
413+
create=False,
414+
)
415+
while np.any(self.launched_cache_manager_signal.value[0] <= 0):
416+
time.sleep(0.01)
417+
# 4. init kv_cache with accurate num_blocks
418+
self.worker.initialize_cache(num_gpu_blocks=num_blocks_local)
419+
420+
def graph_optimize_and_warm_up_model(self) -> None:
421+
self.worker.graph_optimize_and_warm_up_model()
405422

406423
def init_device(self) -> None:
407424
"""Initialize device and Construct model runner"""
@@ -714,8 +731,8 @@ def run_worker_proc() -> None:
714731

715732
# Load model
716733
worker_proc.load_model()
717-
logger.info("determine_num_available_blocks")
718-
worker_proc.determine_num_available_blocks()
734+
# Initialize KV Cache
735+
worker_proc.initialize_kv_cache()
719736

720737
# Trigger CUDAGraph capture
721738
worker_proc.worker.graph_optimize_and_warm_up_model()

0 commit comments

Comments
 (0)