diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 546d9de52a..633b6837c3 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -863,8 +863,6 @@ def create_engine_config(self) -> Config: graph_opt_cfg = self.create_graph_optimization_config() graph_opt_cfg.update_use_cudagraph(self.use_cudagraph) - assert not (self.use_cudagraph and self.enable_prefix_caching), "Prefix caching cannot be used with CUDA graph" - assert not ( self.tensor_parallel_size <= 1 and self.enable_custom_all_reduce ), "enable_custom_all_reduce must be used with tensor_parallel_size>1" diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 5b555465e9..42963ee807 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -183,6 +183,7 @@ def start(self, api_server_pid=None): engine_worker_queue_port=self.cfg.engine_worker_queue_port, pid_suffix=self.ipc_signal_suffix, ) + self.launched_cache_manager_signal.value[0] = 1 self.worker_proc = self._start_worker_service() console_logger.info("Waitting worker processes ready...") @@ -217,9 +218,6 @@ def start(self, api_server_pid=None): # Start TokenProcessor thread self.token_processor.run() - if self.do_profile: - self._stop_profile() - if self.cfg.splitwise_role != "mixed": # 单机逻辑 self.engine_worker_queue.available_prefill_instances.put(1) @@ -849,6 +847,17 @@ def _init_worker_signals(self): create=True, ) + # launched_cache_manager_signal 用于感知engine是否启动了cache_manager + if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed": + launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32) + self.launched_cache_manager_signal = IPCSignal( + name="launched_cache_manager_signal", + array=launched_cache_manager_signal_data, + dtype=np.int32, + suffix=self.ipc_signal_suffix, + create=True, + ) + # worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间 worker_healthy_live_recorded_time_array = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32) self.worker_healthy_live_signal = IPCSignal( @@ -1133,6 +1142,7 @@ def _stop_profile(self): engine_worker_queue_port=self.cfg.engine_worker_queue_port, pid_suffix=self.ipc_signal_suffix, ) + self.launched_cache_manager_signal.value[0] = 1 def check_health(self, time_interval_threashold=30): """ @@ -1171,6 +1181,10 @@ def detect_thread(): self.checking_worker_status_thread = threading.Thread(target=detect_thread, daemon=True) self.checking_worker_status_thread.start() + checking_worker_init_kv_cache_status_thread = None + if self.do_profile: + checking_worker_init_kv_cache_status_thread = threading.Thread(target=self._stop_profile, daemon=True) + checking_worker_init_kv_cache_status_thread.start() # display weight loadding progress with tqdm(total=100, desc="Loading Weights") as pbar: @@ -1201,6 +1215,8 @@ def detect_thread(): self.worker_init_status["finished"] = True try: self.checking_worker_status_thread.join(timeout=1) + if checking_worker_init_kv_cache_status_thread is not None: + checking_worker_init_kv_cache_status_thread.join(timeout=1) except Exception: pass return True diff --git a/fastdeploy/worker/gcu_model_runner.py b/fastdeploy/worker/gcu_model_runner.py index b58c2237fe..406ce53d98 100644 --- a/fastdeploy/worker/gcu_model_runner.py +++ b/fastdeploy/worker/gcu_model_runner.py @@ -151,8 +151,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request]): """ Process inputs for prefill tasks and insert it to share_inputs buffer """ - if "caches" not in self.share_inputs: - self.initialize_kv_cache() if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill": os.environ["PREFILL_NODE_ONE_STEP_STOP"] = "1" @@ -561,7 +559,7 @@ def update_parameters(self, pid): self.initialize_kv_cache() self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory") - def initialize_kv_cache(self) -> None: + def initialize_kv_cache(self, profile: bool = False) -> None: """ Initialize kv cache """ @@ -582,7 +580,7 @@ def initialize_kv_cache(self) -> None: kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num) # local_rank = self.local_rank % self.parallel_config.tensor_parallel_size - if not self.parallel_config.do_profile and ( + if not profile and ( self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed" ): raise NotImplementedError("prefix_caching is not support by GCUModelRunner.") @@ -1012,7 +1010,7 @@ def profile_run(self) -> None: # Initialize kv cache for profile run. After profile run kv cache will be reset. self.num_gcu_blocks = self.parallel_config.total_block_num - self.initialize_kv_cache() + self.initialize_kv_cache(profile=True) # 1. Profile with multimodal encoder & encoder cache @@ -1038,8 +1036,7 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None: self.num_gcu_blocks = num_gpu_blocks # Reset block table and kv cache with global block num - if not (self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): - self.initialize_kv_cache() + self.initialize_kv_cache() # Reset free list free_list = list( @@ -1057,8 +1054,6 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None: } ) - self.parallel_config.do_profile = False - if self.speculative_method in ["mtp"]: self.proposer.update_block_num(num_gpu_blocks) diff --git a/fastdeploy/worker/gcu_worker.py b/fastdeploy/worker/gcu_worker.py index 2e4e83885b..1b98e3b0c8 100644 --- a/fastdeploy/worker/gcu_worker.py +++ b/fastdeploy/worker/gcu_worker.py @@ -98,9 +98,9 @@ def get_model(self) -> nn.Layer: """ """ return self.model_runner.get_model() - def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: + def initialize_cache(self, num_gpu_blocks: int) -> None: """ """ - pass + self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks) def execute_model( self, @@ -134,7 +134,3 @@ def check_health(self) -> bool: def cal_theortical_kvcache(self) -> int: """ """ return self.model_runner.cal_theortical_kvcache() - - def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None: - """ """ - self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 0d199c57d2..710ecaff34 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -193,9 +193,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request]): Process inputs for prefill tasks and insert it to share_inputs buffer TODO(gongshaotian): Refactor this func """ - # NOTE(luotingdan): Lazy initialize kv cache - if "caches" not in self.share_inputs: - self.initialize_kv_cache() # NOTE(luotingdan): Set environment variable of prefill node if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill": @@ -700,7 +697,7 @@ def initialize_forward_meta(self): for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) - def initialize_kv_cache(self) -> None: + def initialize_kv_cache(self, profile: bool = False) -> None: """ Initialize kv cache """ @@ -721,7 +718,7 @@ def initialize_kv_cache(self) -> None: kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num) local_rank = self.local_rank % self.parallel_config.tensor_parallel_size - if not self.parallel_config.do_profile and ( + if not profile and ( self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed" ): cache_kvs_list = [] @@ -739,7 +736,6 @@ def initialize_kv_cache(self) -> None: else: for i in range(self.model_config.num_hidden_layers): - cache_kvs[f"key_caches_{i}"] = paddle.full( shape=kv_cache_shape, fill_value=0, @@ -1218,7 +1214,7 @@ def profile_run(self) -> None: # Initialize kv cache for profile run. After profile run kv cache will be reset. # TODO(gongshaotian): Optimize the management logic of kvcache self.num_gpu_blocks = self.parallel_config.total_block_num - self.initialize_kv_cache() + self.initialize_kv_cache(profile=True) # 1. Profile with multimodal encoder & encoder cache @@ -1243,8 +1239,7 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None: self.num_gpu_blocks = num_gpu_blocks # Reset block table and kv cache with global block num - if not (self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): - self.initialize_kv_cache() + self.initialize_kv_cache() # Reset free list free_list = list( @@ -1262,8 +1257,6 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None: } ) - self.parallel_config.do_profile = False - if self.speculative_method in ["mtp"]: self.proposer.update_block_num(num_gpu_blocks) diff --git a/fastdeploy/worker/gpu_worker.py b/fastdeploy/worker/gpu_worker.py index 45aa96b4c1..812b62935a 100644 --- a/fastdeploy/worker/gpu_worker.py +++ b/fastdeploy/worker/gpu_worker.py @@ -165,9 +165,10 @@ def get_model(self) -> nn.Layer: """Get current model""" return self.model_runner.get_model() - def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: - """Initizlize the KV Cache""" - pass + def initialize_cache(self, num_gpu_blocks: int) -> None: + """Initizlize the KV Cache with accurate num_gpu_blocks""" + # accurate cache size + self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks) def execute_model( self, @@ -198,7 +199,3 @@ def check_health(self) -> bool: def cal_theortical_kvcache(self) -> int: """Calculate the block memory required""" return self.model_runner.cal_theortical_kvcache() - - def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None: - """Reinitialize the kv cache using the parameters from the profile""" - self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks) diff --git a/fastdeploy/worker/iluvatar_model_runner.py b/fastdeploy/worker/iluvatar_model_runner.py index 1bce9d19e2..54d6600d32 100644 --- a/fastdeploy/worker/iluvatar_model_runner.py +++ b/fastdeploy/worker/iluvatar_model_runner.py @@ -141,9 +141,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request]): Process inputs for prefill tasks and insert it to share_inputs buffer TODO(gongshaotian): Refactor this func """ - # NOTE(luotingdan): Lazy initialize kv cache - if "caches" not in self.share_inputs: - self.initialize_kv_cache() # NOTE(luotingdan): Set environment variable of prefill node if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill": @@ -552,7 +549,7 @@ def clear_cache(self): if self.forward_meta is not None: self.forward_meta.clear_caches() - def initialize_kv_cache(self) -> None: + def initialize_kv_cache(self, profile: bool = False) -> None: """ Initialize kv cache """ @@ -992,7 +989,7 @@ def profile_run(self) -> None: # Initialize kv cache for profile run. After profile run kv cache will be reset. # TODO(gongshaotian): Optimize the management logic of kvcache self.num_gpu_blocks = self.parallel_config.total_block_num - self.initialize_kv_cache() + self.initialize_kv_cache(profile=True) # 1. Profile with multimodal encoder & encoder cache @@ -1016,8 +1013,7 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None: self.num_gpu_blocks = num_gpu_blocks # Reset block table and kv cache with global block num - if not (self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): - self.initialize_kv_cache() + self.initialize_kv_cache() # Reset free list free_list = list( @@ -1035,8 +1031,6 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None: } ) - self.parallel_config.do_profile = False - def cal_theortical_kvcache(self): """ Calculate the total block memory required at the model level diff --git a/fastdeploy/worker/iluvatar_worker.py b/fastdeploy/worker/iluvatar_worker.py index f855466ffc..76fcb558bc 100644 --- a/fastdeploy/worker/iluvatar_worker.py +++ b/fastdeploy/worker/iluvatar_worker.py @@ -99,9 +99,9 @@ def get_model(self) -> nn.Layer: """ """ return self.model_runner.get_model() - def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: + def initialize_cache(self, num_gpu_blocks: int) -> None: """ """ - pass + self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks) def execute_model( self, @@ -135,7 +135,3 @@ def check_health(self) -> bool: def cal_theortical_kvcache(self) -> int: """ """ return self.model_runner.cal_theortical_kvcache() - - def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None: - """ """ - self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks) diff --git a/fastdeploy/worker/worker_base.py b/fastdeploy/worker/worker_base.py index 0d604c2e05..281776bb0e 100644 --- a/fastdeploy/worker/worker_base.py +++ b/fastdeploy/worker/worker_base.py @@ -64,7 +64,7 @@ def init_device(self) -> None: raise NotImplementedError @abstractmethod - def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: + def initialize_cache(self, num_gpu_blocks: int) -> None: """Initizlize the KV Cache with the given size in blocks.""" raise NotImplementedError diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index d656f60ae9..96503fd800 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -347,7 +347,7 @@ def event_loop_normal(self) -> None: self.exist_prefill_task_signal.value[0] = self.worker.prefill_finished() - def determine_num_available_blocks(self) -> None: + def initialize_kv_cache(self) -> None: """Profiles the peak memory usage of the model to determine how many KV blocks may be allocated without OOMs. @@ -400,8 +400,25 @@ def determine_num_available_blocks(self) -> None: self.get_profile_block_num_signal.value[0] = num_blocks_local else: num_blocks_local = self.fd_config.parallel_config.total_block_num - # 4. Updata share inputs - self.worker.reinitialize_kv_cache(num_gpu_blocks=num_blocks_local) + + logger.info(f"------- num_blocks_global: {num_blocks_local} --------") + # wait engine launch cache_manager + if self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed": + launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32) + self.launched_cache_manager_signal = IPCSignal( + name="launched_cache_manager_signal", + array=launched_cache_manager_signal_data, + dtype=np.int32, + suffix=self.parallel_config.engine_pid, + create=False, + ) + while np.any(self.launched_cache_manager_signal.value[0] <= 0): + time.sleep(0.01) + # 4. init kv_cache with accurate num_blocks + self.worker.initialize_cache(num_gpu_blocks=num_blocks_local) + + def graph_optimize_and_warm_up_model(self) -> None: + self.worker.graph_optimize_and_warm_up_model() def init_device(self) -> None: """Initialize device and Construct model runner""" @@ -714,8 +731,8 @@ def run_worker_proc() -> None: # Load model worker_proc.load_model() - logger.info("determine_num_available_blocks") - worker_proc.determine_num_available_blocks() + # Initialize KV Cache + worker_proc.initialize_kv_cache() # Trigger CUDAGraph capture worker_proc.worker.graph_optimize_and_warm_up_model() diff --git a/fastdeploy/worker/xpu_worker.py b/fastdeploy/worker/xpu_worker.py index 8ce43b4dd0..49b4a74ac0 100644 --- a/fastdeploy/worker/xpu_worker.py +++ b/fastdeploy/worker/xpu_worker.py @@ -131,9 +131,9 @@ def get_model(self) -> nn.Layer: """ """ return self.model_runner.get_model() - def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: + def initialize_cache(self, num_gpu_blocks: int) -> None: """ """ - pass + self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks) def execute_model( self, @@ -159,7 +159,3 @@ def preprocess_new_task(self, req_dicts: List[Request]) -> None: def check_health(self) -> bool: """ """ return True - - def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None: - """ """ - self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)