Skip to content

Commit 9c681be

Browse files
committed
- refactor test
1 parent 92c0e01 commit 9c681be

File tree

2 files changed

+145
-58
lines changed

2 files changed

+145
-58
lines changed

tests/test_vllm_client_server.py

Lines changed: 143 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -43,32 +43,8 @@ def setUpClass(cls):
4343
["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
4444
)
4545

46-
# Initialize the clients using both initialization methods
47-
cls.client = VLLMClient(connection_timeout=120) # Default host and port
48-
cls.client_base_url = VLLMClient(base_url="http://0.0.0.0:8000", connection_timeout=120) # Using base_url
49-
50-
def test_initialization_methods(self):
51-
"""Test that both initialization methods work correctly."""
52-
# Test generation with default client (host+port)
53-
prompts = ["Test initialization 1"]
54-
outputs_default = self.client.generate(prompts)
55-
self.assertIsInstance(outputs_default, list)
56-
self.assertEqual(len(outputs_default), len(prompts))
57-
58-
# Test generation with base_url client
59-
outputs_base_url = self.client_base_url.generate(prompts)
60-
self.assertIsInstance(outputs_base_url, list)
61-
self.assertEqual(len(outputs_base_url), len(prompts))
62-
63-
def test_base_url_attribute(self):
64-
"""Test that both initialization methods set the base_url attribute correctly."""
65-
# Both clients should have the same base_url
66-
self.assertEqual(self.client.base_url, "http://0.0.0.0:8000")
67-
self.assertEqual(self.client_base_url.base_url, "http://0.0.0.0:8000")
68-
69-
# Verify the client doesn't store host/port when base_url is provided
70-
self.assertTrue(not hasattr(self.client_base_url, 'host') or self.client_base_url.host is None)
71-
self.assertTrue(not hasattr(self.client_base_url, 'server_port') or self.client_base_url.server_port is None)
46+
# Initialize the client
47+
cls.client = VLLMClient(connection_timeout=120)
7248

7349
def test_generate(self):
7450
prompts = ["Hello, AI!", "Tell me a joke"]
@@ -114,9 +90,84 @@ def test_reset_prefix_cache(self):
11490
def tearDownClass(cls):
11591
super().tearDownClass()
11692

117-
# Close the clients
93+
# Close the client
94+
cls.client.close_communicator()
95+
96+
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
97+
# kill the server process and its children explicitly.
98+
parent = psutil.Process(cls.server_process.pid)
99+
children = parent.children(recursive=True)
100+
for child in children:
101+
child.send_signal(signal.SIGTERM)
102+
cls.server_process.terminate()
103+
cls.server_process.wait()
104+
105+
106+
@pytest.mark.slow
107+
@require_torch_multi_gpu
108+
class TestVLLMClientServerBaseURL(unittest.TestCase):
109+
model_id = "Qwen/Qwen2.5-1.5B"
110+
111+
@classmethod
112+
def setUpClass(cls):
113+
# We want the server to run on GPU 1, so we set CUDA_VISIBLE_DEVICES to "1"
114+
env = os.environ.copy()
115+
env["CUDA_VISIBLE_DEVICES"] = "1" # Restrict to GPU 1
116+
117+
# Start the server process
118+
cls.server_process = subprocess.Popen(
119+
["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
120+
)
121+
122+
# Initialize the client with base_url
123+
cls.client = VLLMClient(base_url="http://localhost:8000", connection_timeout=120)
124+
125+
def test_generate(self):
126+
prompts = ["Hello, AI!", "Tell me a joke"]
127+
outputs = self.client.generate(prompts)
128+
129+
# Check that the output is a list
130+
self.assertIsInstance(outputs, list)
131+
132+
# Check that the number of generated sequences is equal to the number of prompts
133+
self.assertEqual(len(outputs), len(prompts))
134+
135+
# Check that the generated sequences are lists of integers
136+
for seq in outputs:
137+
self.assertTrue(all(isinstance(tok, int) for tok in seq))
138+
139+
def test_generate_with_params(self):
140+
prompts = ["Hello, AI!", "Tell me a joke"]
141+
outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)
142+
143+
# Check that the output is a list
144+
self.assertIsInstance(outputs, list)
145+
146+
# Check that the number of generated sequences is 2 times the number of prompts
147+
self.assertEqual(len(outputs), 2 * len(prompts))
148+
149+
# Check that the generated sequences are lists of integers
150+
for seq in outputs:
151+
self.assertTrue(all(isinstance(tok, int) for tok in seq))
152+
153+
# Check that the length of the generated sequences is less than or equal to 32
154+
for seq in outputs:
155+
self.assertLessEqual(len(seq), 32)
156+
157+
def test_update_model_params(self):
158+
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda")
159+
self.client.update_model_params(model)
160+
161+
def test_reset_prefix_cache(self):
162+
# Test resetting the prefix cache
163+
self.client.reset_prefix_cache()
164+
165+
@classmethod
166+
def tearDownClass(cls):
167+
super().tearDownClass()
168+
169+
# Close the client
118170
cls.client.close_communicator()
119-
cls.client_base_url.close_communicator()
120171

121172
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
122173
# kill the server process and its children explicitly.
@@ -147,32 +198,69 @@ def setUpClass(cls):
147198
env=env,
148199
)
149200

150-
# Initialize the clients using both initialization methods
151-
cls.client = VLLMClient(connection_timeout=120) # Default host and port
152-
cls.client_base_url = VLLMClient(base_url="http://0.0.0.0:8000", connection_timeout=120) # Using base_url
153-
154-
def test_initialization_methods(self):
155-
"""Test that both initialization methods work correctly with tensor parallelism enabled."""
156-
# Test generation with default client (host+port)
157-
prompts = ["Test TP initialization 1"]
158-
outputs_default = self.client.generate(prompts)
159-
self.assertIsInstance(outputs_default, list)
160-
self.assertEqual(len(outputs_default), len(prompts))
161-
162-
# Test generation with base_url client
163-
outputs_base_url = self.client_base_url.generate(prompts)
164-
self.assertIsInstance(outputs_base_url, list)
165-
self.assertEqual(len(outputs_base_url), len(prompts))
166-
167-
def test_base_url_attribute(self):
168-
"""Test that both initialization methods set the base_url attribute correctly."""
169-
# Both clients should have the same base_url
170-
self.assertEqual(self.client.base_url, "http://0.0.0.0:8000")
171-
self.assertEqual(self.client_base_url.base_url, "http://0.0.0.0:8000")
172-
173-
# Verify the client doesn't store host/port when base_url is provided
174-
self.assertTrue(not hasattr(self.client_base_url, 'host') or self.client_base_url.host is None)
175-
self.assertTrue(not hasattr(self.client_base_url, 'server_port') or self.client_base_url.server_port is None)
201+
# Initialize the client
202+
cls.client = VLLMClient(connection_timeout=120)
203+
204+
def test_generate(self):
205+
prompts = ["Hello, AI!", "Tell me a joke"]
206+
outputs = self.client.generate(prompts)
207+
208+
# Check that the output is a list
209+
self.assertIsInstance(outputs, list)
210+
211+
# Check that the number of generated sequences is equal to the number of prompts
212+
self.assertEqual(len(outputs), len(prompts))
213+
214+
# Check that the generated sequences are lists of integers
215+
for seq in outputs:
216+
self.assertTrue(all(isinstance(tok, int) for tok in seq))
217+
218+
def test_update_model_params(self):
219+
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda")
220+
self.client.update_model_params(model)
221+
222+
def test_reset_prefix_cache(self):
223+
# Test resetting the prefix cache
224+
self.client.reset_prefix_cache()
225+
226+
@classmethod
227+
def tearDownClass(cls):
228+
super().tearDownClass()
229+
230+
# Close the client
231+
cls.client.close_communicator()
232+
233+
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
234+
# kill the server process and its children explicitly.
235+
parent = psutil.Process(cls.server_process.pid)
236+
children = parent.children(recursive=True)
237+
for child in children:
238+
child.send_signal(signal.SIGTERM)
239+
cls.server_process.terminate()
240+
cls.server_process.wait()
241+
242+
243+
@pytest.mark.slow
244+
@require_3_gpus
245+
class TestVLLMClientServerTPBaseURL(unittest.TestCase):
246+
model_id = "Qwen/Qwen2.5-1.5B"
247+
248+
@classmethod
249+
def setUpClass(cls):
250+
# We want the server to run on GPU 1 and 2, so we set CUDA_VISIBLE_DEVICES to "1,2"
251+
env = os.environ.copy()
252+
env["CUDA_VISIBLE_DEVICES"] = "1,2" # Restrict to GPU 1 and 2
253+
254+
# Start the server process
255+
cls.server_process = subprocess.Popen(
256+
["trl", "vllm-serve", "--model", cls.model_id, "--tensor_parallel_size", "2"],
257+
stdout=subprocess.PIPE,
258+
stderr=subprocess.PIPE,
259+
env=env,
260+
)
261+
262+
# Initialize the client with base_url
263+
cls.client = VLLMClient(base_url="http://localhost:8000", connection_timeout=120)
176264

177265
def test_generate(self):
178266
prompts = ["Hello, AI!", "Tell me a joke"]
@@ -200,9 +288,8 @@ def test_reset_prefix_cache(self):
200288
def tearDownClass(cls):
201289
super().tearDownClass()
202290

203-
# Close the clients
291+
# Close the client
204292
cls.client.close_communicator()
205-
cls.client_base_url.close_communicator()
206293

207294
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
208295
# kill the server process and its children explicitly.

trl/extras/vllm_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,8 @@ def close_communicator(self):
315315
print("Client 1 initialized with base_url")
316316

317317
# Example 2: Initialize with host and port
318-
client2 = VLLMClient(host="0.0.0.0", server_port=8000)
319-
print("Client 2 initialized with host and port")
318+
# client2 = VLLMClient(host="0.0.0.0", server_port=8000)
319+
# print("Client 2 initialized with host and port")
320320

321321
# Choose one client to use for the example
322322
client = client1

0 commit comments

Comments
 (0)