|
1 | 1 | import os
|
2 | 2 | import pathlib
|
3 | 3 | import typing
|
| 4 | +import dataclasses |
4 | 5 |
|
5 | 6 | from fast_llm.config import Config, Field, FieldHint, check_field, config_class
|
6 | 7 | from fast_llm.data.config import TokenizerConfig
|
@@ -109,6 +110,31 @@ def _validate(self) -> None:
|
109 | 110 | super()._validate()
|
110 | 111 | Assert.in_range(self.rank, 0, self.world_size)
|
111 | 112 |
|
| 113 | +@config_class |
| 114 | +class FieldCombinePreparatorConfig(Config): |
| 115 | + fields: typing.List[str] = Field( |
| 116 | + default_factory=list, |
| 117 | + desc="Fields of the dataset to combine.", |
| 118 | + hint=FieldHint.core, |
| 119 | + ) |
| 120 | + delimiter: str = Field( |
| 121 | + default=" ", |
| 122 | + desc="Delimiter to use when combining fields.", |
| 123 | + hint=FieldHint.optional, |
| 124 | + ) |
| 125 | + new_field_name: str = Field( |
| 126 | + default="fast_llm_combined_field", |
| 127 | + desc="Name of the new field to create.", |
| 128 | + hint=FieldHint.optional, |
| 129 | + ) |
| 130 | + |
| 131 | + def _validate(self) -> None: |
| 132 | + # Assert.gt(len(self.fields), 0) |
| 133 | + # assert isinstance(self.fields, list), "Fields must be a list." |
| 134 | + # assert all(isinstance(field, str) for field in self.fields), "All fields must be strings." |
| 135 | + assert isinstance(self.delimiter, str), "Delimiter must be a string." |
| 136 | + # assert isinstance(self.new_field_name, str), "New field name must be a string." |
| 137 | + super()._validate() |
112 | 138 |
|
113 | 139 | @config_class()
|
114 | 140 | class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig):
|
@@ -164,6 +190,11 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig):
|
164 | 190 | " Does not shuffle samples.",
|
165 | 191 | hint=FieldHint.optional,
|
166 | 192 | )
|
| 193 | + combine_fields: FieldCombinePreparatorConfig = Field( |
| 194 | + default=None, |
| 195 | + desc="Combine all files into a single file.", |
| 196 | + hint=FieldHint.optional, |
| 197 | + ) |
167 | 198 |
|
168 | 199 | def _validate(self) -> None:
|
169 | 200 | assert self.tokenizer.path is not None
|
|
0 commit comments