Skip to content

Commit a232a1a

Browse files
committed
- Introduce new parameter vllm_server_base_url in GRPOConfig
- Update `VLLMClient` initialization to support base URL- Modify existing parameters `vllm_server_host` and `vllm_server_port` to be ignored if base URL is provided
1 parent 9c681be commit a232a1a

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

trl/trainer/grpo_config.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,13 @@ class GRPOConfig(TrainingArguments):
8585
use_vllm (`bool`, *optional*, defaults to `False`):
8686
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
8787
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
88+
vllm_server_base_url (`str` or `None`, *optional*, defaults to `None`):
89+
Base URL for the vLLM server (e.g., "http://localhost:8000"). If provided, vllm_server_host and
90+
vllm_server_port are ignored.
8891
vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`):
89-
Host of the vLLM server to connect to.
92+
Host of the vLLM server to connect to. Ignored if vllm_server_base_url is provided.
9093
vllm_server_port (`int`, *optional*, defaults to `8000`):
91-
Port of the vLLM server to connect to.
94+
Port of the vLLM server to connect to. Ignored if vllm_server_base_url is provided.
9295
vllm_server_timeout (`float`, *optional*, defaults to `120.0`):
9396
Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the
9497
timeout, a `ConnectionError` is raised.
@@ -270,13 +273,20 @@ class GRPOConfig(TrainingArguments):
270273
"running. To run the server, install vLLM (`pip install vllm`) and run `trl vllm-serve`."
271274
},
272275
)
276+
vllm_server_base_url: Optional[str] = field(
277+
default=None,
278+
metadata={
279+
"help": "Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, vllm_server_host and "
280+
"vllm_server_port are ignored."
281+
},
282+
)
273283
vllm_server_host: str = field(
274284
default="0.0.0.0",
275-
metadata={"help": "Host of the vLLM server to connect to."},
285+
metadata={"help": "Host of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."},
276286
)
277287
vllm_server_port: int = field(
278288
default=8000,
279-
metadata={"help": "Port of the vLLM server to connect to."},
289+
metadata={"help": "Port of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."},
280290
)
281291
vllm_server_timeout: float = field(
282292
default=120.0,

trl/trainer/grpo_trainer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -603,9 +603,16 @@ def data_collator(features): # No data collation is needed in GRPO
603603
)
604604

605605
if self.accelerator.is_main_process:
606-
self.vllm_client = VLLMClient(
607-
args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout
608-
)
606+
if args.vllm_server_base_url is not None:
607+
self.vllm_client = VLLMClient(
608+
base_url=args.vllm_server_base_url, connection_timeout=args.vllm_server_timeout
609+
)
610+
else:
611+
self.vllm_client = VLLMClient(
612+
host=args.vllm_server_host,
613+
server_port=args.vllm_server_port,
614+
connection_timeout=args.vllm_server_timeout
615+
)
609616

610617
# vLLM specific sampling arguments
611618
self.guided_decoding_regex = args.vllm_guided_decoding_regex

0 commit comments

Comments
 (0)