Skip to content

Commit 80584e7

Browse files
committed
add loss masking col when combining cols
1 parent 7d57592 commit 80584e7

File tree

2 files changed

+45
-12
lines changed

2 files changed

+45
-12
lines changed

fast_llm/data/preparator/gpt_memmap/config.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ class GPTHuggingfaceDatasetConfig(Config):
7777
hint=FieldHint.optional,
7878
)
7979

80-
8180
@config_class
8281
class DatasetPreparatorDistributedConfig(Config):
8382
# TODO: Unify with fast_llm.engine.distributed.config.DistributedConfig
@@ -110,6 +109,22 @@ def _validate(self) -> None:
110109
super()._validate()
111110
Assert.in_range(self.rank, 0, self.world_size)
112111

112+
@config_class
113+
class LossMaskSpansConfig(Config):
114+
masking_column: str = Field(
115+
default=None,
116+
desc="Field containing character spans to mask for loss computation",
117+
hint=FieldHint.optional,
118+
)
119+
loss_masking_spans: str = Field(
120+
default="fast_llm_loss_masking_spans",
121+
desc="Field containing character spans to mask for loss computation",
122+
hint=FieldHint.optional,
123+
)
124+
def _validate(self) -> None:
125+
assert isinstance(self.loss_masking_spans, str), "loss_masking_spans col name must be a string."
126+
super()._validate()
127+
113128
@config_class
114129
class FieldCombinePreparatorConfig(Config):
115130
col_names: typing.List[str] = Field(
@@ -127,13 +142,13 @@ class FieldCombinePreparatorConfig(Config):
127142
desc="Name of the new field to create.",
128143
hint=FieldHint.optional,
129144
)
130-
145+
set_masking_span: LossMaskSpansConfig = Field(
146+
default=None,
147+
desc="Compute loss_masking_spans for the newly combined field.",
148+
hint=FieldHint.optional,
149+
)
131150
def _validate(self) -> None:
132-
# Assert.gt(len(self.fields), 0)
133-
# assert isinstance(self.fields, list), "Fields must be a list."
134-
# assert all(isinstance(field, str) for field in self.fields), "All fields must be strings."
135151
assert isinstance(self.delimiter, str), "Delimiter must be a string."
136-
# assert isinstance(self.new_field_name, str), "New field name must be a string."
137152
super()._validate()
138153

139154
@config_class()
@@ -201,11 +216,13 @@ def _validate(self) -> None:
201216
if self.dataset.data_type is not None:
202217
Assert.incl(DataType.from_numpy(self.dataset.data_type.numpy), MEMMAP_DTYPES_INV)
203218
if self.combine_fields is not None:
204-
if self.dataset.field != self.combine_fields.new_field_name:
205-
logger.warning(
206-
f"Combine mode activated yet dataset.field != combine_fields.new_field_name ({self.dataset.field} != {self.combine_fields.new_field_name}). Setting dataset.field to {self.combine_fields.new_field_name}",
207-
)
208-
self.dataset.field = self.combine_fields.new_field_name
219+
logger.info(
220+
f"Setting dataset.field to {self.combine_fields.new_field_name}",
221+
)
222+
self.dataset.field = self.combine_fields.new_field_name
223+
if self.combine_fields.set_masking_span is not None:
224+
logger.info(f"Setting dataset.loss_masking_spans to {self.combine_fields.set_masking_span.loss_masking_spans}")
225+
self.dataset.loss_masking_spans = self.combine_fields.set_masking_span.loss_masking_spans
209226
super()._validate()
210227

211228
@classmethod

fast_llm/data/preparator/gpt_memmap/prepare.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,9 @@ def run(self) -> None:
211211

212212
# Check for combining fields
213213
if self._config.combine_fields:
214-
Assert.eq(len(set(self._config.combine_fields.col_names).intersection(dataset.column_names)), len(self._config.combine_fields.col_names))
214+
Assert.eq(len(set(self._config.combine_fields.col_names).intersection(dataset.column_names)), len(self._config.combine_fields.col_names),\
215+
msg=f"Some columns to combine are not in the dataset. {set(self._config.combine_fields.col_names).difference(dataset.column_names)}")
216+
215217
logger.info(f"Combining fields {self._config.combine_fields.col_names} into {self._config.combine_fields.new_field_name}")
216218
dataset = dataset.map(
217219
lambda example: {
@@ -224,6 +226,20 @@ def run(self) -> None:
224226
)
225227
logger.info(f"Sample after combining fields:\n{dataset[0]}")
226228
# Note: self.dataset.field is set to new_field_name for the rest of the operation see config validation
229+
230+
if self._config.combine_fields.set_masking_span is not None:
231+
Assert.incl(self._config.combine_fields.set_masking_span.masking_column, dataset.column_names)
232+
233+
dataset = dataset.map(
234+
lambda example: {
235+
self._config.dataset.loss_masking_spans: [
236+
(0, len(str(example[self._config.combine_fields.set_masking_span.masking_column])) - 1)
237+
]# spans are inclusive
238+
},
239+
batched=False,
240+
desc="Setting loss masking spans",
241+
)
242+
logger.info(f"Sample after setting loss masking spans:\n{dataset[0]}")
227243

228244
dataset = dataset.shard(
229245
num_shards=self._config.distributed.world_size,

0 commit comments

Comments
 (0)