23
23
SamplingParameters ,
24
24
)
25
25
from fast_llm .engine .distributed .config import PhaseType
26
- from fast_llm .utils import Assert , Registry , normalize_probabilities , padded_cumsum
26
+ from fast_llm .utils import Assert , normalize_probabilities , padded_cumsum
27
27
28
28
if typing .TYPE_CHECKING :
29
29
from fast_llm .data .dataset .gpt .indexed import GPTConcatenatedDataset , GPTDatasetSlice , GPTIndexedDataset
@@ -93,61 +93,9 @@ class GPTSamplingData(SamplingData):
93
93
truncate_documents : bool = True
94
94
95
95
96
- @config_class ()
96
+ @config_class (registry = True )
97
97
class GPTSampledDatasetConfig (SampledDatasetConfig ):
98
-
99
- # TODO: Generalize dynamic types?
100
- _registry : typing .ClassVar [Registry [str , type ["GPTSampledDatasetConfig" ]]] = Registry [
101
- str , type ["GPTDatasetConfig" ]
102
- ]("gpt_dataset_class" , {})
103
- type_ : typing .ClassVar [str | None ] = None
104
- type : str | None = Field (
105
- default = None ,
106
- desc = "The type of dataset." ,
107
- hint = FieldHint .core ,
108
- )
109
-
110
- def _validate (self ) -> None :
111
- if self .type is None :
112
- self .type = self .type_
113
- # Should be handled in `from_dict`, but can fail if instantiating directly.
114
- Assert .eq (self .type , self .__class__ .type_ )
115
- super ()._validate ()
116
-
117
- @classmethod
118
- def _from_dict (
119
- cls ,
120
- default : dict [str , typing .Any ],
121
- strict : bool = True ,
122
- flat : bool = False ,
123
- ) -> typing .Self :
124
- type_ = default .get ("type" )
125
- if type_ is None :
126
- actual_cls = cls
127
- else :
128
- if type_ not in cls ._registry :
129
- raise ValueError (
130
- f"Unknown { cls ._registry .name } type { type_ } ." f" Available types: { list (cls ._registry .keys ())} "
131
- )
132
- actual_cls = cls ._registry [type_ ]
133
- Assert .custom (issubclass , actual_cls , cls )
134
- if actual_cls == cls :
135
- return super ()._from_dict (default , strict = strict , flat = flat )
136
- else :
137
- return actual_cls ._from_dict (default , strict = strict , flat = flat )
138
-
139
- def __init_subclass__ (cls ) -> None :
140
- if cls ._abstract and cls .type_ is not None :
141
- # Abstract classes should not have a `type_`
142
- raise ValueError (f"Abstract class { cls .__name__ } has type = { cls .type_ } , expected None." )
143
- if cls .type_ is not None :
144
- if cls .type_ in cls ._registry :
145
- raise ValueError (
146
- f"Registry { cls ._registry .name } already contains type { cls .type_ } ."
147
- f" Make sure all classes either have a unique or `None` type."
148
- )
149
- GPTSampledDatasetConfig ._registry [cls .type_ ] = cls
150
- super ().__init_subclass__ ()
98
+ pass
151
99
152
100
153
101
@config_class ()
@@ -161,10 +109,9 @@ def build(self) -> "GPTIndexedDataset":
161
109
raise NotImplementedError ()
162
110
163
111
164
- @config_class ()
112
+ @config_class (dynamic_type = { GPTSampledDatasetConfig : "random" } )
165
113
class GPTRandomDatasetConfig (GPTSamplableDatasetConfig ):
166
114
_abstract : typing .ClassVar [bool ] = False
167
- type_ : typing .ClassVar [str | None ] = "random"
168
115
name : str = Field (
169
116
default = "dummy" ,
170
117
desc = "The name of the dataset." ,
@@ -177,10 +124,9 @@ def build(self) -> "GPTRandomDataset":
177
124
return GPTRandomDataset (self .name )
178
125
179
126
180
- @config_class ()
127
+ @config_class (dynamic_type = { GPTSampledDatasetConfig : "memmap" } )
181
128
class GPTMemmapDatasetConfig (GPTIndexedDatasetConfig ):
182
129
_abstract : typing .ClassVar [bool ] = False
183
- type_ : typing .ClassVar [str | None ] = "memmap"
184
130
path : pathlib .Path = Field (
185
131
default = None ,
186
132
desc = "The path to the dataset, excluding the `.bin` or `.idx` suffix." ,
@@ -203,10 +149,9 @@ def build(self) -> "GPTMemmapDataset":
203
149
return GPTMemmapDataset (str (self .path ).replace ("/" , "__" ), self .path , self .num_documents , self .num_tokens )
204
150
205
151
206
- @config_class ()
152
+ @config_class (dynamic_type = { GPTSampledDatasetConfig : "concatenated" } )
207
153
class GPTConcatenatedDatasetConfig (ConcatenatedDatasetConfig , GPTIndexedDatasetConfig ):
208
154
_abstract : typing .ClassVar [bool ] = False
209
- type_ : typing .ClassVar [str | None ] = "concatenated"
210
155
datasets : list [GPTIndexedDatasetConfig ] = FieldUpdate ()
211
156
212
157
def build (self ) -> "GPTConcatenatedDataset" :
@@ -215,10 +160,9 @@ def build(self) -> "GPTConcatenatedDataset":
215
160
return self ._build (GPTConcatenatedDataset )
216
161
217
162
218
- @config_class ()
163
+ @config_class (dynamic_type = { GPTSampledDatasetConfig : "slice" } )
219
164
class GPTDatasetSliceConfig (DatasetSliceConfig , GPTIndexedDatasetConfig ):
220
165
_abstract : typing .ClassVar [bool ] = False
221
- type_ : typing .ClassVar [str | None ] = "slice"
222
166
dataset : GPTIndexedDatasetConfig = FieldUpdate ()
223
167
224
168
def build (self ) -> "GPTDatasetSlice" :
@@ -227,25 +171,22 @@ def build(self) -> "GPTDatasetSlice":
227
171
return self ._build (GPTDatasetSlice )
228
172
229
173
230
- @config_class ()
174
+ @config_class (dynamic_type = { GPTSampledDatasetConfig : "sampled" } )
231
175
class GPTSampledDatasetUpdateConfig (SampledDatasetUpdateConfig , GPTSampledDatasetConfig ):
232
176
_abstract = False
233
- type_ : typing .ClassVar [str | None ] = "sampled"
234
177
sampling : GPTSamplingConfig = FieldUpdate ()
235
178
dataset : GPTSampledDatasetConfig = FieldUpdate ()
236
179
237
180
238
- @config_class ()
181
+ @config_class (dynamic_type = { GPTSampledDatasetConfig : "blended" } )
239
182
class GPTBlendedDatasetConfig (BlendedDatasetConfig , GPTSampledDatasetConfig ):
240
183
_abstract : typing .ClassVar [bool ] = False
241
- type_ : typing .ClassVar [str | None ] = "blended"
242
184
datasets : list [GPTSampledDatasetConfig ] = FieldUpdate ()
243
185
244
186
245
- @config_class ()
187
+ @config_class (dynamic_type = { GPTSampledDatasetConfig : "file" } )
246
188
class GPTDatasetFromFileConfig (GPTSamplableDatasetConfig ):
247
189
_abstract : typing .ClassVar [bool ] = False
248
- type_ : typing .ClassVar [str | None ] = "file"
249
190
path : pathlib .Path = Field (
250
191
default = None ,
251
192
desc = "The path to a dataset config file." ,
@@ -281,11 +222,11 @@ def _convert_paths(self, config):
281
222
return config
282
223
283
224
284
- @config_class ()
225
+ # Add user-friendly names for the configs.
226
+ @config_class (dynamic_type = {GPTSampledDatasetConfig : "concatenated_memmap" })
285
227
class GPTConcatenatedMemmapConfig (GPTIndexedDatasetConfig ):
286
228
# TODO v0.3: Remove.
287
229
_abstract : typing .ClassVar [bool ] = False
288
- type_ : typing .ClassVar [str | None ] = "concatenated_memmap"
289
230
path : pathlib .Path = Field (
290
231
default = None ,
291
232
desc = "The path to a dataset directory." ,
@@ -388,14 +329,13 @@ class FimConfig(Config):
388
329
)
389
330
390
331
391
- @config_class ()
332
+ @config_class (dynamic_type = { GPTSampledDatasetConfig : "fim" } )
392
333
class GPTFimSampledDatasetConfig (GPTSampledDatasetConfig , FimConfig ):
393
334
"""
394
335
Configuration for FIM.
395
336
"""
396
337
397
338
_abstract : typing .ClassVar [bool ] = False
398
- type_ : typing .ClassVar [str | None ] = "fim"
399
339
400
340
dataset : GPTSampledDatasetConfig = Field (
401
341
default = None ,
@@ -456,10 +396,9 @@ class GPTLegacyConfig(Config):
456
396
)
457
397
458
398
459
- @config_class ()
399
+ @config_class (dynamic_type = { GPTSampledDatasetConfig : "legacy" } )
460
400
class GPTLegacyDatasetConfig (GPTSampledDatasetConfig , GPTLegacyConfig ):
461
401
_abstract : typing .ClassVar [bool ] = False
462
- type_ : typing .ClassVar [str | None ] = "legacy"
463
402
464
403
def build_and_sample (self , sampling : GPTSamplingData ) -> SampledDataset :
465
404
@@ -538,15 +477,14 @@ def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset:
538
477
return GPTSampledDatasetConfig .from_dict (dataset_config ).build_and_sample (sampling )
539
478
540
479
541
- @config_class ()
480
+ @config_class (dynamic_type = { GPTSampledDatasetConfig : "test_slow" } )
542
481
class GPTTestSlowDatasetConfig (GPTSampledDatasetConfig ):
543
482
"""
544
483
A mock dataset that mimics a slow dataset creation on one rank, which may trigger a timeout.
545
484
"""
546
485
547
486
# TODO: This belongs to a testing plugin.
548
487
_abstract : typing .ClassVar [bool ] = False
549
- type_ : typing .ClassVar [str | None ] = "test_slow"
550
488
sleep : float = Field (
551
489
default = 1 ,
552
490
desc = "Sleep time during build, in seconds." ,
0 commit comments