@@ -85,10 +85,13 @@ class GRPOConfig(TrainingArguments):
85
85
use_vllm (`bool`, *optional*, defaults to `False`):
86
86
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
87
87
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
88
+ vllm_server_base_url (`str` or `None`, *optional*, defaults to `None`):
89
+ Base URL for the vLLM server (e.g., "http://localhost:8000"). If provided, vllm_server_host and
90
+ vllm_server_port are ignored.
88
91
vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`):
89
- Host of the vLLM server to connect to.
92
+ Host of the vLLM server to connect to. Ignored if vllm_server_base_url is provided.
90
93
vllm_server_port (`int`, *optional*, defaults to `8000`):
91
- Port of the vLLM server to connect to.
94
+ Port of the vLLM server to connect to. Ignored if vllm_server_base_url is provided.
92
95
vllm_server_timeout (`float`, *optional*, defaults to `120.0`):
93
96
Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the
94
97
timeout, a `ConnectionError` is raised.
@@ -270,13 +273,20 @@ class GRPOConfig(TrainingArguments):
270
273
"running. To run the server, install vLLM (`pip install vllm`) and run `trl vllm-serve`."
271
274
},
272
275
)
276
+ vllm_server_base_url : Optional [str ] = field (
277
+ default = None ,
278
+ metadata = {
279
+ "help" : "Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, vllm_server_host and "
280
+ "vllm_server_port are ignored."
281
+ },
282
+ )
273
283
vllm_server_host : str = field (
274
284
default = "0.0.0.0" ,
275
- metadata = {"help" : "Host of the vLLM server to connect to." },
285
+ metadata = {"help" : "Host of the vLLM server to connect to. Ignored if vllm_server_base_url is provided. " },
276
286
)
277
287
vllm_server_port : int = field (
278
288
default = 8000 ,
279
- metadata = {"help" : "Port of the vLLM server to connect to." },
289
+ metadata = {"help" : "Port of the vLLM server to connect to. Ignored if vllm_server_base_url is provided. " },
280
290
)
281
291
vllm_server_timeout : float = field (
282
292
default = 120.0 ,
0 commit comments