From e036ac56111e135320076fd1d4aa8970a6ed32d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Willem=20R=C3=B6pke?= Date: Thu, 3 Apr 2025 17:56:29 +0200 Subject: [PATCH] Add a raw generate API to the vLLM server --- trl/scripts/vllm_serve.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index bf854b3a35..bf91c7d939 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -323,6 +323,23 @@ async def generate(request: GenerateRequest): {"completion_ids": [[101, 102, 103], [201, 202, 203]]} ``` """ + all_outputs = await generate_raw(request) + completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs] + return {"completion_ids": completion_ids} + + @app.post("/generate_raw/") + async def generate_raw(request: GenerateRequest): + """ + Generates completions for the provided prompts and returns the raw list of RequestOutput objects. + + Args: + request (`GenerateRequest`): + - `prompts` (list of `str`): A list of prompts (text strings) for the model to generate completions. + + Returns: + `list[RequestOutput]`: + - `completion_ids` (list of list of `int`): A list of lists of token IDs for each generated completion. + """ # Guided decoding, if enabled if request.guided_decoding_regex is not None: @@ -342,8 +359,7 @@ async def generate(request: GenerateRequest): guided_decoding=guided_decoding, ) all_outputs = llm.generate(request.prompts, sampling_params=sampling_params) - completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs] - return {"completion_ids": completion_ids} + return all_outputs class InitCommunicatorRequest(BaseModel): host: str