Skip to content

Commit 54f70b1

Browse files
authored
Update ppo_model_utils.py
1 parent cd2d2dc commit 54f70b1

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

paddlenlp/rl/models/ppo_model_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def backward(ctx, grad_output: paddle.Tensor) -> paddle.Tensor:
351351

352352

353353
def entropy_from_logits(logits: paddle.Tensor, tensor_parallel_output=False):
354-
return VocabParallelEntropy.apply(logits, tensor_parallel_output)
354+
return VocabParallelEntropy.apply(logits.astype("float32"), tensor_parallel_output)
355355

356356

357357
@merge_fwd_labels

0 commit comments

Comments
 (0)