diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 2c4311c3..0273d7ea 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -1,6 +1,7 @@ import os import pathlib import typing +import logging from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.data.config import TokenizerConfig @@ -23,7 +24,7 @@ MEMMAP_DTYPES_INV = {y: x for x, y in MEMMAP_DTYPES.items()} MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00" - +logger = logging.getLogger(__name__) @config_class class GPTHuggingfaceDatasetConfig(Config): path: str = Field( @@ -76,7 +77,6 @@ class GPTHuggingfaceDatasetConfig(Config): hint=FieldHint.optional, ) - @config_class class DatasetPreparatorDistributedConfig(Config): # TODO: Unify with fast_llm.engine.distributed.config.DistributedConfig @@ -109,6 +109,47 @@ def _validate(self) -> None: super()._validate() Assert.in_range(self.rank, 0, self.world_size) +@config_class +class LossMaskSpansConfig(Config): + masking_column: str = Field( + default="", + desc="Field containing (input) character spans for loss masking", + hint=FieldHint.core, + ) + loss_masking_spans: str = Field( + default="fast_llm_loss_masking_spans", + desc="Column name of field that would contain the masked spans.", + hint=FieldHint.optional, + ) + def _validate(self) -> None: + assert isinstance(self.loss_masking_spans, str), "loss_masking_spans col name must be a string." + super()._validate() + +@config_class +class CombineFieldsConfig(Config): + col_names: typing.List[str] = Field( + default_factory=list, + desc="Fields of the dataset to combine.", + hint=FieldHint.core, + ) + delimiter: str = Field( + default=" ", + desc="Delimiter to use when combining fields.", + hint=FieldHint.optional, + ) + new_field_name: str = Field( + default="fast_llm_combined_field", + desc="Name of the new field to create.", + hint=FieldHint.optional, + ) + set_masking_span: LossMaskSpansConfig = Field( + default_factory=LossMaskSpansConfig, + desc="Compute loss_masking_spans for the newly combined field.", + hint=FieldHint.optional, + ) + def _validate(self) -> None: + assert isinstance(self.delimiter, str), "Delimiter must be a string." + super()._validate() @config_class() class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): @@ -164,11 +205,24 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): " Does not shuffle samples.", hint=FieldHint.optional, ) + combine_fields: CombineFieldsConfig = Field( + default_factory=CombineFieldsConfig, + desc="Combine all files into a single file.", + hint=FieldHint.optional, + ) def _validate(self) -> None: assert self.tokenizer.path is not None if self.dataset.data_type is not None: Assert.incl(DataType.from_numpy(self.dataset.data_type.numpy), MEMMAP_DTYPES_INV) + if len(self.combine_fields.col_names) > 0: + logger.info( + f"Setting dataset.field to {self.combine_fields.new_field_name}", + ) + self.dataset.field = self.combine_fields.new_field_name + if self.combine_fields.set_masking_span.masking_column != "": + logger.info(f"Setting dataset.loss_masking_spans to {self.combine_fields.set_masking_span.loss_masking_spans}") + self.dataset.loss_masking_spans = self.combine_fields.set_masking_span.loss_masking_spans super()._validate() @classmethod diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 23e497bf..bc528005 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -208,6 +208,38 @@ def run(self) -> None: torch.distributed.barrier() assert isinstance(dataset, datasets.Dataset) + + # Check for combining fields + if len(self._config.combine_fields.col_names) > 0: + Assert.eq(len(set(self._config.combine_fields.col_names).intersection(dataset.column_names)), len(self._config.combine_fields.col_names),\ + msg=f"Some columns to combine are not in the dataset. {set(self._config.combine_fields.col_names).difference(dataset.column_names)}") + + logger.info(f"Combining fields {self._config.combine_fields.col_names} into {self._config.combine_fields.new_field_name}") + dataset = dataset.map( + lambda example: { + self._config.combine_fields.new_field_name: self._config.combine_fields.delimiter.join( + str(example[column]) for column in self._config.combine_fields.col_names + ) + }, + batched=False, + desc="Combining fields", + ) + logger.info(f"Sample after combining fields:\n{dataset[0]}") + # Note: self.dataset.field is set to new_field_name for the rest of the operation see config validation + + if self._config.combine_fields.set_masking_span.masking_column != "": + Assert.incl(self._config.combine_fields.set_masking_span.masking_column, dataset.column_names) + dataset = dataset.map( + lambda example: { + self._config.dataset.loss_masking_spans: [ + (0, len(str(example[self._config.combine_fields.set_masking_span.masking_column])) - 1) + ]# spans are inclusive + }, + batched=False, + desc="Setting loss masking spans", + ) + logger.info(f"Sample after setting loss masking spans:\n{dataset[0]}") + dataset = dataset.shard( num_shards=self._config.distributed.world_size, index=self._config.distributed.rank,