Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 42 additions & 14 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]

# What we call a rollout function is a callable that takes in the data and sampling parameters and returns a list of
# generation results. Those results must include "prompt_ids", "completion_ids", and "logprobs" fields and can include an optional "tools"
# field. Any extra fields are forwarded to the reward functions.
RolloutFunc = Callable[[dict[str, Any], dict[str, Any]], list[dict[str, Any]]]


class GRPOTrainer(BaseTrainer):
"""
Expand Down Expand Up @@ -195,6 +200,10 @@ def reward_func(completions, **kwargs):
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
peft_config ([`~peft.PeftConfig`], *optional*):
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
rollout_func (`RolloutFunc`, *optional*, defaults to `None`):
Function to use for generating completions. It must take in the data sampling parameters and return a list
of generation results. Those results must include "prompt_ids", "completion_ids", and "logprobs" fields and
can include optional "tools" field and any other fields that are forwarded to the reward functions.
"""

_tag_names = ["trl", "grpo"]
Expand Down Expand Up @@ -225,6 +234,7 @@ def __init__(
callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
peft_config: Optional["PeftConfig"] = None,
rollout_func: Optional[RolloutFunc] = None,
):
# Args
if args is None:
Expand Down Expand Up @@ -340,6 +350,9 @@ def __init__(

self.reward_processing_classes = reward_processing_classes

# Rollout function
self.rollout_func = rollout_func

# Training arguments
self.max_prompt_length = args.max_prompt_length
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
Expand Down Expand Up @@ -1116,20 +1129,35 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
ordered_set_of_images = None

with profiling_context(self, "vLLM.generate"):
output = self.vllm_client.generate(
prompts=ordered_set_of_prompts,
images=ordered_set_of_images,
n=self.num_generations,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=-1 if self.top_k is None else self.top_k,
min_p=0.0 if self.min_p is None else self.min_p,
max_tokens=self.max_completion_length,
truncate_prompt_tokens=self.max_prompt_length,
guided_decoding_regex=self.guided_decoding_regex,
generation_kwargs=self.args.generation_kwargs,
)
if self.rollout_func is not None:
output = self.rollout_func(
prompts=ordered_set_of_prompts,
n=self.num_generations,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found it necessary to propagate the sampling parameters because e.g. if we set num_generations in the GRPOConfig, this information must be aligned in the rollout_func for consistency.

The alternative would be to remove the sampling parameters altogether, and then assume the user aligns some of these params in their implementation of rollout_func

repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=-1 if self.top_k is None else self.top_k,
min_p=0.0 if self.min_p is None else self.min_p,
max_tokens=self.max_completion_length,
truncate_prompt_tokens=self.max_prompt_length,
guided_decoding_regex=self.guided_decoding_regex,
generation_kwargs=self.args.generation_kwargs,
)
else:
output = self.vllm_client.generate(
prompts=ordered_set_of_prompts,
images=ordered_set_of_images,
n=self.num_generations,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=-1 if self.top_k is None else self.top_k,
min_p=0.0 if self.min_p is None else self.min_p,
max_tokens=self.max_completion_length,
truncate_prompt_tokens=self.max_prompt_length,
guided_decoding_regex=self.guided_decoding_regex,
generation_kwargs=self.args.generation_kwargs,
)
payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"])
else:
payload = None
Expand Down
Loading