Skip to content

Commit 7d57592

Browse files
committed
set dataset.field to the new field
1 parent 52b9179 commit 7d57592

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

fast_llm/data/preparator/gpt_memmap/config.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import pathlib
33
import typing
4-
import dataclasses
4+
import logging
55

66
from fast_llm.config import Config, Field, FieldHint, check_field, config_class
77
from fast_llm.data.config import TokenizerConfig
@@ -24,7 +24,7 @@
2424
MEMMAP_DTYPES_INV = {y: x for x, y in MEMMAP_DTYPES.items()}
2525
MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00"
2626

27-
27+
logger = logging.getLogger(__name__)
2828
@config_class
2929
class GPTHuggingfaceDatasetConfig(Config):
3030
path: str = Field(
@@ -200,6 +200,12 @@ def _validate(self) -> None:
200200
assert self.tokenizer.path is not None
201201
if self.dataset.data_type is not None:
202202
Assert.incl(DataType.from_numpy(self.dataset.data_type.numpy), MEMMAP_DTYPES_INV)
203+
if self.combine_fields is not None:
204+
if self.dataset.field != self.combine_fields.new_field_name:
205+
logger.warning(
206+
f"Combine mode activated yet dataset.field != combine_fields.new_field_name ({self.dataset.field} != {self.combine_fields.new_field_name}). Setting dataset.field to {self.combine_fields.new_field_name}",
207+
)
208+
self.dataset.field = self.combine_fields.new_field_name
203209
super()._validate()
204210

205211
@classmethod

fast_llm/data/preparator/gpt_memmap/prepare.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def run(self) -> None:
212212
# Check for combining fields
213213
if self._config.combine_fields:
214214
Assert.eq(len(set(self._config.combine_fields.col_names).intersection(dataset.column_names)), len(self._config.combine_fields.col_names))
215+
logger.info(f"Combining fields {self._config.combine_fields.col_names} into {self._config.combine_fields.new_field_name}")
215216
dataset = dataset.map(
216217
lambda example: {
217218
self._config.combine_fields.new_field_name: self._config.combine_fields.delimiter.join(
@@ -221,9 +222,9 @@ def run(self) -> None:
221222
batched=False,
222223
desc="Combining fields",
223224
)
224-
# Set the new field name in the config for following operations
225-
self._config.dataset.field = self._config.combine_fields.new_field_name
226-
225+
logger.info(f"Sample after combining fields:\n{dataset[0]}")
226+
# Note: self.dataset.field is set to new_field_name for the rest of the operation see config validation
227+
227228
dataset = dataset.shard(
228229
num_shards=self._config.distributed.world_size,
229230
index=self._config.distributed.rank,

0 commit comments

Comments
 (0)