Skip to content

[Feature] Support using prefix-caching + cudagraph for inference #2924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 22, 2025
Merged
2 changes: 0 additions & 2 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,8 +863,6 @@ def create_engine_config(self) -> Config:
graph_opt_cfg = self.create_graph_optimization_config()
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)

assert not (self.use_cudagraph and self.enable_prefix_caching), "Prefix caching cannot be used with CUDA graph"

assert not (
self.tensor_parallel_size <= 1 and self.enable_custom_all_reduce
), "enable_custom_all_reduce must be used with tensor_parallel_size>1"
Expand Down
19 changes: 16 additions & 3 deletions fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def start(self, api_server_pid=None):
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=self.ipc_signal_suffix,
)
self.launched_cache_manager_signal.value[0] = 1

self.worker_proc = self._start_worker_service()
console_logger.info("Waitting worker processes ready...")
Expand Down Expand Up @@ -217,9 +218,6 @@ def start(self, api_server_pid=None):
# Start TokenProcessor thread
self.token_processor.run()

if self.do_profile:
self._stop_profile()

if self.cfg.splitwise_role != "mixed":
# 单机逻辑
self.engine_worker_queue.available_prefill_instances.put(1)
Expand Down Expand Up @@ -849,6 +847,17 @@ def _init_worker_signals(self):
create=True,
)

# launched_cache_manager_signal 用于感知engine是否启动了cache_manager
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
self.launched_cache_manager_signal = IPCSignal(
name="launched_cache_manager_signal",
array=launched_cache_manager_signal_data,
dtype=np.int32,
suffix=self.ipc_signal_suffix,
create=True,
)

# worker_live_signal 用于engine感知各worker进程是否存活,记录每个step 时间
worker_healthy_live_recorded_time_array = np.zeros(shape=[self.cfg.worker_num_per_node], dtype=np.int32)
self.worker_healthy_live_signal = IPCSignal(
Expand Down Expand Up @@ -1133,6 +1142,7 @@ def _stop_profile(self):
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=self.ipc_signal_suffix,
)
self.launched_cache_manager_signal.value[0] = 1

def check_health(self, time_interval_threashold=30):
"""
Expand Down Expand Up @@ -1168,6 +1178,9 @@ def detect_thread():
self.worker_init_status["layer_loadding"] = progress
if self.worker_init_status["layer_loadding"] == self.cfg.model_config.num_layers - 1:
self.worker_init_status["finished"] = True
elif match := re.search(r"num_blocks_global", line):
if self.do_profile:
self._stop_profile()

self.checking_worker_status_thread = threading.Thread(target=detect_thread, daemon=True)
self.checking_worker_status_thread.start()
Comment on lines 1181 to 1183
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几行也需要删掉

Copy link
Contributor Author

@zeroRains zeroRains Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个删掉了之后,engine启动那边加载模型的进度条就没有了,只能在workerlog.0中看到加载进度。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前版本engine感知到worker启动结束,也需要这个线程去判断,是需要更改现有engine感知worker启动的方式吗?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

范围标错了

Expand Down
6 changes: 5 additions & 1 deletion fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,9 @@ def capture_model(self) -> None:
time_before_capture = time.perf_counter()
expected_decode_len = 1
capture_sizes = self.cudagraph_capture_sizes.copy()
need_init_cache = "caches" not in self.share_inputs
if need_init_cache:
self.initialize_kv_cache()
for batch_size in sorted(capture_sizes, reverse=True):
self._dummy_run(
num_tokens=self.parallel_config.max_num_batched_tokens,
Expand All @@ -1007,7 +1010,8 @@ def capture_model(self) -> None:
expected_decode_len=expected_decode_len,
)
logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}")

if need_init_cache:
self.clear_cache()
time_after_capture = time.perf_counter()
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")

Expand Down
19 changes: 18 additions & 1 deletion fastdeploy/worker/worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,9 +400,26 @@ def determine_num_available_blocks(self) -> None:
self.get_profile_block_num_signal.value[0] = num_blocks_local
else:
num_blocks_local = self.fd_config.parallel_config.total_block_num
# logger.info will write in worker_process.log
# Need `print` to triger engine->check_worker_initialize_status->detect_thread
print(f"------- num_blocks_global: {num_blocks_local} --------")
# 4. Updata share inputs
self.worker.reinitialize_kv_cache(num_gpu_blocks=num_blocks_local)

def graph_optimize_and_warm_up_model(self) -> None:
if self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed":
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
self.launched_cache_manager_signal = IPCSignal(
name="launched_cache_manager_signal",
array=launched_cache_manager_signal_data,
dtype=np.int32,
suffix=self.parallel_config.engine_pid,
create=False,
)
while np.any(self.launched_cache_manager_signal.value[0] <= 0):
time.sleep(0.01)
self.worker.graph_optimize_and_warm_up_model()

def init_device(self) -> None:
"""Initialize device and Construct model runner"""
self.worker.init_device()
Expand Down Expand Up @@ -718,7 +735,7 @@ def run_worker_proc() -> None:
worker_proc.determine_num_available_blocks()

# Trigger CUDAGraph capture
worker_proc.worker.graph_optimize_and_warm_up_model()
worker_proc.graph_optimize_and_warm_up_model()

# Initialize health status
worker_proc.init_health_status()
Expand Down
Loading