From aa3bc0be1368c22a9eceb5d00a1f69db779858b9 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 14 May 2025 16:43:45 -0400 Subject: [PATCH 1/5] stuff --- fast_llm/config.py | 92 ++++++++----------- fast_llm/data/data/config.py | 4 +- fast_llm/data/data/gpt/config.py | 3 +- fast_llm/data/dataset/config.py | 2 - fast_llm/data/dataset/gpt/config.py | 5 +- fast_llm/data/preparator/gpt_memmap/config.py | 7 +- fast_llm/engine/base_model/base_model.py | 2 +- fast_llm/engine/base_model/config.py | 1 + fast_llm/engine/config_utils/run.py | 8 +- fast_llm/engine/multi_stage/config.py | 17 +--- fast_llm/engine/training/config.py | 32 ++----- fast_llm/layers/language_model/config.py | 1 - fast_llm/layers/ssm/config.py | 1 - fast_llm/layers/transformer/config.py | 3 - fast_llm/models/custom/config.py | 8 +- fast_llm/models/gpt/config.py | 8 +- fast_llm/models/ssm/config.py | 9 +- fast_llm/tools/cli.py | 2 +- fast_llm/tools/convert.py | 18 ++-- tests/config/common.py | 8 +- tests/config/test_config.py | 30 ++++++ tests/data/common.py | 2 +- tests/test_checkpoint.py | 16 ++-- tests/test_ssms.py | 2 + tools/push_model.py | 4 +- 25 files changed, 128 insertions(+), 157 deletions(-) create mode 100644 tests/config/test_config.py diff --git a/fast_llm/config.py b/fast_llm/config.py index 46c903f1..3b277202 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1,3 +1,4 @@ +import abc import contextlib import copy import dataclasses @@ -137,7 +138,6 @@ def __init__( default=dataclasses.MISSING, default_factory=dataclasses.MISSING, init: bool = True, - repr: bool = True, hash=None, compare: bool = True, metadata=None, @@ -146,12 +146,11 @@ def __init__( if default is not dataclasses.MISSING and default_factory is not dataclasses.MISSING: raise ValueError("cannot specify both default and default_factory") if isinstance(default_factory, type) and issubclass(default_factory, Config): - default_factory = _ConfigFactory(default_factory) + raise ValueError("Config classes should not be used as `default_factory`") super().__init__( default=default, default_factory=default_factory, init=init, - repr=repr, hash=hash, compare=compare, metadata=metadata, @@ -223,20 +222,6 @@ def valid(x): return valid -class _ConfigFactory: - """ - A dataclass default factory that prevents early validation. - Validation is still done through the parent config if needed. - """ - - def __init__(self, factory: typing.Callable[[], "Config"] | type["Config"]): - self._factory = factory - - def __call__(self): - with NoAutoValidate(): - return self._factory() - - class ValidationError(ValueError): pass @@ -257,7 +242,7 @@ def _process_config_class(cls: type["Config"]): return cls -def config_class(cls=None): +def config_class[T: Config]() -> typing.Callable[[type[T]], type[T]]: """ Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications. """ @@ -280,20 +265,23 @@ def __init__(self, **kwargs): if _AUTO_VALIDATE: self.validate() - cls.__init__ = __init__ + wrapped.__init__ = __init__ return wrapped - # See if we're being called as @config_class or @config_class(). - if cls is None: - # We're called with parens. - return wrap + return wrap + - # We're called as @config_class without parens. - return wrap(cls) +class ConfigMeta(abc.ABCMeta): + def __call__(cls: "type[Config]", **kwargs): + # Always go through `_from_dict` for correct dynamic class selection and nested config instantiation. + if not kwargs.pop("_from_dict_check", False): + # with NoAutoValidate(): + return cls._from_dict(kwargs) + return super().__call__(**kwargs) -@dataclasses.dataclass() -class Config: +@dataclasses.dataclass(kw_only=True, repr=False) +class Config(metaclass=ConfigMeta): """ An advanced `dataclass` with basic type checking, validation and argparse support. Typically, a subclass will: @@ -307,14 +295,14 @@ class Config: # Set to true to prevent instantiation. _abstract: typing.ClassVar[bool] = False # Keep track of whether an instance has been validated - _validated: bool = Field(init=False, repr=False) + _validated: bool = Field(init=False) # Keep track of unknown fields so they can be reported during validation. - _unknown_fields: dict[str, typing.Any] = Field(init=False, repr=False) + _unknown_fields: dict[str, typing.Any] = Field(init=False) # Keep track of explicitly set fields to ensure they get serialized and used as config updates. - _explicit_fields: set[str] = Field(init=False, repr=False) + _explicit_fields: set[str] = Field(init=False) # Used within `_set_implicit_default` to set implicit defaults for fields # without them being automatically added to `_explicit_fields`. - _setting_implicit_default: bool | None = Field(init=False, repr=False) + _setting_implicit_default: bool | None = Field(init=False) def __setattr__(self, key: str, value: typing.Any) -> None: """ @@ -339,7 +327,7 @@ def __setattr__(self, key: str, value: typing.Any) -> None: ) else: field = self.get_field(key) - if field.init and field._field_type != dataclasses._FIELD_CLASSVAR: + if field.init and field._field_type == dataclasses._FIELD: # Adding to explicit field list except within `_set_implicit_default` context, # during dataclass initialization (`_setting_implicit_default` not yet set) # and during automated config validation (`_setting_implicit_default=None`) @@ -358,13 +346,13 @@ def __delattr__(self, key: str) -> None: super().__delattr__(key) @contextlib.contextmanager - def _set_implicit_default(self, _value: bool | int = True): + def _set_implicit_default(self, _value: bool | None = True): assert self._setting_implicit_default is False self._setting_implicit_default = _value yield self._setting_implicit_default = False - def validate[T](self: T, *, _is_validating: bool = False) -> T: + def validate[T: Config](self: T, *, _is_validating: bool = False) -> T: """ Validate a class and mark it as read-only This should not be overridden in derived classes. @@ -388,11 +376,16 @@ def _validate(self) -> None: Can be extended to add custom post-processing (typically before the super() call) and validation (typically after) """ - self._check_abstract() + if self._abstract: + raise ValidationError(f"{type(self).__name__} is abstract") + if not self.__class_validated__: + raise ValidationError( + f"{type(self).__name__} hasn't been validated. Make sure to use the @config_class decorator." + ) errors = [] with self._set_implicit_default(None): for name, field in self.fields(): - if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa + if not field.init or field._field_type != dataclasses._FIELD: # noqa continue value = getattr(self, name) if isinstance(value, Tag): @@ -610,11 +603,7 @@ def _add_field_to_args( all_fields: bool = False, serializable: bool = True, ) -> None: - if ( - field is not None - and (not field.init or field._field_type == dataclasses._FIELD_CLASSVAR) - and not all_fields - ): + if field is not None and (not field.init or field._field_type != dataclasses._FIELD) and not all_fields: # Exclude class variables and derived fields unless requested explicitly. return explicit_field = ( @@ -677,6 +666,9 @@ def to_copy[ ) -> T: return self.from_dict(self, *updates, strict=strict, update_type=update_type) + def __repr__(self): + return self.to_logs(log_fn=str) + def to_logs[ T ]( @@ -739,7 +731,7 @@ def _from_dict( flat: bool = False, ) -> typing.Self: # TODO v0.3: Remove flat format - out_arg_dict = {} + out_arg_dict = {"_from_dict_check": True} # TODO v0.3: Remove backward compatibility fix if "__class__" in default: @@ -748,7 +740,7 @@ def _from_dict( # Do not validate yet in case the root class sets cross-dependencies in validation. with NoAutoValidate(): for name, field in cls.fields(): - if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa + if not field.init or field._field_type != dataclasses._FIELD: # noqa continue if flat: if isinstance(field.type, type) and issubclass(field.type, Config): @@ -869,22 +861,15 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ f"Config comparison errors:\n " + "\n".join(errors), log_fn=log_fn, ) - - @classmethod - def _check_abstract(cls) -> None: - if cls._abstract: - raise ValidationError(f"{cls.__name__} is abstract") - if not cls.__class_validated__: - raise ValidationError( - f"{cls.__name__} hasn't been validated. Make sure to use the @config_class decorator." - ) + return None def __init_subclass__(cls): """ We need to postpone validation until the class has been processed by the dataclass wrapper. """ + Assert.eq(cls.__name__, cls.__qualname__) for base_class in cls.__mro__: - if issubclass(base_class, Config): + if issubclass(base_class, Config) and base_class is not cls: assert cls.__class_validated__, ( f"Parent class {get_type_name(base_class)} of config class {get_type_name(cls)} has not been validated." f" Make sure to use the @config_class decorator." @@ -913,7 +898,6 @@ def __init_subclass__(cls): valid=value.pop("valid", base_class_field.valid), default=value.pop("default", base_class_field.default), default_factory=value.pop("default_factory", base_class_field.default_factory), - repr=value.pop("repr", base_class_field.repr), hash=value.pop("hash", base_class_field.hash), compare=value.pop("compare", base_class_field.compare), metadata=value.pop("metadata", base_class_field.metadata), diff --git a/fast_llm/data/data/config.py b/fast_llm/data/data/config.py index 25850ac3..41dbb5d9 100644 --- a/fast_llm/data/data/config.py +++ b/fast_llm/data/data/config.py @@ -9,6 +9,4 @@ class DataConfig(Config): _abstract = True _sampling_config_class: typing.ClassVar[type[SamplingData]] - sampling: SamplingConfig = Field( - default_factory=SamplingConfig, desc="Default configuration for dataset sampling." - ) + sampling: SamplingConfig = Field(desc="Default configuration for dataset sampling.") diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 6c598c0c..85bcc656 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -27,7 +27,6 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): _abstract = False tokenizer: TokenizerConfig = Field( - default_factory=TokenizerConfig, desc="Configuration for the tokenizer (for FIM).", hint=FieldHint.feature, ) @@ -37,7 +36,7 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): desc="Configuration for the dataset(s).", hint=FieldHint.core, ) - sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig) + sampling: GPTSamplingConfig = FieldUpdate() data_sample_warn_time_ms: float = Field( default=1000, desc="Warn if a sample takes too long to load.", diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 7901d6e7..1bb4b6be 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -174,12 +174,10 @@ class SampledDatasetUpdateConfig(SampledDatasetConfig): _abstract = True sampling: SamplingConfig = Field( - default_factory=SamplingConfig, desc="Optional override to sampling configuration parameters.", hint=FieldHint.core, ) dataset: SampledDatasetConfig = Field( - default_factory=SampledDatasetConfig, desc="The dataset to sample from.", hint=FieldHint.core, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index ed9128c6..f4f6e282 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -230,8 +230,8 @@ def build(self) -> "GPTDatasetSlice": class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig): _abstract = False type_: typing.ClassVar[str | None] = "sampled" - sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig) - dataset: GPTSampledDatasetConfig = FieldUpdate(default_factory=GPTSampledDatasetConfig) + sampling: GPTSamplingConfig = FieldUpdate() + dataset: GPTSampledDatasetConfig = FieldUpdate() @config_class() @@ -450,7 +450,6 @@ class GPTLegacyConfig(Config): valid=_validate_path, ) fim: FimConfig = Field( - default_factory=FimConfig, desc="Configuration for Fill In the Middle (FIM).", hint=FieldHint.feature, ) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 2c4311c3..7091f3c8 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -24,7 +24,7 @@ MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00" -@config_class +@config_class() class GPTHuggingfaceDatasetConfig(Config): path: str = Field( default=None, @@ -77,7 +77,7 @@ class GPTHuggingfaceDatasetConfig(Config): ) -@config_class +@config_class() class DatasetPreparatorDistributedConfig(Config): # TODO: Unify with fast_llm.engine.distributed.config.DistributedConfig @@ -120,7 +120,6 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.core, ) distributed: DatasetPreparatorDistributedConfig = Field( - default_factory=DatasetPreparatorDistributedConfig, desc="Configuration for distributed processing.", hint=FieldHint.feature, ) @@ -149,12 +148,10 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): valid=check_field(Assert.geq, 1), ) dataset: GPTHuggingfaceDatasetConfig = Field( - default_factory=GPTHuggingfaceDatasetConfig, desc="Configuration for the dataset.", hint=FieldHint.feature, ) tokenizer: TokenizerConfig = Field( - default_factory=TokenizerConfig, desc="Configuration for the tokenizer.", hint=FieldHint.feature, ) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 2be1e487..df603a91 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -90,7 +90,7 @@ def __init__( config: BaseModelConfig, distributed_config: DistributedConfig, ): - self._tensor_space = TensorSpace(distributed_config) + self._tensor_space: TensorSpace = TensorSpace(distributed_config) config.setup_tensor_space(self._tensor_space) super().__init__(config) diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 25f53e4a..4be42e06 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -42,6 +42,7 @@ def _get_architecture(self) -> dict[str, typing.Any]: assert isinstance(field, Field), f"{name}, {field}" if field.hint == FieldHint.architecture: architecture[name] = self._serialize_architecture_field(getattr(self, name, MISSING)) + return architecture def _serialize_architecture_field(self, value: typing.Any) -> typing.Any: if isinstance(value, BaseModelConfig): diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index d6377409..126e0ae8 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -20,9 +20,7 @@ @config_class() class RunConfig(Config): - tensor_logs: TensorLogsConfig = Field( - default_factory=TensorLogsConfig, desc="Configuration for debug tensor logs.", hint=FieldHint.logging - ) + tensor_logs: TensorLogsConfig = Field(desc="Configuration for debug tensor logs.", hint=FieldHint.logging) # TODO v0.3: Adjust (now only affects logging to file). structured_logs: bool = Field( default=True, desc="Configure logging to the Fast-LLM format.", hint=FieldHint.logging @@ -70,9 +68,7 @@ def _validate(self): @config_class() class ExperimentConfig(RunnableConfig): - run: RunConfig = Field( - default_factory=RunConfig, desc="Global properties for the experiment.", hint=FieldHint.core - ) + run: RunConfig = Field(desc="Global properties for the experiment.", hint=FieldHint.core) def _show( self, diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index e2d04f80..9434fba6 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -211,17 +211,12 @@ class FastLLMModelConfig(Config): FastLLMCheckpointFormat, ) model_name: typing.ClassVar[str] - base_model: BaseModelConfig = Field( - default_factory=BaseModelConfig, desc="Configuration for the base model.", hint=FieldHint.core - ) + base_model: BaseModelConfig = Field(desc="Configuration for the base model.", hint=FieldHint.core) multi_stage: MultiStageConfig = Field( - default_factory=MultiStageConfig, desc="Configuration for the stage breakdown of the model.", hint=FieldHint.core, ) - distributed: DistributedConfig = Field( - default_factory=DistributedConfig, desc="Distributed configuration.", hint=FieldHint.core - ) + distributed: DistributedConfig = Field(desc="Distributed configuration.", hint=FieldHint.core) @classmethod def __fast_llm_serialize__(cls) -> str: @@ -291,11 +286,8 @@ class PretrainedFastLLMModelConfig(Config): # TODO: Generalize data, schedule, logging, etc. _abstract = True # This configs may be overridden with the pretrained config during validation, so we should be careful about accessing them before. - model: FastLLMModelConfig = Field( - default_factory=FastLLMModelConfig, desc="Configuration for the Fast-LLM model.", hint=FieldHint.core - ) + model: FastLLMModelConfig = Field(desc="Configuration for the Fast-LLM model.", hint=FieldHint.core) pretrained: CheckpointLoadConfig = Field( - default_factory=CheckpointLoadConfig, desc="Configuration for loading the configuration and state of a pretrained model.", hint=FieldHint.feature, ) @@ -315,7 +307,7 @@ def _setup(self) -> None: pass -@config_class +@config_class() class CheckpointMetadata(Config): # TODO: Make entries more flexible? # I.e.. model / format / usage (ex. training) - specific entries instead of a generic metadata? @@ -336,7 +328,6 @@ class CheckpointMetadata(Config): hint=FieldHint.core, ) config: FastLLMModelConfig = Field( - default_factory=FastLLMModelConfig, desc="The Fast-LLM model configuration for the saved model.", hint=FieldHint.core, ) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 1e990e9c..a5be2e7e 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -120,7 +120,7 @@ class WandbAlertConfig(IntervalConfig): "The update may be posted by email and/or slack depending on the Wandb account configuration.", hint=FieldHint.feature, ) - post_alerts: bool = Field(init=False, repr=False) + post_alerts: bool = Field(init=False) def _validate(self) -> None: if self.status_updates is None: @@ -141,7 +141,6 @@ class MetricsLogsConfig(IntervalConfig): @config_class() class WandbConfig(Config): alert: WandbAlertConfig = Field( - default_factory=WandbAlertConfig, desc="Configuration for Wandb alerts." " The alerts may be posted by email and/or slack depending on the Wandb account configuration.", hint=FieldHint.core, @@ -175,7 +174,6 @@ class TrainingCheckpointBaseConfig(IntervalConfig): _abstract = True save_name: typing.ClassVar[str] = "save" callback: CallbackConfig = Field( - default_factory=CallbackConfig, desc="Callback (shell script).", hint=FieldHint.core, ) @@ -257,7 +255,6 @@ class TrainingExportConfig(TrainingCheckpointBaseConfig, CheckpointStateSaveConf offset = FieldUpdate(desc="Offset for the first export.") callback: CallbackConfig = FieldUpdate(desc="Callback (shell script) to run after export.") - @abc.abstractmethod def get_save_directory(self, experiment_directory: pathlib.Path) -> pathlib.Path: return experiment_directory / "export" / self.format.name @@ -284,19 +281,11 @@ class TrainingConfig(Config): desc="A dictionary of evaluation dataset names and their configurations for the validation phase.", hint=FieldHint.core, ) - logs: MetricsLogsConfig = Field( - default_factory=MetricsLogsConfig, desc="Configuration for metric logging.", hint=FieldHint.core - ) - checkpoint: TrainingCheckpointConfig = Field( - default_factory=MetricsLogsConfig, desc="Configuration for checkpoints.", hint=FieldHint.core - ) - export: TrainingExportConfig = Field( - default_factory=MetricsLogsConfig, desc="Configuration for exports.", hint=FieldHint.core - ) - shutdown: ShutdownConfig = Field( - default_factory=ShutdownConfig, desc="Configuration for automated shutdown.", hint=FieldHint.core - ) - wandb: WandbConfig = Field(default_factory=WandbConfig, desc="Configuration for Wandb.", hint=FieldHint.core) + logs: MetricsLogsConfig = Field(desc="Configuration for metric logging.", hint=FieldHint.core) + checkpoint: TrainingCheckpointConfig = Field(desc="Configuration for checkpoints.", hint=FieldHint.core) + export: TrainingExportConfig = Field(desc="Configuration for exports.", hint=FieldHint.core) + shutdown: ShutdownConfig = Field(desc="Configuration for automated shutdown.", hint=FieldHint.core) + wandb: WandbConfig = Field(desc="Configuration for Wandb.", hint=FieldHint.core) train_iters: int = Field( default=0, desc="Total number of training iterations.", hint=FieldHint.core, valid=check_field(Assert.geq, 0) ) @@ -349,30 +338,23 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): _abstract = True # TODO: Generalize data, schedule, logging, etc. training: TrainingConfig = Field( - default_factory=TrainingConfig, desc="Configuration for the training phases and global properties.", hint=FieldHint.core, ) batch: BatchConfig = Field( - default_factory=BatchConfig, desc="Configuration for the training, validation and test batches.", hint=FieldHint.core, ) - schedule: ScheduleConfig = Field( - default_factory=ScheduleConfig, desc="Configuration for the scheduling of each iteration.", hint=FieldHint.core - ) + schedule: ScheduleConfig = Field(desc="Configuration for the scheduling of each iteration.", hint=FieldHint.core) data: DataConfig = Field( - default_factory=DataConfig, desc="Configuration for the dataset and model-independent preprocessing.", hint=FieldHint.core, ) profiling: ProfilingConfig = Field( - default_factory=ProfilingConfig, desc="Configuration for the optional profiling of GPU and CPU CUDA operations.", hint=FieldHint.logging, ) optimizer: OptimizerConfig = Field( - default_factory=OptimizerConfig, desc="Configuration for the training optimizer and learning rate schedule.", hint=FieldHint.core, ) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index d0f03ccf..0db76ad1 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -40,7 +40,6 @@ class LanguageModelKwargs: @config_class() class LanguageModelBaseConfig(BaseModelConfig): transformer: TransformerConfig = Field( - default_factory=TransformerConfig, desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index c6fe622e..25ad3d22 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -26,7 +26,6 @@ class SSMConfig(BaseModelConfig): # Normalization normalization: NormalizationConfig = Field( - default_factory=NormalizationConfig, desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index e69b1841..c621139c 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -248,17 +248,14 @@ def _validate(self) -> None: class TransformerConfig(BaseModelConfig): _abstract = False normalization: NormalizationConfig = Field( - default_factory=NormalizationConfig, desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) rotary: RotaryConfig = Field( - default_factory=RotaryConfig, desc="Configuration for the rotary positional embeddings.", hint=FieldHint.architecture, ) peft: TransformerPeftConfig = Field( - default_factory=TransformerPeftConfig, desc="Configuration for the parameter-efficient fine tuning.", hint=FieldHint.architecture, ) diff --git a/fast_llm/models/custom/config.py b/fast_llm/models/custom/config.py index 8be45e1c..08902e2c 100644 --- a/fast_llm/models/custom/config.py +++ b/fast_llm/models/custom/config.py @@ -26,7 +26,7 @@ class CustomBaseModelConfig(GPTBaseModelConfig): class CustomModelConfig(GPTModelConfig): # TODO: Add custom model config parameters, if any (typically none). model_name: typing.ClassVar[str] = "gpt_custom" - base_model: CustomBaseModelConfig = FieldUpdate(default_factory=CustomBaseModelConfig) + base_model: CustomBaseModelConfig = FieldUpdate() @classmethod def get_model_class(cls) -> type["CustomModel"]: @@ -43,14 +43,14 @@ def get_huggingface_model_class(cls) -> type["HuggingfaceCustomModelForCausalLM" @config_class() class PretrainedCustomModelConfig(PretrainedGPTModelConfig): - model: CustomModelConfig = FieldUpdate(default_factory=CustomModelConfig) + model: CustomModelConfig = FieldUpdate() @config_class() class CustomTrainerConfig(PretrainedCustomModelConfig, GPTTrainerConfig): # TODO: Add custom trainer config parameters, if any (typically none). - data: CustomDataConfig = FieldUpdate(default_factory=CustomDataConfig) - reference_models: dict[str, PretrainedCustomModelConfig] = FieldUpdate(default_factory=PretrainedCustomModelConfig) + data: CustomDataConfig = FieldUpdate() + reference_models: dict[str, PretrainedCustomModelConfig] = FieldUpdate() @classmethod def get_trainer_class(cls) -> type["CustomTrainer"]: diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 418f948e..0ec3fb51 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -129,7 +129,7 @@ def _from_dict( class GPTModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "gpt" - base_model: GPTBaseModelConfig = FieldUpdate(default_factory=GPTBaseModelConfig) + base_model: GPTBaseModelConfig = FieldUpdate() checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = FastLLMModelConfig.checkpoint_formats + ( AutoGPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, @@ -156,13 +156,13 @@ def get_huggingface_model_class(cls) -> type["HuggingfaceGPTModelForCausalLM"]: @config_class() class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): _abstract = False - model: GPTModelConfig = FieldUpdate(default_factory=GPTModelConfig) + model: GPTModelConfig = FieldUpdate() @config_class() class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): - data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) - batch: GPTBatchConfig = FieldUpdate(default_factory=GPTBatchConfig) + data: GPTDataConfig = FieldUpdate() + batch: GPTBatchConfig = FieldUpdate() # TODO: Use dynamic model type? reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 0311cc69..771a4fca 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -26,7 +26,6 @@ class HybridSSMBaseModelConfig(LanguageModelBaseConfig): _abstract = False ssm: SSMConfig = Field( - default_factory=SSMConfig, desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) @@ -129,7 +128,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: class HybridSSMModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "hybrid_ssm" - base_model: HybridSSMBaseModelConfig = FieldUpdate(default_factory=HybridSSMBaseModelConfig) + base_model: HybridSSMBaseModelConfig = FieldUpdate() checkpoint_formats = FastLLMModelConfig.checkpoint_formats + (LLambaHuggingfaceCheckpointFormat,) @classmethod @@ -154,13 +153,13 @@ def _validate(self): @config_class() class PretrainedHybridSSMModelConfig(PretrainedFastLLMModelConfig): _abstract = False - model: HybridSSMModelConfig = FieldUpdate(default_factory=HybridSSMModelConfig) + model: HybridSSMModelConfig = FieldUpdate() @config_class() class HybridTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): - data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) - batch: GPTBatchConfig = FieldUpdate(default_factory=GPTBatchConfig) + data: GPTDataConfig = FieldUpdate() + batch: GPTBatchConfig = FieldUpdate() @classmethod def get_trainer_class(cls) -> type["SSMTrainer"]: diff --git a/fast_llm/tools/cli.py b/fast_llm/tools/cli.py index 0cc02f42..8df884fe 100644 --- a/fast_llm/tools/cli.py +++ b/fast_llm/tools/cli.py @@ -21,7 +21,7 @@ def fast_llm(args=None): if parsed.subcommand == "train": from fast_llm.tools.train import CliTrainingConfig as Runnable elif parsed.subcommand == "convert": - from fast_llm.tools.convert import ConversionConfig as Runnable + from fast_llm.tools.convert import ConvertConfig as Runnable elif parsed.subcommand == "prepare": from fast_llm.tools.prepare_dataset import PrepareDatasetConfig as Runnable else: diff --git a/fast_llm/tools/convert.py b/fast_llm/tools/convert.py index d3db3745..3ee580aa 100644 --- a/fast_llm/tools/convert.py +++ b/fast_llm/tools/convert.py @@ -19,13 +19,13 @@ @config_class() -class ConversionConfig(RunnableConfig): - input: CheckpointLoadConfig = Field(default_factory=CheckpointLoadConfig) - output: CheckpointSaveConfig = Field(default_factory=CheckpointSaveConfig) +class ConvertConfig(RunnableConfig): + input: CheckpointLoadConfig = Field() + output: CheckpointSaveConfig = Field() use_cpu: bool = Field(default=False) exist_ok: bool = Field(default=False) layers_per_step: int | None = Field(default=None) - model_config_class: type[FastLLMModelConfig] = Field(default=None) + model: type[FastLLMModelConfig] = Field(default=None) @classmethod def _get_parser(cls): @@ -44,9 +44,9 @@ def _from_parsed_args(cls, parsed: argparse.Namespace, unparsed: list[str]): return config def _validate(self): - assert self.model_config_class is not None - self.input.setup(self.model_config_class) - self.output.setup(self.model_config_class) + assert self.model is not None + self.input.setup(self.model) + self.output.setup(self.model) super()._validate() def _convert_model_partial( @@ -81,7 +81,7 @@ def run(self): f"Output path {self.output.path} already exists and has been processed. Skipping model conversion..." ) return - model_class = self.model_config_class.get_model_class() + model_class = self.model.get_model_class() if self.layers_per_step is None: self._convert_model_partial(model_class, self.output) else: @@ -161,4 +161,4 @@ def run(self): if __name__ == "__main__": - ConversionConfig.parse_and_run() + ConvertConfig.parse_and_run() diff --git a/tests/config/common.py b/tests/config/common.py index a2657926..9ccfb597 100644 --- a/tests/config/common.py +++ b/tests/config/common.py @@ -13,7 +13,7 @@ class ExampleEnum(enum.StrEnum): c = "c" -@config_class +@config_class() class ExampleConfig(Config): int_field: int = Field(default=0, hint=FieldHint.optional) bool_field: bool = Field(default=False, hint=FieldHint.optional) @@ -40,7 +40,7 @@ def _validate(self) -> None: super()._validate() -@config_class +@config_class() class ExampleVerboseConfig(Config): # These fields will have non-empty default serialized values. list_default_field: list[int] = Field(default_factory=lambda: [0], hint=FieldHint.optional) @@ -56,9 +56,9 @@ def _validate(self) -> None: super()._validate() -@config_class +@config_class() class ExampleNestedConfig(ExampleConfig): - nested_field: ExampleConfig = Field(default_factory=ExampleConfig, hint=FieldHint.core) + nested_field: ExampleConfig = Field(hint=FieldHint.core) def check_config( diff --git a/tests/config/test_config.py b/tests/config/test_config.py new file mode 100644 index 00000000..4c473fa6 --- /dev/null +++ b/tests/config/test_config.py @@ -0,0 +1,30 @@ +import pytest + +from fast_llm.config import NoAutoValidate +from tests.config.common import ExampleConfig + + +def test_auto_validate(): + assert (config := ExampleConfig())._validated + with pytest.raises(RuntimeError): + config.bool_field = True + config.bool_field = False + + assert ExampleConfig.from_dict({})._validated + + with NoAutoValidate(): + assert not (config := ExampleConfig())._validated + + config.bool_field = True + + config.validate() + + assert config._validated + with pytest.raises(RuntimeError): + config.bool_field = False + config.bool_field = True + + with NoAutoValidate(): + assert not (config := ExampleConfig.from_dict({}))._validated + config.validate() + assert config._validated diff --git a/tests/data/common.py b/tests/data/common.py index 47b53195..00c3ff20 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -31,7 +31,6 @@ def get_sampling_data( *, seed: int = 54983, cache_directory: pathlib.Path | None = None, - distributed: Distributed = Distributed(DistributedConfig(), use_cpu=True), phase=PhaseType.training, sequence_length: int = 512, vocab_size=TEST_VOCAB_SIZE, @@ -41,6 +40,7 @@ def get_sampling_data( truncate_documents=True, ) -> GPTSamplingData: # Config with convenient defaults. + distributed = Distributed(DistributedConfig(), use_cpu=True) return GPTSamplingData( config=GPTSamplingConfig( seed=seed, diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 257947e9..77a4b482 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -17,7 +17,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.engine.multi_stage.multi_stage import ShardName from fast_llm.models.auto import model_registry -from fast_llm.tools.convert import ConversionConfig +from fast_llm.tools.convert import ConvertConfig from tests.common import ( CONFIG_COMMON, FORCE_REUSE_RESULTS, @@ -90,7 +90,7 @@ def test_resume(): ) -def _run_conversion(config: ConversionConfig): +def _run_conversion(config: ConvertConfig): if config.output.path.is_dir() and not REUSE_RESULTS: shutil.rmtree(config.output.path) if not config.output.path.is_dir(): @@ -106,7 +106,7 @@ def _run_conversion(config: ConversionConfig): @pytest.mark.depends(on=["test_checkpoint_and_eval"]) def test_convert_distributed_to_fast_llm(): _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CKPT_PATH, format=DistributedCheckpointFormat, @@ -125,7 +125,7 @@ def test_convert_fast_llm_to_huggingface(): if HUGGINGFACE_CHECKPOINT_FORMAT is None: pytest.skip(f"Conversion not supported for {TEST_MODEL}") _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CONVERT_PATH / "fast_llm_0", format=FastLLMCheckpointFormat, @@ -142,7 +142,7 @@ def test_convert_fast_llm_to_huggingface(): @pytest.mark.depends(on=["test_convert_fast_llm_to_huggingface"]) def test_convert_huggingface_to_distributed(): _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_0", format=HUGGINGFACE_CHECKPOINT_FORMAT, @@ -161,7 +161,7 @@ def test_convert_distributed_to_huggingface(): if HUGGINGFACE_CHECKPOINT_FORMAT is None: pytest.skip(f"Conversion not supported for {TEST_MODEL}") _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CKPT_PATH, format=DistributedCheckpointFormat, @@ -178,7 +178,7 @@ def test_convert_distributed_to_huggingface(): @pytest.mark.depends(on=["test_convert_distributed_to_huggingface"]) def test_convert_huggingface_to_fast_llm(): _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_1", format=HUGGINGFACE_CHECKPOINT_FORMAT, @@ -195,7 +195,7 @@ def test_convert_huggingface_to_fast_llm(): @pytest.mark.depends(on=["test_convert_huggingface_to_fast_llm"]) def test_convert_fast_llm_to_distributed(): _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CONVERT_PATH / "fast_llm_1", format=FastLLMCheckpointFormat, diff --git a/tests/test_ssms.py b/tests/test_ssms.py index e6c9aafd..0fec3741 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -139,6 +139,7 @@ def test_load_from_llamba_checkpoint(distributed_config): assert torch.allclose(logits, hf_logits, atol=1e-2) +@pytest.mark.skip(reason="Too slow.") @pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") @pytest.mark.parametrize( "hybrid_block_layout,LAYER_CLS", @@ -207,6 +208,7 @@ def test_mamba_block(distributed_config, distributed): assert not torch.isinf(hidden_states).any() +@pytest.mark.slow @pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") @pytest.mark.parametrize( ("hybrid_block_layout"), diff --git a/tools/push_model.py b/tools/push_model.py index cd98b93c..edab3312 100644 --- a/tools/push_model.py +++ b/tools/push_model.py @@ -27,7 +27,7 @@ raise ImportError("Please install huggingface_hub to use this script") from e -from fast_llm.tools.convert import ConversionConfig # isort:skip +from fast_llm.tools.convert import ConvertConfig # isort:skip logger = logging.getLogger(__name__) @@ -147,7 +147,7 @@ def run(self) -> None: for _, checkpoint_path in new_checkpoint_paths: checkpoint_path_hf = checkpoint_path.with_name(checkpoint_path.name + "_hf") # Block until the conversion is done - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=checkpoint_path, format=DistributedCheckpointFormat, From 28d321e320354539f2939fe3f94095a96fc43dcc Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 14 May 2025 17:02:59 -0400 Subject: [PATCH 2/5] stuff --- fast_llm/config.py | 1 + fast_llm/engine/optimizer/config.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 3b277202..4928cdbd 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -151,6 +151,7 @@ def __init__( default=default, default_factory=default_factory, init=init, + repr=False, hash=hash, compare=compare, metadata=metadata, diff --git a/fast_llm/engine/optimizer/config.py b/fast_llm/engine/optimizer/config.py index 3a154c9e..f4303a5d 100644 --- a/fast_llm/engine/optimizer/config.py +++ b/fast_llm/engine/optimizer/config.py @@ -74,12 +74,10 @@ class GradientScalerConfig(Config): class OptimizerConfig(Config): learning_rate: LearningRateScheduleConfig = Field( - default_factory=LearningRateScheduleConfig, desc="A schedule for the learning rate.", hint=FieldHint.core, ) gradient_scaler: GradientScalerConfig = Field( - default_factory=GradientScalerConfig, desc="Configuration for the fixed or dynamic gradient scaling.", hint=FieldHint.feature, ) From 1bbd7fb1bf55258a4f7589d6eb0b63d71d2f0fa5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 14 May 2025 17:46:02 -0400 Subject: [PATCH 3/5] stuff --- fast_llm/tools/convert.py | 2 +- tests/test_checkpoint.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/fast_llm/tools/convert.py b/fast_llm/tools/convert.py index 3ee580aa..648138ec 100644 --- a/fast_llm/tools/convert.py +++ b/fast_llm/tools/convert.py @@ -40,7 +40,7 @@ def _get_parser(cls): @classmethod def _from_parsed_args(cls, parsed: argparse.Namespace, unparsed: list[str]): config = super()._from_parsed_args(parsed, unparsed) - config.model_config_class = model_registry[parsed.model_type] + config.model = model_registry[parsed.model_type] return config def _validate(self): diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 77a4b482..e0845a4c 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -115,7 +115,7 @@ def test_convert_distributed_to_fast_llm(): path=_CONVERT_PATH / "fast_llm_0", format=FastLLMCheckpointFormat, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) @@ -134,7 +134,7 @@ def test_convert_fast_llm_to_huggingface(): path=_CONVERT_PATH / "huggingface_0", format=HUGGINGFACE_CHECKPOINT_FORMAT, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) @@ -151,7 +151,7 @@ def test_convert_huggingface_to_distributed(): path=_CONVERT_PATH / "distributed_0", format=DistributedCheckpointFormat, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) @@ -170,7 +170,7 @@ def test_convert_distributed_to_huggingface(): path=_CONVERT_PATH / "huggingface_1", format=HUGGINGFACE_CHECKPOINT_FORMAT, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) @@ -187,7 +187,7 @@ def test_convert_huggingface_to_fast_llm(): path=_CONVERT_PATH / "fast_llm_1", format=FastLLMCheckpointFormat, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) @@ -204,7 +204,7 @@ def test_convert_fast_llm_to_distributed(): path=_CONVERT_PATH / "distributed_1", format=DistributedCheckpointFormat, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) From 35959491886cd87d6a592b61745c42214a08c4a5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 14 May 2025 19:29:34 -0400 Subject: [PATCH 4/5] Minimalistic dynamic configs --- fast_llm/config.py | 73 ++++++++++++++++++++++- fast_llm/data/dataset/gpt/config.py | 92 +++++------------------------ 2 files changed, 85 insertions(+), 80 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 4928cdbd..6e3e92dc 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -12,7 +12,7 @@ import yaml -from fast_llm.utils import Assert, Tag, compare_nested, get_type_name, header, log +from fast_llm.utils import Assert, Registry, Tag, compare_nested, get_type_name, header, log logger = logging.getLogger(__name__) @@ -243,7 +243,9 @@ def _process_config_class(cls: type["Config"]): return cls -def config_class[T: Config]() -> typing.Callable[[type[T]], type[T]]: +def config_class[ + T: Config +](registry: bool = False, dynamic_type: "dict[type[Config], str]|None" = None) -> typing.Callable[[type[T]], type[T]]: """ Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications. """ @@ -253,7 +255,7 @@ def wrap(cls): if hasattr(cls, "__post_init__"): raise TypeError(f"`__post_init__` should not be implemented for `Config` classes") - wrapped = _process_config_class(dataclasses.dataclass(cls, kw_only=True)) + wrapped = _process_config_class(dataclasses.dataclass(cls, kw_only=True, repr=False)) wrapped_init = cls.__init__ @@ -267,6 +269,13 @@ def __init__(self, **kwargs): self.validate() wrapped.__init__ = __init__ + + wrapped._registry = Registry[str, type[wrapped]](wrapped.__name__, {}) if registry else None + + if dynamic_type is not None: + for cls_, name in dynamic_type.items(): + cls_.register_subclass(name, wrapped) + return wrapped return wrap @@ -305,6 +314,9 @@ class Config(metaclass=ConfigMeta): # without them being automatically added to `_explicit_fields`. _setting_implicit_default: bool | None = Field(init=False) + # A registry for all the config classes. + _registry: typing.ClassVar[Registry[str, type[typing.Self]] | None] = None + def __setattr__(self, key: str, value: typing.Any) -> None: """ Make the class read-only after validation. @@ -358,6 +370,17 @@ def validate[T: Config](self: T, *, _is_validating: bool = False) -> T: Validate a class and mark it as read-only This should not be overridden in derived classes. """ + # Should be handled in `from_dict`, but can fail if instantiating directly. + try: + expected_class = self.get_subclass(self.type) + except KeyError as e: + # Delayed instantiation error in `from_dict`. + raise ValidationError(*e.args) + + if expected_class is not None: + # Should be handled in `from_dict`, but can fail if instantiating directly. + Assert.is_(self.__class__, expected_class) + if not self._validated: try: self._validate() @@ -738,6 +761,14 @@ def _from_dict( if "__class__" in default: del default["__class__"] + try: + actual_cls = cls.get_subclass(default.get("type")) + if actual_cls is not None and actual_cls is not cls: + return actual_cls._from_dict(default, strict=strict, flat=flat) + except KeyError: + # Postpone error to validation. + pass + # Do not validate yet in case the root class sets cross-dependencies in validation. with NoAutoValidate(): for name, field in cls.fields(): @@ -864,6 +895,42 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ ) return None + @classmethod + def register_subclass(cls, name: str, cls_: type[typing.Self]) -> None: + Assert.custom(issubclass, cls_, cls) + if cls._registry is None: + raise NotImplementedError(f"Subclass `{name}` doesn't have a registry..") + if name in cls._registry: + old_cls = cls._registry[name] + if old_cls.__name__ == cls_.__name__ and cls._registry[name].__module__ == cls_.__module__: + del cls._registry[name] + else: + raise KeyError(f"{cls.__name__} class registry already has an entry {name} from class {cls.__name__}.") + cls._registry[name] = cls_ + + @classmethod + def get_subclass(cls, name: str | None): + # TODO: Make it case-insensitive? + if name is None: + return None + cls_ = None + for base_class in cls.__mro__: + if issubclass(base_class, Config) and base_class._registry is not None and name in base_class._registry: + if cls_ is None: + cls_ = base_class._registry[name] + if not issubclass(cls_, cls): + raise KeyError(f" {cls_.__name__} is not a subclass of {cls.__name__} (from type {name})") + elif base_class._registry[name] is not cls_: + # We explicitly prevent ambiguous classes to ensure safe and unambiguous serialization. + # TODO: Only really need to avoid conflict with `Config`'s registry, relax this a bit? + raise KeyError( + f"Ambiguous type `{name}` for base class {cls.__name__}." + f" ({cls_.__name__} vs {base_class._registry[name]})" + ) + if cls_ is None: + raise KeyError(f"Unknown type {name} for base class {cls.__name__}") + return cls_ + def __init_subclass__(cls): """ We need to postpone validation until the class has been processed by the dataclass wrapper. diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index f4f6e282..4ab0b7df 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -23,7 +23,7 @@ SamplingParameters, ) from fast_llm.engine.distributed.config import PhaseType -from fast_llm.utils import Assert, Registry, normalize_probabilities, padded_cumsum +from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum if typing.TYPE_CHECKING: from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset @@ -92,61 +92,9 @@ class GPTSamplingData(SamplingData): truncate_documents: bool = True -@config_class() +@config_class(registry=True) class GPTSampledDatasetConfig(SampledDatasetConfig): - - # TODO: Generalize dynamic types? - _registry: typing.ClassVar[Registry[str, type["GPTSampledDatasetConfig"]]] = Registry[ - str, type["GPTDatasetConfig"] - ]("gpt_dataset_class", {}) - type_: typing.ClassVar[str | None] = None - type: str | None = Field( - default=None, - desc="The type of dataset.", - hint=FieldHint.core, - ) - - def _validate(self) -> None: - if self.type is None: - self.type = self.type_ - # Should be handled in `from_dict`, but can fail if instantiating directly. - Assert.eq(self.type, self.__class__.type_) - super()._validate() - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - type_ = default.get("type") - if type_ is None: - actual_cls = cls - else: - if type_ not in cls._registry: - raise ValueError( - f"Unknown {cls._registry.name} type {type_}." f" Available types: {list(cls._registry.keys())}" - ) - actual_cls = cls._registry[type_] - Assert.custom(issubclass, actual_cls, cls) - if actual_cls == cls: - return super()._from_dict(default, strict=strict, flat=flat) - else: - return actual_cls._from_dict(default, strict=strict, flat=flat) - - def __init_subclass__(cls) -> None: - if cls._abstract and cls.type_ is not None: - # Abstract classes should not have a `type_` - raise ValueError(f"Abstract class {cls.__name__} has type = {cls.type_}, expected None.") - if cls.type_ is not None: - if cls.type_ in cls._registry: - raise ValueError( - f"Registry {cls._registry.name} already contains type {cls.type_}." - f" Make sure all classes either have a unique or `None` type." - ) - GPTSampledDatasetConfig._registry[cls.type_] = cls - super().__init_subclass__() + pass @config_class() @@ -160,10 +108,9 @@ def build(self) -> "GPTIndexedDataset": raise NotImplementedError() -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "random"}) class GPTRandomDatasetConfig(GPTSamplableDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "random" name: str = Field( default="dummy", desc="The name of the dataset.", @@ -176,10 +123,9 @@ def build(self) -> "GPTRandomDataset": return GPTRandomDataset(self.name) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "memmap"}) class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "memmap" path: pathlib.Path = Field( default=None, desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.", @@ -202,10 +148,9 @@ def build(self) -> "GPTMemmapDataset": return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated"}) class GPTConcatenatedDatasetConfig(ConcatenatedDatasetConfig, GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "concatenated" datasets: list[GPTIndexedDatasetConfig] = FieldUpdate() def build(self) -> "GPTConcatenatedDataset": @@ -214,10 +159,9 @@ def build(self) -> "GPTConcatenatedDataset": return self._build(GPTConcatenatedDataset) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "slice"}) class GPTDatasetSliceConfig(DatasetSliceConfig, GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "slice" dataset: GPTIndexedDatasetConfig = FieldUpdate() def build(self) -> "GPTDatasetSlice": @@ -226,25 +170,22 @@ def build(self) -> "GPTDatasetSlice": return self._build(GPTDatasetSlice) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "sampled"}) class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig): _abstract = False - type_: typing.ClassVar[str | None] = "sampled" sampling: GPTSamplingConfig = FieldUpdate() dataset: GPTSampledDatasetConfig = FieldUpdate() -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "blended"}) class GPTBlendedDatasetConfig(BlendedDatasetConfig, GPTSampledDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "blended" datasets: list[GPTSampledDatasetConfig] = FieldUpdate() -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "file"}) class GPTDatasetFromFileConfig(GPTSamplableDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "file" path: pathlib.Path = Field( default=None, desc="The path to a dataset config file.", @@ -280,11 +221,11 @@ def _convert_paths(self, config): return config -@config_class() +# Add user-friendly names for the configs. +@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated_memmap"}) class GPTConcatenatedMemmapConfig(GPTIndexedDatasetConfig): # TODO v0.3: Remove. _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "concatenated_memmap" path: pathlib.Path = Field( default=None, desc="The path to a dataset directory.", @@ -387,14 +328,13 @@ class FimConfig(Config): ) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "fim"}) class GPTFimSampledDatasetConfig(GPTSampledDatasetConfig, FimConfig): """ Configuration for FIM. """ _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "fim" dataset: GPTSampledDatasetConfig = Field( default=None, @@ -455,10 +395,9 @@ class GPTLegacyConfig(Config): ) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "legacy"}) class GPTLegacyDatasetConfig(GPTSampledDatasetConfig, GPTLegacyConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "legacy" def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset: @@ -537,7 +476,7 @@ def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset: return GPTSampledDatasetConfig.from_dict(dataset_config).build_and_sample(sampling) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "test_slow"}) class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig): """ A mock dataset that mimics a slow dataset creation on one rank, which may trigger a timeout. @@ -545,7 +484,6 @@ class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig): # TODO: This belongs to a testing plugin. _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "test_slow" sleep: float = Field( default=1, desc="Sleep time during build, in seconds.", From 39b1a04fd140718afda39e96a5882754819b49d8 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 14 May 2025 20:42:56 -0400 Subject: [PATCH 5/5] stuff --- fast_llm/config.py | 10 +++++++++- fast_llm/layers/common/config.py | 8 +++++++- fast_llm/layers/transformer/config.py | 16 ++++++++++++++-- tests/data/common.py | 3 +-- 4 files changed, 31 insertions(+), 6 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 6e3e92dc..380100e3 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -274,6 +274,7 @@ def __init__(self, **kwargs): if dynamic_type is not None: for cls_, name in dynamic_type.items(): + print(cls_, name, wrapped) cls_.register_subclass(name, wrapped) return wrapped @@ -899,7 +900,7 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ def register_subclass(cls, name: str, cls_: type[typing.Self]) -> None: Assert.custom(issubclass, cls_, cls) if cls._registry is None: - raise NotImplementedError(f"Subclass `{name}` doesn't have a registry..") + raise NotImplementedError(f"Subclass `{cls.__name__}` doesn't have a registry..") if name in cls._registry: old_cls = cls._registry[name] if old_cls.__name__ == cls_.__name__ and cls._registry[name].__module__ == cls_.__module__: @@ -980,6 +981,13 @@ def __init_subclass__(cls): # dataclasses expects an annotation, so we use the one from the base class. cls.__annotations__[name] = base_class_field.type + # Type for the field. At the end of class definition to avoid shadowing builtin. + type: str | None = Field( + default=None, + desc="The config class name.", + hint=FieldHint.feature, + ) + class Configurable[ConfigType: Config]: config_class: typing.ClassVar[type[Config]] = Config diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 269989ce..054c26c3 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -33,7 +33,7 @@ class NormalizationType(str, enum.Enum): rms_norm = "rms_norm" -@config_class() +@config_class(registry=True) class NormalizationConfig(BaseModelConfig): _abstract = False @@ -107,6 +107,12 @@ def _from_dict( return super()._from_dict(default, strict, flat) +for name in NormalizationType: + # We need this because we are using the reserved field name `type`. + # TODO: Implement proper dynamic typing. + NormalizationConfig.register_subclass(name.value, NormalizationConfig) + + class PeftType(str, enum.Enum): # TODO : Use a dynamic config type instead. none = "none" diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index c621139c..e7ef0b15 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -95,7 +95,7 @@ class RotaryEmbeddingType(str, enum.Enum): yarn = "yarn" -@config_class() +@config_class(registry=True) class RotaryConfig(BaseModelConfig): _abstract = False type: RotaryEmbeddingType = Field( @@ -158,6 +158,12 @@ def _validate(self) -> None: warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.") +for name in RotaryEmbeddingType: + # We need this because we are using the reserved field name `type`. + # TODO: Implement proper dynamic typing. + RotaryConfig.register_subclass(name.value, RotaryConfig) + + class AddLinearBiasChoices(str, enum.Enum): nowhere = "nowhere" everywhere = "everywhere" @@ -175,7 +181,7 @@ class TransformerSubLayerName(str, enum.Enum): mlp_2 = "mlp_2" -@config_class() +@config_class(registry=True) class TransformerPeftConfig(PeftConfig): layers: list[TransformerSubLayerName] = Field( default=None, @@ -244,6 +250,12 @@ def _validate(self) -> None: ) +for name in PeftType: + # We need this because we are using the reserved field name `type`. + # TODO: Implement proper dynamic typing. + TransformerPeftConfig.register_subclass(name.value, TransformerPeftConfig) + + @config_class() class TransformerConfig(BaseModelConfig): _abstract = False diff --git a/tests/data/common.py b/tests/data/common.py index 00c3ff20..cacb28e6 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -189,10 +189,9 @@ def validate_indexed_dataset_sampling( return token_ids -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "mock_memmap"}) class MockGPTMemmapDatasetConfig(GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "mock_memmap" num_documents: int | None = Field( default=None, desc="Expected number of documents in the dataset.",