Skip to content

Commit b4ffda7

Browse files
🙋 Add Optional Eager Execution Mode for vLLM Serving (#3335)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
1 parent 0dad4eb commit b4ffda7

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

trl/scripts/vllm_serve.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ class ScriptArguments:
174174
enable_prefix_caching (`bool` or `None`, *optional*, defaults to `None`):
175175
Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support
176176
this feature.
177+
enforce_eager (`bool` or `None`, *optional*, defaults to `None`):
178+
Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the
179+
model in eager mode. If `False` (default behavior), we will use CUDA graph and eager execution in hybrid.
177180
"""
178181

179182
model: str = field(metadata={"help": "Model name or path to load the model from."})
@@ -224,6 +227,14 @@ class ScriptArguments:
224227
"hardware support this feature."
225228
},
226229
)
230+
enforce_eager: Optional[bool] = field(
231+
default=None,
232+
metadata={
233+
"help": "Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always "
234+
"execute the model in eager mode. If `False` (default behavior), we will use CUDA graph and eager "
235+
"execution in hybrid."
236+
},
237+
)
227238

228239

229240
def main(script_args: ScriptArguments):
@@ -250,6 +261,7 @@ def main(script_args: ScriptArguments):
250261
revision=script_args.revision,
251262
tensor_parallel_size=script_args.tensor_parallel_size,
252263
gpu_memory_utilization=script_args.gpu_memory_utilization,
264+
enforce_eager=script_args.enforce_eager,
253265
dtype=script_args.dtype,
254266
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
255267
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.

0 commit comments

Comments
 (0)