Skip to content

Misc improvements and fixes #266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 39 additions & 54 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import contextlib
import copy
import dataclasses
Expand Down Expand Up @@ -137,7 +138,6 @@ def __init__(
default=dataclasses.MISSING,
default_factory=dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash=None,
compare: bool = True,
metadata=None,
Expand All @@ -146,12 +146,12 @@ def __init__(
if default is not dataclasses.MISSING and default_factory is not dataclasses.MISSING:
raise ValueError("cannot specify both default and default_factory")
if isinstance(default_factory, type) and issubclass(default_factory, Config):
default_factory = _ConfigFactory(default_factory)
raise ValueError("Config classes should not be used as `default_factory`")
super().__init__(
default=default,
default_factory=default_factory,
init=init,
repr=repr,
repr=False,
hash=hash,
compare=compare,
metadata=metadata,
Expand Down Expand Up @@ -223,20 +223,6 @@ def valid(x):
return valid


class _ConfigFactory:
"""
A dataclass default factory that prevents early validation.
Validation is still done through the parent config if needed.
"""

def __init__(self, factory: typing.Callable[[], "Config"] | type["Config"]):
self._factory = factory

def __call__(self):
with NoAutoValidate():
return self._factory()


class ValidationError(ValueError):
pass

Expand All @@ -257,7 +243,7 @@ def _process_config_class(cls: type["Config"]):
return cls


def config_class(cls=None):
def config_class[T: Config]() -> typing.Callable[[type[T]], type[T]]:
"""
Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications.
"""
Expand All @@ -280,20 +266,23 @@ def __init__(self, **kwargs):
if _AUTO_VALIDATE:
self.validate()

cls.__init__ = __init__
wrapped.__init__ = __init__
return wrapped

# See if we're being called as @config_class or @config_class().
if cls is None:
# We're called with parens.
return wrap
return wrap


# We're called as @config_class without parens.
return wrap(cls)
class ConfigMeta(abc.ABCMeta):
def __call__(cls: "type[Config]", **kwargs):
# Always go through `_from_dict` for correct dynamic class selection and nested config instantiation.
if not kwargs.pop("_from_dict_check", False):
# with NoAutoValidate():
return cls._from_dict(kwargs)
return super().__call__(**kwargs)


@dataclasses.dataclass()
class Config:
@dataclasses.dataclass(kw_only=True, repr=False)
class Config(metaclass=ConfigMeta):
"""
An advanced `dataclass` with basic type checking, validation and argparse support.
Typically, a subclass will:
Expand All @@ -307,14 +296,14 @@ class Config:
# Set to true to prevent instantiation.
_abstract: typing.ClassVar[bool] = False
# Keep track of whether an instance has been validated
_validated: bool = Field(init=False, repr=False)
_validated: bool = Field(init=False)
# Keep track of unknown fields so they can be reported during validation.
_unknown_fields: dict[str, typing.Any] = Field(init=False, repr=False)
_unknown_fields: dict[str, typing.Any] = Field(init=False)
# Keep track of explicitly set fields to ensure they get serialized and used as config updates.
_explicit_fields: set[str] = Field(init=False, repr=False)
_explicit_fields: set[str] = Field(init=False)
# Used within `_set_implicit_default` to set implicit defaults for fields
# without them being automatically added to `_explicit_fields`.
_setting_implicit_default: bool | None = Field(init=False, repr=False)
_setting_implicit_default: bool | None = Field(init=False)

def __setattr__(self, key: str, value: typing.Any) -> None:
"""
Expand All @@ -339,7 +328,7 @@ def __setattr__(self, key: str, value: typing.Any) -> None:
)
else:
field = self.get_field(key)
if field.init and field._field_type != dataclasses._FIELD_CLASSVAR:
if field.init and field._field_type == dataclasses._FIELD:
# Adding to explicit field list except within `_set_implicit_default` context,
# during dataclass initialization (`_setting_implicit_default` not yet set)
# and during automated config validation (`_setting_implicit_default=None`)
Expand All @@ -358,13 +347,13 @@ def __delattr__(self, key: str) -> None:
super().__delattr__(key)

@contextlib.contextmanager
def _set_implicit_default(self, _value: bool | int = True):
def _set_implicit_default(self, _value: bool | None = True):
assert self._setting_implicit_default is False
self._setting_implicit_default = _value
yield
self._setting_implicit_default = False

def validate[T](self: T, *, _is_validating: bool = False) -> T:
def validate[T: Config](self: T, *, _is_validating: bool = False) -> T:
"""
Validate a class and mark it as read-only
This should not be overridden in derived classes.
Expand All @@ -388,11 +377,16 @@ def _validate(self) -> None:
Can be extended to add custom post-processing (typically before the super() call)
and validation (typically after)
"""
self._check_abstract()
if self._abstract:
raise ValidationError(f"{type(self).__name__} is abstract")
if not self.__class_validated__:
raise ValidationError(
f"{type(self).__name__} hasn't been validated. Make sure to use the @config_class decorator."
)
errors = []
with self._set_implicit_default(None):
for name, field in self.fields():
if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa
if not field.init or field._field_type != dataclasses._FIELD: # noqa
continue
value = getattr(self, name)
if isinstance(value, Tag):
Expand Down Expand Up @@ -610,11 +604,7 @@ def _add_field_to_args(
all_fields: bool = False,
serializable: bool = True,
) -> None:
if (
field is not None
and (not field.init or field._field_type == dataclasses._FIELD_CLASSVAR)
and not all_fields
):
if field is not None and (not field.init or field._field_type != dataclasses._FIELD) and not all_fields:
# Exclude class variables and derived fields unless requested explicitly.
return
explicit_field = (
Expand Down Expand Up @@ -677,6 +667,9 @@ def to_copy[
) -> T:
return self.from_dict(self, *updates, strict=strict, update_type=update_type)

def __repr__(self):
return self.to_logs(log_fn=str)

def to_logs[
T
](
Expand Down Expand Up @@ -739,7 +732,7 @@ def _from_dict(
flat: bool = False,
) -> typing.Self:
# TODO v0.3: Remove flat format
out_arg_dict = {}
out_arg_dict = {"_from_dict_check": True}

# TODO v0.3: Remove backward compatibility fix
if "__class__" in default:
Expand All @@ -748,7 +741,7 @@ def _from_dict(
# Do not validate yet in case the root class sets cross-dependencies in validation.
with NoAutoValidate():
for name, field in cls.fields():
if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa
if not field.init or field._field_type != dataclasses._FIELD: # noqa
continue
if flat:
if isinstance(field.type, type) and issubclass(field.type, Config):
Expand Down Expand Up @@ -869,22 +862,15 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ
f"Config comparison errors:\n " + "\n".join(errors),
log_fn=log_fn,
)

@classmethod
def _check_abstract(cls) -> None:
if cls._abstract:
raise ValidationError(f"{cls.__name__} is abstract")
if not cls.__class_validated__:
raise ValidationError(
f"{cls.__name__} hasn't been validated. Make sure to use the @config_class decorator."
)
return None

def __init_subclass__(cls):
"""
We need to postpone validation until the class has been processed by the dataclass wrapper.
"""
Assert.eq(cls.__name__, cls.__qualname__)
for base_class in cls.__mro__:
if issubclass(base_class, Config):
if issubclass(base_class, Config) and base_class is not cls:
assert cls.__class_validated__, (
f"Parent class {get_type_name(base_class)} of config class {get_type_name(cls)} has not been validated."
f" Make sure to use the @config_class decorator."
Expand Down Expand Up @@ -913,7 +899,6 @@ def __init_subclass__(cls):
valid=value.pop("valid", base_class_field.valid),
default=value.pop("default", base_class_field.default),
default_factory=value.pop("default_factory", base_class_field.default_factory),
repr=value.pop("repr", base_class_field.repr),
hash=value.pop("hash", base_class_field.hash),
compare=value.pop("compare", base_class_field.compare),
metadata=value.pop("metadata", base_class_field.metadata),
Expand Down
4 changes: 1 addition & 3 deletions fast_llm/data/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,4 @@ class DataConfig(Config):
_abstract = True
_sampling_config_class: typing.ClassVar[type[SamplingData]]

sampling: SamplingConfig = Field(
default_factory=SamplingConfig, desc="Default configuration for dataset sampling."
)
sampling: SamplingConfig = Field(desc="Default configuration for dataset sampling.")
3 changes: 1 addition & 2 deletions fast_llm/data/data/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
_abstract = False

tokenizer: TokenizerConfig = Field(
default_factory=TokenizerConfig,
desc="Configuration for the tokenizer (for FIM).",
hint=FieldHint.feature,
)
Expand All @@ -37,7 +36,7 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
desc="Configuration for the dataset(s).",
hint=FieldHint.core,
)
sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig)
sampling: GPTSamplingConfig = FieldUpdate()
data_sample_warn_time_ms: float = Field(
default=1000,
desc="Warn if a sample takes too long to load.",
Expand Down
2 changes: 0 additions & 2 deletions fast_llm/data/dataset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,10 @@ class SampledDatasetUpdateConfig(SampledDatasetConfig):

_abstract = True
sampling: SamplingConfig = Field(
default_factory=SamplingConfig,
desc="Optional override to sampling configuration parameters.",
hint=FieldHint.core,
)
dataset: SampledDatasetConfig = Field(
default_factory=SampledDatasetConfig,
desc="The dataset to sample from.",
hint=FieldHint.core,
)
Expand Down
5 changes: 2 additions & 3 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ def build(self) -> "GPTDatasetSlice":
class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig):
_abstract = False
type_: typing.ClassVar[str | None] = "sampled"
sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig)
dataset: GPTSampledDatasetConfig = FieldUpdate(default_factory=GPTSampledDatasetConfig)
sampling: GPTSamplingConfig = FieldUpdate()
dataset: GPTSampledDatasetConfig = FieldUpdate()


@config_class()
Expand Down Expand Up @@ -450,7 +450,6 @@ class GPTLegacyConfig(Config):
valid=_validate_path,
)
fim: FimConfig = Field(
default_factory=FimConfig,
desc="Configuration for Fill In the Middle (FIM).",
hint=FieldHint.feature,
)
Expand Down
7 changes: 2 additions & 5 deletions fast_llm/data/preparator/gpt_memmap/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00"


@config_class
@config_class()
class GPTHuggingfaceDatasetConfig(Config):
path: str = Field(
default=None,
Expand Down Expand Up @@ -77,7 +77,7 @@ class GPTHuggingfaceDatasetConfig(Config):
)


@config_class
@config_class()
class DatasetPreparatorDistributedConfig(Config):
# TODO: Unify with fast_llm.engine.distributed.config.DistributedConfig

Expand Down Expand Up @@ -120,7 +120,6 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig):
hint=FieldHint.core,
)
distributed: DatasetPreparatorDistributedConfig = Field(
default_factory=DatasetPreparatorDistributedConfig,
desc="Configuration for distributed processing.",
hint=FieldHint.feature,
)
Expand Down Expand Up @@ -149,12 +148,10 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig):
valid=check_field(Assert.geq, 1),
)
dataset: GPTHuggingfaceDatasetConfig = Field(
default_factory=GPTHuggingfaceDatasetConfig,
desc="Configuration for the dataset.",
hint=FieldHint.feature,
)
tokenizer: TokenizerConfig = Field(
default_factory=TokenizerConfig,
desc="Configuration for the tokenizer.",
hint=FieldHint.feature,
)
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
config: BaseModelConfig,
distributed_config: DistributedConfig,
):
self._tensor_space = TensorSpace(distributed_config)
self._tensor_space: TensorSpace = TensorSpace(distributed_config)
config.setup_tensor_space(self._tensor_space)

super().__init__(config)
Expand Down
1 change: 1 addition & 0 deletions fast_llm/engine/base_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def _get_architecture(self) -> dict[str, typing.Any]:
assert isinstance(field, Field), f"{name}, {field}"
if field.hint == FieldHint.architecture:
architecture[name] = self._serialize_architecture_field(getattr(self, name, MISSING))
return architecture

def _serialize_architecture_field(self, value: typing.Any) -> typing.Any:
if isinstance(value, BaseModelConfig):
Expand Down
8 changes: 2 additions & 6 deletions fast_llm/engine/config_utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@

@config_class()
class RunConfig(Config):
tensor_logs: TensorLogsConfig = Field(
default_factory=TensorLogsConfig, desc="Configuration for debug tensor logs.", hint=FieldHint.logging
)
tensor_logs: TensorLogsConfig = Field(desc="Configuration for debug tensor logs.", hint=FieldHint.logging)
# TODO v0.3: Adjust (now only affects logging to file).
structured_logs: bool = Field(
default=True, desc="Configure logging to the Fast-LLM format.", hint=FieldHint.logging
Expand Down Expand Up @@ -70,9 +68,7 @@ def _validate(self):

@config_class()
class ExperimentConfig(RunnableConfig):
run: RunConfig = Field(
default_factory=RunConfig, desc="Global properties for the experiment.", hint=FieldHint.core
)
run: RunConfig = Field(desc="Global properties for the experiment.", hint=FieldHint.core)

def _show(
self,
Expand Down
Loading