Skip to content

Commit 9874b3a

Browse files
lewtunqgallouedec
andauthored
[GRPO] Add metrics for low and high clipped token probabilities (#3289)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
1 parent 1e61f6c commit 9874b3a

File tree

2 files changed

+55
-10
lines changed

2 files changed

+55
-10
lines changed

docs/source/grpo_trainer.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,16 @@ This constant is recommended to be the maximum completion length. To use this fo
155155
- `reward/{reward_func_name}/std`: The standard deviation of the reward from a specific reward function.
156156
- `reward`: The overall average reward after applying reward weights.
157157
- `reward_std`: The standard deviation of the overall reward within each batch after applying reward weights.
158-
- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero.
159-
- `clip_ratio`: The fraction of tokens where the PPO objective is clipped to stay within the trust region:
158+
- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero.
159+
- `clip_ratio/region_mean`: The ratio of token probabilities where the GRPO objective is clipped to stay within the trust region:
160160
$$
161-
\text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right)
161+
\text{clip}\left( r_{i,t}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \qquad r_{i,t}(\theta) = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}\,.
162162
$$
163-
A higher value means more tokens were affected by clipping, limiting how much the policy can change.
163+
A higher value means more tokens are clipped, which constrains how much the policy $\pi_\theta$ can change.
164+
- `clip_ratio/low_mean`: The average ratio of token probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
165+
- `clip_ratio/low_min`: The minimum ratio of token probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
166+
- `clip_ratio/high_mean`: The average ratio of token probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\)
167+
- `clip_ratio/high_max`: The maximum ratio of token probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\).
164168

165169
## Customization
166170

trl/trainer/grpo_trainer.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,36 @@ def nanstd(tensor: torch.Tensor) -> torch.Tensor:
183183
return torch.sqrt(variance)
184184

185185

186+
def nanmin(tensor: torch.Tensor) -> torch.Tensor:
187+
"""
188+
Compute the minimum value of a tensor, ignoring NaNs. This function only supports 1D tensors.
189+
190+
Args:
191+
tensor (`torch.Tensor`): Input tensor of shape `(N,)`.
192+
193+
Returns:
194+
`torch.Tensor`: Minimum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN.
195+
"""
196+
if torch.isnan(tensor).all():
197+
return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device)
198+
return torch.min(tensor[~torch.isnan(tensor)])
199+
200+
201+
def nanmax(tensor: torch.Tensor) -> torch.Tensor:
202+
"""
203+
Compute the maximum value of a tensor, ignoring NaNs. This function only supports 1D tensors.
204+
205+
Args:
206+
tensor (`torch.Tensor`): Input tensor of shape `(N,)`.
207+
208+
Returns:
209+
`torch.Tensor`: Maximum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN.
210+
"""
211+
if torch.isnan(tensor).all():
212+
return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device)
213+
return torch.max(tensor[~torch.isnan(tensor)])
214+
215+
186216
class GRPOTrainer(Trainer):
187217
"""
188218
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
@@ -1086,12 +1116,23 @@ def _compute_loss(self, model, inputs):
10861116
mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
10871117
self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).nanmean().item())
10881118

1089-
# Compute the clip ratio
1090-
is_clipped = ((coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
1091-
(coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
1092-
)
1093-
clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
1094-
self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).nanmean().item())
1119+
# Compute the clipped probability ratios
1120+
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
1121+
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
1122+
is_region_clipped = is_low_clipped | is_high_clipped
1123+
1124+
low_clip = (is_low_clipped * completion_mask).sum() / completion_mask.sum()
1125+
high_clip = (is_high_clipped * completion_mask).sum() / completion_mask.sum()
1126+
clip_ratio = (is_region_clipped * completion_mask).sum() / completion_mask.sum()
1127+
1128+
gathered_low_clip = self.accelerator.gather_for_metrics(low_clip)
1129+
self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
1130+
self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item())
1131+
gathered_high_clip = self.accelerator.gather_for_metrics(high_clip)
1132+
self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
1133+
self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
1134+
gathered_clip_ratio = self.accelerator.gather_for_metrics(clip_ratio)
1135+
self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
10951136
return loss
10961137

10971138
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):

0 commit comments

Comments
 (0)