Skip to content

Commit 3e5edc3

Browse files
authored
Minimalistic dynamic configs (#268)
1 parent e1a3d13 commit 3e5edc3

File tree

5 files changed

+115
-85
lines changed

5 files changed

+115
-85
lines changed

fast_llm/config.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import yaml
1414

15-
from fast_llm.utils import Assert, Tag, compare_nested, get_type_name, header, log
15+
from fast_llm.utils import Assert, Registry, Tag, compare_nested, get_type_name, header, log
1616

1717
logger = logging.getLogger(__name__)
1818

@@ -243,7 +243,9 @@ def _process_config_class(cls: type["Config"]):
243243
return cls
244244

245245

246-
def config_class[T: Config]() -> typing.Callable[[type[T]], type[T]]:
246+
def config_class[
247+
T: Config
248+
](registry: bool = False, dynamic_type: "dict[type[Config], str]|None" = None) -> typing.Callable[[type[T]], type[T]]:
247249
"""
248250
Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications.
249251
"""
@@ -253,7 +255,7 @@ def wrap(cls):
253255
if hasattr(cls, "__post_init__"):
254256
raise TypeError(f"`__post_init__` should not be implemented for `Config` classes")
255257

256-
wrapped = _process_config_class(dataclasses.dataclass(cls, kw_only=True))
258+
wrapped = _process_config_class(dataclasses.dataclass(cls, kw_only=True, repr=False))
257259

258260
wrapped_init = cls.__init__
259261

@@ -267,6 +269,14 @@ def __init__(self, **kwargs):
267269
self.validate()
268270

269271
wrapped.__init__ = __init__
272+
273+
wrapped._registry = Registry[str, type[wrapped]](wrapped.__name__, {}) if registry else None
274+
275+
if dynamic_type is not None:
276+
for cls_, name in dynamic_type.items():
277+
print(cls_, name, wrapped)
278+
cls_.register_subclass(name, wrapped)
279+
270280
return wrapped
271281

272282
return wrap
@@ -305,6 +315,9 @@ class Config(metaclass=ConfigMeta):
305315
# without them being automatically added to `_explicit_fields`.
306316
_setting_implicit_default: bool | None = Field(init=False)
307317

318+
# A registry for all the config classes.
319+
_registry: typing.ClassVar[Registry[str, type[typing.Self]] | None] = None
320+
308321
def __setattr__(self, key: str, value: typing.Any) -> None:
309322
"""
310323
Make the class read-only after validation.
@@ -358,6 +371,17 @@ def validate[T: Config](self: T, *, _is_validating: bool = False) -> T:
358371
Validate a class and mark it as read-only
359372
This should not be overridden in derived classes.
360373
"""
374+
# Should be handled in `from_dict`, but can fail if instantiating directly.
375+
try:
376+
expected_class = self.get_subclass(self.type)
377+
except KeyError as e:
378+
# Delayed instantiation error in `from_dict`.
379+
raise ValidationError(*e.args)
380+
381+
if expected_class is not None:
382+
# Should be handled in `from_dict`, but can fail if instantiating directly.
383+
Assert.is_(self.__class__, expected_class)
384+
361385
if not self._validated:
362386
try:
363387
self._validate()
@@ -738,6 +762,14 @@ def _from_dict(
738762
if "__class__" in default:
739763
del default["__class__"]
740764

765+
try:
766+
actual_cls = cls.get_subclass(default.get("type"))
767+
if actual_cls is not None and actual_cls is not cls:
768+
return actual_cls._from_dict(default, strict=strict, flat=flat)
769+
except KeyError:
770+
# Postpone error to validation.
771+
pass
772+
741773
# Do not validate yet in case the root class sets cross-dependencies in validation.
742774
with NoAutoValidate():
743775
for name, field in cls.fields():
@@ -864,6 +896,42 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ
864896
)
865897
return None
866898

899+
@classmethod
900+
def register_subclass(cls, name: str, cls_: type[typing.Self]) -> None:
901+
Assert.custom(issubclass, cls_, cls)
902+
if cls._registry is None:
903+
raise NotImplementedError(f"Subclass `{cls.__name__}` doesn't have a registry..")
904+
if name in cls._registry:
905+
old_cls = cls._registry[name]
906+
if old_cls.__name__ == cls_.__name__ and cls._registry[name].__module__ == cls_.__module__:
907+
del cls._registry[name]
908+
else:
909+
raise KeyError(f"{cls.__name__} class registry already has an entry {name} from class {cls.__name__}.")
910+
cls._registry[name] = cls_
911+
912+
@classmethod
913+
def get_subclass(cls, name: str | None):
914+
# TODO: Make it case-insensitive?
915+
if name is None:
916+
return None
917+
cls_ = None
918+
for base_class in cls.__mro__:
919+
if issubclass(base_class, Config) and base_class._registry is not None and name in base_class._registry:
920+
if cls_ is None:
921+
cls_ = base_class._registry[name]
922+
if not issubclass(cls_, cls):
923+
raise KeyError(f" {cls_.__name__} is not a subclass of {cls.__name__} (from type {name})")
924+
elif base_class._registry[name] is not cls_:
925+
# We explicitly prevent ambiguous classes to ensure safe and unambiguous serialization.
926+
# TODO: Only really need to avoid conflict with `Config`'s registry, relax this a bit?
927+
raise KeyError(
928+
f"Ambiguous type `{name}` for base class {cls.__name__}."
929+
f" ({cls_.__name__} vs {base_class._registry[name]})"
930+
)
931+
if cls_ is None:
932+
raise KeyError(f"Unknown type {name} for base class {cls.__name__}")
933+
return cls_
934+
867935
def __init_subclass__(cls):
868936
"""
869937
We need to postpone validation until the class has been processed by the dataclass wrapper.
@@ -913,6 +981,13 @@ def __init_subclass__(cls):
913981
# dataclasses expects an annotation, so we use the one from the base class.
914982
cls.__annotations__[name] = base_class_field.type
915983

984+
# Type for the field. At the end of class definition to avoid shadowing builtin.
985+
type: str | None = Field(
986+
default=None,
987+
desc="The config class name.",
988+
hint=FieldHint.feature,
989+
)
990+
916991

917992
class Configurable[ConfigType: Config]:
918993
config_class: typing.ClassVar[type[Config]] = Config

fast_llm/data/dataset/gpt/config.py

Lines changed: 15 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
SamplingParameters,
2424
)
2525
from fast_llm.engine.distributed.config import PhaseType
26-
from fast_llm.utils import Assert, Registry, normalize_probabilities, padded_cumsum
26+
from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum
2727

2828
if typing.TYPE_CHECKING:
2929
from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset
@@ -93,61 +93,9 @@ class GPTSamplingData(SamplingData):
9393
truncate_documents: bool = True
9494

9595

96-
@config_class()
96+
@config_class(registry=True)
9797
class GPTSampledDatasetConfig(SampledDatasetConfig):
98-
99-
# TODO: Generalize dynamic types?
100-
_registry: typing.ClassVar[Registry[str, type["GPTSampledDatasetConfig"]]] = Registry[
101-
str, type["GPTDatasetConfig"]
102-
]("gpt_dataset_class", {})
103-
type_: typing.ClassVar[str | None] = None
104-
type: str | None = Field(
105-
default=None,
106-
desc="The type of dataset.",
107-
hint=FieldHint.core,
108-
)
109-
110-
def _validate(self) -> None:
111-
if self.type is None:
112-
self.type = self.type_
113-
# Should be handled in `from_dict`, but can fail if instantiating directly.
114-
Assert.eq(self.type, self.__class__.type_)
115-
super()._validate()
116-
117-
@classmethod
118-
def _from_dict(
119-
cls,
120-
default: dict[str, typing.Any],
121-
strict: bool = True,
122-
flat: bool = False,
123-
) -> typing.Self:
124-
type_ = default.get("type")
125-
if type_ is None:
126-
actual_cls = cls
127-
else:
128-
if type_ not in cls._registry:
129-
raise ValueError(
130-
f"Unknown {cls._registry.name} type {type_}." f" Available types: {list(cls._registry.keys())}"
131-
)
132-
actual_cls = cls._registry[type_]
133-
Assert.custom(issubclass, actual_cls, cls)
134-
if actual_cls == cls:
135-
return super()._from_dict(default, strict=strict, flat=flat)
136-
else:
137-
return actual_cls._from_dict(default, strict=strict, flat=flat)
138-
139-
def __init_subclass__(cls) -> None:
140-
if cls._abstract and cls.type_ is not None:
141-
# Abstract classes should not have a `type_`
142-
raise ValueError(f"Abstract class {cls.__name__} has type = {cls.type_}, expected None.")
143-
if cls.type_ is not None:
144-
if cls.type_ in cls._registry:
145-
raise ValueError(
146-
f"Registry {cls._registry.name} already contains type {cls.type_}."
147-
f" Make sure all classes either have a unique or `None` type."
148-
)
149-
GPTSampledDatasetConfig._registry[cls.type_] = cls
150-
super().__init_subclass__()
98+
pass
15199

152100

153101
@config_class()
@@ -161,10 +109,9 @@ def build(self) -> "GPTIndexedDataset":
161109
raise NotImplementedError()
162110

163111

164-
@config_class()
112+
@config_class(dynamic_type={GPTSampledDatasetConfig: "random"})
165113
class GPTRandomDatasetConfig(GPTSamplableDatasetConfig):
166114
_abstract: typing.ClassVar[bool] = False
167-
type_: typing.ClassVar[str | None] = "random"
168115
name: str = Field(
169116
default="dummy",
170117
desc="The name of the dataset.",
@@ -177,10 +124,9 @@ def build(self) -> "GPTRandomDataset":
177124
return GPTRandomDataset(self.name)
178125

179126

180-
@config_class()
127+
@config_class(dynamic_type={GPTSampledDatasetConfig: "memmap"})
181128
class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig):
182129
_abstract: typing.ClassVar[bool] = False
183-
type_: typing.ClassVar[str | None] = "memmap"
184130
path: pathlib.Path = Field(
185131
default=None,
186132
desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.",
@@ -203,10 +149,9 @@ def build(self) -> "GPTMemmapDataset":
203149
return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens)
204150

205151

206-
@config_class()
152+
@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated"})
207153
class GPTConcatenatedDatasetConfig(ConcatenatedDatasetConfig, GPTIndexedDatasetConfig):
208154
_abstract: typing.ClassVar[bool] = False
209-
type_: typing.ClassVar[str | None] = "concatenated"
210155
datasets: list[GPTIndexedDatasetConfig] = FieldUpdate()
211156

212157
def build(self) -> "GPTConcatenatedDataset":
@@ -215,10 +160,9 @@ def build(self) -> "GPTConcatenatedDataset":
215160
return self._build(GPTConcatenatedDataset)
216161

217162

218-
@config_class()
163+
@config_class(dynamic_type={GPTSampledDatasetConfig: "slice"})
219164
class GPTDatasetSliceConfig(DatasetSliceConfig, GPTIndexedDatasetConfig):
220165
_abstract: typing.ClassVar[bool] = False
221-
type_: typing.ClassVar[str | None] = "slice"
222166
dataset: GPTIndexedDatasetConfig = FieldUpdate()
223167

224168
def build(self) -> "GPTDatasetSlice":
@@ -227,25 +171,22 @@ def build(self) -> "GPTDatasetSlice":
227171
return self._build(GPTDatasetSlice)
228172

229173

230-
@config_class()
174+
@config_class(dynamic_type={GPTSampledDatasetConfig: "sampled"})
231175
class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig):
232176
_abstract = False
233-
type_: typing.ClassVar[str | None] = "sampled"
234177
sampling: GPTSamplingConfig = FieldUpdate()
235178
dataset: GPTSampledDatasetConfig = FieldUpdate()
236179

237180

238-
@config_class()
181+
@config_class(dynamic_type={GPTSampledDatasetConfig: "blended"})
239182
class GPTBlendedDatasetConfig(BlendedDatasetConfig, GPTSampledDatasetConfig):
240183
_abstract: typing.ClassVar[bool] = False
241-
type_: typing.ClassVar[str | None] = "blended"
242184
datasets: list[GPTSampledDatasetConfig] = FieldUpdate()
243185

244186

245-
@config_class()
187+
@config_class(dynamic_type={GPTSampledDatasetConfig: "file"})
246188
class GPTDatasetFromFileConfig(GPTSamplableDatasetConfig):
247189
_abstract: typing.ClassVar[bool] = False
248-
type_: typing.ClassVar[str | None] = "file"
249190
path: pathlib.Path = Field(
250191
default=None,
251192
desc="The path to a dataset config file.",
@@ -281,11 +222,11 @@ def _convert_paths(self, config):
281222
return config
282223

283224

284-
@config_class()
225+
# Add user-friendly names for the configs.
226+
@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated_memmap"})
285227
class GPTConcatenatedMemmapConfig(GPTIndexedDatasetConfig):
286228
# TODO v0.3: Remove.
287229
_abstract: typing.ClassVar[bool] = False
288-
type_: typing.ClassVar[str | None] = "concatenated_memmap"
289230
path: pathlib.Path = Field(
290231
default=None,
291232
desc="The path to a dataset directory.",
@@ -388,14 +329,13 @@ class FimConfig(Config):
388329
)
389330

390331

391-
@config_class()
332+
@config_class(dynamic_type={GPTSampledDatasetConfig: "fim"})
392333
class GPTFimSampledDatasetConfig(GPTSampledDatasetConfig, FimConfig):
393334
"""
394335
Configuration for FIM.
395336
"""
396337

397338
_abstract: typing.ClassVar[bool] = False
398-
type_: typing.ClassVar[str | None] = "fim"
399339

400340
dataset: GPTSampledDatasetConfig = Field(
401341
default=None,
@@ -456,10 +396,9 @@ class GPTLegacyConfig(Config):
456396
)
457397

458398

459-
@config_class()
399+
@config_class(dynamic_type={GPTSampledDatasetConfig: "legacy"})
460400
class GPTLegacyDatasetConfig(GPTSampledDatasetConfig, GPTLegacyConfig):
461401
_abstract: typing.ClassVar[bool] = False
462-
type_: typing.ClassVar[str | None] = "legacy"
463402

464403
def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset:
465404

@@ -538,15 +477,14 @@ def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset:
538477
return GPTSampledDatasetConfig.from_dict(dataset_config).build_and_sample(sampling)
539478

540479

541-
@config_class()
480+
@config_class(dynamic_type={GPTSampledDatasetConfig: "test_slow"})
542481
class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig):
543482
"""
544483
A mock dataset that mimics a slow dataset creation on one rank, which may trigger a timeout.
545484
"""
546485

547486
# TODO: This belongs to a testing plugin.
548487
_abstract: typing.ClassVar[bool] = False
549-
type_: typing.ClassVar[str | None] = "test_slow"
550488
sleep: float = Field(
551489
default=1,
552490
desc="Sleep time during build, in seconds.",

fast_llm/layers/common/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class NormalizationType(str, enum.Enum):
3333
rms_norm = "rms_norm"
3434

3535

36-
@config_class()
36+
@config_class(registry=True)
3737
class NormalizationConfig(BaseModelConfig):
3838
_abstract = False
3939

@@ -107,6 +107,12 @@ def _from_dict(
107107
return super()._from_dict(default, strict, flat)
108108

109109

110+
for name in NormalizationType:
111+
# We need this because we are using the reserved field name `type`.
112+
# TODO: Implement proper dynamic typing.
113+
NormalizationConfig.register_subclass(name.value, NormalizationConfig)
114+
115+
110116
class PeftType(str, enum.Enum):
111117
# TODO : Use a dynamic config type instead.
112118
none = "none"

0 commit comments

Comments
 (0)