Skip to content

Conversation

lewtun
Copy link
Member

@lewtun lewtun commented Oct 20, 2025

What does this PR do?

Simplified version of #3469 which exposes a rollout_func to enable users to define custom logic for tool calling, environments etc. Very much WIP :)

Examples below.

Tool calling (single-step)

W B Chart 20_10_2025, 11_14_16 pm
Script: click to expand
from datasets import Dataset
import requests
from trl import GRPOConfig, GRPOTrainer
import random
import json
import re
from transformers.utils import get_json_schema
"""
Usage:

-- Spin up server --
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000

-- Run this script --
CUDA_VISIBLE_DEVICES=1 python scratch/grpo_rollout_tool.py
"""

def multiply(a: int, b: int) -> int:
    """Multiply two integers.

    Args:
        a: The first integer.
        b: The second integer.

    Returns:
        The product of the two integers.
    """
    return a * b

def create_tool_prompt(a, b):
    """Create a prompt that instructs the model to use the multiply tool via system message."""
    system_message = (
        "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n"
        "# Tools\n\n"
        "You may call one or more functions to assist with the user query.\n\n"
        "You are provided with function signatures within <tools></tools> XML tags:\n"
        "<tools>\n"
        f"{get_json_schema(multiply)}\n"
        "</tools>\n\n"
        "For each function call, return a json object with function name and arguments "
        "within <tool_call></tool_call> XML tags:\n"
        "<tool_call>\n"
        '{"name": <function-name>, "arguments": <args-json-object>}\n'
        "</tool_call>"
    )

    return [
        {"role": "system", "content": system_message},
        {"role": "user", "content": f"Multiply {a} and {b}."}
    ]

factors = [(random.randint(0, 9999), random.randint(0, 9999)) for _ in range(100)]
dataset = Dataset.from_dict(
    {
        "prompt": [create_tool_prompt(a, b) for a, b in factors],
        "result": [a * b for a, b in factors],
        "factor_a": [a for a, b in factors],
        "factor_b": [b for a, b in factors],
    }
)

GEN_URL = "http://0.0.0.0:8000/generate/"

def parse_tool_call(completion_text):
    """
    Parse a tool call from the model's completion.

    Returns:
        dict with 'name' and 'arguments', or None if no valid tool call found
    """
    # Pattern 1: Qwen's <tool_call> XML format
    # <tool_call>
    # {"name": "multiply", "arguments": {"a": 123, "b": 456}}
    # </tool_call>
    xml_pattern = r'<tool_call>\s*(\{[^}]*\})\s*</tool_call>'
    match = re.search(xml_pattern, completion_text, re.DOTALL)
    if match:
        try:
            json_str = match.group(1).strip()
            obj = json.loads(json_str)
            if obj.get("name") == "multiply" and "arguments" in obj:
                args = obj["arguments"]
                if "a" in args and "b" in args:
                    return {"name": "multiply", "arguments": {"a": int(args["a"]), "b": int(args["b"])}}
        except (json.JSONDecodeError, ValueError, KeyError):
            pass

    # Pattern 2: Direct JSON format (fallback)
    # {"name": "multiply", "arguments": {"a": 123, "b": 456}}
    try:
        # Find all JSON-like structures
        json_objects = re.findall(r'\{[^}]*"name"[^}]*"arguments"[^}]*\{[^}]*\}[^}]*\}', completion_text)
        for obj_str in json_objects:
            try:
                obj = json.loads(obj_str)
                if obj.get("name") == "multiply" and "arguments" in obj:
                    args = obj["arguments"]
                    if "a" in args and "b" in args:
                        return {"name": "multiply", "arguments": {"a": int(args["a"]), "b": int(args["b"])}}
            except (json.JSONDecodeError, ValueError, KeyError):
                continue
    except Exception:
        pass

    return None

def accuracy_reward(completions, result, **kwargs):
    print(f"\n=== accuracy_reward called ===")
    print(f"Number of completions: {len(completions)}")
    print(f"Number of results: {len(result)}")
    print(f"First result sample: {result[0]}")
    print(f"First completion sample: {completions[0]}")

    rewards = []
    for idx, (c, r) in enumerate(zip(completions, result)):
        completion_text = c[0]["content"]

        # Parse the tool call from the completion
        tool_call = parse_tool_call(completion_text)

        if tool_call and tool_call["name"] == "multiply":
            # Execute the multiply function with parsed arguments
            args = tool_call["arguments"]
            computed_answer = multiply(args["a"], args["b"])

            # Check if computed answer matches expected result
            reward = int(computed_answer == r)
            print(f"\nCompletion {idx}: Tool call {args}, Computed: {computed_answer}, Expected: {r}, Reward: {reward}")
        else:
            # Check if result appears directly in the final 100 characters of the completion text
            try:
                expected_answer = int(r)
                found_numbers = [int(num) for num in re.findall(r'\b\d+\b', completion_text[-100:])]
                final_number = found_numbers[-1] if found_numbers else None
                reward = int(final_number == expected_answer)
                print(f"\nCompletion {idx}: No valid tool call found. Final number: {final_number}, Expected: {expected_answer}, Reward: {reward}")
            except ValueError:
                reward = 0
                print(f"\nCompletion {idx}: No valid tool call found and expected answer is not an integer. Reward: {reward}")

        rewards.append(reward)

    print(f"\n=== Returning {len(rewards)} rewards, mean: {sum(rewards)/len(rewards):.3f} ===\n")
    return rewards

def rollout_func(prompts, **sampling_kwargs):
    """
    Rollout function that handles chat-formatted prompts.

    Args:
        prompts: List of chat-formatted prompts (list of dicts with "role" and "content")
        **sampling_kwargs: Sampling parameters (n, temperature, top_p, etc.)

    Returns:
        dict with keys: "prompt_ids", "completion_ids", "logprobs"
    """
    print(f"\n=== rollout_func called ===")
    print(f"Number of prompts: {len(prompts)}")
    print(f"Sampling kwargs: {sampling_kwargs}")
    print(f"First prompt sample: {prompts[0][:2] if prompts and isinstance(prompts[0], list) else prompts[0]}")

    # Make request to TRL's custom /generate/ endpoint
    payload = {
        "prompts": prompts,
        "n": sampling_kwargs.get("n", 1),
        "temperature": sampling_kwargs.get("temperature", 1.0),
        "top_p": sampling_kwargs.get("top_p", 1.0),
        "top_k": sampling_kwargs.get("top_k", -1),
        "min_p": sampling_kwargs.get("min_p", 0.0),
        "max_tokens": sampling_kwargs.get("max_completion_length", 128),
        "repetition_penalty": sampling_kwargs.get("repetition_penalty", 1.0),
    }

    print(f"Sending request to {GEN_URL}")
    response = requests.post(GEN_URL, json=payload)

    if response.status_code != 200:
        print(f"Error response: {response.text}")

    response.raise_for_status()
    result = response.json()

    print(f"Response keys: {result.keys()}")
    print(f"Response shapes: {[(k, len(v) if isinstance(v, list) else 'not-list') for k, v in result.items()]}")
    print(f"=== rollout_func completed ===\n")

    return result


training_args = GRPOConfig(
    output_dir="scratch/Qwen2.5-0.5B-GRPO-Rollout",
    vllm_mode="server",
    use_vllm=True,
    logging_steps=1,
    report_to=["trackio", "wandb"],
    num_train_epochs=1,
    num_generations=4,
    max_completion_length=4096
)
trainer = GRPOTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    reward_funcs=accuracy_reward,
    args=training_args,
    train_dataset=dataset,
    rollout_func=rollout_func,
)
trainer.train()

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

if self.rollout_func is not None:
output = self.rollout_func(
prompts=ordered_set_of_prompts,
n=self.num_generations,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found it necessary to propagate the sampling parameters because e.g. if we set num_generations in the GRPOConfig, this information must be aligned in the rollout_func for consistency.

The alternative would be to remove the sampling parameters altogether, and then assume the user aligns some of these params in their implementation of rollout_func

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant