@@ -239,7 +239,7 @@ def create_startend_row_indices(input_ids, pad_token_id=0):
239
239
240
240
241
241
class RLHFPPOLoss (nn .Layer ):
242
- def __init__ (self , config , clip_range_ratio = 0.2 ):
242
+ def __init__ (self , config , clip_range_ratio = 0.2 , clip_range_ratio_low = None , clip_range_ratio_high = None ):
243
243
"""
244
244
Initialize the `ClipRewardRange` object.
245
245
@@ -257,6 +257,8 @@ def __init__(self, config, clip_range_ratio=0.2):
257
257
"""
258
258
super ().__init__ ()
259
259
self .clip_range_ratio = clip_range_ratio
260
+ self .clip_range_ratio_low = clip_range_ratio_low
261
+ self .clip_range_ratio_high = clip_range_ratio_high
260
262
self .config = config
261
263
262
264
def actor_loss_fn (
@@ -283,8 +285,8 @@ def actor_loss_fn(
283
285
pg_loss1 = - advantages * ratio
284
286
pg_loss2 = - advantages * paddle .clip (
285
287
ratio ,
286
- 1.0 - self .clip_range_ratio ,
287
- 1.0 + self .clip_range_ratio ,
288
+ 1.0 - self .clip_range_ratio_low ,
289
+ 1.0 + self .clip_range_ratio_high ,
288
290
)
289
291
return paddle .sum (paddle .maximum (pg_loss1 , pg_loss2 ) * mask ) / mask .sum ()
290
292
@@ -361,6 +363,8 @@ def __init__(
361
363
config ,
362
364
ptx_coeff = 16 ,
363
365
clip_range_ratio = 0.2 ,
366
+ clip_range_ratio_low = None ,
367
+ clip_range_ratio_high = None ,
364
368
kl_loss_coeff = 0.001 ,
365
369
clip_range_score = 10 ,
366
370
info_buffer = None ,
@@ -379,10 +383,14 @@ def __init__(
379
383
self .config = config
380
384
self .ptx_coeff = ptx_coeff
381
385
# if self.config.use_fused_head_and_loss_fn:
382
- # self.ppo_criterion = FusedPPOLoss(config, clip_range_ratio)
386
+ # self.ppo_criterion = FusedPPOLoss(config, clip_range_ratio, clip_range_ratio_low, clip_range_ratio_high )
383
387
# else:
384
- # self.ppo_criterion = RLHFPPOLoss(config, clip_range_ratio)
385
- self .ppo_criterion = RLHFPPOLoss (config , clip_range_ratio )
388
+ # self.ppo_criterion = RLHFPPOLoss(config, clip_range_ratio, clip_range_ratio_low, clip_range_ratio_high)
389
+ self .clip_range_ratio_low = clip_range_ratio_low if clip_range_ratio_low is not None else clip_range_ratio
390
+ self .clip_range_ratio_high = clip_range_ratio_high if clip_range_ratio_high is not None else clip_range_ratio
391
+ self .ppo_criterion = RLHFPPOLoss (
392
+ config , clip_range_ratio , self .clip_range_ratio_low , self .clip_range_ratio_high
393
+ )
386
394
self .sft_criterion = PretrainingCriterion (config )
387
395
self .kl_loss_coeff = kl_loss_coeff
388
396
self .clip_range_score = clip_range_score
@@ -449,6 +457,8 @@ def forward(
449
457
tensor_parallel_output = self .config .tensor_parallel_output ,
450
458
pg_loss_coeff = self .pg_loss_coeff , # donot use this
451
459
clip_range_ratio = self .clip_range_ratio ,
460
+ clip_range_ratio_low = self .clip_range_ratio_low ,
461
+ clip_range_ratio_high = self .clip_range_ratio_high ,
452
462
entropy_coeff = self .entropy_coeff , # donot support this
453
463
clip_range_score = self .clip_range_score ,
454
464
kl_loss_coeff = self .kl_loss_coeff ,
@@ -642,6 +652,8 @@ def forward(
642
652
ref_log_probs : paddle .Tensor ,
643
653
advantages : paddle .Tensor ,
644
654
clip_range_ratio : float ,
655
+ clip_range_ratio_low : float ,
656
+ clip_range_ratio_high : float ,
645
657
clip_range_score : float ,
646
658
kl_loss_coeff : float , # KL loss coefficient
647
659
temperature : float ,
@@ -777,7 +789,9 @@ def forward(
777
789
778
790
# ratio
779
791
ratio_chunk = paddle .exp (log_probs_chunk - old_log_probs_chunk )
780
- clipped_ratio_chunk = paddle .clip (ratio_chunk , min = 1.0 - clip_range_ratio , max = 1.0 + clip_range_ratio )
792
+ clipped_ratio_chunk = paddle .clip (
793
+ ratio_chunk , min = 1.0 - clip_range_ratio_low , max = 1.0 + clip_range_ratio_high
794
+ )
781
795
782
796
# final loss
783
797
pg_loss1_chunk = - advantages_chunk * ratio_chunk
@@ -913,10 +927,12 @@ def backward(ctx, grad_output, *args):
913
927
class FusedPPOLoss (nn .Layer ):
914
928
"""Fused PPOLoss"""
915
929
916
- def __init__ (self , config , clip_range_ratio = 0.2 ):
930
+ def __init__ (self , config , clip_range_ratio = 0.2 , clip_range_ratio_low = None , clip_range_ratio_high = None ):
917
931
"""Initialize FusedPPOLoss class."""
918
932
super ().__init__ ()
919
933
self .clip_range_ratio = clip_range_ratio
934
+ self .clip_range_ratio_low = clip_range_ratio_low
935
+ self .clip_range_ratio_high = clip_range_ratio_high
920
936
self .config = config
921
937
922
938
def forward (
@@ -970,6 +986,8 @@ def forward(
970
986
old_log_probs = old_log_probs ,
971
987
advantages = reward_advantages ,
972
988
clip_range_ratio = self .clip_range_ratio ,
989
+ clip_range_ratio_low = self .clip_range_ratio_low ,
990
+ clip_range_ratio_high = self .clip_range_ratio_high ,
973
991
)
974
992
return actor_loss
975
993
@@ -994,6 +1012,8 @@ def forward(
994
1012
tensor_parallel_output : bool ,
995
1013
pg_loss_coeff : float ,
996
1014
clip_range_ratio : float , # pg loss
1015
+ clip_range_ratio_low : float ,
1016
+ clip_range_ratio_high : float ,
997
1017
entropy_coeff : float , # entropy loss
998
1018
clip_range_score : float , # clip loss
999
1019
kl_loss_coeff : float , # clip loss
@@ -1092,8 +1112,8 @@ def maybe_transpose(x):
1092
1112
ratio_chunk = paddle .exp (log_probs_chunk - old_log_prob_chunk )
1093
1113
clipped_ratio_chunk = paddle .clip (
1094
1114
ratio_chunk ,
1095
- min = 1.0 - clip_range_ratio ,
1096
- max = 1.0 + clip_range_ratio ,
1115
+ min = 1.0 - clip_range_ratio_low ,
1116
+ max = 1.0 + clip_range_ratio_high ,
1097
1117
)
1098
1118
1099
1119
pg_loss1_chunk = - advantages_chunk * ratio_chunk
@@ -1249,6 +1269,8 @@ def actor_fused_pg_entropy_kl_loss(
1249
1269
tensor_parallel_output : bool = False ,
1250
1270
pg_loss_coeff : float = 1.0 ,
1251
1271
clip_range_ratio : float = 0.2 ,
1272
+ clip_range_ratio_low : float = None ,
1273
+ clip_range_ratio_high : float = None ,
1252
1274
entropy_coeff : float = 0.001 ,
1253
1275
clip_range_score : float = 10.0 ,
1254
1276
kl_loss_coeff : float = 0.001 ,
@@ -1280,6 +1302,8 @@ def actor_fused_pg_entropy_kl_loss(
1280
1302
fused_linear = fused_linear ,
1281
1303
loop_chunk_size = loop_chunk_size ,
1282
1304
clip_range_ratio = clip_range_ratio ,
1305
+ clip_range_ratio_low = clip_range_ratio_low ,
1306
+ clip_range_ratio_high = clip_range_ratio_high ,
1283
1307
clip_range_score = clip_range_score ,
1284
1308
kl_loss_coeff = kl_loss_coeff ,
1285
1309
ignore_index = - 100 ,
@@ -1301,6 +1325,8 @@ def actor_fused_pg_entropy_kl_loss(
1301
1325
tensor_parallel_output = tensor_parallel_output ,
1302
1326
pg_loss_coeff = pg_loss_coeff ,
1303
1327
clip_range_ratio = clip_range_ratio , # pg loss
1328
+ clip_range_ratio_low = clip_range_ratio_low ,
1329
+ clip_range_ratio_high = clip_range_ratio_high ,
1304
1330
entropy_coeff = entropy_coeff , # entropy loss
1305
1331
clip_range_score = clip_range_score , # clip loss
1306
1332
kl_loss_coeff = kl_loss_coeff , # clip loss
0 commit comments