15
15
import atexit
16
16
import logging
17
17
import time
18
- from typing import Optional
18
+ from typing import Optional , Union
19
+ from urllib .parse import urlparse
19
20
20
21
import torch
21
22
from torch import nn
@@ -47,10 +48,12 @@ class VLLMClient:
47
48
weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`.
48
49
49
50
Args:
51
+ base_url (`str`, *optional*, defaults to `None`):
52
+ Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, host and server_port are ignored.
50
53
host (`str`, *optional*, defaults to `"0.0.0.0"`):
51
- IP address of the vLLM server.
54
+ IP address of the vLLM server. Ignored if `base_url` is provided.
52
55
server_port (`int`, *optional*, defaults to `8000`):
53
- Port number of the vLLM server.
56
+ Port number of the vLLM server. Ignored if `base_url` is provided.
54
57
group_port (`int`, *optional*, defaults to `51216`):
55
58
Port number for the weight update group.
56
59
connection_timeout (`float`, *optional*, defaults to `0.0`):
@@ -67,8 +70,25 @@ class VLLMClient:
67
70
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
68
71
```
69
72
70
- Use the client to generate completions and update model weights :
73
+ There are two ways to initialize the client :
71
74
75
+ 1. Using base_url:
76
+ ```python
77
+ >>> from trl.extras.vllm_client import VLLMClient
78
+ >>> # Connect to a local server
79
+ >>> client = VLLMClient(base_url="http://localhost:8000")
80
+ >>> # Or connect to a remote server
81
+ >>> client = VLLMClient(base_url="http://192.168.1.100:8000")
82
+ ```
83
+ 2. Using host and server_port:
84
+ ```python
85
+ >>> from trl.extras.vllm_client import VLLMClient
86
+ >>> # Connect to a local server
87
+ >>> client = VLLMClient(host="localhost", server_port=8000)
88
+ >>> # Or connect to a remote server
89
+ >>> client = VLLMClient(host="192.168.1.100", server_port=8000)
90
+ ```
91
+ Use the client to generate completions and update model weights:
72
92
```python
73
93
>>> from trl.extras.vllm_client import VLLMClient
74
94
>>> client = VLLMClient()
@@ -78,25 +98,37 @@ class VLLMClient:
78
98
79
99
>>> from transformers import AutoModelForCausalLM
80
100
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B", device_map="cuda")
101
+ >>> client.init_communicator()
81
102
>>> client.update_model_params(model)
82
103
```
83
104
"""
84
105
85
106
def __init__ (
86
- self , host : str = "0.0.0.0" , server_port : int = 8000 , group_port : int = 51216 , connection_timeout : float = 0.0
107
+ self ,
108
+ base_url : Optional [str ] = None ,
109
+ host : str = "0.0.0.0" ,
110
+ server_port : int = 8000 ,
111
+ group_port : int = 51216 ,
112
+ connection_timeout : float = 0.0
87
113
):
88
114
if not is_requests_available ():
89
115
raise ImportError ("requests is not installed. Please install it with `pip install requests`." )
90
116
if not is_vllm_available ():
91
117
raise ImportError ("vLLM is not installed. Please install it with `pip install vllm`." )
92
118
93
119
self .session = requests .Session ()
94
- self .host = host
95
- self .server_port = server_port
120
+
121
+ if base_url is not None :
122
+ # Parse the base_url to extract host and port
123
+ parsed_url = urlparse (base_url )
124
+ scheme = parsed_url .scheme or "http"
125
+ self .base_url = f"{ scheme } ://{ parsed_url .netloc } { parsed_url .path } "
126
+ else :
127
+ self .host = host
128
+ self .server_port = server_port
129
+ self .base_url = f"http://{ self .host } :{ self .server_port } "
96
130
self .group_port = group_port
97
131
self .check_server (connection_timeout ) # check server and fail after timeout
98
- self .init_communicator ()
99
- atexit .register (self .close_communicator ) # when the client object is deleted, close the weight update group
100
132
101
133
def check_server (self , total_timeout : float = 0.0 , retry_interval : float = 2.0 ):
102
134
"""
@@ -109,7 +141,7 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
109
141
total_timeout (`float`, *optional*, defaults to `0.0`):
110
142
Total timeout duration in seconds.
111
143
"""
112
- url = f"http:// { self .host } : { self . server_port } /health/"
144
+ url = f"{ self .base_url } /health/"
113
145
start_time = time .time () # Record the start time
114
146
115
147
while True :
@@ -120,7 +152,7 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
120
152
elapsed_time = time .time () - start_time
121
153
if elapsed_time >= total_timeout :
122
154
raise ConnectionError (
123
- f"The vLLM server can't be reached at { self .host } : { self . server_port } after { total_timeout } "
155
+ f"The vLLM server can't be reached at { self .base_url } after { total_timeout } "
124
156
"seconds. Make sure the server is running by running `trl vllm-serve`."
125
157
) from exc
126
158
else :
@@ -171,7 +203,7 @@ def generate(
171
203
`list[list[int]]`:
172
204
List of lists of token IDs representing the model-generated completions for each prompt.
173
205
"""
174
- url = f"http:// { self .host } : { self . server_port } /generate/"
206
+ url = f"{ self .base_url } /generate/"
175
207
response = self .session .post (
176
208
url ,
177
209
json = {
@@ -195,28 +227,36 @@ def init_communicator(self):
195
227
"""
196
228
Initializes the weight update group in a distributed setup for model synchronization.
197
229
"""
198
- # Get the tensor parallel size from the server
199
- url = f"http:// { self .host } : { self . server_port } /get_tensor_parallel_size /"
230
+ # Get the world size from the server
231
+ url = f"{ self .base_url } /get_world_size /"
200
232
response = requests .get (url )
201
233
if response .status_code == 200 :
202
- tensor_parallel_size = response .json ()["tensor_parallel_size " ]
234
+ vllm_world_size = response .json ()["world_size " ]
203
235
else :
204
236
raise Exception (f"Request failed: { response .status_code } , { response .text } " )
205
237
206
- world_size = tensor_parallel_size + 1
207
- self .rank = tensor_parallel_size # The client's rank is the last process
238
+ world_size = vllm_world_size + 1 # add the client to the world
239
+ self .rank = vllm_world_size # the client's rank is the last process
208
240
209
241
# Initialize weight update group
210
- url = f"http:// { self .host } : { self . server_port } /init_communicator/"
242
+ url = f"{ self .base_url } /init_communicator/"
211
243
# In the server side, the host is set to 0.0.0.0
212
244
response = self .session .post (url , json = {"host" : "0.0.0.0" , "port" : self .group_port , "world_size" : world_size })
213
245
if response .status_code != 200 :
214
246
raise Exception (f"Request failed: { response .status_code } , { response .text } " )
215
247
248
+ # Brief delay to allow server initialization. While not strictly required (client socket will retry on
249
+ # connection failure), this prevents log warnings like:
250
+ # [W416 23:24:57.460001114 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
251
+ time .sleep (0.1 )
252
+
216
253
# Set up the communication group for weight broadcasting
217
254
pg = StatelessProcessGroup .create (host = self .host , port = self .group_port , rank = self .rank , world_size = world_size )
218
255
self .pynccl_comm = PyNcclCommunicator (pg , device = 0 )
219
256
257
+ # When the client object is deleted, close the weight update group
258
+ atexit .register (self .close_communicator )
259
+
220
260
def update_named_param (self , name : str , weights : torch .Tensor ):
221
261
"""
222
262
Updates a specific named parameter in the model and broadcasts it to other processes.
@@ -279,6 +319,7 @@ def close_communicator(self):
279
319
from vllm import SamplingParams
280
320
281
321
client = VLLMClient ()
322
+ client .init_communicator ()
282
323
283
324
# Generate completions
284
325
responses = client .generate (["Hello, AI!" , "Tell me a joke" ], n = 4 , max_tokens = 32 , sampling_params = SamplingParams ())
0 commit comments