|
27 | 27 | from .trainer_utils import guard_set_args
|
28 | 28 |
|
29 | 29 |
|
30 |
| -class ActorReferenceTrainer(RLTrainer): |
| 30 | +class ActorReferenceTrainerBase(RLTrainer): |
31 | 31 | loss_cls = RLHFPPOMixedLoss
|
32 | 32 | trainer_type = "policy"
|
33 |
| - |
34 |
| - def loss_identifier(self, inputs: Dict) -> str: |
35 |
| - """ |
36 |
| - Identify whether to use the ptx loss function or the actor loss function based on the input dictionary. |
37 |
| - If labels are present, return "ptx_loss"; otherwise, return "actor_loss". |
38 |
| -
|
39 |
| - Args: |
40 |
| - inputs (Dict): A dictionary containing two key-value pairs, "inputs" and "labels". |
41 |
| - "inputs" represents the model's input, while "labels" is optional and indicates whether to use the ptx loss function. |
42 |
| - The default value for "labels" is None. |
43 |
| -
|
44 |
| - Returns: |
45 |
| - str: A string indicating whether to use the ptx loss function or the actor loss function, either "ptx_loss" or "actor_loss". |
46 |
| - """ |
47 |
| - return "actor_loss" |
48 |
| - |
49 |
| - @paddle.no_grad() |
50 |
| - def generate_sequences(self, prompt_only_batch: Dict, do_eval=False) -> List[Dict[str, Any]]: |
51 |
| - """Rollout a batch of experiences.""" |
52 |
| - input_ids = prompt_only_batch["input_ids"] |
53 |
| - |
54 |
| - repeat_num = 1 if do_eval else self.args.rollout_n |
55 |
| - |
56 |
| - with guard_set_args(self.model.config, {"use_fused_head_and_loss_fn": False}): |
57 |
| - sequences = self.get_model(False).generate( |
58 |
| - input_ids=input_ids, |
59 |
| - attention_mask=None, |
60 |
| - position_ids=None, |
61 |
| - do_eval=do_eval, |
62 |
| - repeat_num=repeat_num, |
63 |
| - )[0] |
64 |
| - |
65 |
| - if repeat_num > 1: |
66 |
| - input_ids = input_ids.repeat_interleave(repeat_num, axis=0) |
67 |
| - |
68 |
| - if self.args.use_rm_server: |
69 |
| - label_ids = prompt_only_batch["label_ids"] |
70 |
| - if repeat_num > 1: |
71 |
| - label_ids = label_ids.repeat_interleave(repeat_num, axis=0) |
72 |
| - |
73 |
| - sequences = sequences.reshape([input_ids.shape[0] // repeat_num, repeat_num, -1]) |
74 |
| - if do_eval: |
75 |
| - sequences = sequences.transpose([1, 0, 2]) |
76 |
| - # prompt, sequence, attention_mask |
77 |
| - return [ |
78 |
| - { |
79 |
| - "prompt": input_ids, |
80 |
| - "input_ids": seq, |
81 |
| - **({"label_ids": label_ids[idx * len(seq) : (idx + 1) * len(seq)]} if self.args.use_rm_server else {}), |
82 |
| - "index": np.array([str(uuid.uuid4())] * len(seq), dtype=object), |
83 |
| - } |
84 |
| - for idx, seq in enumerate(sequences) |
85 |
| - ] |
| 33 | + loss_identifier = lambda self, inputs: "actor_loss" |
86 | 34 |
|
87 | 35 | @paddle.no_grad()
|
88 | 36 | def compute_logprob(self, input_ids: paddle.Tensor, position_ids: paddle.Tensor = None, **kwargs) -> paddle.Tensor:
|
@@ -280,3 +228,43 @@ def update_actor(self, rl_batch: Dict[str, paddle.Tensor]) -> Dict[str, Any]:
|
280 | 228 | "train_max_generated_length": max_generated_length,
|
281 | 229 | "train_min_generated_length": min_generated_length,
|
282 | 230 | }
|
| 231 | + |
| 232 | + |
| 233 | +class ActorReferenceTrainer(ActorReferenceTrainerBase): |
| 234 | + @paddle.no_grad() |
| 235 | + def generate_sequences(self, prompt_only_batch: Dict, do_eval=False) -> List[Dict[str, Any]]: |
| 236 | + """Rollout a batch of experiences.""" |
| 237 | + input_ids = prompt_only_batch["input_ids"] |
| 238 | + |
| 239 | + repeat_num = 1 if do_eval else self.args.rollout_n |
| 240 | + |
| 241 | + with guard_set_args(self.model.config, {"use_fused_head_and_loss_fn": False}): |
| 242 | + sequences = self.get_model(False).generate( |
| 243 | + input_ids=input_ids, |
| 244 | + attention_mask=None, |
| 245 | + position_ids=None, |
| 246 | + do_eval=do_eval, |
| 247 | + repeat_num=repeat_num, |
| 248 | + )[0] |
| 249 | + |
| 250 | + if repeat_num > 1: |
| 251 | + input_ids = input_ids.repeat_interleave(repeat_num, axis=0) |
| 252 | + |
| 253 | + if self.args.use_rm_server: |
| 254 | + label_ids = prompt_only_batch["label_ids"] |
| 255 | + if repeat_num > 1: |
| 256 | + label_ids = label_ids.repeat_interleave(repeat_num, axis=0) |
| 257 | + |
| 258 | + sequences = sequences.reshape([input_ids.shape[0] // repeat_num, repeat_num, -1]) |
| 259 | + if do_eval: |
| 260 | + sequences = sequences.transpose([1, 0, 2]) |
| 261 | + # prompt, sequence, attention_mask |
| 262 | + return [ |
| 263 | + { |
| 264 | + "prompt": input_ids, |
| 265 | + "input_ids": seq, |
| 266 | + **({"label_ids": label_ids[idx * len(seq) : (idx + 1) * len(seq)]} if self.args.use_rm_server else {}), |
| 267 | + "index": np.array([str(uuid.uuid4())] * len(seq), dtype=object), |
| 268 | + } |
| 269 | + for idx, seq in enumerate(sequences) |
| 270 | + ] |
0 commit comments