diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py index 817ebcf92f..98f50c91f4 100644 --- a/tests/test_vllm_client_server.py +++ b/tests/test_vllm_client_server.py @@ -252,3 +252,79 @@ def tearDownClass(cls): child.send_signal(signal.SIGTERM) cls.server_process.terminate() cls.server_process.wait() + + +@pytest.mark.slow +@require_torch_multi_gpu +class TestVLLMClientServerBaseURL(unittest.TestCase): + model_id = "Qwen/Qwen2.5-1.5B" + + @classmethod + def setUpClass(cls): + # We want the server to run on GPU 1, so we set CUDA_VISIBLE_DEVICES to "1" + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = "1" # Restrict to GPU 1 + + # Start the server process + cls.server_process = subprocess.Popen( + ["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env + ) + + # Initialize the client with base_url + cls.client = VLLMClient(base_url="http://localhost:8000", connection_timeout=120) + + def test_generate(self): + prompts = ["Hello, AI!", "Tell me a joke"] + outputs = self.client.generate(prompts) + + # Check that the output is a list + self.assertIsInstance(outputs, list) + + # Check that the number of generated sequences is equal to the number of prompts + self.assertEqual(len(outputs), len(prompts)) + + # Check that the generated sequences are lists of integers + for seq in outputs: + self.assertTrue(all(isinstance(tok, int) for tok in seq)) + + def test_generate_with_params(self): + prompts = ["Hello, AI!", "Tell me a joke"] + outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32) + + # Check that the output is a list + self.assertIsInstance(outputs, list) + + # Check that the number of generated sequences is 2 times the number of prompts + self.assertEqual(len(outputs), 2 * len(prompts)) + + # Check that the generated sequences are lists of integers + for seq in outputs: + self.assertTrue(all(isinstance(tok, int) for tok in seq)) + + # Check that the length of the generated sequences is less than or equal to 32 + for seq in outputs: + self.assertLessEqual(len(seq), 32) + + def test_update_model_params(self): + model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda") + self.client.update_model_params(model) + + def test_reset_prefix_cache(self): + # Test resetting the prefix cache + self.client.reset_prefix_cache() + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + + # Close the client + cls.client.close_communicator() + + # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to + # kill the server process and its children explicitly. + parent = psutil.Process(cls.server_process.pid) + children = parent.children(recursive=True) + for child in children: + child.send_signal(signal.SIGTERM) + cls.server_process.terminate() + cls.server_process.wait() \ No newline at end of file diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index 761a79f0bf..2c9eda776e 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -15,7 +15,8 @@ import atexit import logging import time -from typing import Optional +from typing import Optional, Union +from urllib.parse import urlparse import torch from torch import nn @@ -47,10 +48,12 @@ class VLLMClient: weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`. Args: + base_url (`str`, *optional*, defaults to `None`): + Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, host and server_port are ignored. host (`str`, *optional*, defaults to `"0.0.0.0"`): - IP address of the vLLM server. + IP address of the vLLM server. Ignored if `base_url` is provided. server_port (`int`, *optional*, defaults to `8000`): - Port number of the vLLM server. + Port number of the vLLM server. Ignored if `base_url` is provided. group_port (`int`, *optional*, defaults to `51216`): Port number for the weight update group. connection_timeout (`float`, *optional*, defaults to `0.0`): @@ -67,8 +70,25 @@ class VLLMClient: INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) ``` - Use the client to generate completions and update model weights: + There are two ways to initialize the client: + 1. Using base_url: + ```python + >>> from trl.extras.vllm_client import VLLMClient + >>> # Connect to a local server + >>> client = VLLMClient(base_url="http://localhost:8000") + >>> # Or connect to a remote server + >>> client = VLLMClient(base_url="http://192.168.1.100:8000") + ``` + 2. Using host and server_port: + ```python + >>> from trl.extras.vllm_client import VLLMClient + >>> # Connect to a local server + >>> client = VLLMClient(host="localhost", server_port=8000) + >>> # Or connect to a remote server + >>> client = VLLMClient(host="192.168.1.100", server_port=8000) + ``` + Use the client to generate completions and update model weights: ```python >>> from trl.extras.vllm_client import VLLMClient >>> client = VLLMClient() @@ -84,7 +104,12 @@ class VLLMClient: """ def __init__( - self, host: str = "0.0.0.0", server_port: int = 8000, group_port: int = 51216, connection_timeout: float = 0.0 + self, + base_url: Optional[str] = None, + host: str = "0.0.0.0", + server_port: int = 8000, + group_port: int = 51216, + connection_timeout: float = 0.0 ): if not is_requests_available(): raise ImportError("requests is not installed. Please install it with `pip install requests`.") @@ -92,8 +117,16 @@ def __init__( raise ImportError("vLLM is not installed. Please install it with `pip install vllm`.") self.session = requests.Session() - self.host = host - self.server_port = server_port + + if base_url is not None: + # Parse the base_url to extract host and port + parsed_url = urlparse(base_url) + scheme = parsed_url.scheme or "http" + self.base_url = f"{scheme}://{parsed_url.netloc}{parsed_url.path}" + else: + self.host = host + self.server_port = server_port + self.base_url = f"http://{self.host}:{self.server_port}" self.group_port = group_port self.check_server(connection_timeout) # check server and fail after timeout @@ -108,7 +141,7 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0): total_timeout (`float`, *optional*, defaults to `0.0`): Total timeout duration in seconds. """ - url = f"http://{self.host}:{self.server_port}/health/" + url = f"{self.base_url}/health/" start_time = time.time() # Record the start time while True: @@ -119,11 +152,19 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0): elapsed_time = time.time() - start_time if elapsed_time >= total_timeout: raise ConnectionError( - f"The vLLM server can't be reached at {self.host}:{self.server_port} after {total_timeout} " + f"The vLLM server can't be reached at {self.base_url} after {total_timeout} " "seconds. Make sure the server is running by running `trl vllm-serve`." ) from exc else: if response.status_code == 200: + if "X-Forwarded-For" in response.headers: + self.host = response.headers["X-Forwarded-For"] + else: + resp = requests.get(url, stream=True) + resp.raise_for_status() + sock = resp.raw._connection.sock + ip, port = sock.getpeername() + self.host = ip logger.info("Server is up!") return None @@ -170,7 +211,7 @@ def generate( `list[list[int]]`: List of lists of token IDs representing the model-generated completions for each prompt. """ - url = f"http://{self.host}:{self.server_port}/generate/" + url = f"{self.base_url}/generate/" response = self.session.post( url, json={ @@ -195,7 +236,7 @@ def init_communicator(self): Initializes the weight update group in a distributed setup for model synchronization. """ # Get the world size from the server - url = f"http://{self.host}:{self.server_port}/get_world_size/" + url = f"{self.base_url}/get_world_size/" response = requests.get(url) if response.status_code == 200: vllm_world_size = response.json()["world_size"] @@ -206,7 +247,7 @@ def init_communicator(self): self.rank = vllm_world_size # the client's rank is the last process # Initialize weight update group - url = f"http://{self.host}:{self.server_port}/init_communicator/" + url = f"{self.base_url}/init_communicator/" # In the server side, the host is set to 0.0.0.0 response = self.session.post(url, json={"host": "0.0.0.0", "port": self.group_port, "world_size": world_size}) if response.status_code != 200: @@ -235,7 +276,7 @@ def update_named_param(self, name: str, weights: torch.Tensor): Tensor containing the updated weights. """ dtype, shape = str(weights.dtype), tuple(weights.shape) - url = f"http://{self.host}:{self.server_port}/update_named_param/" + url = f"{self.base_url}/update_named_param/" response = self.session.post(url, json={"name": name, "dtype": dtype, "shape": shape}) if response.status_code != 200: raise Exception(f"Request failed: {response.status_code}, {response.text}") @@ -260,7 +301,7 @@ def reset_prefix_cache(self): """ Resets the prefix cache for the model. """ - url = f"http://{self.host}:{self.server_port}/reset_prefix_cache/" + url = f"{self.base_url}/reset_prefix_cache/" response = self.session.post(url) if response.status_code != 200: raise Exception(f"Request failed: {response.status_code}, {response.text}") @@ -269,7 +310,7 @@ def close_communicator(self): """ Closes the weight update group and cleans up the communication group. """ - url = f"http://{self.host}:{self.server_port}/close_communicator/" + url = f"{self.base_url}/close_communicator/" try: response = self.session.post(url) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index fafdc30cb4..c61a8a9b60 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -90,8 +90,8 @@ class GRPOConfig(TrainingArguments): > Parameters that control generation acceleration powered by vLLM use_vllm (`bool`, *optional*, defaults to `False`): - Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for generation - instead of the default model.generate(). Requires `vllm` to be installed. + Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for + training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`). vllm_mode (`str`, *optional*, defaults to `"server"`): Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `"server"` or `"colocate"`. @@ -104,12 +104,14 @@ class GRPOConfig(TrainingArguments): Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled. > Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) - + vllm_server_base_url (`str` or `None`, *optional*, defaults to `None`): + Base URL for the vLLM server (e.g., "http://localhost:8000"). If provided, vllm_server_host and + vllm_server_port are ignored. vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`): - Host of the vLLM server to connect to. + Host of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided. vllm_server_port (`int`, *optional*, defaults to `8000`): - Port of the vLLM server to connect to. - vllm_server_timeout (`float`, *optional*, defaults to `240.0`): + Port of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided. + vllm_server_timeout (`float`, *optional*, defaults to `120.0`): Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the timeout, a `ConnectionError` is raised. @@ -318,6 +320,13 @@ class GRPOConfig(TrainingArguments): "generation instead of the default model.generate(). Requires `vllm` to be installed." }, ) + vllm_server_base_url: Optional[str] = field( + default=None, + metadata={ + "help": "Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, vllm_server_host and " + "vllm_server_port are ignored." + }, + ) vllm_mode: str = field( default="server", metadata={ @@ -336,11 +345,11 @@ class GRPOConfig(TrainingArguments): # Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) vllm_server_host: str = field( default="0.0.0.0", - metadata={"help": "Host of the vLLM server to connect to."}, + metadata={"help": "Host of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."}, ) vllm_server_port: int = field( default=8000, - metadata={"help": "Port of the vLLM server to connect to."}, + metadata={"help": "Port of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."}, ) vllm_server_timeout: float = field( default=240.0, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 513b8bfac1..c624eae0ba 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -616,10 +616,10 @@ def data_collator(features): # No data collation is needed in GRPO "`pip install vllm` to use it." ) - if self.vllm_mode == "server" and self.accelerator.is_main_process: - self.vllm_client = VLLMClient( - args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout - ) + if self.accelerator.is_main_process: + base_url = args.vllm_server_base_url if args.vllm_server_base_url is not None else f"http://{args.vllm_server_host}:{args.vllm_server_port}" + self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) + self.vllm_client.init_communicator() elif self.vllm_mode == "colocate":