Skip to content

Commit 2f4df5d

Browse files
committed
updates to config names and desc.
1 parent 80584e7 commit 2f4df5d

File tree

2 files changed

+12
-13
lines changed

2 files changed

+12
-13
lines changed

fast_llm/data/preparator/gpt_memmap/config.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,21 +112,21 @@ def _validate(self) -> None:
112112
@config_class
113113
class LossMaskSpansConfig(Config):
114114
masking_column: str = Field(
115-
default=None,
116-
desc="Field containing character spans to mask for loss computation",
117-
hint=FieldHint.optional,
115+
default="",
116+
desc="Field containing (input) character spans for loss masking",
117+
hint=FieldHint.core,
118118
)
119119
loss_masking_spans: str = Field(
120120
default="fast_llm_loss_masking_spans",
121-
desc="Field containing character spans to mask for loss computation",
121+
desc="Column name of field that would contain the masked spans.",
122122
hint=FieldHint.optional,
123123
)
124124
def _validate(self) -> None:
125125
assert isinstance(self.loss_masking_spans, str), "loss_masking_spans col name must be a string."
126126
super()._validate()
127127

128128
@config_class
129-
class FieldCombinePreparatorConfig(Config):
129+
class CombineFieldsConfig(Config):
130130
col_names: typing.List[str] = Field(
131131
default_factory=list,
132132
desc="Fields of the dataset to combine.",
@@ -143,7 +143,7 @@ class FieldCombinePreparatorConfig(Config):
143143
hint=FieldHint.optional,
144144
)
145145
set_masking_span: LossMaskSpansConfig = Field(
146-
default=None,
146+
default_factory=LossMaskSpansConfig,
147147
desc="Compute loss_masking_spans for the newly combined field.",
148148
hint=FieldHint.optional,
149149
)
@@ -205,8 +205,8 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig):
205205
" Does not shuffle samples.",
206206
hint=FieldHint.optional,
207207
)
208-
combine_fields: FieldCombinePreparatorConfig = Field(
209-
default=None,
208+
combine_fields: CombineFieldsConfig = Field(
209+
default_factory=CombineFieldsConfig,
210210
desc="Combine all files into a single file.",
211211
hint=FieldHint.optional,
212212
)
@@ -215,12 +215,12 @@ def _validate(self) -> None:
215215
assert self.tokenizer.path is not None
216216
if self.dataset.data_type is not None:
217217
Assert.incl(DataType.from_numpy(self.dataset.data_type.numpy), MEMMAP_DTYPES_INV)
218-
if self.combine_fields is not None:
218+
if len(self.combine_fields.col_names) > 0:
219219
logger.info(
220220
f"Setting dataset.field to {self.combine_fields.new_field_name}",
221221
)
222222
self.dataset.field = self.combine_fields.new_field_name
223-
if self.combine_fields.set_masking_span is not None:
223+
if self.combine_fields.set_masking_span.masking_column != "":
224224
logger.info(f"Setting dataset.loss_masking_spans to {self.combine_fields.set_masking_span.loss_masking_spans}")
225225
self.dataset.loss_masking_spans = self.combine_fields.set_masking_span.loss_masking_spans
226226
super()._validate()

fast_llm/data/preparator/gpt_memmap/prepare.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def run(self) -> None:
210210
assert isinstance(dataset, datasets.Dataset)
211211

212212
# Check for combining fields
213-
if self._config.combine_fields:
213+
if len(self._config.combine_fields.col_names) > 0:
214214
Assert.eq(len(set(self._config.combine_fields.col_names).intersection(dataset.column_names)), len(self._config.combine_fields.col_names),\
215215
msg=f"Some columns to combine are not in the dataset. {set(self._config.combine_fields.col_names).difference(dataset.column_names)}")
216216

@@ -227,9 +227,8 @@ def run(self) -> None:
227227
logger.info(f"Sample after combining fields:\n{dataset[0]}")
228228
# Note: self.dataset.field is set to new_field_name for the rest of the operation see config validation
229229

230-
if self._config.combine_fields.set_masking_span is not None:
230+
if self._config.combine_fields.set_masking_span.masking_column != "":
231231
Assert.incl(self._config.combine_fields.set_masking_span.masking_column, dataset.column_names)
232-
233232
dataset = dataset.map(
234233
lambda example: {
235234
self._config.dataset.loss_masking_spans: [

0 commit comments

Comments
 (0)