Skip to content

Commit 6bdb716

Browse files
authored
Update ppo_model_utils.py (#10593)
* Update ppo_model_utils.py * Update pp_model_utils.py
1 parent c654d1a commit 6bdb716

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

paddlenlp/rl/models/pp_model_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ def fwd_step_patch(func, output, self, *args, **kwargs):
3838
# Training patch
3939
if self.training and self.is_pipeline_last_stage():
4040
if getattr(self, "_step_losses", None):
41-
self._step_losses.append(output.detach())
41+
self._step_losses.append(output[0].detach())
4242
else:
43-
self._step_losses = [output.detach()]
43+
self._step_losses = [output[0].detach()]
4444

4545

4646
def make_wrapper(func, pre_patch=None, post_patch=None):

paddlenlp/rl/models/ppo_model_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def backward(ctx, grad_output: paddle.Tensor) -> paddle.Tensor:
351351

352352

353353
def entropy_from_logits(logits: paddle.Tensor, tensor_parallel_output=False):
354-
return VocabParallelEntropy.apply(logits, tensor_parallel_output)
354+
return VocabParallelEntropy.apply(logits.astype("float32"), tensor_parallel_output)
355355

356356

357357
@merge_fwd_labels

0 commit comments

Comments
 (0)