@@ -43,32 +43,8 @@ def setUpClass(cls):
43
43
["trl" , "vllm-serve" , "--model" , cls .model_id ], stdout = subprocess .PIPE , stderr = subprocess .PIPE , env = env
44
44
)
45
45
46
- # Initialize the clients using both initialization methods
47
- cls .client = VLLMClient (connection_timeout = 120 ) # Default host and port
48
- cls .client_base_url = VLLMClient (base_url = "http://0.0.0.0:8000" , connection_timeout = 120 ) # Using base_url
49
-
50
- def test_initialization_methods (self ):
51
- """Test that both initialization methods work correctly."""
52
- # Test generation with default client (host+port)
53
- prompts = ["Test initialization 1" ]
54
- outputs_default = self .client .generate (prompts )
55
- self .assertIsInstance (outputs_default , list )
56
- self .assertEqual (len (outputs_default ), len (prompts ))
57
-
58
- # Test generation with base_url client
59
- outputs_base_url = self .client_base_url .generate (prompts )
60
- self .assertIsInstance (outputs_base_url , list )
61
- self .assertEqual (len (outputs_base_url ), len (prompts ))
62
-
63
- def test_base_url_attribute (self ):
64
- """Test that both initialization methods set the base_url attribute correctly."""
65
- # Both clients should have the same base_url
66
- self .assertEqual (self .client .base_url , "http://0.0.0.0:8000" )
67
- self .assertEqual (self .client_base_url .base_url , "http://0.0.0.0:8000" )
68
-
69
- # Verify the client doesn't store host/port when base_url is provided
70
- self .assertTrue (not hasattr (self .client_base_url , 'host' ) or self .client_base_url .host is None )
71
- self .assertTrue (not hasattr (self .client_base_url , 'server_port' ) or self .client_base_url .server_port is None )
46
+ # Initialize the client
47
+ cls .client = VLLMClient (connection_timeout = 120 )
72
48
73
49
def test_generate (self ):
74
50
prompts = ["Hello, AI!" , "Tell me a joke" ]
@@ -114,9 +90,84 @@ def test_reset_prefix_cache(self):
114
90
def tearDownClass (cls ):
115
91
super ().tearDownClass ()
116
92
117
- # Close the clients
93
+ # Close the client
94
+ cls .client .close_communicator ()
95
+
96
+ # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
97
+ # kill the server process and its children explicitly.
98
+ parent = psutil .Process (cls .server_process .pid )
99
+ children = parent .children (recursive = True )
100
+ for child in children :
101
+ child .send_signal (signal .SIGTERM )
102
+ cls .server_process .terminate ()
103
+ cls .server_process .wait ()
104
+
105
+
106
+ @pytest .mark .slow
107
+ @require_torch_multi_gpu
108
+ class TestVLLMClientServerBaseURL (unittest .TestCase ):
109
+ model_id = "Qwen/Qwen2.5-1.5B"
110
+
111
+ @classmethod
112
+ def setUpClass (cls ):
113
+ # We want the server to run on GPU 1, so we set CUDA_VISIBLE_DEVICES to "1"
114
+ env = os .environ .copy ()
115
+ env ["CUDA_VISIBLE_DEVICES" ] = "1" # Restrict to GPU 1
116
+
117
+ # Start the server process
118
+ cls .server_process = subprocess .Popen (
119
+ ["trl" , "vllm-serve" , "--model" , cls .model_id ], stdout = subprocess .PIPE , stderr = subprocess .PIPE , env = env
120
+ )
121
+
122
+ # Initialize the client with base_url
123
+ cls .client = VLLMClient (base_url = "http://localhost:8000" , connection_timeout = 120 )
124
+
125
+ def test_generate (self ):
126
+ prompts = ["Hello, AI!" , "Tell me a joke" ]
127
+ outputs = self .client .generate (prompts )
128
+
129
+ # Check that the output is a list
130
+ self .assertIsInstance (outputs , list )
131
+
132
+ # Check that the number of generated sequences is equal to the number of prompts
133
+ self .assertEqual (len (outputs ), len (prompts ))
134
+
135
+ # Check that the generated sequences are lists of integers
136
+ for seq in outputs :
137
+ self .assertTrue (all (isinstance (tok , int ) for tok in seq ))
138
+
139
+ def test_generate_with_params (self ):
140
+ prompts = ["Hello, AI!" , "Tell me a joke" ]
141
+ outputs = self .client .generate (prompts , n = 2 , repetition_penalty = 0.9 , temperature = 0.8 , max_tokens = 32 )
142
+
143
+ # Check that the output is a list
144
+ self .assertIsInstance (outputs , list )
145
+
146
+ # Check that the number of generated sequences is 2 times the number of prompts
147
+ self .assertEqual (len (outputs ), 2 * len (prompts ))
148
+
149
+ # Check that the generated sequences are lists of integers
150
+ for seq in outputs :
151
+ self .assertTrue (all (isinstance (tok , int ) for tok in seq ))
152
+
153
+ # Check that the length of the generated sequences is less than or equal to 32
154
+ for seq in outputs :
155
+ self .assertLessEqual (len (seq ), 32 )
156
+
157
+ def test_update_model_params (self ):
158
+ model = AutoModelForCausalLM .from_pretrained (self .model_id , device_map = "cuda" )
159
+ self .client .update_model_params (model )
160
+
161
+ def test_reset_prefix_cache (self ):
162
+ # Test resetting the prefix cache
163
+ self .client .reset_prefix_cache ()
164
+
165
+ @classmethod
166
+ def tearDownClass (cls ):
167
+ super ().tearDownClass ()
168
+
169
+ # Close the client
118
170
cls .client .close_communicator ()
119
- cls .client_base_url .close_communicator ()
120
171
121
172
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
122
173
# kill the server process and its children explicitly.
@@ -147,32 +198,69 @@ def setUpClass(cls):
147
198
env = env ,
148
199
)
149
200
150
- # Initialize the clients using both initialization methods
151
- cls .client = VLLMClient (connection_timeout = 120 ) # Default host and port
152
- cls .client_base_url = VLLMClient (base_url = "http://0.0.0.0:8000" , connection_timeout = 120 ) # Using base_url
153
-
154
- def test_initialization_methods (self ):
155
- """Test that both initialization methods work correctly with tensor parallelism enabled."""
156
- # Test generation with default client (host+port)
157
- prompts = ["Test TP initialization 1" ]
158
- outputs_default = self .client .generate (prompts )
159
- self .assertIsInstance (outputs_default , list )
160
- self .assertEqual (len (outputs_default ), len (prompts ))
161
-
162
- # Test generation with base_url client
163
- outputs_base_url = self .client_base_url .generate (prompts )
164
- self .assertIsInstance (outputs_base_url , list )
165
- self .assertEqual (len (outputs_base_url ), len (prompts ))
166
-
167
- def test_base_url_attribute (self ):
168
- """Test that both initialization methods set the base_url attribute correctly."""
169
- # Both clients should have the same base_url
170
- self .assertEqual (self .client .base_url , "http://0.0.0.0:8000" )
171
- self .assertEqual (self .client_base_url .base_url , "http://0.0.0.0:8000" )
172
-
173
- # Verify the client doesn't store host/port when base_url is provided
174
- self .assertTrue (not hasattr (self .client_base_url , 'host' ) or self .client_base_url .host is None )
175
- self .assertTrue (not hasattr (self .client_base_url , 'server_port' ) or self .client_base_url .server_port is None )
201
+ # Initialize the client
202
+ cls .client = VLLMClient (connection_timeout = 120 )
203
+
204
+ def test_generate (self ):
205
+ prompts = ["Hello, AI!" , "Tell me a joke" ]
206
+ outputs = self .client .generate (prompts )
207
+
208
+ # Check that the output is a list
209
+ self .assertIsInstance (outputs , list )
210
+
211
+ # Check that the number of generated sequences is equal to the number of prompts
212
+ self .assertEqual (len (outputs ), len (prompts ))
213
+
214
+ # Check that the generated sequences are lists of integers
215
+ for seq in outputs :
216
+ self .assertTrue (all (isinstance (tok , int ) for tok in seq ))
217
+
218
+ def test_update_model_params (self ):
219
+ model = AutoModelForCausalLM .from_pretrained (self .model_id , device_map = "cuda" )
220
+ self .client .update_model_params (model )
221
+
222
+ def test_reset_prefix_cache (self ):
223
+ # Test resetting the prefix cache
224
+ self .client .reset_prefix_cache ()
225
+
226
+ @classmethod
227
+ def tearDownClass (cls ):
228
+ super ().tearDownClass ()
229
+
230
+ # Close the client
231
+ cls .client .close_communicator ()
232
+
233
+ # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
234
+ # kill the server process and its children explicitly.
235
+ parent = psutil .Process (cls .server_process .pid )
236
+ children = parent .children (recursive = True )
237
+ for child in children :
238
+ child .send_signal (signal .SIGTERM )
239
+ cls .server_process .terminate ()
240
+ cls .server_process .wait ()
241
+
242
+
243
+ @pytest .mark .slow
244
+ @require_3_gpus
245
+ class TestVLLMClientServerTPBaseURL (unittest .TestCase ):
246
+ model_id = "Qwen/Qwen2.5-1.5B"
247
+
248
+ @classmethod
249
+ def setUpClass (cls ):
250
+ # We want the server to run on GPU 1 and 2, so we set CUDA_VISIBLE_DEVICES to "1,2"
251
+ env = os .environ .copy ()
252
+ env ["CUDA_VISIBLE_DEVICES" ] = "1,2" # Restrict to GPU 1 and 2
253
+
254
+ # Start the server process
255
+ cls .server_process = subprocess .Popen (
256
+ ["trl" , "vllm-serve" , "--model" , cls .model_id , "--tensor_parallel_size" , "2" ],
257
+ stdout = subprocess .PIPE ,
258
+ stderr = subprocess .PIPE ,
259
+ env = env ,
260
+ )
261
+
262
+ # Initialize the client with base_url
263
+ cls .client = VLLMClient (base_url = "http://localhost:8000" , connection_timeout = 120 )
176
264
177
265
def test_generate (self ):
178
266
prompts = ["Hello, AI!" , "Tell me a joke" ]
@@ -200,9 +288,8 @@ def test_reset_prefix_cache(self):
200
288
def tearDownClass (cls ):
201
289
super ().tearDownClass ()
202
290
203
- # Close the clients
291
+ # Close the client
204
292
cls .client .close_communicator ()
205
- cls .client_base_url .close_communicator ()
206
293
207
294
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
208
295
# kill the server process and its children explicitly.
0 commit comments