Skip to content

Minimalistic dynamic configs #268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Megatron-LM
168 changes: 114 additions & 54 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import contextlib
import copy
import dataclasses
Expand All @@ -11,7 +12,7 @@

import yaml

from fast_llm.utils import Assert, Tag, compare_nested, get_type_name, header, log
from fast_llm.utils import Assert, Registry, Tag, compare_nested, get_type_name, header, log

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -137,7 +138,6 @@ def __init__(
default=dataclasses.MISSING,
default_factory=dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash=None,
compare: bool = True,
metadata=None,
Expand All @@ -146,12 +146,12 @@ def __init__(
if default is not dataclasses.MISSING and default_factory is not dataclasses.MISSING:
raise ValueError("cannot specify both default and default_factory")
if isinstance(default_factory, type) and issubclass(default_factory, Config):
default_factory = _ConfigFactory(default_factory)
raise ValueError("Config classes should not be used as `default_factory`")
super().__init__(
default=default,
default_factory=default_factory,
init=init,
repr=repr,
repr=False,
hash=hash,
compare=compare,
metadata=metadata,
Expand Down Expand Up @@ -223,20 +223,6 @@ def valid(x):
return valid


class _ConfigFactory:
"""
A dataclass default factory that prevents early validation.
Validation is still done through the parent config if needed.
"""

def __init__(self, factory: typing.Callable[[], "Config"] | type["Config"]):
self._factory = factory

def __call__(self):
with NoAutoValidate():
return self._factory()


class ValidationError(ValueError):
pass

Expand All @@ -257,7 +243,9 @@ def _process_config_class(cls: type["Config"]):
return cls


def config_class(cls=None):
def config_class[
T: Config
](registry: bool = False, dynamic_type: "dict[type[Config], str]|None" = None) -> typing.Callable[[type[T]], type[T]]:
"""
Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications.
"""
Expand All @@ -267,7 +255,7 @@ def wrap(cls):
if hasattr(cls, "__post_init__"):
raise TypeError(f"`__post_init__` should not be implemented for `Config` classes")

wrapped = _process_config_class(dataclasses.dataclass(cls, kw_only=True))
wrapped = _process_config_class(dataclasses.dataclass(cls, kw_only=True, repr=False))

wrapped_init = cls.__init__

Expand All @@ -280,20 +268,31 @@ def __init__(self, **kwargs):
if _AUTO_VALIDATE:
self.validate()

cls.__init__ = __init__
wrapped.__init__ = __init__

wrapped._registry = Registry[str, type[wrapped]](wrapped.__name__, {}) if registry else None

if dynamic_type is not None:
for cls_, name in dynamic_type.items():
print(cls_, name, wrapped)
cls_.register_subclass(name, wrapped)

return wrapped

# See if we're being called as @config_class or @config_class().
if cls is None:
# We're called with parens.
return wrap
return wrap

# We're called as @config_class without parens.
return wrap(cls)

class ConfigMeta(abc.ABCMeta):
def __call__(cls: "type[Config]", **kwargs):
# Always go through `_from_dict` for correct dynamic class selection and nested config instantiation.
if not kwargs.pop("_from_dict_check", False):
# with NoAutoValidate():
return cls._from_dict(kwargs)
return super().__call__(**kwargs)

@dataclasses.dataclass()
class Config:

@dataclasses.dataclass(kw_only=True, repr=False)
class Config(metaclass=ConfigMeta):
"""
An advanced `dataclass` with basic type checking, validation and argparse support.
Typically, a subclass will:
Expand All @@ -307,14 +306,17 @@ class Config:
# Set to true to prevent instantiation.
_abstract: typing.ClassVar[bool] = False
# Keep track of whether an instance has been validated
_validated: bool = Field(init=False, repr=False)
_validated: bool = Field(init=False)
# Keep track of unknown fields so they can be reported during validation.
_unknown_fields: dict[str, typing.Any] = Field(init=False, repr=False)
_unknown_fields: dict[str, typing.Any] = Field(init=False)
# Keep track of explicitly set fields to ensure they get serialized and used as config updates.
_explicit_fields: set[str] = Field(init=False, repr=False)
_explicit_fields: set[str] = Field(init=False)
# Used within `_set_implicit_default` to set implicit defaults for fields
# without them being automatically added to `_explicit_fields`.
_setting_implicit_default: bool | None = Field(init=False, repr=False)
_setting_implicit_default: bool | None = Field(init=False)

# A registry for all the config classes.
_registry: typing.ClassVar[Registry[str, type[typing.Self]] | None] = None

def __setattr__(self, key: str, value: typing.Any) -> None:
"""
Expand All @@ -339,7 +341,7 @@ def __setattr__(self, key: str, value: typing.Any) -> None:
)
else:
field = self.get_field(key)
if field.init and field._field_type != dataclasses._FIELD_CLASSVAR:
if field.init and field._field_type == dataclasses._FIELD:
# Adding to explicit field list except within `_set_implicit_default` context,
# during dataclass initialization (`_setting_implicit_default` not yet set)
# and during automated config validation (`_setting_implicit_default=None`)
Expand All @@ -358,17 +360,28 @@ def __delattr__(self, key: str) -> None:
super().__delattr__(key)

@contextlib.contextmanager
def _set_implicit_default(self, _value: bool | int = True):
def _set_implicit_default(self, _value: bool | None = True):
assert self._setting_implicit_default is False
self._setting_implicit_default = _value
yield
self._setting_implicit_default = False

def validate[T](self: T, *, _is_validating: bool = False) -> T:
def validate[T: Config](self: T, *, _is_validating: bool = False) -> T:
"""
Validate a class and mark it as read-only
This should not be overridden in derived classes.
"""
# Should be handled in `from_dict`, but can fail if instantiating directly.
try:
expected_class = self.get_subclass(self.type)
except KeyError as e:
# Delayed instantiation error in `from_dict`.
raise ValidationError(*e.args)

if expected_class is not None:
# Should be handled in `from_dict`, but can fail if instantiating directly.
Assert.is_(self.__class__, expected_class)

if not self._validated:
try:
self._validate()
Expand All @@ -388,11 +401,16 @@ def _validate(self) -> None:
Can be extended to add custom post-processing (typically before the super() call)
and validation (typically after)
"""
self._check_abstract()
if self._abstract:
raise ValidationError(f"{type(self).__name__} is abstract")
if not self.__class_validated__:
raise ValidationError(
f"{type(self).__name__} hasn't been validated. Make sure to use the @config_class decorator."
)
errors = []
with self._set_implicit_default(None):
for name, field in self.fields():
if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa
if not field.init or field._field_type != dataclasses._FIELD: # noqa
continue
value = getattr(self, name)
if isinstance(value, Tag):
Expand Down Expand Up @@ -610,11 +628,7 @@ def _add_field_to_args(
all_fields: bool = False,
serializable: bool = True,
) -> None:
if (
field is not None
and (not field.init or field._field_type == dataclasses._FIELD_CLASSVAR)
and not all_fields
):
if field is not None and (not field.init or field._field_type != dataclasses._FIELD) and not all_fields:
# Exclude class variables and derived fields unless requested explicitly.
return
explicit_field = (
Expand Down Expand Up @@ -677,6 +691,9 @@ def to_copy[
) -> T:
return self.from_dict(self, *updates, strict=strict, update_type=update_type)

def __repr__(self):
return self.to_logs(log_fn=str)

def to_logs[
T
](
Expand Down Expand Up @@ -739,16 +756,24 @@ def _from_dict(
flat: bool = False,
) -> typing.Self:
# TODO v0.3: Remove flat format
out_arg_dict = {}
out_arg_dict = {"_from_dict_check": True}

# TODO v0.3: Remove backward compatibility fix
if "__class__" in default:
del default["__class__"]

try:
actual_cls = cls.get_subclass(default.get("type"))
if actual_cls is not None and actual_cls is not cls:
return actual_cls._from_dict(default, strict=strict, flat=flat)
except KeyError:
# Postpone error to validation.
pass

# Do not validate yet in case the root class sets cross-dependencies in validation.
with NoAutoValidate():
for name, field in cls.fields():
if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa
if not field.init or field._field_type != dataclasses._FIELD: # noqa
continue
if flat:
if isinstance(field.type, type) and issubclass(field.type, Config):
Expand Down Expand Up @@ -869,22 +894,51 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ
f"Config comparison errors:\n " + "\n".join(errors),
log_fn=log_fn,
)
return None

@classmethod
def _check_abstract(cls) -> None:
if cls._abstract:
raise ValidationError(f"{cls.__name__} is abstract")
if not cls.__class_validated__:
raise ValidationError(
f"{cls.__name__} hasn't been validated. Make sure to use the @config_class decorator."
)
def register_subclass(cls, name: str, cls_: type[typing.Self]) -> None:
Assert.custom(issubclass, cls_, cls)
if cls._registry is None:
raise NotImplementedError(f"Subclass `{cls.__name__}` doesn't have a registry..")
if name in cls._registry:
old_cls = cls._registry[name]
if old_cls.__name__ == cls_.__name__ and cls._registry[name].__module__ == cls_.__module__:
del cls._registry[name]
else:
raise KeyError(f"{cls.__name__} class registry already has an entry {name} from class {cls.__name__}.")
cls._registry[name] = cls_

@classmethod
def get_subclass(cls, name: str | None):
# TODO: Make it case-insensitive?
if name is None:
return None
cls_ = None
for base_class in cls.__mro__:
if issubclass(base_class, Config) and base_class._registry is not None and name in base_class._registry:
if cls_ is None:
cls_ = base_class._registry[name]
if not issubclass(cls_, cls):
raise KeyError(f" {cls_.__name__} is not a subclass of {cls.__name__} (from type {name})")
elif base_class._registry[name] is not cls_:
# We explicitly prevent ambiguous classes to ensure safe and unambiguous serialization.
# TODO: Only really need to avoid conflict with `Config`'s registry, relax this a bit?
raise KeyError(
f"Ambiguous type `{name}` for base class {cls.__name__}."
f" ({cls_.__name__} vs {base_class._registry[name]})"
)
if cls_ is None:
raise KeyError(f"Unknown type {name} for base class {cls.__name__}")
return cls_

def __init_subclass__(cls):
"""
We need to postpone validation until the class has been processed by the dataclass wrapper.
"""
Assert.eq(cls.__name__, cls.__qualname__)
for base_class in cls.__mro__:
if issubclass(base_class, Config):
if issubclass(base_class, Config) and base_class is not cls:
assert cls.__class_validated__, (
f"Parent class {get_type_name(base_class)} of config class {get_type_name(cls)} has not been validated."
f" Make sure to use the @config_class decorator."
Expand Down Expand Up @@ -913,7 +967,6 @@ def __init_subclass__(cls):
valid=value.pop("valid", base_class_field.valid),
default=value.pop("default", base_class_field.default),
default_factory=value.pop("default_factory", base_class_field.default_factory),
repr=value.pop("repr", base_class_field.repr),
hash=value.pop("hash", base_class_field.hash),
compare=value.pop("compare", base_class_field.compare),
metadata=value.pop("metadata", base_class_field.metadata),
Expand All @@ -928,6 +981,13 @@ def __init_subclass__(cls):
# dataclasses expects an annotation, so we use the one from the base class.
cls.__annotations__[name] = base_class_field.type

# Type for the field. At the end of class definition to avoid shadowing builtin.
type: str | None = Field(
default=None,
desc="The config class name.",
hint=FieldHint.feature,
)


class Configurable[ConfigType: Config]:
config_class: typing.ClassVar[type[Config]] = Config
Expand Down
4 changes: 1 addition & 3 deletions fast_llm/data/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,4 @@ class DataConfig(Config):
_abstract = True
_sampling_config_class: typing.ClassVar[type[SamplingData]]

sampling: SamplingConfig = Field(
default_factory=SamplingConfig, desc="Default configuration for dataset sampling."
)
sampling: SamplingConfig = Field(desc="Default configuration for dataset sampling.")
3 changes: 1 addition & 2 deletions fast_llm/data/data/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
_abstract = False

tokenizer: TokenizerConfig = Field(
default_factory=TokenizerConfig,
desc="Configuration for the tokenizer (for FIM).",
hint=FieldHint.feature,
)
Expand All @@ -37,7 +36,7 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
desc="Configuration for the dataset(s).",
hint=FieldHint.core,
)
sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig)
sampling: GPTSamplingConfig = FieldUpdate()
data_sample_warn_time_ms: float = Field(
default=1000,
desc="Warn if a sample takes too long to load.",
Expand Down
2 changes: 0 additions & 2 deletions fast_llm/data/dataset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,10 @@ class SampledDatasetUpdateConfig(SampledDatasetConfig):

_abstract = True
sampling: SamplingConfig = Field(
default_factory=SamplingConfig,
desc="Optional override to sampling configuration parameters.",
hint=FieldHint.core,
)
dataset: SampledDatasetConfig = Field(
default_factory=SampledDatasetConfig,
desc="The dataset to sample from.",
hint=FieldHint.core,
)
Expand Down
Loading