File tree Expand file tree Collapse file tree 3 files changed +46
-0
lines changed Expand file tree Collapse file tree 3 files changed +46
-0
lines changed Original file line number Diff line number Diff line change @@ -70,6 +70,13 @@ def set_training_args_kwargs(cls, cfg):
70
70
if trl .scale_rewards is not None :
71
71
grpo_args_kwargs ["scale_rewards" ] = trl .scale_rewards
72
72
73
+ if trl .loss_type is not None :
74
+ grpo_args_kwargs ["loss_type" ] = trl .loss_type
75
+ if trl .mask_truncated_completions is not None :
76
+ grpo_args_kwargs ["mask_truncated_completions" ] = (
77
+ trl .mask_truncated_completions
78
+ )
79
+
73
80
if trl .temperature is not None :
74
81
grpo_args_kwargs ["temperature" ] = trl .temperature
75
82
if trl .top_p is not None :
@@ -85,6 +92,11 @@ def set_training_args_kwargs(cls, cfg):
85
92
grpo_args_kwargs ["num_iterations" ] = trl .num_iterations
86
93
if trl .epsilon is not None :
87
94
grpo_args_kwargs ["epsilon" ] = trl .epsilon
95
+ if trl .epsilon_high is not None :
96
+ grpo_args_kwargs ["epsilon_high" ] = trl .epsilon_high
97
+
98
+ if trl .use_liger_loss is not None :
99
+ grpo_args_kwargs ["use_liger_loss" ] = trl .use_liger_loss
88
100
89
101
return grpo_args_kwargs
90
102
Original file line number Diff line number Diff line change @@ -1150,6 +1150,18 @@ def check_kto_config(cls, data):
1150
1150
1151
1151
return data
1152
1152
1153
+ @model_validator (mode = "before" )
1154
+ @classmethod
1155
+ def check_grpo_peft_liger (cls , data ):
1156
+ if (
1157
+ data .get ("rl" ) == "grpo"
1158
+ and data .get ("trl" , {})
1159
+ and data .get ("trl" ).get ("use_liger_loss" )
1160
+ and data .get ("adapter" )
1161
+ ):
1162
+ raise ValueError ("PEFT + GRPO + Liger is not yet supported" )
1163
+ return data
1164
+
1153
1165
@model_validator (mode = "after" )
1154
1166
def check_sequence_parallel_degree (self ):
1155
1167
if not self .sequence_parallel_degree :
Original file line number Diff line number Diff line change @@ -133,3 +133,25 @@ class TRLConfig(BaseModel):
133
133
"description" : "Epsilon value for clipping in the GRPO algorithm."
134
134
},
135
135
)
136
+ epsilon_high : float | None = Field (
137
+ default = None ,
138
+ json_schema_extra = {
139
+ "description" : "Upper-bound epsilon value for clipping in the GRPO algorithm."
140
+ },
141
+ )
142
+ use_liger_loss : bool | None = Field (
143
+ default = None ,
144
+ json_schema_extra = {"description" : "Whether to use Liger loss for GRPO." },
145
+ )
146
+ loss_type : str | None = Field (
147
+ default = None ,
148
+ json_schema_extra = {
149
+ "description" : "Specifies the loss formulation to use. Supported values are `grpo`, `bnpo`, and `dr_grpo`."
150
+ },
151
+ )
152
+ mask_truncated_completions : bool = Field (
153
+ default = False ,
154
+ json_schema_extra = {
155
+ "description" : "When enabled, truncated completions are excluded from the loss calculation."
156
+ },
157
+ )
You can’t perform that action at this time.
0 commit comments