@@ -112,21 +112,21 @@ def _validate(self) -> None:
112
112
@config_class
113
113
class LossMaskSpansConfig (Config ):
114
114
masking_column : str = Field (
115
- default = None ,
116
- desc = "Field containing character spans to mask for loss computation " ,
117
- hint = FieldHint .optional ,
115
+ default = "" ,
116
+ desc = "Field containing (input) character spans for loss masking " ,
117
+ hint = FieldHint .core ,
118
118
)
119
119
loss_masking_spans : str = Field (
120
120
default = "fast_llm_loss_masking_spans" ,
121
- desc = "Field containing character spans to mask for loss computation " ,
121
+ desc = "Column name of field that would contain the masked spans. " ,
122
122
hint = FieldHint .optional ,
123
123
)
124
124
def _validate (self ) -> None :
125
125
assert isinstance (self .loss_masking_spans , str ), "loss_masking_spans col name must be a string."
126
126
super ()._validate ()
127
127
128
128
@config_class
129
- class FieldCombinePreparatorConfig (Config ):
129
+ class CombineFieldsConfig (Config ):
130
130
col_names : typing .List [str ] = Field (
131
131
default_factory = list ,
132
132
desc = "Fields of the dataset to combine." ,
@@ -205,7 +205,7 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig):
205
205
" Does not shuffle samples." ,
206
206
hint = FieldHint .optional ,
207
207
)
208
- combine_fields : FieldCombinePreparatorConfig = Field (
208
+ combine_fields : CombineFieldsConfig = Field (
209
209
default = None ,
210
210
desc = "Combine all files into a single file." ,
211
211
hint = FieldHint .optional ,
0 commit comments