diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 54b62394a8..885757767b 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -107,7 +107,7 @@ def __init__( ref_model: Optional[nn.Module], reward_model: nn.Module, train_dataset: Dataset, - value_model: Optional[nn.Module] = None, + value_model: nn.Module, data_collator: Optional[DataCollatorWithPadding] = None, eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, # less commonly used