Skip to content

Commit 38c9c8e

Browse files
committed
feat(vllm_client): add base_url parameter for vLLM server connection
- Introduce base_url parameter to simplify server connection - Update __init__ method to support both base_url and host+port configurations - Modify URL construction in various methods to use base_url - Update documentation to include new initialization examples
1 parent 294f35b commit 38c9c8e

File tree

1 file changed

+62
-21
lines changed

1 file changed

+62
-21
lines changed

trl/extras/vllm_client.py

+62-21
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,25 @@
1515
import atexit
1616
import logging
1717
import time
18-
from typing import Optional
18+
from typing import Optional, Union
19+
from urllib.parse import urlparse
1920

2021
import torch
2122
from torch import nn
2223

2324
from ..import_utils import is_requests_available, is_vllm_ascend_available, is_vllm_available
2425

25-
2626
if is_requests_available():
2727
import requests
2828
from requests import ConnectionError
2929

30-
3130
if is_vllm_available():
3231
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
3332
from vllm.distributed.utils import StatelessProcessGroup
3433

3534
if is_vllm_ascend_available():
3635
from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator
3736

38-
3937
logger = logging.getLogger(__name__)
4038

4139

@@ -47,10 +45,12 @@ class VLLMClient:
4745
weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`.
4846
4947
Args:
48+
base_url (`str`, *optional*, defaults to `None`):
49+
Base URL for the vLLM server (e.g., "http://localhost:8000"). If provided, host and server_port are ignored.
5050
host (`str`, *optional*, defaults to `"0.0.0.0"`):
51-
IP address of the vLLM server.
51+
IP address of the vLLM server. Ignored if base_url is provided.
5252
server_port (`int`, *optional*, defaults to `8000`):
53-
Port number of the vLLM server.
53+
Port number of the vLLM server. Ignored if base_url is provided.
5454
group_port (`int`, *optional*, defaults to `51216`):
5555
Port number for the weight update group.
5656
connection_timeout (`float`, *optional*, defaults to `0.0`):
@@ -67,11 +67,29 @@ class VLLMClient:
6767
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
6868
```
6969
70-
Use the client to generate completions and update model weights:
70+
There are two ways to initialize the client:
71+
72+
1. Using base_url:
73+
```python
74+
>>> from trl.extras.vllm_client import VLLMClient
75+
>>> # Connect to a local server
76+
>>> client = VLLMClient(base_url="http://localhost:8000")
77+
>>> # Or connect to a remote server
78+
>>> client = VLLMClient(base_url="http://192.168.1.100:8000")
79+
```
7180
81+
2. Using host and server_port:
7282
```python
7383
>>> from trl.extras.vllm_client import VLLMClient
74-
>>> client = VLLMClient()
84+
>>> # Connect to a local server
85+
>>> client = VLLMClient(host="localhost", server_port=8000)
86+
>>> # Or connect to a remote server
87+
>>> client = VLLMClient(host="192.168.1.100", server_port=8000)
88+
```
89+
90+
Use the client to generate completions and update model weights:
91+
92+
```python
7593
>>> client.generate(["Hello, AI!", "Tell me a joke"])
7694
[[2980, 498, 1492, 752, 448, 264, 13027, 8645, 30, 358, 2776, 4460, 311, 3270, 264, 2025],
7795
[911, 7988, 1251, 382, 3838, 653, 498, 1618, 4325, 879, 2581, 20027, 264, 21428, 30, 362]]
@@ -83,16 +101,30 @@ class VLLMClient:
83101
"""
84102

85103
def __init__(
86-
self, host: str = "0.0.0.0", server_port: int = 8000, group_port: int = 51216, connection_timeout: float = 0.0
104+
self,
105+
base_url: Optional[str] = None,
106+
host: str = "0.0.0.0",
107+
server_port: int = 8000,
108+
group_port: int = 51216,
109+
connection_timeout: float = 0.0
87110
):
88111
if not is_requests_available():
89112
raise ImportError("requests is not installed. Please install it with `pip install requests`.")
90113
if not is_vllm_available():
91114
raise ImportError("vLLM is not installed. Please install it with `pip install vllm`.")
92115

93116
self.session = requests.Session()
94-
self.host = host
95-
self.server_port = server_port
117+
118+
if base_url is not None:
119+
# Parse the base_url to extract host and port
120+
parsed_url = urlparse(base_url)
121+
scheme = parsed_url.scheme or "http"
122+
self.base_url = f"{scheme}://{parsed_url.netloc}{parsed_url.path}"
123+
else:
124+
self.host = host
125+
self.server_port = server_port
126+
self.base_url = f"http://{self.host}:{self.server_port}"
127+
96128
self.group_port = group_port
97129
self.check_server(connection_timeout) # check server and fail after timeout
98130
self.init_communicator()
@@ -109,7 +141,7 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
109141
total_timeout (`float`, *optional*, defaults to `0.0`):
110142
Total timeout duration in seconds.
111143
"""
112-
url = f"http://{self.host}:{self.server_port}/health/"
144+
url = f"{self.base_url}/health/"
113145
start_time = time.time() # Record the start time
114146

115147
while True:
@@ -120,7 +152,7 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
120152
elapsed_time = time.time() - start_time
121153
if elapsed_time >= total_timeout:
122154
raise ConnectionError(
123-
f"The vLLM server can't be reached at {self.host}:{self.server_port} after {total_timeout} "
155+
f"The vLLM server can't be reached at {self.base_url} after {total_timeout} "
124156
"seconds. Make sure the server is running by running `trl vllm-serve`."
125157
) from exc
126158
else:
@@ -171,7 +203,7 @@ def generate(
171203
`list[list[int]]`:
172204
List of lists of token IDs representing the model-generated completions for each prompt.
173205
"""
174-
url = f"http://{self.host}:{self.server_port}/generate/"
206+
url = f"{self.base_url}/generate/"
175207
response = self.session.post(
176208
url,
177209
json={
@@ -196,7 +228,7 @@ def init_communicator(self):
196228
Initializes the weight update group in a distributed setup for model synchronization.
197229
"""
198230
# Get the tensor parallel size from the server
199-
url = f"http://{self.host}:{self.server_port}/get_tensor_parallel_size/"
231+
url = f"{self.base_url}/get_tensor_parallel_size/"
200232
response = requests.get(url)
201233
if response.status_code == 200:
202234
tensor_parallel_size = response.json()["tensor_parallel_size"]
@@ -207,7 +239,7 @@ def init_communicator(self):
207239
self.rank = tensor_parallel_size # The client's rank is the last process
208240

209241
# Initialize weight update group
210-
url = f"http://{self.host}:{self.server_port}/init_communicator/"
242+
url = f"{self.base_url}/init_communicator/"
211243
# In the server side, the host is set to 0.0.0.0
212244
response = self.session.post(url, json={"host": "0.0.0.0", "port": self.group_port, "world_size": world_size})
213245
if response.status_code != 200:
@@ -228,7 +260,7 @@ def update_named_param(self, name: str, weights: torch.Tensor):
228260
Tensor containing the updated weights.
229261
"""
230262
dtype, shape = str(weights.dtype), tuple(weights.shape)
231-
url = f"http://{self.host}:{self.server_port}/update_named_param/"
263+
url = f"{self.base_url}/update_named_param/"
232264
response = self.session.post(url, json={"name": name, "dtype": dtype, "shape": shape})
233265
if response.status_code != 200:
234266
raise Exception(f"Request failed: {response.status_code}, {response.text}")
@@ -253,7 +285,7 @@ def reset_prefix_cache(self):
253285
"""
254286
Resets the prefix cache for the model.
255287
"""
256-
url = f"http://{self.host}:{self.server_port}/reset_prefix_cache/"
288+
url = f"{self.base_url}/reset_prefix_cache/"
257289
response = self.session.post(url)
258290
if response.status_code != 200:
259291
raise Exception(f"Request failed: {response.status_code}, {response.text}")
@@ -262,7 +294,7 @@ def close_communicator(self):
262294
"""
263295
Closes the weight update group and cleans up the communication group.
264296
"""
265-
url = f"http://{self.host}:{self.server_port}/close_communicator/"
297+
url = f"{self.base_url}/close_communicator/"
266298

267299
try:
268300
response = self.session.post(url)
@@ -278,10 +310,19 @@ def close_communicator(self):
278310
if __name__ == "__main__":
279311
from vllm import SamplingParams
280312

281-
client = VLLMClient()
313+
# Example 1: Initialize with base_url (recommended)
314+
client1 = VLLMClient(base_url="http://0.0.0.0:8000")
315+
print("Client 1 initialized with base_url")
316+
317+
# Example 2: Initialize with host and port
318+
client2 = VLLMClient(host="0.0.0.0", server_port=8000)
319+
print("Client 2 initialized with host and port")
320+
321+
# Choose one client to use for the example
322+
client = client1
282323

283324
# Generate completions
284-
responses = client.generate(["Hello, AI!", "Tell me a joke"], n=4, max_tokens=32, sampling_params=SamplingParams())
325+
responses = client.generate(["Hello, AI!", "Tell me a joke"], n=4, max_tokens=32)
285326
print("Responses:", responses) # noqa
286327

287328
# Update model weights

0 commit comments

Comments
 (0)