Skip to content

[Feature] Support using prefix-caching + cudagraph for inference #2924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 22, 2025
Merged
2 changes: 0 additions & 2 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,8 +863,6 @@ def create_engine_config(self) -> Config:
graph_opt_cfg = self.create_graph_optimization_config()
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)

assert not (self.use_cudagraph and self.enable_prefix_caching), "Prefix caching cannot be used with CUDA graph"

assert not (
self.tensor_parallel_size <= 1 and self.enable_custom_all_reduce
), "enable_custom_all_reduce must be used with tensor_parallel_size>1"
Expand Down
19 changes: 16 additions & 3 deletions fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def start(self, api_server_pid=None):
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=self.ipc_signal_suffix,
)
self.launched_cache_manager_signal.value[0] = 1

self.worker_proc = self._start_worker_service()
console_logger.info("Waitting worker processes ready...")
Expand Down Expand Up @@ -217,9 +218,6 @@ def start(self, api_server_pid=None):
# Start TokenProcessor thread
self.token_processor.run()

if self.do_profile:
self._stop_profile()

if self.cfg.splitwise_role != "mixed":
# 单机逻辑
self.engine_worker_queue.available_prefill_instances.put(1)
Expand Down Expand Up @@ -849,6 +847,17 @@ def _init_worker_signals(self):
create=True,
)

# launched_cache_manager_signal 用于感知engine是否启动了cache_manager
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
self.launched_cache_manager_signal = IPCSignal(
name="launched_cache_manager_signal",
array=launched_cache_manager_signal_data,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=True,
)

# worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间
worker_healthy_live_recorded_time_array = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32)
self.worker_healthy_live_signal = IPCSignal(
Expand Down Expand Up @@ -1133,6 +1142,7 @@ def _stop_profile(self):
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=self.ipc_signal_suffix,
)
self.launched_cache_manager_signal.value[0] = 1

def check_health(self, time_interval_threashold=30):
"""
Expand Down Expand Up @@ -1168,6 +1178,9 @@ def detect_thread():
self.worker_init_status["layer_loadding"] = progress
if self.worker_init_status["layer_loadding"] == self.cfg.model_config.num_layers - 1:
self.worker_init_status["finished"] = True
elif match := re.search(r"num_blocks_global", line):
if self.do_profile:
self._stop_profile()

self.checking_worker_status_thread = threading.Thread(target=detect_thread, daemon=True)
self.checking_worker_status_thread.start()
Comment on lines 1181 to 1183
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几行也需要删掉

Copy link
Contributor Author

@zeroRains zeroRains Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个删掉了之后,engine启动那边加载模型的进度条就没有了,只能在workerlog.0中看到加载进度。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前版本engine感知到worker启动结束,也需要这个线程去判断,是需要更改现有engine感知worker启动的方式吗?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

范围标错了

Expand Down
13 changes: 4 additions & 9 deletions fastdeploy/worker/gcu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
"""
Process inputs for prefill tasks and insert it to share_inputs buffer
"""
if "caches" not in self.share_inputs:
self.initialize_kv_cache()

if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill":
os.environ["PREFILL_NODE_ONE_STEP_STOP"] = "1"
Expand Down Expand Up @@ -561,7 +559,7 @@ def update_parameters(self, pid):
self.initialize_kv_cache()
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")

def initialize_kv_cache(self) -> None:
def initialize_kv_cache(self, profile: bool = False) -> None:
"""
Initialize kv cache
"""
Expand All @@ -582,7 +580,7 @@ def initialize_kv_cache(self) -> None:
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num)
# local_rank = self.local_rank % self.parallel_config.tensor_parallel_size

if not self.parallel_config.do_profile and (
if not profile and (
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
):
raise NotImplementedError("prefix_caching is not support by GCUModelRunner.")
Expand Down Expand Up @@ -1012,7 +1010,7 @@ def profile_run(self) -> None:

# Initialize kv cache for profile run. After profile run kv cache will be reset.
self.num_gcu_blocks = self.parallel_config.total_block_num
self.initialize_kv_cache()
self.initialize_kv_cache(profile=True)

# 1. Profile with multimodal encoder & encoder cache

Expand All @@ -1038,8 +1036,7 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
self.num_gcu_blocks = num_gpu_blocks

# Reset block table and kv cache with global block num
if not (self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
self.initialize_kv_cache()
self.initialize_kv_cache()

# Reset free list
free_list = list(
Expand All @@ -1057,8 +1054,6 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
}
)

self.parallel_config.do_profile = False

if self.speculative_method in ["mtp"]:
self.proposer.update_block_num(num_gpu_blocks)

Expand Down
8 changes: 2 additions & 6 deletions fastdeploy/worker/gcu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def get_model(self) -> nn.Layer:
""" """
return self.model_runner.get_model()

def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
def initialize_cache(self, num_gpu_blocks: int) -> None:
""" """
pass
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)

def execute_model(
self,
Expand Down Expand Up @@ -134,7 +134,3 @@ def check_health(self) -> bool:
def cal_theortical_kvcache(self) -> int:
""" """
return self.model_runner.cal_theortical_kvcache()

def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None:
""" """
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
15 changes: 4 additions & 11 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
Process inputs for prefill tasks and insert it to share_inputs buffer
TODO(gongshaotian): Refactor this func
"""
# NOTE(luotingdan): Lazy initialize kv cache
if "caches" not in self.share_inputs:
self.initialize_kv_cache()

# NOTE(luotingdan): Set environment variable of prefill node
if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill":
Expand Down Expand Up @@ -700,7 +697,7 @@ def initialize_forward_meta(self):
for attn_backend in self.attn_backends:
attn_backend.init_attention_metadata(self.forward_meta)

def initialize_kv_cache(self) -> None:
def initialize_kv_cache(self, profile: bool = False) -> None:
"""
Initialize kv cache
"""
Expand All @@ -721,7 +718,7 @@ def initialize_kv_cache(self) -> None:
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=max_block_num)
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size

if not self.parallel_config.do_profile and (
if not profile and (
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
):
cache_kvs_list = []
Expand All @@ -739,7 +736,6 @@ def initialize_kv_cache(self) -> None:

else:
for i in range(self.model_config.num_hidden_layers):

cache_kvs[f"key_caches_{i}"] = paddle.full(
shape=kv_cache_shape,
fill_value=0,
Expand Down Expand Up @@ -1218,7 +1214,7 @@ def profile_run(self) -> None:
# Initialize kv cache for profile run. After profile run kv cache will be reset.
# TODO(gongshaotian): Optimize the management logic of kvcache
self.num_gpu_blocks = self.parallel_config.total_block_num
self.initialize_kv_cache()
self.initialize_kv_cache(profile=True)

# 1. Profile with multimodal encoder & encoder cache

Expand All @@ -1243,8 +1239,7 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
self.num_gpu_blocks = num_gpu_blocks

# Reset block table and kv cache with global block num
if not (self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
self.initialize_kv_cache()
self.initialize_kv_cache()

# Reset free list
free_list = list(
Expand All @@ -1262,8 +1257,6 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
}
)

self.parallel_config.do_profile = False

if self.speculative_method in ["mtp"]:
self.proposer.update_block_num(num_gpu_blocks)

Expand Down
11 changes: 4 additions & 7 deletions fastdeploy/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,10 @@ def get_model(self) -> nn.Layer:
"""Get current model"""
return self.model_runner.get_model()

def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
"""Initizlize the KV Cache"""
pass
def initialize_cache(self, num_gpu_blocks: int) -> None:
"""Initizlize the KV Cache with accurate num_gpu_blocks"""
# accurate cache size
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)

def execute_model(
self,
Expand Down Expand Up @@ -198,7 +199,3 @@ def check_health(self) -> bool:
def cal_theortical_kvcache(self) -> int:
"""Calculate the block memory required"""
return self.model_runner.cal_theortical_kvcache()

def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None:
"""Reinitialize the kv cache using the parameters from the profile"""
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
12 changes: 3 additions & 9 deletions fastdeploy/worker/iluvatar_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
Process inputs for prefill tasks and insert it to share_inputs buffer
TODO(gongshaotian): Refactor this func
"""
# NOTE(luotingdan): Lazy initialize kv cache
if "caches" not in self.share_inputs:
self.initialize_kv_cache()

# NOTE(luotingdan): Set environment variable of prefill node
if req_dicts[-1].disaggregate_info is not None and req_dicts[-1].disaggregate_info["role"] == "prefill":
Expand Down Expand Up @@ -552,7 +549,7 @@ def clear_cache(self):
if self.forward_meta is not None:
self.forward_meta.clear_caches()

def initialize_kv_cache(self) -> None:
def initialize_kv_cache(self, profile: bool = False) -> None:
"""
Initialize kv cache
"""
Expand Down Expand Up @@ -992,7 +989,7 @@ def profile_run(self) -> None:
# Initialize kv cache for profile run. After profile run kv cache will be reset.
# TODO(gongshaotian): Optimize the management logic of kvcache
self.num_gpu_blocks = self.parallel_config.total_block_num
self.initialize_kv_cache()
self.initialize_kv_cache(profile=True)

# 1. Profile with multimodal encoder & encoder cache

Expand All @@ -1016,8 +1013,7 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
self.num_gpu_blocks = num_gpu_blocks

# Reset block table and kv cache with global block num
if not (self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
self.initialize_kv_cache()
self.initialize_kv_cache()

# Reset free list
free_list = list(
Expand All @@ -1035,8 +1031,6 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
}
)

self.parallel_config.do_profile = False

def cal_theortical_kvcache(self):
"""
Calculate the total block memory required at the model level
Expand Down
8 changes: 2 additions & 6 deletions fastdeploy/worker/iluvatar_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ def get_model(self) -> nn.Layer:
""" """
return self.model_runner.get_model()

def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
def initialize_cache(self, num_gpu_blocks: int) -> None:
""" """
pass
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)

def execute_model(
self,
Expand Down Expand Up @@ -135,7 +135,3 @@ def check_health(self) -> bool:
def cal_theortical_kvcache(self) -> int:
""" """
return self.model_runner.cal_theortical_kvcache()

def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None:
""" """
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
2 changes: 1 addition & 1 deletion fastdeploy/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def init_device(self) -> None:
raise NotImplementedError

@abstractmethod
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
def initialize_cache(self, num_gpu_blocks: int) -> None:
"""Initizlize the KV Cache with the given size in blocks."""
raise NotImplementedError

Expand Down
28 changes: 23 additions & 5 deletions fastdeploy/worker/worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def event_loop_normal(self) -> None:

self.exist_prefill_task_signal.value[0] = self.worker.prefill_finished()

def determine_num_available_blocks(self) -> None:
def initialize_kv_cache(self) -> None:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.

Expand Down Expand Up @@ -400,8 +400,26 @@ def determine_num_available_blocks(self) -> None:
self.get_profile_block_num_signal.value[0] = num_blocks_local
else:
num_blocks_local = self.fd_config.parallel_config.total_block_num
# 4. Updata share inputs
self.worker.reinitialize_kv_cache(num_gpu_blocks=num_blocks_local)
# logger.info will write in worker_process.log
# Need `print` to triger engine->check_worker_initialize_status->detect_thread
print(f"------- num_blocks_global: {num_blocks_local} --------")
# wait engine launch cache_manager
if self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed":
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
self.launched_cache_manager_signal = IPCSignal(
name="launched_cache_manager_signal",
array=launched_cache_manager_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False,
)
while np.any(self.launched_cache_manager_signal.value[0] <= 0):
time.sleep(0.01)
# 4. init kv_cache with accurate num_blocks
self.worker.initialize_cache(num_gpu_blocks=num_blocks_local)

def graph_optimize_and_warm_up_model(self) -> None:
self.worker.graph_optimize_and_warm_up_model()

def init_device(self) -> None:
"""Initialize device and Construct model runner"""
Expand Down Expand Up @@ -714,8 +732,8 @@ def run_worker_proc() -> None:

# Load model
worker_proc.load_model()
logger.info("determine_num_available_blocks")
worker_proc.determine_num_available_blocks()
# Initialize KV Cache
worker_proc.initialize_kv_cache()

# Trigger CUDAGraph capture
worker_proc.worker.graph_optimize_and_warm_up_model()
Expand Down
8 changes: 2 additions & 6 deletions fastdeploy/worker/xpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ def get_model(self) -> nn.Layer:
""" """
return self.model_runner.get_model()

def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
def initialize_cache(self, num_gpu_blocks: int) -> None:
""" """
pass
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)

def execute_model(
self,
Expand All @@ -159,7 +159,3 @@ def preprocess_new_task(self, req_dicts: List[Request]) -> None:
def check_health(self) -> bool:
""" """
return True

def reinitialize_kv_cache(self, num_gpu_blocks: int) -> None:
""" """
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
Loading