Skip to content

Commit 24ff5f5

Browse files
authored
additional args for grpo config/trainer (axolotl-ai-cloud#2598)
1 parent 5e949ea commit 24ff5f5

File tree

3 files changed

+46
-0
lines changed

3 files changed

+46
-0
lines changed

src/axolotl/core/trainers/grpo/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@ def set_training_args_kwargs(cls, cfg):
7070
if trl.scale_rewards is not None:
7171
grpo_args_kwargs["scale_rewards"] = trl.scale_rewards
7272

73+
if trl.loss_type is not None:
74+
grpo_args_kwargs["loss_type"] = trl.loss_type
75+
if trl.mask_truncated_completions is not None:
76+
grpo_args_kwargs["mask_truncated_completions"] = (
77+
trl.mask_truncated_completions
78+
)
79+
7380
if trl.temperature is not None:
7481
grpo_args_kwargs["temperature"] = trl.temperature
7582
if trl.top_p is not None:
@@ -85,6 +92,11 @@ def set_training_args_kwargs(cls, cfg):
8592
grpo_args_kwargs["num_iterations"] = trl.num_iterations
8693
if trl.epsilon is not None:
8794
grpo_args_kwargs["epsilon"] = trl.epsilon
95+
if trl.epsilon_high is not None:
96+
grpo_args_kwargs["epsilon_high"] = trl.epsilon_high
97+
98+
if trl.use_liger_loss is not None:
99+
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
88100

89101
return grpo_args_kwargs
90102

src/axolotl/utils/schemas/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,6 +1150,18 @@ def check_kto_config(cls, data):
11501150

11511151
return data
11521152

1153+
@model_validator(mode="before")
1154+
@classmethod
1155+
def check_grpo_peft_liger(cls, data):
1156+
if (
1157+
data.get("rl") == "grpo"
1158+
and data.get("trl", {})
1159+
and data.get("trl").get("use_liger_loss")
1160+
and data.get("adapter")
1161+
):
1162+
raise ValueError("PEFT + GRPO + Liger is not yet supported")
1163+
return data
1164+
11531165
@model_validator(mode="after")
11541166
def check_sequence_parallel_degree(self):
11551167
if not self.sequence_parallel_degree:

src/axolotl/utils/schemas/trl.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,25 @@ class TRLConfig(BaseModel):
133133
"description": "Epsilon value for clipping in the GRPO algorithm."
134134
},
135135
)
136+
epsilon_high: float | None = Field(
137+
default=None,
138+
json_schema_extra={
139+
"description": "Upper-bound epsilon value for clipping in the GRPO algorithm."
140+
},
141+
)
142+
use_liger_loss: bool | None = Field(
143+
default=None,
144+
json_schema_extra={"description": "Whether to use Liger loss for GRPO."},
145+
)
146+
loss_type: str | None = Field(
147+
default=None,
148+
json_schema_extra={
149+
"description": "Specifies the loss formulation to use. Supported values are `grpo`, `bnpo`, and `dr_grpo`."
150+
},
151+
)
152+
mask_truncated_completions: bool = Field(
153+
default=False,
154+
json_schema_extra={
155+
"description": "When enabled, truncated completions are excluded from the loss calculation."
156+
},
157+
)

0 commit comments

Comments
 (0)