Skip to content

Commit 9c371ab

Browse files
committed
combined cols to create a new col
1 parent 929c1cf commit 9c371ab

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

fast_llm/data/preparator/gpt_memmap/config.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import pathlib
33
import typing
4+
import dataclasses
45

56
from fast_llm.config import Config, Field, FieldHint, check_field, config_class
67
from fast_llm.data.config import TokenizerConfig
@@ -109,6 +110,31 @@ def _validate(self) -> None:
109110
super()._validate()
110111
Assert.in_range(self.rank, 0, self.world_size)
111112

113+
@config_class
114+
class FieldCombinePreparatorConfig(Config):
115+
fields: typing.List[str] = Field(
116+
default_factory=list,
117+
desc="Fields of the dataset to combine.",
118+
hint=FieldHint.core,
119+
)
120+
delimiter: str = Field(
121+
default=" ",
122+
desc="Delimiter to use when combining fields.",
123+
hint=FieldHint.optional,
124+
)
125+
new_field_name: str = Field(
126+
default="fast_llm_combined_field",
127+
desc="Name of the new field to create.",
128+
hint=FieldHint.optional,
129+
)
130+
131+
def _validate(self) -> None:
132+
# Assert.gt(len(self.fields), 0)
133+
# assert isinstance(self.fields, list), "Fields must be a list."
134+
# assert all(isinstance(field, str) for field in self.fields), "All fields must be strings."
135+
assert isinstance(self.delimiter, str), "Delimiter must be a string."
136+
# assert isinstance(self.new_field_name, str), "New field name must be a string."
137+
super()._validate()
112138

113139
@config_class()
114140
class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig):
@@ -164,6 +190,11 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig):
164190
" Does not shuffle samples.",
165191
hint=FieldHint.optional,
166192
)
193+
combine_fields: FieldCombinePreparatorConfig = Field(
194+
default=None,
195+
desc="Combine all files into a single file.",
196+
hint=FieldHint.optional,
197+
)
167198

168199
def _validate(self) -> None:
169200
assert self.tokenizer.path is not None

fast_llm/data/preparator/gpt_memmap/prepare.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,22 @@ def run(self) -> None:
208208
torch.distributed.barrier()
209209

210210
assert isinstance(dataset, datasets.Dataset)
211+
212+
# Check for combining fields
213+
if self._config.combine_fields:
214+
Assert.eq(len(set(self._config.combine_fields.fields).intersection(dataset.column_names)), len(self._config.combine_fields.fields))
215+
dataset = dataset.map(
216+
lambda example: {
217+
self._config.combine_fields.new_field_name: self._config.combine_fields.delimiter.join(
218+
str(example[column]) for column in self._config.combine_fields.fields
219+
)
220+
},
221+
batched=False,
222+
desc="Combining fields",
223+
)
224+
# Set the new field name in the config for following operations
225+
self._config.dataset.field = self._config.combine_fields.new_field_name
226+
211227
dataset = dataset.shard(
212228
num_shards=self._config.distributed.world_size,
213229
index=self._config.distributed.rank,

0 commit comments

Comments
 (0)