From 9c371ab99474788e496bf23a3fe8a62453dc9d34 Mon Sep 17 00:00:00 2001 From: "luke.kumar" Date: Thu, 1 May 2025 10:56:11 -0400 Subject: [PATCH 1/5] combined cols to create a new col --- fast_llm/data/preparator/gpt_memmap/config.py | 31 +++++++++++++++++++ .../data/preparator/gpt_memmap/prepare.py | 16 ++++++++++ 2 files changed, 47 insertions(+) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 2c4311c3..c8fe08c2 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -1,6 +1,7 @@ import os import pathlib import typing +import dataclasses from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.data.config import TokenizerConfig @@ -109,6 +110,31 @@ def _validate(self) -> None: super()._validate() Assert.in_range(self.rank, 0, self.world_size) +@config_class +class FieldCombinePreparatorConfig(Config): + fields: typing.List[str] = Field( + default_factory=list, + desc="Fields of the dataset to combine.", + hint=FieldHint.core, + ) + delimiter: str = Field( + default=" ", + desc="Delimiter to use when combining fields.", + hint=FieldHint.optional, + ) + new_field_name: str = Field( + default="fast_llm_combined_field", + desc="Name of the new field to create.", + hint=FieldHint.optional, + ) + + def _validate(self) -> None: + # Assert.gt(len(self.fields), 0) + # assert isinstance(self.fields, list), "Fields must be a list." + # assert all(isinstance(field, str) for field in self.fields), "All fields must be strings." + assert isinstance(self.delimiter, str), "Delimiter must be a string." + # assert isinstance(self.new_field_name, str), "New field name must be a string." + super()._validate() @config_class() class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): @@ -164,6 +190,11 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): " Does not shuffle samples.", hint=FieldHint.optional, ) + combine_fields: FieldCombinePreparatorConfig = Field( + default=None, + desc="Combine all files into a single file.", + hint=FieldHint.optional, + ) def _validate(self) -> None: assert self.tokenizer.path is not None diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 23e497bf..c8ad235d 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -208,6 +208,22 @@ def run(self) -> None: torch.distributed.barrier() assert isinstance(dataset, datasets.Dataset) + + # Check for combining fields + if self._config.combine_fields: + Assert.eq(len(set(self._config.combine_fields.fields).intersection(dataset.column_names)), len(self._config.combine_fields.fields)) + dataset = dataset.map( + lambda example: { + self._config.combine_fields.new_field_name: self._config.combine_fields.delimiter.join( + str(example[column]) for column in self._config.combine_fields.fields + ) + }, + batched=False, + desc="Combining fields", + ) + # Set the new field name in the config for following operations + self._config.dataset.field = self._config.combine_fields.new_field_name + dataset = dataset.shard( num_shards=self._config.distributed.world_size, index=self._config.distributed.rank, From 52b91797bfffc4ae1b48014eb44a4304273c172c Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Thu, 1 May 2025 17:19:10 +0000 Subject: [PATCH 2/5] rename fields to colnames name re-used --- fast_llm/data/preparator/gpt_memmap/config.py | 2 +- fast_llm/data/preparator/gpt_memmap/prepare.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index c8fe08c2..1f21e41b 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -112,7 +112,7 @@ def _validate(self) -> None: @config_class class FieldCombinePreparatorConfig(Config): - fields: typing.List[str] = Field( + col_names: typing.List[str] = Field( default_factory=list, desc="Fields of the dataset to combine.", hint=FieldHint.core, diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index c8ad235d..40bb65e1 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -211,11 +211,11 @@ def run(self) -> None: # Check for combining fields if self._config.combine_fields: - Assert.eq(len(set(self._config.combine_fields.fields).intersection(dataset.column_names)), len(self._config.combine_fields.fields)) + Assert.eq(len(set(self._config.combine_fields.col_names).intersection(dataset.column_names)), len(self._config.combine_fields.col_names)) dataset = dataset.map( lambda example: { self._config.combine_fields.new_field_name: self._config.combine_fields.delimiter.join( - str(example[column]) for column in self._config.combine_fields.fields + str(example[column]) for column in self._config.combine_fields.col_names ) }, batched=False, From 7d57592781fc87d8028d58a7becdbd02ca2f0f30 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Thu, 1 May 2025 18:12:59 +0000 Subject: [PATCH 3/5] set dataset.field to the new field --- fast_llm/data/preparator/gpt_memmap/config.py | 10 ++++++++-- fast_llm/data/preparator/gpt_memmap/prepare.py | 7 ++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 1f21e41b..4c6a1bf2 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -1,7 +1,7 @@ import os import pathlib import typing -import dataclasses +import logging from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.data.config import TokenizerConfig @@ -24,7 +24,7 @@ MEMMAP_DTYPES_INV = {y: x for x, y in MEMMAP_DTYPES.items()} MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00" - +logger = logging.getLogger(__name__) @config_class class GPTHuggingfaceDatasetConfig(Config): path: str = Field( @@ -200,6 +200,12 @@ def _validate(self) -> None: assert self.tokenizer.path is not None if self.dataset.data_type is not None: Assert.incl(DataType.from_numpy(self.dataset.data_type.numpy), MEMMAP_DTYPES_INV) + if self.combine_fields is not None: + if self.dataset.field != self.combine_fields.new_field_name: + logger.warning( + f"Combine mode activated yet dataset.field != combine_fields.new_field_name ({self.dataset.field} != {self.combine_fields.new_field_name}). Setting dataset.field to {self.combine_fields.new_field_name}", + ) + self.dataset.field = self.combine_fields.new_field_name super()._validate() @classmethod diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 40bb65e1..cd2425fa 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -212,6 +212,7 @@ def run(self) -> None: # Check for combining fields if self._config.combine_fields: Assert.eq(len(set(self._config.combine_fields.col_names).intersection(dataset.column_names)), len(self._config.combine_fields.col_names)) + logger.info(f"Combining fields {self._config.combine_fields.col_names} into {self._config.combine_fields.new_field_name}") dataset = dataset.map( lambda example: { self._config.combine_fields.new_field_name: self._config.combine_fields.delimiter.join( @@ -221,9 +222,9 @@ def run(self) -> None: batched=False, desc="Combining fields", ) - # Set the new field name in the config for following operations - self._config.dataset.field = self._config.combine_fields.new_field_name - + logger.info(f"Sample after combining fields:\n{dataset[0]}") + # Note: self.dataset.field is set to new_field_name for the rest of the operation see config validation + dataset = dataset.shard( num_shards=self._config.distributed.world_size, index=self._config.distributed.rank, From 80584e71e657078a252bf6c675459bdd95f64060 Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Thu, 1 May 2025 21:27:58 +0000 Subject: [PATCH 4/5] add loss masking col when combining cols --- fast_llm/data/preparator/gpt_memmap/config.py | 39 +++++++++++++------ .../data/preparator/gpt_memmap/prepare.py | 18 ++++++++- 2 files changed, 45 insertions(+), 12 deletions(-) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 4c6a1bf2..15e51e09 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -77,7 +77,6 @@ class GPTHuggingfaceDatasetConfig(Config): hint=FieldHint.optional, ) - @config_class class DatasetPreparatorDistributedConfig(Config): # TODO: Unify with fast_llm.engine.distributed.config.DistributedConfig @@ -110,6 +109,22 @@ def _validate(self) -> None: super()._validate() Assert.in_range(self.rank, 0, self.world_size) +@config_class +class LossMaskSpansConfig(Config): + masking_column: str = Field( + default=None, + desc="Field containing character spans to mask for loss computation", + hint=FieldHint.optional, + ) + loss_masking_spans: str = Field( + default="fast_llm_loss_masking_spans", + desc="Field containing character spans to mask for loss computation", + hint=FieldHint.optional, + ) + def _validate(self) -> None: + assert isinstance(self.loss_masking_spans, str), "loss_masking_spans col name must be a string." + super()._validate() + @config_class class FieldCombinePreparatorConfig(Config): col_names: typing.List[str] = Field( @@ -127,13 +142,13 @@ class FieldCombinePreparatorConfig(Config): desc="Name of the new field to create.", hint=FieldHint.optional, ) - + set_masking_span: LossMaskSpansConfig = Field( + default=None, + desc="Compute loss_masking_spans for the newly combined field.", + hint=FieldHint.optional, + ) def _validate(self) -> None: - # Assert.gt(len(self.fields), 0) - # assert isinstance(self.fields, list), "Fields must be a list." - # assert all(isinstance(field, str) for field in self.fields), "All fields must be strings." assert isinstance(self.delimiter, str), "Delimiter must be a string." - # assert isinstance(self.new_field_name, str), "New field name must be a string." super()._validate() @config_class() @@ -201,11 +216,13 @@ def _validate(self) -> None: if self.dataset.data_type is not None: Assert.incl(DataType.from_numpy(self.dataset.data_type.numpy), MEMMAP_DTYPES_INV) if self.combine_fields is not None: - if self.dataset.field != self.combine_fields.new_field_name: - logger.warning( - f"Combine mode activated yet dataset.field != combine_fields.new_field_name ({self.dataset.field} != {self.combine_fields.new_field_name}). Setting dataset.field to {self.combine_fields.new_field_name}", - ) - self.dataset.field = self.combine_fields.new_field_name + logger.info( + f"Setting dataset.field to {self.combine_fields.new_field_name}", + ) + self.dataset.field = self.combine_fields.new_field_name + if self.combine_fields.set_masking_span is not None: + logger.info(f"Setting dataset.loss_masking_spans to {self.combine_fields.set_masking_span.loss_masking_spans}") + self.dataset.loss_masking_spans = self.combine_fields.set_masking_span.loss_masking_spans super()._validate() @classmethod diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index cd2425fa..a1ff7c24 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -211,7 +211,9 @@ def run(self) -> None: # Check for combining fields if self._config.combine_fields: - Assert.eq(len(set(self._config.combine_fields.col_names).intersection(dataset.column_names)), len(self._config.combine_fields.col_names)) + Assert.eq(len(set(self._config.combine_fields.col_names).intersection(dataset.column_names)), len(self._config.combine_fields.col_names),\ + msg=f"Some columns to combine are not in the dataset. {set(self._config.combine_fields.col_names).difference(dataset.column_names)}") + logger.info(f"Combining fields {self._config.combine_fields.col_names} into {self._config.combine_fields.new_field_name}") dataset = dataset.map( lambda example: { @@ -224,6 +226,20 @@ def run(self) -> None: ) logger.info(f"Sample after combining fields:\n{dataset[0]}") # Note: self.dataset.field is set to new_field_name for the rest of the operation see config validation + + if self._config.combine_fields.set_masking_span is not None: + Assert.incl(self._config.combine_fields.set_masking_span.masking_column, dataset.column_names) + + dataset = dataset.map( + lambda example: { + self._config.dataset.loss_masking_spans: [ + (0, len(str(example[self._config.combine_fields.set_masking_span.masking_column])) - 1) + ]# spans are inclusive + }, + batched=False, + desc="Setting loss masking spans", + ) + logger.info(f"Sample after setting loss masking spans:\n{dataset[0]}") dataset = dataset.shard( num_shards=self._config.distributed.world_size, From 2f4df5d1afef0d6bf3ce75cd86a6c88ce7c87bff Mon Sep 17 00:00:00 2001 From: Luke Kumar Date: Fri, 2 May 2025 14:59:18 +0000 Subject: [PATCH 5/5] updates to config names and desc. --- fast_llm/data/preparator/gpt_memmap/config.py | 20 +++++++++---------- .../data/preparator/gpt_memmap/prepare.py | 5 ++--- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 15e51e09..0273d7ea 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -112,13 +112,13 @@ def _validate(self) -> None: @config_class class LossMaskSpansConfig(Config): masking_column: str = Field( - default=None, - desc="Field containing character spans to mask for loss computation", - hint=FieldHint.optional, + default="", + desc="Field containing (input) character spans for loss masking", + hint=FieldHint.core, ) loss_masking_spans: str = Field( default="fast_llm_loss_masking_spans", - desc="Field containing character spans to mask for loss computation", + desc="Column name of field that would contain the masked spans.", hint=FieldHint.optional, ) def _validate(self) -> None: @@ -126,7 +126,7 @@ def _validate(self) -> None: super()._validate() @config_class -class FieldCombinePreparatorConfig(Config): +class CombineFieldsConfig(Config): col_names: typing.List[str] = Field( default_factory=list, desc="Fields of the dataset to combine.", @@ -143,7 +143,7 @@ class FieldCombinePreparatorConfig(Config): hint=FieldHint.optional, ) set_masking_span: LossMaskSpansConfig = Field( - default=None, + default_factory=LossMaskSpansConfig, desc="Compute loss_masking_spans for the newly combined field.", hint=FieldHint.optional, ) @@ -205,8 +205,8 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): " Does not shuffle samples.", hint=FieldHint.optional, ) - combine_fields: FieldCombinePreparatorConfig = Field( - default=None, + combine_fields: CombineFieldsConfig = Field( + default_factory=CombineFieldsConfig, desc="Combine all files into a single file.", hint=FieldHint.optional, ) @@ -215,12 +215,12 @@ def _validate(self) -> None: assert self.tokenizer.path is not None if self.dataset.data_type is not None: Assert.incl(DataType.from_numpy(self.dataset.data_type.numpy), MEMMAP_DTYPES_INV) - if self.combine_fields is not None: + if len(self.combine_fields.col_names) > 0: logger.info( f"Setting dataset.field to {self.combine_fields.new_field_name}", ) self.dataset.field = self.combine_fields.new_field_name - if self.combine_fields.set_masking_span is not None: + if self.combine_fields.set_masking_span.masking_column != "": logger.info(f"Setting dataset.loss_masking_spans to {self.combine_fields.set_masking_span.loss_masking_spans}") self.dataset.loss_masking_spans = self.combine_fields.set_masking_span.loss_masking_spans super()._validate() diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index a1ff7c24..bc528005 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -210,7 +210,7 @@ def run(self) -> None: assert isinstance(dataset, datasets.Dataset) # Check for combining fields - if self._config.combine_fields: + if len(self._config.combine_fields.col_names) > 0: Assert.eq(len(set(self._config.combine_fields.col_names).intersection(dataset.column_names)), len(self._config.combine_fields.col_names),\ msg=f"Some columns to combine are not in the dataset. {set(self._config.combine_fields.col_names).difference(dataset.column_names)}") @@ -227,9 +227,8 @@ def run(self) -> None: logger.info(f"Sample after combining fields:\n{dataset[0]}") # Note: self.dataset.field is set to new_field_name for the rest of the operation see config validation - if self._config.combine_fields.set_masking_span is not None: + if self._config.combine_fields.set_masking_span.masking_column != "": Assert.incl(self._config.combine_fields.set_masking_span.masking_column, dataset.column_names) - dataset = dataset.map( lambda example: { self._config.dataset.loss_masking_spans: [