@@ -112,21 +112,21 @@ def _validate(self) -> None:
112
112
@config_class
113
113
class LossMaskSpansConfig (Config ):
114
114
masking_column : str = Field (
115
- default = None ,
116
- desc = "Field containing character spans to mask for loss computation " ,
117
- hint = FieldHint .optional ,
115
+ default = "" ,
116
+ desc = "Field containing (input) character spans for loss masking " ,
117
+ hint = FieldHint .core ,
118
118
)
119
119
loss_masking_spans : str = Field (
120
120
default = "fast_llm_loss_masking_spans" ,
121
- desc = "Field containing character spans to mask for loss computation " ,
121
+ desc = "Column name of field that would contain the masked spans. " ,
122
122
hint = FieldHint .optional ,
123
123
)
124
124
def _validate (self ) -> None :
125
125
assert isinstance (self .loss_masking_spans , str ), "loss_masking_spans col name must be a string."
126
126
super ()._validate ()
127
127
128
128
@config_class
129
- class FieldCombinePreparatorConfig (Config ):
129
+ class CombineFieldsConfig (Config ):
130
130
col_names : typing .List [str ] = Field (
131
131
default_factory = list ,
132
132
desc = "Fields of the dataset to combine." ,
@@ -143,7 +143,7 @@ class FieldCombinePreparatorConfig(Config):
143
143
hint = FieldHint .optional ,
144
144
)
145
145
set_masking_span : LossMaskSpansConfig = Field (
146
- default = None ,
146
+ default_factory = LossMaskSpansConfig ,
147
147
desc = "Compute loss_masking_spans for the newly combined field." ,
148
148
hint = FieldHint .optional ,
149
149
)
@@ -205,8 +205,8 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig):
205
205
" Does not shuffle samples." ,
206
206
hint = FieldHint .optional ,
207
207
)
208
- combine_fields : FieldCombinePreparatorConfig = Field (
209
- default = None ,
208
+ combine_fields : CombineFieldsConfig = Field (
209
+ default_factory = CombineFieldsConfig ,
210
210
desc = "Combine all files into a single file." ,
211
211
hint = FieldHint .optional ,
212
212
)
@@ -215,12 +215,12 @@ def _validate(self) -> None:
215
215
assert self .tokenizer .path is not None
216
216
if self .dataset .data_type is not None :
217
217
Assert .incl (DataType .from_numpy (self .dataset .data_type .numpy ), MEMMAP_DTYPES_INV )
218
- if self .combine_fields is not None :
218
+ if len ( self .combine_fields . col_names ) > 0 :
219
219
logger .info (
220
220
f"Setting dataset.field to { self .combine_fields .new_field_name } " ,
221
221
)
222
222
self .dataset .field = self .combine_fields .new_field_name
223
- if self .combine_fields .set_masking_span is not None :
223
+ if self .combine_fields .set_masking_span . masking_column != "" :
224
224
logger .info (f"Setting dataset.loss_masking_spans to { self .combine_fields .set_masking_span .loss_masking_spans } " )
225
225
self .dataset .loss_masking_spans = self .combine_fields .set_masking_span .loss_masking_spans
226
226
super ()._validate ()
0 commit comments