@@ -77,7 +77,6 @@ class GPTHuggingfaceDatasetConfig(Config):
77
77
hint = FieldHint .optional ,
78
78
)
79
79
80
-
81
80
@config_class
82
81
class DatasetPreparatorDistributedConfig (Config ):
83
82
# TODO: Unify with fast_llm.engine.distributed.config.DistributedConfig
@@ -110,6 +109,22 @@ def _validate(self) -> None:
110
109
super ()._validate ()
111
110
Assert .in_range (self .rank , 0 , self .world_size )
112
111
112
+ @config_class
113
+ class LossMaskSpansConfig (Config ):
114
+ masking_column : str = Field (
115
+ default = None ,
116
+ desc = "Field containing character spans to mask for loss computation" ,
117
+ hint = FieldHint .optional ,
118
+ )
119
+ loss_masking_spans : str = Field (
120
+ default = "fast_llm_loss_masking_spans" ,
121
+ desc = "Field containing character spans to mask for loss computation" ,
122
+ hint = FieldHint .optional ,
123
+ )
124
+ def _validate (self ) -> None :
125
+ assert isinstance (self .loss_masking_spans , str ), "loss_masking_spans col name must be a string."
126
+ super ()._validate ()
127
+
113
128
@config_class
114
129
class FieldCombinePreparatorConfig (Config ):
115
130
col_names : typing .List [str ] = Field (
@@ -127,13 +142,13 @@ class FieldCombinePreparatorConfig(Config):
127
142
desc = "Name of the new field to create." ,
128
143
hint = FieldHint .optional ,
129
144
)
130
-
145
+ set_masking_span : LossMaskSpansConfig = Field (
146
+ default = None ,
147
+ desc = "Compute loss_masking_spans for the newly combined field." ,
148
+ hint = FieldHint .optional ,
149
+ )
131
150
def _validate (self ) -> None :
132
- # Assert.gt(len(self.fields), 0)
133
- # assert isinstance(self.fields, list), "Fields must be a list."
134
- # assert all(isinstance(field, str) for field in self.fields), "All fields must be strings."
135
151
assert isinstance (self .delimiter , str ), "Delimiter must be a string."
136
- # assert isinstance(self.new_field_name, str), "New field name must be a string."
137
152
super ()._validate ()
138
153
139
154
@config_class ()
@@ -201,11 +216,13 @@ def _validate(self) -> None:
201
216
if self .dataset .data_type is not None :
202
217
Assert .incl (DataType .from_numpy (self .dataset .data_type .numpy ), MEMMAP_DTYPES_INV )
203
218
if self .combine_fields is not None :
204
- if self .dataset .field != self .combine_fields .new_field_name :
205
- logger .warning (
206
- f"Combine mode activated yet dataset.field != combine_fields.new_field_name ({ self .dataset .field } != { self .combine_fields .new_field_name } ). Setting dataset.field to { self .combine_fields .new_field_name } " ,
207
- )
208
- self .dataset .field = self .combine_fields .new_field_name
219
+ logger .info (
220
+ f"Setting dataset.field to { self .combine_fields .new_field_name } " ,
221
+ )
222
+ self .dataset .field = self .combine_fields .new_field_name
223
+ if self .combine_fields .set_masking_span is not None :
224
+ logger .info (f"Setting dataset.loss_masking_spans to { self .combine_fields .set_masking_span .loss_masking_spans } " )
225
+ self .dataset .loss_masking_spans = self .combine_fields .set_masking_span .loss_masking_spans
209
226
super ()._validate ()
210
227
211
228
@classmethod
0 commit comments