Skip to content

combined columns to create a new field in prepare datasets #248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions fast_llm/data/preparator/gpt_memmap/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import pathlib
import typing
import logging

from fast_llm.config import Config, Field, FieldHint, check_field, config_class
from fast_llm.data.config import TokenizerConfig
Expand All @@ -23,7 +24,7 @@
MEMMAP_DTYPES_INV = {y: x for x, y in MEMMAP_DTYPES.items()}
MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00"


logger = logging.getLogger(__name__)
@config_class
class GPTHuggingfaceDatasetConfig(Config):
path: str = Field(
Expand Down Expand Up @@ -76,7 +77,6 @@ class GPTHuggingfaceDatasetConfig(Config):
hint=FieldHint.optional,
)


@config_class
class DatasetPreparatorDistributedConfig(Config):
# TODO: Unify with fast_llm.engine.distributed.config.DistributedConfig
Expand Down Expand Up @@ -109,6 +109,47 @@ def _validate(self) -> None:
super()._validate()
Assert.in_range(self.rank, 0, self.world_size)

@config_class
class LossMaskSpansConfig(Config):
masking_column: str = Field(
default="",
desc="Field containing (input) character spans for loss masking",
hint=FieldHint.core,
)
loss_masking_spans: str = Field(
default="fast_llm_loss_masking_spans",
desc="Column name of field that would contain the masked spans.",
hint=FieldHint.optional,
)
def _validate(self) -> None:
assert isinstance(self.loss_masking_spans, str), "loss_masking_spans col name must be a string."
super()._validate()

@config_class
class CombineFieldsConfig(Config):
col_names: typing.List[str] = Field(
default_factory=list,
desc="Fields of the dataset to combine.",
hint=FieldHint.core,
)
delimiter: str = Field(
default=" ",
desc="Delimiter to use when combining fields.",
hint=FieldHint.optional,
)
new_field_name: str = Field(
default="fast_llm_combined_field",
desc="Name of the new field to create.",
hint=FieldHint.optional,
)
set_masking_span: LossMaskSpansConfig = Field(
default_factory=LossMaskSpansConfig,
desc="Compute loss_masking_spans for the newly combined field.",
hint=FieldHint.optional,
)
def _validate(self) -> None:
assert isinstance(self.delimiter, str), "Delimiter must be a string."
super()._validate()

@config_class()
class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig):
Expand Down Expand Up @@ -164,11 +205,24 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig):
" Does not shuffle samples.",
hint=FieldHint.optional,
)
combine_fields: CombineFieldsConfig = Field(
default_factory=CombineFieldsConfig,
desc="Combine all files into a single file.",
hint=FieldHint.optional,
)

def _validate(self) -> None:
assert self.tokenizer.path is not None
if self.dataset.data_type is not None:
Assert.incl(DataType.from_numpy(self.dataset.data_type.numpy), MEMMAP_DTYPES_INV)
if len(self.combine_fields.col_names) > 0:
logger.info(
f"Setting dataset.field to {self.combine_fields.new_field_name}",
)
self.dataset.field = self.combine_fields.new_field_name
if self.combine_fields.set_masking_span.masking_column != "":
logger.info(f"Setting dataset.loss_masking_spans to {self.combine_fields.set_masking_span.loss_masking_spans}")
self.dataset.loss_masking_spans = self.combine_fields.set_masking_span.loss_masking_spans
super()._validate()

@classmethod
Expand Down
32 changes: 32 additions & 0 deletions fast_llm/data/preparator/gpt_memmap/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,38 @@ def run(self) -> None:
torch.distributed.barrier()

assert isinstance(dataset, datasets.Dataset)

# Check for combining fields
if len(self._config.combine_fields.col_names) > 0:
Assert.eq(len(set(self._config.combine_fields.col_names).intersection(dataset.column_names)), len(self._config.combine_fields.col_names),\
msg=f"Some columns to combine are not in the dataset. {set(self._config.combine_fields.col_names).difference(dataset.column_names)}")

logger.info(f"Combining fields {self._config.combine_fields.col_names} into {self._config.combine_fields.new_field_name}")
dataset = dataset.map(
lambda example: {
self._config.combine_fields.new_field_name: self._config.combine_fields.delimiter.join(
str(example[column]) for column in self._config.combine_fields.col_names
)
},
batched=False,
desc="Combining fields",
)
logger.info(f"Sample after combining fields:\n{dataset[0]}")
# Note: self.dataset.field is set to new_field_name for the rest of the operation see config validation

if self._config.combine_fields.set_masking_span.masking_column != "":
Assert.incl(self._config.combine_fields.set_masking_span.masking_column, dataset.column_names)
dataset = dataset.map(
lambda example: {
self._config.dataset.loss_masking_spans: [
(0, len(str(example[self._config.combine_fields.set_masking_span.masking_column])) - 1)
]# spans are inclusive
},
batched=False,
desc="Setting loss masking spans",
)
logger.info(f"Sample after setting loss masking spans:\n{dataset[0]}")

dataset = dataset.shard(
num_shards=self._config.distributed.world_size,
index=self._config.distributed.rank,
Expand Down