Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
552e899
Refactor image handling: replace `image_split_sizes` with `image_grid…
qgallouedec Sep 19, 2025
449ef07
simpler
qgallouedec Sep 19, 2025
c8933aa
gfpo
qgallouedec Sep 19, 2025
229c554
multi-image grpo
qgallouedec Sep 19, 2025
3ca6ad5
log with wandb
qgallouedec Sep 19, 2025
dcf4b92
no vlm reward models
qgallouedec Sep 20, 2025
30ad7ca
rloo
qgallouedec Sep 20, 2025
86cc30b
gfpo
qgallouedec Sep 20, 2025
088897b
fix
qgallouedec Sep 20, 2025
d2adc63
test peft
qgallouedec Sep 20, 2025
f4c82bf
fix gfpo
qgallouedec Sep 20, 2025
1257796
rloo test
qgallouedec Sep 20, 2025
099a39b
peft rloo
qgallouedec Sep 20, 2025
529add6
oops
qgallouedec Sep 20, 2025
fc6b11f
update test
qgallouedec Sep 20, 2025
ae1f497
generate method
qgallouedec Sep 20, 2025
f998432
debug
qgallouedec Sep 20, 2025
fa73876
skip failing test
qgallouedec Sep 20, 2025
52d8bd9
Merge branch 'main' into drop-image_split_sizes
qgallouedec Sep 20, 2025
dfc0d38
Merge branch 'drop-image_split_sizes' into multi-image-support
qgallouedec Sep 20, 2025
fc52e68
test fixed!
qgallouedec Sep 20, 2025
4d12aeb
Merge branch 'multi-image-support' into generate-method
qgallouedec Sep 20, 2025
4fc2b5b
gfpo
qgallouedec Sep 20, 2025
b628744
rm vllm
qgallouedec Sep 20, 2025
d3a769f
fix doc
qgallouedec Sep 20, 2025
c9693b2
a bit messy!
qgallouedec Sep 21, 2025
e17ec42
Merge branch 'main' into drop-image_split_sizes
qgallouedec Sep 22, 2025
efbb03a
Merge branch 'drop-image_split_sizes' into multi-image-support
qgallouedec Sep 22, 2025
562c662
Merge branch 'main' into multi-image-support
qgallouedec Sep 22, 2025
485781c
Merge branch 'main' into multi-image-support
qgallouedec Sep 22, 2025
05270f8
update layers to ignore
qgallouedec Sep 22, 2025
1c53094
clarify image column desc
qgallouedec Sep 22, 2025
9b6652e
rm VLM x RM warning
qgallouedec Sep 23, 2025
c500440
Merge branch 'multi-image-support' into generate-method
qgallouedec Sep 23, 2025
a6a8c44
Merge branch 'main' into generate-method
qgallouedec Sep 23, 2025
b8656e0
Merge branch 'generate-method' into multi-turn
qgallouedec Sep 23, 2025
d8665e1
Merge branch 'main' into generate-method
qgallouedec Sep 23, 2025
365d501
Merge branch 'main' into generate-method
qgallouedec Sep 23, 2025
acb44bc
Merge branch 'generate-method' into multi-turn
qgallouedec Sep 23, 2025
cdb4c76
Merge branch 'main' into generate-method
qgallouedec Sep 24, 2025
c83e710
same for rloo
qgallouedec Sep 24, 2025
ec6ad25
nits style and align
qgallouedec Sep 24, 2025
b4cadde
Merge branch 'main' into generate-method
qgallouedec Sep 24, 2025
594a07d
Merge branch 'generate-method' into multi-turn
qgallouedec Sep 24, 2025
04e4bd7
Update trl/trainer/grpo_trainer.py
qgallouedec Sep 26, 2025
242d66a
Merge branch 'main' into multi-turn
qgallouedec Sep 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def apply_chat_template(
# Apply the chat template to the prompt, adding the generation prompt
if "prompt" in example:
last_role = example["prompt"][-1]["role"]
if last_role == "user":
if last_role in ["user", "tool"]:
add_generation_prompt = True
continue_final_message = False
elif last_role == "assistant":
Expand Down
74 changes: 72 additions & 2 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import inspect
import json
import os
import re
import textwrap
Expand Down Expand Up @@ -61,6 +62,10 @@
disable_dropout_in_model,
ensure_master_addr_port,
entropy_from_logits,
flush_left,
flush_right,
generate_model_card,
get_comet_experiment_url,
identity,
nanmax,
nanmin,
Expand Down Expand Up @@ -97,7 +102,21 @@
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]


class GRPOTrainer(BaseTrainer):
def extract_tool_calls(text: str) -> dict[str, Any]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about moving this function to a non-specific trainer module, so it can be used by any trainer in the future?

"""
Given a list of strings, extract all <tool_call> JSON blocks and return them as a list of dictionaries.
"""
pattern = re.compile(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", re.DOTALL)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is unfortunately no standardisation for the tool call tags across model families, but Matt is working on extending the chat templates so they can auto-parse tools calls (internal Slack thread): https://huggingface.slack.com/archives/C06JKEMK6BZ/p1757691450090859

Note that vllm works around this by providing a dedicated set of parsers that can be set when spinning up the server: https://docs.vllm.ai/en/stable/features/tool_calling.html

I'm not sure we want to go down this route, since it's quite messy in my experience to match the parser to the desired model (e.g. some Qwen models use the hermes parser, others not)

So in the meantime, we might want to give uses the ability to provide their own parsing function and default to yours (which is the most common I've seen)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is the approach we usually followed on smolagents: to provide a sensible default, but allow users to fully customize the function/object instance.

Copy link
Member

@albertvillanova albertvillanova Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the compile function will be called at each function call. If optimizing performance is necessary in this case, we should set the compilation as a constant at the module level, so it is called only once at import time.


for match in pattern.findall(text):
try:
return json.loads(match)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You only return the first match?

except json.JSONDecodeError:
pass
return None


class GRPOTrainer(Trainer):
"""
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language
Expand Down Expand Up @@ -227,7 +246,10 @@ def __init__(
callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
peft_config: Optional["PeftConfig"] = None,
tools=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will need to add the tools param to the trainer docstring. And give a type hint.

):
self.tools = tools or []
self._tool_dict = {tool.__name__: tool for tool in self.tools}
# Args
if args is None:
model_name = model if isinstance(model, str) else model.config._name_or_path
Expand Down Expand Up @@ -1085,7 +1107,8 @@ def _generate(self, prompts: list[str], images: Optional[list]):
prepare_multimodal_messages(prompt, num_images=len(image_list))

prompts_text = [
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
maybe_apply_chat_template({"prompt": prompt}, self.processing_class, tools=self.tools)["prompt"]
for prompt in prompts
]

prompt_inputs = self.processing_class(
Expand Down Expand Up @@ -1413,6 +1436,53 @@ def _generate_and_score_completions(
sampling_per_token_logps,
forward_kwargs,
) = self._generate(prompts, images)
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
tool_calls = [extract_tool_calls(completion) for completion in completions]
tool_results = [self._tool_dict[tc["name"]](**tc["arguments"]) if tc else None for tc in tool_calls]
tool_messages = [
[{"role": "tool", "name": tc["name"], "content": str(tr)}] if tc else None
for tc, tr in zip(tool_calls, tool_results)
]
Comment on lines +1440 to +1445
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure of this handles potential multiple tool calls in a single completion...

Comment on lines +1440 to +1445
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before the messages with the tool results ("role": "tool"), shouldn't we prepend the messages with the tool calls themselves ("role": "assistant", "tool_calls":...)? Not sure of this though... A real question! 😅

new_prompts = [
p + [{"role": "user", "content": c}] + t for p, c, t in zip(prompts, completions, tool_messages) if t
]
needs_tool = torch.tensor([tc is not None for tc in tool_calls], device=device)
if new_prompts:
(
new_prompt_ids,
new_completion_ids,
new_prompt_mask,
new_completion_mask,
new_num_items_in_batch,
new_sampling_per_token_logps,
new_forward_kwargs,
) = self._generate(new_prompts, images)
num_tool_ids = new_prompt_mask.sum(-1) - torch.cat(
[prompt_mask[needs_tool], completion_mask[needs_tool]], dim=1
).sum(-1)
tool_ids = [ids[-num:] for ids, num in zip(new_prompt_ids, num_tool_ids)]
tool_mask = [torch.ones_like(ids) for ids in tool_ids]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be cool to have a unit test for this masking so we're confident it is behaving as expected

r_completion_mask, r_completion_ids = flush_right(completion_mask[needs_tool], completion_ids[needs_tool])
ci = [torch.cat(x) for x in zip(r_completion_ids, tool_ids, new_completion_ids)]
cm = [torch.cat(x) for x in zip(r_completion_mask, tool_mask, new_completion_mask)]

new_ci = []
new_cm = []
true_idx = 0
for i, m in enumerate(needs_tool):
if m:
# take the next tensor from list_true
new_ci.append(ci[true_idx])
new_cm.append(cm[true_idx])
true_idx += 1
else:
new_ci.append(completion_ids[i])
new_cm.append(completion_mask[i])

completion_ids = pad(new_ci, self.pad_token_id)
completion_mask = pad(new_cm, 0)
completion_mask, completion_ids = flush_left(completion_mask, completion_ids)
num_items_in_batch += new_num_items_in_batch

# Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need
# to re-tokenize completions if the reward is computed from tokens.
Expand Down
Loading