diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 2cc5bd361d..42de9973c9 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -109,4 +109,6 @@ - sections: - local: bco_trainer title: BCO + - local: openenv + title: OpenEnv Integration title: Experimental \ No newline at end of file diff --git a/docs/source/openenv.md b/docs/source/openenv.md new file mode 100644 index 0000000000..6aa161902c --- /dev/null +++ b/docs/source/openenv.md @@ -0,0 +1,178 @@ +# OpenEnv Integration for Training LLMs with Environments + +## Overview + +[OpenEnv](https://github.com/meta-pytorch/OpenEnv) is a framework from Meta to integrate external environments with RL training loops. It provides [Gymnasium-style APIs](https://gymnasium.farama.org) (`reset()`, `step()`, `state()`) and a simple HTTP protocol for interacting with environments running as Docker containers. You can find OpenEnv environments on the Hugging Face Hub under dedicated [orgs](https://huggingface.co/openenv). + +[OpenEnv](https://github.com/meta-pytorch/OpenEnv) is an open-source framework from Meta's PyTorch team for defining, deploying and interacting with environments in RL/agentic workflows. It offers [Gymnasium-style APIs](https://gymnasium.farama.org) (e.g., `reset()` and `step()`) to interface with environments in a standard manner, and supports running these environments as backend servers (for example via HTTP or containerised execution). A collection of ready-to-use OpenEnv environments is available on the [Hugging Face Hub](https://huggingface.co/collections/openenv/environment-hub). + +Here, we’ll focus on the **integration of OpenEnv with TRL**, but check out the above resources to learn more about them. + +## Installation + +To use OpenEnv with TRL, install the framework: + +```bash +pip install git+https://github.com/meta-pytorch/OpenEnv.git +``` + +## Using `rollout_func` with OpenEnv environments + +TRL's [`GRPOTrainer`] supports _custom rollout logic_ through the `rollout_func` argument. This lets you override the trainer's default text-generation loop and directly interact with OpenEnv environments — for example, to compute environment-based rewards instead of purely model-based ones. + +### Rollout Function Signature + +A rollout function must have the following signature: + +```python +def rollout_func( + prompts: list[str], + args: GRPOConfig, + processing_class +) -> dict[str, list]: + """ + Custom rollout function for generation and reward computation. + + Args: + prompts: List of prompts to generate from + args: GRPOConfig containing sampling parameters (temperature, top_p, etc.) + processing_class: Tokenizer/processor for encoding/decoding + + Returns: + Dictionary containing: + - prompt_ids: List of token IDs for each prompt + - completion_ids: List of token IDs for each completion + - logprobs: List of log probabilities for each token + - Any additional fields are forwarded to reward functions as kwargs + """ + pass +``` + +> [!NOTE] +> Any extra fields in the returned dictionary (beyond the required three) are automatically forwarded to your reward functions. This makes it easy to propagate signals such as environment rewards or auxiliary metrics from the rollout step. + +### Integration pattern + +The typical pattern when combining OpenEnv with TRL looks like this: + +1. Start or connect to an OpenEnv environment (e.g., an HTTP endpoint or Dockerized env). +2. Generate completions from your model — for example, via a vLLM inference server (`use_vllm=True`, `vllm_mode="server"`). +3. Step through the environment using each completion to compute rewards or metrics. +4. Add environment results (e.g., `env_reward`) to the rollout result dict. +5. Access those rewards inside your reward function via `**kwargs`. + +By using OpenEnv in this loop, you can: + +* Train with realistic or interactive feedback (not just static reward functions). +* Plug in custom simulators, web APIs, or evaluators as environments. +* Pass structured reward signals back into RL training seamlessly. + +## A simple example + +The [echo.py](../../examples/scripts/openenv/echo.py) script demonstrates a minimal, end-to-end integration between TRL and OpenEnv. In this example, the Echo environment rewards completions based on their text length, encouraging the model to generate longer outputs. This pattern can be extended to any custom environment that provides structured feedback or task-based rewards: + +```python +from envs.echo_env import EchoEnv, EchoAction +from trl import GRPOConfig, GRPOTrainer + +# Create HTTP client for Echo Environment +client = EchoEnv.from_docker_image("echo-env:latest") + +def rollout_func(prompts, args, processing_class): + # 1. Generate completions via vLLM inference server (running on port 8000) + payload = { + "prompts": prompts, + "n": args.num_generations, + "temperature": args.temperature, + "max_tokens": args.max_completion_length, + } + response = requests.post("http://0.0.0.0:8000/generate/", json=payload) + result = response.json() + + completions_text = processing_class.batch_decode( + result["completion_ids"], + skip_special_tokens=True + ) + + # 2. Step through the environment to get rewards + client.reset() + env_rewards = [] + for msg in completions_text: + env_result = client.step(EchoAction(message=msg)) + env_rewards.append(env_result.reward) + + # 3. Add environment rewards as extra field + result["env_reward"] = env_rewards + return result + +def reward_from_env(completions, **kwargs): + """Extract environment rewards passed via rollout_func kwargs.""" + env_rewards = kwargs.get("env_reward", []) + return [float(reward) for reward in env_rewards] if env_rewards else [0.0] * len(completions) + +dataset = Dataset.from_dict({"prompt": ["You are an AI that interacts with an *Echo* environment. Word to echo:"] * 64}) + +# Setup trainer with custom rollout +trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", + reward_funcs=reward_from_env, + train_dataset=dataset, + rollout_func=rollout_func, # Use custom rollout + args=GRPOConfig( + vllm_mode="server", + use_vllm=True, + num_train_epochs=1, + num_generations=8, + max_completion_length=2048, + per_device_train_batch_size=8, + gradient_accumulation_steps=4, + ), +) +trainer.train() +``` + +That's it! Now that you’ve seen the full example, let’s unpack how the main pieces fit together. + +1. **Environment Client:** `EchoEnv` implements an HTTP interface to interact with the environment server. +2. **Custom rollout:** The `rollout_func` generates completions and steps through the environment to collect rewards. +3. **Extra fields:** The rollout adds `env_reward` to the result dictionary, which is automatically passed to reward functions. +4. **Reward function:** Extracts `env_reward` from `kwargs` to apply environment-computed rewards during training. + +> [!WARNING] +> The `rollout_func` is currently only supported when using vLLM in server mode (`use_vllm=True`, `vllm_mode="server"`). + +### Running the Example + +The example requires two GPUs: + +```bash +# Terminal 1: Start vLLM inference server +CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000 + +# Terminal 2: Run GRPO training with OpenEnv +CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/echo.py +``` + +To learn more about how to create custom environments, see the [OpenEnv documentation](https://github.com/meta-pytorch/OpenEnv/blob/main/src/envs/README.md). + +## Another example: Catch + +The [catch.py](../../examples/scripts/openenv/catch.py) script demonstrates training an LLM to play the Catch environment from OpenEnv. +In this example, the catch environment is a simple 10×5 grid game where a ball falls from the top and you control a paddle at the bottom. Move left, right, or stay to catch the ball for +1 reward or miss it for –1. + +```txt +· · ● · · +· · · · · +· · · · · +· · · · · +· · · · · +· · · · · +· · · · · +· · · · · +· · · · · +· · █ · · +``` + +The model is prompted with a description of the environment and the current state, and trained to output actions to maximize the environment reward. Below is the reward curve from training: + + \ No newline at end of file diff --git a/examples/scripts/openenv/catch.py b/examples/scripts/openenv/catch.py new file mode 100644 index 0000000000..9a976542a9 --- /dev/null +++ b/examples/scripts/openenv/catch.py @@ -0,0 +1,251 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: T201 +import os +import re +import subprocess +import sys +import time +from pathlib import Path + +import requests +from datasets import Dataset +from envs.openspiel_env import OpenSpielEnv +from envs.openspiel_env.models import OpenSpielAction + +from trl import GRPOConfig, GRPOTrainer, RichProgressCallback, apply_chat_template + + +""" +Simple script to run GRPO training with OpenEnv's Catch environment (OpenSpiel) and a vLLM server. The reward function +is based on the catch game where the agent tries to catch falling balls. + +Setup: + +```sh +uv pip install git+https://github.com/meta-pytorch/OpenEnv.git +uv pip install open_spiel rich trackio +``` + +Usage (2 GPUs required): + +# Spin up vLLM server + +```sh +CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000 +``` + +# Run training + +```sh +CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/catch.py +``` +""" + +GEN_URL = "http://0.0.0.0:8000/generate/" +ENV_URL = "http://0.0.0.0:8001" + +BASE_PROMPT = """You are an AI agent playing the game **Catch**. + +### Game Description +- The game is played on a **10×5 grid**. +- There is one **falling ball** and one **paddle** that you control at the bottom. +- The objective is to **move the paddle left or right to catch the ball** as it falls. +- The episode ends when the ball reaches the bottom row: + - You get **+1 reward** if you catch it. + - You get **–1 reward** if you miss it. + +### Observation Format + +- `observation`: a list of **50 numbers (floats)** representing the entire grid, flattened row by row. + - Each cell contains `1.0` if it is occupied (either by the ball or the paddle), or `0.0` if it is empty. + - The positions of the two `1.0` values indicate where the **ball** and **paddle** currently are. +- `legal_actions`: a list of integers representing which actions are currently allowed. + +### Actions Each action is a discrete integer: +- `0` → Move paddle **left** +- `1` → **Stay** (no movement) +- `2` → Move paddle **right** + +### Output Format Respond **only with one integer** representing your chosen action: `0`, `1`, or `2`. + +### Current Observation +""" + +# Start the OpenSpiel server in background +print("⚡ Starting FastAPI server for OpenSpiel Catch Environment...") + +# Determine the correct path +work_dir = str(Path.cwd().parent.absolute()) + +server_process = subprocess.Popen( + [sys.executable, "-m", "uvicorn", "envs.openspiel_env.server.app:app", "--host", "0.0.0.0", "--port", "8001"], + env={**os.environ, "PYTHONPATH": f"{work_dir}/src"}, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + cwd=work_dir, +) + +print("⏳ Waiting for server to start...") +time.sleep(5) + +# Check if server is running +try: + response = requests.get(f"{ENV_URL}/health", timeout=2) + print("\n✅ OpenSpiel Catch Environment server is running!") +except Exception as e: + print(f"\n❌ Server failed to start: {e}") + print("\n📋 Checking error output...") + server_process.poll() + if server_process.stderr: + stderr = server_process.stderr.read() + if stderr: + print(stderr) + raise + + +# Create HTTP client for OpenSpiel Catch Environment +client = OpenSpielEnv(base_url=f"{ENV_URL}") + + +def rollout_func(prompts: list[str], args: GRPOConfig, processing_class) -> dict[str, list]: + """ + Custom rollout function that generates completions via vLLM server and computes environment rewards. + + The catch game expects action IDs (integers). We'll parse the model's text output to extract action choices. + + Args: + prompts: List of prompts to generate from + args: GRPOConfig containing all sampling parameters + processing_class: Tokenizer/processor for decoding completions + + Returns: + Dict containing prompt_ids, completion_ids, logprobs, and env_reward + """ + # Run full episodes for each generation to get episode rewards + env_rewards = [] + all_prompt_ids = [] + all_completion_ids = [] + all_logprobs = [] + + for base_prompt in prompts: + for _ in range(args.num_generations): + # Run episode: Reset environment and loop until done + env_result = client.reset() + obs = env_result.observation + total_reward = 0.0 + + episode_prompt_ids = [] + episode_completion_ids = [] + episode_logprobs = [] + + # TODO: parallelise! + while not obs.done: + # FIXME: handle the addition of observation to prompt more cleanly, ideally without a train_dataset + episode_msg = {"prompt": [{"role": "user", "content": f"{base_prompt}\n\n{obs.info_state}\n"}]} + episode_prompt = apply_chat_template(episode_msg, processing_class) + + # Generate action from model + gen_payload = { + "prompts": [episode_prompt["prompt"]], + "n": 1, + "temperature": args.temperature, + "top_p": args.top_p, + "top_k": -1 if args.top_k is None else args.top_k, + "min_p": 0.0 if args.min_p is None else args.min_p, + "max_tokens": args.max_completion_length, + "repetition_penalty": args.repetition_penalty, + } + gen_response = requests.post(GEN_URL, json=gen_payload) + gen_response.raise_for_status() + gen_result = gen_response.json() + + # Collect prompt_ids, completion_ids, and logprobs from this step + episode_prompt_ids.extend(gen_result["prompt_ids"][0]) + episode_completion_ids.extend(gen_result["completion_ids"][0]) + episode_logprobs.extend(gen_result["logprobs"][0]) + + completion_text = processing_class.batch_decode( + gen_result["completion_ids"], skip_special_tokens=True + )[0] + + # Parse action from completion + action_id = 0 # default + numbers = re.findall(r"\b([0-2])\b", completion_text) + if numbers: + action_id = int(numbers[0]) + elif obs.legal_actions: + action_id = obs.legal_actions[0] + + # Take action in environment + env_result = client.step(OpenSpielAction(action_id=action_id, game_name="catch")) + reward = env_result.reward if env_result.reward is not None else 0.0 + total_reward += reward + obs = env_result.observation + + # Store episode results + env_rewards.append(total_reward) + all_prompt_ids.append(episode_prompt_ids) + all_completion_ids.append(episode_completion_ids) + all_logprobs.append(episode_logprobs) + + return { + "prompt_ids": all_prompt_ids, + "completion_ids": all_completion_ids, + "logprobs": all_logprobs, + "env_reward": env_rewards, + } + + +dataset = Dataset.from_dict({"prompt": [BASE_PROMPT] * 1000}) + + +def reward_from_env(completions, **kwargs): + """Reward function that uses the environment reward from the catch game.""" + # Extract environment rewards from kwargs (propagated via extra_fields) + env_rewards = kwargs.get("env_reward", []) + if env_rewards: + return [float(reward) for reward in env_rewards] + else: + # Fallback if env_reward is not available + return [0.0] * len(completions) + + +training_args = GRPOConfig( + output_dir="Qwen2.5-0.5B-GRPO-Catch", + vllm_mode="server", + use_vllm=True, + logging_steps=1, + report_to="trackio", + num_train_epochs=1, + max_completion_length=4, + gradient_accumulation_steps=4, +) +trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", + reward_funcs=reward_from_env, + args=training_args, + train_dataset=dataset, + rollout_func=rollout_func, + callbacks=[RichProgressCallback()], +) +trainer.train() + +# Give time for background threads to finish +time.sleep(5) + +print("🛑 Terminating environment server...") +server_process.terminate() diff --git a/examples/scripts/openenv/echo.py b/examples/scripts/openenv/echo.py new file mode 100644 index 0000000000..f2d05aa015 --- /dev/null +++ b/examples/scripts/openenv/echo.py @@ -0,0 +1,174 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: T201 +import os +import subprocess +import sys +import time +from pathlib import Path + +import requests +from datasets import load_dataset +from envs.echo_env import EchoEnv +from envs.echo_env.models import EchoAction + +from trl import GRPOConfig, GRPOTrainer, RichProgressCallback + + +""" +Simple script to run GRPO training with OpenEnv's Echo environment and a vLLM server. The reward function encourages +longer completions. + +Setup: + +```sh +uv pip install git+https://github.com/meta-pytorch/OpenEnv.git +``` + +Usage (2 GPUs required): + +# Spin up server + +```sh +CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000 +``` + +# Run training + +```sh +CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/echo.py +``` +""" + +GEN_URL = "http://0.0.0.0:8000/generate/" +ENV_URL = "http://0.0.0.0:8001" + +print("⚡ Starting FastAPI server for Echo Environment...") +# Workaround if you can't run the env with Docker +work_dir = str(Path.cwd().parent.absolute()) +server_process = subprocess.Popen( + [sys.executable, "-m", "uvicorn", "envs.echo_env.server.app:app", "--host", "0.0.0.0", "--port", "8001"], + env={**os.environ, "PYTHONPATH": f"{work_dir}/src"}, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + cwd=work_dir, +) + +print("⏳ Waiting for server to start...") +time.sleep(5) + +try: + response = requests.get(f"{ENV_URL}/health", timeout=2) + print("\n✅ Echo Environment server is running!") +except Exception as e: + print(f"\n❌ Server failed to start: {e}") + print("\n📋 Checking error output...") + server_process.poll() + if server_process.stderr: + stderr = server_process.stderr.read() + if stderr: + print(stderr) + raise + + +# Create HTTP client for Echo Environment +client = EchoEnv(base_url=f"{ENV_URL}") + + +def rollout_func(prompts: list[str], args: GRPOConfig, processing_class) -> dict[str, list]: + """ + Custom rollout function that generates completions via vLLM server and computes environment rewards. + + Args: + prompts: List of prompts to generate from + args: GRPOConfig containing all sampling parameters + processing_class: Tokenizer/processor for decoding completions + + Returns: + Dict containing prompt_ids, completion_ids, logprobs, and env_reward + """ + # 1. Generate completions via vLLM inference server (running on port 8000) + payload = { + "prompts": prompts, + "n": args.num_generations, + "temperature": args.temperature, + "top_p": args.top_p, + "top_k": -1 if args.top_k is None else args.top_k, + "min_p": 0.0 if args.min_p is None else args.min_p, + "max_tokens": args.max_completion_length, + "repetition_penalty": args.repetition_penalty, + } + response = requests.post(GEN_URL, json=payload) + + if response.status_code != 200: + print(f"Error response: {response.text}") + + response.raise_for_status() + result = response.json() + + completions_text = processing_class.batch_decode(result["completion_ids"], skip_special_tokens=True) + + # 2. Step through the environment to get rewards + env_result = client.reset() + env_rewards = [] + for msg in completions_text: + env_result = client.step(EchoAction(message=msg)) + env_rewards.append(env_result.reward) + + # 3. Add environment rewards as extra field + result["env_reward"] = env_rewards + + return result + + +def reward_from_env(completions, **kwargs): + """Reward function that uses the environment reward.""" + # Extract environment rewards from kwargs (propagated via extra_fields) + env_rewards = kwargs.get("env_reward", []) + if env_rewards: + return [float(reward) for reward in env_rewards] + else: + # Fallback if env_reward is not available + return [0.0] * len(completions) + + +dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train[:1000]") + +training_args = GRPOConfig( + output_dir="Qwen2.5-0.5B-GRPO-Rollout", + vllm_mode="server", + use_vllm=True, + logging_steps=1, + report_to="trackio", + num_train_epochs=1, + max_completion_length=2048, + gradient_accumulation_steps=4, +) +trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", + reward_funcs=reward_from_env, + args=training_args, + train_dataset=dataset, + rollout_func=rollout_func, + callbacks=[RichProgressCallback()], +) +trainer.train() + +# Give time for background threads to finish +time.sleep(5) + +print("🛑 Terminating Echo Environment server...") +server_process.terminate() diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 875d187309..4e9786c616 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -15,6 +15,7 @@ import inspect import os import textwrap +import warnings from collections import defaultdict, deque from contextlib import nullcontext from functools import partial @@ -99,6 +100,11 @@ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] +# What we call a rollout function is a callable that takes prompts (list), args (GRPOConfig), and processing_class as +# parameters and returns a dict of generation results. Those results must include "prompt_ids", "completion_ids", and +# "logprobs" fields. Any extra fields (per-completion) are forwarded to the reward functions. +RolloutFunc = Callable[[list[str], Any, Any], dict[str, Any]] + class GRPOTrainer(BaseTrainer): """ @@ -200,6 +206,11 @@ def reward_func(completions, **kwargs): model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. peft_config ([`~peft.PeftConfig`], *optional*): PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + rollout_func (`RolloutFunc`, *optional*): + Function to use for generating completions. It must take prompts, args, and processing_class as parameters + and return a dict with `"prompt_ids"`, `"completion_ids"`, and `"logprobs"` fields. Any other fields that + are forwarded to the reward functions. This feature is experimental and may change or be removed at any + time without prior notice. """ _tag_names = ["trl", "grpo"] @@ -230,6 +241,7 @@ def __init__( callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional["PeftConfig"] = None, + rollout_func: Optional[RolloutFunc] = None, ): # Args if args is None: @@ -345,6 +357,17 @@ def __init__( self.reward_processing_classes = reward_processing_classes + # Rollout function + if rollout_func is not None and os.environ.get("TRL_EXPERIMENTAL_SILENCE", "0") != "1": + warnings.warn( + "You are importing from 'rollout_func', which is an experimental feature. This API may change or be " + "removed at any time without prior notice. Silence this warning by setting environment variable " + "TRL_EXPERIMENTAL_SILENCE=1.", + UserWarning, + stacklevel=2, + ) + self.rollout_func = rollout_func + # Training arguments self.max_prompt_length = args.max_prompt_length self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper @@ -1089,7 +1112,8 @@ def _generate_single_turn(self, prompts: list): self._move_model_to_vllm() self._last_loaded_step = self.state.global_step - prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in prompts] + if is_conversational({"prompt": prompts[0]}): + prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in prompts] # Generate completions using vLLM: gather all prompts and use them in a single call in the main process if self.vllm_mode == "server": @@ -1114,18 +1138,34 @@ def _generate_single_turn(self, prompts: list): "generation_kwargs": self.args.generation_kwargs, } with profiling_context(self, "vLLM.generate"): - if is_conversational({"prompt": ordered_set_of_prompts[0]}): - output = self.vllm_client.chat(prompts=ordered_set_of_prompts, **sampling_params) + if self.rollout_func is not None: + if is_conversational({"prompt": ordered_set_of_prompts[0]}): + ordered_set_of_prompts = [ + apply_chat_template({"prompt": p}, self.processing_class)["prompt"] + for p in ordered_set_of_prompts + ] + output = self.rollout_func( + ordered_set_of_prompts, + self.args, + self.processing_class, + ) else: - output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) - payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"]) + if is_conversational({"prompt": ordered_set_of_prompts[0]}): + # FIXME: this endpoint doesn't exist in vllm_client + output = self.vllm_client.chat(prompts=ordered_set_of_prompts, **sampling_params) + else: + output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) + # Extract required fields and collect any extra fields for reward functions + required_keys = {"prompt_ids", "completion_ids", "logprobs"} + extra_fields = {k: v for k, v in output.items() if k not in required_keys} + payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"], extra_fields) else: payload = None # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. obj_list = [payload] broadcast_object_list(obj_list, from_process=0) - all_prompt_ids, all_completion_ids, all_logprobs = obj_list[0] + all_prompt_ids, all_completion_ids, all_logprobs, all_extra_fields = obj_list[0] # At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(self.num_generations)] @@ -1138,6 +1178,14 @@ def _generate_single_turn(self, prompts: list): completion_ids = all_completion_ids[process_slice] logprobs = all_logprobs[process_slice] + # Slice extra fields dict-of-lists per process (extra fields are per-completion, like completion_ids) + extra_fields = {} + for key, values in all_extra_fields.items(): + if isinstance(values, list): + extra_fields[key] = values[process_slice] + else: + extra_fields[key] = values + # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts elif self.vllm_mode == "colocate": if self.guided_decoding_regex: @@ -1201,6 +1249,8 @@ def _generate_single_turn(self, prompts: list): completion_ids = all_completion_ids logprobs = all_logprobs + extra_fields = {} # No extra fields for colocate mode + if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=2) @@ -1235,6 +1285,7 @@ def _generate_single_turn(self, prompts: list): completion_ids = [output.generated_tokens for output in all_outputs.values()] prompt_ids = generate_inputs["inputs"] logprobs = None # not used in this case + extra_fields = {} # No extra fields for paged mode else: # Regular generation path @@ -1279,14 +1330,15 @@ def _generate_single_turn(self, prompts: list): prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())] logprobs = None # not used in this case + extra_fields = {} # No extra fields for non-rollout_func paths - return prompt_ids, completion_ids, logprobs + return prompt_ids, completion_ids, logprobs, extra_fields def _generate(self, prompts: list): device = self.accelerator.device mode = "train" if self.model.training else "eval" - prompt_ids, completion_ids, logprobs = self._generate_single_turn(prompts) + prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts) # Get completion length per sequence, used for logging prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) @@ -1318,7 +1370,7 @@ def _generate(self, prompts: list): self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - return prompt_ids, completion_ids, total_completion_tokens, logprobs + return prompt_ids, completion_ids, total_completion_tokens, logprobs, extra_fields def _generate_and_score_completions( self, inputs: list[dict[str, Union[torch.Tensor, Any]]] @@ -1344,8 +1396,8 @@ def _generate_and_score_completions( if images is not None: prompts = [prepare_multimodal_messages(prompt, image_list) for prompt, image_list in zip(prompts, images)] - prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list = self._generate( - prompts + prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list, extra_fields = ( + self._generate(prompts) ) # Convert lists of token IDs to padded tensors @@ -1464,6 +1516,15 @@ def _generate_and_score_completions( else: completions = completions_text + # Merge extra_fields from rollout_func into inputs for reward functions + if extra_fields: + for i, inp in enumerate(inputs): + for key, values in extra_fields.items(): + if isinstance(values, list) and i < len(values): + inp[key] = values[i] + elif not isinstance(values, list): + inp[key] = values + # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is # important because rewards will be normalized per group, and completions are distributed. We will later slice # rewards_per_func to extract each process's subset.