Skip to content

Commit 477048c

Browse files
authored
fix_eos_mask (#10588)
1 parent 6d40971 commit 477048c

File tree

3 files changed

+39
-3
lines changed

3 files changed

+39
-3
lines changed

llm/alignment/rl/run_rl.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from paddle.distributed import fleet
2424

2525
from paddlenlp.datasets.rlhf_datasets import RLHFDataset, collate_fn
26+
from paddlenlp.generation import GenerationConfig
2627
from paddlenlp.rl.models.score_model import AutoModelForScore
2728
from paddlenlp.rl.trainer.ppo_trainer import PPOTrainer
2829
from paddlenlp.rl.utils.config_utils import (
@@ -358,6 +359,12 @@ def compute_metrics(eval_preds):
358359
accuracy = (eval_preds.predictions == 3).astype("float32").mean().item()
359360
return {"accuracy": accuracy}
360361

362+
try:
363+
generation_config = GenerationConfig.from_pretrained(model_args.actor_model_name_or_path)
364+
except:
365+
logger.warning("Can't find generation config, so it will not use generation_config field in the model config")
366+
generation_config = None
367+
361368
trainer = PPOTrainer(
362369
actor_model=actor_model,
363370
reference_model=reference_model,
@@ -379,6 +386,7 @@ def compute_metrics(eval_preds):
379386
max_prompt_len=data_args.max_prompt_len if training_args.balance_batch else None,
380387
), # NOTE: enforce prompt padding to max_prompt_len when using balance_batch
381388
compute_metrics=compute_metrics, # TODO: only used for grpo (kk datasets)
389+
generation_config=generation_config,
382390
)
383391

384392
# TODO(gongenlei) resume_from_checkpoint is not ready

paddlenlp/rl/trainer/ppo_trainer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from rich.table import Table
3434

3535
from ...data import DataCollator
36+
from ...generation import GenerationConfig
3637
from ...trainer.trainer import (
3738
EvalLoopOutput,
3839
EvalPrediction,
@@ -53,6 +54,7 @@
5354
PretrainedTokenizer,
5455
)
5556
from ...transformers.model_utils import _add_variant
57+
from ...trl import llm_utils
5658
from ...utils.env import PADDLE_WEIGHTS_NAME
5759
from ..algos.advantage import (
5860
add_kl_divergence_regularization,
@@ -71,6 +73,7 @@
7173
filter_valid_reward_groups,
7274
gather_and_pad,
7375
get_timer_label,
76+
make_eos_mask,
7477
new_timer_log,
7578
pad_tensor,
7679
split_batch_by_rank,
@@ -228,6 +231,7 @@ def __init__(
228231
callbacks: Optional[List[TrainerCallback]] = None,
229232
optimizers: Tuple[paddle.optimizer.Optimizer, paddle.optimizer.lr.LRScheduler] = (None, None),
230233
preprocess_logits_for_metrics: Optional[Callable[[paddle.Tensor, paddle.Tensor], paddle.Tensor]] = None,
234+
generation_config: Optional[GenerationConfig] = None,
231235
):
232236
"""
233237
Args:
@@ -359,6 +363,7 @@ def __init__(
359363
self.model = self.model_wrapped = self.DummyPPOModel()
360364
if self.timers:
361365
self.timers.log = types.MethodType(new_timer_log, self.timers)
366+
self.generation_config = generation_config
362367

363368
def create_actor_trainer(
364369
self,
@@ -1142,12 +1147,15 @@ def pad_batch_data(
11421147
dtype=label_ids[0].dtype,
11431148
padding_side="right",
11441149
)
1145-
position_ids = make_position_ids_from_input_ids(input_ids)
1150+
position_ids = make_position_ids_from_input_ids(input_ids, pad_token_id=self.tokenizer.pad_token_id)
11461151
return input_ids, label_ids, position_ids
11471152

11481153
def distribute_gather_and_pad_data(self, batch):
11491154
# group index for grpo
1150-
eos_mask = (batch["input_ids"] != self.tokenizer.pad_token_id)[:, batch["prompt"].shape[-1] :].to(
1155+
eos_mask = make_eos_mask(
1156+
batch["input_ids"][:, batch["prompt"].shape[-1] :],
1157+
eos_token_ids=llm_utils.get_eos_token_id(self.tokenizer, self.generation_config),
1158+
).to(
11511159
batch["log_probs"].dtype # fix dtype
11521160
)
11531161
try:

paddlenlp/rl/utils/comm_utils.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,7 @@ def process_prompt_and_response(micro_batch, pad_token_id=0):
10151015
response = paddle.stack(padded_response_tensors, axis=0)
10161016

10171017
micro_batch["input_ids"] = paddle.concat([micro_batch["prompt"], response], axis=1)
1018-
micro_batch["position_ids"] = make_position_ids_from_input_ids(micro_batch["input_ids"])
1018+
micro_batch["position_ids"] = make_position_ids_from_input_ids(micro_batch["input_ids"], pad_token_id=pad_token_id)
10191019
key_to_slice = [
10201020
"eos_mask",
10211021
"kl_rewards",
@@ -1072,3 +1072,23 @@ def split_batch_into_micro_batches(total_batch, batch_size, pad_token_id=0):
10721072
micro_batches.append(micro_batch)
10731073

10741074
return micro_batches
1075+
1076+
1077+
def make_eos_mask(response_id, eos_token_ids=0, dtype=paddle.int64):
1078+
"""
1079+
end of sentence token can be int or list: 1 or [1, 2]
1080+
e.g. eos_token=1
1081+
response_id: [0, 0, 2, 42, 3, 5, 1, 0, 0]
1082+
eos_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0]
1083+
"""
1084+
if isinstance(eos_token_ids, int):
1085+
eos_token_ids = [eos_token_ids]
1086+
1087+
eos_mask = paddle.zeros_like(response_id, dtype=paddle.bool)
1088+
for token_id in eos_token_ids:
1089+
eos_mask |= response_id == token_id
1090+
1091+
eos_mask = eos_mask.to("int64")
1092+
eos_mask = (paddle.cumsum(eos_mask, axis=1) - eos_mask).to("bool")
1093+
eos_mask = paddle.logical_not(eos_mask).to(dtype)
1094+
return eos_mask

0 commit comments

Comments
 (0)