Skip to content

Commit de86d4b

Browse files
committed
stuff
1 parent 3ac976b commit de86d4b

File tree

25 files changed

+128
-180
lines changed

25 files changed

+128
-180
lines changed

fast_llm/config.py

Lines changed: 38 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import abc
12
import contextlib
23
import copy
34
import dataclasses
@@ -137,7 +138,6 @@ def __init__(
137138
default=dataclasses.MISSING,
138139
default_factory=dataclasses.MISSING,
139140
init: bool = True,
140-
repr: bool = True,
141141
hash=None,
142142
compare: bool = True,
143143
metadata=None,
@@ -146,12 +146,11 @@ def __init__(
146146
if default is not dataclasses.MISSING and default_factory is not dataclasses.MISSING:
147147
raise ValueError("cannot specify both default and default_factory")
148148
if isinstance(default_factory, type) and issubclass(default_factory, Config):
149-
default_factory = _ConfigFactory(default_factory)
149+
raise ValueError("Config classes should not be used as `default_factory`")
150150
super().__init__(
151151
default=default,
152152
default_factory=default_factory,
153153
init=init,
154-
repr=repr,
155154
hash=hash,
156155
compare=compare,
157156
metadata=metadata,
@@ -223,20 +222,6 @@ def valid(x):
223222
return valid
224223

225224

226-
class _ConfigFactory:
227-
"""
228-
A dataclass default factory that prevents early validation.
229-
Validation is still done through the parent config if needed.
230-
"""
231-
232-
def __init__(self, factory: typing.Callable[[], "Config"] | type["Config"]):
233-
self._factory = factory
234-
235-
def __call__(self):
236-
with NoAutoValidate():
237-
return self._factory()
238-
239-
240225
class ValidationError(ValueError):
241226
pass
242227

@@ -257,7 +242,7 @@ def _process_config_class(cls: type["Config"]):
257242
return cls
258243

259244

260-
def config_class(cls=None):
245+
def config_class[T: Config]() -> typing.Callable[[type[T]], type[T]]:
261246
"""
262247
Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications.
263248
"""
@@ -280,20 +265,23 @@ def __init__(self, **kwargs):
280265
if _AUTO_VALIDATE:
281266
self.validate()
282267

283-
cls.__init__ = __init__
268+
wrapped.__init__ = __init__
284269
return wrapped
285270

286-
# See if we're being called as @config_class or @config_class().
287-
if cls is None:
288-
# We're called with parens.
289-
return wrap
271+
return wrap
272+
290273

291-
# We're called as @config_class without parens.
292-
return wrap(cls)
274+
class ConfigMeta(abc.ABCMeta):
275+
def __call__(cls: "type[Config]", **kwargs):
276+
# Always go through `_from_dict` for correct dynamic class selection and nested config instantiation.
277+
if not kwargs.pop("_from_dict_check", False):
278+
# with NoAutoValidate():
279+
return cls._from_dict(kwargs)
280+
return super().__call__(**kwargs)
293281

294282

295-
@dataclasses.dataclass()
296-
class Config:
283+
@dataclasses.dataclass(kw_only=True, repr=False)
284+
class Config(metaclass=ConfigMeta):
297285
"""
298286
An advanced `dataclass` with basic type checking, validation and argparse support.
299287
Typically, a subclass will:
@@ -307,14 +295,14 @@ class Config:
307295
# Set to true to prevent instantiation.
308296
_abstract: typing.ClassVar[bool] = False
309297
# Keep track of whether an instance has been validated
310-
_validated: bool = Field(init=False, repr=False)
298+
_validated: bool = Field(init=False)
311299
# Keep track of unknown fields so they can be reported during validation.
312-
_unknown_fields: dict[str, typing.Any] = Field(init=False, repr=False)
300+
_unknown_fields: dict[str, typing.Any] = Field(init=False)
313301
# Keep track of explicitly set fields to ensure they get serialized and used as config updates.
314-
_explicit_fields: set[str] = Field(init=False, repr=False)
302+
_explicit_fields: set[str] = Field(init=False)
315303
# Used within `_set_implicit_default` to set implicit defaults for fields
316304
# without them being automatically added to `_explicit_fields`.
317-
_setting_implicit_default: bool | None = Field(init=False, repr=False)
305+
_setting_implicit_default: bool | None = Field(init=False)
318306

319307
def __setattr__(self, key: str, value: typing.Any) -> None:
320308
"""
@@ -339,7 +327,7 @@ def __setattr__(self, key: str, value: typing.Any) -> None:
339327
)
340328
else:
341329
field = self.get_field(key)
342-
if field.init and field._field_type != dataclasses._FIELD_CLASSVAR:
330+
if field.init and field._field_type == dataclasses._FIELD:
343331
# Adding to explicit field list except within `_set_implicit_default` context,
344332
# during dataclass initialization (`_setting_implicit_default` not yet set)
345333
# and during automated config validation (`_setting_implicit_default=None`)
@@ -358,13 +346,13 @@ def __delattr__(self, key: str) -> None:
358346
super().__delattr__(key)
359347

360348
@contextlib.contextmanager
361-
def _set_implicit_default(self, _value: bool | int = True):
349+
def _set_implicit_default(self, _value: bool | None = True):
362350
assert self._setting_implicit_default is False
363351
self._setting_implicit_default = _value
364352
yield
365353
self._setting_implicit_default = False
366354

367-
def validate[T](self: T, *, _is_validating: bool = False) -> T:
355+
def validate[T: Config](self: T, *, _is_validating: bool = False) -> T:
368356
"""
369357
Validate a class and mark it as read-only
370358
This should not be overridden in derived classes.
@@ -388,11 +376,16 @@ def _validate(self) -> None:
388376
Can be extended to add custom post-processing (typically before the super() call)
389377
and validation (typically after)
390378
"""
391-
self._check_abstract()
379+
if self._abstract:
380+
raise ValidationError(f"{type(self).__name__} is abstract")
381+
if not self.__class_validated__:
382+
raise ValidationError(
383+
f"{type(self).__name__} hasn't been validated. Make sure to use the @config_class decorator."
384+
)
392385
errors = []
393386
with self._set_implicit_default(None):
394387
for name, field in self.fields():
395-
if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa
388+
if not field.init or field._field_type != dataclasses._FIELD: # noqa
396389
continue
397390
value = getattr(self, name)
398391
if isinstance(value, Tag):
@@ -610,11 +603,7 @@ def _add_field_to_args(
610603
all_fields: bool = False,
611604
serializable: bool = True,
612605
) -> None:
613-
if (
614-
field is not None
615-
and (not field.init or field._field_type == dataclasses._FIELD_CLASSVAR)
616-
and not all_fields
617-
):
606+
if field is not None and (not field.init or field._field_type != dataclasses._FIELD) and not all_fields:
618607
# Exclude class variables and derived fields unless requested explicitly.
619608
return
620609
explicit_field = (
@@ -677,6 +666,9 @@ def to_copy[
677666
) -> T:
678667
return self.from_dict(self, *updates, strict=strict, update_type=update_type)
679668

669+
def __repr__(self):
670+
return self.to_logs(log_fn=str)
671+
680672
def to_logs[
681673
T
682674
](
@@ -739,7 +731,7 @@ def _from_dict(
739731
flat: bool = False,
740732
) -> typing.Self:
741733
# TODO v0.3: Remove flat format
742-
out_arg_dict = {}
734+
out_arg_dict = {"_from_dict_check": True}
743735

744736
# TODO v0.3: Remove backward compatibility fix
745737
if "__class__" in default:
@@ -748,7 +740,7 @@ def _from_dict(
748740
# Do not validate yet in case the root class sets cross-dependencies in validation.
749741
with NoAutoValidate():
750742
for name, field in cls.fields():
751-
if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa
743+
if not field.init or field._field_type != dataclasses._FIELD: # noqa
752744
continue
753745
if flat:
754746
if isinstance(field.type, type) and issubclass(field.type, Config):
@@ -869,22 +861,15 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ
869861
f"Config comparison errors:\n " + "\n".join(errors),
870862
log_fn=log_fn,
871863
)
872-
873-
@classmethod
874-
def _check_abstract(cls) -> None:
875-
if cls._abstract:
876-
raise ValidationError(f"{cls.__name__} is abstract")
877-
if not cls.__class_validated__:
878-
raise ValidationError(
879-
f"{cls.__name__} hasn't been validated. Make sure to use the @config_class decorator."
880-
)
864+
return None
881865

882866
def __init_subclass__(cls):
883867
"""
884868
We need to postpone validation until the class has been processed by the dataclass wrapper.
885869
"""
870+
Assert.eq(cls.__name__, cls.__qualname__)
886871
for base_class in cls.__mro__:
887-
if issubclass(base_class, Config):
872+
if issubclass(base_class, Config) and base_class is not cls:
888873
assert cls.__class_validated__, (
889874
f"Parent class {get_type_name(base_class)} of config class {get_type_name(cls)} has not been validated."
890875
f" Make sure to use the @config_class decorator."
@@ -913,7 +898,6 @@ def __init_subclass__(cls):
913898
valid=value.pop("valid", base_class_field.valid),
914899
default=value.pop("default", base_class_field.default),
915900
default_factory=value.pop("default_factory", base_class_field.default_factory),
916-
repr=value.pop("repr", base_class_field.repr),
917901
hash=value.pop("hash", base_class_field.hash),
918902
compare=value.pop("compare", base_class_field.compare),
919903
metadata=value.pop("metadata", base_class_field.metadata),

fast_llm/data/data/config.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,4 @@ class DataConfig(Config):
99
_abstract = True
1010
_sampling_config_class: typing.ClassVar[type[SamplingData]]
1111

12-
sampling: SamplingConfig = Field(
13-
default_factory=SamplingConfig, desc="Default configuration for dataset sampling."
14-
)
12+
sampling: SamplingConfig = Field(desc="Default configuration for dataset sampling.")

fast_llm/data/data/gpt/config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
2727
_abstract = False
2828

2929
tokenizer: TokenizerConfig = Field(
30-
default_factory=TokenizerConfig,
3130
desc="Configuration for the tokenizer (for FIM).",
3231
hint=FieldHint.feature,
3332
)
@@ -37,7 +36,7 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
3736
desc="Configuration for the dataset(s).",
3837
hint=FieldHint.core,
3938
)
40-
sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig)
39+
sampling: GPTSamplingConfig = FieldUpdate()
4140
data_sample_warn_time_ms: float = Field(
4241
default=1000,
4342
desc="Warn if a sample takes too long to load.",

fast_llm/data/dataset/config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,10 @@ class SampledDatasetUpdateConfig(SampledDatasetConfig):
174174

175175
_abstract = True
176176
sampling: SamplingConfig = Field(
177-
default_factory=SamplingConfig,
178177
desc="Optional override to sampling configuration parameters.",
179178
hint=FieldHint.core,
180179
)
181180
dataset: SampledDatasetConfig = Field(
182-
default_factory=SampledDatasetConfig,
183181
desc="The dataset to sample from.",
184182
hint=FieldHint.core,
185183
)

fast_llm/data/dataset/gpt/config.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,8 @@ def build(self) -> "GPTDatasetSlice":
231231
class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig):
232232
_abstract = False
233233
type_: typing.ClassVar[str | None] = "sampled"
234-
sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig)
235-
dataset: GPTSampledDatasetConfig = FieldUpdate(default_factory=GPTSampledDatasetConfig)
234+
sampling: GPTSamplingConfig = FieldUpdate()
235+
dataset: GPTSampledDatasetConfig = FieldUpdate()
236236

237237

238238
@config_class()
@@ -451,7 +451,6 @@ class GPTLegacyConfig(Config):
451451
valid=_validate_path,
452452
)
453453
fim: FimConfig = Field(
454-
default_factory=FimConfig,
455454
desc="Configuration for Fill In the Middle (FIM).",
456455
hint=FieldHint.feature,
457456
)

fast_llm/data/preparator/gpt_memmap/config.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00"
2525

2626

27-
@config_class
27+
@config_class()
2828
class GPTHuggingfaceDatasetConfig(Config):
2929
path: str = Field(
3030
default=None,
@@ -59,12 +59,6 @@ class GPTHuggingfaceDatasetConfig(Config):
5959
loss_masking_spans: None | str = Field(
6060
default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional
6161
)
62-
chosen_text: None | str = Field(
63-
default=None, desc="Field containing chosen text for preference optimization", hint=FieldHint.optional
64-
)
65-
rejected_text: None | str = Field(
66-
default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional
67-
)
6862
data_type: DataType | None = Field(
6963
default=None,
7064
desc="Data type of the dataset field."
@@ -83,7 +77,7 @@ class GPTHuggingfaceDatasetConfig(Config):
8377
)
8478

8579

86-
@config_class
80+
@config_class()
8781
class DatasetPreparatorDistributedConfig(Config):
8882
# TODO: Unify with fast_llm.engine.distributed.config.DistributedConfig
8983

@@ -126,7 +120,6 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig):
126120
hint=FieldHint.core,
127121
)
128122
distributed: DatasetPreparatorDistributedConfig = Field(
129-
default_factory=DatasetPreparatorDistributedConfig,
130123
desc="Configuration for distributed processing.",
131124
hint=FieldHint.feature,
132125
)
@@ -155,12 +148,10 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig):
155148
valid=check_field(Assert.geq, 1),
156149
)
157150
dataset: GPTHuggingfaceDatasetConfig = Field(
158-
default_factory=GPTHuggingfaceDatasetConfig,
159151
desc="Configuration for the dataset.",
160152
hint=FieldHint.feature,
161153
)
162154
tokenizer: TokenizerConfig = Field(
163-
default_factory=TokenizerConfig,
164155
desc="Configuration for the tokenizer.",
165156
hint=FieldHint.feature,
166157
)

fast_llm/engine/base_model/base_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __init__(
9090
config: BaseModelConfig,
9191
distributed_config: DistributedConfig,
9292
):
93-
self._tensor_space = TensorSpace(distributed_config)
93+
self._tensor_space: TensorSpace = TensorSpace(distributed_config)
9494
config.setup_tensor_space(self._tensor_space)
9595

9696
super().__init__(config)

fast_llm/engine/base_model/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def _get_architecture(self) -> dict[str, typing.Any]:
4242
assert isinstance(field, Field), f"{name}, {field}"
4343
if field.hint == FieldHint.architecture:
4444
architecture[name] = self._serialize_architecture_field(getattr(self, name, MISSING))
45+
return architecture
4546

4647
def _serialize_architecture_field(self, value: typing.Any) -> typing.Any:
4748
if isinstance(value, BaseModelConfig):

fast_llm/engine/config_utils/run.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@
2020

2121
@config_class()
2222
class RunConfig(Config):
23-
tensor_logs: TensorLogsConfig = Field(
24-
default_factory=TensorLogsConfig, desc="Configuration for debug tensor logs.", hint=FieldHint.logging
25-
)
23+
tensor_logs: TensorLogsConfig = Field(desc="Configuration for debug tensor logs.", hint=FieldHint.logging)
2624
# TODO v0.3: Adjust (now only affects logging to file).
2725
structured_logs: bool = Field(
2826
default=True, desc="Configure logging to the Fast-LLM format.", hint=FieldHint.logging
@@ -70,9 +68,7 @@ def _validate(self):
7068

7169
@config_class()
7270
class ExperimentConfig(RunnableConfig):
73-
run: RunConfig = Field(
74-
default_factory=RunConfig, desc="Global properties for the experiment.", hint=FieldHint.core
75-
)
71+
run: RunConfig = Field(desc="Global properties for the experiment.", hint=FieldHint.core)
7672

7773
def _show(
7874
self,

0 commit comments

Comments
 (0)