Skip to content

Commit 0bb3624

Browse files
committed
refactor
1 parent 294f35b commit 0bb3624

File tree

4 files changed

+205
-64
lines changed

4 files changed

+205
-64
lines changed

tests/test_vllm_client_server.py

+76
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,79 @@ def tearDownClass(cls):
162162
child.send_signal(signal.SIGTERM)
163163
cls.server_process.terminate()
164164
cls.server_process.wait()
165+
166+
167+
@pytest.mark.slow
168+
@require_torch_multi_gpu
169+
class TestVLLMClientServerBaseURL(unittest.TestCase):
170+
model_id = "Qwen/Qwen2.5-1.5B"
171+
172+
@classmethod
173+
def setUpClass(cls):
174+
# We want the server to run on GPU 1, so we set CUDA_VISIBLE_DEVICES to "1"
175+
env = os.environ.copy()
176+
env["CUDA_VISIBLE_DEVICES"] = "1" # Restrict to GPU 1
177+
178+
# Start the server process
179+
cls.server_process = subprocess.Popen(
180+
["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
181+
)
182+
183+
# Initialize the client with base_url
184+
cls.client = VLLMClient(base_url="http://localhost:8000", connection_timeout=120)
185+
186+
def test_generate(self):
187+
prompts = ["Hello, AI!", "Tell me a joke"]
188+
outputs = self.client.generate(prompts)
189+
190+
# Check that the output is a list
191+
self.assertIsInstance(outputs, list)
192+
193+
# Check that the number of generated sequences is equal to the number of prompts
194+
self.assertEqual(len(outputs), len(prompts))
195+
196+
# Check that the generated sequences are lists of integers
197+
for seq in outputs:
198+
self.assertTrue(all(isinstance(tok, int) for tok in seq))
199+
200+
def test_generate_with_params(self):
201+
prompts = ["Hello, AI!", "Tell me a joke"]
202+
outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)
203+
204+
# Check that the output is a list
205+
self.assertIsInstance(outputs, list)
206+
207+
# Check that the number of generated sequences is 2 times the number of prompts
208+
self.assertEqual(len(outputs), 2 * len(prompts))
209+
210+
# Check that the generated sequences are lists of integers
211+
for seq in outputs:
212+
self.assertTrue(all(isinstance(tok, int) for tok in seq))
213+
214+
# Check that the length of the generated sequences is less than or equal to 32
215+
for seq in outputs:
216+
self.assertLessEqual(len(seq), 32)
217+
218+
def test_update_model_params(self):
219+
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda")
220+
self.client.update_model_params(model)
221+
222+
def test_reset_prefix_cache(self):
223+
# Test resetting the prefix cache
224+
self.client.reset_prefix_cache()
225+
226+
@classmethod
227+
def tearDownClass(cls):
228+
super().tearDownClass()
229+
230+
# Close the client
231+
cls.client.close_communicator()
232+
233+
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
234+
# kill the server process and its children explicitly.
235+
parent = psutil.Process(cls.server_process.pid)
236+
children = parent.children(recursive=True)
237+
for child in children:
238+
child.send_signal(signal.SIGTERM)
239+
cls.server_process.terminate()
240+
cls.server_process.wait()

trl/extras/vllm_client.py

+59-18
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
import atexit
1616
import logging
1717
import time
18-
from typing import Optional
18+
from typing import Optional, Union
19+
from urllib.parse import urlparse
1920

2021
import torch
2122
from torch import nn
@@ -47,10 +48,12 @@ class VLLMClient:
4748
weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`.
4849
4950
Args:
51+
base_url (`str`, *optional*, defaults to `None`):
52+
Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, host and server_port are ignored.
5053
host (`str`, *optional*, defaults to `"0.0.0.0"`):
51-
IP address of the vLLM server.
54+
IP address of the vLLM server. Ignored if `base_url` is provided.
5255
server_port (`int`, *optional*, defaults to `8000`):
53-
Port number of the vLLM server.
56+
Port number of the vLLM server. Ignored if `base_url` is provided.
5457
group_port (`int`, *optional*, defaults to `51216`):
5558
Port number for the weight update group.
5659
connection_timeout (`float`, *optional*, defaults to `0.0`):
@@ -67,8 +70,25 @@ class VLLMClient:
6770
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
6871
```
6972
70-
Use the client to generate completions and update model weights:
73+
There are two ways to initialize the client:
7174
75+
1. Using base_url:
76+
```python
77+
>>> from trl.extras.vllm_client import VLLMClient
78+
>>> # Connect to a local server
79+
>>> client = VLLMClient(base_url="http://localhost:8000")
80+
>>> # Or connect to a remote server
81+
>>> client = VLLMClient(base_url="http://192.168.1.100:8000")
82+
```
83+
2. Using host and server_port:
84+
```python
85+
>>> from trl.extras.vllm_client import VLLMClient
86+
>>> # Connect to a local server
87+
>>> client = VLLMClient(host="localhost", server_port=8000)
88+
>>> # Or connect to a remote server
89+
>>> client = VLLMClient(host="192.168.1.100", server_port=8000)
90+
```
91+
Use the client to generate completions and update model weights:
7292
```python
7393
>>> from trl.extras.vllm_client import VLLMClient
7494
>>> client = VLLMClient()
@@ -78,25 +98,37 @@ class VLLMClient:
7898
7999
>>> from transformers import AutoModelForCausalLM
80100
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B", device_map="cuda")
101+
>>> client.init_communicator()
81102
>>> client.update_model_params(model)
82103
```
83104
"""
84105

85106
def __init__(
86-
self, host: str = "0.0.0.0", server_port: int = 8000, group_port: int = 51216, connection_timeout: float = 0.0
107+
self,
108+
base_url: Optional[str] = None,
109+
host: str = "0.0.0.0",
110+
server_port: int = 8000,
111+
group_port: int = 51216,
112+
connection_timeout: float = 0.0
87113
):
88114
if not is_requests_available():
89115
raise ImportError("requests is not installed. Please install it with `pip install requests`.")
90116
if not is_vllm_available():
91117
raise ImportError("vLLM is not installed. Please install it with `pip install vllm`.")
92118

93119
self.session = requests.Session()
94-
self.host = host
95-
self.server_port = server_port
120+
121+
if base_url is not None:
122+
# Parse the base_url to extract host and port
123+
parsed_url = urlparse(base_url)
124+
scheme = parsed_url.scheme or "http"
125+
self.base_url = f"{scheme}://{parsed_url.netloc}{parsed_url.path}"
126+
else:
127+
self.host = host
128+
self.server_port = server_port
129+
self.base_url = f"http://{self.host}:{self.server_port}"
96130
self.group_port = group_port
97131
self.check_server(connection_timeout) # check server and fail after timeout
98-
self.init_communicator()
99-
atexit.register(self.close_communicator) # when the client object is deleted, close the weight update group
100132

101133
def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
102134
"""
@@ -109,7 +141,7 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
109141
total_timeout (`float`, *optional*, defaults to `0.0`):
110142
Total timeout duration in seconds.
111143
"""
112-
url = f"http://{self.host}:{self.server_port}/health/"
144+
url = f"{self.base_url}/health/"
113145
start_time = time.time() # Record the start time
114146

115147
while True:
@@ -120,7 +152,7 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):
120152
elapsed_time = time.time() - start_time
121153
if elapsed_time >= total_timeout:
122154
raise ConnectionError(
123-
f"The vLLM server can't be reached at {self.host}:{self.server_port} after {total_timeout} "
155+
f"The vLLM server can't be reached at {self.base_url} after {total_timeout} "
124156
"seconds. Make sure the server is running by running `trl vllm-serve`."
125157
) from exc
126158
else:
@@ -171,7 +203,7 @@ def generate(
171203
`list[list[int]]`:
172204
List of lists of token IDs representing the model-generated completions for each prompt.
173205
"""
174-
url = f"http://{self.host}:{self.server_port}/generate/"
206+
url = f"{self.base_url}/generate/"
175207
response = self.session.post(
176208
url,
177209
json={
@@ -195,28 +227,36 @@ def init_communicator(self):
195227
"""
196228
Initializes the weight update group in a distributed setup for model synchronization.
197229
"""
198-
# Get the tensor parallel size from the server
199-
url = f"http://{self.host}:{self.server_port}/get_tensor_parallel_size/"
230+
# Get the world size from the server
231+
url = f"{self.base_url}/get_world_size/"
200232
response = requests.get(url)
201233
if response.status_code == 200:
202-
tensor_parallel_size = response.json()["tensor_parallel_size"]
234+
vllm_world_size = response.json()["world_size"]
203235
else:
204236
raise Exception(f"Request failed: {response.status_code}, {response.text}")
205237

206-
world_size = tensor_parallel_size + 1
207-
self.rank = tensor_parallel_size # The client's rank is the last process
238+
world_size = vllm_world_size + 1 # add the client to the world
239+
self.rank = vllm_world_size # the client's rank is the last process
208240

209241
# Initialize weight update group
210-
url = f"http://{self.host}:{self.server_port}/init_communicator/"
242+
url = f"{self.base_url}/init_communicator/"
211243
# In the server side, the host is set to 0.0.0.0
212244
response = self.session.post(url, json={"host": "0.0.0.0", "port": self.group_port, "world_size": world_size})
213245
if response.status_code != 200:
214246
raise Exception(f"Request failed: {response.status_code}, {response.text}")
215247

248+
# Brief delay to allow server initialization. While not strictly required (client socket will retry on
249+
# connection failure), this prevents log warnings like:
250+
# [W416 23:24:57.460001114 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
251+
time.sleep(0.1)
252+
216253
# Set up the communication group for weight broadcasting
217254
pg = StatelessProcessGroup.create(host=self.host, port=self.group_port, rank=self.rank, world_size=world_size)
218255
self.pynccl_comm = PyNcclCommunicator(pg, device=0)
219256

257+
# When the client object is deleted, close the weight update group
258+
atexit.register(self.close_communicator)
259+
220260
def update_named_param(self, name: str, weights: torch.Tensor):
221261
"""
222262
Updates a specific named parameter in the model and broadcasts it to other processes.
@@ -279,6 +319,7 @@ def close_communicator(self):
279319
from vllm import SamplingParams
280320

281321
client = VLLMClient()
322+
client.init_communicator()
282323

283324
# Generate completions
284325
responses = client.generate(["Hello, AI!", "Tell me a joke"], n=4, max_tokens=32, sampling_params=SamplingParams())

trl/trainer/grpo_config.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,13 @@ class GRPOConfig(TrainingArguments):
8585
use_vllm (`bool`, *optional*, defaults to `False`):
8686
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
8787
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
88+
vllm_server_base_url (`str` or `None`, *optional*, defaults to `None`):
89+
Base URL for the vLLM server (e.g., "http://localhost:8000"). If provided, vllm_server_host and
90+
vllm_server_port are ignored.
8891
vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`):
89-
Host of the vLLM server to connect to.
92+
Host of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided.
9093
vllm_server_port (`int`, *optional*, defaults to `8000`):
91-
Port of the vLLM server to connect to.
94+
Port of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided.
9295
vllm_server_timeout (`float`, *optional*, defaults to `120.0`):
9396
Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the
9497
timeout, a `ConnectionError` is raised.
@@ -270,13 +273,20 @@ class GRPOConfig(TrainingArguments):
270273
"running. To run the server, install vLLM (`pip install vllm`) and run `trl vllm-serve`."
271274
},
272275
)
276+
vllm_server_base_url: Optional[str] = field(
277+
default=None,
278+
metadata={
279+
"help": "Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, vllm_server_host and "
280+
"vllm_server_port are ignored."
281+
},
282+
)
273283
vllm_server_host: str = field(
274284
default="0.0.0.0",
275-
metadata={"help": "Host of the vLLM server to connect to."},
285+
metadata={"help": "Host of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."},
276286
)
277287
vllm_server_port: int = field(
278288
default=8000,
279-
metadata={"help": "Port of the vLLM server to connect to."},
289+
metadata={"help": "Port of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."},
280290
)
281291
vllm_server_timeout: float = field(
282292
default=120.0,

0 commit comments

Comments
 (0)