-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[vllm] support base_url parameter for vLLM client initialization #3324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
7219281
to
38c9c8e
Compare
tests/test_vllm_client_server.py
Outdated
|
||
|
||
@pytest.mark.slow | ||
@require_3_gpus | ||
class TestVLLMClientServerTPBaseURL(unittest.TestCase): | ||
model_id = "Qwen/Qwen2.5-1.5B" | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
# We want the server to run on GPU 1 and 2, so we set CUDA_VISIBLE_DEVICES to "1,2" | ||
env = os.environ.copy() | ||
env["CUDA_VISIBLE_DEVICES"] = "1,2" # Restrict to GPU 1 and 2 | ||
|
||
# Start the server process | ||
cls.server_process = subprocess.Popen( | ||
["trl", "vllm-serve", "--model", cls.model_id, "--tensor_parallel_size", "2"], | ||
stdout=subprocess.PIPE, | ||
stderr=subprocess.PIPE, | ||
env=env, | ||
) | ||
|
||
# Initialize the client with base_url | ||
cls.client = VLLMClient(base_url="http://localhost:8000", connection_timeout=120) | ||
|
||
def test_generate(self): | ||
prompts = ["Hello, AI!", "Tell me a joke"] | ||
outputs = self.client.generate(prompts) | ||
|
||
# Check that the output is a list | ||
self.assertIsInstance(outputs, list) | ||
|
||
# Check that the number of generated sequences is equal to the number of prompts | ||
self.assertEqual(len(outputs), len(prompts)) | ||
|
||
# Check that the generated sequences are lists of integers | ||
for seq in outputs: | ||
self.assertTrue(all(isinstance(tok, int) for tok in seq)) | ||
|
||
def test_update_model_params(self): | ||
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda") | ||
self.client.update_model_params(model) | ||
|
||
def test_reset_prefix_cache(self): | ||
# Test resetting the prefix cache | ||
self.client.reset_prefix_cache() | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
super().tearDownClass() | ||
|
||
# Close the client | ||
cls.client.close_communicator() | ||
|
||
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to | ||
# kill the server process and its children explicitly. | ||
parent = psutil.Process(cls.server_process.pid) | ||
children = parent.children(recursive=True) | ||
for child in children: | ||
child.send_signal(signal.SIGTERM) | ||
cls.server_process.terminate() | ||
cls.server_process.wait() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pytest.mark.slow | |
@require_3_gpus | |
class TestVLLMClientServerTPBaseURL(unittest.TestCase): | |
model_id = "Qwen/Qwen2.5-1.5B" | |
@classmethod | |
def setUpClass(cls): | |
# We want the server to run on GPU 1 and 2, so we set CUDA_VISIBLE_DEVICES to "1,2" | |
env = os.environ.copy() | |
env["CUDA_VISIBLE_DEVICES"] = "1,2" # Restrict to GPU 1 and 2 | |
# Start the server process | |
cls.server_process = subprocess.Popen( | |
["trl", "vllm-serve", "--model", cls.model_id, "--tensor_parallel_size", "2"], | |
stdout=subprocess.PIPE, | |
stderr=subprocess.PIPE, | |
env=env, | |
) | |
# Initialize the client with base_url | |
cls.client = VLLMClient(base_url="http://localhost:8000", connection_timeout=120) | |
def test_generate(self): | |
prompts = ["Hello, AI!", "Tell me a joke"] | |
outputs = self.client.generate(prompts) | |
# Check that the output is a list | |
self.assertIsInstance(outputs, list) | |
# Check that the number of generated sequences is equal to the number of prompts | |
self.assertEqual(len(outputs), len(prompts)) | |
# Check that the generated sequences are lists of integers | |
for seq in outputs: | |
self.assertTrue(all(isinstance(tok, int) for tok in seq)) | |
def test_update_model_params(self): | |
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda") | |
self.client.update_model_params(model) | |
def test_reset_prefix_cache(self): | |
# Test resetting the prefix cache | |
self.client.reset_prefix_cache() | |
@classmethod | |
def tearDownClass(cls): | |
super().tearDownClass() | |
# Close the client | |
cls.client.close_communicator() | |
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to | |
# kill the server process and its children explicitly. | |
parent = psutil.Process(cls.server_process.pid) | |
children = parent.children(recursive=True) | |
for child in children: | |
child.send_signal(signal.SIGTERM) | |
cls.server_process.terminate() | |
cls.server_process.wait() |
I think one test is enough
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this change makes sense, thanks for suggesting it.
Sorry for the delayed review.
I've made a few comments.
Have you tested on your infrastructure?
cd35db0
to
0bb3624
Compare
Thanks for your review. 😄 |
Hey! any update? |
hi @qgallouedec, After several tests, I realized that using 0.0.0.0 as the host address in StatelessProcessGroup.create isn't viable—it requires the server's actual IP address to function properly. But if we use base url to init the vllm client, we don't know the server's ip. Proposed Approach: Use the base url to resolve the server's IP programmatically in the client (when calling I’m currently testing this approach and diving deeper into how StatelessProcessGroup interacts with PyTorch’s TCPStore I’ll update this thread once I’ve narrowed down the root cause and validated the fix. Let me know if you have insights or suggestions! |
…ocket connection - Add logic to check for X-Forwarded-For header in response - If header is not present, use socket connection to get host IP - Update self.host with the obtained IP address
@qgallouedec |
"seconds. Make sure the server is running by running `trl vllm-serve`." | ||
) from exc | ||
else: | ||
if response.status_code == 200: | ||
if "X-Forwarded-For" in response.headers: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we still need the ip addrees when using base url
- An HTTP proxy can add a header containing the backend IP so that the client can read it from the response. The most common header for passing client IPs is X-Forwarded-For
- get the ip via a request
related issue #3322