diff --git a/fast_llm/config.py b/fast_llm/config.py index 4928cdbd..380100e3 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -12,7 +12,7 @@ import yaml -from fast_llm.utils import Assert, Tag, compare_nested, get_type_name, header, log +from fast_llm.utils import Assert, Registry, Tag, compare_nested, get_type_name, header, log logger = logging.getLogger(__name__) @@ -243,7 +243,9 @@ def _process_config_class(cls: type["Config"]): return cls -def config_class[T: Config]() -> typing.Callable[[type[T]], type[T]]: +def config_class[ + T: Config +](registry: bool = False, dynamic_type: "dict[type[Config], str]|None" = None) -> typing.Callable[[type[T]], type[T]]: """ Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications. """ @@ -253,7 +255,7 @@ def wrap(cls): if hasattr(cls, "__post_init__"): raise TypeError(f"`__post_init__` should not be implemented for `Config` classes") - wrapped = _process_config_class(dataclasses.dataclass(cls, kw_only=True)) + wrapped = _process_config_class(dataclasses.dataclass(cls, kw_only=True, repr=False)) wrapped_init = cls.__init__ @@ -267,6 +269,14 @@ def __init__(self, **kwargs): self.validate() wrapped.__init__ = __init__ + + wrapped._registry = Registry[str, type[wrapped]](wrapped.__name__, {}) if registry else None + + if dynamic_type is not None: + for cls_, name in dynamic_type.items(): + print(cls_, name, wrapped) + cls_.register_subclass(name, wrapped) + return wrapped return wrap @@ -305,6 +315,9 @@ class Config(metaclass=ConfigMeta): # without them being automatically added to `_explicit_fields`. _setting_implicit_default: bool | None = Field(init=False) + # A registry for all the config classes. + _registry: typing.ClassVar[Registry[str, type[typing.Self]] | None] = None + def __setattr__(self, key: str, value: typing.Any) -> None: """ Make the class read-only after validation. @@ -358,6 +371,17 @@ def validate[T: Config](self: T, *, _is_validating: bool = False) -> T: Validate a class and mark it as read-only This should not be overridden in derived classes. """ + # Should be handled in `from_dict`, but can fail if instantiating directly. + try: + expected_class = self.get_subclass(self.type) + except KeyError as e: + # Delayed instantiation error in `from_dict`. + raise ValidationError(*e.args) + + if expected_class is not None: + # Should be handled in `from_dict`, but can fail if instantiating directly. + Assert.is_(self.__class__, expected_class) + if not self._validated: try: self._validate() @@ -738,6 +762,14 @@ def _from_dict( if "__class__" in default: del default["__class__"] + try: + actual_cls = cls.get_subclass(default.get("type")) + if actual_cls is not None and actual_cls is not cls: + return actual_cls._from_dict(default, strict=strict, flat=flat) + except KeyError: + # Postpone error to validation. + pass + # Do not validate yet in case the root class sets cross-dependencies in validation. with NoAutoValidate(): for name, field in cls.fields(): @@ -864,6 +896,42 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ ) return None + @classmethod + def register_subclass(cls, name: str, cls_: type[typing.Self]) -> None: + Assert.custom(issubclass, cls_, cls) + if cls._registry is None: + raise NotImplementedError(f"Subclass `{cls.__name__}` doesn't have a registry..") + if name in cls._registry: + old_cls = cls._registry[name] + if old_cls.__name__ == cls_.__name__ and cls._registry[name].__module__ == cls_.__module__: + del cls._registry[name] + else: + raise KeyError(f"{cls.__name__} class registry already has an entry {name} from class {cls.__name__}.") + cls._registry[name] = cls_ + + @classmethod + def get_subclass(cls, name: str | None): + # TODO: Make it case-insensitive? + if name is None: + return None + cls_ = None + for base_class in cls.__mro__: + if issubclass(base_class, Config) and base_class._registry is not None and name in base_class._registry: + if cls_ is None: + cls_ = base_class._registry[name] + if not issubclass(cls_, cls): + raise KeyError(f" {cls_.__name__} is not a subclass of {cls.__name__} (from type {name})") + elif base_class._registry[name] is not cls_: + # We explicitly prevent ambiguous classes to ensure safe and unambiguous serialization. + # TODO: Only really need to avoid conflict with `Config`'s registry, relax this a bit? + raise KeyError( + f"Ambiguous type `{name}` for base class {cls.__name__}." + f" ({cls_.__name__} vs {base_class._registry[name]})" + ) + if cls_ is None: + raise KeyError(f"Unknown type {name} for base class {cls.__name__}") + return cls_ + def __init_subclass__(cls): """ We need to postpone validation until the class has been processed by the dataclass wrapper. @@ -913,6 +981,13 @@ def __init_subclass__(cls): # dataclasses expects an annotation, so we use the one from the base class. cls.__annotations__[name] = base_class_field.type + # Type for the field. At the end of class definition to avoid shadowing builtin. + type: str | None = Field( + default=None, + desc="The config class name.", + hint=FieldHint.feature, + ) + class Configurable[ConfigType: Config]: config_class: typing.ClassVar[type[Config]] = Config diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 90e9c573..ae87e0e7 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -23,7 +23,7 @@ SamplingParameters, ) from fast_llm.engine.distributed.config import PhaseType -from fast_llm.utils import Assert, Registry, normalize_probabilities, padded_cumsum +from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum if typing.TYPE_CHECKING: from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset @@ -93,61 +93,9 @@ class GPTSamplingData(SamplingData): truncate_documents: bool = True -@config_class() +@config_class(registry=True) class GPTSampledDatasetConfig(SampledDatasetConfig): - - # TODO: Generalize dynamic types? - _registry: typing.ClassVar[Registry[str, type["GPTSampledDatasetConfig"]]] = Registry[ - str, type["GPTDatasetConfig"] - ]("gpt_dataset_class", {}) - type_: typing.ClassVar[str | None] = None - type: str | None = Field( - default=None, - desc="The type of dataset.", - hint=FieldHint.core, - ) - - def _validate(self) -> None: - if self.type is None: - self.type = self.type_ - # Should be handled in `from_dict`, but can fail if instantiating directly. - Assert.eq(self.type, self.__class__.type_) - super()._validate() - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - type_ = default.get("type") - if type_ is None: - actual_cls = cls - else: - if type_ not in cls._registry: - raise ValueError( - f"Unknown {cls._registry.name} type {type_}." f" Available types: {list(cls._registry.keys())}" - ) - actual_cls = cls._registry[type_] - Assert.custom(issubclass, actual_cls, cls) - if actual_cls == cls: - return super()._from_dict(default, strict=strict, flat=flat) - else: - return actual_cls._from_dict(default, strict=strict, flat=flat) - - def __init_subclass__(cls) -> None: - if cls._abstract and cls.type_ is not None: - # Abstract classes should not have a `type_` - raise ValueError(f"Abstract class {cls.__name__} has type = {cls.type_}, expected None.") - if cls.type_ is not None: - if cls.type_ in cls._registry: - raise ValueError( - f"Registry {cls._registry.name} already contains type {cls.type_}." - f" Make sure all classes either have a unique or `None` type." - ) - GPTSampledDatasetConfig._registry[cls.type_] = cls - super().__init_subclass__() + pass @config_class() @@ -161,10 +109,9 @@ def build(self) -> "GPTIndexedDataset": raise NotImplementedError() -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "random"}) class GPTRandomDatasetConfig(GPTSamplableDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "random" name: str = Field( default="dummy", desc="The name of the dataset.", @@ -177,10 +124,9 @@ def build(self) -> "GPTRandomDataset": return GPTRandomDataset(self.name) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "memmap"}) class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "memmap" path: pathlib.Path = Field( default=None, desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.", @@ -203,10 +149,9 @@ def build(self) -> "GPTMemmapDataset": return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated"}) class GPTConcatenatedDatasetConfig(ConcatenatedDatasetConfig, GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "concatenated" datasets: list[GPTIndexedDatasetConfig] = FieldUpdate() def build(self) -> "GPTConcatenatedDataset": @@ -215,10 +160,9 @@ def build(self) -> "GPTConcatenatedDataset": return self._build(GPTConcatenatedDataset) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "slice"}) class GPTDatasetSliceConfig(DatasetSliceConfig, GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "slice" dataset: GPTIndexedDatasetConfig = FieldUpdate() def build(self) -> "GPTDatasetSlice": @@ -227,25 +171,22 @@ def build(self) -> "GPTDatasetSlice": return self._build(GPTDatasetSlice) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "sampled"}) class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig): _abstract = False - type_: typing.ClassVar[str | None] = "sampled" sampling: GPTSamplingConfig = FieldUpdate() dataset: GPTSampledDatasetConfig = FieldUpdate() -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "blended"}) class GPTBlendedDatasetConfig(BlendedDatasetConfig, GPTSampledDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "blended" datasets: list[GPTSampledDatasetConfig] = FieldUpdate() -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "file"}) class GPTDatasetFromFileConfig(GPTSamplableDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "file" path: pathlib.Path = Field( default=None, desc="The path to a dataset config file.", @@ -281,11 +222,11 @@ def _convert_paths(self, config): return config -@config_class() +# Add user-friendly names for the configs. +@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated_memmap"}) class GPTConcatenatedMemmapConfig(GPTIndexedDatasetConfig): # TODO v0.3: Remove. _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "concatenated_memmap" path: pathlib.Path = Field( default=None, desc="The path to a dataset directory.", @@ -388,14 +329,13 @@ class FimConfig(Config): ) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "fim"}) class GPTFimSampledDatasetConfig(GPTSampledDatasetConfig, FimConfig): """ Configuration for FIM. """ _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "fim" dataset: GPTSampledDatasetConfig = Field( default=None, @@ -456,10 +396,9 @@ class GPTLegacyConfig(Config): ) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "legacy"}) class GPTLegacyDatasetConfig(GPTSampledDatasetConfig, GPTLegacyConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "legacy" def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset: @@ -538,7 +477,7 @@ def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset: return GPTSampledDatasetConfig.from_dict(dataset_config).build_and_sample(sampling) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "test_slow"}) class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig): """ A mock dataset that mimics a slow dataset creation on one rank, which may trigger a timeout. @@ -546,7 +485,6 @@ class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig): # TODO: This belongs to a testing plugin. _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "test_slow" sleep: float = Field( default=1, desc="Sleep time during build, in seconds.", diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 269989ce..054c26c3 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -33,7 +33,7 @@ class NormalizationType(str, enum.Enum): rms_norm = "rms_norm" -@config_class() +@config_class(registry=True) class NormalizationConfig(BaseModelConfig): _abstract = False @@ -107,6 +107,12 @@ def _from_dict( return super()._from_dict(default, strict, flat) +for name in NormalizationType: + # We need this because we are using the reserved field name `type`. + # TODO: Implement proper dynamic typing. + NormalizationConfig.register_subclass(name.value, NormalizationConfig) + + class PeftType(str, enum.Enum): # TODO : Use a dynamic config type instead. none = "none" diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index c621139c..e7ef0b15 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -95,7 +95,7 @@ class RotaryEmbeddingType(str, enum.Enum): yarn = "yarn" -@config_class() +@config_class(registry=True) class RotaryConfig(BaseModelConfig): _abstract = False type: RotaryEmbeddingType = Field( @@ -158,6 +158,12 @@ def _validate(self) -> None: warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.") +for name in RotaryEmbeddingType: + # We need this because we are using the reserved field name `type`. + # TODO: Implement proper dynamic typing. + RotaryConfig.register_subclass(name.value, RotaryConfig) + + class AddLinearBiasChoices(str, enum.Enum): nowhere = "nowhere" everywhere = "everywhere" @@ -175,7 +181,7 @@ class TransformerSubLayerName(str, enum.Enum): mlp_2 = "mlp_2" -@config_class() +@config_class(registry=True) class TransformerPeftConfig(PeftConfig): layers: list[TransformerSubLayerName] = Field( default=None, @@ -244,6 +250,12 @@ def _validate(self) -> None: ) +for name in PeftType: + # We need this because we are using the reserved field name `type`. + # TODO: Implement proper dynamic typing. + TransformerPeftConfig.register_subclass(name.value, TransformerPeftConfig) + + @config_class() class TransformerConfig(BaseModelConfig): _abstract = False diff --git a/tests/data/common.py b/tests/data/common.py index 00c3ff20..cacb28e6 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -189,10 +189,9 @@ def validate_indexed_dataset_sampling( return token_ids -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "mock_memmap"}) class MockGPTMemmapDatasetConfig(GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "mock_memmap" num_documents: int | None = Field( default=None, desc="Expected number of documents in the dataset.",