-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Multi-turn tool calling support #4115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
552e899
449ef07
c8933aa
229c554
3ca6ad5
dcf4b92
30ad7ca
86cc30b
088897b
d2adc63
f4c82bf
1257796
099a39b
529add6
fc6b11f
ae1f497
f998432
fa73876
52d8bd9
dfc0d38
fc52e68
4d12aeb
4fc2b5b
b628744
d3a769f
c9693b2
e17ec42
efbb03a
562c662
485781c
05270f8
1c53094
9b6652e
c500440
a6a8c44
b8656e0
d8665e1
365d501
acb44bc
cdb4c76
c83e710
ec6ad25
b4cadde
594a07d
04e4bd7
242d66a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |
| # limitations under the License. | ||
|
|
||
| import inspect | ||
| import json | ||
| import os | ||
| import re | ||
| import textwrap | ||
|
|
@@ -61,6 +62,10 @@ | |
| disable_dropout_in_model, | ||
| ensure_master_addr_port, | ||
| entropy_from_logits, | ||
| flush_left, | ||
| flush_right, | ||
| generate_model_card, | ||
| get_comet_experiment_url, | ||
| identity, | ||
| nanmax, | ||
| nanmin, | ||
|
|
@@ -97,7 +102,21 @@ | |
| RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] | ||
|
|
||
|
|
||
| class GRPOTrainer(BaseTrainer): | ||
| def extract_tool_calls(text: str) -> dict[str, Any]: | ||
| """ | ||
| Given a list of strings, extract all <tool_call> JSON blocks and return them as a list of dictionaries. | ||
| """ | ||
| pattern = re.compile(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", re.DOTALL) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is unfortunately no standardisation for the tool call tags across model families, but Matt is working on extending the chat templates so they can auto-parse tools calls (internal Slack thread): https://huggingface.slack.com/archives/C06JKEMK6BZ/p1757691450090859 Note that I'm not sure we want to go down this route, since it's quite messy in my experience to match the parser to the desired model (e.g. some Qwen models use the So in the meantime, we might want to give uses the ability to provide their own parsing function and default to yours (which is the most common I've seen)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is the approach we usually followed on smolagents: to provide a sensible default, but allow users to fully customize the function/object instance.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that the compile function will be called at each function call. If optimizing performance is necessary in this case, we should set the compilation as a constant at the module level, so it is called only once at import time. |
||
|
|
||
| for match in pattern.findall(text): | ||
| try: | ||
| return json.loads(match) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You only return the first match? |
||
| except json.JSONDecodeError: | ||
| pass | ||
| return None | ||
|
|
||
|
|
||
| class GRPOTrainer(Trainer): | ||
| """ | ||
| Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the | ||
| paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language | ||
|
|
@@ -227,7 +246,10 @@ def __init__( | |
| callbacks: Optional[list[TrainerCallback]] = None, | ||
| optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), | ||
| peft_config: Optional["PeftConfig"] = None, | ||
| tools=None, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You will need to add the |
||
| ): | ||
| self.tools = tools or [] | ||
| self._tool_dict = {tool.__name__: tool for tool in self.tools} | ||
| # Args | ||
| if args is None: | ||
| model_name = model if isinstance(model, str) else model.config._name_or_path | ||
|
|
@@ -1085,7 +1107,8 @@ def _generate(self, prompts: list[str], images: Optional[list]): | |
| prepare_multimodal_messages(prompt, num_images=len(image_list)) | ||
|
|
||
| prompts_text = [ | ||
| maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts | ||
| maybe_apply_chat_template({"prompt": prompt}, self.processing_class, tools=self.tools)["prompt"] | ||
| for prompt in prompts | ||
| ] | ||
|
|
||
| prompt_inputs = self.processing_class( | ||
|
|
@@ -1413,6 +1436,53 @@ def _generate_and_score_completions( | |
| sampling_per_token_logps, | ||
| forward_kwargs, | ||
| ) = self._generate(prompts, images) | ||
| completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) | ||
| tool_calls = [extract_tool_calls(completion) for completion in completions] | ||
| tool_results = [self._tool_dict[tc["name"]](**tc["arguments"]) if tc else None for tc in tool_calls] | ||
| tool_messages = [ | ||
| [{"role": "tool", "name": tc["name"], "content": str(tr)}] if tc else None | ||
| for tc, tr in zip(tool_calls, tool_results) | ||
| ] | ||
|
Comment on lines
+1440
to
+1445
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure of this handles potential multiple tool calls in a single completion...
Comment on lines
+1440
to
+1445
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Before the messages with the tool results ( |
||
| new_prompts = [ | ||
| p + [{"role": "user", "content": c}] + t for p, c, t in zip(prompts, completions, tool_messages) if t | ||
| ] | ||
| needs_tool = torch.tensor([tc is not None for tc in tool_calls], device=device) | ||
| if new_prompts: | ||
| ( | ||
| new_prompt_ids, | ||
| new_completion_ids, | ||
| new_prompt_mask, | ||
| new_completion_mask, | ||
| new_num_items_in_batch, | ||
| new_sampling_per_token_logps, | ||
| new_forward_kwargs, | ||
| ) = self._generate(new_prompts, images) | ||
| num_tool_ids = new_prompt_mask.sum(-1) - torch.cat( | ||
| [prompt_mask[needs_tool], completion_mask[needs_tool]], dim=1 | ||
| ).sum(-1) | ||
| tool_ids = [ids[-num:] for ids, num in zip(new_prompt_ids, num_tool_ids)] | ||
| tool_mask = [torch.ones_like(ids) for ids in tool_ids] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be cool to have a unit test for this masking so we're confident it is behaving as expected |
||
| r_completion_mask, r_completion_ids = flush_right(completion_mask[needs_tool], completion_ids[needs_tool]) | ||
| ci = [torch.cat(x) for x in zip(r_completion_ids, tool_ids, new_completion_ids)] | ||
| cm = [torch.cat(x) for x in zip(r_completion_mask, tool_mask, new_completion_mask)] | ||
|
|
||
| new_ci = [] | ||
| new_cm = [] | ||
| true_idx = 0 | ||
| for i, m in enumerate(needs_tool): | ||
| if m: | ||
| # take the next tensor from list_true | ||
| new_ci.append(ci[true_idx]) | ||
| new_cm.append(cm[true_idx]) | ||
| true_idx += 1 | ||
| else: | ||
| new_ci.append(completion_ids[i]) | ||
| new_cm.append(completion_mask[i]) | ||
|
|
||
| completion_ids = pad(new_ci, self.pad_token_id) | ||
| completion_mask = pad(new_cm, 0) | ||
| completion_mask, completion_ids = flush_left(completion_mask, completion_ids) | ||
| num_items_in_batch += new_num_items_in_batch | ||
|
|
||
| # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need | ||
| # to re-tokenize completions if the reward is computed from tokens. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about moving this function to a non-specific trainer module, so it can be used by any trainer in the future?