Skip to content

Commit e6e4928

Browse files
authored
【RL】Add DAPO (#10380)
* add dapo * fix dapo * fix conflict
1 parent d854740 commit e6e4928

File tree

8 files changed

+582
-37
lines changed

8 files changed

+582
-37
lines changed

paddlenlp/datasets/rlhf_datasets/protocol.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,6 @@ class DataProtoItem:
219219
class DataProto:
220220
"""
221221
A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions.
222-
It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/.
223-
TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the
224-
same batch size should be put inside batch.
225222
"""
226223

227224
batch: TensorDict = None
@@ -343,7 +340,7 @@ def to(self, device) -> "DataProto":
343340
"""move the batch to device
344341
345342
Args:
346-
device (torch.device, str): torch device
343+
device (paddle.device, str): paddle device
347344
348345
Returns:
349346
DataProto: the current DataProto
@@ -466,8 +463,7 @@ def union(self, other: "DataProto") -> "DataProto":
466463
return self
467464

468465
def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=None):
469-
"""Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch
470-
dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details.
466+
"""Make an iterator from the DataProto.
471467
472468
Args:
473469
mini_batch_size (int): mini-batch size when iterating the dataset. We require that

paddlenlp/datasets/rlhf_datasets/rl_dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,12 @@ def padding_batch_data(samples: list[dict], pad_token_id: int, requires_label: b
4242
# attention_mask = [np.ones(input_id.shape, dtype=bool) for input_id in input_ids]
4343
input_dict["input_ids"] = left_padding(input_ids, padding_value=pad_token_id)
4444
# input_dict["attention_mask"] = left_padding(attention_mask, padding_value=0)
45+
input_dict["raw_prompt_len"] = paddle.to_tensor([len(sample["input_ids"]) for sample in samples])
4546

4647
if requires_label:
4748
label_ids = [sample["label_ids"] for sample in samples]
4849
input_dict["label_ids"] = left_padding(label_ids, padding_value=pad_token_id)
50+
input_dict["raw_label_ids_len"] = paddle.to_tensor([len(sample["label_ids"]) for sample in samples])
4951

5052
return input_dict
5153

paddlenlp/rl/algos/penalty.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
17+
18+
def apply_overlong_penalty(response_length, max_dec_len, overlong_buffer_len, penalty_factor):
19+
"""
20+
Apply length penalty to overlong responses.
21+
22+
Args:
23+
response_length (paddle.Tensor): Tensor of shape (B,) indicating the length of each response.
24+
max_dec_len (int): The maximum allowed decoding length.
25+
overlong_buffer_len (int): The allowed buffer before applying penalty.
26+
penalty_factor (float): The penalty factor to scale the length overflow.
27+
28+
Returns:
29+
paddle.Tensor: A tensor of shape (B,) representing the length penalty for each response.
30+
"""
31+
expected_len = max_dec_len - overlong_buffer_len
32+
exceed_len = response_length - expected_len
33+
34+
reward_penalty = -exceed_len / overlong_buffer_len * penalty_factor
35+
# Only apply negative penalty if response exceeds limit, otherwise zero
36+
overlong_penalty = paddle.minimum(reward_penalty, paddle.zeros_like(reward_penalty))
37+
38+
return overlong_penalty

paddlenlp/rl/models/ppo_model_utils.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def create_startend_row_indices(input_ids, pad_token_id=0):
239239

240240

241241
class RLHFPPOLoss(nn.Layer):
242-
def __init__(self, config, clip_range_ratio=0.2):
242+
def __init__(self, config, clip_range_ratio=0.2, clip_range_ratio_low=None, clip_range_ratio_high=None):
243243
"""
244244
Initialize the `ClipRewardRange` object.
245245
@@ -257,6 +257,8 @@ def __init__(self, config, clip_range_ratio=0.2):
257257
"""
258258
super().__init__()
259259
self.clip_range_ratio = clip_range_ratio
260+
self.clip_range_ratio_low = clip_range_ratio_low
261+
self.clip_range_ratio_high = clip_range_ratio_high
260262
self.config = config
261263

262264
def actor_loss_fn(
@@ -283,8 +285,8 @@ def actor_loss_fn(
283285
pg_loss1 = -advantages * ratio
284286
pg_loss2 = -advantages * paddle.clip(
285287
ratio,
286-
1.0 - self.clip_range_ratio,
287-
1.0 + self.clip_range_ratio,
288+
1.0 - self.clip_range_ratio_low,
289+
1.0 + self.clip_range_ratio_high,
288290
)
289291
return paddle.sum(paddle.maximum(pg_loss1, pg_loss2) * mask) / mask.sum()
290292

@@ -361,6 +363,8 @@ def __init__(
361363
config,
362364
ptx_coeff=16,
363365
clip_range_ratio=0.2,
366+
clip_range_ratio_low=None,
367+
clip_range_ratio_high=None,
364368
kl_loss_coeff=0.001,
365369
clip_range_score=10,
366370
info_buffer=None,
@@ -379,10 +383,14 @@ def __init__(
379383
self.config = config
380384
self.ptx_coeff = ptx_coeff
381385
# if self.config.use_fused_head_and_loss_fn:
382-
# self.ppo_criterion = FusedPPOLoss(config, clip_range_ratio)
386+
# self.ppo_criterion = FusedPPOLoss(config, clip_range_ratio, clip_range_ratio_low, clip_range_ratio_high)
383387
# else:
384-
# self.ppo_criterion = RLHFPPOLoss(config, clip_range_ratio)
385-
self.ppo_criterion = RLHFPPOLoss(config, clip_range_ratio)
388+
# self.ppo_criterion = RLHFPPOLoss(config, clip_range_ratio, clip_range_ratio_low, clip_range_ratio_high)
389+
self.clip_range_ratio_low = clip_range_ratio_low if clip_range_ratio_low is not None else clip_range_ratio
390+
self.clip_range_ratio_high = clip_range_ratio_high if clip_range_ratio_high is not None else clip_range_ratio
391+
self.ppo_criterion = RLHFPPOLoss(
392+
config, clip_range_ratio, self.clip_range_ratio_low, self.clip_range_ratio_high
393+
)
386394
self.sft_criterion = PretrainingCriterion(config)
387395
self.kl_loss_coeff = kl_loss_coeff
388396
self.clip_range_score = clip_range_score
@@ -449,6 +457,8 @@ def forward(
449457
tensor_parallel_output=self.config.tensor_parallel_output,
450458
pg_loss_coeff=self.pg_loss_coeff, # donot use this
451459
clip_range_ratio=self.clip_range_ratio,
460+
clip_range_ratio_low=self.clip_range_ratio_low,
461+
clip_range_ratio_high=self.clip_range_ratio_high,
452462
entropy_coeff=self.entropy_coeff, # donot support this
453463
clip_range_score=self.clip_range_score,
454464
kl_loss_coeff=self.kl_loss_coeff,
@@ -642,6 +652,8 @@ def forward(
642652
ref_log_probs: paddle.Tensor,
643653
advantages: paddle.Tensor,
644654
clip_range_ratio: float,
655+
clip_range_ratio_low: float,
656+
clip_range_ratio_high: float,
645657
clip_range_score: float,
646658
kl_loss_coeff: float, # KL loss coefficient
647659
temperature: float,
@@ -777,7 +789,9 @@ def forward(
777789

778790
# ratio
779791
ratio_chunk = paddle.exp(log_probs_chunk - old_log_probs_chunk)
780-
clipped_ratio_chunk = paddle.clip(ratio_chunk, min=1.0 - clip_range_ratio, max=1.0 + clip_range_ratio)
792+
clipped_ratio_chunk = paddle.clip(
793+
ratio_chunk, min=1.0 - clip_range_ratio_low, max=1.0 + clip_range_ratio_high
794+
)
781795

782796
# final loss
783797
pg_loss1_chunk = -advantages_chunk * ratio_chunk
@@ -913,10 +927,12 @@ def backward(ctx, grad_output, *args):
913927
class FusedPPOLoss(nn.Layer):
914928
"""Fused PPOLoss"""
915929

916-
def __init__(self, config, clip_range_ratio=0.2):
930+
def __init__(self, config, clip_range_ratio=0.2, clip_range_ratio_low=None, clip_range_ratio_high=None):
917931
"""Initialize FusedPPOLoss class."""
918932
super().__init__()
919933
self.clip_range_ratio = clip_range_ratio
934+
self.clip_range_ratio_low = clip_range_ratio_low
935+
self.clip_range_ratio_high = clip_range_ratio_high
920936
self.config = config
921937

922938
def forward(
@@ -970,6 +986,8 @@ def forward(
970986
old_log_probs=old_log_probs,
971987
advantages=reward_advantages,
972988
clip_range_ratio=self.clip_range_ratio,
989+
clip_range_ratio_low=self.clip_range_ratio_low,
990+
clip_range_ratio_high=self.clip_range_ratio_high,
973991
)
974992
return actor_loss
975993

@@ -994,6 +1012,8 @@ def forward(
9941012
tensor_parallel_output: bool,
9951013
pg_loss_coeff: float,
9961014
clip_range_ratio: float, # pg loss
1015+
clip_range_ratio_low: float,
1016+
clip_range_ratio_high: float,
9971017
entropy_coeff: float, # entropy loss
9981018
clip_range_score: float, # clip loss
9991019
kl_loss_coeff: float, # clip loss
@@ -1092,8 +1112,8 @@ def maybe_transpose(x):
10921112
ratio_chunk = paddle.exp(log_probs_chunk - old_log_prob_chunk)
10931113
clipped_ratio_chunk = paddle.clip(
10941114
ratio_chunk,
1095-
min=1.0 - clip_range_ratio,
1096-
max=1.0 + clip_range_ratio,
1115+
min=1.0 - clip_range_ratio_low,
1116+
max=1.0 + clip_range_ratio_high,
10971117
)
10981118

10991119
pg_loss1_chunk = -advantages_chunk * ratio_chunk
@@ -1249,6 +1269,8 @@ def actor_fused_pg_entropy_kl_loss(
12491269
tensor_parallel_output: bool = False,
12501270
pg_loss_coeff: float = 1.0,
12511271
clip_range_ratio: float = 0.2,
1272+
clip_range_ratio_low: float = None,
1273+
clip_range_ratio_high: float = None,
12521274
entropy_coeff: float = 0.001,
12531275
clip_range_score: float = 10.0,
12541276
kl_loss_coeff: float = 0.001,
@@ -1280,6 +1302,8 @@ def actor_fused_pg_entropy_kl_loss(
12801302
fused_linear=fused_linear,
12811303
loop_chunk_size=loop_chunk_size,
12821304
clip_range_ratio=clip_range_ratio,
1305+
clip_range_ratio_low=clip_range_ratio_low,
1306+
clip_range_ratio_high=clip_range_ratio_high,
12831307
clip_range_score=clip_range_score,
12841308
kl_loss_coeff=kl_loss_coeff,
12851309
ignore_index=-100,
@@ -1301,6 +1325,8 @@ def actor_fused_pg_entropy_kl_loss(
13011325
tensor_parallel_output=tensor_parallel_output,
13021326
pg_loss_coeff=pg_loss_coeff,
13031327
clip_range_ratio=clip_range_ratio, # pg loss
1328+
clip_range_ratio_low=clip_range_ratio_low,
1329+
clip_range_ratio_high=clip_range_ratio_high,
13041330
entropy_coeff=entropy_coeff, # entropy loss
13051331
clip_range_score=clip_range_score, # clip loss
13061332
kl_loss_coeff=kl_loss_coeff, # clip loss

0 commit comments

Comments
 (0)