1
+ import abc
1
2
import contextlib
2
3
import copy
3
4
import dataclasses
@@ -137,7 +138,6 @@ def __init__(
137
138
default = dataclasses .MISSING ,
138
139
default_factory = dataclasses .MISSING ,
139
140
init : bool = True ,
140
- repr : bool = True ,
141
141
hash = None ,
142
142
compare : bool = True ,
143
143
metadata = None ,
@@ -146,12 +146,11 @@ def __init__(
146
146
if default is not dataclasses .MISSING and default_factory is not dataclasses .MISSING :
147
147
raise ValueError ("cannot specify both default and default_factory" )
148
148
if isinstance (default_factory , type ) and issubclass (default_factory , Config ):
149
- default_factory = _ConfigFactory ( default_factory )
149
+ raise ValueError ( "Config classes should not be used as ` default_factory`" )
150
150
super ().__init__ (
151
151
default = default ,
152
152
default_factory = default_factory ,
153
153
init = init ,
154
- repr = repr ,
155
154
hash = hash ,
156
155
compare = compare ,
157
156
metadata = metadata ,
@@ -223,20 +222,6 @@ def valid(x):
223
222
return valid
224
223
225
224
226
- class _ConfigFactory :
227
- """
228
- A dataclass default factory that prevents early validation.
229
- Validation is still done through the parent config if needed.
230
- """
231
-
232
- def __init__ (self , factory : typing .Callable [[], "Config" ] | type ["Config" ]):
233
- self ._factory = factory
234
-
235
- def __call__ (self ):
236
- with NoAutoValidate ():
237
- return self ._factory ()
238
-
239
-
240
225
class ValidationError (ValueError ):
241
226
pass
242
227
@@ -257,7 +242,7 @@ def _process_config_class(cls: type["Config"]):
257
242
return cls
258
243
259
244
260
- def config_class ( cls = None ) :
245
+ def config_class [ T : Config ]() -> typing . Callable [[ type [ T ]], type [ T ]] :
261
246
"""
262
247
Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications.
263
248
"""
@@ -280,20 +265,23 @@ def __init__(self, **kwargs):
280
265
if _AUTO_VALIDATE :
281
266
self .validate ()
282
267
283
- cls .__init__ = __init__
268
+ wrapped .__init__ = __init__
284
269
return wrapped
285
270
286
- # See if we're being called as @config_class or @config_class().
287
- if cls is None :
288
- # We're called with parens.
289
- return wrap
271
+ return wrap
272
+
290
273
291
- # We're called as @config_class without parens.
292
- return wrap (cls )
274
+ class ConfigMeta (abc .ABCMeta ):
275
+ def __call__ (cls : "type[Config]" , ** kwargs ):
276
+ # Always go through `_from_dict` for correct dynamic class selection and nested config instantiation.
277
+ if not kwargs .pop ("_from_dict_check" , False ):
278
+ # with NoAutoValidate():
279
+ return cls ._from_dict (kwargs )
280
+ return super ().__call__ (** kwargs )
293
281
294
282
295
- @dataclasses .dataclass ()
296
- class Config :
283
+ @dataclasses .dataclass (kw_only = True , repr = False )
284
+ class Config ( metaclass = ConfigMeta ) :
297
285
"""
298
286
An advanced `dataclass` with basic type checking, validation and argparse support.
299
287
Typically, a subclass will:
@@ -307,14 +295,14 @@ class Config:
307
295
# Set to true to prevent instantiation.
308
296
_abstract : typing .ClassVar [bool ] = False
309
297
# Keep track of whether an instance has been validated
310
- _validated : bool = Field (init = False , repr = False )
298
+ _validated : bool = Field (init = False )
311
299
# Keep track of unknown fields so they can be reported during validation.
312
- _unknown_fields : dict [str , typing .Any ] = Field (init = False , repr = False )
300
+ _unknown_fields : dict [str , typing .Any ] = Field (init = False )
313
301
# Keep track of explicitly set fields to ensure they get serialized and used as config updates.
314
- _explicit_fields : set [str ] = Field (init = False , repr = False )
302
+ _explicit_fields : set [str ] = Field (init = False )
315
303
# Used within `_set_implicit_default` to set implicit defaults for fields
316
304
# without them being automatically added to `_explicit_fields`.
317
- _setting_implicit_default : bool | None = Field (init = False , repr = False )
305
+ _setting_implicit_default : bool | None = Field (init = False )
318
306
319
307
def __setattr__ (self , key : str , value : typing .Any ) -> None :
320
308
"""
@@ -339,7 +327,7 @@ def __setattr__(self, key: str, value: typing.Any) -> None:
339
327
)
340
328
else :
341
329
field = self .get_field (key )
342
- if field .init and field ._field_type != dataclasses ._FIELD_CLASSVAR :
330
+ if field .init and field ._field_type == dataclasses ._FIELD :
343
331
# Adding to explicit field list except within `_set_implicit_default` context,
344
332
# during dataclass initialization (`_setting_implicit_default` not yet set)
345
333
# and during automated config validation (`_setting_implicit_default=None`)
@@ -358,13 +346,13 @@ def __delattr__(self, key: str) -> None:
358
346
super ().__delattr__ (key )
359
347
360
348
@contextlib .contextmanager
361
- def _set_implicit_default (self , _value : bool | int = True ):
349
+ def _set_implicit_default (self , _value : bool | None = True ):
362
350
assert self ._setting_implicit_default is False
363
351
self ._setting_implicit_default = _value
364
352
yield
365
353
self ._setting_implicit_default = False
366
354
367
- def validate [T ](self : T , * , _is_validating : bool = False ) -> T :
355
+ def validate [T : Config ](self : T , * , _is_validating : bool = False ) -> T :
368
356
"""
369
357
Validate a class and mark it as read-only
370
358
This should not be overridden in derived classes.
@@ -388,11 +376,16 @@ def _validate(self) -> None:
388
376
Can be extended to add custom post-processing (typically before the super() call)
389
377
and validation (typically after)
390
378
"""
391
- self ._check_abstract ()
379
+ if self ._abstract :
380
+ raise ValidationError (f"{ type (self ).__name__ } is abstract" )
381
+ if not self .__class_validated__ :
382
+ raise ValidationError (
383
+ f"{ type (self ).__name__ } hasn't been validated. Make sure to use the @config_class decorator."
384
+ )
392
385
errors = []
393
386
with self ._set_implicit_default (None ):
394
387
for name , field in self .fields ():
395
- if not field .init or field ._field_type == dataclasses ._FIELD_CLASSVAR : # noqa
388
+ if not field .init or field ._field_type != dataclasses ._FIELD : # noqa
396
389
continue
397
390
value = getattr (self , name )
398
391
if isinstance (value , Tag ):
@@ -610,11 +603,7 @@ def _add_field_to_args(
610
603
all_fields : bool = False ,
611
604
serializable : bool = True ,
612
605
) -> None :
613
- if (
614
- field is not None
615
- and (not field .init or field ._field_type == dataclasses ._FIELD_CLASSVAR )
616
- and not all_fields
617
- ):
606
+ if field is not None and (not field .init or field ._field_type != dataclasses ._FIELD ) and not all_fields :
618
607
# Exclude class variables and derived fields unless requested explicitly.
619
608
return
620
609
explicit_field = (
@@ -677,6 +666,9 @@ def to_copy[
677
666
) -> T :
678
667
return self .from_dict (self , * updates , strict = strict , update_type = update_type )
679
668
669
+ def __repr__ (self ):
670
+ return self .to_logs (log_fn = str )
671
+
680
672
def to_logs [
681
673
T
682
674
](
@@ -739,7 +731,7 @@ def _from_dict(
739
731
flat : bool = False ,
740
732
) -> typing .Self :
741
733
# TODO v0.3: Remove flat format
742
- out_arg_dict = {}
734
+ out_arg_dict = {"_from_dict_check" : True }
743
735
744
736
# TODO v0.3: Remove backward compatibility fix
745
737
if "__class__" in default :
@@ -748,7 +740,7 @@ def _from_dict(
748
740
# Do not validate yet in case the root class sets cross-dependencies in validation.
749
741
with NoAutoValidate ():
750
742
for name , field in cls .fields ():
751
- if not field .init or field ._field_type == dataclasses ._FIELD_CLASSVAR : # noqa
743
+ if not field .init or field ._field_type != dataclasses ._FIELD : # noqa
752
744
continue
753
745
if flat :
754
746
if isinstance (field .type , type ) and issubclass (field .type , Config ):
@@ -869,22 +861,15 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ
869
861
f"Config comparison errors:\n " + "\n " .join (errors ),
870
862
log_fn = log_fn ,
871
863
)
872
-
873
- @classmethod
874
- def _check_abstract (cls ) -> None :
875
- if cls ._abstract :
876
- raise ValidationError (f"{ cls .__name__ } is abstract" )
877
- if not cls .__class_validated__ :
878
- raise ValidationError (
879
- f"{ cls .__name__ } hasn't been validated. Make sure to use the @config_class decorator."
880
- )
864
+ return None
881
865
882
866
def __init_subclass__ (cls ):
883
867
"""
884
868
We need to postpone validation until the class has been processed by the dataclass wrapper.
885
869
"""
870
+ Assert .eq (cls .__name__ , cls .__qualname__ )
886
871
for base_class in cls .__mro__ :
887
- if issubclass (base_class , Config ):
872
+ if issubclass (base_class , Config ) and base_class is not cls :
888
873
assert cls .__class_validated__ , (
889
874
f"Parent class { get_type_name (base_class )} of config class { get_type_name (cls )} has not been validated."
890
875
f" Make sure to use the @config_class decorator."
@@ -913,7 +898,6 @@ def __init_subclass__(cls):
913
898
valid = value .pop ("valid" , base_class_field .valid ),
914
899
default = value .pop ("default" , base_class_field .default ),
915
900
default_factory = value .pop ("default_factory" , base_class_field .default_factory ),
916
- repr = value .pop ("repr" , base_class_field .repr ),
917
901
hash = value .pop ("hash" , base_class_field .hash ),
918
902
compare = value .pop ("compare" , base_class_field .compare ),
919
903
metadata = value .pop ("metadata" , base_class_field .metadata ),
0 commit comments