From 94e54032c2bdf0c116c77605d014be274b73824c Mon Sep 17 00:00:00 2001 From: zeroRains Date: Fri, 18 Jul 2025 16:43:18 +0800 Subject: [PATCH 01/10] fix the bug in cudagraph+prefix-caching but still have some bug with profile Change-Id: Ibf2ba3f2e3b08641d03f4b1391d7c862c3efa397 --- fastdeploy/engine/args_utils.py | 2 -- fastdeploy/worker/gpu_model_runner.py | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index cdd9e81d93..5159c133c7 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -795,8 +795,6 @@ def create_engine_config(self) -> Config: graph_opt_cfg = self.create_graph_optimization_config() graph_opt_cfg.update_use_cudagraph(self.use_cudagraph) - assert not (self.use_cudagraph and self.enable_prefix_caching), \ - "Prefix caching cannot be used with CUDA graph" assert not (self.tensor_parallel_size<=1 and self.enable_custom_all_reduce), \ "enable_custom_all_reduce must be used with tensor_parallel_size>1" diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 8ad834c708..318a375f9e 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1078,6 +1078,7 @@ def capture_model(self) -> None: time_before_capture = time.perf_counter() expected_decode_len = 1 capture_sizes = self.cudagraph_capture_sizes.copy() + self.initialize_kv_cache() for batch_size in sorted(capture_sizes, reverse=True): self._dummy_run(num_tokens=self.parallel_config.max_model_len, batch_size=batch_size, @@ -1086,7 +1087,7 @@ def capture_model(self) -> None: logger.info( f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}" ) - + self.clear_cache() time_after_capture = time.perf_counter() logger.info( f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds" From 6080e07b5e1f257b9124a63039c80bb4a5bfd5eb Mon Sep 17 00:00:00 2001 From: zeroRains Date: Sat, 19 Jul 2025 15:20:16 +0800 Subject: [PATCH 02/10] add the signal to make sure cache manager launched --- fastdeploy/engine/engine.py | 18 +++++++++++++++ fastdeploy/worker/gpu_model_runner.py | 7 ++++-- fastdeploy/worker/worker_process.py | 32 ++++++++++++++++++++------- 3 files changed, 47 insertions(+), 10 deletions(-) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 414a7b2092..5dcece2294 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -178,6 +178,7 @@ def start(self, api_server_pid=None): pod_ip=self.cfg.master_ip, engine_worker_queue_port=self.cfg.engine_worker_queue_port, pid_suffix=self.ipc_signal_suffix) + self.launched_cache_manager_signal.value[0] = 1 self.worker_proc = self._start_worker_service() console_logger.info("Waitting worker processes ready...") @@ -869,6 +870,16 @@ def _init_worker_signals(self): suffix=self.ipc_signal_suffix, create=True) + # launched_cache_manager_signal 用于感知engine是否启动了cache_manager + if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed": + launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32) + self.launched_cache_manager_signal = IPCSignal( + name="launched_cache_manager_signal", + array=launched_cache_manager_signal_data, + dtype=np.int32, + suffix=self.ipc_signal_suffix, + create=True) + # worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间 worker_healthy_live_recorded_time_array = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32) @@ -1165,6 +1176,9 @@ def _stop_profile(self): pod_ip=self.cfg.master_ip, engine_worker_queue_port=self.cfg.engine_worker_queue_port, pid_suffix=self.ipc_signal_suffix) + self.launched_cache_manager_signal.value[0] = 1 + + def check_health(self, time_interval_threashold=30): """ Check the health of the model server by checking whether all workers are alive. @@ -1203,6 +1217,10 @@ def detect_thread(): if self.worker_init_status[ "layer_loadding"] == self.cfg.model_config.num_layers - 1: self.worker_init_status["finished"] = True + elif match := re.search(r'num_blocks_global', + line): + if self.do_profile: + self._stop_profile() self.checking_worker_status_thread = threading.Thread( target=detect_thread, daemon=True) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 318a375f9e..946f316d75 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1078,7 +1078,9 @@ def capture_model(self) -> None: time_before_capture = time.perf_counter() expected_decode_len = 1 capture_sizes = self.cudagraph_capture_sizes.copy() - self.initialize_kv_cache() + need_init_cache = "caches" not in self.share_inputs + if need_init_cache: + self.initialize_kv_cache() for batch_size in sorted(capture_sizes, reverse=True): self._dummy_run(num_tokens=self.parallel_config.max_model_len, batch_size=batch_size, @@ -1087,7 +1089,8 @@ def capture_model(self) -> None: logger.info( f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}" ) - self.clear_cache() + if need_init_cache: + self.clear_cache() time_after_capture = time.perf_counter() logger.info( f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds" diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 6123a37b47..81b22290a4 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -397,23 +397,39 @@ def determine_num_available_blocks(self) -> None: if num_blocks_global < 0: logger.error( - f"The total number of blocks cannot be less than zero." - f"Please increase gpu_memory_utilization" - f"Or decrease max_num_batched_tokens(max model length) ") + "The total number of blocks cannot be less than zero." + "Please increase gpu_memory_utilization" + "Or decrease max_num_batched_tokens(max model length) ") raise ValueError( - f"The total number of blocks cannot be less than zero." - f"Please increase gpu_memory_utilization" - f"Or decrease max_num_batched_tokens(max model length) ") - + "The total number of blocks cannot be less than zero." + "Please increase gpu_memory_utilization" + "Or decrease max_num_batched_tokens(max model length) ") + self.get_profile_block_num_signal.value[ self.local_rank] = num_blocks_global else: num_blocks_global = self.fd_config.parallel_config.total_block_num + # logger.info will write in worker_process.log + # Need `print` to triger engine->check_worker_initialize_status->detect_thread + print(f"------- num_blocks_global: {num_blocks_global} --------") # NOTE(liuzichang): Too big num_blocks_global will lead to error 700 # 4. Updata share inputs self.worker.reinitialize_kv_cache(num_gpu_blocks=num_blocks_global) + def graph_optimize_and_warm_up_model(self) ->None: + if self.parallel_config.enable_prefix_caching: + launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32) + self.launched_cache_manager_signal = IPCSignal( + name="launched_cache_manager_signal", + array=launched_cache_manager_signal_data, + dtype=np.int32, + suffix=self.parallel_config.engine_pid, + create=False) + while np.any(self.launched_cache_manager_signal.value[0] <= 0): + time.sleep(0.01) + self.worker.graph_optimize_and_warm_up_model() + def init_device(self) -> None: """ Initialize device and Construct model runner """ self.worker.init_device() @@ -729,7 +745,7 @@ def run_worker_proc() -> None: worker_proc.determine_num_available_blocks() # Trigger CUDAGraph capture - worker_proc.worker.graph_optimize_and_warm_up_model() + worker_proc.graph_optimize_and_warm_up_model() # Initialize health status worker_proc.init_health_status() From 560ab7d9711decbfe3d8f51749680b6f33273746 Mon Sep 17 00:00:00 2001 From: zeroRains Date: Sat, 19 Jul 2025 15:35:12 +0800 Subject: [PATCH 03/10] fix judge condition --- fastdeploy/worker/worker_process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 81b22290a4..4d1d4d2b20 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -418,7 +418,7 @@ def determine_num_available_blocks(self) -> None: self.worker.reinitialize_kv_cache(num_gpu_blocks=num_blocks_global) def graph_optimize_and_warm_up_model(self) ->None: - if self.parallel_config.enable_prefix_caching: + if self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed": launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32) self.launched_cache_manager_signal = IPCSignal( name="launched_cache_manager_signal", From e0ffe2ebd73aab35145a77a09f0f65589890ed6a Mon Sep 17 00:00:00 2001 From: zeroRains Date: Sun, 20 Jul 2025 12:48:16 +0800 Subject: [PATCH 04/10] reomove useless control --- fastdeploy/engine/engine.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 5ce47cb973..48b03d9329 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -218,9 +218,6 @@ def start(self, api_server_pid=None): # Start TokenProcessor thread self.token_processor.run() - if self.do_profile: - self._stop_profile() - if self.cfg.splitwise_role != "mixed": # 单机逻辑 self.engine_worker_queue.available_prefill_instances.put(1) From 9dad411e56d9544c9e9ec47ef3a48fa0eddd51cc Mon Sep 17 00:00:00 2001 From: zeroRains Date: Mon, 21 Jul 2025 21:39:04 +0800 Subject: [PATCH 05/10] update control stream --- fastdeploy/worker/gcu_model_runner.py | 8 ++------ fastdeploy/worker/gcu_worker.py | 8 ++------ fastdeploy/worker/gpu_model_runner.py | 16 +++------------- fastdeploy/worker/gpu_worker.py | 11 ++++------- fastdeploy/worker/iluvatar_model_runner.py | 9 ++------- fastdeploy/worker/iluvatar_worker.py | 8 ++------ fastdeploy/worker/worker_base.py | 2 +- fastdeploy/worker/worker_process.py | 17 +++++++++-------- 8 files changed, 25 insertions(+), 54 deletions(-) diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index b58c2237fe..0a43c1f631 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -151,8 +151,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request]): """ Process inputs for prefill tasks and insert it to share_inputs buffer """ - if "caches" not in self.share_inputs: - self.initialize_kv_cache() if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill": os.environ["PREFILL_NODE_ONE_STEP_STOP"] = "1" @@ -1035,11 +1033,11 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None: Args: num_gpu_blocks: """ + self.parallel_config.do_profile = False self.num_gcu_blocks = num_gpu_blocks # Reset block table and kv cache with global block num - if not (self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): - self.initialize_kv_cache() + self.initialize_kv_cache() # Reset free list free_list = list( @@ -1057,8 +1055,6 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None: } ) - self.parallel_config.do_profile = False - if self.speculative_method in ["mtp"]: self.proposer.update_block_num(num_gpu_blocks) diff --git a/fastdeploy/worker/gcu_worker.py b/fastdeploy/worker/gcu_worker.py index 2e4e83885b..1b98e3b0c8 100644 --- a/fastdeploy/worker/gcu_worker.py +++ b/fastdeploy/worker/gcu_worker.py @@ -98,9 +98,9 @@ def get_model(self) -> nn.Layer: """ """ return self.model_runner.get_model() - def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: + def initialize_cache(self, num_gpu_blocks: int) -> None: """ """ - pass + self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks) def execute_model( self, @@ -134,7 +134,3 @@ def check_health(self) -> bool: def cal_theortical_kvcache(self) -> int: """ """ return self.model_runner.cal_theortical_kvcache() - - def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None: - """ """ - self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 8ed7f24100..2ab567f95b 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -193,9 +193,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request]): Process inputs for prefill tasks and insert it to share_inputs buffer TODO(gongshaotian): Refactor this func """ - # NOTE(luotingdan): Lazy initialize kv cache - if "caches" not in self.share_inputs: - self.initialize_kv_cache() # NOTE(luotingdan): Set environment variable of prefill node if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill": @@ -739,7 +736,6 @@ def initialize_kv_cache(self) -> None: else: for i in range(self.model_config.num_hidden_layers): - cache_kvs[f"key_caches_{i}"] = paddle.full( shape=kv_cache_shape, fill_value=0, @@ -999,9 +995,6 @@ def capture_model(self) -> None: time_before_capture = time.perf_counter() expected_decode_len = 1 capture_sizes = self.cudagraph_capture_sizes.copy() - need_init_cache = "caches" not in self.share_inputs - if need_init_cache: - self.initialize_kv_cache() for batch_size in sorted(capture_sizes, reverse=True): self._dummy_run( num_tokens=self.parallel_config.max_num_batched_tokens, @@ -1010,8 +1003,7 @@ def capture_model(self) -> None: expected_decode_len=expected_decode_len, ) logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}") - if need_init_cache: - self.clear_cache() + time_after_capture = time.perf_counter() logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds") @@ -1237,6 +1229,7 @@ def profile_run(self) -> None: if self.speculative_method in ["mtp"]: self.proposer.clear_dummy_input() + self.parallel_config.do_profile = False def update_share_input_block_num(self, num_gpu_blocks: int) -> None: """ @@ -1247,8 +1240,7 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None: self.num_gpu_blocks = num_gpu_blocks # Reset block table and kv cache with global block num - if not (self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): - self.initialize_kv_cache() + self.initialize_kv_cache() # Reset free list free_list = list( @@ -1266,8 +1258,6 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None: } ) - self.parallel_config.do_profile = False - if self.speculative_method in ["mtp"]: self.proposer.update_block_num(num_gpu_blocks) diff --git a/fastdeploy/worker/gpu_worker.py b/fastdeploy/worker/gpu_worker.py index 7dcdcbe8f7..498593658a 100644 --- a/fastdeploy/worker/gpu_worker.py +++ b/fastdeploy/worker/gpu_worker.py @@ -165,9 +165,10 @@ def get_model(self) -> nn.Layer: """Get current model""" return self.model_runner.get_model() - def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: - """Initizlize the KV Cache""" - pass + def initialize_cache(self, num_gpu_blocks: int) -> None: + """Initizlize the KV Cache with accurate num_gpu_blocks""" + # accurate cache size + self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks) def execute_model( self, @@ -198,7 +199,3 @@ def check_health(self) -> bool: def cal_theortical_kvcache(self) -> int: """Calculate the block memory required""" return self.model_runner.cal_theortical_kvcache() - - def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None: - """Reinitialize the kv cache using the parameters from the profile""" - self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks) diff --git a/fastdeploy/worker/iluvatar_model_runner.py b/fastdeploy/worker/iluvatar_model_runner.py index 1bce9d19e2..aa571e9681 100644 --- a/fastdeploy/worker/iluvatar_model_runner.py +++ b/fastdeploy/worker/iluvatar_model_runner.py @@ -141,9 +141,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request]): Process inputs for prefill tasks and insert it to share_inputs buffer TODO(gongshaotian): Refactor this func """ - # NOTE(luotingdan): Lazy initialize kv cache - if "caches" not in self.share_inputs: - self.initialize_kv_cache() # NOTE(luotingdan): Set environment variable of prefill node if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill": @@ -1013,11 +1010,11 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None: Args: num_gpu_blocks: """ + self.parallel_config.do_profile = False self.num_gpu_blocks = num_gpu_blocks # Reset block table and kv cache with global block num - if not (self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): - self.initialize_kv_cache() + self.initialize_kv_cache() # Reset free list free_list = list( @@ -1035,8 +1032,6 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None: } ) - self.parallel_config.do_profile = False - def cal_theortical_kvcache(self): """ Calculate the total block memory required at the model level diff --git a/fastdeploy/worker/iluvatar_worker.py b/fastdeploy/worker/iluvatar_worker.py index f855466ffc..76fcb558bc 100644 --- a/fastdeploy/worker/iluvatar_worker.py +++ b/fastdeploy/worker/iluvatar_worker.py @@ -99,9 +99,9 @@ def get_model(self) -> nn.Layer: """ """ return self.model_runner.get_model() - def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: + def initialize_cache(self, num_gpu_blocks: int) -> None: """ """ - pass + self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks) def execute_model( self, @@ -135,7 +135,3 @@ def check_health(self) -> bool: def cal_theortical_kvcache(self) -> int: """ """ return self.model_runner.cal_theortical_kvcache() - - def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None: - """ """ - self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks) diff --git a/fastdeploy/worker/worker_base.py b/fastdeploy/worker/worker_base.py index 0d604c2e05..281776bb0e 100644 --- a/fastdeploy/worker/worker_base.py +++ b/fastdeploy/worker/worker_base.py @@ -64,7 +64,7 @@ def init_device(self) -> None: raise NotImplementedError @abstractmethod - def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: + def initialize_cache(self, num_gpu_blocks: int) -> None: """Initizlize the KV Cache with the given size in blocks.""" raise NotImplementedError diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index bb49920366..17c39abc4a 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -347,7 +347,7 @@ def event_loop_normal(self) -> None: self.exist_prefill_task_signal.value[0] = self.worker.prefill_finished() - def determine_num_available_blocks(self) -> None: + def initialize_kv_cache(self) -> None: """Profiles the peak memory usage of the model to determine how many KV blocks may be allocated without OOMs. @@ -403,10 +403,7 @@ def determine_num_available_blocks(self) -> None: # logger.info will write in worker_process.log # Need `print` to triger engine->check_worker_initialize_status->detect_thread print(f"------- num_blocks_global: {num_blocks_local} --------") - # 4. Updata share inputs - self.worker.reinitialize_kv_cache(num_gpu_blocks=num_blocks_local) - - def graph_optimize_and_warm_up_model(self) -> None: + # wait engine launch cache_manager if self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed": launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32) self.launched_cache_manager_signal = IPCSignal( @@ -418,6 +415,10 @@ def graph_optimize_and_warm_up_model(self) -> None: ) while np.any(self.launched_cache_manager_signal.value[0] <= 0): time.sleep(0.01) + # 4. init kv_cache with accurate num_blocks + self.worker.initialize_cache(num_gpu_blocks=num_blocks_local) + + def graph_optimize_and_warm_up_model(self) -> None: self.worker.graph_optimize_and_warm_up_model() def init_device(self) -> None: @@ -731,11 +732,11 @@ def run_worker_proc() -> None: # Load model worker_proc.load_model() - logger.info("determine_num_available_blocks") - worker_proc.determine_num_available_blocks() + # logger.info("determine_num_available_blocks") + worker_proc.initialize_kv_cache() # Trigger CUDAGraph capture - worker_proc.graph_optimize_and_warm_up_model() + worker_proc.worker.graph_optimize_and_warm_up_model() # Initialize health status worker_proc.init_health_status() From bf28bcf3b7e6324e43cf740c8681470db527f0a4 Mon Sep 17 00:00:00 2001 From: zeroRains Date: Mon, 21 Jul 2025 21:42:32 +0800 Subject: [PATCH 06/10] update --- fastdeploy/worker/gcu_model_runner.py | 2 +- fastdeploy/worker/iluvatar_model_runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index 0a43c1f631..5af120236a 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -1026,6 +1026,7 @@ def profile_run(self) -> None: if self.speculative_method in ["mtp"]: self.proposer.clear_dummy_input() # paddle.device.cuda.synchronize() + self.parallel_config.do_profile = False def update_share_input_block_num(self, num_gpu_blocks: int) -> None: """ @@ -1033,7 +1034,6 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None: Args: num_gpu_blocks: """ - self.parallel_config.do_profile = False self.num_gcu_blocks = num_gpu_blocks # Reset block table and kv cache with global block num diff --git a/fastdeploy/worker/iluvatar_model_runner.py b/fastdeploy/worker/iluvatar_model_runner.py index aa571e9681..d2e5bff1b0 100644 --- a/fastdeploy/worker/iluvatar_model_runner.py +++ b/fastdeploy/worker/iluvatar_model_runner.py @@ -1003,6 +1003,7 @@ def profile_run(self) -> None: self.clear_cache() # paddle.device.cuda.synchronize() + self.parallel_config.do_profile = False def update_share_input_block_num(self, num_gpu_blocks: int) -> None: """ @@ -1010,7 +1011,6 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None: Args: num_gpu_blocks: """ - self.parallel_config.do_profile = False self.num_gpu_blocks = num_gpu_blocks # Reset block table and kv cache with global block num From d8b931713ad99f69085974e6b21beb4dff8e1b49 Mon Sep 17 00:00:00 2001 From: zeroRains Date: Mon, 21 Jul 2025 22:02:05 +0800 Subject: [PATCH 07/10] fix xpu --- fastdeploy/worker/xpu_worker.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/fastdeploy/worker/xpu_worker.py b/fastdeploy/worker/xpu_worker.py index 8ce43b4dd0..49b4a74ac0 100644 --- a/fastdeploy/worker/xpu_worker.py +++ b/fastdeploy/worker/xpu_worker.py @@ -131,9 +131,9 @@ def get_model(self) -> nn.Layer: """ """ return self.model_runner.get_model() - def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: + def initialize_cache(self, num_gpu_blocks: int) -> None: """ """ - pass + self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks) def execute_model( self, @@ -159,7 +159,3 @@ def preprocess_new_task(self, req_dicts: List[Request]) -> None: def check_health(self) -> bool: """ """ return True - - def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None: - """ """ - self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks) From 2fda7096766dedd3782edc04c6d33d8a30ce3c52 Mon Sep 17 00:00:00 2001 From: zeroRains Date: Tue, 22 Jul 2025 11:35:40 +0800 Subject: [PATCH 08/10] change the do_profile flag --- fastdeploy/worker/gcu_model_runner.py | 7 +++---- fastdeploy/worker/gpu_model_runner.py | 7 +++---- fastdeploy/worker/iluvatar_model_runner.py | 5 ++--- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index 5af120236a..406ce53d98 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -559,7 +559,7 @@ def update_parameters(self, pid): self.initialize_kv_cache() self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory") - def initialize_kv_cache(self) -> None: + def initialize_kv_cache(self, profile: bool = False) -> None: """ Initialize kv cache """ @@ -580,7 +580,7 @@ def initialize_kv_cache(self) -> None: kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num) # local_rank = self.local_rank % self.parallel_config.tensor_parallel_size - if not self.parallel_config.do_profile and ( + if not profile and ( self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed" ): raise NotImplementedError("prefix_caching is not support by GCUModelRunner.") @@ -1010,7 +1010,7 @@ def profile_run(self) -> None: # Initialize kv cache for profile run. After profile run kv cache will be reset. self.num_gcu_blocks = self.parallel_config.total_block_num - self.initialize_kv_cache() + self.initialize_kv_cache(profile=True) # 1. Profile with multimodal encoder & encoder cache @@ -1026,7 +1026,6 @@ def profile_run(self) -> None: if self.speculative_method in ["mtp"]: self.proposer.clear_dummy_input() # paddle.device.cuda.synchronize() - self.parallel_config.do_profile = False def update_share_input_block_num(self, num_gpu_blocks: int) -> None: """ diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 2ab567f95b..710ecaff34 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -697,7 +697,7 @@ def initialize_forward_meta(self): for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) - def initialize_kv_cache(self) -> None: + def initialize_kv_cache(self, profile: bool = False) -> None: """ Initialize kv cache """ @@ -718,7 +718,7 @@ def initialize_kv_cache(self) -> None: kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num) local_rank = self.local_rank % self.parallel_config.tensor_parallel_size - if not self.parallel_config.do_profile and ( + if not profile and ( self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed" ): cache_kvs_list = [] @@ -1214,7 +1214,7 @@ def profile_run(self) -> None: # Initialize kv cache for profile run. After profile run kv cache will be reset. # TODO(gongshaotian): Optimize the management logic of kvcache self.num_gpu_blocks = self.parallel_config.total_block_num - self.initialize_kv_cache() + self.initialize_kv_cache(profile=True) # 1. Profile with multimodal encoder & encoder cache @@ -1229,7 +1229,6 @@ def profile_run(self) -> None: if self.speculative_method in ["mtp"]: self.proposer.clear_dummy_input() - self.parallel_config.do_profile = False def update_share_input_block_num(self, num_gpu_blocks: int) -> None: """ diff --git a/fastdeploy/worker/iluvatar_model_runner.py b/fastdeploy/worker/iluvatar_model_runner.py index d2e5bff1b0..54d6600d32 100644 --- a/fastdeploy/worker/iluvatar_model_runner.py +++ b/fastdeploy/worker/iluvatar_model_runner.py @@ -549,7 +549,7 @@ def clear_cache(self): if self.forward_meta is not None: self.forward_meta.clear_caches() - def initialize_kv_cache(self) -> None: + def initialize_kv_cache(self, profile: bool = False) -> None: """ Initialize kv cache """ @@ -989,7 +989,7 @@ def profile_run(self) -> None: # Initialize kv cache for profile run. After profile run kv cache will be reset. # TODO(gongshaotian): Optimize the management logic of kvcache self.num_gpu_blocks = self.parallel_config.total_block_num - self.initialize_kv_cache() + self.initialize_kv_cache(profile=True) # 1. Profile with multimodal encoder & encoder cache @@ -1003,7 +1003,6 @@ def profile_run(self) -> None: self.clear_cache() # paddle.device.cuda.synchronize() - self.parallel_config.do_profile = False def update_share_input_block_num(self, num_gpu_blocks: int) -> None: """ From ad3ca954592343ad5727933af3f032c94e40819c Mon Sep 17 00:00:00 2001 From: zeroRains Date: Tue, 22 Jul 2025 11:39:29 +0800 Subject: [PATCH 09/10] update --- fastdeploy/worker/worker_process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 17c39abc4a..b26fa10541 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -732,7 +732,7 @@ def run_worker_proc() -> None: # Load model worker_proc.load_model() - # logger.info("determine_num_available_blocks") + # Initialize KV Cache worker_proc.initialize_kv_cache() # Trigger CUDAGraph capture From b12667b2cba05b6f2c8702d479e29a068518ef15 Mon Sep 17 00:00:00 2001 From: zeroRains Date: Tue, 22 Jul 2025 13:08:54 +0800 Subject: [PATCH 10/10] add new threads to init cache_manager --- fastdeploy/engine/engine.py | 9 ++++++--- fastdeploy/worker/worker_process.py | 5 ++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index c86d328c2e..a03fc5f99a 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -1178,12 +1178,13 @@ def detect_thread(): self.worker_init_status["layer_loadding"] = progress if self.worker_init_status["layer_loadding"] == self.cfg.model_config.num_layers - 1: self.worker_init_status["finished"] = True - elif match := re.search(r"num_blocks_global", line): - if self.do_profile: - self._stop_profile() self.checking_worker_status_thread = threading.Thread(target=detect_thread, daemon=True) self.checking_worker_status_thread.start() + checking_worker_init_kv_cache_status_thread = None + if self.do_profile: + checking_worker_init_kv_cache_status_thread = threading.Thread(target=self._stop_profile, daemon=True) + checking_worker_init_kv_cache_status_thread.start() # display weight loadding progress with tqdm(total=100, desc="Loading Weights") as pbar: @@ -1214,6 +1215,8 @@ def detect_thread(): self.worker_init_status["finished"] = True try: self.checking_worker_status_thread.join(timeout=1) + if checking_worker_init_kv_cache_status_thread is not None: + checking_worker_init_kv_cache_status_thread.join(timeout=1) except Exception: pass return True diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index b26fa10541..04dd5d74b2 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -400,9 +400,8 @@ def initialize_kv_cache(self) -> None: self.get_profile_block_num_signal.value[0] = num_blocks_local else: num_blocks_local = self.fd_config.parallel_config.total_block_num - # logger.info will write in worker_process.log - # Need `print` to triger engine->check_worker_initialize_status->detect_thread - print(f"------- num_blocks_global: {num_blocks_local} --------") + + logger.info(f"------- num_blocks_global: {num_blocks_local} --------") # wait engine launch cache_manager if self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed": launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)