From 7fa1863fb48dd60f43db4162fca50e547ab5f4e7 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Mon, 20 Oct 2025 14:52:44 +0000 Subject: [PATCH 01/18] Add rollout function for multi-step RL --- trl/trainer/grpo_trainer.py | 57 ++++++++++++++++++++++++++++--------- 1 file changed, 43 insertions(+), 14 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index eaeb6eb5a3..ce4fd0b93b 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -94,6 +94,11 @@ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] +# What we call a rollout function is a callable that takes in the data and sampling parameters and returns a list of +# generation results. Those results must include "prompt_ids", "completion_ids", and "logprobs" fields and can include an optional "tools" +# field. Any extra fields are forwarded to the reward functions. +RolloutFunc = Callable[[dict[str, Any], dict[str, Any]], list[dict[str, Any]]] + class GRPOTrainer(BaseTrainer): """ @@ -195,6 +200,11 @@ def reward_func(completions, **kwargs): model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. peft_config ([`~peft.PeftConfig`], *optional*): PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + rollout_func (`RolloutFunc`, *optional*, defaults to `None`): + Function to use for generating completions. It must take in the data and sampling parameters and return a + list of generation results. Those results must include "prompt_ids", "completion_ids", and "logprobs" + fields and can include optional "tools" field and any other fields that are forwarded to the reward + functions. """ _tag_names = ["trl", "grpo"] @@ -225,6 +235,7 @@ def __init__( callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional["PeftConfig"] = None, + rollout_func: Optional[RolloutFunc] = None, ): # Args if args is None: @@ -340,6 +351,9 @@ def __init__( self.reward_processing_classes = reward_processing_classes + # Rollout function + self.rollout_func = rollout_func + # Training arguments self.max_prompt_length = args.max_prompt_length self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper @@ -1116,20 +1130,35 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): ordered_set_of_images = None with profiling_context(self, "vLLM.generate"): - output = self.vllm_client.generate( - prompts=ordered_set_of_prompts, - images=ordered_set_of_images, - n=self.num_generations, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=-1 if self.top_k is None else self.top_k, - min_p=0.0 if self.min_p is None else self.min_p, - max_tokens=self.max_completion_length, - truncate_prompt_tokens=self.max_prompt_length, - guided_decoding_regex=self.guided_decoding_regex, - generation_kwargs=self.args.generation_kwargs, - ) + if self.rollout_func is not None: + output = self.rollout_func( + prompts=ordered_set_of_prompts, + n=self.num_generations, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=-1 if self.top_k is None else self.top_k, + min_p=0.0 if self.min_p is None else self.min_p, + max_tokens=self.max_completion_length, + truncate_prompt_tokens=self.max_prompt_length, + guided_decoding_regex=self.guided_decoding_regex, + generation_kwargs=self.args.generation_kwargs, + ) + else: + output = self.vllm_client.generate( + prompts=ordered_set_of_prompts, + images=ordered_set_of_images, + n=self.num_generations, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=-1 if self.top_k is None else self.top_k, + min_p=0.0 if self.min_p is None else self.min_p, + max_tokens=self.max_completion_length, + truncate_prompt_tokens=self.max_prompt_length, + guided_decoding_regex=self.guided_decoding_regex, + generation_kwargs=self.args.generation_kwargs, + ) payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"]) else: payload = None From 9b6e0202c2e0f3df44ca3dcd04666c8d214a0f8b Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Mon, 20 Oct 2025 18:53:11 +0000 Subject: [PATCH 02/18] Make multi-step RL work --- trl/trainer/grpo_trainer.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index ce4fd0b93b..361383814f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -94,10 +94,10 @@ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] -# What we call a rollout function is a callable that takes in the data and sampling parameters and returns a list of +# What we call a rollout function is a callable that takes in the data and returns a list of # generation results. Those results must include "prompt_ids", "completion_ids", and "logprobs" fields and can include an optional "tools" # field. Any extra fields are forwarded to the reward functions. -RolloutFunc = Callable[[dict[str, Any], dict[str, Any]], list[dict[str, Any]]] +RolloutFunc = Callable[[dict[str, Any]], list[dict[str, Any]]] class GRPOTrainer(BaseTrainer): @@ -201,10 +201,9 @@ def reward_func(completions, **kwargs): peft_config ([`~peft.PeftConfig`], *optional*): PEFT configuration used to wrap the model. If `None`, the model is not wrapped. rollout_func (`RolloutFunc`, *optional*, defaults to `None`): - Function to use for generating completions. It must take in the data and sampling parameters and return a - list of generation results. Those results must include "prompt_ids", "completion_ids", and "logprobs" - fields and can include optional "tools" field and any other fields that are forwarded to the reward - functions. + Function to use for generating completions. It must take in the data and return a list of generation + results. Those results must include "prompt_ids", "completion_ids", and "logprobs" fields and can include + optional "tools" field and any other fields that are forwarded to the reward functions. """ _tag_names = ["trl", "grpo"] From d84bb0477570062166a24d56adc35d3d83171506 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Mon, 20 Oct 2025 21:18:29 +0000 Subject: [PATCH 03/18] Restore sampling params --- trl/trainer/grpo_trainer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 361383814f..86954bb39f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -94,10 +94,10 @@ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] -# What we call a rollout function is a callable that takes in the data and returns a list of +# What we call a rollout function is a callable that takes in the data and sampling parameters and returns a list of # generation results. Those results must include "prompt_ids", "completion_ids", and "logprobs" fields and can include an optional "tools" # field. Any extra fields are forwarded to the reward functions. -RolloutFunc = Callable[[dict[str, Any]], list[dict[str, Any]]] +RolloutFunc = Callable[[dict[str, Any], dict[str, Any]], list[dict[str, Any]]] class GRPOTrainer(BaseTrainer): @@ -201,9 +201,9 @@ def reward_func(completions, **kwargs): peft_config ([`~peft.PeftConfig`], *optional*): PEFT configuration used to wrap the model. If `None`, the model is not wrapped. rollout_func (`RolloutFunc`, *optional*, defaults to `None`): - Function to use for generating completions. It must take in the data and return a list of generation - results. Those results must include "prompt_ids", "completion_ids", and "logprobs" fields and can include - optional "tools" field and any other fields that are forwarded to the reward functions. + Function to use for generating completions. It must take in the data sampling parameters and return a list + of generation results. Those results must include "prompt_ids", "completion_ids", and "logprobs" fields and + can include optional "tools" field and any other fields that are forwarded to the reward functions. """ _tag_names = ["trl", "grpo"] From c72bcc9429110a64f7ced40d65e2514b2aa8be1e Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 21 Oct 2025 13:02:55 +0000 Subject: [PATCH 04/18] Add OpenEnv integratoin --- trl/experimental/openenv/echo.py | 146 +++++++++++++++++++++++++++++++ trl/trainer/grpo_trainer.py | 1 + 2 files changed, 147 insertions(+) create mode 100644 trl/experimental/openenv/echo.py diff --git a/trl/experimental/openenv/echo.py b/trl/experimental/openenv/echo.py new file mode 100644 index 0000000000..afe83f5886 --- /dev/null +++ b/trl/experimental/openenv/echo.py @@ -0,0 +1,146 @@ +from datasets import load_dataset +import requests +from trl import GRPOConfig, GRPOTrainer +import subprocess +import time +import sys +import os +from envs.echo_env.models import ( + EchoAction, +) +from envs.echo_env import EchoEnv +from pathlib import Path +""" +Simple script to run GRPO training with OpenEnv's Echo environment and a vLLM server. The reward function encourages +longer completions. + +Usage (2 GPUs required): + +-- Spin up server -- CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port +8000 + +-- Run this script -- CUDA_VISIBLE_DEVICES=1 python trl/experimental/openenv/echo.py +""" + +GEN_URL = "http://0.0.0.0:8000/generate/" +ENV_URL = "http://0.0.0.0:8001" + +# Start the Echo server in background +print("⚔ Starting FastAPI server for Echo Environment...") + +# Determine the correct path +work_dir = str(Path.cwd().parent.absolute()) + +server_process = subprocess.Popen( + [sys.executable, "-m", "uvicorn", + "envs.echo_env.server.app:app", + "--host", "0.0.0.0", + "--port", "8001"], + env={**os.environ, + "PYTHONPATH": f"{work_dir}/src"}, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + cwd=work_dir +) + +# Wait for server to start +print("ā³ Waiting for server to start...") +time.sleep(5) + +# Check if server is running +import requests +try: + response = requests.get(f'{ENV_URL}/health', timeout=2) + print("\nāœ… Echo Environment server is running!") +except Exception as e: + print(f"\nāŒ Server failed to start: {e}") + print("\nšŸ“‹ Checking error output...") + server_process.poll() + if server_process.stderr: + stderr = server_process.stderr.read() + if stderr: + print(stderr) + raise + + +# Create HTTP client for Echo Environment +client = EchoEnv(base_url=f"{ENV_URL}") +print("āœ… Client created!") + +def rollout_func(prompts, **sampling_kwargs): + + # Make request to TRL's custom /generate/ endpoint + payload = { + "prompts": prompts, + "n": sampling_kwargs.get("n", 1), + "temperature": sampling_kwargs.get("temperature", 1.0), + "top_p": sampling_kwargs.get("top_p", 1.0), + "top_k": sampling_kwargs.get("top_k", -1), + "min_p": sampling_kwargs.get("min_p", 0.0), + "max_tokens": sampling_kwargs.get("max_tokens", 128), + "repetition_penalty": sampling_kwargs.get("repetition_penalty", 1.0), + } + + print(f"Sending request to {GEN_URL}") + response = requests.post(GEN_URL, json=payload) + + if response.status_code != 200: + print(f"Error response: {response.text}") + + response.raise_for_status() + result = response.json() + + # Decode for env communication + processing_class = sampling_kwargs.get("processing_class", None) + + completions_text = processing_class.batch_decode(result["completion_ids"], skip_special_tokens=True) + + # Flush env + env_result = client.reset() + + # Take an action (HTTP POST /step) + print("\nšŸ“¤ Calling client.step()...") + + for msg in completions_text: + env_result = client.step(EchoAction(message=msg)) + + # Get state (HTTP GET /state) + state = client.state() + # print(f"\nšŸ“Š Episode state:") + # print(f" • episode_id: {state.episode_id}") + # print(f" • step_count: {state.step_count}") + + # print(f"Response keys: {result.keys()}") + # print(f"Response shapes: {[(k, len(v) if isinstance(v, list) else 'not-list') for k, v in result.items()]}") + # print(f"=== rollout_func completed ===\n") + + return result + +dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train[:1000]") + +def reward_len(completions, **kwargs): + """Reward function that rewards longer completions.""" + completion_contents = [completion[0]["content"] for completion in completions] + return [float(len(content)) for content in completion_contents] + +training_args = GRPOConfig( + output_dir="scratch/Qwen2.5-0.5B-GRPO-Rollout", + vllm_mode="server", + use_vllm=True, + logging_steps=1, + report_to=["trackio", "wandb"], + num_train_epochs=1, + num_generations=16, + max_completion_length=4096, + per_device_train_batch_size=8, + gradient_accumulation_steps=4, +) +trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", + reward_funcs=reward_len, + args=training_args, + train_dataset=dataset, + rollout_func=rollout_func, +) +trainer.train() \ No newline at end of file diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 86954bb39f..16aa9d977a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1142,6 +1142,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): truncate_prompt_tokens=self.max_prompt_length, guided_decoding_regex=self.guided_decoding_regex, generation_kwargs=self.args.generation_kwargs, + processing_class=self.processing_class, ) else: output = self.vllm_client.generate( From ff04634e3e56b5f8c7d21cb4b91d67182292ae94 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 21 Oct 2025 14:16:46 +0000 Subject: [PATCH 05/18] it works! --- trl/experimental/openenv/echo.py | 99 ++++++++++++++++++++------------ trl/trainer/grpo_trainer.py | 42 +++++++++++--- 2 files changed, 95 insertions(+), 46 deletions(-) diff --git a/trl/experimental/openenv/echo.py b/trl/experimental/openenv/echo.py index afe83f5886..be680eca91 100644 --- a/trl/experimental/openenv/echo.py +++ b/trl/experimental/openenv/echo.py @@ -1,25 +1,51 @@ -from datasets import load_dataset -import requests -from trl import GRPOConfig, GRPOTrainer +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: T201 +import os import subprocess -import time import sys -import os +import time +from pathlib import Path + +import requests +from datasets import load_dataset +from envs.echo_env import EchoEnv from envs.echo_env.models import ( EchoAction, ) -from envs.echo_env import EchoEnv -from pathlib import Path + +from trl import GRPOConfig, GRPOTrainer + + """ Simple script to run GRPO training with OpenEnv's Echo environment and a vLLM server. The reward function encourages longer completions. +Setup: + +```bash +uv pip install git+https://github.com/meta-pytorch/OpenEnv.git +``` + Usage (2 GPUs required): --- Spin up server -- CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port -8000 +-- Spin up server -- +CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000 --- Run this script -- CUDA_VISIBLE_DEVICES=1 python trl/experimental/openenv/echo.py +-- Run this script -- +CUDA_VISIBLE_DEVICES=1 python trl/experimental/openenv/echo.py """ GEN_URL = "http://0.0.0.0:8000/generate/" @@ -32,16 +58,12 @@ work_dir = str(Path.cwd().parent.absolute()) server_process = subprocess.Popen( - [sys.executable, "-m", "uvicorn", - "envs.echo_env.server.app:app", - "--host", "0.0.0.0", - "--port", "8001"], - env={**os.environ, - "PYTHONPATH": f"{work_dir}/src"}, + [sys.executable, "-m", "uvicorn", "envs.echo_env.server.app:app", "--host", "0.0.0.0", "--port", "8001"], + env={**os.environ, "PYTHONPATH": f"{work_dir}/src"}, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, - cwd=work_dir + cwd=work_dir, ) # Wait for server to start @@ -49,9 +71,8 @@ time.sleep(5) # Check if server is running -import requests try: - response = requests.get(f'{ENV_URL}/health', timeout=2) + response = requests.get(f"{ENV_URL}/health", timeout=2) print("\nāœ… Echo Environment server is running!") except Exception as e: print(f"\nāŒ Server failed to start: {e}") @@ -68,8 +89,8 @@ client = EchoEnv(base_url=f"{ENV_URL}") print("āœ… Client created!") -def rollout_func(prompts, **sampling_kwargs): +def rollout_func(prompts, **sampling_kwargs): # Make request to TRL's custom /generate/ endpoint payload = { "prompts": prompts, @@ -91,7 +112,7 @@ def rollout_func(prompts, **sampling_kwargs): response.raise_for_status() result = response.json() - # Decode for env communication + # FIXME: we should not need to propagate the processing_class like this processing_class = sampling_kwargs.get("processing_class", None) completions_text = processing_class.batch_decode(result["completion_ids"], skip_special_tokens=True) @@ -99,30 +120,32 @@ def rollout_func(prompts, **sampling_kwargs): # Flush env env_result = client.reset() - # Take an action (HTTP POST /step) + # Take an action (HTTP POST /step) and collect environment rewards print("\nšŸ“¤ Calling client.step()...") + env_rewards = [] for msg in completions_text: env_result = client.step(EchoAction(message=msg)) + env_rewards.append(env_result.reward) - # Get state (HTTP GET /state) - state = client.state() - # print(f"\nšŸ“Š Episode state:") - # print(f" • episode_id: {state.episode_id}") - # print(f" • step_count: {state.step_count}") - - # print(f"Response keys: {result.keys()}") - # print(f"Response shapes: {[(k, len(v) if isinstance(v, list) else 'not-list') for k, v in result.items()]}") - # print(f"=== rollout_func completed ===\n") + result["env_reward"] = env_rewards return result + dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train[:1000]") -def reward_len(completions, **kwargs): - """Reward function that rewards longer completions.""" - completion_contents = [completion[0]["content"] for completion in completions] - return [float(len(content)) for content in completion_contents] + +def reward_from_env(completions, **kwargs): + """Reward function that uses the environment reward.""" + # Extract environment rewards from kwargs (propagated via extra_fields) + env_rewards = kwargs.get("env_reward", []) + if env_rewards: + return [float(reward) for reward in env_rewards] + else: + # Fallback if env_reward is not available + return [0.0] * len(completions) + training_args = GRPOConfig( output_dir="scratch/Qwen2.5-0.5B-GRPO-Rollout", @@ -131,16 +154,16 @@ def reward_len(completions, **kwargs): logging_steps=1, report_to=["trackio", "wandb"], num_train_epochs=1, - num_generations=16, + num_generations=8, max_completion_length=4096, per_device_train_batch_size=8, gradient_accumulation_steps=4, ) trainer = GRPOTrainer( model="Qwen/Qwen2.5-0.5B-Instruct", - reward_funcs=reward_len, + reward_funcs=reward_from_env, args=training_args, train_dataset=dataset, rollout_func=rollout_func, ) -trainer.train() \ No newline at end of file +trainer.train() diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 16aa9d977a..e0bf4f581b 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -96,7 +96,7 @@ # What we call a rollout function is a callable that takes in the data and sampling parameters and returns a list of # generation results. Those results must include "prompt_ids", "completion_ids", and "logprobs" fields and can include an optional "tools" -# field. Any extra fields are forwarded to the reward functions. +# field. Any extra fields (per-completion) are forwarded to the reward functions. RolloutFunc = Callable[[dict[str, Any], dict[str, Any]], list[dict[str, Any]]] @@ -1159,14 +1159,17 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): guided_decoding_regex=self.guided_decoding_regex, generation_kwargs=self.args.generation_kwargs, ) - payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"]) + # Extract required fields and collect any extra fields for reward functions + required_keys = {"prompt_ids", "completion_ids", "logprobs"} + extra_fields = {k: v for k, v in output.items() if k not in required_keys} + payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"], extra_fields) else: payload = None # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. obj_list = [payload] broadcast_object_list(obj_list, from_process=0) - all_prompt_ids, all_completion_ids, all_logprobs = obj_list[0] + all_prompt_ids, all_completion_ids, all_logprobs, all_extra_fields = obj_list[0] # At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(self.num_generations)] @@ -1179,6 +1182,15 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): completion_ids = all_completion_ids[process_slice] logprobs = all_logprobs[process_slice] + # Slice extra fields dict-of-lists per process (extra fields are per-completion, like completion_ids) + extra_fields = {} + for key, values in all_extra_fields.items(): + if isinstance(values, list): + extra_fields[key] = values[process_slice] + else: + # Scalar value, keep as-is + extra_fields[key] = values + # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts elif self.vllm_mode == "colocate": if self.guided_decoding_regex: @@ -1252,6 +1264,8 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): completion_ids = all_completion_ids logprobs = all_logprobs + extra_fields = {} # No extra fields for colocate mode + if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=1) @@ -1288,6 +1302,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): # Restore the original attention implementation, training mode self.model_wrapped.config._attn_implementation = previous_attn logprobs = None # not used in this case + extra_fields = {} # No extra fields for paged mode else: # Regular generation path @@ -1328,14 +1343,15 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())] logprobs = None # not used in this case + extra_fields = {} # No extra fields for non-rollout_func paths - return prompt_ids, completion_ids, logprobs + return prompt_ids, completion_ids, logprobs, extra_fields def _generate(self, prompts: list[str], images: Optional[list]): device = self.accelerator.device mode = "train" if self.model.training else "eval" - prompt_ids, completion_ids, logprobs = self._generate_single_turn(prompts, images) + prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts, images) # Get completion length per sequence, used for logging prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) @@ -1367,7 +1383,7 @@ def _generate(self, prompts: list[str], images: Optional[list]): self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - return prompt_ids, completion_ids, total_completion_tokens, logprobs + return prompt_ids, completion_ids, total_completion_tokens, logprobs, extra_fields def _generate_and_score_completions( self, inputs: list[dict[str, Union[torch.Tensor, Any]]] @@ -1387,8 +1403,8 @@ def _generate_and_score_completions( if images is not None and all(img_list == [] for img_list in images): images = None - prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list = self._generate( - prompts, images + prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list, extra_fields = ( + self._generate(prompts, images) ) # Convert lists of token IDs to padded tensors @@ -1507,6 +1523,16 @@ def _generate_and_score_completions( else: completions = completions_text + # Merge extra_fields from rollout_func into inputs for reward functions + if extra_fields: + for i, inp in enumerate(inputs): + for key, values in extra_fields.items(): + if isinstance(values, list) and i < len(values): + inp[key] = values[i] + elif not isinstance(values, list): + # Scalar value, add to all inputs + inp[key] = values + # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is # important because rewards will be normalized per group, and completions are distributed. We will later slice # rewards_per_func to extract each process's subset. From a9d80a6f3a1e428ef1f8b293cf48c3a3cfefe1c1 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 21 Oct 2025 14:52:16 +0000 Subject: [PATCH 06/18] Clean up --- trl/experimental/openenv/echo.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/trl/experimental/openenv/echo.py b/trl/experimental/openenv/echo.py index be680eca91..bdf9438bce 100644 --- a/trl/experimental/openenv/echo.py +++ b/trl/experimental/openenv/echo.py @@ -35,17 +35,23 @@ Setup: -```bash +```sh uv pip install git+https://github.com/meta-pytorch/OpenEnv.git ``` Usage (2 GPUs required): --- Spin up server -- +# Spin up server + +```sh CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000 +``` + +# Run training --- Run this script -- +```sh CUDA_VISIBLE_DEVICES=1 python trl/experimental/openenv/echo.py +``` """ GEN_URL = "http://0.0.0.0:8000/generate/" @@ -87,7 +93,6 @@ # Create HTTP client for Echo Environment client = EchoEnv(base_url=f"{ENV_URL}") -print("āœ… Client created!") def rollout_func(prompts, **sampling_kwargs): @@ -102,8 +107,6 @@ def rollout_func(prompts, **sampling_kwargs): "max_tokens": sampling_kwargs.get("max_tokens", 128), "repetition_penalty": sampling_kwargs.get("repetition_penalty", 1.0), } - - print(f"Sending request to {GEN_URL}") response = requests.post(GEN_URL, json=payload) if response.status_code != 200: @@ -119,9 +122,6 @@ def rollout_func(prompts, **sampling_kwargs): # Flush env env_result = client.reset() - - # Take an action (HTTP POST /step) and collect environment rewards - print("\nšŸ“¤ Calling client.step()...") env_rewards = [] for msg in completions_text: From cb76c08c735d0f659740bf29f2f70e3debecb858 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 21 Oct 2025 15:44:18 +0000 Subject: [PATCH 07/18] Refactor sig --- trl/experimental/openenv/echo.py | 23 ++++++++++------------- trl/trainer/grpo_trainer.py | 19 ++++++------------- 2 files changed, 16 insertions(+), 26 deletions(-) diff --git a/trl/experimental/openenv/echo.py b/trl/experimental/openenv/echo.py index bdf9438bce..7fc693d5a4 100644 --- a/trl/experimental/openenv/echo.py +++ b/trl/experimental/openenv/echo.py @@ -95,17 +95,17 @@ client = EchoEnv(base_url=f"{ENV_URL}") -def rollout_func(prompts, **sampling_kwargs): +def rollout_func(prompts, images, args, processing_class): # Make request to TRL's custom /generate/ endpoint payload = { "prompts": prompts, - "n": sampling_kwargs.get("n", 1), - "temperature": sampling_kwargs.get("temperature", 1.0), - "top_p": sampling_kwargs.get("top_p", 1.0), - "top_k": sampling_kwargs.get("top_k", -1), - "min_p": sampling_kwargs.get("min_p", 0.0), - "max_tokens": sampling_kwargs.get("max_tokens", 128), - "repetition_penalty": sampling_kwargs.get("repetition_penalty", 1.0), + "n": args.num_generations, + "temperature": args.temperature, + "top_p": args.top_p, + "top_k": -1 if args.top_k is None else args.top_k, + "min_p": 0.0 if args.min_p is None else args.min_p, + "max_tokens": args.max_completion_length, + "repetition_penalty": args.repetition_penalty, } response = requests.post(GEN_URL, json=payload) @@ -115,15 +115,12 @@ def rollout_func(prompts, **sampling_kwargs): response.raise_for_status() result = response.json() - # FIXME: we should not need to propagate the processing_class like this - processing_class = sampling_kwargs.get("processing_class", None) - completions_text = processing_class.batch_decode(result["completion_ids"], skip_special_tokens=True) # Flush env env_result = client.reset() - env_rewards = [] + env_rewards = [] for msg in completions_text: env_result = client.step(EchoAction(message=msg)) env_rewards.append(env_result.reward) @@ -155,7 +152,7 @@ def reward_from_env(completions, **kwargs): report_to=["trackio", "wandb"], num_train_epochs=1, num_generations=8, - max_completion_length=4096, + max_completion_length=2048, per_device_train_batch_size=8, gradient_accumulation_steps=4, ) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index e0bf4f581b..ff355ec9fb 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -201,9 +201,10 @@ def reward_func(completions, **kwargs): peft_config ([`~peft.PeftConfig`], *optional*): PEFT configuration used to wrap the model. If `None`, the model is not wrapped. rollout_func (`RolloutFunc`, *optional*, defaults to `None`): - Function to use for generating completions. It must take in the data sampling parameters and return a list - of generation results. Those results must include "prompt_ids", "completion_ids", and "logprobs" fields and - can include optional "tools" field and any other fields that are forwarded to the reward functions. + Function to use for generating completions. It must take prompts, images (optional), args, and + processing_class as parameters and return a dict with "prompt_ids", "completion_ids", and "logprobs" + fields. It can include optional "tools" field and any other fields that are forwarded to the reward + functions. """ _tag_names = ["trl", "grpo"] @@ -1132,16 +1133,8 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): if self.rollout_func is not None: output = self.rollout_func( prompts=ordered_set_of_prompts, - n=self.num_generations, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=-1 if self.top_k is None else self.top_k, - min_p=0.0 if self.min_p is None else self.min_p, - max_tokens=self.max_completion_length, - truncate_prompt_tokens=self.max_prompt_length, - guided_decoding_regex=self.guided_decoding_regex, - generation_kwargs=self.args.generation_kwargs, + images=ordered_set_of_images, + args=self.args, processing_class=self.processing_class, ) else: From f295fb27f29b6b3f6d376eb54055776cba33cf07 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 21 Oct 2025 15:47:25 +0000 Subject: [PATCH 08/18] Fix doc --- trl/experimental/openenv/echo.py | 14 +++++++++++++- trl/trainer/grpo_trainer.py | 9 +++++---- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/trl/experimental/openenv/echo.py b/trl/experimental/openenv/echo.py index 7fc693d5a4..26660e5f7e 100644 --- a/trl/experimental/openenv/echo.py +++ b/trl/experimental/openenv/echo.py @@ -95,7 +95,19 @@ client = EchoEnv(base_url=f"{ENV_URL}") -def rollout_func(prompts, images, args, processing_class): +def rollout_func(prompts: list[str], images: list | None, args: GRPOConfig, processing_class) -> dict[str, list]: + """ + Custom rollout function that generates completions via vLLM server and computes environment rewards. + + Args: + prompts: List of prompt strings to generate from + images: Optional images for vision models (not used in this example) + args: GRPOConfig containing all sampling parameters + processing_class: Tokenizer/processor for decoding completions + + Returns: + Dict containing prompt_ids, completion_ids, logprobs, and env_reward + """ # Make request to TRL's custom /generate/ endpoint payload = { "prompts": prompts, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index ff355ec9fb..9db6a83d4f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -94,10 +94,11 @@ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] -# What we call a rollout function is a callable that takes in the data and sampling parameters and returns a list of -# generation results. Those results must include "prompt_ids", "completion_ids", and "logprobs" fields and can include an optional "tools" -# field. Any extra fields (per-completion) are forwarded to the reward functions. -RolloutFunc = Callable[[dict[str, Any], dict[str, Any]], list[dict[str, Any]]] +# What we call a rollout function is a callable that takes prompts (list), images (optional), args (GRPOConfig), +# and processing_class as parameters and returns a dict of generation results. Those results must include "prompt_ids", +# "completion_ids", and "logprobs" fields and can include an optional "tools" field. Any extra fields (per-completion) +# are forwarded to the reward functions. +RolloutFunc = Callable[[list[str], Any, Any, Any], dict[str, Any]] class GRPOTrainer(BaseTrainer): From 2bf9de830ab25ec54e2cee6a0c2965c10ac5426d Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 21 Oct 2025 15:51:04 +0000 Subject: [PATCH 09/18] Clean up --- trl/trainer/grpo_trainer.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 9db6a83d4f..8b9407f985 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -96,8 +96,7 @@ # What we call a rollout function is a callable that takes prompts (list), images (optional), args (GRPOConfig), # and processing_class as parameters and returns a dict of generation results. Those results must include "prompt_ids", -# "completion_ids", and "logprobs" fields and can include an optional "tools" field. Any extra fields (per-completion) -# are forwarded to the reward functions. +# "completion_ids", and "logprobs" fields. Any extra fields (per-completion) are forwarded to the reward functions. RolloutFunc = Callable[[list[str], Any, Any, Any], dict[str, Any]] @@ -204,8 +203,7 @@ def reward_func(completions, **kwargs): rollout_func (`RolloutFunc`, *optional*, defaults to `None`): Function to use for generating completions. It must take prompts, images (optional), args, and processing_class as parameters and return a dict with "prompt_ids", "completion_ids", and "logprobs" - fields. It can include optional "tools" field and any other fields that are forwarded to the reward - functions. + fields. Any other fields that are forwarded to the reward functions. """ _tag_names = ["trl", "grpo"] @@ -1182,7 +1180,6 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): if isinstance(values, list): extra_fields[key] = values[process_slice] else: - # Scalar value, keep as-is extra_fields[key] = values # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts @@ -1524,7 +1521,6 @@ def _generate_and_score_completions( if isinstance(values, list) and i < len(values): inp[key] = values[i] elif not isinstance(values, list): - # Scalar value, add to all inputs inp[key] = values # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is From 47327f3de552178d5efcbea8cff336f9ef91a0a5 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 21 Oct 2025 16:52:16 +0000 Subject: [PATCH 10/18] Add WIP Catch --- trl/experimental/openenv/catch.py | 249 ++++++++++++++++++++++++++++++ 1 file changed, 249 insertions(+) create mode 100644 trl/experimental/openenv/catch.py diff --git a/trl/experimental/openenv/catch.py b/trl/experimental/openenv/catch.py new file mode 100644 index 0000000000..b01adbbebd --- /dev/null +++ b/trl/experimental/openenv/catch.py @@ -0,0 +1,249 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: T201 +import os +import subprocess +import sys +import time +from pathlib import Path + +import requests +from datasets import Dataset +from envs.openspiel_env import OpenSpielEnv +from envs.openspiel_env.models import OpenSpielAction + +from trl import GRPOConfig, GRPOTrainer + + +""" +Simple script to run GRPO training with OpenEnv's Catch environment (OpenSpiel) and a vLLM server. The reward function +is based on the catch game where the agent tries to catch falling balls. + +Setup: + +```sh +uv pip install git+https://github.com/meta-pytorch/OpenEnv.git +uv pip install open_spiel +``` + +Usage (2 GPUs required): + +# Spin up vLLM server + +```sh +CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000 +``` + +# Run training + +```sh +CUDA_VISIBLE_DEVICES=1 python trl/experimental/openenv/catch.py +``` +""" + +GEN_URL = "http://0.0.0.0:8000/generate/" +ENV_URL = "http://0.0.0.0:8002" + +BASE_PROMPT = """You are an AI agent playing the game **Catch**. + +### Game Description +- The game is played on a **5Ɨ5 grid**. +- There is one **falling ball** and one **paddle** that you control at the bottom. +- The objective is to **move the paddle left or right to catch the ball** as it falls. +- The episode ends when the ball reaches the bottom row: + - You get **+1 reward** if you catch it. + - You get **–1 reward** if you miss it. + +### Observation Format You will receive: +- `observation`: a list of **50 numbers (floats)**. + - The first **25 numbers** (indices `0–24`) represent the **ball layer**, flattened from a 5Ɨ5 grid. Each cell is + `1.0` if the ball is there, `0.0` otherwise. + - The next **25 numbers** (indices `25–49`) represent the **paddle layer**, also flattened from a 5Ɨ5 grid. Each cell + is `1.0` if the paddle occupies that column in the bottom row, `0.0` otherwise. +- `legal_actions`: a list of integers representing which actions are currently allowed. + +### Actions Each action is a discrete integer: +- `0` → Move paddle **left** +- `1` → **Stay** (no movement) +- `2` → Move paddle **right** + +### Output Format Respond **only with one integer** representing your chosen action: `0`, `1`, or `2`. + +### Current Observation +""" + +# Start the OpenSpiel server in background +print("⚔ Starting FastAPI server for OpenSpiel Catch Environment...") + +# Determine the correct path +work_dir = str(Path.cwd().parent.absolute()) + +server_process = subprocess.Popen( + [sys.executable, "-m", "uvicorn", "envs.openspiel_env.server.app:app", "--host", "0.0.0.0", "--port", "8002"], + env={**os.environ, "PYTHONPATH": f"{work_dir}/src"}, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + cwd=work_dir, +) + +# Wait for server to start +print("ā³ Waiting for server to start...") +time.sleep(5) + +# Check if server is running +try: + response = requests.get(f"{ENV_URL}/health", timeout=2) + print("\nāœ… OpenSpiel Catch Environment server is running!") +except Exception as e: + print(f"\nāŒ Server failed to start: {e}") + print("\nšŸ“‹ Checking error output...") + server_process.poll() + if server_process.stderr: + stderr = server_process.stderr.read() + if stderr: + print(stderr) + raise + + +# Create HTTP client for OpenSpiel Catch Environment +client = OpenSpielEnv(base_url=f"{ENV_URL}") + + +def rollout_func(prompts: list[str], images: list | None, args: GRPOConfig, processing_class) -> dict[str, list]: + """ + Custom rollout function that generates completions via vLLM server and computes environment rewards. + + The catch game expects action IDs (integers). We'll parse the model's text output to extract action choices. + + Args: + prompts: List of prompt strings to generate from + images: Optional images for vision models (not used in this example) + args: GRPOConfig containing all sampling parameters + processing_class: Tokenizer/processor for decoding completions + + Returns: + Dict containing prompt_ids, completion_ids, logprobs, and env_reward + """ + import re + + # Run full episodes for each generation to get episode rewards + env_rewards = [] + all_prompt_ids = [] + all_completion_ids = [] + all_logprobs = [] + + for prompt_idx, base_prompt in enumerate(prompts): + for _ in range(args.num_generations): + # Run episode: Reset environment and loop until done + env_result = client.reset() + obs = env_result.observation + total_reward = 0.0 + + episode_prompt_ids = [] + episode_completion_ids = [] + episode_logprobs = [] + + while not obs.done: + # Build prompt with current observation and legal actions + episode_prompt = ( + f"{base_prompt}\n\n" + f"{obs.info_state}\n" + ) + + # Generate action from model + gen_payload = { + "prompts": [episode_prompt], + "n": 1, + "temperature": args.temperature, + "top_p": args.top_p, + "top_k": -1 if args.top_k is None else args.top_k, + "min_p": 0.0 if args.min_p is None else args.min_p, + "max_tokens": args.max_completion_length, + "repetition_penalty": args.repetition_penalty, + } + gen_response = requests.post(GEN_URL, json=gen_payload) + gen_response.raise_for_status() + gen_result = gen_response.json() + + # Collect prompt_ids, completion_ids, and logprobs from this step + episode_prompt_ids.extend(gen_result["prompt_ids"][0]) + episode_completion_ids.extend(gen_result["completion_ids"][0]) + episode_logprobs.extend(gen_result["logprobs"][0]) + + completion_text = processing_class.batch_decode(gen_result["completion_ids"], skip_special_tokens=True)[0] + + # Parse action from completion + action_id = 0 # default + numbers = re.findall(r'\b([0-2])\b', completion_text) + if numbers: + action_id = int(numbers[0]) + elif obs.legal_actions: + action_id = obs.legal_actions[0] + + # Take action in environment + env_result = client.step(OpenSpielAction(action_id=action_id, game_name="catch")) + reward = env_result.reward if env_result.reward is not None else 0.0 + total_reward += reward + obs = env_result.observation + + # Store episode results + env_rewards.append(total_reward) + all_prompt_ids.append(episode_prompt_ids) + all_completion_ids.append(episode_completion_ids) + all_logprobs.append(episode_logprobs) + + return { + "prompt_ids": all_prompt_ids, + "completion_ids": all_completion_ids, + "logprobs": all_logprobs, + "env_reward": env_rewards, + } + + +dataset = Dataset.from_dict({"prompt": [BASE_PROMPT] * 1000}) + + +def reward_from_env(completions, **kwargs): + """Reward function that uses the environment reward from the catch game.""" + # Extract environment rewards from kwargs (propagated via extra_fields) + env_rewards = kwargs.get("env_reward", []) + if env_rewards: + return [float(reward) for reward in env_rewards] + else: + # Fallback if env_reward is not available + return [0.0] * len(completions) + + +training_args = GRPOConfig( + output_dir="scratch/Qwen2.5-0.5B-GRPO-Catch", + vllm_mode="server", + use_vllm=True, + logging_steps=1, + report_to=["trackio", "wandb"], + num_train_epochs=1, + num_generations=8, + max_completion_length=64, # Shorter for catch game (just need action selection) + per_device_train_batch_size=8, + gradient_accumulation_steps=4, +) +trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", + reward_funcs=reward_from_env, + args=training_args, + train_dataset=dataset, + rollout_func=rollout_func, +) +trainer.train() From 3739b8c3fbbd03d6ac7092e2c328d19d4431aa31 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 21 Oct 2025 20:05:17 +0200 Subject: [PATCH 11/18] terminate after training (#4318) --- trl/experimental/openenv/echo.py | 203 +++++++++++++++---------------- 1 file changed, 101 insertions(+), 102 deletions(-) diff --git a/trl/experimental/openenv/echo.py b/trl/experimental/openenv/echo.py index 26660e5f7e..e23d84642e 100644 --- a/trl/experimental/openenv/echo.py +++ b/trl/experimental/openenv/echo.py @@ -17,6 +17,7 @@ import subprocess import sys import time +from contextlib import suppress from pathlib import Path import requests @@ -66,9 +67,6 @@ server_process = subprocess.Popen( [sys.executable, "-m", "uvicorn", "envs.echo_env.server.app:app", "--host", "0.0.0.0", "--port", "8001"], env={**os.environ, "PYTHONPATH": f"{work_dir}/src"}, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, cwd=work_dir, ) @@ -76,103 +74,104 @@ print("ā³ Waiting for server to start...") time.sleep(5) -# Check if server is running try: - response = requests.get(f"{ENV_URL}/health", timeout=2) - print("\nāœ… Echo Environment server is running!") -except Exception as e: - print(f"\nāŒ Server failed to start: {e}") - print("\nšŸ“‹ Checking error output...") - server_process.poll() - if server_process.stderr: - stderr = server_process.stderr.read() - if stderr: - print(stderr) - raise - - -# Create HTTP client for Echo Environment -client = EchoEnv(base_url=f"{ENV_URL}") - - -def rollout_func(prompts: list[str], images: list | None, args: GRPOConfig, processing_class) -> dict[str, list]: - """ - Custom rollout function that generates completions via vLLM server and computes environment rewards. - - Args: - prompts: List of prompt strings to generate from - images: Optional images for vision models (not used in this example) - args: GRPOConfig containing all sampling parameters - processing_class: Tokenizer/processor for decoding completions - - Returns: - Dict containing prompt_ids, completion_ids, logprobs, and env_reward - """ - # Make request to TRL's custom /generate/ endpoint - payload = { - "prompts": prompts, - "n": args.num_generations, - "temperature": args.temperature, - "top_p": args.top_p, - "top_k": -1 if args.top_k is None else args.top_k, - "min_p": 0.0 if args.min_p is None else args.min_p, - "max_tokens": args.max_completion_length, - "repetition_penalty": args.repetition_penalty, - } - response = requests.post(GEN_URL, json=payload) - - if response.status_code != 200: - print(f"Error response: {response.text}") - - response.raise_for_status() - result = response.json() - - completions_text = processing_class.batch_decode(result["completion_ids"], skip_special_tokens=True) - - # Flush env - env_result = client.reset() - - env_rewards = [] - for msg in completions_text: - env_result = client.step(EchoAction(message=msg)) - env_rewards.append(env_result.reward) - - result["env_reward"] = env_rewards - - return result - - -dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train[:1000]") - - -def reward_from_env(completions, **kwargs): - """Reward function that uses the environment reward.""" - # Extract environment rewards from kwargs (propagated via extra_fields) - env_rewards = kwargs.get("env_reward", []) - if env_rewards: - return [float(reward) for reward in env_rewards] - else: - # Fallback if env_reward is not available - return [0.0] * len(completions) - - -training_args = GRPOConfig( - output_dir="scratch/Qwen2.5-0.5B-GRPO-Rollout", - vllm_mode="server", - use_vllm=True, - logging_steps=1, - report_to=["trackio", "wandb"], - num_train_epochs=1, - num_generations=8, - max_completion_length=2048, - per_device_train_batch_size=8, - gradient_accumulation_steps=4, -) -trainer = GRPOTrainer( - model="Qwen/Qwen2.5-0.5B-Instruct", - reward_funcs=reward_from_env, - args=training_args, - train_dataset=dataset, - rollout_func=rollout_func, -) -trainer.train() + # Check if server is running + try: + response = requests.get(f"{ENV_URL}/health", timeout=2) + print("\nāœ… Echo Environment server is running!") + except Exception as e: + print(f"\nāŒ Server failed to start: {e}") + raise + + # Create HTTP client for Echo Environment + client = EchoEnv(base_url=f"{ENV_URL}") + + def rollout_func(prompts: list[str], images: list | None, args: GRPOConfig, processing_class) -> dict[str, list]: + """ + Custom rollout function that generates completions via vLLM server and computes environment rewards. + + Args: + prompts: List of prompt strings to generate from + images: Optional images for vision models (not used in this example) + args: GRPOConfig containing all sampling parameters + processing_class: Tokenizer/processor for decoding completions + + Returns: + Dict containing prompt_ids, completion_ids, logprobs, and env_reward + """ + # Make request to TRL's custom /generate/ endpoint + payload = { + "prompts": prompts, + "n": args.num_generations, + "temperature": args.temperature, + "top_p": args.top_p, + "top_k": -1 if args.top_k is None else args.top_k, + "min_p": 0.0 if args.min_p is None else args.min_p, + "max_tokens": args.max_completion_length, + "repetition_penalty": args.repetition_penalty, + } + response = requests.post(GEN_URL, json=payload) + + if response.status_code != 200: + print(f"Error response: {response.text}") + + response.raise_for_status() + result = response.json() + + completions_text = processing_class.batch_decode(result["completion_ids"], skip_special_tokens=True) + + # Flush env + env_result = client.reset() + + env_rewards = [] + for msg in completions_text: + env_result = client.step(EchoAction(message=msg)) + env_rewards.append(env_result.reward) + + result["env_reward"] = env_rewards + + return result + + dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train[:1000]") + + def reward_from_env(completions, **kwargs): + """Reward function that uses the environment reward.""" + # Extract environment rewards from kwargs (propagated via extra_fields) + env_rewards = kwargs.get("env_reward", []) + if env_rewards: + return [float(reward) for reward in env_rewards] + else: + # Fallback if env_reward is not available + return [0.0] * len(completions) + + training_args = GRPOConfig( + output_dir="scratch/Qwen2.5-0.5B-GRPO-Rollout", + vllm_mode="server", + use_vllm=True, + logging_steps=1, + report_to=["trackio", "wandb"], + num_train_epochs=1, + num_generations=8, + max_completion_length=2048, + per_device_train_batch_size=8, + gradient_accumulation_steps=4, + ) + trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", + reward_funcs=reward_from_env, + args=training_args, + train_dataset=dataset, + rollout_func=rollout_func, + ) + trainer.train() +finally: + print("\nšŸ›‘ Stopping Echo Environment server...") + if server_process.poll() is None: + server_process.terminate() + try: + server_process.wait(timeout=5) + except subprocess.TimeoutExpired: + print("āš ļø Termination timeout reached. Forcing shutdown.") + server_process.kill() + with suppress(subprocess.TimeoutExpired): + server_process.wait(timeout=5) From d8baa434f03cc66db809244dcdf74c27280f1833 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 21 Oct 2025 20:00:18 +0000 Subject: [PATCH 12/18] Move --- .../scripts}/openenv/catch.py | 0 examples/scripts/openenv/catch_demo.py | 271 ++++++++++++++++++ examples/scripts/openenv/echo.py | 185 ++++++++++++ 3 files changed, 456 insertions(+) rename {trl/experimental => examples/scripts}/openenv/catch.py (100%) create mode 100644 examples/scripts/openenv/catch_demo.py create mode 100644 examples/scripts/openenv/echo.py diff --git a/trl/experimental/openenv/catch.py b/examples/scripts/openenv/catch.py similarity index 100% rename from trl/experimental/openenv/catch.py rename to examples/scripts/openenv/catch.py diff --git a/examples/scripts/openenv/catch_demo.py b/examples/scripts/openenv/catch_demo.py new file mode 100644 index 0000000000..4a40f17286 --- /dev/null +++ b/examples/scripts/openenv/catch_demo.py @@ -0,0 +1,271 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: T201 +""" +Simple demo script for the OpenSpiel Catch environment. + +This demonstrates the basic workflow: +1. Start the environment server +2. Connect to it +3. Reset and observe initial state +4. Take actions and see rewards +5. Clean up + +Setup: + +```sh +uv pip install git+https://github.com/meta-pytorch/OpenEnv.git +uv pip install open_spiel +``` + +Usage: + +```sh +python trl/experimental/openenv/catch_demo.py +``` +""" + +import os +import subprocess +import sys +import time +from pathlib import Path + +import requests +from envs.openspiel_env import OpenSpielEnv +from envs.openspiel_env.models import OpenSpielAction + +ENV_URL = "http://0.0.0.0:8002" + + +class Policy: + """Base policy class.""" + + def __init__(self, name): + self.name = name + + def select_action(self, obs): + """Select an action given an observation.""" + raise NotImplementedError + + +class RandomPolicy(Policy): + """Policy that selects random legal actions.""" + + def __init__(self): + super().__init__("šŸŽ² Random Policy") + + def select_action(self, obs): + import random + + return random.choice(obs.legal_actions) if obs.legal_actions else 0 + + +class LeftPolicy(Policy): + """Policy that always moves left.""" + + def __init__(self): + super().__init__("ā¬…ļø Always Left") + + def select_action(self, obs): + return obs.legal_actions[0] if obs.legal_actions else 0 + + +class RightPolicy(Policy): + """Policy that always moves right.""" + + def __init__(self): + super().__init__("āž”ļø Always Right") + + def select_action(self, obs): + return obs.legal_actions[-1] if obs.legal_actions else 2 + + +class SmartPolicy(Policy): + """Policy that tries to move towards the ball.""" + + def __init__(self): + super().__init__("🧠 Smart Policy") + + def select_action(self, obs): + # In catch, the info_state contains information about paddle and ball positions + # This is a simple heuristic - in practice you'd need to understand the state representation + if not obs.legal_actions: + return 0 + + # For catch game, often legal actions are [0, 1, 2] = [left, stay, right] + # Simple heuristic: choose middle action (stay) or random + import random + + return random.choice(obs.legal_actions) + + +def run_episode(env, policy, visualize=True, delay=0.3): + """Run one episode with a policy against OpenSpiel environment.""" + + # RESET + result = env.reset() + obs = result.observation + + if visualize: + print(f"\n{'='*60}") + print(f" šŸŽ® {policy.name}") + print(f" šŸŽ² Playing against OpenSpiel Catch") + print("=" * 60 + "\n") + time.sleep(delay) + + total_reward = 0 + step = 0 + action_names = ["ā¬…ļø LEFT", "šŸ›‘ STAY", "āž”ļø RIGHT"] + + # THE RL LOOP + while not obs.done: + # 1. Policy chooses action + action_id = policy.select_action(obs) + + # 2. Environment executes (via HTTP!) + action = OpenSpielAction(action_id=action_id, game_name="catch") + result = env.step(action) + obs = result.observation + + # 3. Collect reward + if result.reward is not None: + total_reward += result.reward + + if visualize: + action_name = action_names[action_id] if action_id < len(action_names) else f"ACTION {action_id}" + print(f"šŸ“ Step {step + 1}: {action_name} → Reward: {result.reward}") + time.sleep(delay) + + step += 1 + + if visualize: + result_text = "šŸŽ‰ CAUGHT!" if total_reward > 0 else "😢 MISSED" + print(f"\n{'='*60}") + print(f" {result_text} Total Reward: {total_reward}") + print("=" * 60) + + return total_reward > 0 + + +def start_server(): + """Start the OpenSpiel environment server.""" + print("⚔ Starting FastAPI server for OpenSpiel Catch Environment...") + + work_dir = str(Path.cwd().parent.absolute()) + + server_process = subprocess.Popen( + [sys.executable, "-m", "uvicorn", "envs.openspiel_env.server.app:app", "--host", "0.0.0.0", "--port", "8002"], + env={**os.environ, "PYTHONPATH": f"{work_dir}/src"}, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + cwd=work_dir, + ) + + # Wait for server to start + print("ā³ Waiting for server to start...") + time.sleep(5) + + # Check if server is running + try: + response = requests.get(f"{ENV_URL}/health", timeout=2) + print("āœ… OpenSpiel Catch Environment server is running!\n") + return server_process + except Exception as e: + print(f"āŒ Server failed to start: {e}") + print("\nšŸ“‹ Checking error output...") + server_process.poll() + if server_process.stderr: + stderr = server_process.stderr.read() + if stderr: + print(stderr) + raise + + +def run_demo(): + """Run a simple demo of the Catch environment.""" + print("šŸŽÆ OpenSpiel Catch Environment Demo") + print("=" * 60) + + # Connect to environment server + client = OpenSpielEnv(base_url=ENV_URL) + + try: + # Reset environment and show initial state + print("\nšŸ“ Resetting environment...") + result = client.reset() + + print(f" Initial observation shape: {len(result.observation.info_state)}") + print(f" Info state (first 10 values): {result.observation.info_state[:10]}") + print(f" Legal actions: {result.observation.legal_actions}") + print(f" Game phase: {result.observation.game_phase}") + print(f" Done: {result.done}") + print(f" Initial reward: {result.reward}") + + # Demo different policies + print("\nšŸ“ŗ " + "=" * 64 + " šŸ“ŗ") + print(" Watch Policies Play Against OpenSpiel!") + print("šŸ“ŗ " + "=" * 64 + " šŸ“ŗ\n") + + policies = [SmartPolicy(), RandomPolicy(), LeftPolicy(), RightPolicy()] + + for policy in policies: + caught = run_episode(client, policy, visualize=True, delay=0.5) + + print("\nšŸ’” You just watched REAL OpenSpiel Catch being played!") + print(" • Every action was an HTTP call") + print(" • Game logic runs in the server") + print(" • Client only sends actions and receives observations\n") + + # Get final environment state + state = client.state() + print(f"\nšŸ“Š Final Environment State:") + print(f" Episode ID: {state.episode_id}") + print(f" Step count: {state.step_count}") + print(f" Game: {state.game_name}") + print(f" Num players: {state.num_players}") + print(f" Agent player: {state.agent_player}") + + except Exception as e: + print(f"\nāŒ Error during demo: {e}") + import traceback + + traceback.print_exc() + + finally: + # Always close the environment + client.close() + print("\nāœ… Demo complete!") + + +def main(): + """Main function to run the demo.""" + server_process = None + try: + server_process = start_server() + run_demo() + except KeyboardInterrupt: + print("\n\nāš ļø Interrupted by user") + finally: + if server_process: + print("\nšŸ›‘ Shutting down server...") + server_process.terminate() + server_process.wait(timeout=5) + print("šŸ‘‹ Done!") + + +if __name__ == "__main__": + main() diff --git a/examples/scripts/openenv/echo.py b/examples/scripts/openenv/echo.py new file mode 100644 index 0000000000..899d9403c1 --- /dev/null +++ b/examples/scripts/openenv/echo.py @@ -0,0 +1,185 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: T201 +import os +import subprocess +import sys +import time +from pathlib import Path + +import requests +from datasets import load_dataset +from envs.echo_env import EchoEnv +from envs.echo_env.models import ( + EchoAction, +) + +from trl import GRPOConfig, GRPOTrainer + + +""" +Simple script to run GRPO training with OpenEnv's Echo environment and a vLLM server. The reward function encourages +longer completions. + +Setup: + +```sh +uv pip install git+https://github.com/meta-pytorch/OpenEnv.git +``` + +Usage (2 GPUs required): + +# Spin up server + +```sh +CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000 +``` + +# Run training + +```sh +CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/echo.py +``` +""" + +GEN_URL = "http://0.0.0.0:8000/generate/" +ENV_URL = "http://0.0.0.0:8001" + +# Start the Echo server in background +print("⚔ Starting FastAPI server for Echo Environment...") + +# Determine the correct path +work_dir = str(Path.cwd().parent.absolute()) + +server_process = subprocess.Popen( + [sys.executable, "-m", "uvicorn", "envs.echo_env.server.app:app", "--host", "0.0.0.0", "--port", "8001"], + env={**os.environ, "PYTHONPATH": f"{work_dir}/src"}, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + cwd=work_dir, +) + +# Wait for server to start +print("ā³ Waiting for server to start...") +time.sleep(5) + +# Check if server is running +try: + response = requests.get(f"{ENV_URL}/health", timeout=2) + print("\nāœ… Echo Environment server is running!") +except Exception as e: + print(f"\nāŒ Server failed to start: {e}") + print("\nšŸ“‹ Checking error output...") + server_process.poll() + if server_process.stderr: + stderr = server_process.stderr.read() + if stderr: + print(stderr) + raise + + +# Create HTTP client for Echo Environment +client = EchoEnv(base_url=f"{ENV_URL}") + + +def rollout_func(prompts: list[str], images: list | None, args: GRPOConfig, processing_class) -> dict[str, list]: + """ + Custom rollout function that generates completions via vLLM server and computes environment rewards. + + Args: + prompts: List of prompt strings to generate from + images: Optional images for vision models (not used in this example) + args: GRPOConfig containing all sampling parameters + processing_class: Tokenizer/processor for decoding completions + + Returns: + Dict containing prompt_ids, completion_ids, logprobs, and env_reward + """ + # Make request to TRL's custom /generate/ endpoint + payload = { + "prompts": prompts, + "n": args.num_generations, + "temperature": args.temperature, + "top_p": args.top_p, + "top_k": -1 if args.top_k is None else args.top_k, + "min_p": 0.0 if args.min_p is None else args.min_p, + "max_tokens": args.max_completion_length, + "repetition_penalty": args.repetition_penalty, + } + response = requests.post(GEN_URL, json=payload) + + if response.status_code != 200: + print(f"Error response: {response.text}") + + response.raise_for_status() + result = response.json() + + completions_text = processing_class.batch_decode(result["completion_ids"], skip_special_tokens=True) + + # Flush env + env_result = client.reset() + + env_rewards = [] + for msg in completions_text: + env_result = client.step(EchoAction(message=msg)) + env_rewards.append(env_result.reward) + + result["env_reward"] = env_rewards + + return result + + +dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train[:1000]") + + +def reward_from_env(completions, **kwargs): + """Reward function that uses the environment reward.""" + # Extract environment rewards from kwargs (propagated via extra_fields) + env_rewards = kwargs.get("env_reward", []) + if env_rewards: + return [float(reward) for reward in env_rewards] + else: + # Fallback if env_reward is not available + return [0.0] * len(completions) + + +training_args = GRPOConfig( + output_dir="scratch/Qwen2.5-0.5B-GRPO-Rollout", + vllm_mode="server", + use_vllm=True, + logging_steps=1, + report_to=["trackio", "wandb"], + num_train_epochs=1, + num_generations=8, + max_completion_length=2048, + per_device_train_batch_size=8, + gradient_accumulation_steps=4, + max_steps=1 +) +trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", + reward_funcs=reward_from_env, + args=training_args, + train_dataset=dataset, + rollout_func=rollout_func, +) +trainer.train() + +# Give time for background threads to finish +time.sleep(5) + +print("šŸ›‘ Terminating Echo Environment server...") +server_process.terminate() From 6001f021d1a4a49343979e7107f8f805004c81d8 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 21 Oct 2025 20:00:45 +0000 Subject: [PATCH 13/18] lint --- examples/scripts/openenv/catch.py | 11 +++++------ examples/scripts/openenv/catch_demo.py | 9 +++++---- examples/scripts/openenv/echo.py | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/scripts/openenv/catch.py b/examples/scripts/openenv/catch.py index b01adbbebd..ebc6c0968b 100644 --- a/examples/scripts/openenv/catch.py +++ b/examples/scripts/openenv/catch.py @@ -158,10 +158,7 @@ def rollout_func(prompts: list[str], images: list | None, args: GRPOConfig, proc while not obs.done: # Build prompt with current observation and legal actions - episode_prompt = ( - f"{base_prompt}\n\n" - f"{obs.info_state}\n" - ) + episode_prompt = f"{base_prompt}\n\n{obs.info_state}\n" # Generate action from model gen_payload = { @@ -183,11 +180,13 @@ def rollout_func(prompts: list[str], images: list | None, args: GRPOConfig, proc episode_completion_ids.extend(gen_result["completion_ids"][0]) episode_logprobs.extend(gen_result["logprobs"][0]) - completion_text = processing_class.batch_decode(gen_result["completion_ids"], skip_special_tokens=True)[0] + completion_text = processing_class.batch_decode( + gen_result["completion_ids"], skip_special_tokens=True + )[0] # Parse action from completion action_id = 0 # default - numbers = re.findall(r'\b([0-2])\b', completion_text) + numbers = re.findall(r"\b([0-2])\b", completion_text) if numbers: action_id = int(numbers[0]) elif obs.legal_actions: diff --git a/examples/scripts/openenv/catch_demo.py b/examples/scripts/openenv/catch_demo.py index 4a40f17286..760def5938 100644 --- a/examples/scripts/openenv/catch_demo.py +++ b/examples/scripts/openenv/catch_demo.py @@ -47,6 +47,7 @@ from envs.openspiel_env import OpenSpielEnv from envs.openspiel_env.models import OpenSpielAction + ENV_URL = "http://0.0.0.0:8002" @@ -120,9 +121,9 @@ def run_episode(env, policy, visualize=True, delay=0.3): obs = result.observation if visualize: - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f" šŸŽ® {policy.name}") - print(f" šŸŽ² Playing against OpenSpiel Catch") + print(" šŸŽ² Playing against OpenSpiel Catch") print("=" * 60 + "\n") time.sleep(delay) @@ -153,7 +154,7 @@ def run_episode(env, policy, visualize=True, delay=0.3): if visualize: result_text = "šŸŽ‰ CAUGHT!" if total_reward > 0 else "😢 MISSED" - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f" {result_text} Total Reward: {total_reward}") print("=" * 60) @@ -232,7 +233,7 @@ def run_demo(): # Get final environment state state = client.state() - print(f"\nšŸ“Š Final Environment State:") + print("\nšŸ“Š Final Environment State:") print(f" Episode ID: {state.episode_id}") print(f" Step count: {state.step_count}") print(f" Game: {state.game_name}") diff --git a/examples/scripts/openenv/echo.py b/examples/scripts/openenv/echo.py index 899d9403c1..ad971b007b 100644 --- a/examples/scripts/openenv/echo.py +++ b/examples/scripts/openenv/echo.py @@ -167,7 +167,7 @@ def reward_from_env(completions, **kwargs): max_completion_length=2048, per_device_train_batch_size=8, gradient_accumulation_steps=4, - max_steps=1 + max_steps=1, ) trainer = GRPOTrainer( model="Qwen/Qwen2.5-0.5B-Instruct", From 89aef4491bae87c441097b06489094593879fcfa Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 21 Oct 2025 20:01:35 +0000 Subject: [PATCH 14/18] fix --- examples/scripts/openenv/catch_demo.py | 272 ------------------------- examples/scripts/openenv/echo.py | 2 +- 2 files changed, 1 insertion(+), 273 deletions(-) delete mode 100644 examples/scripts/openenv/catch_demo.py diff --git a/examples/scripts/openenv/catch_demo.py b/examples/scripts/openenv/catch_demo.py deleted file mode 100644 index 760def5938..0000000000 --- a/examples/scripts/openenv/catch_demo.py +++ /dev/null @@ -1,272 +0,0 @@ -# Copyright 2020-2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# ruff: noqa: T201 -""" -Simple demo script for the OpenSpiel Catch environment. - -This demonstrates the basic workflow: -1. Start the environment server -2. Connect to it -3. Reset and observe initial state -4. Take actions and see rewards -5. Clean up - -Setup: - -```sh -uv pip install git+https://github.com/meta-pytorch/OpenEnv.git -uv pip install open_spiel -``` - -Usage: - -```sh -python trl/experimental/openenv/catch_demo.py -``` -""" - -import os -import subprocess -import sys -import time -from pathlib import Path - -import requests -from envs.openspiel_env import OpenSpielEnv -from envs.openspiel_env.models import OpenSpielAction - - -ENV_URL = "http://0.0.0.0:8002" - - -class Policy: - """Base policy class.""" - - def __init__(self, name): - self.name = name - - def select_action(self, obs): - """Select an action given an observation.""" - raise NotImplementedError - - -class RandomPolicy(Policy): - """Policy that selects random legal actions.""" - - def __init__(self): - super().__init__("šŸŽ² Random Policy") - - def select_action(self, obs): - import random - - return random.choice(obs.legal_actions) if obs.legal_actions else 0 - - -class LeftPolicy(Policy): - """Policy that always moves left.""" - - def __init__(self): - super().__init__("ā¬…ļø Always Left") - - def select_action(self, obs): - return obs.legal_actions[0] if obs.legal_actions else 0 - - -class RightPolicy(Policy): - """Policy that always moves right.""" - - def __init__(self): - super().__init__("āž”ļø Always Right") - - def select_action(self, obs): - return obs.legal_actions[-1] if obs.legal_actions else 2 - - -class SmartPolicy(Policy): - """Policy that tries to move towards the ball.""" - - def __init__(self): - super().__init__("🧠 Smart Policy") - - def select_action(self, obs): - # In catch, the info_state contains information about paddle and ball positions - # This is a simple heuristic - in practice you'd need to understand the state representation - if not obs.legal_actions: - return 0 - - # For catch game, often legal actions are [0, 1, 2] = [left, stay, right] - # Simple heuristic: choose middle action (stay) or random - import random - - return random.choice(obs.legal_actions) - - -def run_episode(env, policy, visualize=True, delay=0.3): - """Run one episode with a policy against OpenSpiel environment.""" - - # RESET - result = env.reset() - obs = result.observation - - if visualize: - print(f"\n{'=' * 60}") - print(f" šŸŽ® {policy.name}") - print(" šŸŽ² Playing against OpenSpiel Catch") - print("=" * 60 + "\n") - time.sleep(delay) - - total_reward = 0 - step = 0 - action_names = ["ā¬…ļø LEFT", "šŸ›‘ STAY", "āž”ļø RIGHT"] - - # THE RL LOOP - while not obs.done: - # 1. Policy chooses action - action_id = policy.select_action(obs) - - # 2. Environment executes (via HTTP!) - action = OpenSpielAction(action_id=action_id, game_name="catch") - result = env.step(action) - obs = result.observation - - # 3. Collect reward - if result.reward is not None: - total_reward += result.reward - - if visualize: - action_name = action_names[action_id] if action_id < len(action_names) else f"ACTION {action_id}" - print(f"šŸ“ Step {step + 1}: {action_name} → Reward: {result.reward}") - time.sleep(delay) - - step += 1 - - if visualize: - result_text = "šŸŽ‰ CAUGHT!" if total_reward > 0 else "😢 MISSED" - print(f"\n{'=' * 60}") - print(f" {result_text} Total Reward: {total_reward}") - print("=" * 60) - - return total_reward > 0 - - -def start_server(): - """Start the OpenSpiel environment server.""" - print("⚔ Starting FastAPI server for OpenSpiel Catch Environment...") - - work_dir = str(Path.cwd().parent.absolute()) - - server_process = subprocess.Popen( - [sys.executable, "-m", "uvicorn", "envs.openspiel_env.server.app:app", "--host", "0.0.0.0", "--port", "8002"], - env={**os.environ, "PYTHONPATH": f"{work_dir}/src"}, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - cwd=work_dir, - ) - - # Wait for server to start - print("ā³ Waiting for server to start...") - time.sleep(5) - - # Check if server is running - try: - response = requests.get(f"{ENV_URL}/health", timeout=2) - print("āœ… OpenSpiel Catch Environment server is running!\n") - return server_process - except Exception as e: - print(f"āŒ Server failed to start: {e}") - print("\nšŸ“‹ Checking error output...") - server_process.poll() - if server_process.stderr: - stderr = server_process.stderr.read() - if stderr: - print(stderr) - raise - - -def run_demo(): - """Run a simple demo of the Catch environment.""" - print("šŸŽÆ OpenSpiel Catch Environment Demo") - print("=" * 60) - - # Connect to environment server - client = OpenSpielEnv(base_url=ENV_URL) - - try: - # Reset environment and show initial state - print("\nšŸ“ Resetting environment...") - result = client.reset() - - print(f" Initial observation shape: {len(result.observation.info_state)}") - print(f" Info state (first 10 values): {result.observation.info_state[:10]}") - print(f" Legal actions: {result.observation.legal_actions}") - print(f" Game phase: {result.observation.game_phase}") - print(f" Done: {result.done}") - print(f" Initial reward: {result.reward}") - - # Demo different policies - print("\nšŸ“ŗ " + "=" * 64 + " šŸ“ŗ") - print(" Watch Policies Play Against OpenSpiel!") - print("šŸ“ŗ " + "=" * 64 + " šŸ“ŗ\n") - - policies = [SmartPolicy(), RandomPolicy(), LeftPolicy(), RightPolicy()] - - for policy in policies: - caught = run_episode(client, policy, visualize=True, delay=0.5) - - print("\nšŸ’” You just watched REAL OpenSpiel Catch being played!") - print(" • Every action was an HTTP call") - print(" • Game logic runs in the server") - print(" • Client only sends actions and receives observations\n") - - # Get final environment state - state = client.state() - print("\nšŸ“Š Final Environment State:") - print(f" Episode ID: {state.episode_id}") - print(f" Step count: {state.step_count}") - print(f" Game: {state.game_name}") - print(f" Num players: {state.num_players}") - print(f" Agent player: {state.agent_player}") - - except Exception as e: - print(f"\nāŒ Error during demo: {e}") - import traceback - - traceback.print_exc() - - finally: - # Always close the environment - client.close() - print("\nāœ… Demo complete!") - - -def main(): - """Main function to run the demo.""" - server_process = None - try: - server_process = start_server() - run_demo() - except KeyboardInterrupt: - print("\n\nāš ļø Interrupted by user") - finally: - if server_process: - print("\nšŸ›‘ Shutting down server...") - server_process.terminate() - server_process.wait(timeout=5) - print("šŸ‘‹ Done!") - - -if __name__ == "__main__": - main() diff --git a/examples/scripts/openenv/echo.py b/examples/scripts/openenv/echo.py index ad971b007b..c8916b15ee 100644 --- a/examples/scripts/openenv/echo.py +++ b/examples/scripts/openenv/echo.py @@ -63,6 +63,7 @@ # Determine the correct path work_dir = str(Path.cwd().parent.absolute()) +# Workaround if you can't run the env with Docker server_process = subprocess.Popen( [sys.executable, "-m", "uvicorn", "envs.echo_env.server.app:app", "--host", "0.0.0.0", "--port", "8001"], env={**os.environ, "PYTHONPATH": f"{work_dir}/src"}, @@ -167,7 +168,6 @@ def reward_from_env(completions, **kwargs): max_completion_length=2048, per_device_train_batch_size=8, gradient_accumulation_steps=4, - max_steps=1, ) trainer = GRPOTrainer( model="Qwen/Qwen2.5-0.5B-Instruct", From 46810da8d290345b2928b50eb5a502bd31b93dfb Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 21 Oct 2025 20:27:15 +0000 Subject: [PATCH 15/18] fix --- examples/scripts/openenv/catch.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/scripts/openenv/catch.py b/examples/scripts/openenv/catch.py index ebc6c0968b..939dc7a6a5 100644 --- a/examples/scripts/openenv/catch.py +++ b/examples/scripts/openenv/catch.py @@ -49,12 +49,12 @@ # Run training ```sh -CUDA_VISIBLE_DEVICES=1 python trl/experimental/openenv/catch.py +CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/catch.py ``` """ GEN_URL = "http://0.0.0.0:8000/generate/" -ENV_URL = "http://0.0.0.0:8002" +ENV_URL = "http://0.0.0.0:8001" BASE_PROMPT = """You are an AI agent playing the game **Catch**. @@ -91,7 +91,7 @@ work_dir = str(Path.cwd().parent.absolute()) server_process = subprocess.Popen( - [sys.executable, "-m", "uvicorn", "envs.openspiel_env.server.app:app", "--host", "0.0.0.0", "--port", "8002"], + [sys.executable, "-m", "uvicorn", "envs.openspiel_env.server.app:app", "--host", "0.0.0.0", "--port", "8001"], env={**os.environ, "PYTHONPATH": f"{work_dir}/src"}, stdout=subprocess.PIPE, stderr=subprocess.PIPE, @@ -157,8 +157,9 @@ def rollout_func(prompts: list[str], images: list | None, args: GRPOConfig, proc episode_logprobs = [] while not obs.done: - # Build prompt with current observation and legal actions - episode_prompt = f"{base_prompt}\n\n{obs.info_state}\n" + # FIXME: handle this better + episode_msg = [{"role": "user", "content": f"{base_prompt}\n\n{obs.info_state}\n"}] + episode_prompt = processing_class.apply_chat_template(episode_msg, tokenize=False) # Generate action from model gen_payload = { From 011667242ec50e1711f9930d78880070792ac8f4 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 21 Oct 2025 21:04:18 +0000 Subject: [PATCH 16/18] Make catch run --- examples/scripts/openenv/catch.py | 21 +++++++++++++-------- examples/scripts/openenv/echo.py | 6 +----- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/examples/scripts/openenv/catch.py b/examples/scripts/openenv/catch.py index 939dc7a6a5..f509cb4e3d 100644 --- a/examples/scripts/openenv/catch.py +++ b/examples/scripts/openenv/catch.py @@ -24,7 +24,7 @@ from envs.openspiel_env import OpenSpielEnv from envs.openspiel_env.models import OpenSpielAction -from trl import GRPOConfig, GRPOTrainer +from trl import GRPOConfig, GRPOTrainer, apply_chat_template """ @@ -99,7 +99,6 @@ cwd=work_dir, ) -# Wait for server to start print("ā³ Waiting for server to start...") time.sleep(5) @@ -145,7 +144,7 @@ def rollout_func(prompts: list[str], images: list | None, args: GRPOConfig, proc all_completion_ids = [] all_logprobs = [] - for prompt_idx, base_prompt in enumerate(prompts): + for base_prompt in prompts: for _ in range(args.num_generations): # Run episode: Reset environment and loop until done env_result = client.reset() @@ -157,13 +156,13 @@ def rollout_func(prompts: list[str], images: list | None, args: GRPOConfig, proc episode_logprobs = [] while not obs.done: - # FIXME: handle this better - episode_msg = [{"role": "user", "content": f"{base_prompt}\n\n{obs.info_state}\n"}] - episode_prompt = processing_class.apply_chat_template(episode_msg, tokenize=False) + # FIXME: handle the addition of observation to prompt more cleanly, ideally without a train_dataset + episode_msg = {"prompt": [{"role": "user", "content": f"{base_prompt}\n\n{obs.info_state}\n"}]} + episode_prompt = apply_chat_template(episode_msg, processing_class) # Generate action from model gen_payload = { - "prompts": [episode_prompt], + "prompts": [episode_prompt["prompt"]], "n": 1, "temperature": args.temperature, "top_p": args.top_p, @@ -235,7 +234,7 @@ def reward_from_env(completions, **kwargs): report_to=["trackio", "wandb"], num_train_epochs=1, num_generations=8, - max_completion_length=64, # Shorter for catch game (just need action selection) + max_completion_length=64, per_device_train_batch_size=8, gradient_accumulation_steps=4, ) @@ -247,3 +246,9 @@ def reward_from_env(completions, **kwargs): rollout_func=rollout_func, ) trainer.train() + +# Give time for background threads to finish +time.sleep(5) + +print("šŸ›‘ Terminating environment server...") +server_process.terminate() diff --git a/examples/scripts/openenv/echo.py b/examples/scripts/openenv/echo.py index c8916b15ee..8af8ce073a 100644 --- a/examples/scripts/openenv/echo.py +++ b/examples/scripts/openenv/echo.py @@ -57,13 +57,11 @@ GEN_URL = "http://0.0.0.0:8000/generate/" ENV_URL = "http://0.0.0.0:8001" -# Start the Echo server in background print("⚔ Starting FastAPI server for Echo Environment...") -# Determine the correct path -work_dir = str(Path.cwd().parent.absolute()) # Workaround if you can't run the env with Docker +work_dir = str(Path.cwd().parent.absolute()) server_process = subprocess.Popen( [sys.executable, "-m", "uvicorn", "envs.echo_env.server.app:app", "--host", "0.0.0.0", "--port", "8001"], env={**os.environ, "PYTHONPATH": f"{work_dir}/src"}, @@ -73,11 +71,9 @@ cwd=work_dir, ) -# Wait for server to start print("ā³ Waiting for server to start...") time.sleep(5) -# Check if server is running try: response = requests.get(f"{ENV_URL}/health", timeout=2) print("\nāœ… Echo Environment server is running!") From 3970d4c68e3f69145ecedd47285f1ce0ca1c00f9 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 21 Oct 2025 21:14:17 +0000 Subject: [PATCH 17/18] Note --- examples/scripts/openenv/catch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/scripts/openenv/catch.py b/examples/scripts/openenv/catch.py index f509cb4e3d..9405f6d0b2 100644 --- a/examples/scripts/openenv/catch.py +++ b/examples/scripts/openenv/catch.py @@ -155,6 +155,7 @@ def rollout_func(prompts: list[str], images: list | None, args: GRPOConfig, proc episode_completion_ids = [] episode_logprobs = [] + # TODO: parallelise! while not obs.done: # FIXME: handle the addition of observation to prompt more cleanly, ideally without a train_dataset episode_msg = {"prompt": [{"role": "user", "content": f"{base_prompt}\n\n{obs.info_state}\n"}]} @@ -234,7 +235,7 @@ def reward_from_env(completions, **kwargs): report_to=["trackio", "wandb"], num_train_epochs=1, num_generations=8, - max_completion_length=64, + max_completion_length=4, per_device_train_batch_size=8, gradient_accumulation_steps=4, ) From f297fb6ee9a4fd7ac7f32b5805ae9562fd4953b6 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Tue, 21 Oct 2025 21:20:07 +0000 Subject: [PATCH 18/18] Nuke --- trl/experimental/openenv/echo.py | 177 ------------------------------- 1 file changed, 177 deletions(-) delete mode 100644 trl/experimental/openenv/echo.py diff --git a/trl/experimental/openenv/echo.py b/trl/experimental/openenv/echo.py deleted file mode 100644 index e23d84642e..0000000000 --- a/trl/experimental/openenv/echo.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright 2020-2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# ruff: noqa: T201 -import os -import subprocess -import sys -import time -from contextlib import suppress -from pathlib import Path - -import requests -from datasets import load_dataset -from envs.echo_env import EchoEnv -from envs.echo_env.models import ( - EchoAction, -) - -from trl import GRPOConfig, GRPOTrainer - - -""" -Simple script to run GRPO training with OpenEnv's Echo environment and a vLLM server. The reward function encourages -longer completions. - -Setup: - -```sh -uv pip install git+https://github.com/meta-pytorch/OpenEnv.git -``` - -Usage (2 GPUs required): - -# Spin up server - -```sh -CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000 -``` - -# Run training - -```sh -CUDA_VISIBLE_DEVICES=1 python trl/experimental/openenv/echo.py -``` -""" - -GEN_URL = "http://0.0.0.0:8000/generate/" -ENV_URL = "http://0.0.0.0:8001" - -# Start the Echo server in background -print("⚔ Starting FastAPI server for Echo Environment...") - -# Determine the correct path -work_dir = str(Path.cwd().parent.absolute()) - -server_process = subprocess.Popen( - [sys.executable, "-m", "uvicorn", "envs.echo_env.server.app:app", "--host", "0.0.0.0", "--port", "8001"], - env={**os.environ, "PYTHONPATH": f"{work_dir}/src"}, - cwd=work_dir, -) - -# Wait for server to start -print("ā³ Waiting for server to start...") -time.sleep(5) - -try: - # Check if server is running - try: - response = requests.get(f"{ENV_URL}/health", timeout=2) - print("\nāœ… Echo Environment server is running!") - except Exception as e: - print(f"\nāŒ Server failed to start: {e}") - raise - - # Create HTTP client for Echo Environment - client = EchoEnv(base_url=f"{ENV_URL}") - - def rollout_func(prompts: list[str], images: list | None, args: GRPOConfig, processing_class) -> dict[str, list]: - """ - Custom rollout function that generates completions via vLLM server and computes environment rewards. - - Args: - prompts: List of prompt strings to generate from - images: Optional images for vision models (not used in this example) - args: GRPOConfig containing all sampling parameters - processing_class: Tokenizer/processor for decoding completions - - Returns: - Dict containing prompt_ids, completion_ids, logprobs, and env_reward - """ - # Make request to TRL's custom /generate/ endpoint - payload = { - "prompts": prompts, - "n": args.num_generations, - "temperature": args.temperature, - "top_p": args.top_p, - "top_k": -1 if args.top_k is None else args.top_k, - "min_p": 0.0 if args.min_p is None else args.min_p, - "max_tokens": args.max_completion_length, - "repetition_penalty": args.repetition_penalty, - } - response = requests.post(GEN_URL, json=payload) - - if response.status_code != 200: - print(f"Error response: {response.text}") - - response.raise_for_status() - result = response.json() - - completions_text = processing_class.batch_decode(result["completion_ids"], skip_special_tokens=True) - - # Flush env - env_result = client.reset() - - env_rewards = [] - for msg in completions_text: - env_result = client.step(EchoAction(message=msg)) - env_rewards.append(env_result.reward) - - result["env_reward"] = env_rewards - - return result - - dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train[:1000]") - - def reward_from_env(completions, **kwargs): - """Reward function that uses the environment reward.""" - # Extract environment rewards from kwargs (propagated via extra_fields) - env_rewards = kwargs.get("env_reward", []) - if env_rewards: - return [float(reward) for reward in env_rewards] - else: - # Fallback if env_reward is not available - return [0.0] * len(completions) - - training_args = GRPOConfig( - output_dir="scratch/Qwen2.5-0.5B-GRPO-Rollout", - vllm_mode="server", - use_vllm=True, - logging_steps=1, - report_to=["trackio", "wandb"], - num_train_epochs=1, - num_generations=8, - max_completion_length=2048, - per_device_train_batch_size=8, - gradient_accumulation_steps=4, - ) - trainer = GRPOTrainer( - model="Qwen/Qwen2.5-0.5B-Instruct", - reward_funcs=reward_from_env, - args=training_args, - train_dataset=dataset, - rollout_func=rollout_func, - ) - trainer.train() -finally: - print("\nšŸ›‘ Stopping Echo Environment server...") - if server_process.poll() is None: - server_process.terminate() - try: - server_process.wait(timeout=5) - except subprocess.TimeoutExpired: - print("āš ļø Termination timeout reached. Forcing shutdown.") - server_process.kill() - with suppress(subprocess.TimeoutExpired): - server_process.wait(timeout=5)