You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
If all the above steps execute successfully, FastDeploy is installed correctly.
120
120
121
-
## How to deploy services on kunlunxin XPU
122
-
Refer to [**Supported Models and Service Deployment**](../../usage/kunlunxin_xpu_deployment.md) for the details about the supported models and the way to deploy services on kunlunxin XPU.
121
+
## How to deploy services on Kunlunxin XPU
122
+
Refer to [**Supported Models and Service Deployment**](../../usage/kunlunxin_xpu_deployment.md) for the details about the supported models and the way to deploy services on Kunlunxin XPU.
Copy file name to clipboardExpand all lines: docs/parameters.md
+38-7Lines changed: 38 additions & 7 deletions
Original file line number
Diff line number
Diff line change
@@ -33,9 +33,8 @@ When using FastDeploy to deploy models (including offline inference and service
33
33
|```long_prefill_token_threshold```|`int`| When Chunked Prefill is enabled, requests with token count exceeding this value are considered long requests, default: max_model_len*0.04 |
34
34
|```static_decode_blocks```|`int`| During inference, each request is forced to allocate corresponding number of blocks from Prefill's KVCache for Decode use, default: 2 |
35
35
|```reasoning_parser```|`str`| Specify the reasoning parser to extract reasoning content from model output |
36
-
|```enable_static_graph_inference```|`bool`| Whether to use static graph inference mode, default: False |
37
36
|```use_cudagraph```|`bool`| Whether to use cuda graph, default: False |
38
-
|```max_capture_batch_size```|`int`| When cuda graph is enabled, maximum batch size of captured cuda graph, default: 64|
37
+
|```graph_optimization_config```| `str` | Parameters related to graph optimization can be configured, with default values of'{"use_cudagraph":false, "graph_opt_level":0, "cudagraph_capture_sizes": null }' |
|```innode_prefill_ports```|`str`| Internal engine startup ports for prefill instances (only required for single-machine PD separation), default: None |
@@ -72,20 +71,53 @@ When `enable_chunked_prefill` is enabled, the service processes long input seque
72
71
To optimize scheduling priority for short requests, new `max_long_partial_prefills` and `long_prefill_token_threshold` parameter combination is added. The former limits the number of long requests in single prefill batch, the latter defines the token threshold for long requests. The system will prioritize batch space for short requests, thereby reducing short request latency in mixed workload scenarios while maintaining stable throughput.
73
72
74
73
## 4. GraphOptimizationBackend related configuration parameters
74
+
Currently, only user configuration of the following parameters is supported:
75
+
-`use_cudagraph` : bool = False
76
+
-`graph_optimization_config` : Dict[str, Any]
77
+
-`graph_opt_level`: int = 0
78
+
-`use_cudagraph`: bool = False
79
+
-`cudagraph_capture_sizes` : List[int] = None
75
80
76
-
### Static graph inference related parameters
81
+
CudaGrpah can be enabled by setting `--use-cudagraph` or `--graph-optimization-config '{"use_cudagraph":true}'`. Using two different methods to set the use graph simultaneously may cause conflicts.
82
+
83
+
84
+
The `graph_opt_level` parameter within `--graph-optimization-config` is used to configure the graph optimization level, with the following available options:
85
+
-`0`: Use Dynamic compute graph, default to 0
86
+
-`1`: Use Static compute graph, during the initialization phase, Paddle API will be used to convert the dynamic image into a static image
87
+
-`2`: Base on Static compute graph, use the complier(CINN, Compiler Infrastructure for Neural Networks) of Paddle to compile and optimize
88
+
89
+
In general, static graphs have lower Kernel Launch overhead than dynamic graphs, and it is recommended to use static graphs.
90
+
For adapted models, FastDeploy's CudaGraph * * can support both dynamic and static graphs * * simultaneously.
91
+
92
+
When CudaGraph is enabled in the default configuration, a list of Batch Sizes that CudaGraph needs to capture will be automatically set based on the 'max_num_deqs' parameter. The logic for generating the list of Batch Sizes that need to be captured is as follows:
93
+
94
+
1. Generate a candidate list with a range of [1,1024] Batch Size.
95
+
```
96
+
# Batch Size [1, 2, 4, 8, 16, ... 120, 128]
97
+
candidate_capture_sizes = [1, 2, 4] + [8 * i for i in range(1, 17)]
98
+
# Batch Size (128, 144, ... 240, 256]
99
+
candidate_capture_sizes += [16 * i for i in range(9, 17)]
100
+
# Batch Size (256, 288, ... 992, 1024]
101
+
candidate_capture_sizes += [32 * i for i in range(17, 33)]
102
+
```
103
+
2. Crop the candidate list based on the user set 'max_num_deqs' to obtain a CudaGraph capture list with a range of [1,' max_num_deqs'].
104
+
105
+
Users can also customize the batch size list that needs to be captured by CudaGraph through the parameter `cudagraph_capture_sizes` in`--graph-optimization-config`:
- When ```enable_static_graph_inference``` is enabled, dynamic-to-static graph conversion will be performed, using static graph for inference.
79
110
80
111
### CudaGraph related parameters
81
112
82
-
For adapted models, FastDeploy's CudaGraph can support both dynamic and static graphs. Using CudaGraph incurs some additional memory overhead, divided into two categories in FastDeploy:
113
+
Using CudaGraph incurs some additional memory overhead, divided into two categories in FastDeploy:
83
114
* Additional input Buffer overhead
84
115
* CudaGraph uses dedicated memory pool, thus holding some intermediate activation memory isolated from main framework
85
116
86
117
FastDeploy initialization sequence first uses `gpu_memory_utilization` parameter to calculate available memory for `KVCache`, after initializing `KVCache` then uses remaining memory to initialize CudaGraph. Since CudaGraph is not enabled by default currently, using default startup parameters may encounter `Out of memory` errors, can try following solutions:
87
118
* Lower `gpu_memory_utilization` value, reserve more memory for CudaGraph.
88
-
* Lower `max_capture_batch_size` value, reduce CudaGraph memory usage, but also reduce CudaGraph usage during inference.
119
+
* Lower `max_num_seqs` to decrease the maximum concurrency.
120
+
* Customize the batch size list that CudaGraph needs to capture through `graph_optimization_config`, and reduce the number of captured graphs by using `cudagraph_capture_sizes`
89
121
90
122
- Before use, must ensure loaded model is properly decorated with ```@support_graph_optimization```.
- When ```use_cudagraph``` is enabled, currently only supports single-GPU inference, i.e. ```tensor_parallel_size``` set to 1.
120
152
- When ```use_cudagraph``` is enabled, cannot enable ```enable_prefix_caching``` or ```enable_chunked_prefill```.
121
-
- When ```use_cudagraph``` is enabled, batches with size ≤ ```max_capture_batch_size``` will be executed by CudaGraph, batches > ```max_capture_batch_size``` will be executed by original dynamic/static graph. To have all batch sizes executed by CudaGraph, ```max_capture_batch_size``` value should match ```max_num_seqs```. ```max_capture_batch_size``` > ```max_num_seqs``` will cause waste by capturing batches that won't be encountered during inference, occupying more time and memory.
0 commit comments