Skip to content

Commit 6d40971

Browse files
authored
Refactor RL (#10569)
* pathes to RLTrainerBase * refactor ActorReferenceTrainerBase
1 parent fe442e4 commit 6d40971

File tree

3 files changed

+467
-497
lines changed

3 files changed

+467
-497
lines changed

paddlenlp/rl/trainer/actor_trainer.py

Lines changed: 42 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -27,62 +27,10 @@
2727
from .trainer_utils import guard_set_args
2828

2929

30-
class ActorReferenceTrainer(RLTrainer):
30+
class ActorReferenceTrainerBase(RLTrainer):
3131
loss_cls = RLHFPPOMixedLoss
3232
trainer_type = "policy"
33-
34-
def loss_identifier(self, inputs: Dict) -> str:
35-
"""
36-
Identify whether to use the ptx loss function or the actor loss function based on the input dictionary.
37-
If labels are present, return "ptx_loss"; otherwise, return "actor_loss".
38-
39-
Args:
40-
inputs (Dict): A dictionary containing two key-value pairs, "inputs" and "labels".
41-
"inputs" represents the model's input, while "labels" is optional and indicates whether to use the ptx loss function.
42-
The default value for "labels" is None.
43-
44-
Returns:
45-
str: A string indicating whether to use the ptx loss function or the actor loss function, either "ptx_loss" or "actor_loss".
46-
"""
47-
return "actor_loss"
48-
49-
@paddle.no_grad()
50-
def generate_sequences(self, prompt_only_batch: Dict, do_eval=False) -> List[Dict[str, Any]]:
51-
"""Rollout a batch of experiences."""
52-
input_ids = prompt_only_batch["input_ids"]
53-
54-
repeat_num = 1 if do_eval else self.args.rollout_n
55-
56-
with guard_set_args(self.model.config, {"use_fused_head_and_loss_fn": False}):
57-
sequences = self.get_model(False).generate(
58-
input_ids=input_ids,
59-
attention_mask=None,
60-
position_ids=None,
61-
do_eval=do_eval,
62-
repeat_num=repeat_num,
63-
)[0]
64-
65-
if repeat_num > 1:
66-
input_ids = input_ids.repeat_interleave(repeat_num, axis=0)
67-
68-
if self.args.use_rm_server:
69-
label_ids = prompt_only_batch["label_ids"]
70-
if repeat_num > 1:
71-
label_ids = label_ids.repeat_interleave(repeat_num, axis=0)
72-
73-
sequences = sequences.reshape([input_ids.shape[0] // repeat_num, repeat_num, -1])
74-
if do_eval:
75-
sequences = sequences.transpose([1, 0, 2])
76-
# prompt, sequence, attention_mask
77-
return [
78-
{
79-
"prompt": input_ids,
80-
"input_ids": seq,
81-
**({"label_ids": label_ids[idx * len(seq) : (idx + 1) * len(seq)]} if self.args.use_rm_server else {}),
82-
"index": np.array([str(uuid.uuid4())] * len(seq), dtype=object),
83-
}
84-
for idx, seq in enumerate(sequences)
85-
]
33+
loss_identifier = lambda self, inputs: "actor_loss"
8634

8735
@paddle.no_grad()
8836
def compute_logprob(self, input_ids: paddle.Tensor, position_ids: paddle.Tensor = None, **kwargs) -> paddle.Tensor:
@@ -280,3 +228,43 @@ def update_actor(self, rl_batch: Dict[str, paddle.Tensor]) -> Dict[str, Any]:
280228
"train_max_generated_length": max_generated_length,
281229
"train_min_generated_length": min_generated_length,
282230
}
231+
232+
233+
class ActorReferenceTrainer(ActorReferenceTrainerBase):
234+
@paddle.no_grad()
235+
def generate_sequences(self, prompt_only_batch: Dict, do_eval=False) -> List[Dict[str, Any]]:
236+
"""Rollout a batch of experiences."""
237+
input_ids = prompt_only_batch["input_ids"]
238+
239+
repeat_num = 1 if do_eval else self.args.rollout_n
240+
241+
with guard_set_args(self.model.config, {"use_fused_head_and_loss_fn": False}):
242+
sequences = self.get_model(False).generate(
243+
input_ids=input_ids,
244+
attention_mask=None,
245+
position_ids=None,
246+
do_eval=do_eval,
247+
repeat_num=repeat_num,
248+
)[0]
249+
250+
if repeat_num > 1:
251+
input_ids = input_ids.repeat_interleave(repeat_num, axis=0)
252+
253+
if self.args.use_rm_server:
254+
label_ids = prompt_only_batch["label_ids"]
255+
if repeat_num > 1:
256+
label_ids = label_ids.repeat_interleave(repeat_num, axis=0)
257+
258+
sequences = sequences.reshape([input_ids.shape[0] // repeat_num, repeat_num, -1])
259+
if do_eval:
260+
sequences = sequences.transpose([1, 0, 2])
261+
# prompt, sequence, attention_mask
262+
return [
263+
{
264+
"prompt": input_ids,
265+
"input_ids": seq,
266+
**({"label_ids": label_ids[idx * len(seq) : (idx + 1) * len(seq)]} if self.args.use_rm_server else {}),
267+
"index": np.array([str(uuid.uuid4())] * len(seq), dtype=object),
268+
}
269+
for idx, seq in enumerate(sequences)
270+
]

paddlenlp/rl/trainer/ppo_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
from .actor_trainer import ActorReferenceTrainer
8383
from .critic_trainer import CriticTrainer
8484
from .reward_trainer import RewardTrainer
85+
from .rl_trainer import RLTrainerBase
8586
from .trainer_utils import (
8687
MuteDefaultFlowCallback,
8788
batch_retokenize,
@@ -205,7 +206,7 @@ def update(self, metrics: Dict[str, paddle.Tensor]) -> Union[None, Dict[str, flo
205206
return out_metrics
206207

207208

208-
class PPOTrainer(Trainer):
209+
class PPOTrainer(RLTrainerBase):
209210
def __init__(
210211
self,
211212
actor_model: Union[PretrainedModel, nn.Layer],

0 commit comments

Comments
 (0)