15
15
import atexit
16
16
import logging
17
17
import time
18
- from typing import Optional
18
+ from typing import Optional , Union
19
+ from urllib .parse import urlparse
19
20
20
21
import torch
21
22
from torch import nn
22
23
23
24
from ..import_utils import is_requests_available , is_vllm_ascend_available , is_vllm_available
24
25
25
-
26
26
if is_requests_available ():
27
27
import requests
28
28
from requests import ConnectionError
29
29
30
-
31
30
if is_vllm_available ():
32
31
from vllm .distributed .device_communicators .pynccl import PyNcclCommunicator
33
32
from vllm .distributed .utils import StatelessProcessGroup
34
33
35
34
if is_vllm_ascend_available ():
36
35
from vllm_ascend .distributed .device_communicators .pyhccl import PyHcclCommunicator as PyNcclCommunicator
37
36
38
-
39
37
logger = logging .getLogger (__name__ )
40
38
41
39
@@ -47,10 +45,12 @@ class VLLMClient:
47
45
weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`.
48
46
49
47
Args:
48
+ base_url (`str`, *optional*, defaults to `None`):
49
+ Base URL for the vLLM server (e.g., "http://localhost:8000"). If provided, host and server_port are ignored.
50
50
host (`str`, *optional*, defaults to `"0.0.0.0"`):
51
- IP address of the vLLM server.
51
+ IP address of the vLLM server. Ignored if base_url is provided.
52
52
server_port (`int`, *optional*, defaults to `8000`):
53
- Port number of the vLLM server.
53
+ Port number of the vLLM server. Ignored if base_url is provided.
54
54
group_port (`int`, *optional*, defaults to `51216`):
55
55
Port number for the weight update group.
56
56
connection_timeout (`float`, *optional*, defaults to `0.0`):
@@ -67,11 +67,29 @@ class VLLMClient:
67
67
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
68
68
```
69
69
70
- Use the client to generate completions and update model weights:
70
+ There are two ways to initialize the client:
71
+
72
+ 1. Using base_url:
73
+ ```python
74
+ >>> from trl.extras.vllm_client import VLLMClient
75
+ >>> # Connect to a local server
76
+ >>> client = VLLMClient(base_url="http://localhost:8000")
77
+ >>> # Or connect to a remote server
78
+ >>> client = VLLMClient(base_url="http://192.168.1.100:8000")
79
+ ```
71
80
81
+ 2. Using host and server_port:
72
82
```python
73
83
>>> from trl.extras.vllm_client import VLLMClient
74
- >>> client = VLLMClient()
84
+ >>> # Connect to a local server
85
+ >>> client = VLLMClient(host="localhost", server_port=8000)
86
+ >>> # Or connect to a remote server
87
+ >>> client = VLLMClient(host="192.168.1.100", server_port=8000)
88
+ ```
89
+
90
+ Use the client to generate completions and update model weights:
91
+
92
+ ```python
75
93
>>> client.generate(["Hello, AI!", "Tell me a joke"])
76
94
[[2980, 498, 1492, 752, 448, 264, 13027, 8645, 30, 358, 2776, 4460, 311, 3270, 264, 2025],
77
95
[911, 7988, 1251, 382, 3838, 653, 498, 1618, 4325, 879, 2581, 20027, 264, 21428, 30, 362]]
@@ -83,16 +101,30 @@ class VLLMClient:
83
101
"""
84
102
85
103
def __init__ (
86
- self , host : str = "0.0.0.0" , server_port : int = 8000 , group_port : int = 51216 , connection_timeout : float = 0.0
104
+ self ,
105
+ base_url : Optional [str ] = None ,
106
+ host : str = "0.0.0.0" ,
107
+ server_port : int = 8000 ,
108
+ group_port : int = 51216 ,
109
+ connection_timeout : float = 0.0
87
110
):
88
111
if not is_requests_available ():
89
112
raise ImportError ("requests is not installed. Please install it with `pip install requests`." )
90
113
if not is_vllm_available ():
91
114
raise ImportError ("vLLM is not installed. Please install it with `pip install vllm`." )
92
115
93
116
self .session = requests .Session ()
94
- self .host = host
95
- self .server_port = server_port
117
+
118
+ if base_url is not None :
119
+ # Parse the base_url to extract host and port
120
+ parsed_url = urlparse (base_url )
121
+ scheme = parsed_url .scheme or "http"
122
+ self .base_url = f"{ scheme } ://{ parsed_url .netloc } { parsed_url .path } "
123
+ else :
124
+ self .host = host
125
+ self .server_port = server_port
126
+ self .base_url = f"http://{ self .host } :{ self .server_port } "
127
+
96
128
self .group_port = group_port
97
129
self .check_server (connection_timeout ) # check server and fail after timeout
98
130
self .init_communicator ()
@@ -109,7 +141,7 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
109
141
total_timeout (`float`, *optional*, defaults to `0.0`):
110
142
Total timeout duration in seconds.
111
143
"""
112
- url = f"http:// { self .host } : { self . server_port } /health/"
144
+ url = f"{ self .base_url } /health/"
113
145
start_time = time .time () # Record the start time
114
146
115
147
while True :
@@ -120,7 +152,7 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
120
152
elapsed_time = time .time () - start_time
121
153
if elapsed_time >= total_timeout :
122
154
raise ConnectionError (
123
- f"The vLLM server can't be reached at { self .host } : { self . server_port } after { total_timeout } "
155
+ f"The vLLM server can't be reached at { self .base_url } after { total_timeout } "
124
156
"seconds. Make sure the server is running by running `trl vllm-serve`."
125
157
) from exc
126
158
else :
@@ -171,7 +203,7 @@ def generate(
171
203
`list[list[int]]`:
172
204
List of lists of token IDs representing the model-generated completions for each prompt.
173
205
"""
174
- url = f"http:// { self .host } : { self . server_port } /generate/"
206
+ url = f"{ self .base_url } /generate/"
175
207
response = self .session .post (
176
208
url ,
177
209
json = {
@@ -196,7 +228,7 @@ def init_communicator(self):
196
228
Initializes the weight update group in a distributed setup for model synchronization.
197
229
"""
198
230
# Get the tensor parallel size from the server
199
- url = f"http:// { self .host } : { self . server_port } /get_tensor_parallel_size/"
231
+ url = f"{ self .base_url } /get_tensor_parallel_size/"
200
232
response = requests .get (url )
201
233
if response .status_code == 200 :
202
234
tensor_parallel_size = response .json ()["tensor_parallel_size" ]
@@ -207,7 +239,7 @@ def init_communicator(self):
207
239
self .rank = tensor_parallel_size # The client's rank is the last process
208
240
209
241
# Initialize weight update group
210
- url = f"http:// { self .host } : { self . server_port } /init_communicator/"
242
+ url = f"{ self .base_url } /init_communicator/"
211
243
# In the server side, the host is set to 0.0.0.0
212
244
response = self .session .post (url , json = {"host" : "0.0.0.0" , "port" : self .group_port , "world_size" : world_size })
213
245
if response .status_code != 200 :
@@ -228,7 +260,7 @@ def update_named_param(self, name: str, weights: torch.Tensor):
228
260
Tensor containing the updated weights.
229
261
"""
230
262
dtype , shape = str (weights .dtype ), tuple (weights .shape )
231
- url = f"http:// { self .host } : { self . server_port } /update_named_param/"
263
+ url = f"{ self .base_url } /update_named_param/"
232
264
response = self .session .post (url , json = {"name" : name , "dtype" : dtype , "shape" : shape })
233
265
if response .status_code != 200 :
234
266
raise Exception (f"Request failed: { response .status_code } , { response .text } " )
@@ -253,7 +285,7 @@ def reset_prefix_cache(self):
253
285
"""
254
286
Resets the prefix cache for the model.
255
287
"""
256
- url = f"http:// { self .host } : { self . server_port } /reset_prefix_cache/"
288
+ url = f"{ self .base_url } /reset_prefix_cache/"
257
289
response = self .session .post (url )
258
290
if response .status_code != 200 :
259
291
raise Exception (f"Request failed: { response .status_code } , { response .text } " )
@@ -262,7 +294,7 @@ def close_communicator(self):
262
294
"""
263
295
Closes the weight update group and cleans up the communication group.
264
296
"""
265
- url = f"http:// { self .host } : { self . server_port } /close_communicator/"
297
+ url = f"{ self .base_url } /close_communicator/"
266
298
267
299
try :
268
300
response = self .session .post (url )
@@ -278,10 +310,19 @@ def close_communicator(self):
278
310
if __name__ == "__main__" :
279
311
from vllm import SamplingParams
280
312
281
- client = VLLMClient ()
313
+ # Example 1: Initialize with base_url (recommended)
314
+ client1 = VLLMClient (base_url = "http://0.0.0.0:8000" )
315
+ print ("Client 1 initialized with base_url" )
316
+
317
+ # Example 2: Initialize with host and port
318
+ client2 = VLLMClient (host = "0.0.0.0" , server_port = 8000 )
319
+ print ("Client 2 initialized with host and port" )
320
+
321
+ # Choose one client to use for the example
322
+ client = client1
282
323
283
324
# Generate completions
284
- responses = client .generate (["Hello, AI!" , "Tell me a joke" ], n = 4 , max_tokens = 32 , sampling_params = SamplingParams () )
325
+ responses = client .generate (["Hello, AI!" , "Tell me a joke" ], n = 4 , max_tokens = 32 )
285
326
print ("Responses:" , responses ) # noqa
286
327
287
328
# Update model weights
0 commit comments