From 4b606b06f3a185a44080fc45838fe8afa7beb322 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 30 Apr 2025 13:05:51 -0400 Subject: [PATCH 01/26] Generalize config classes --- fast_llm/config.py | 41 +++++++++++++++-- fast_llm/data/dataset/gpt/config.py | 71 +++++++---------------------- 2 files changed, 55 insertions(+), 57 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 74563286..4ced6648 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -11,7 +11,7 @@ import yaml -from fast_llm.utils import Assert, Tag, compare_nested, get_type_name, header, log +from fast_llm.utils import Assert, Registry, Tag, compare_nested, get_type_name, header, log logger = logging.getLogger(__name__) @@ -313,6 +313,14 @@ class Config: # without them being automatically added to `_explicit_fields`. _setting_implicit_default: bool | None = Field(init=False, repr=False) + # A registry for all the config classes. + _registry: typing.ClassVar[Registry[str, type["Config"]]] = Registry[str, type["Config"]]("config", {}) + type: str | None = Field( + default=None, + desc="The config class name.", + hint=FieldHint.core, + ) + def __setattr__(self, key: str, value: typing.Any) -> None: """ Make the class read-only after validation. @@ -355,7 +363,7 @@ def __delattr__(self, key: str) -> None: super().__delattr__(key) @contextlib.contextmanager - def _set_implicit_default(self, _value: bool | int = True): + def _set_implicit_default(self, _value: bool | None = True): assert self._setting_implicit_default is False self._setting_implicit_default = _value yield @@ -388,6 +396,10 @@ def _validate(self) -> None: self._check_abstract() errors = [] with self._set_implicit_default(None): + if self.type is None: + self.type = self.__class__.__name__ + # Should be handled in `from_dict`, but can fail if instantiating directly. + Assert.is_(self._registry[self.type], self.__class__) for name, field in self.fields(): if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa continue @@ -465,6 +477,13 @@ def _validate_element(cls, value, type_, name: str): raise FieldTypeError(f"Not a type.") elif issubclass(type_, Config): cls._validate_element_type(value, type_, strict=False) + # If the value belongs to a proper subclass of `type_`, + # we need an explicitly set `type` field for serialization to remember the actual config class. + if type(value) != type_: + if value.type is None: + value.type = value.__class__.__name__ + value._explicit_fields.add("type") + value.validate(_is_validating=True) else: value = cls._validate_simple(value, type_) @@ -717,7 +736,18 @@ def from_dict( for keys, value in update.items(): set_nested_dict_value(default, keys, value, update_type) - return cls._from_dict(default, strict) + type_ = default.get("type") + if type_ is None: + actual_cls = cls + else: + if type_ not in cls._registry: + raise ValueError(f"Unknown config type {type_}.") + actual_cls = cls._registry[type_] + if not issubclass(actual_cls, cls): + raise ValueError( + f"Config class {actual_cls.__name__} (from type {type_}) is not a subclass of {cls.__name__}" + ) + return actual_cls._from_dict(default, strict=strict) @classmethod def from_flat_dict( @@ -880,6 +910,11 @@ def __init_subclass__(cls): """ We need to postpone validation until the class has been processed by the dataclass wrapper. """ + Assert.is_(cls._registry, Config._registry) + cls._registry[cls.__name__] = cls + short_name = cls.__name__.strip("Config") + if short_name != cls.__name__: + cls._registry[short_name] = cls for base_class in cls.__mro__: if issubclass(base_class, Config): assert cls.__class_validated__, ( diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index ed9128c6..ae70d01d 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -23,7 +23,7 @@ SamplingParameters, ) from fast_llm.engine.distributed.config import PhaseType -from fast_llm.utils import Assert, Registry, normalize_probabilities, padded_cumsum +from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum if typing.TYPE_CHECKING: from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset @@ -94,59 +94,7 @@ class GPTSamplingData(SamplingData): @config_class() class GPTSampledDatasetConfig(SampledDatasetConfig): - - # TODO: Generalize dynamic types? - _registry: typing.ClassVar[Registry[str, type["GPTSampledDatasetConfig"]]] = Registry[ - str, type["GPTDatasetConfig"] - ]("gpt_dataset_class", {}) - type_: typing.ClassVar[str | None] = None - type: str | None = Field( - default=None, - desc="The type of dataset.", - hint=FieldHint.core, - ) - - def _validate(self) -> None: - if self.type is None: - self.type = self.type_ - # Should be handled in `from_dict`, but can fail if instantiating directly. - Assert.eq(self.type, self.__class__.type_) - super()._validate() - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - type_ = default.get("type") - if type_ is None: - actual_cls = cls - else: - if type_ not in cls._registry: - raise ValueError( - f"Unknown {cls._registry.name} type {type_}." f" Available types: {list(cls._registry.keys())}" - ) - actual_cls = cls._registry[type_] - Assert.custom(issubclass, actual_cls, cls) - if actual_cls == cls: - return super()._from_dict(default, strict=strict, flat=flat) - else: - return actual_cls._from_dict(default, strict=strict, flat=flat) - - def __init_subclass__(cls) -> None: - if cls._abstract and cls.type_ is not None: - # Abstract classes should not have a `type_` - raise ValueError(f"Abstract class {cls.__name__} has type = {cls.type_}, expected None.") - if cls.type_ is not None: - if cls.type_ in cls._registry: - raise ValueError( - f"Registry {cls._registry.name} already contains type {cls.type_}." - f" Make sure all classes either have a unique or `None` type." - ) - GPTSampledDatasetConfig._registry[cls.type_] = cls - super().__init_subclass__() + pass @config_class() @@ -558,3 +506,18 @@ def build_and_sample(self, sampling: SamplingData) -> SampledDataset: if sampling.distributed.config.rank == 0: time.sleep(self.sleep) return GPTRandomDatasetConfig().build_and_sample(sampling) + + +# Add old names to the config registry for backward compatibility. +# TODO v0.x: remove. +Config._registry["dummy"] = GPTRandomDatasetConfig +Config._registry["memmap"] = GPTMemmapDatasetConfig +Config._registry["concatenated"] = GPTConcatenatedDatasetConfig +Config._registry["slice"] = GPTDatasetSliceConfig +Config._registry["sampled"] = GPTSampledDatasetUpdateConfig +Config._registry["blended"] = GPTBlendedDatasetConfig +Config._registry["file"] = GPTDatasetFromFileConfig +Config._registry["concatenated_memmap"] = GPTConcatenatedMemmapConfig +Config._registry["fim"] = GPTFimSampledDatasetConfig +Config._registry["legacy"] = GPTLegacyDatasetConfig +Config._registry["test_slow"] = GPTTestSlowDatasetConfig From 4a67660a2a7b10d9c953493abeddbdc501d90954 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 30 Apr 2025 14:06:03 -0400 Subject: [PATCH 02/26] cli --- fast_llm/tools/cli.py | 32 ++++++++++++-------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/fast_llm/tools/cli.py b/fast_llm/tools/cli.py index 0cc02f42..f7148322 100644 --- a/fast_llm/tools/cli.py +++ b/fast_llm/tools/cli.py @@ -1,4 +1,3 @@ -import argparse import logging import sys import traceback @@ -6,6 +5,12 @@ from fast_llm.config import ValidationError from fast_llm.engine.config_utils.logging import configure_logging from fast_llm.engine.config_utils.run import log_main_rank +from fast_llm.engine.config_utils.runnable import RunnableConfig + +# TODO: These imports indirectly adds all known runnables to the config registry, need a better way? +from fast_llm.data.preparator.config import DatasetPreparatorConfig # isort: skip +from fast_llm.models.auto import trainer_registry # isort: skip +from fast_llm.tools.convert import ConversionConfig # isort: skip logger = logging.getLogger(__name__) @@ -14,28 +19,15 @@ def fast_llm(args=None): # TODO: Add hook to register model classes? (environment variable?) # (Pre-)configure logging configure_logging() - parser = argparse.ArgumentParser(add_help=False) - parser.add_argument("subcommand", choices=["train", "convert", "prepare"]) - parsed, unparsed = parser.parse_known_args(args) try: - if parsed.subcommand == "train": - from fast_llm.tools.train import CliTrainingConfig as Runnable - elif parsed.subcommand == "convert": - from fast_llm.tools.convert import ConversionConfig as Runnable - elif parsed.subcommand == "prepare": - from fast_llm.tools.prepare_dataset import PrepareDatasetConfig as Runnable - else: - raise RuntimeError("Unknown subcommand") - Runnable.parse_and_run(unparsed) - except ValidationError: - if sys.gettrace(): - raise - log_main_rank(traceback.format_exc(), log_fn=logger.error) - sys.exit(1) - except Exception: # noqa + RunnableConfig.parse_and_run(args) + except Exception as e: if sys.gettrace(): raise - logger.critical(traceback.format_exc()) + if isinstance(e, ValidationError): + log_main_rank(traceback.format_exc(), log_fn=logger.error) + else: + logger.critical(traceback.format_exc()) sys.exit(1) From 182340761c53778db3cde69741ba50d042202567 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 5 May 2025 18:54:27 -0400 Subject: [PATCH 03/26] misc --- fast_llm/config.py | 37 ++++- fast_llm/data/dataset/gpt/config.py | 25 ++-- fast_llm/layers/common/config.py | 128 +++++++++++------- fast_llm/layers/transformer/config.py | 188 ++++++++++++++------------ fast_llm/models/gpt/conversion.py | 151 ++++++++++++--------- fast_llm/models/ssm/conversion.py | 5 +- tests/layers/test_lm_head.py | 3 +- 7 files changed, 311 insertions(+), 226 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 90c9f2a2..a443e51c 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -317,7 +317,7 @@ class Config: _setting_implicit_default: bool | None = Field(init=False, repr=False) # A registry for all the config classes. - _registry: typing.ClassVar[Registry[str, type["Config"]]] = Registry[str, type["Config"]]("config", {}) + _registry: typing.ClassVar[Registry[str, type[typing.Self]]] type: str | None = Field( default=None, desc="The config class name.", @@ -399,8 +399,8 @@ def _validate(self) -> None: self._check_abstract() errors = [] with self._set_implicit_default(None): - if self.type is None: - self.type = self.__class__.__name__ + # Set the type field, or override it to the provided type with the actual class for clarity and safety. + self.type = self.__class__.__name__ # Should be handled in `from_dict`, but can fail if instantiating directly. Assert.is_(self._registry[self.type], self.__class__) for name, field in self.fields(): @@ -909,15 +909,40 @@ def _check_abstract(cls) -> None: f"{cls.__name__} hasn't been validated. Make sure to use the @config_class decorator." ) + @classmethod + def register_subclass(cls, name: str, cls_: type[typing.Self]) -> None: + cls._registry[cls.__name__] = cls + + @classmethod + def get_subclass(cls, name): + # TODO: Make it case-insensitive? + cls_ = None + for base_class in cls.__mro__: + if issubclass(base_class, Config) and name in base_class._registry: + if cls_ is None: + cls_ = base_class._registry[name] + if not issubclass(cls_, cls): + raise KeyError(f" {cls_.__name__} is not a subclass of {cls.__name__} (from type {name})") + elif base_class._registry[name] is not cls_: + # We explicitly prevent ambiguous classes to ensure safe and unambiguous serialization. + # TODO: Only really need to avoid conflict with `Config`'s registry, relax this a bit? + raise RuntimeError( + f"Ambiguous type `{name}` for base class {cls.__name__}." + f" ({cls_.__name__} vs {base_class._registry[name]})" + ) + if cls_ is None: + raise KeyError(f"Unknown type {name} for base class {cls.__name__}") + return cls_ + def __init_subclass__(cls): """ We need to postpone validation until the class has been processed by the dataclass wrapper. """ - Assert.is_(cls._registry, Config._registry) - cls._registry[cls.__name__] = cls + cls._registry = Registry[str, type[cls]](cls.__name__, {}) + Config._registry[cls.__name__] = cls short_name = cls.__name__.strip("Config") if short_name != cls.__name__: - cls._registry[short_name] = cls + Config._registry[short_name] = cls for base_class in cls.__mro__: if issubclass(base_class, Config): assert cls.__class_validated__, ( diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index ae70d01d..87feb976 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -508,16 +508,15 @@ def build_and_sample(self, sampling: SamplingData) -> SampledDataset: return GPTRandomDatasetConfig().build_and_sample(sampling) -# Add old names to the config registry for backward compatibility. -# TODO v0.x: remove. -Config._registry["dummy"] = GPTRandomDatasetConfig -Config._registry["memmap"] = GPTMemmapDatasetConfig -Config._registry["concatenated"] = GPTConcatenatedDatasetConfig -Config._registry["slice"] = GPTDatasetSliceConfig -Config._registry["sampled"] = GPTSampledDatasetUpdateConfig -Config._registry["blended"] = GPTBlendedDatasetConfig -Config._registry["file"] = GPTDatasetFromFileConfig -Config._registry["concatenated_memmap"] = GPTConcatenatedMemmapConfig -Config._registry["fim"] = GPTFimSampledDatasetConfig -Config._registry["legacy"] = GPTLegacyDatasetConfig -Config._registry["test_slow"] = GPTTestSlowDatasetConfig +# Add user-friendly names for the configs. +GPTSampledDatasetConfig.register_subclass("dummy", GPTRandomDatasetConfig) +GPTSampledDatasetConfig.register_subclass("memmap", GPTMemmapDatasetConfig) +GPTSampledDatasetConfig.register_subclass("concatenated", GPTConcatenatedDatasetConfig) +GPTSampledDatasetConfig.register_subclass("slice", GPTDatasetSliceConfig) +GPTSampledDatasetConfig.register_subclass("sampled", GPTSampledDatasetUpdateConfig) +GPTSampledDatasetConfig.register_subclass("blended", GPTBlendedDatasetConfig) +GPTSampledDatasetConfig.register_subclass("file", GPTDatasetFromFileConfig) +GPTSampledDatasetConfig.register_subclass("concatenated_memmap", GPTConcatenatedMemmapConfig) +GPTSampledDatasetConfig.register_subclass("fim", GPTFimSampledDatasetConfig) +GPTSampledDatasetConfig.register_subclass("legacy", GPTLegacyDatasetConfig) +GPTSampledDatasetConfig.register_subclass("test_slow", GPTTestSlowDatasetConfig) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 269989ce..9d2a01ff 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -1,3 +1,4 @@ +import abc import enum import typing @@ -6,6 +7,8 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: + import torch + from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.layers.common.linear import LinearBase, LinearLike from fast_llm.layers.common.normalization import LayerNorm, RMSNorm @@ -23,26 +26,30 @@ class NormalizationImplementation(str, enum.Enum): triton = "triton" -class NormalizationType(str, enum.Enum): - """ - An enum for the available normalization layers. - TODO: Add no_norm type? - """ +@config_class() +class NormalizationConfig(BaseModelConfig): + pass - layer_norm = "layer_norm" - rms_norm = "rms_norm" + @abc.abstractmethod + def get_layer(self, hidden_dim: "TensorDim") -> torch.nn.Module: + pass @config_class() -class NormalizationConfig(BaseModelConfig): +class NoNormalizationConfig(NormalizationConfig): _abstract = False - # Normalization type - type: NormalizationType = Field( - default=NormalizationType.layer_norm, - desc="The type of normalization to use, for example Layer Norm or RMS Norm.", - hint=FieldHint.architecture, - ) + @abc.abstractmethod + def get_layer(self, hidden_dim: "TensorDim") -> torch.nn.Module: + return torch.nn.Identity() + + +@config_class() +class LayerNormBaseConfig(NormalizationConfig): + """ + Common configuration for layer norm and rms norm + """ + # TODO: Rename to normalization_epsilon epsilon: float = Field( default=1e-5, @@ -69,7 +76,6 @@ class NormalizationConfig(BaseModelConfig): ) def get_layer(self, hidden_dim: "TensorDim") -> "LayerNorm | RMSNorm": - from fast_llm.layers.common.normalization import LayerNorm, RMSNorm from fast_llm.tensor import init_uniform_ kwargs = { @@ -83,14 +89,12 @@ def get_layer(self, hidden_dim: "TensorDim") -> "LayerNorm | RMSNorm": kwargs["weight_init_method"] = init_uniform_( mean - self.initialization_range, mean + self.initialization_range ) - if self.type == NormalizationType.layer_norm: - if self.initialization_range: - kwargs["bias_init_method"] = init_uniform_(-self.initialization_range, self.initialization_range) - return LayerNorm(**kwargs) - elif self.type == NormalizationType.rms_norm: - return RMSNorm(**kwargs) - else: - raise ValueError(self.type) + return self.module_class(**kwargs) + + @property + @abc.abstractmethod + def module_class(self): + pass @classmethod def _from_dict( @@ -107,21 +111,52 @@ def _from_dict( return super()._from_dict(default, strict, flat) -class PeftType(str, enum.Enum): - # TODO : Use a dynamic config type instead. - none = "none" - lora = "lora" +@config_class() +class LayerNormalizationConfig(LayerNormBaseConfig): + _abstract = False + + @property + def module_class(self): + from fast_llm.layers.common.normalization import LayerNorm + + return LayerNorm + + +@config_class() +class RMSNormalizationConfig(LayerNormBaseConfig): + _abstract = False + + @property + def module_class(self): + from fast_llm.layers.common.normalization import RMSNorm + + return RMSNorm + + +NormalizationConfig.register_subclass("none", NoNormalizationConfig) +NormalizationConfig.register_subclass("layer_norm", LayerNormalizationConfig) +NormalizationConfig.register_subclass("rms_norm", RMSNormalizationConfig) @config_class() class PeftConfig(BaseModelConfig): + @abc.abstractmethod + def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": + pass + + +@config_class() +class NoPeftConfig(PeftConfig): + _abstract = False + + def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": + return linear + + +@config_class() +class LoRAConfig(PeftConfig): _abstract = False - type: PeftType = Field( - default=PeftType.none, - desc="The type of parameter-efficient fine tuning to use Only LoRA is supported at the moment.", - hint=FieldHint.core, - ) rank: int = Field( default=8, desc="The LoRA rank, i.e. the size of the intermediate dimension.", @@ -139,20 +174,15 @@ class PeftConfig(BaseModelConfig): ) def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": - if self.type == PeftType.none: - return linear - elif self.type == PeftType.lora: - from fast_llm.layers.common.peft import lora_linear - - # TODO: Init method? - return lora_linear( - linear, - linear.weight.param_init_method, - linear.weight.param_init_method, - self.rank, - self.alpha, - self.dropout, - **kwargs, - ) - else: - raise NotImplementedError(self.type) + from fast_llm.layers.common.peft import lora_linear + + # TODO: Init method? + return lora_linear( + linear, + linear.weight.param_init_method, + linear.weight.param_init_method, + self.rank, + self.alpha, + self.dropout, + **kwargs, + ) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index e69b1841..d9607924 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -1,3 +1,4 @@ +import abc import enum import functools import logging @@ -11,7 +12,13 @@ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel, TritonConfig -from fast_llm.layers.common.config import NormalizationConfig, PeftConfig, PeftType +from fast_llm.layers.common.config import ( + LayerNormalizationConfig, + LoRAConfig, + NoPeftConfig, + NormalizationConfig, + PeftConfig, +) from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: @@ -88,21 +95,21 @@ class TransformerLossNames: router_z_loss = "router_z_loss" -class RotaryEmbeddingType(str, enum.Enum): - none = "none" - default = "default" - llama3 = "llama3" - yarn = "yarn" +@config_class() +class RotaryConfig(BaseModelConfig): + # TODO: Move rotary to its own submodule. + + @property + def enabled(self) -> bool: + return False @config_class() -class RotaryConfig(BaseModelConfig): +class NoRotaryConfig(RotaryConfig): _abstract = False - type: RotaryEmbeddingType = Field( - default=RotaryEmbeddingType.none, - desc="The type of rotary embedding to use. Choices: none, default, llama3.", - hint=FieldHint.architecture, - ) + + +class DefaultRotaryConfig(RotaryConfig): theta: float = Field( default=10000, desc="Scale for the rotary positional embeddings", @@ -114,48 +121,51 @@ class RotaryConfig(BaseModelConfig): desc="Enable the triton implementation of the rotary embeddings. Affects the model layout.", hint=FieldHint.architecture, ) - # TODO: These are not really architecture parameters, but we want to import them from huggingface. - scale_factor: float = Field( - default=8.0, desc="Scaling factor for llama3-type scaling.", hint=FieldHint.architecture - ) - low_frequency_factor: float = Field( - default=1.0, desc="Low frequency factor for llama3-type scaling.", hint=FieldHint.feature - ) - high_frequency_factor: float = Field( - default=4.0, desc="High frequency factor for llama3-type scaling.", hint=FieldHint.feature - ) - original_context_length: int = Field( - default=8192, desc="Original context length for llama3/yarn-type scaling.", hint=FieldHint.feature - ) + + @property + def enabled(self) -> bool: + return True + + @property + def complex_format(self) -> bool: + # TODO: Make a backup implementation that doesn't affect the layout. + return not self.triton + + def _validate(self) -> None: + super()._validate() + if self.triton and not TritonConfig.TRITON_ENABLED: + warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.") + + +class Llama3RotaryConfig(DefaultRotaryConfig): + # TODO: Add descriptions. + scale_factor: float = Field(default=8.0, hint=FieldHint.feature) + low_frequency_factor: float = Field(default=1.0, hint=FieldHint.feature) + high_frequency_factor: float = Field(default=4.0, hint=FieldHint.feature) + original_context_length: int = Field(default=8192, hint=FieldHint.feature) + + +class YarnRotaryConfig(DefaultRotaryConfig): + # TODO: Add descriptions. attention_factor: None | float = Field( default=None, - desc="Attention factor for yarn-type scaling.", hint=FieldHint.feature, ) beta_fast: float = Field( default=32.0, - desc="Beta-fast for yarn-type scaling.", hint=FieldHint.feature, ) beta_slow: float = Field( default=1.0, - desc="Beta-slow for yarn-type scaling.", hint=FieldHint.feature, ) + original_context_length: int = Field(default=8192, hint=FieldHint.feature) - @property - def enabled(self) -> bool: - return self.type != RotaryEmbeddingType.none - @property - def complex_format(self) -> bool: - # TODO: Make a backup implementation that doesn't affect the layout. - return self.enabled and not self.triton - - def _validate(self) -> None: - super()._validate() - if self.triton and not TritonConfig.TRITON_ENABLED: - warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.") +RotaryConfig.register_subclass("none", RotaryConfig) +RotaryConfig.register_subclass("default", DefaultRotaryConfig) +RotaryConfig.register_subclass("llama3", Llama3RotaryConfig) +RotaryConfig.register_subclass("yarn", YarnRotaryConfig) class AddLinearBiasChoices(str, enum.Enum): @@ -175,10 +185,21 @@ class TransformerSubLayerName(str, enum.Enum): mlp_2 = "mlp_2" -@config_class() class TransformerPeftConfig(PeftConfig): + @abc.abstractmethod + def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": + pass + + +@config_class() +class TransformerNoPeftConfig(TransformerPeftConfig, NoPeftConfig): + _abstract = False + + +@config_class() +class TransformerLoRAConfig(LoRAConfig, TransformerPeftConfig): layers: list[TransformerSubLayerName] = Field( - default=None, + default=(TransformerSubLayerName.query, TransformerSubLayerName.value_), desc="The layers on which to apply LoRA.", hint=FieldHint.feature, ) @@ -188,77 +209,70 @@ class TransformerPeftConfig(PeftConfig): ) def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - if self.type != PeftType.none: - if layer_type is None or self.layers is None or layer_type in self.layers: - if layer_type == TransformerSubLayerName.key: - return super().apply_linear(linear, out_channel_end=div(linear._out_dim.global_size, 2)) - elif layer_type == TransformerSubLayerName.value_: - return super().apply_linear(linear, out_channel_begin=div(linear._out_dim.global_size, 2)) - else: - return super().apply_linear(linear) - elif self.freeze_others: - linear.weight.requires_grad = False + if layer_type is None or self.layers is None or layer_type in self.layers: + if layer_type == TransformerSubLayerName.key: + return super().apply_linear(linear, out_channel_end=div(linear._out_dim.global_size, 2)) + elif layer_type == TransformerSubLayerName.value_: + return super().apply_linear(linear, out_channel_begin=div(linear._out_dim.global_size, 2)) + else: + return super().apply_linear(linear) + elif self.freeze_others: + linear.weight.requires_grad = False return linear def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - if self.type != PeftType.none and self.freeze_others: + if self.freeze_others: for parameter in module.parameters(): parameter.requires_grad = False return module def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - if self.type != PeftType.none and self.freeze_others: + if self.freeze_others: parameter.requires_grad = False return parameter def _validate(self) -> None: - if self.layers is None: - with self._set_implicit_default(): - # Setting the default layers only whee PeFT is enabled - # so they don't appear when serializing the default transformer config. - self.layers = ( - [TransformerSubLayerName.query, TransformerSubLayerName.value_] - if self.type == PeftType.lora - else [] - ) - if self.type != PeftType.none: - if TransformerSubLayerName.mlp_1 in self.layers or TransformerSubLayerName.mlp_2 in self.layers: - # TODO: Add MLP support. - raise NotImplementedError("LoRA not supported for MLP.") - if TransformerSubLayerName.dense in self.layers: - # TODO: Support InputParallelLinear (different output format). - raise NotImplementedError("LoRA not supported for attention dense layer.") - if ( - sum( - name in self.layers - for name in ( - TransformerSubLayerName.key_value, - TransformerSubLayerName.key, - TransformerSubLayerName.value_, - ) - ) - > 1 - ): - raise ValueError( - f"{TransformerSubLayerName.key_value.value}, {TransformerSubLayerName.key.value} and {TransformerSubLayerName.value_.value} are mutually exclusive." + if TransformerSubLayerName.mlp_1 in self.layers or TransformerSubLayerName.mlp_2 in self.layers: + # TODO: Add MLP support. + raise NotImplementedError("LoRA not supported for MLP.") + if TransformerSubLayerName.dense in self.layers: + # TODO: Support InputParallelLinear (different output format). + raise NotImplementedError("LoRA not supported for attention dense layer.") + if ( + sum( + name in self.layers + for name in ( + TransformerSubLayerName.key_value, + TransformerSubLayerName.key, + TransformerSubLayerName.value_, ) + ) + > 1 + ): + raise ValueError( + f"{TransformerSubLayerName.key_value.value}, {TransformerSubLayerName.key.value} and {TransformerSubLayerName.value_.value} are mutually exclusive." + ) + + +TransformerPeftConfig.register_subclass("none", TransformerNoPeftConfig) +TransformerPeftConfig.register_subclass("lora", TransformerLoRAConfig) @config_class() class TransformerConfig(BaseModelConfig): _abstract = False normalization: NormalizationConfig = Field( - default_factory=NormalizationConfig, + default_factory=LayerNormalizationConfig, desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) rotary: RotaryConfig = Field( - default_factory=RotaryConfig, + default_factory=NoRotaryConfig, desc="Configuration for the rotary positional embeddings.", hint=FieldHint.architecture, ) peft: TransformerPeftConfig = Field( - default_factory=TransformerPeftConfig, + default_factory=TransformerNoPeftConfig, desc="Configuration for the parameter-efficient fine tuning.", hint=FieldHint.architecture, ) @@ -635,7 +649,7 @@ def _from_dict( default, "use_rotary_embeddings", ("rotary", "type"), - lambda x: RotaryEmbeddingType.default if x else RotaryEmbeddingType.none, + lambda x: "default" if x else "none", ) cls._handle_renamed_field(default, "rotary_embedding_scale", ("rotary", "theta"), lambda x: math.exp(-x)) cls._handle_renamed_field(default, "triton_rotary", ("rotary", "triton")) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index bd733692..9a27a079 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -25,8 +25,14 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex -from fast_llm.layers.common.config import NormalizationType -from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType, TransformerConfig +from fast_llm.layers.common.config import LayerNormalizationConfig, RMSNormalizationConfig +from fast_llm.layers.transformer.config import ( + DefaultRotaryConfig, + Llama3RotaryConfig, + RoutingType, + TransformerConfig, + YarnRotaryConfig, +) from fast_llm.models.gpt.config import ( GPTBaseModelConfig, GPTModelConfig, @@ -184,7 +190,7 @@ def _create_transformer_layer_converters( self, fast_llm_layer_name: str, hf_layer_name: str, ignore_export: bool = False ) -> list[WeightConverter]: transformer_config: TransformerConfig = self._model.config.base_model.transformer - norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm + norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) converters = [] names_bias_cls = [ # Self-attn @@ -250,7 +256,7 @@ def _create_transformer_layer_converters( def _create_lm_head_converters(self) -> list[WeightConverter]: num_layers = self._model.config.base_model.transformer.num_layers prediction_heads = self._model.config.base_model.prediction_heads - norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm + norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) converters = [] # Next-token prediction head @@ -318,10 +324,11 @@ def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantExportParamConverter(export_names=(("architectures",),), export_value=["Starcoder2ForCausalLM"]), ConstantImportParamConverter( - fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.default + fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=DefaultRotaryConfig ), ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.layer_norm + fast_llm_names=(("transformer", "normalization", "type"),), + fast_llm_value=LayerNormalizationConfig.__name__, ), RenameParamConverter( fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) @@ -350,75 +357,90 @@ class CommonLlamaHuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + fast_llm_names=(("transformer", "normalization", "type"),), + fast_llm_value=RMSNormalizationConfig.__name__, ), RenameParamConverter( fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) ), RenameParamConverter( fast_llm_names=(("transformer", "kv_channels"),), - export_names=(("head_dim"),), + export_names=(("head_dim",),), ), ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), - RopeScalingParamConverter( - fast_llm_names=( - ("transformer", "rotary", "type"), - ("transformer", "rotary", "scale_factor"), - ("transformer", "rotary", "low_frequency_factor"), - ("transformer", "rotary", "high_frequency_factor"), - ("transformer", "rotary", "original_context_length"), - ("transformer", "rotary", "attention_factor"), - ("transformer", "rotary", "beta_fast"), - ("transformer", "rotary", "beta_slow"), + LLamaRotaryParamConverter( + fast_llm_names=(("transformer", "rotary"),), + export_names=( + ("rope_theta",), + ("rope_scaling",), ), - export_names=(("rope_scaling",),), ), ] @dataclasses.dataclass -class RopeScalingParamConverter(ParamConverter): - _HUGGINGFACE_NAMES = ( - "rope_type", - "factor", - "low_freq_factor", - "high_freq_factor", - "original_max_position_embeddings", - "attention_factor", - "beta_fast", - "beta_slow", - ) - +class LLamaRotaryParamConverter(ParamConverter): def __post_init__(self): - Assert.eq(len(self.fast_llm_names), 8) - Assert.eq(len(self.export_names), 1) + Assert.eq(len(self.fast_llm_names), 1) + Assert.eq(len(self.export_names), 2) def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - rope_type, *parameters = fast_llm_values - if rope_type == RotaryEmbeddingType.default: - return (None,) - elif rope_type == RotaryEmbeddingType.llama3: - return ({key: value for key, value in zip(self._HUGGINGFACE_NAMES, ("llama3", *parameters), strict=True)},) - elif rope_type == RotaryEmbeddingType.yarn: - return ({key: value for key, value in zip(self._HUGGINGFACE_NAMES, ("yarn", *parameters), strict=True)},) + (rotary_config,) = fast_llm_values + serialized_config = rotary_config.to_dict() + if type(rotary_config) is DefaultRotaryConfig: + rotary_scaling = { + "rope_type": "default", + } + elif type(rotary_config) is Llama3RotaryConfig: + rotary_scaling = { + "rope_type": "llama3", + "factor": serialized_config["scale_factor"], + "low_freq_factor": serialized_config["low_frequency_factor"], + "high_freq_factor": serialized_config["high_frequency_factor"], + "original_max_position_embeddings": serialized_config["original_context_length"], + } + elif type(rotary_config) is YarnRotaryConfig: + rotary_scaling = { + "rope_type": "yarn", + "attention_factor": serialized_config["attention_factor"], + "beta_fast": serialized_config["beta_fast"], + "beta_slow": serialized_config["beta_slow"], + "original_max_position_embeddings": serialized_config["original_context_length"], + } else: - raise ValueError(f"Unsupported rotary scaling type: {rope_type}") + raise ValueError(f"Unsupported rotary type: {type(rotary_config).__name__}") + + return serialized_config["theta"], rotary_scaling def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - (export_value,) = export_values - if ( - export_value is None - or export_value is MISSING - or (rope_type := export_value[self._HUGGINGFACE_NAMES[0]]) == "default" - ): - return (RotaryEmbeddingType.default,) + (DEFAULT,) * 7 - elif rope_type == RotaryEmbeddingType.llama3: - return ("llama3", *[export_value.get(key, DEFAULT) for key in self._HUGGINGFACE_NAMES[1:]]) - elif rope_type == RotaryEmbeddingType.yarn: - return ("yarn", *[export_value.get(key, DEFAULT) for key in self._HUGGINGFACE_NAMES[1:]]) - else: - raise ValueError(f"Unsupported rotary scaling type: {rope_type}") + rotary_theta, rope_scaling = export_values + rotary_type = "default" if rope_scaling in (None, MISSING) else rope_scaling.get("rope_type", "default") + rotary_config = { + "type": rotary_type, + "theta": rotary_theta, + } + if rotary_type == "default": + pass + elif rotary_type == "llama3": + rotary_config.update( + { + "scale_factor": rope_scaling.get("factor", DEFAULT), + "low_frequency_factor": rope_scaling.get("low_freq_factor", DEFAULT), + "high_frequency_factor": rope_scaling.get("high_freq_factor", DEFAULT), + "original_context_length": rope_scaling.get("original_max_position_embeddings", DEFAULT), + } + ) + elif rotary_type == "yarn": + rotary_config.update( + { + "attention_factor": rope_scaling.get("attention_factor", DEFAULT), + "beta_fast": rope_scaling.get("beta_fast", DEFAULT), + "beta_slow": rope_scaling.get("beta_slow", DEFAULT), + "original_context_length": rope_scaling.get("original_max_position_embeddings", DEFAULT), + } + ) + return (rotary_config,) # RotaryConfig.from_dict(rotary_config) class LlamaHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): @@ -481,7 +503,8 @@ def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantExportParamConverter(export_names=(("architectures",),), export_value=["Qwen2ForCausalLM"]), ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + fast_llm_names=(("transformer", "normalization", "type"),), + fast_llm_value=RMSNormalizationConfig.__name__, ), RenameParamConverter( fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) @@ -490,18 +513,12 @@ def _create_config_converters(cls) -> list[ParamConverter]: ConstantImportParamConverter( fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value="only_attn_qkv" ), - RopeScalingParamConverter( - fast_llm_names=( - ("transformer", "rotary", "type"), - ("transformer", "rotary", "scale_factor"), - ("transformer", "rotary", "low_frequency_factor"), - ("transformer", "rotary", "high_frequency_factor"), - ("transformer", "rotary", "original_context_length"), - ("transformer", "rotary", "attention_factor"), - ("transformer", "rotary", "beta_fast"), - ("transformer", "rotary", "beta_slow"), + LLamaRotaryParamConverter( + fast_llm_names=(("transformer", "rotary"),), + export_names=( + ("rope_theta",), + ("rope_scaling",), ), - export_names=(("rope_scaling",),), ), IgnoreImportQwen2SlidingWindowParamsConverter(), ] @@ -637,7 +654,7 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig def _create_lm_head_converters(self) -> list[WeightConverter]: num_layers = self._model.config.base_model.transformer.num_layers prediction_heads = self._model.config.base_model.prediction_heads - norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm + norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) converters = [] # Next-token prediction head diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 190b2ffa..783eb621 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -16,7 +16,7 @@ from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import NormalizationType +from fast_llm.layers.common.config import RMSNormalizationConfig from fast_llm.models.gpt.conversion import MLPLayer2Converter from fast_llm.models.ssm.config import HybridSSMModelConfig, LLambaHuggingfaceCheckpointFormat from fast_llm.models.ssm.model import HybridSSMModel @@ -51,7 +51,8 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_names=(("ssm", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) ), ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + fast_llm_names=(("transformer", "normalization", "type"),), + fast_llm_value=RMSNormalizationConfig.__name__, ), RenameParamConverter( fast_llm_names=(("vocab_size",),), diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 79101f34..0619e6ae 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -11,7 +11,6 @@ from fast_llm.engine.multi_stage.config import StageConfig from fast_llm.engine.multi_stage.stage import Stage from fast_llm.functional.config import CrossEntropyImpl -from fast_llm.layers.common.config import NormalizationType from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead @@ -78,7 +77,7 @@ def test_lm_head( config = GPTBaseModelConfig.from_dict( { "transformer": { - "normalization": {"type": NormalizationType.rms_norm}, + "normalization": {"type": "rms_norm"}, "hidden_size": HIDDEN_SIZE, "num_layers": 0, }, From fe7acd9c00d5530295b464d83df492de01508264 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 7 May 2025 14:10:24 -0400 Subject: [PATCH 04/26] stuff --- fast_llm/config.py | 12 +++++++----- fast_llm/engine/checkpoint/state_dict.py | 2 ++ 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index a443e51c..a215163c 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -318,11 +318,6 @@ class Config: # A registry for all the config classes. _registry: typing.ClassVar[Registry[str, type[typing.Self]]] - type: str | None = Field( - default=None, - desc="The config class name.", - hint=FieldHint.core, - ) def __setattr__(self, key: str, value: typing.Any) -> None: """ @@ -988,6 +983,13 @@ def __init_subclass__(cls): # dataclasses expects an annotation, so we use the one from the base class. cls.__annotations__[name] = base_class_field.type + # Type for the field. At the end of class definition to avoid shadowing builtin. + type: str | None = Field( + default=None, + desc="The config class name.", + hint=FieldHint.core, + ) + class Configurable[ConfigType: Config]: config_class: typing.ClassVar[type[Config]] = Config diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index 556e97be..d6807138 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -26,6 +26,8 @@ logger = logging.getLogger(__name__) +torch.distributed.gather + class StateDictCheckpointHandler(CheckpointHandler): base_file_name: typing.ClassVar[str] = "model" From d41be60ed38dd3227edac5bfd67ce3c322e55d5e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 7 May 2025 16:30:30 -0400 Subject: [PATCH 05/26] stuff --- fast_llm/config.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index a215163c..7970f92c 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -257,7 +257,7 @@ def _process_config_class(cls: type["Config"]): return cls -def config_class(cls=None): +def config_class[T: Config]() -> typing.Callable[[type[T]], type[T]]: """ Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications. """ @@ -283,13 +283,7 @@ def __init__(self, **kwargs): cls.__init__ = __init__ return wrapped - # See if we're being called as @config_class or @config_class(). - if cls is None: - # We're called with parens. - return wrap - - # We're called as @config_class without parens. - return wrap(cls) + return wrap @dataclasses.dataclass() From ec35a5030b93049b7be3ee380120629e013b5d57 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 8 May 2025 11:06:17 -0400 Subject: [PATCH 06/26] fixes --- fast_llm/config.py | 18 ++- fast_llm/data/preparator/gpt_memmap/config.py | 4 +- fast_llm/engine/multi_stage/config.py | 2 +- fast_llm/layers/common/config.py | 4 +- fast_llm/layers/transformer/config.py | 105 ++++++++++++++++++ fast_llm/layers/transformer/preprocessing.py | 9 +- fast_llm/utils.py | 5 + tests/config/common.py | 6 +- tests/test_triton_kernels.py | 13 +-- 9 files changed, 139 insertions(+), 27 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 7970f92c..43badc2c 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -311,7 +311,7 @@ class Config: _setting_implicit_default: bool | None = Field(init=False, repr=False) # A registry for all the config classes. - _registry: typing.ClassVar[Registry[str, type[typing.Self]]] + _registry: typing.ClassVar[Registry[str, type[typing.Self]]] = Registry[str, "type[Config]"]("Config", {}) def __setattr__(self, key: str, value: typing.Any) -> None: """ @@ -900,7 +900,14 @@ def _check_abstract(cls) -> None: @classmethod def register_subclass(cls, name: str, cls_: type[typing.Self]) -> None: - cls._registry[cls.__name__] = cls + Assert.custom(issubclass, cls_, cls) + if name in cls._registry: + old_cls = cls._registry[name] + if old_cls.__name__ == cls_.__name__ and cls._registry[name].__module__ == cls_.__module__: + del cls._registry[name] + else: + raise KeyError(f"{cls.__name__} class registry already has an entry {name} from class {cls.__name__}.") + cls._registry[name] = cls_ @classmethod def get_subclass(cls, name): @@ -927,13 +934,14 @@ def __init_subclass__(cls): """ We need to postpone validation until the class has been processed by the dataclass wrapper. """ + Assert.eq(cls.__name__, cls.__qualname__) cls._registry = Registry[str, type[cls]](cls.__name__, {}) - Config._registry[cls.__name__] = cls + Config.register_subclass(cls.__name__, cls) short_name = cls.__name__.strip("Config") if short_name != cls.__name__: - Config._registry[short_name] = cls + Config.register_subclass(short_name, cls) for base_class in cls.__mro__: - if issubclass(base_class, Config): + if issubclass(base_class, Config) and base_class is not cls: assert cls.__class_validated__, ( f"Parent class {get_type_name(base_class)} of config class {get_type_name(cls)} has not been validated." f" Make sure to use the @config_class decorator." diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 2c4311c3..7e0604d4 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -24,7 +24,7 @@ MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00" -@config_class +@config_class() class GPTHuggingfaceDatasetConfig(Config): path: str = Field( default=None, @@ -77,7 +77,7 @@ class GPTHuggingfaceDatasetConfig(Config): ) -@config_class +@config_class() class DatasetPreparatorDistributedConfig(Config): # TODO: Unify with fast_llm.engine.distributed.config.DistributedConfig diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index e2d04f80..d37bd890 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -315,7 +315,7 @@ def _setup(self) -> None: pass -@config_class +@config_class() class CheckpointMetadata(Config): # TODO: Make entries more flexible? # I.e.. model / format / usage (ex. training) - specific entries instead of a generic metadata? diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 9d2a01ff..3678acb5 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -31,7 +31,7 @@ class NormalizationConfig(BaseModelConfig): pass @abc.abstractmethod - def get_layer(self, hidden_dim: "TensorDim") -> torch.nn.Module: + def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": pass @@ -40,7 +40,7 @@ class NoNormalizationConfig(NormalizationConfig): _abstract = False @abc.abstractmethod - def get_layer(self, hidden_dim: "TensorDim") -> torch.nn.Module: + def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": return torch.nn.Identity() diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index d9607924..f551ef13 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -103,12 +103,16 @@ class RotaryConfig(BaseModelConfig): def enabled(self) -> bool: return False + def get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") -> "torch.Tensor": + raise NotImplementedError() + @config_class() class NoRotaryConfig(RotaryConfig): _abstract = False +@config_class() class DefaultRotaryConfig(RotaryConfig): theta: float = Field( default=10000, @@ -136,17 +140,76 @@ def _validate(self) -> None: if self.triton and not TritonConfig.TRITON_ENABLED: warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.") + def get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") -> "torch.Tensor": + import torch + + from fast_llm.functional.rotary import convert_rotary_complex_to_real + + # Calculate the complex frequencies (https://blog.eleuther.ai/rotary-embeddings/) + # `exp(i * n * a) = cos(n * a) + i sin(n * a)`, + # `a = theta ** - (2 * (channel // 2) / kv_channels)`, + # where n is the position in the sequence. + # We preform the calculation in high precision because it matters for rotary embeddings. + positions = torch.arange(sequence_length, device=device, dtype=torch.float64) + angles = torch.outer(positions, self._get_angle_scales(kv_channels, device)) + frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) + if not self.complex_format: + frequencies = convert_rotary_complex_to_real( + torch.view_as_real(frequencies).flatten(-2), kv_channels, 3 + ).contiguous() + return frequencies + def _get_angle_scales(self, kv_channels: int, device="cuda") -> "torch.Tensor": + return self.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) + + +@config_class() class Llama3RotaryConfig(DefaultRotaryConfig): + """ + Llama3 scaling: https://github.com/meta-llama/llama-models/blob/baf7b01b6e62bc7126c7b558d2b67d4533142680/models/llama3/reference_impl/model.py#L45-L67 + """ + # TODO: Add descriptions. scale_factor: float = Field(default=8.0, hint=FieldHint.feature) low_frequency_factor: float = Field(default=1.0, hint=FieldHint.feature) high_frequency_factor: float = Field(default=4.0, hint=FieldHint.feature) original_context_length: int = Field(default=8192, hint=FieldHint.feature) + def _validate(self) -> None: + super()._validate() + Assert.gt(self.high_frequency_factor, self.low_frequency_factor) + + def _get_angle_scales(self, kv_channels: int, device="cuda") -> "torch.Tensor": + import torch + + scales = super()._get_angle_scales(kv_channels, device) + low_frequency_wavelength = self.original_context_length / self.low_frequency_factor + high_frequency_wavelength = self.original_context_length / self.high_frequency_factor + new_scales = [] + for scale in scales: + wavelength = 2 * math.pi / scale + if wavelength < high_frequency_wavelength: + new_scales.append(scale) + elif wavelength > low_frequency_wavelength: + new_scales.append(scale / self.scale_factor) + else: + smooth = (self.original_context_length / wavelength - self.low_frequency_factor) / ( + self.high_frequency_factor - self.low_frequency_factor + ) + new_scales.append((1 - smooth) * scale / self.scale_factor + smooth * scale) + return torch.stack(new_scales) + +@config_class() class YarnRotaryConfig(DefaultRotaryConfig): + """ + Yarn scaling: + https://github.com/huggingface/transformers/blob/006d9249ec0270ff6c4d3840979d23fe94bdc763/src/transformers/modeling_rope_utils.py#L163 + [original paper](https://arxiv.org/abs/2309.00071) + """ + # TODO: Add descriptions. + scale_factor: float = Field(default=8.0, hint=FieldHint.feature) attention_factor: None | float = Field( default=None, hint=FieldHint.feature, @@ -161,6 +224,47 @@ class YarnRotaryConfig(DefaultRotaryConfig): ) original_context_length: int = Field(default=8192, hint=FieldHint.feature) + def _validate(self) -> None: + if self.attention_factor is None: + with self._set_implicit_default(): + self.attention_factor = 0.1 * math.log(self.scale_factor) + 1.0 + super()._validate() + + def _linear_ramp_factor(self, min, max, dim): + import torch + + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + def get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") -> "torch.Tensor": + return super().get_frequencies(sequence_length, kv_channels, device) * self.attention_factor + + def _get_angle_scales(self, kv_channels: int, device="cuda") -> "torch.Tensor": + import torch + + scales = super()._get_angle_scales(kv_channels, device) + # TODO: max_position_embeddings or original_context_length? + # see https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L304 + low = max(self._get_correction(self.beta_slow, kv_channels), 0) + high = min(self._get_correction(self.beta_fast, kv_channels), kv_channels - 1) + if low == high: + high += 0.001 # Prevent singularity + + # Get n-dimensional rotational scaling corrected for extrapolation + extrapolation_factor = torch.clamp( + (torch.arange(kv_channels, dtype=torch.float32, device=scales.device) - low) / (high - low), 0, 1 + ) + return scales / self.scale_factor * extrapolation_factor + scales * (1 - extrapolation_factor) + + def _get_correction(self, beta: float, dim: int) -> float: + return math.floor( + dim * math.log(self.original_context_length / (beta * 2 * math.pi)) / (2 * math.log(self.theta)) + ) + RotaryConfig.register_subclass("none", RotaryConfig) RotaryConfig.register_subclass("default", DefaultRotaryConfig) @@ -185,6 +289,7 @@ class TransformerSubLayerName(str, enum.Enum): mlp_2 = "mlp_2" +@config_class() class TransformerPeftConfig(PeftConfig): @abc.abstractmethod def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 2415a2f9..dee3bb65 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -8,18 +8,19 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.rotary import convert_rotary_complex_to_real from fast_llm.layers.transformer.config import ( + Llama3RotaryConfig, RotaryConfig, - RotaryEmbeddingType, TransformerConfig, TransformerDimNames, TransformerKwargs, + YarnRotaryConfig, ) from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) -def apply_llama3_scaling(config: RotaryConfig, frequencies: torch.Tensor) -> torch.Tensor: +def apply_llama3_scaling(config: Llama3RotaryConfig, frequencies: torch.Tensor) -> tuple[torch.Tensor, float]: """ Llama3 scaling: https://github.com/meta-llama/llama-models/blob/baf7b01b6e62bc7126c7b558d2b67d4533142680/models/llama3/reference_impl/model.py#L45-L67 """ @@ -41,7 +42,7 @@ def apply_llama3_scaling(config: RotaryConfig, frequencies: torch.Tensor) -> tor return torch.tensor(new_frequencies, dtype=frequencies.dtype, device=frequencies.device), 1.0 -def apply_yarn_scaling(config: RotaryConfig, frequencies: torch.Tensor, kv_channels, sequence_length) -> torch.Tensor: +def apply_yarn_scaling(config: YarnRotaryConfig, frequencies: torch.Tensor, kv_channels) -> tuple[torch.Tensor, float]: """ Yarn scaling: https://github.com/huggingface/transformers/blob/006d9249ec0270ff6c4d3840979d23fe94bdc763/src/transformers/modeling_rope_utils.py#L163 @@ -116,7 +117,7 @@ def get_rotary_frequencies( if config.type == RotaryEmbeddingType.llama3: frequencies, attention_scaling = apply_llama3_scaling(config, frequencies) elif config.type == RotaryEmbeddingType.yarn: - frequencies, attention_scaling = apply_yarn_scaling(config, frequencies, kv_channels, sequence_length) + frequencies, attention_scaling = apply_yarn_scaling(config, frequencies, kv_channels) else: attention_scaling = 1.0 angles = torch.outer(positions, frequencies) diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 51e0eee5..d5a7352b 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -217,6 +217,11 @@ def __setitem__(self, key: KeyType, value: ValueType): raise KeyError(f"Entry {key} already in {self._name} registry") self._data[key] = value + def __delitem__(self, key: KeyType): + if key not in self: + raise KeyError(f"Entry {key} not found in {self._name} registry") + del self._data[key] + def keys(self) -> list[KeyType]: return list(self._data) diff --git a/tests/config/common.py b/tests/config/common.py index a2657926..b671c4af 100644 --- a/tests/config/common.py +++ b/tests/config/common.py @@ -13,7 +13,7 @@ class ExampleEnum(enum.StrEnum): c = "c" -@config_class +@config_class() class ExampleConfig(Config): int_field: int = Field(default=0, hint=FieldHint.optional) bool_field: bool = Field(default=False, hint=FieldHint.optional) @@ -40,7 +40,7 @@ def _validate(self) -> None: super()._validate() -@config_class +@config_class() class ExampleVerboseConfig(Config): # These fields will have non-empty default serialized values. list_default_field: list[int] = Field(default_factory=lambda: [0], hint=FieldHint.optional) @@ -56,7 +56,7 @@ def _validate(self) -> None: super()._validate() -@config_class +@config_class() class ExampleNestedConfig(ExampleConfig): nested_field: ExampleConfig = Field(default_factory=ExampleConfig, hint=FieldHint.core) diff --git a/tests/test_triton_kernels.py b/tests/test_triton_kernels.py index 108a2898..b44c0580 100644 --- a/tests/test_triton_kernels.py +++ b/tests/test_triton_kernels.py @@ -28,8 +28,7 @@ from fast_llm.functional.triton.pointwise import triton_add, triton_copy, triton_fill from fast_llm.functional.triton.rotary import triton_rotary_ from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.transformer.config import RotaryConfig, RotaryEmbeddingType -from fast_llm.layers.transformer.preprocessing import get_rotary_frequencies +from fast_llm.layers.transformer.config import DefaultRotaryConfig from fast_llm.utils import Assert, rms_diff from tests.common import requires_cuda @@ -92,8 +91,7 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): y1 = apply_rotary_embeddings( x, - get_rotary_frequencies( - RotaryConfig(type=RotaryEmbeddingType.default, triton=False), + DefaultRotaryConfig(triton=False).get_frequencies( sequence_length, kv_channels, device="cuda", @@ -103,12 +101,7 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): y2 = convert_rotary_real_to_complex( triton_rotary_( convert_rotary_complex_to_real(x, kv_channels, 3), - get_rotary_frequencies( - RotaryConfig(type=RotaryEmbeddingType.default, triton=True), - sequence_length, - kv_channels, - device="cuda", - ), + DefaultRotaryConfig(triton=True).get_frequencies(sequence_length, kv_channels, device="cuda"), ), kv_channels, 3, From 3005c8c89e1515c7ac59bc88fdab874a326fc6e3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 9 May 2025 11:38:42 -0400 Subject: [PATCH 07/26] stuff --- fast_llm/config.py | 116 ++++++++-------- fast_llm/data/data/config.py | 4 +- fast_llm/data/data/gpt/config.py | 3 +- fast_llm/data/dataset/config.py | 2 - fast_llm/data/dataset/gpt/config.py | 17 +-- fast_llm/data/preparator/gpt_memmap/config.py | 3 - fast_llm/engine/base_model/base_model.py | 2 +- fast_llm/engine/config_utils/run.py | 8 +- fast_llm/engine/multi_stage/config.py | 15 +-- fast_llm/engine/optimizer/config.py | 2 - fast_llm/engine/training/config.py | 29 +--- fast_llm/layers/common/config.py | 18 ++- fast_llm/layers/language_model/config.py | 1 - fast_llm/layers/ssm/config.py | 1 - fast_llm/layers/transformer/config.py | 46 +++++-- fast_llm/layers/transformer/preprocessing.py | 125 +----------------- fast_llm/models/custom/config.py | 8 +- fast_llm/models/gpt/config.py | 8 +- fast_llm/models/ssm/config.py | 9 +- fast_llm/tools/convert.py | 4 +- tests/config/common.py | 2 +- tests/config/test_config.py | 30 +++++ tests/data/common.py | 2 +- tests/data/test_prepare_gpt_memmap.py | 16 +-- 24 files changed, 181 insertions(+), 290 deletions(-) create mode 100644 tests/config/test_config.py diff --git a/fast_llm/config.py b/fast_llm/config.py index 43badc2c..732e7611 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1,3 +1,4 @@ +import abc import contextlib import copy import dataclasses @@ -15,7 +16,6 @@ logger = logging.getLogger(__name__) - _AUTO_VALIDATE = True MISSING = Tag("") @@ -146,7 +146,7 @@ def __init__( if default is not dataclasses.MISSING and default_factory is not dataclasses.MISSING: raise ValueError("cannot specify both default and default_factory") if isinstance(default_factory, type) and issubclass(default_factory, Config): - default_factory = _ConfigFactory(default_factory) + raise ValueError("Config classes should not be used as `default_factory`") super().__init__( default=default, default_factory=default_factory, @@ -223,20 +223,6 @@ def valid(x): return valid -class _ConfigFactory: - """ - A dataclass default factory that prevents early validation. - Validation is still done through the parent config if needed. - """ - - def __init__(self, factory: typing.Callable[[], "Config"] | type["Config"]): - self._factory = factory - - def __call__(self): - with NoAutoValidate(): - return self._factory() - - class ValidationError(ValueError): pass @@ -286,8 +272,20 @@ def __init__(self, **kwargs): return wrap +class ConfigMeta(abc.ABCMeta): + def __call__(cls: "type[Config]", **kwargs): + # Always go through `_from_dict` for correct dynamic class selection and nested config instantiation. + print("AIKDNJOINF", cls) + if not kwargs.pop("_from_dict_check", False): + print("AAA") + with NoAutoValidate(): + return cls._from_dict(kwargs) + print("BBB", kwargs) + return super().__call__(**kwargs) + + @dataclasses.dataclass() -class Config: +class Config(metaclass=ConfigMeta): """ An advanced `dataclass` with basic type checking, validation and argparse support. Typically, a subclass will: @@ -336,7 +334,7 @@ def __setattr__(self, key: str, value: typing.Any) -> None: ) else: field = self.get_field(key) - if field.init and field._field_type != dataclasses._FIELD_CLASSVAR: + if field.init and field._field_type == dataclasses._FIELD: # Adding to explicit field list except within `_set_implicit_default` context, # during dataclass initialization (`_setting_implicit_default` not yet set) # and during automated config validation (`_setting_implicit_default=None`) @@ -385,15 +383,29 @@ def _validate(self) -> None: Can be extended to add custom post-processing (typically before the super() call) and validation (typically after) """ - self._check_abstract() + # Should be handled in `from_dict`, but can fail if instantiating directly. + try: + expected_class = self.get_subclass(self.type) + except KeyError as e: + # Delayed instantiation error in `from_dict`. + raise ValidationError(*e.args) + + if expected_class is not None: + # Should be handled in `from_dict`, but can fail if instantiating directly. + Assert.is_(self.__class__, expected_class) + + if self._abstract: + raise ValidationError(f"{type(self).__name__} is abstract") + if not self.__class_validated__: + raise ValidationError( + f"{type(self).__name__} hasn't been validated. Make sure to use the @config_class decorator." + ) errors = [] with self._set_implicit_default(None): # Set the type field, or override it to the provided type with the actual class for clarity and safety. self.type = self.__class__.__name__ - # Should be handled in `from_dict`, but can fail if instantiating directly. - Assert.is_(self._registry[self.type], self.__class__) for name, field in self.fields(): - if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa + if not field.init or field._field_type != dataclasses._FIELD: # noqa continue value = getattr(self, name) if isinstance(value, Tag): @@ -618,11 +630,7 @@ def _add_field_to_args( all_fields: bool = False, serializable: bool = True, ) -> None: - if ( - field is not None - and (not field.init or field._field_type == dataclasses._FIELD_CLASSVAR) - and not all_fields - ): + if field is not None and (not field.init or field._field_type != dataclasses._FIELD) and not all_fields: # Exclude class variables and derived fields unless requested explicitly. return explicit_field = ( @@ -728,18 +736,7 @@ def from_dict( for keys, value in update.items(): set_nested_dict_value(default, keys, value, update_type) - type_ = default.get("type") - if type_ is None: - actual_cls = cls - else: - if type_ not in cls._registry: - raise ValueError(f"Unknown config type {type_}.") - actual_cls = cls._registry[type_] - if not issubclass(actual_cls, cls): - raise ValueError( - f"Config class {actual_cls.__name__} (from type {type_}) is not a subclass of {cls.__name__}" - ) - return actual_cls._from_dict(default, strict=strict) + return cls._from_dict(default, strict=strict) @classmethod def from_flat_dict( @@ -758,16 +755,24 @@ def _from_dict( flat: bool = False, ) -> typing.Self: # TODO v0.3: Remove flat format - out_arg_dict = {} + out_arg_dict = {"_from_dict_check": True} # TODO v0.3: Remove backward compatibility fix if "__class__" in default: del default["__class__"] + try: + actual_cls = cls.get_subclass(default.get("type")) + if actual_cls is not None and actual_cls is not cls: + return actual_cls._from_dict(default, strict=strict, flat=flat) + except KeyError: + # Postpone error to validation. + pass + # Do not validate yet in case the root class sets cross-dependencies in validation. with NoAutoValidate(): for name, field in cls.fields(): - if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa + if not field.init or field._field_type != dataclasses._FIELD: # noqa continue if flat: if isinstance(field.type, type) and issubclass(field.type, Config): @@ -889,18 +894,10 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ log_fn=log_fn, ) - @classmethod - def _check_abstract(cls) -> None: - if cls._abstract: - raise ValidationError(f"{cls.__name__} is abstract") - if not cls.__class_validated__: - raise ValidationError( - f"{cls.__name__} hasn't been validated. Make sure to use the @config_class decorator." - ) - @classmethod def register_subclass(cls, name: str, cls_: type[typing.Self]) -> None: Assert.custom(issubclass, cls_, cls) + assert not cls_._abstract if name in cls._registry: old_cls = cls._registry[name] if old_cls.__name__ == cls_.__name__ and cls._registry[name].__module__ == cls_.__module__: @@ -910,8 +907,10 @@ def register_subclass(cls, name: str, cls_: type[typing.Self]) -> None: cls._registry[name] = cls_ @classmethod - def get_subclass(cls, name): + def get_subclass(cls, name: str | None): # TODO: Make it case-insensitive? + if name is None: + return None cls_ = None for base_class in cls.__mro__: if issubclass(base_class, Config) and name in base_class._registry: @@ -922,7 +921,7 @@ def get_subclass(cls, name): elif base_class._registry[name] is not cls_: # We explicitly prevent ambiguous classes to ensure safe and unambiguous serialization. # TODO: Only really need to avoid conflict with `Config`'s registry, relax this a bit? - raise RuntimeError( + raise KeyError( f"Ambiguous type `{name}` for base class {cls.__name__}." f" ({cls_.__name__} vs {base_class._registry[name]})" ) @@ -936,10 +935,11 @@ def __init_subclass__(cls): """ Assert.eq(cls.__name__, cls.__qualname__) cls._registry = Registry[str, type[cls]](cls.__name__, {}) - Config.register_subclass(cls.__name__, cls) - short_name = cls.__name__.strip("Config") - if short_name != cls.__name__: - Config.register_subclass(short_name, cls) + if not cls._abstract: + Config.register_subclass(cls.__name__, cls) + short_name = cls.__name__.strip("Config") + if short_name != cls.__name__: + Config.register_subclass(short_name, cls) for base_class in cls.__mro__: if issubclass(base_class, Config) and base_class is not cls: assert cls.__class_validated__, ( @@ -986,10 +986,10 @@ def __init_subclass__(cls): cls.__annotations__[name] = base_class_field.type # Type for the field. At the end of class definition to avoid shadowing builtin. - type: str | None = Field( + type: str = Field( default=None, desc="The config class name.", - hint=FieldHint.core, + hint=FieldHint.feature, ) diff --git a/fast_llm/data/data/config.py b/fast_llm/data/data/config.py index 25850ac3..41dbb5d9 100644 --- a/fast_llm/data/data/config.py +++ b/fast_llm/data/data/config.py @@ -9,6 +9,4 @@ class DataConfig(Config): _abstract = True _sampling_config_class: typing.ClassVar[type[SamplingData]] - sampling: SamplingConfig = Field( - default_factory=SamplingConfig, desc="Default configuration for dataset sampling." - ) + sampling: SamplingConfig = Field(desc="Default configuration for dataset sampling.") diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 6c598c0c..85bcc656 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -27,7 +27,6 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): _abstract = False tokenizer: TokenizerConfig = Field( - default_factory=TokenizerConfig, desc="Configuration for the tokenizer (for FIM).", hint=FieldHint.feature, ) @@ -37,7 +36,7 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): desc="Configuration for the dataset(s).", hint=FieldHint.core, ) - sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig) + sampling: GPTSamplingConfig = FieldUpdate() data_sample_warn_time_ms: float = Field( default=1000, desc="Warn if a sample takes too long to load.", diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 7901d6e7..1bb4b6be 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -174,12 +174,10 @@ class SampledDatasetUpdateConfig(SampledDatasetConfig): _abstract = True sampling: SamplingConfig = Field( - default_factory=SamplingConfig, desc="Optional override to sampling configuration parameters.", hint=FieldHint.core, ) dataset: SampledDatasetConfig = Field( - default_factory=SampledDatasetConfig, desc="The dataset to sample from.", hint=FieldHint.core, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 87feb976..df38474e 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -111,7 +111,6 @@ def build(self) -> "GPTIndexedDataset": @config_class() class GPTRandomDatasetConfig(GPTSamplableDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "random" name: str = Field( default="dummy", desc="The name of the dataset.", @@ -127,7 +126,6 @@ def build(self) -> "GPTRandomDataset": @config_class() class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "memmap" path: pathlib.Path = Field( default=None, desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.", @@ -153,7 +151,6 @@ def build(self) -> "GPTMemmapDataset": @config_class() class GPTConcatenatedDatasetConfig(ConcatenatedDatasetConfig, GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "concatenated" datasets: list[GPTIndexedDatasetConfig] = FieldUpdate() def build(self) -> "GPTConcatenatedDataset": @@ -165,7 +162,6 @@ def build(self) -> "GPTConcatenatedDataset": @config_class() class GPTDatasetSliceConfig(DatasetSliceConfig, GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "slice" dataset: GPTIndexedDatasetConfig = FieldUpdate() def build(self) -> "GPTDatasetSlice": @@ -177,22 +173,19 @@ def build(self) -> "GPTDatasetSlice": @config_class() class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig): _abstract = False - type_: typing.ClassVar[str | None] = "sampled" - sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig) - dataset: GPTSampledDatasetConfig = FieldUpdate(default_factory=GPTSampledDatasetConfig) + sampling: GPTSamplingConfig = FieldUpdate() + dataset: GPTSampledDatasetConfig = FieldUpdate() @config_class() class GPTBlendedDatasetConfig(BlendedDatasetConfig, GPTSampledDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "blended" datasets: list[GPTSampledDatasetConfig] = FieldUpdate() @config_class() class GPTDatasetFromFileConfig(GPTSamplableDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "file" path: pathlib.Path = Field( default=None, desc="The path to a dataset config file.", @@ -232,7 +225,6 @@ def _convert_paths(self, config): class GPTConcatenatedMemmapConfig(GPTIndexedDatasetConfig): # TODO v0.3: Remove. _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "concatenated_memmap" path: pathlib.Path = Field( default=None, desc="The path to a dataset directory.", @@ -342,7 +334,6 @@ class GPTFimSampledDatasetConfig(GPTSampledDatasetConfig, FimConfig): """ _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "fim" dataset: GPTSampledDatasetConfig = Field( default=None, @@ -398,7 +389,6 @@ class GPTLegacyConfig(Config): valid=_validate_path, ) fim: FimConfig = Field( - default_factory=FimConfig, desc="Configuration for Fill In the Middle (FIM).", hint=FieldHint.feature, ) @@ -407,7 +397,6 @@ class GPTLegacyConfig(Config): @config_class() class GPTLegacyDatasetConfig(GPTSampledDatasetConfig, GPTLegacyConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "legacy" def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset: @@ -494,7 +483,6 @@ class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig): # TODO: This belongs to a testing plugin. _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "test_slow" sleep: float = Field( default=1, desc="Sleep time during build, in seconds.", @@ -510,6 +498,7 @@ def build_and_sample(self, sampling: SamplingData) -> SampledDataset: # Add user-friendly names for the configs. GPTSampledDatasetConfig.register_subclass("dummy", GPTRandomDatasetConfig) +GPTSampledDatasetConfig.register_subclass("random", GPTRandomDatasetConfig) GPTSampledDatasetConfig.register_subclass("memmap", GPTMemmapDatasetConfig) GPTSampledDatasetConfig.register_subclass("concatenated", GPTConcatenatedDatasetConfig) GPTSampledDatasetConfig.register_subclass("slice", GPTDatasetSliceConfig) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 7e0604d4..7091f3c8 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -120,7 +120,6 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.core, ) distributed: DatasetPreparatorDistributedConfig = Field( - default_factory=DatasetPreparatorDistributedConfig, desc="Configuration for distributed processing.", hint=FieldHint.feature, ) @@ -149,12 +148,10 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): valid=check_field(Assert.geq, 1), ) dataset: GPTHuggingfaceDatasetConfig = Field( - default_factory=GPTHuggingfaceDatasetConfig, desc="Configuration for the dataset.", hint=FieldHint.feature, ) tokenizer: TokenizerConfig = Field( - default_factory=TokenizerConfig, desc="Configuration for the tokenizer.", hint=FieldHint.feature, ) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 2be1e487..df603a91 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -90,7 +90,7 @@ def __init__( config: BaseModelConfig, distributed_config: DistributedConfig, ): - self._tensor_space = TensorSpace(distributed_config) + self._tensor_space: TensorSpace = TensorSpace(distributed_config) config.setup_tensor_space(self._tensor_space) super().__init__(config) diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index d6377409..126e0ae8 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -20,9 +20,7 @@ @config_class() class RunConfig(Config): - tensor_logs: TensorLogsConfig = Field( - default_factory=TensorLogsConfig, desc="Configuration for debug tensor logs.", hint=FieldHint.logging - ) + tensor_logs: TensorLogsConfig = Field(desc="Configuration for debug tensor logs.", hint=FieldHint.logging) # TODO v0.3: Adjust (now only affects logging to file). structured_logs: bool = Field( default=True, desc="Configure logging to the Fast-LLM format.", hint=FieldHint.logging @@ -70,9 +68,7 @@ def _validate(self): @config_class() class ExperimentConfig(RunnableConfig): - run: RunConfig = Field( - default_factory=RunConfig, desc="Global properties for the experiment.", hint=FieldHint.core - ) + run: RunConfig = Field(desc="Global properties for the experiment.", hint=FieldHint.core) def _show( self, diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index d37bd890..9434fba6 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -211,17 +211,12 @@ class FastLLMModelConfig(Config): FastLLMCheckpointFormat, ) model_name: typing.ClassVar[str] - base_model: BaseModelConfig = Field( - default_factory=BaseModelConfig, desc="Configuration for the base model.", hint=FieldHint.core - ) + base_model: BaseModelConfig = Field(desc="Configuration for the base model.", hint=FieldHint.core) multi_stage: MultiStageConfig = Field( - default_factory=MultiStageConfig, desc="Configuration for the stage breakdown of the model.", hint=FieldHint.core, ) - distributed: DistributedConfig = Field( - default_factory=DistributedConfig, desc="Distributed configuration.", hint=FieldHint.core - ) + distributed: DistributedConfig = Field(desc="Distributed configuration.", hint=FieldHint.core) @classmethod def __fast_llm_serialize__(cls) -> str: @@ -291,11 +286,8 @@ class PretrainedFastLLMModelConfig(Config): # TODO: Generalize data, schedule, logging, etc. _abstract = True # This configs may be overridden with the pretrained config during validation, so we should be careful about accessing them before. - model: FastLLMModelConfig = Field( - default_factory=FastLLMModelConfig, desc="Configuration for the Fast-LLM model.", hint=FieldHint.core - ) + model: FastLLMModelConfig = Field(desc="Configuration for the Fast-LLM model.", hint=FieldHint.core) pretrained: CheckpointLoadConfig = Field( - default_factory=CheckpointLoadConfig, desc="Configuration for loading the configuration and state of a pretrained model.", hint=FieldHint.feature, ) @@ -336,7 +328,6 @@ class CheckpointMetadata(Config): hint=FieldHint.core, ) config: FastLLMModelConfig = Field( - default_factory=FastLLMModelConfig, desc="The Fast-LLM model configuration for the saved model.", hint=FieldHint.core, ) diff --git a/fast_llm/engine/optimizer/config.py b/fast_llm/engine/optimizer/config.py index 3a154c9e..f4303a5d 100644 --- a/fast_llm/engine/optimizer/config.py +++ b/fast_llm/engine/optimizer/config.py @@ -74,12 +74,10 @@ class GradientScalerConfig(Config): class OptimizerConfig(Config): learning_rate: LearningRateScheduleConfig = Field( - default_factory=LearningRateScheduleConfig, desc="A schedule for the learning rate.", hint=FieldHint.core, ) gradient_scaler: GradientScalerConfig = Field( - default_factory=GradientScalerConfig, desc="Configuration for the fixed or dynamic gradient scaling.", hint=FieldHint.feature, ) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 1e990e9c..0b572779 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -141,7 +141,6 @@ class MetricsLogsConfig(IntervalConfig): @config_class() class WandbConfig(Config): alert: WandbAlertConfig = Field( - default_factory=WandbAlertConfig, desc="Configuration for Wandb alerts." " The alerts may be posted by email and/or slack depending on the Wandb account configuration.", hint=FieldHint.core, @@ -175,7 +174,6 @@ class TrainingCheckpointBaseConfig(IntervalConfig): _abstract = True save_name: typing.ClassVar[str] = "save" callback: CallbackConfig = Field( - default_factory=CallbackConfig, desc="Callback (shell script).", hint=FieldHint.core, ) @@ -284,19 +282,11 @@ class TrainingConfig(Config): desc="A dictionary of evaluation dataset names and their configurations for the validation phase.", hint=FieldHint.core, ) - logs: MetricsLogsConfig = Field( - default_factory=MetricsLogsConfig, desc="Configuration for metric logging.", hint=FieldHint.core - ) - checkpoint: TrainingCheckpointConfig = Field( - default_factory=MetricsLogsConfig, desc="Configuration for checkpoints.", hint=FieldHint.core - ) - export: TrainingExportConfig = Field( - default_factory=MetricsLogsConfig, desc="Configuration for exports.", hint=FieldHint.core - ) - shutdown: ShutdownConfig = Field( - default_factory=ShutdownConfig, desc="Configuration for automated shutdown.", hint=FieldHint.core - ) - wandb: WandbConfig = Field(default_factory=WandbConfig, desc="Configuration for Wandb.", hint=FieldHint.core) + logs: MetricsLogsConfig = Field(desc="Configuration for metric logging.", hint=FieldHint.core) + checkpoint: TrainingCheckpointConfig = Field(desc="Configuration for checkpoints.", hint=FieldHint.core) + export: TrainingExportConfig = Field(desc="Configuration for exports.", hint=FieldHint.core) + shutdown: ShutdownConfig = Field(desc="Configuration for automated shutdown.", hint=FieldHint.core) + wandb: WandbConfig = Field(desc="Configuration for Wandb.", hint=FieldHint.core) train_iters: int = Field( default=0, desc="Total number of training iterations.", hint=FieldHint.core, valid=check_field(Assert.geq, 0) ) @@ -349,30 +339,23 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): _abstract = True # TODO: Generalize data, schedule, logging, etc. training: TrainingConfig = Field( - default_factory=TrainingConfig, desc="Configuration for the training phases and global properties.", hint=FieldHint.core, ) batch: BatchConfig = Field( - default_factory=BatchConfig, desc="Configuration for the training, validation and test batches.", hint=FieldHint.core, ) - schedule: ScheduleConfig = Field( - default_factory=ScheduleConfig, desc="Configuration for the scheduling of each iteration.", hint=FieldHint.core - ) + schedule: ScheduleConfig = Field(desc="Configuration for the scheduling of each iteration.", hint=FieldHint.core) data: DataConfig = Field( - default_factory=DataConfig, desc="Configuration for the dataset and model-independent preprocessing.", hint=FieldHint.core, ) profiling: ProfilingConfig = Field( - default_factory=ProfilingConfig, desc="Configuration for the optional profiling of GPU and CPU CUDA operations.", hint=FieldHint.logging, ) optimizer: OptimizerConfig = Field( - default_factory=OptimizerConfig, desc="Configuration for the training optimizer and learning rate schedule.", hint=FieldHint.core, ) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 3678acb5..c03e9957 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -34,6 +34,18 @@ class NormalizationConfig(BaseModelConfig): def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": pass + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is NormalizationConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return LayerNormalizationConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + @config_class() class NoNormalizationConfig(NormalizationConfig): @@ -45,7 +57,7 @@ def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": @config_class() -class LayerNormBaseConfig(NormalizationConfig): +class LayerNormalizationBaseConfig(NormalizationConfig): """ Common configuration for layer norm and rms norm """ @@ -112,7 +124,7 @@ def _from_dict( @config_class() -class LayerNormalizationConfig(LayerNormBaseConfig): +class LayerNormalizationConfig(LayerNormalizationBaseConfig): _abstract = False @property @@ -123,7 +135,7 @@ def module_class(self): @config_class() -class RMSNormalizationConfig(LayerNormBaseConfig): +class RMSNormalizationConfig(LayerNormalizationBaseConfig): _abstract = False @property diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index d0f03ccf..0db76ad1 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -40,7 +40,6 @@ class LanguageModelKwargs: @config_class() class LanguageModelBaseConfig(BaseModelConfig): transformer: TransformerConfig = Field( - default_factory=TransformerConfig, desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index c6fe622e..25ad3d22 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -26,7 +26,6 @@ class SSMConfig(BaseModelConfig): # Normalization normalization: NormalizationConfig = Field( - default_factory=NormalizationConfig, desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index f551ef13..a52933a9 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -12,13 +12,7 @@ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel, TritonConfig -from fast_llm.layers.common.config import ( - LayerNormalizationConfig, - LoRAConfig, - NoPeftConfig, - NormalizationConfig, - PeftConfig, -) +from fast_llm.layers.common.config import LoRAConfig, NoPeftConfig, NormalizationConfig, PeftConfig from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: @@ -99,6 +93,18 @@ class TransformerLossNames: class RotaryConfig(BaseModelConfig): # TODO: Move rotary to its own submodule. + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is RotaryConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return DefaultRotaryConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + @property def enabled(self) -> bool: return False @@ -114,6 +120,7 @@ class NoRotaryConfig(RotaryConfig): @config_class() class DefaultRotaryConfig(RotaryConfig): + _abstract = False theta: float = Field( default=10000, desc="Scale for the rotary positional embeddings", @@ -160,6 +167,8 @@ def get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") return frequencies def _get_angle_scales(self, kv_channels: int, device="cuda") -> "torch.Tensor": + import torch + return self.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) @@ -266,7 +275,7 @@ def _get_correction(self, beta: float, dim: int) -> float: ) -RotaryConfig.register_subclass("none", RotaryConfig) +RotaryConfig.register_subclass("none", NoRotaryConfig) RotaryConfig.register_subclass("default", DefaultRotaryConfig) RotaryConfig.register_subclass("llama3", Llama3RotaryConfig) RotaryConfig.register_subclass("yarn", YarnRotaryConfig) @@ -295,11 +304,29 @@ class TransformerPeftConfig(PeftConfig): def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": pass + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is TransformerPeftConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return TransformerNoPeftConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + @config_class() class TransformerNoPeftConfig(TransformerPeftConfig, NoPeftConfig): _abstract = False + def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": + return module + + def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": + return parameter + @config_class() class TransformerLoRAConfig(LoRAConfig, TransformerPeftConfig): @@ -367,17 +394,14 @@ def _validate(self) -> None: class TransformerConfig(BaseModelConfig): _abstract = False normalization: NormalizationConfig = Field( - default_factory=LayerNormalizationConfig, desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) rotary: RotaryConfig = Field( - default_factory=NoRotaryConfig, desc="Configuration for the rotary positional embeddings.", hint=FieldHint.architecture, ) peft: TransformerPeftConfig = Field( - default_factory=TransformerNoPeftConfig, desc="Configuration for the parameter-efficient fine tuning.", hint=FieldHint.architecture, ) diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index dee3bb65..bcf41365 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -1,136 +1,16 @@ import logging -import math import typing import torch from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.functional.rotary import convert_rotary_complex_to_real -from fast_llm.layers.transformer.config import ( - Llama3RotaryConfig, - RotaryConfig, - TransformerConfig, - TransformerDimNames, - TransformerKwargs, - YarnRotaryConfig, -) +from fast_llm.layers.transformer.config import RotaryConfig, TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) -def apply_llama3_scaling(config: Llama3RotaryConfig, frequencies: torch.Tensor) -> tuple[torch.Tensor, float]: - """ - Llama3 scaling: https://github.com/meta-llama/llama-models/blob/baf7b01b6e62bc7126c7b558d2b67d4533142680/models/llama3/reference_impl/model.py#L45-L67 - """ - low_frequency_wavelength = config.original_context_length / config.low_frequency_factor - high_frequency_wavelength = config.original_context_length / config.high_frequency_factor - new_frequencies = [] - for frequency in frequencies: - wavelength = 2 * math.pi / frequency - if wavelength < high_frequency_wavelength: - new_frequencies.append(frequency) - elif wavelength > low_frequency_wavelength: - new_frequencies.append(frequency / config.scale_factor) - else: - assert low_frequency_wavelength != high_frequency_wavelength - smooth = (config.original_context_length / wavelength - config.low_frequency_factor) / ( - config.high_frequency_factor - config.low_frequency_factor - ) - new_frequencies.append((1 - smooth) * frequency / config.scale_factor + smooth * frequency) - return torch.tensor(new_frequencies, dtype=frequencies.dtype, device=frequencies.device), 1.0 - - -def apply_yarn_scaling(config: YarnRotaryConfig, frequencies: torch.Tensor, kv_channels) -> tuple[torch.Tensor, float]: - """ - Yarn scaling: - https://github.com/huggingface/transformers/blob/006d9249ec0270ff6c4d3840979d23fe94bdc763/src/transformers/modeling_rope_utils.py#L163 - [original paper](https://arxiv.org/abs/2309.00071) - """ - base = config.theta - partial_rotary_factor = 1.0 - dim = int(kv_channels * partial_rotary_factor) - factor = config.scale_factor - - attention_factor = config.attention_factor - if attention_factor is None: - attention_factor = 0.1 * math.log(factor) + 1.0 - - # Compute the inverse frequencies - def find_correction_dim(num_rotations, dim, base, max_position_embeddings): - """Inverse dimension formula to find the dimension based on the number of rotations""" - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - - def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings): - """Find dimension range bounds based on rotations""" - low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) - return max(low, 0), min(high, dim - 1) - - def linear_ramp_factor(min, max, dim): - if min == max: - max += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs - # to expand the possible context length. In other words, interpolation = apply scaling factor. - # pos_freqs = base ** (torch.arange(0, dim, 2).float().to(frequencies.device) / dim) - # inv_freq_extrapolation = 1.0 / pos_freqs - # inv_freq_interpolation = 1.0 / (factor * pos_freqs) - - inv_freq_extrapolation = frequencies - inv_freq_interpolation = frequencies / factor - - # TODO: max_position_embeddings or original_context_length? - # see https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L304 - low, high = find_correction_range(config.beta_fast, config.beta_slow, dim, base, config.original_context_length) - - # Get n-dimensional rotational scaling corrected for extrapolation - inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(frequencies.device) - inv_freq = ( - inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) - + inv_freq_extrapolation * inv_freq_extrapolation_factor - ) - - return inv_freq, attention_factor - - -def get_rotary_frequencies( - config: RotaryConfig, - sequence_length, - kv_channels, - *, - device="cuda", -) -> torch.Tensor: - # Calculate the complex frequencies (https://blog.eleuther.ai/rotary-embeddings/) - # `exp(i * n * a) = cos(n * a) + i sin(n * a)`, - # `a = theta ** - (2 * (channel // 2) / kv_channels)`, - # where n is the position in the sequence. - # We preform the calculation in high precision because it matters for rotary embeddings. - positions = torch.arange(sequence_length, device=device, dtype=torch.float64) - frequencies = config.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) - # Apply scaling - if config.type == RotaryEmbeddingType.llama3: - frequencies, attention_scaling = apply_llama3_scaling(config, frequencies) - elif config.type == RotaryEmbeddingType.yarn: - frequencies, attention_scaling = apply_yarn_scaling(config, frequencies, kv_channels) - else: - attention_scaling = 1.0 - angles = torch.outer(positions, frequencies) - frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) - if not config.complex_format: - frequencies = convert_rotary_complex_to_real( - torch.view_as_real(frequencies).flatten(-2), kv_channels, 3 - ).contiguous() - # Advanced Rope types like yarn apply a post-processing scaling factor, equivalent to scaling attention. - frequencies = frequencies * attention_scaling - return frequencies - - class RotaryEmbeddingPreprocessor(Preprocessor): _scalar_dim: TensorDim _kv_channels_dim: TensorDim @@ -156,8 +36,7 @@ def _create_tensors(self, sequence_length: int) -> None: return self._tensor_cache_max_sequence_length = sequence_length - self._rotary_embedding_frequencies = get_rotary_frequencies( - self._config, + self._rotary_embedding_frequencies = self._config.get_frequencies( sequence_length, self._kv_channels_dim.global_size, device=self._tensor_space.distributed.device, diff --git a/fast_llm/models/custom/config.py b/fast_llm/models/custom/config.py index 8be45e1c..08902e2c 100644 --- a/fast_llm/models/custom/config.py +++ b/fast_llm/models/custom/config.py @@ -26,7 +26,7 @@ class CustomBaseModelConfig(GPTBaseModelConfig): class CustomModelConfig(GPTModelConfig): # TODO: Add custom model config parameters, if any (typically none). model_name: typing.ClassVar[str] = "gpt_custom" - base_model: CustomBaseModelConfig = FieldUpdate(default_factory=CustomBaseModelConfig) + base_model: CustomBaseModelConfig = FieldUpdate() @classmethod def get_model_class(cls) -> type["CustomModel"]: @@ -43,14 +43,14 @@ def get_huggingface_model_class(cls) -> type["HuggingfaceCustomModelForCausalLM" @config_class() class PretrainedCustomModelConfig(PretrainedGPTModelConfig): - model: CustomModelConfig = FieldUpdate(default_factory=CustomModelConfig) + model: CustomModelConfig = FieldUpdate() @config_class() class CustomTrainerConfig(PretrainedCustomModelConfig, GPTTrainerConfig): # TODO: Add custom trainer config parameters, if any (typically none). - data: CustomDataConfig = FieldUpdate(default_factory=CustomDataConfig) - reference_models: dict[str, PretrainedCustomModelConfig] = FieldUpdate(default_factory=PretrainedCustomModelConfig) + data: CustomDataConfig = FieldUpdate() + reference_models: dict[str, PretrainedCustomModelConfig] = FieldUpdate() @classmethod def get_trainer_class(cls) -> type["CustomTrainer"]: diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 418f948e..0ec3fb51 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -129,7 +129,7 @@ def _from_dict( class GPTModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "gpt" - base_model: GPTBaseModelConfig = FieldUpdate(default_factory=GPTBaseModelConfig) + base_model: GPTBaseModelConfig = FieldUpdate() checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = FastLLMModelConfig.checkpoint_formats + ( AutoGPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, @@ -156,13 +156,13 @@ def get_huggingface_model_class(cls) -> type["HuggingfaceGPTModelForCausalLM"]: @config_class() class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): _abstract = False - model: GPTModelConfig = FieldUpdate(default_factory=GPTModelConfig) + model: GPTModelConfig = FieldUpdate() @config_class() class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): - data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) - batch: GPTBatchConfig = FieldUpdate(default_factory=GPTBatchConfig) + data: GPTDataConfig = FieldUpdate() + batch: GPTBatchConfig = FieldUpdate() # TODO: Use dynamic model type? reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 0311cc69..771a4fca 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -26,7 +26,6 @@ class HybridSSMBaseModelConfig(LanguageModelBaseConfig): _abstract = False ssm: SSMConfig = Field( - default_factory=SSMConfig, desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) @@ -129,7 +128,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: class HybridSSMModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "hybrid_ssm" - base_model: HybridSSMBaseModelConfig = FieldUpdate(default_factory=HybridSSMBaseModelConfig) + base_model: HybridSSMBaseModelConfig = FieldUpdate() checkpoint_formats = FastLLMModelConfig.checkpoint_formats + (LLambaHuggingfaceCheckpointFormat,) @classmethod @@ -154,13 +153,13 @@ def _validate(self): @config_class() class PretrainedHybridSSMModelConfig(PretrainedFastLLMModelConfig): _abstract = False - model: HybridSSMModelConfig = FieldUpdate(default_factory=HybridSSMModelConfig) + model: HybridSSMModelConfig = FieldUpdate() @config_class() class HybridTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): - data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) - batch: GPTBatchConfig = FieldUpdate(default_factory=GPTBatchConfig) + data: GPTDataConfig = FieldUpdate() + batch: GPTBatchConfig = FieldUpdate() @classmethod def get_trainer_class(cls) -> type["SSMTrainer"]: diff --git a/fast_llm/tools/convert.py b/fast_llm/tools/convert.py index d3db3745..7f327618 100644 --- a/fast_llm/tools/convert.py +++ b/fast_llm/tools/convert.py @@ -20,8 +20,8 @@ @config_class() class ConversionConfig(RunnableConfig): - input: CheckpointLoadConfig = Field(default_factory=CheckpointLoadConfig) - output: CheckpointSaveConfig = Field(default_factory=CheckpointSaveConfig) + input: CheckpointLoadConfig = Field() + output: CheckpointSaveConfig = Field() use_cpu: bool = Field(default=False) exist_ok: bool = Field(default=False) layers_per_step: int | None = Field(default=None) diff --git a/tests/config/common.py b/tests/config/common.py index b671c4af..9ccfb597 100644 --- a/tests/config/common.py +++ b/tests/config/common.py @@ -58,7 +58,7 @@ def _validate(self) -> None: @config_class() class ExampleNestedConfig(ExampleConfig): - nested_field: ExampleConfig = Field(default_factory=ExampleConfig, hint=FieldHint.core) + nested_field: ExampleConfig = Field(hint=FieldHint.core) def check_config( diff --git a/tests/config/test_config.py b/tests/config/test_config.py new file mode 100644 index 00000000..4c473fa6 --- /dev/null +++ b/tests/config/test_config.py @@ -0,0 +1,30 @@ +import pytest + +from fast_llm.config import NoAutoValidate +from tests.config.common import ExampleConfig + + +def test_auto_validate(): + assert (config := ExampleConfig())._validated + with pytest.raises(RuntimeError): + config.bool_field = True + config.bool_field = False + + assert ExampleConfig.from_dict({})._validated + + with NoAutoValidate(): + assert not (config := ExampleConfig())._validated + + config.bool_field = True + + config.validate() + + assert config._validated + with pytest.raises(RuntimeError): + config.bool_field = False + config.bool_field = True + + with NoAutoValidate(): + assert not (config := ExampleConfig.from_dict({}))._validated + config.validate() + assert config._validated diff --git a/tests/data/common.py b/tests/data/common.py index 47b53195..00c3ff20 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -31,7 +31,6 @@ def get_sampling_data( *, seed: int = 54983, cache_directory: pathlib.Path | None = None, - distributed: Distributed = Distributed(DistributedConfig(), use_cpu=True), phase=PhaseType.training, sequence_length: int = 512, vocab_size=TEST_VOCAB_SIZE, @@ -41,6 +40,7 @@ def get_sampling_data( truncate_documents=True, ) -> GPTSamplingData: # Config with convenient defaults. + distributed = Distributed(DistributedConfig(), use_cpu=True) return GPTSamplingData( config=GPTSamplingConfig( seed=seed, diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 9dd7975c..ccb94c23 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from fast_llm.data.dataset.gpt.config import GPTIndexedDatasetConfig +from fast_llm.data.dataset.gpt.config import GPTBlendedDatasetConfig, GPTDatasetSliceConfig, GPTIndexedDatasetConfig from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, GPTMemmapDatasetPreparatorConfig @@ -77,12 +77,12 @@ def test_absent_metadata_local(): DATASET_DICT_0 = { - "type": "mock_memmap", + "type": MockGPTMemmapDatasetConfig.__name__, "num_documents": 500, "num_tokens_per_document": 300, } DATASET_DICT_1 = { - "type": "mock_memmap", + "type": MockGPTMemmapDatasetConfig.__name__, "num_documents": 1500, "num_tokens_per_document": 100, } @@ -101,13 +101,13 @@ def test_split_dataset(): config, { "training": { - "type": "slice", + "type": GPTDatasetSliceConfig.__name__, "dataset": dataset_config_0.to_dict(), "begin": 0, "end": 0.75, }, "validation": { - "type": "slice", + "type": GPTDatasetSliceConfig.__name__, "dataset": dataset_config_0.to_dict(), "begin": 0.75, "end": 1, @@ -147,11 +147,11 @@ def test_split_datasets_1(): config, { "training": { - "type": "blended", + "type": GPTBlendedDatasetConfig.__name__, "datasets": [ dataset_config_0.to_dict(), { - "type": "slice", + "type": GPTDatasetSliceConfig.__name__, "dataset": dataset_config_1.to_dict(), "begin": 0, "end": 0.5, @@ -160,7 +160,7 @@ def test_split_datasets_1(): "weights": [2 / 3, 1 / 3], }, "validation": { - "type": "slice", + "type": GPTDatasetSliceConfig.__name__, "dataset": dataset_config_1.to_dict(), "begin": 0.5, "end": 1, From 5735d218e7d49116606b2c035b047520aeb1cb3e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 9 May 2025 13:00:54 -0400 Subject: [PATCH 08/26] fix --- fast_llm/layers/transformer/config.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index a52933a9..8e71ab03 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -304,6 +304,14 @@ class TransformerPeftConfig(PeftConfig): def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": pass + @abc.abstractmethod + def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": + pass + + @abc.abstractmethod + def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": + pass + @classmethod def _from_dict( cls, From 6357365aa95a52559435357969a76967f3eeffb4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 9 May 2025 17:18:05 -0400 Subject: [PATCH 09/26] stuff --- fast_llm/cli.py | 48 +++++++++++++++++++ fast_llm/config.py | 10 ++-- fast_llm/data/auto.py | 16 ++----- fast_llm/data/preparator/gpt_memmap/config.py | 4 ++ .../{tools => engine/checkpoint}/convert.py | 25 ++++------ fast_llm/engine/config_utils/runnable.py | 3 ++ fast_llm/engine/training/config.py | 1 - fast_llm/layers/transformer/config.py | 5 +- fast_llm/models/custom/config.py | 6 +++ fast_llm/models/gpt/config.py | 5 ++ fast_llm/models/ssm/config.py | 5 ++ fast_llm/tools/__init__.py | 0 fast_llm/tools/cli.py | 35 -------------- fast_llm/tools/prepare_dataset.py | 24 ---------- fast_llm/tools/train.py | 24 ---------- setup.cfg | 2 +- tests/test_checkpoint.py | 16 +++---- tests/test_config.py | 6 +-- tools/push_model.py | 4 +- 19 files changed, 106 insertions(+), 133 deletions(-) create mode 100644 fast_llm/cli.py rename fast_llm/{tools => engine/checkpoint}/convert.py (90%) delete mode 100644 fast_llm/tools/__init__.py delete mode 100644 fast_llm/tools/cli.py delete mode 100644 fast_llm/tools/prepare_dataset.py delete mode 100644 fast_llm/tools/train.py diff --git a/fast_llm/cli.py b/fast_llm/cli.py new file mode 100644 index 00000000..bce0e097 --- /dev/null +++ b/fast_llm/cli.py @@ -0,0 +1,48 @@ +import logging +import sys +import traceback + +from fast_llm.config import ValidationError +from fast_llm.engine.config_utils.logging import configure_logging +from fast_llm.engine.config_utils.run import log_main_rank +from fast_llm.engine.config_utils.runnable import RunnableConfig + +# Import these submodules to ensure classes are added to the dynamic class registry. +import fast_llm.data.auto # isort: skip +import fast_llm.engine.checkpoint.convert # isort: skip +import fast_llm.models.auto # isort: skip + +logger = logging.getLogger(__name__) + + +def fast_llm_main(args=None): + # TODO: Add hook to register model classes? (environment variable?) + # (Pre-)configure logging + configure_logging() + try: + if args is None: + args = sys.argv[1:] + # TODO: Remove backward compatibility. + if len(args) >= 2 and args[0] == "train": + if args[1] == "gpt": + args = ["type=train_gpt"] + args[2:] + elif args[1] == "hybrid_ssm": + args = ["type=train_hybrid_ssm"] + args[2:] + elif len(args) >= 2 and args[0] == "convert": + if "=" not in args[1]: + args = ["type=convert", f"model={args[1]}"] + args[2:] + elif len(args) >= 2 and args[0] == "prepare" and args[1] == "gpt_memmap": + args = ["type=prepare_gpt_memmap"] + args[2:] + RunnableConfig.parse_and_run(args) + except Exception as e: + if sys.gettrace(): + raise + if isinstance(e, ValidationError): + log_main_rank(traceback.format_exc(), log_fn=logger.error) + else: + logger.critical(traceback.format_exc()) + sys.exit(1) + + +if __name__ == "__main__": + fast_llm_main() diff --git a/fast_llm/config.py b/fast_llm/config.py index 732e7611..2c03fda2 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -147,6 +147,9 @@ def __init__( raise ValueError("cannot specify both default and default_factory") if isinstance(default_factory, type) and issubclass(default_factory, Config): raise ValueError("Config classes should not be used as `default_factory`") + if not init: + # Non-init fields cause errors when printed before validation. + repr = False super().__init__( default=default, default_factory=default_factory, @@ -275,12 +278,9 @@ def __init__(self, **kwargs): class ConfigMeta(abc.ABCMeta): def __call__(cls: "type[Config]", **kwargs): # Always go through `_from_dict` for correct dynamic class selection and nested config instantiation. - print("AIKDNJOINF", cls) if not kwargs.pop("_from_dict_check", False): - print("AAA") - with NoAutoValidate(): - return cls._from_dict(kwargs) - print("BBB", kwargs) + # with NoAutoValidate(): + return cls._from_dict(kwargs) return super().__call__(**kwargs) diff --git a/fast_llm/data/auto.py b/fast_llm/data/auto.py index 902faf1c..c44e538f 100644 --- a/fast_llm/data/auto.py +++ b/fast_llm/data/auto.py @@ -1,13 +1,5 @@ -from fast_llm.data.preparator.config import DatasetPreparatorConfig -from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig -from fast_llm.utils import Registry +""" +Import these submodules to ensure classes are added to the dynamic class registry. +""" -dataset_preparator_registry = Registry[str, DatasetPreparatorConfig]( - "DatasetPreparator", - { - dataset_preparator.preparator_name: dataset_preparator - for dataset_preparator in [ - GPTMemmapDatasetPreparatorConfig, - ] - }, -) +from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig # isort: skip diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 7091f3c8..775c367c 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -6,6 +6,7 @@ from fast_llm.data.config import TokenizerConfig from fast_llm.data.preparator.config import DatasetPreparatorConfig from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -173,3 +174,6 @@ def get_dataset_preparator_class(cls) -> type["GPTMemmapDatasetPreparator"]: from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator return GPTMemmapDatasetPreparator + + +RunnableConfig.register_subclass("prepare_gpt_memmap", GPTMemmapDatasetPreparatorConfig) diff --git a/fast_llm/tools/convert.py b/fast_llm/engine/checkpoint/convert.py similarity index 90% rename from fast_llm/tools/convert.py rename to fast_llm/engine/checkpoint/convert.py index 7f327618..998e8b8b 100644 --- a/fast_llm/tools/convert.py +++ b/fast_llm/engine/checkpoint/convert.py @@ -19,23 +19,13 @@ @config_class() -class ConversionConfig(RunnableConfig): +class ConvertConfig(RunnableConfig): input: CheckpointLoadConfig = Field() output: CheckpointSaveConfig = Field() use_cpu: bool = Field(default=False) exist_ok: bool = Field(default=False) layers_per_step: int | None = Field(default=None) - model_config_class: type[FastLLMModelConfig] = Field(default=None) - - @classmethod - def _get_parser(cls): - parser = super()._get_parser() - parser.add_argument( - "model_type", - choices=model_registry.keys(), - help="The Fast-LLM model type to use. Must be defined in the model registry in `fast_llm.models.auto`.", - ) - return parser + model: type[FastLLMModelConfig] = Field(default=None) @classmethod def _from_parsed_args(cls, parsed: argparse.Namespace, unparsed: list[str]): @@ -44,9 +34,11 @@ def _from_parsed_args(cls, parsed: argparse.Namespace, unparsed: list[str]): return config def _validate(self): - assert self.model_config_class is not None - self.input.setup(self.model_config_class) - self.output.setup(self.model_config_class) + assert self.model is not None + if isinstance(self.model, str): + self.model = FastLLMModelConfig.get_subclass(self.model) + self.input.setup(self.model) + self.output.setup(self.model) super()._validate() def _convert_model_partial( @@ -160,5 +152,4 @@ def run(self): logger.info(f">>> All done!") -if __name__ == "__main__": - ConversionConfig.parse_and_run() +RunnableConfig.register_subclass("convert", ConvertConfig) diff --git a/fast_llm/engine/config_utils/runnable.py b/fast_llm/engine/config_utils/runnable.py index 6142de47..6c105491 100644 --- a/fast_llm/engine/config_utils/runnable.py +++ b/fast_llm/engine/config_utils/runnable.py @@ -20,6 +20,9 @@ class RunnableConfig(Config): @classmethod def parse_and_run(cls, args=None) -> None: + if len(args) >= 1 and "=" not in args[0]: + # Make the `type=` part optional. + args = [f"type={args[0]}"] + args[1:] parsed, unparsed = cls._get_parser().parse_known_args(args) with NoAutoValidate(): config: "RunnableConfig" = cls._from_parsed_args(parsed, unparsed) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 0b572779..768d42f5 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -255,7 +255,6 @@ class TrainingExportConfig(TrainingCheckpointBaseConfig, CheckpointStateSaveConf offset = FieldUpdate(desc="Offset for the first export.") callback: CallbackConfig = FieldUpdate(desc="Callback (shell script) to run after export.") - @abc.abstractmethod def get_save_directory(self, experiment_directory: pathlib.Path) -> pathlib.Path: return experiment_directory / "export" / self.format.name diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 8e71ab03..7ab0a299 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -326,9 +326,12 @@ def _from_dict( @config_class() -class TransformerNoPeftConfig(TransformerPeftConfig, NoPeftConfig): +class TransformerNoPeftConfig(NoPeftConfig, TransformerPeftConfig): _abstract = False + def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": + return super().apply_linear(linear) + def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": return module diff --git a/fast_llm/models/custom/config.py b/fast_llm/models/custom/config.py index 08902e2c..5bc4bb13 100644 --- a/fast_llm/models/custom/config.py +++ b/fast_llm/models/custom/config.py @@ -2,6 +2,8 @@ from fast_llm.config import FieldUpdate, config_class from fast_llm.data.data.gpt.config import GPTDataConfig +from fast_llm.engine.config_utils.runnable import RunnableConfig +from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig, GPTTrainerConfig, PretrainedGPTModelConfig if typing.TYPE_CHECKING: @@ -57,3 +59,7 @@ def get_trainer_class(cls) -> type["CustomTrainer"]: from fast_llm.models.custom.trainer import CustomTrainer return CustomTrainer + + +FastLLMModelConfig.register_subclass("gpt_custom", GPTModelConfig) +RunnableConfig.register_subclass("train_gpt_custom", GPTTrainerConfig) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 0ec3fb51..e6798339 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -4,6 +4,7 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler +from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig @@ -210,3 +211,7 @@ def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: from fast_llm.models.gpt.model import GPTInferenceRunner return GPTInferenceRunner + + +FastLLMModelConfig.register_subclass("gpt", GPTModelConfig) +RunnableConfig.register_subclass("train_gpt", GPTTrainerConfig) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 771a4fca..cc41d89b 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -5,6 +5,7 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, config_class from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler +from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig @@ -166,3 +167,7 @@ def get_trainer_class(cls) -> type["SSMTrainer"]: from fast_llm.models.ssm.trainer import SSMTrainer return SSMTrainer + + +FastLLMModelConfig.register_subclass("hybrid_ssm", HybridSSMModelConfig) +RunnableConfig.register_subclass("train_hybrid_ssm", HybridTrainerConfig) diff --git a/fast_llm/tools/__init__.py b/fast_llm/tools/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/fast_llm/tools/cli.py b/fast_llm/tools/cli.py deleted file mode 100644 index f7148322..00000000 --- a/fast_llm/tools/cli.py +++ /dev/null @@ -1,35 +0,0 @@ -import logging -import sys -import traceback - -from fast_llm.config import ValidationError -from fast_llm.engine.config_utils.logging import configure_logging -from fast_llm.engine.config_utils.run import log_main_rank -from fast_llm.engine.config_utils.runnable import RunnableConfig - -# TODO: These imports indirectly adds all known runnables to the config registry, need a better way? -from fast_llm.data.preparator.config import DatasetPreparatorConfig # isort: skip -from fast_llm.models.auto import trainer_registry # isort: skip -from fast_llm.tools.convert import ConversionConfig # isort: skip - -logger = logging.getLogger(__name__) - - -def fast_llm(args=None): - # TODO: Add hook to register model classes? (environment variable?) - # (Pre-)configure logging - configure_logging() - try: - RunnableConfig.parse_and_run(args) - except Exception as e: - if sys.gettrace(): - raise - if isinstance(e, ValidationError): - log_main_rank(traceback.format_exc(), log_fn=logger.error) - else: - logger.critical(traceback.format_exc()) - sys.exit(1) - - -if __name__ == "__main__": - fast_llm() diff --git a/fast_llm/tools/prepare_dataset.py b/fast_llm/tools/prepare_dataset.py deleted file mode 100644 index aafe2690..00000000 --- a/fast_llm/tools/prepare_dataset.py +++ /dev/null @@ -1,24 +0,0 @@ -import argparse - -from fast_llm.data.auto import dataset_preparator_registry -from fast_llm.engine.config_utils.runnable import RunnableConfig - - -class PrepareDatasetConfig(RunnableConfig): - @classmethod - def _get_parser(cls): - parser = super()._get_parser() - parser.add_argument( - "model_type", - choices=dataset_preparator_registry.keys(), - help="The Fast-LLM model type to use. Must be defined in the model registry in `fast_llm.models.auto`.", - ) - return parser - - @classmethod - def _from_parsed_args(cls, parsed: argparse.Namespace, unparsed: list[str]): - return dataset_preparator_registry[parsed.model_type]._from_parsed_args(parsed, unparsed) - - -if __name__ == "__main__": - PrepareDatasetConfig.parse_and_run() diff --git a/fast_llm/tools/train.py b/fast_llm/tools/train.py deleted file mode 100644 index ae902279..00000000 --- a/fast_llm/tools/train.py +++ /dev/null @@ -1,24 +0,0 @@ -import argparse - -from fast_llm.engine.config_utils.runnable import RunnableConfig -from fast_llm.models.auto import trainer_registry - - -class CliTrainingConfig(RunnableConfig): - @classmethod - def _get_parser(cls): - parser = super()._get_parser() - parser.add_argument( - "model_type", - choices=trainer_registry.keys(), - help="The Fast-LLM model type to use. Must be defined in the trainer registry in `fast_llm.models.auto`.", - ) - return parser - - @classmethod - def _from_parsed_args(cls, parsed: argparse.Namespace, unparsed: list[str]): - return trainer_registry[parsed.model_type]._from_parsed_args(parsed, unparsed) - - -if __name__ == "__main__": - CliTrainingConfig.parse_and_run() diff --git a/setup.cfg b/setup.cfg index 9b944b27..a48759c1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -68,4 +68,4 @@ DOCS = [options.entry_points] console_scripts = - fast-llm = fast_llm.tools.cli:fast_llm + fast-llm = fast_llm.cli:fast_llm_main diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 257947e9..9c3d4611 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -14,10 +14,10 @@ FastLLMCheckpointFormat, ModelConfigType, ) +from fast_llm.engine.checkpoint.convert import ConvertConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.engine.multi_stage.multi_stage import ShardName from fast_llm.models.auto import model_registry -from fast_llm.tools.convert import ConversionConfig from tests.common import ( CONFIG_COMMON, FORCE_REUSE_RESULTS, @@ -90,7 +90,7 @@ def test_resume(): ) -def _run_conversion(config: ConversionConfig): +def _run_conversion(config: ConvertConfig): if config.output.path.is_dir() and not REUSE_RESULTS: shutil.rmtree(config.output.path) if not config.output.path.is_dir(): @@ -106,7 +106,7 @@ def _run_conversion(config: ConversionConfig): @pytest.mark.depends(on=["test_checkpoint_and_eval"]) def test_convert_distributed_to_fast_llm(): _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CKPT_PATH, format=DistributedCheckpointFormat, @@ -125,7 +125,7 @@ def test_convert_fast_llm_to_huggingface(): if HUGGINGFACE_CHECKPOINT_FORMAT is None: pytest.skip(f"Conversion not supported for {TEST_MODEL}") _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CONVERT_PATH / "fast_llm_0", format=FastLLMCheckpointFormat, @@ -142,7 +142,7 @@ def test_convert_fast_llm_to_huggingface(): @pytest.mark.depends(on=["test_convert_fast_llm_to_huggingface"]) def test_convert_huggingface_to_distributed(): _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_0", format=HUGGINGFACE_CHECKPOINT_FORMAT, @@ -161,7 +161,7 @@ def test_convert_distributed_to_huggingface(): if HUGGINGFACE_CHECKPOINT_FORMAT is None: pytest.skip(f"Conversion not supported for {TEST_MODEL}") _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CKPT_PATH, format=DistributedCheckpointFormat, @@ -178,7 +178,7 @@ def test_convert_distributed_to_huggingface(): @pytest.mark.depends(on=["test_convert_distributed_to_huggingface"]) def test_convert_huggingface_to_fast_llm(): _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_1", format=HUGGINGFACE_CHECKPOINT_FORMAT, @@ -195,7 +195,7 @@ def test_convert_huggingface_to_fast_llm(): @pytest.mark.depends(on=["test_convert_huggingface_to_fast_llm"]) def test_convert_fast_llm_to_distributed(): _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CONVERT_PATH / "fast_llm_1", format=FastLLMCheckpointFormat, diff --git a/tests/test_config.py b/tests/test_config.py index 80bed418..cc239043 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -44,17 +44,17 @@ def run_without_import(cmd: str): def test_validate_train_gpt_without_import(): - run_without_import("main(['train', 'gpt', '-v'])") + run_without_import("main(['train_gpt', '-v'])") def test_validate_prepare_gpt_memmap_without_import(): run_without_import( - "main(['prepare', 'gpt_memmap', '-v', 'dataset.path=test', 'output_path=test', 'tokenizer.path=test'])" + "main(['prepare_gpt_memmap', '-v', 'dataset.path=test', 'output_path=test', 'tokenizer.path=test'])" ) def test_validate_convert_gpt_without_import(): - run_without_import("main(['convert', 'gpt', '-v'])") + run_without_import("main(['convert', 'model=gpt', '-v'])") def test_validate_example_config(): diff --git a/tools/push_model.py b/tools/push_model.py index cd98b93c..edab3312 100644 --- a/tools/push_model.py +++ b/tools/push_model.py @@ -27,7 +27,7 @@ raise ImportError("Please install huggingface_hub to use this script") from e -from fast_llm.tools.convert import ConversionConfig # isort:skip +from fast_llm.tools.convert import ConvertConfig # isort:skip logger = logging.getLogger(__name__) @@ -147,7 +147,7 @@ def run(self) -> None: for _, checkpoint_path in new_checkpoint_paths: checkpoint_path_hf = checkpoint_path.with_name(checkpoint_path.name + "_hf") # Block until the conversion is done - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=checkpoint_path, format=DistributedCheckpointFormat, From 207aef089a2ab8eb582c3c2fb13c3cb6efe457e5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 12 May 2025 13:09:45 -0400 Subject: [PATCH 10/26] stuff --- fast_llm/cli.py | 15 +--------- fast_llm/engine/checkpoint/convert.py | 8 ------ fast_llm/engine/config_utils/runnable.py | 16 +++++++---- fast_llm/models/auto.py | 35 ++++-------------------- fast_llm/models/custom/config.py | 4 +-- fast_llm/models/gpt/config.py | 3 +- fast_llm/models/ssm/config.py | 13 ++++----- fast_llm/models/ssm/trainer.py | 6 ++-- tests/common.py | 6 ++-- tests/test_checkpoint.py | 3 +- tests/test_config.py | 5 ++-- tests/test_multi_stage.py | 6 ++-- 12 files changed, 38 insertions(+), 82 deletions(-) diff --git a/fast_llm/cli.py b/fast_llm/cli.py index bce0e097..34546120 100644 --- a/fast_llm/cli.py +++ b/fast_llm/cli.py @@ -15,24 +15,11 @@ logger = logging.getLogger(__name__) -def fast_llm_main(args=None): +def fast_llm_main(args: list[str] | None = None): # TODO: Add hook to register model classes? (environment variable?) # (Pre-)configure logging configure_logging() try: - if args is None: - args = sys.argv[1:] - # TODO: Remove backward compatibility. - if len(args) >= 2 and args[0] == "train": - if args[1] == "gpt": - args = ["type=train_gpt"] + args[2:] - elif args[1] == "hybrid_ssm": - args = ["type=train_hybrid_ssm"] + args[2:] - elif len(args) >= 2 and args[0] == "convert": - if "=" not in args[1]: - args = ["type=convert", f"model={args[1]}"] + args[2:] - elif len(args) >= 2 and args[0] == "prepare" and args[1] == "gpt_memmap": - args = ["type=prepare_gpt_memmap"] + args[2:] RunnableConfig.parse_and_run(args) except Exception as e: if sys.gettrace(): diff --git a/fast_llm/engine/checkpoint/convert.py b/fast_llm/engine/checkpoint/convert.py index 998e8b8b..97d9643d 100644 --- a/fast_llm/engine/checkpoint/convert.py +++ b/fast_llm/engine/checkpoint/convert.py @@ -1,4 +1,3 @@ -import argparse import json import logging import math @@ -9,7 +8,6 @@ from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.functional.config import TritonConfig -from fast_llm.models.auto import model_registry from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -27,12 +25,6 @@ class ConvertConfig(RunnableConfig): layers_per_step: int | None = Field(default=None) model: type[FastLLMModelConfig] = Field(default=None) - @classmethod - def _from_parsed_args(cls, parsed: argparse.Namespace, unparsed: list[str]): - config = super()._from_parsed_args(parsed, unparsed) - config.model_config_class = model_registry[parsed.model_type] - return config - def _validate(self): assert self.model is not None if isinstance(self.model, str): diff --git a/fast_llm/engine/config_utils/runnable.py b/fast_llm/engine/config_utils/runnable.py index 6c105491..01d24eaa 100644 --- a/fast_llm/engine/config_utils/runnable.py +++ b/fast_llm/engine/config_utils/runnable.py @@ -19,13 +19,17 @@ @config_class() class RunnableConfig(Config): @classmethod - def parse_and_run(cls, args=None) -> None: - if len(args) >= 1 and "=" not in args[0]: - # Make the `type=` part optional. - args = [f"type={args[0]}"] + args[1:] - parsed, unparsed = cls._get_parser().parse_known_args(args) + def parse_and_run(cls, args: list[str] | None = None) -> None: + if args is None: + args = sys.argv[1:] + cls_ = cls + while len(args) >= 1 and "=" not in args[0]: + # Allow chained dynamic type selection without the `type=`, ex. `train gpt`. + cls_ = cls_.get_subclass(args[0]) + args = args[1:] + parsed, unparsed = cls_._get_parser().parse_known_args([f"type={cls_.__name__}"] + args) with NoAutoValidate(): - config: "RunnableConfig" = cls._from_parsed_args(parsed, unparsed) + config: "RunnableConfig" = cls_._from_parsed_args(parsed, unparsed) try: config.configure_logging() config.validate() diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index 8f16aaea..3be74856 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -1,30 +1,7 @@ -from fast_llm.engine.multi_stage.config import FastLLMModelConfig -from fast_llm.engine.training.config import TrainerConfig -from fast_llm.models.custom.config import CustomModelConfig, CustomTrainerConfig -from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig -from fast_llm.models.ssm.config import HybridSSMModelConfig, HybridTrainerConfig -from fast_llm.utils import Registry +""" +Import these submodules to ensure classes are added to the dynamic class registry. +""" -model_registry = Registry[str, FastLLMModelConfig]( - "Model", - { - model.model_name: model - for model in [ - GPTModelConfig, - CustomModelConfig, - HybridSSMModelConfig, - ] - }, -) - -trainer_registry = Registry[str, TrainerConfig]( - "Model", - { - trainer.get_field("model").type.model_name: trainer - for trainer in [ - GPTTrainerConfig, - CustomTrainerConfig, - HybridTrainerConfig, - ] - }, -) +from fast_llm.models.custom.config import CustomModelConfig, CustomTrainerConfig # isort: skip +from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig # isort: skip +from fast_llm.models.ssm.config import HybridSSMModelConfig, HybridSSMTrainerConfig # isort: skip diff --git a/fast_llm/models/custom/config.py b/fast_llm/models/custom/config.py index 5bc4bb13..f09657e5 100644 --- a/fast_llm/models/custom/config.py +++ b/fast_llm/models/custom/config.py @@ -2,8 +2,8 @@ from fast_llm.config import FieldUpdate, config_class from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.engine.training.config import TrainerConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig, GPTTrainerConfig, PretrainedGPTModelConfig if typing.TYPE_CHECKING: @@ -62,4 +62,4 @@ def get_trainer_class(cls) -> type["CustomTrainer"]: FastLLMModelConfig.register_subclass("gpt_custom", GPTModelConfig) -RunnableConfig.register_subclass("train_gpt_custom", GPTTrainerConfig) +TrainerConfig.register_subclass("gpt_custom", CustomTrainerConfig) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index e6798339..3c889e4e 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -4,7 +4,6 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler -from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig @@ -214,4 +213,4 @@ def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: FastLLMModelConfig.register_subclass("gpt", GPTModelConfig) -RunnableConfig.register_subclass("train_gpt", GPTTrainerConfig) +TrainerConfig.register_subclass("gpt", GPTTrainerConfig) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index cc41d89b..7db2a2b3 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -5,7 +5,6 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, config_class from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler -from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig @@ -17,7 +16,7 @@ if typing.TYPE_CHECKING: from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM from fast_llm.models.ssm.model import HybridSSMModel - from fast_llm.models.ssm.trainer import SSMTrainer + from fast_llm.models.ssm.trainer import HybridSSMTrainer logger = logging.getLogger(__name__) @@ -158,16 +157,16 @@ class PretrainedHybridSSMModelConfig(PretrainedFastLLMModelConfig): @config_class() -class HybridTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): +class HybridSSMTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): data: GPTDataConfig = FieldUpdate() batch: GPTBatchConfig = FieldUpdate() @classmethod - def get_trainer_class(cls) -> type["SSMTrainer"]: - from fast_llm.models.ssm.trainer import SSMTrainer + def get_trainer_class(cls) -> type["HybridSSMTrainer"]: + from fast_llm.models.ssm.trainer import HybridSSMTrainer - return SSMTrainer + return HybridSSMTrainer FastLLMModelConfig.register_subclass("hybrid_ssm", HybridSSMModelConfig) -RunnableConfig.register_subclass("train_hybrid_ssm", HybridTrainerConfig) +TrainerConfig.register_subclass("hybrid_ssm", HybridSSMTrainerConfig) diff --git a/fast_llm/models/ssm/trainer.py b/fast_llm/models/ssm/trainer.py index c0e5be26..efa7b704 100644 --- a/fast_llm/models/ssm/trainer.py +++ b/fast_llm/models/ssm/trainer.py @@ -1,10 +1,10 @@ import typing from fast_llm.models.gpt.trainer import GPTTrainer -from fast_llm.models.ssm.config import HybridTrainerConfig +from fast_llm.models.ssm.config import HybridSSMTrainerConfig from fast_llm.models.ssm.model import HybridSSMModel -class SSMTrainer[ConfigType: HybridTrainerConfig](GPTTrainer[ConfigType]): - config_class: typing.ClassVar[type[HybridTrainerConfig]] = HybridTrainerConfig +class HybridSSMTrainer[ConfigType: HybridSSMTrainerConfig](GPTTrainer[ConfigType]): + config_class: typing.ClassVar[type[HybridSSMTrainerConfig]] = HybridSSMTrainerConfig model_class: typing.ClassVar[type[HybridSSMModel]] = HybridSSMModel diff --git a/tests/common.py b/tests/common.py index 569d690c..bcc563d7 100644 --- a/tests/common.py +++ b/tests/common.py @@ -13,6 +13,7 @@ from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.layers.ssm.config import SSMConfig from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.models.gpt.config import ( @@ -24,7 +25,6 @@ Starcoder2GPTHuggingfaceCheckpointFormat, ) from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, LLambaHuggingfaceCheckpointFormat -from fast_llm.tools.train import CliTrainingConfig from tests.compare_tensor_logs import CompareConfig, compare_tensor_logs # FIXME: figure out correct import of megatron modules without this hack @@ -392,7 +392,7 @@ def run_test_script( if is_megatron: script = [*script, f"--structured-logs-dir={path}", f"--data-cache-path={path}"] else: - script = [model_type, *script, f"run.experiment_dir={path}"] + script = ["train", model_type, *script, f"run.experiment_dir={path}"] header = ["Megatron-LM/pretrain_gpt.py"] if is_megatron else ["--no-python", "fast-llm", "train"] command = [ "python", @@ -408,7 +408,7 @@ def run_test_script( else: get_test_dataset() if num_gpus == 1 and not is_megatron: - CliTrainingConfig.parse_and_run(script) + RunnableConfig.parse_and_run(script) else: completed_proc = subprocess.run(command, env=env, timeout=60) if completed_proc.returncode: diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 9c3d4611..eb21f3b3 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -17,7 +17,6 @@ from fast_llm.engine.checkpoint.convert import ConvertConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.engine.multi_stage.multi_stage import ShardName -from fast_llm.models.auto import model_registry from tests.common import ( CONFIG_COMMON, FORCE_REUSE_RESULTS, @@ -31,7 +30,7 @@ ) from tests.compare_tensor_logs import CompareConfig, compare_logged_tensor -TEST_MODEL_CONFIG_CLS = model_registry[TEST_MODEL_TYPE] +TEST_MODEL_CONFIG_CLS = FastLLMModelConfig.get_subclass(TEST_MODEL_TYPE) TEST_MODEL_HF_CLS = TEST_MODEL_CONFIG_CLS.get_huggingface_model_class() TEST_MODEL_CLS = TEST_MODEL_CONFIG_CLS.get_model_class() TEST_BASE_MODEL_CONFIG_CLS = TEST_MODEL_CONFIG_CLS.get_base_model_config_class() diff --git a/tests/test_config.py b/tests/test_config.py index cc239043..1ea225b7 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -11,8 +11,7 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.models.auto import trainer_registry -from fast_llm.models.gpt.config import GPTModelConfig, PretrainedGPTModelConfig +from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert, check_equal_nested from tests.common import TEST_RESULTS_PATH @@ -61,7 +60,7 @@ def test_validate_example_config(): fast_llm_config_dict = yaml.safe_load( (pathlib.Path(__file__).parents[1] / "examples" / "mistral.yaml").read_text() ) - trainer_registry["gpt"].from_dict(fast_llm_config_dict) + GPTTrainerConfig.from_dict(fast_llm_config_dict) def test_do_use_flash_attention(): diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index bb468ceb..aeea5b8c 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -2,14 +2,14 @@ from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.tools.train import CliTrainingConfig from fast_llm.utils import Assert from tests.common import CONFIG_COMMON, requires_cuda def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: - parsed, unparsed = CliTrainingConfig._get_parser().parse_known_args([model_type] + args) - config: TrainerConfig = CliTrainingConfig._from_parsed_args(parsed, unparsed) + cls = TrainerConfig.get_subclass(model_type) + parsed, unparsed = cls._get_parser().parse_known_args(args) + config: TrainerConfig = cls._from_parsed_args(parsed, unparsed) distributed = Distributed(config.model.distributed) trainer = config.get_trainer_class()(config=config) trainer.setup(distributed, config.get_run(distributed)) From 31579bd01afa2db6b533e48e990117b8a26c2ca9 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 13 May 2025 11:16:30 -0400 Subject: [PATCH 11/26] Bring back default_factory --- fast_llm/config.py | 21 +++++------- fast_llm/data/data/config.py | 4 ++- fast_llm/data/data/gpt/config.py | 3 +- fast_llm/data/dataset/config.py | 2 ++ fast_llm/data/dataset/gpt/config.py | 5 +-- fast_llm/data/preparator/config.py | 3 ++ fast_llm/data/preparator/gpt_memmap/config.py | 4 +++ fast_llm/engine/base_model/config.py | 1 + fast_llm/engine/checkpoint/convert.py | 4 +-- fast_llm/engine/checkpoint/state_dict.py | 2 -- fast_llm/engine/config_utils/run.py | 8 +++-- fast_llm/engine/config_utils/runnable.py | 2 +- fast_llm/engine/multi_stage/config.py | 21 +++++++----- fast_llm/engine/optimizer/config.py | 2 ++ fast_llm/engine/training/config.py | 33 +++++++++++++++---- fast_llm/layers/language_model/config.py | 1 + fast_llm/layers/ssm/config.py | 1 + fast_llm/models/custom/config.py | 10 +++--- fast_llm/models/gpt/config.py | 10 +++--- fast_llm/models/ssm/config.py | 11 ++++--- tests/config/common.py | 2 +- tests/test_config.py | 14 ++++---- tests/test_ssms.py | 2 ++ 23 files changed, 109 insertions(+), 57 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 2c03fda2..f7931697 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -143,10 +143,8 @@ def __init__( metadata=None, kw_only=dataclasses.MISSING, ): - if default is not dataclasses.MISSING and default_factory is not dataclasses.MISSING: - raise ValueError("cannot specify both default and default_factory") - if isinstance(default_factory, type) and issubclass(default_factory, Config): - raise ValueError("Config classes should not be used as `default_factory`") + if (default is dataclasses.MISSING) == (default_factory is dataclasses.MISSING): + raise ValueError("Fields should define exactly one of `default` or `default_factory`.") if not init: # Non-init fields cause errors when printed before validation. repr = False @@ -782,9 +780,9 @@ def _from_dict( else: # Check for nested configs to instantiate. try: - value = cls._from_dict_nested(default.pop(name, MISSING), field.type, strict) - if value is not MISSING: - out_arg_dict[name] = value + if name in default: + out_arg_dict[name] = cls._from_dict_nested(default[name], field.type, strict) + except FieldTypeError as e: raise FieldTypeError( f"Invalid field type `{get_type_name(field.type)}` in class {cls._get_class_name()}: " @@ -817,11 +815,8 @@ def _from_dict_nested(cls, value, type_, strict: bool): raise FieldTypeError(f"Unsupported __origin__ `{origin}`") elif not isinstance(type_, type): raise FieldTypeError(f"Not a type: {type_}.") - elif issubclass(type_, Config): - if value is MISSING: - value = {} - if isinstance(value, dict): - value = type_._from_dict(value, strict) + elif issubclass(type_, Config) and isinstance(value, dict): + value = type_._from_dict(value, strict) return value @classmethod @@ -893,11 +888,11 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ f"Config comparison errors:\n " + "\n".join(errors), log_fn=log_fn, ) + return None @classmethod def register_subclass(cls, name: str, cls_: type[typing.Self]) -> None: Assert.custom(issubclass, cls_, cls) - assert not cls_._abstract if name in cls._registry: old_cls = cls._registry[name] if old_cls.__name__ == cls_.__name__ and cls._registry[name].__module__ == cls_.__module__: diff --git a/fast_llm/data/data/config.py b/fast_llm/data/data/config.py index 41dbb5d9..25850ac3 100644 --- a/fast_llm/data/data/config.py +++ b/fast_llm/data/data/config.py @@ -9,4 +9,6 @@ class DataConfig(Config): _abstract = True _sampling_config_class: typing.ClassVar[type[SamplingData]] - sampling: SamplingConfig = Field(desc="Default configuration for dataset sampling.") + sampling: SamplingConfig = Field( + default_factory=SamplingConfig, desc="Default configuration for dataset sampling." + ) diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 85bcc656..6c598c0c 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -27,6 +27,7 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): _abstract = False tokenizer: TokenizerConfig = Field( + default_factory=TokenizerConfig, desc="Configuration for the tokenizer (for FIM).", hint=FieldHint.feature, ) @@ -36,7 +37,7 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): desc="Configuration for the dataset(s).", hint=FieldHint.core, ) - sampling: GPTSamplingConfig = FieldUpdate() + sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig) data_sample_warn_time_ms: float = Field( default=1000, desc="Warn if a sample takes too long to load.", diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 1bb4b6be..7901d6e7 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -174,10 +174,12 @@ class SampledDatasetUpdateConfig(SampledDatasetConfig): _abstract = True sampling: SamplingConfig = Field( + default_factory=SamplingConfig, desc="Optional override to sampling configuration parameters.", hint=FieldHint.core, ) dataset: SampledDatasetConfig = Field( + default_factory=SampledDatasetConfig, desc="The dataset to sample from.", hint=FieldHint.core, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index df38474e..6c48e170 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -173,8 +173,8 @@ def build(self) -> "GPTDatasetSlice": @config_class() class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig): _abstract = False - sampling: GPTSamplingConfig = FieldUpdate() - dataset: GPTSampledDatasetConfig = FieldUpdate() + sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig) + dataset: GPTSampledDatasetConfig = FieldUpdate(default_factory=GPTSampledDatasetConfig) @config_class() @@ -389,6 +389,7 @@ class GPTLegacyConfig(Config): valid=_validate_path, ) fim: FimConfig = Field( + default_factory=FimConfig, desc="Configuration for Fill In the Middle (FIM).", hint=FieldHint.feature, ) diff --git a/fast_llm/data/preparator/config.py b/fast_llm/data/preparator/config.py index edf088c0..b2068ddc 100644 --- a/fast_llm/data/preparator/config.py +++ b/fast_llm/data/preparator/config.py @@ -24,3 +24,6 @@ class DatasetPreparator[ConfigType: DatasetPreparatorConfig](Configurable[Config @abc.abstractmethod def run(self) -> None: raise NotImplementedError + + +RunnableConfig.register_subclass("prepare", DatasetPreparatorConfig) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 775c367c..bb16a82d 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -121,6 +121,7 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.core, ) distributed: DatasetPreparatorDistributedConfig = Field( + default_factory=FimConfig, desc="Configuration for distributed processing.", hint=FieldHint.feature, ) @@ -149,10 +150,12 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): valid=check_field(Assert.geq, 1), ) dataset: GPTHuggingfaceDatasetConfig = Field( + default_factory=GPTHuggingfaceDatasetConfig, desc="Configuration for the dataset.", hint=FieldHint.feature, ) tokenizer: TokenizerConfig = Field( + default_factory=TokenizerConfig, desc="Configuration for the tokenizer.", hint=FieldHint.feature, ) @@ -177,3 +180,4 @@ def get_dataset_preparator_class(cls) -> type["GPTMemmapDatasetPreparator"]: RunnableConfig.register_subclass("prepare_gpt_memmap", GPTMemmapDatasetPreparatorConfig) +DatasetPreparatorConfig.register_subclass("gpt_memmap", GPTMemmapDatasetPreparatorConfig) diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 25f53e4a..4be42e06 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -42,6 +42,7 @@ def _get_architecture(self) -> dict[str, typing.Any]: assert isinstance(field, Field), f"{name}, {field}" if field.hint == FieldHint.architecture: architecture[name] = self._serialize_architecture_field(getattr(self, name, MISSING)) + return architecture def _serialize_architecture_field(self, value: typing.Any) -> typing.Any: if isinstance(value, BaseModelConfig): diff --git a/fast_llm/engine/checkpoint/convert.py b/fast_llm/engine/checkpoint/convert.py index 97d9643d..51bd4a5f 100644 --- a/fast_llm/engine/checkpoint/convert.py +++ b/fast_llm/engine/checkpoint/convert.py @@ -18,8 +18,8 @@ @config_class() class ConvertConfig(RunnableConfig): - input: CheckpointLoadConfig = Field() - output: CheckpointSaveConfig = Field() + input: CheckpointLoadConfig = Field(default_factory=CheckpointLoadConfig) + output: CheckpointSaveConfig = Field(default_factory=CheckpointSaveConfig) use_cpu: bool = Field(default=False) exist_ok: bool = Field(default=False) layers_per_step: int | None = Field(default=None) diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index d6807138..556e97be 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -26,8 +26,6 @@ logger = logging.getLogger(__name__) -torch.distributed.gather - class StateDictCheckpointHandler(CheckpointHandler): base_file_name: typing.ClassVar[str] = "model" diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 126e0ae8..d6377409 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -20,7 +20,9 @@ @config_class() class RunConfig(Config): - tensor_logs: TensorLogsConfig = Field(desc="Configuration for debug tensor logs.", hint=FieldHint.logging) + tensor_logs: TensorLogsConfig = Field( + default_factory=TensorLogsConfig, desc="Configuration for debug tensor logs.", hint=FieldHint.logging + ) # TODO v0.3: Adjust (now only affects logging to file). structured_logs: bool = Field( default=True, desc="Configure logging to the Fast-LLM format.", hint=FieldHint.logging @@ -68,7 +70,9 @@ def _validate(self): @config_class() class ExperimentConfig(RunnableConfig): - run: RunConfig = Field(desc="Global properties for the experiment.", hint=FieldHint.core) + run: RunConfig = Field( + default_factory=RunConfig, desc="Global properties for the experiment.", hint=FieldHint.core + ) def _show( self, diff --git a/fast_llm/engine/config_utils/runnable.py b/fast_llm/engine/config_utils/runnable.py index 01d24eaa..ac10225e 100644 --- a/fast_llm/engine/config_utils/runnable.py +++ b/fast_llm/engine/config_utils/runnable.py @@ -23,7 +23,7 @@ def parse_and_run(cls, args: list[str] | None = None) -> None: if args is None: args = sys.argv[1:] cls_ = cls - while len(args) >= 1 and "=" not in args[0]: + while len(args) >= 1 and "=" not in args[0] and not args[0].startswith("-"): # Allow chained dynamic type selection without the `type=`, ex. `train gpt`. cls_ = cls_.get_subclass(args[0]) args = args[1:] diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 9434fba6..97534319 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -211,12 +211,17 @@ class FastLLMModelConfig(Config): FastLLMCheckpointFormat, ) model_name: typing.ClassVar[str] - base_model: BaseModelConfig = Field(desc="Configuration for the base model.", hint=FieldHint.core) + base_model: BaseModelConfig = Field( + default_factory=BaseModelConfig, desc="Configuration for the base model.", hint=FieldHint.core + ) multi_stage: MultiStageConfig = Field( + default_factory=MultiStageConfig, desc="Configuration for the stage breakdown of the model.", hint=FieldHint.core, ) - distributed: DistributedConfig = Field(desc="Distributed configuration.", hint=FieldHint.core) + distributed: DistributedConfig = Field( + default_factory=DistributedConfig, desc="Distributed configuration.", hint=FieldHint.core + ) @classmethod def __fast_llm_serialize__(cls) -> str: @@ -286,8 +291,11 @@ class PretrainedFastLLMModelConfig(Config): # TODO: Generalize data, schedule, logging, etc. _abstract = True # This configs may be overridden with the pretrained config during validation, so we should be careful about accessing them before. - model: FastLLMModelConfig = Field(desc="Configuration for the Fast-LLM model.", hint=FieldHint.core) + model: FastLLMModelConfig = Field( + default_factory=FastLLMModelConfig, desc="Configuration for the Fast-LLM model.", hint=FieldHint.core + ) pretrained: CheckpointLoadConfig = Field( + default_factory=CheckpointLoadConfig, desc="Configuration for loading the configuration and state of a pretrained model.", hint=FieldHint.feature, ) @@ -328,6 +336,7 @@ class CheckpointMetadata(Config): hint=FieldHint.core, ) config: FastLLMModelConfig = Field( + default_factory=FastLLMModelConfig, desc="The Fast-LLM model configuration for the saved model.", hint=FieldHint.core, ) @@ -372,13 +381,9 @@ def _from_dict( if "fast_llm_version" not in default: default["fast_llm_version"] = "0" - # Determine the model config class. - from fast_llm.models.auto import model_registry - model_config_class = default["model"] if isinstance(model_config_class, str): - Assert.incl(model_config_class, model_registry) - model_config_class = model_registry[model_config_class] + model_config_class = FastLLMModelConfig.get_subclass(default["model"]) default["model"] = model_config_class # TODO v0.3: Remove backward compatibility. diff --git a/fast_llm/engine/optimizer/config.py b/fast_llm/engine/optimizer/config.py index f4303a5d..3a154c9e 100644 --- a/fast_llm/engine/optimizer/config.py +++ b/fast_llm/engine/optimizer/config.py @@ -74,10 +74,12 @@ class GradientScalerConfig(Config): class OptimizerConfig(Config): learning_rate: LearningRateScheduleConfig = Field( + default_factory=LearningRateScheduleConfig, desc="A schedule for the learning rate.", hint=FieldHint.core, ) gradient_scaler: GradientScalerConfig = Field( + default_factory=GradientScalerConfig, desc="Configuration for the fixed or dynamic gradient scaling.", hint=FieldHint.feature, ) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 768d42f5..210f302a 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -23,6 +23,7 @@ DistributedCheckpointFormat, ) from fast_llm.engine.config_utils.run import ExperimentConfig +from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import PretrainedFastLLMModelConfig from fast_llm.engine.optimizer.config import OptimizerConfig from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig @@ -141,6 +142,7 @@ class MetricsLogsConfig(IntervalConfig): @config_class() class WandbConfig(Config): alert: WandbAlertConfig = Field( + default_factory=WandbAlertConfig, desc="Configuration for Wandb alerts." " The alerts may be posted by email and/or slack depending on the Wandb account configuration.", hint=FieldHint.core, @@ -174,6 +176,7 @@ class TrainingCheckpointBaseConfig(IntervalConfig): _abstract = True save_name: typing.ClassVar[str] = "save" callback: CallbackConfig = Field( + default_factory=CallbackConfig, desc="Callback (shell script).", hint=FieldHint.core, ) @@ -281,11 +284,19 @@ class TrainingConfig(Config): desc="A dictionary of evaluation dataset names and their configurations for the validation phase.", hint=FieldHint.core, ) - logs: MetricsLogsConfig = Field(desc="Configuration for metric logging.", hint=FieldHint.core) - checkpoint: TrainingCheckpointConfig = Field(desc="Configuration for checkpoints.", hint=FieldHint.core) - export: TrainingExportConfig = Field(desc="Configuration for exports.", hint=FieldHint.core) - shutdown: ShutdownConfig = Field(desc="Configuration for automated shutdown.", hint=FieldHint.core) - wandb: WandbConfig = Field(desc="Configuration for Wandb.", hint=FieldHint.core) + logs: MetricsLogsConfig = Field( + default_factory=MetricsLogsConfig, desc="Configuration for metric logging.", hint=FieldHint.core + ) + checkpoint: TrainingCheckpointConfig = Field( + default_factory=MetricsLogsConfig, desc="Configuration for checkpoints.", hint=FieldHint.core + ) + export: TrainingExportConfig = Field( + default_factory=MetricsLogsConfig, desc="Configuration for exports.", hint=FieldHint.core + ) + shutdown: ShutdownConfig = Field( + default_factory=ShutdownConfig, desc="Configuration for automated shutdown.", hint=FieldHint.core + ) + wandb: WandbConfig = Field(default_factory=WandbConfig, desc="Configuration for Wandb.", hint=FieldHint.core) train_iters: int = Field( default=0, desc="Total number of training iterations.", hint=FieldHint.core, valid=check_field(Assert.geq, 0) ) @@ -338,23 +349,30 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): _abstract = True # TODO: Generalize data, schedule, logging, etc. training: TrainingConfig = Field( + default_factory=TrainingConfig, desc="Configuration for the training phases and global properties.", hint=FieldHint.core, ) batch: BatchConfig = Field( + default_factory=BatchConfig, desc="Configuration for the training, validation and test batches.", hint=FieldHint.core, ) - schedule: ScheduleConfig = Field(desc="Configuration for the scheduling of each iteration.", hint=FieldHint.core) + schedule: ScheduleConfig = Field( + default_factory=ScheduleConfig, desc="Configuration for the scheduling of each iteration.", hint=FieldHint.core + ) data: DataConfig = Field( + default_factory=DataConfig, desc="Configuration for the dataset and model-independent preprocessing.", hint=FieldHint.core, ) profiling: ProfilingConfig = Field( + default_factory=ProfilingConfig, desc="Configuration for the optional profiling of GPU and CPU CUDA operations.", hint=FieldHint.logging, ) optimizer: OptimizerConfig = Field( + default_factory=OptimizerConfig, desc="Configuration for the training optimizer and learning rate schedule.", hint=FieldHint.core, ) @@ -420,3 +438,6 @@ def new_setup(): old_setup() object.__setattr__(pretrained, "_setup", new_setup) + + +RunnableConfig.register_subclass("train", TrainerConfig) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 0db76ad1..d0f03ccf 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -40,6 +40,7 @@ class LanguageModelKwargs: @config_class() class LanguageModelBaseConfig(BaseModelConfig): transformer: TransformerConfig = Field( + default_factory=TransformerConfig, desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 25ad3d22..c6fe622e 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -26,6 +26,7 @@ class SSMConfig(BaseModelConfig): # Normalization normalization: NormalizationConfig = Field( + default_factory=NormalizationConfig, desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) diff --git a/fast_llm/models/custom/config.py b/fast_llm/models/custom/config.py index f09657e5..414263c0 100644 --- a/fast_llm/models/custom/config.py +++ b/fast_llm/models/custom/config.py @@ -2,6 +2,7 @@ from fast_llm.config import FieldUpdate, config_class from fast_llm.data.data.gpt.config import GPTDataConfig +from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig, GPTTrainerConfig, PretrainedGPTModelConfig @@ -28,7 +29,7 @@ class CustomBaseModelConfig(GPTBaseModelConfig): class CustomModelConfig(GPTModelConfig): # TODO: Add custom model config parameters, if any (typically none). model_name: typing.ClassVar[str] = "gpt_custom" - base_model: CustomBaseModelConfig = FieldUpdate() + base_model: CustomBaseModelConfig = FieldUpdate(default_factory=CustomBaseModelConfig) @classmethod def get_model_class(cls) -> type["CustomModel"]: @@ -45,14 +46,14 @@ def get_huggingface_model_class(cls) -> type["HuggingfaceCustomModelForCausalLM" @config_class() class PretrainedCustomModelConfig(PretrainedGPTModelConfig): - model: CustomModelConfig = FieldUpdate() + model: CustomModelConfig = FieldUpdate(default_factory=CustomModelConfig) @config_class() class CustomTrainerConfig(PretrainedCustomModelConfig, GPTTrainerConfig): # TODO: Add custom trainer config parameters, if any (typically none). - data: CustomDataConfig = FieldUpdate() - reference_models: dict[str, PretrainedCustomModelConfig] = FieldUpdate() + data: CustomDataConfig = FieldUpdate(default_factory=CustomDataConfig) + reference_models: dict[str, PretrainedCustomModelConfig] = FieldUpdate(default_factory=PretrainedCustomModelConfig) @classmethod def get_trainer_class(cls) -> type["CustomTrainer"]: @@ -62,4 +63,5 @@ def get_trainer_class(cls) -> type["CustomTrainer"]: FastLLMModelConfig.register_subclass("gpt_custom", GPTModelConfig) +RunnableConfig.register_subclass("train_gpt_custom", CustomTrainerConfig) TrainerConfig.register_subclass("gpt_custom", CustomTrainerConfig) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 3c889e4e..f2894f8c 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -4,6 +4,7 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler +from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig @@ -129,7 +130,7 @@ def _from_dict( class GPTModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "gpt" - base_model: GPTBaseModelConfig = FieldUpdate() + base_model: GPTBaseModelConfig = FieldUpdate(default_factory=GPTBaseModelConfig) checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = FastLLMModelConfig.checkpoint_formats + ( AutoGPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, @@ -156,13 +157,13 @@ def get_huggingface_model_class(cls) -> type["HuggingfaceGPTModelForCausalLM"]: @config_class() class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): _abstract = False - model: GPTModelConfig = FieldUpdate() + model: GPTModelConfig = FieldUpdate(default_factory=GPTModelConfig) @config_class() class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): - data: GPTDataConfig = FieldUpdate() - batch: GPTBatchConfig = FieldUpdate() + data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) + batch: GPTBatchConfig = FieldUpdate(default_factory=GPTBatchConfig) # TODO: Use dynamic model type? reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() @@ -213,4 +214,5 @@ def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: FastLLMModelConfig.register_subclass("gpt", GPTModelConfig) +RunnableConfig.register_subclass("train_gpt", GPTTrainerConfig) TrainerConfig.register_subclass("gpt", GPTTrainerConfig) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 7db2a2b3..59f640b9 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -5,6 +5,7 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, config_class from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler +from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig @@ -26,6 +27,7 @@ class HybridSSMBaseModelConfig(LanguageModelBaseConfig): _abstract = False ssm: SSMConfig = Field( + default_factory=SSMConfig, desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) @@ -128,7 +130,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: class HybridSSMModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "hybrid_ssm" - base_model: HybridSSMBaseModelConfig = FieldUpdate() + base_model: HybridSSMBaseModelConfig = FieldUpdate(default_factory=HybridSSMBaseModelConfig) checkpoint_formats = FastLLMModelConfig.checkpoint_formats + (LLambaHuggingfaceCheckpointFormat,) @classmethod @@ -153,13 +155,13 @@ def _validate(self): @config_class() class PretrainedHybridSSMModelConfig(PretrainedFastLLMModelConfig): _abstract = False - model: HybridSSMModelConfig = FieldUpdate() + model: HybridSSMModelConfig = FieldUpdate(default_factory=HybridSSMModelConfig) @config_class() class HybridSSMTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): - data: GPTDataConfig = FieldUpdate() - batch: GPTBatchConfig = FieldUpdate() + data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) + batch: GPTBatchConfig = FieldUpdate(default_factory=GPTBatchConfig) @classmethod def get_trainer_class(cls) -> type["HybridSSMTrainer"]: @@ -169,4 +171,5 @@ def get_trainer_class(cls) -> type["HybridSSMTrainer"]: FastLLMModelConfig.register_subclass("hybrid_ssm", HybridSSMModelConfig) +RunnableConfig.register_subclass("train_hybrid_ssm", HybridSSMTrainerConfig) TrainerConfig.register_subclass("hybrid_ssm", HybridSSMTrainerConfig) diff --git a/tests/config/common.py b/tests/config/common.py index 9ccfb597..b671c4af 100644 --- a/tests/config/common.py +++ b/tests/config/common.py @@ -58,7 +58,7 @@ def _validate(self) -> None: @config_class() class ExampleNestedConfig(ExampleConfig): - nested_field: ExampleConfig = Field(hint=FieldHint.core) + nested_field: ExampleConfig = Field(default_factory=ExampleConfig, hint=FieldHint.core) def check_config( diff --git a/tests/test_config.py b/tests/test_config.py index 1ea225b7..ec91f21c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -31,7 +31,7 @@ def run_without_import(cmd: str): "sys.path=[p for p in sys.path if not any(x in p for x in ('site-packages', 'dist-packages', '.egg'))]", # We still want to enable imports from within Fast-llm f"sys.path.append('{repo_path}')", - "from fast_llm.tools.cli import fast_llm as main", + "from fast_llm.cli import fast_llm_main as main", cmd, ] ), @@ -110,7 +110,7 @@ def test_pretrained_config(load_config: ModelConfigType): "rotary": {"type": "default"}, "num_layers": 12, # Default "hidden_size": 1024, # Default - "window_size": 32, # Non-architecture + "window_size": 32, "ffn_hidden_size": 4096, # Implicit default, default value "activation_type": "silu", # Implicit default, non-default value "head_groups": 4, @@ -131,7 +131,7 @@ def test_pretrained_config(load_config: ModelConfigType): "transformer": { # rotary: Don't override nested. "normalization": {"implementation": "triton"}, # Update non-default nested - "peft": {"freeze_others": False}, # Update default nested, non-architecture + "peft": {"type": "lora", "freeze_others": False}, # Update default nested, change type "hidden_size": 512, # Override, affects derived value (kv channels) "head_groups": 1, # Override to default }, @@ -156,9 +156,9 @@ def test_pretrained_config(load_config: ModelConfigType): if load_config in (ModelConfigType.fast_llm, ModelConfigType.model): expected_config["base_model"] = { "transformer": { - "normalization": {"type": "rms_norm", "implementation": "triton"}, - "rotary": {"type": "default"}, - "peft": {"freeze_others": False}, + "normalization": {"type": "RMSNormalizationConfig", "implementation": "triton"}, + "rotary": {"type": "DefaultRotaryConfig"}, + "peft": {"type": "TransformerLoRAConfig", "freeze_others": False}, "num_layers": 12, "hidden_size": 512, "ffn_hidden_size": 4096, @@ -170,6 +170,8 @@ def test_pretrained_config(load_config: ModelConfigType): "vocab_size": 1000, } else: + base_model_update["transformer"]["normalization"]["type"] = "RMSNormalizationConfig" + base_model_update["transformer"]["rotary"]["type"] = "DefaultRotaryConfig" expected_config["base_model"] = base_model_update check_equal_nested(serialized_config, expected_config) diff --git a/tests/test_ssms.py b/tests/test_ssms.py index e6c9aafd..0fec3741 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -139,6 +139,7 @@ def test_load_from_llamba_checkpoint(distributed_config): assert torch.allclose(logits, hf_logits, atol=1e-2) +@pytest.mark.skip(reason="Too slow.") @pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") @pytest.mark.parametrize( "hybrid_block_layout,LAYER_CLS", @@ -207,6 +208,7 @@ def test_mamba_block(distributed_config, distributed): assert not torch.isinf(hidden_states).any() +@pytest.mark.slow @pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") @pytest.mark.parametrize( ("hybrid_block_layout"), From f79ed279b15ad64ac898bbc4184b8ea1ddceeddc Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 13 May 2025 11:55:12 -0400 Subject: [PATCH 12/26] fix --- fast_llm/config.py | 9 ++++--- fast_llm/data/preparator/gpt_memmap/config.py | 2 +- fast_llm/engine/training/config.py | 4 +-- fast_llm/layers/common/config.py | 12 --------- fast_llm/layers/transformer/config.py | 27 +++---------------- 5 files changed, 11 insertions(+), 43 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index f7931697..9fc07e7a 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -143,9 +143,10 @@ def __init__( metadata=None, kw_only=dataclasses.MISSING, ): - if (default is dataclasses.MISSING) == (default_factory is dataclasses.MISSING): - raise ValueError("Fields should define exactly one of `default` or `default_factory`.") - if not init: + if init: + if (default is dataclasses.MISSING) == (default_factory is dataclasses.MISSING): + raise ValueError("Fields should define exactly one of `default` or `default_factory`.") + else: # Non-init fields cause errors when printed before validation. repr = False super().__init__( @@ -781,7 +782,7 @@ def _from_dict( # Check for nested configs to instantiate. try: if name in default: - out_arg_dict[name] = cls._from_dict_nested(default[name], field.type, strict) + out_arg_dict[name] = cls._from_dict_nested(default.pop(name), field.type, strict) except FieldTypeError as e: raise FieldTypeError( diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index bb16a82d..013229be 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -121,7 +121,7 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.core, ) distributed: DatasetPreparatorDistributedConfig = Field( - default_factory=FimConfig, + default_factory=DatasetPreparatorDistributedConfig, desc="Configuration for distributed processing.", hint=FieldHint.feature, ) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 210f302a..65a7411a 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -288,10 +288,10 @@ class TrainingConfig(Config): default_factory=MetricsLogsConfig, desc="Configuration for metric logging.", hint=FieldHint.core ) checkpoint: TrainingCheckpointConfig = Field( - default_factory=MetricsLogsConfig, desc="Configuration for checkpoints.", hint=FieldHint.core + default_factory=TrainingCheckpointConfig, desc="Configuration for checkpoints.", hint=FieldHint.core ) export: TrainingExportConfig = Field( - default_factory=MetricsLogsConfig, desc="Configuration for exports.", hint=FieldHint.core + default_factory=TrainingExportConfig, desc="Configuration for exports.", hint=FieldHint.core ) shutdown: ShutdownConfig = Field( default_factory=ShutdownConfig, desc="Configuration for automated shutdown.", hint=FieldHint.core diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index c03e9957..7914ee33 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -34,18 +34,6 @@ class NormalizationConfig(BaseModelConfig): def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": pass - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - if cls is NormalizationConfig and cls.get_subclass(default.get("type")) is None: - # Default subclass. - return LayerNormalizationConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) - @config_class() class NoNormalizationConfig(NormalizationConfig): diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 7ab0a299..d7297e28 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -93,18 +93,6 @@ class TransformerLossNames: class RotaryConfig(BaseModelConfig): # TODO: Move rotary to its own submodule. - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - if cls is RotaryConfig and cls.get_subclass(default.get("type")) is None: - # Default subclass. - return DefaultRotaryConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) - @property def enabled(self) -> bool: return False @@ -312,18 +300,6 @@ def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": pass - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - if cls is TransformerPeftConfig and cls.get_subclass(default.get("type")) is None: - # Default subclass. - return TransformerNoPeftConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) - @config_class() class TransformerNoPeftConfig(NoPeftConfig, TransformerPeftConfig): @@ -405,14 +381,17 @@ def _validate(self) -> None: class TransformerConfig(BaseModelConfig): _abstract = False normalization: NormalizationConfig = Field( + default_factory=NormalizationConfig, desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) rotary: RotaryConfig = Field( + default_factory=RotaryConfig, desc="Configuration for the rotary positional embeddings.", hint=FieldHint.architecture, ) peft: TransformerPeftConfig = Field( + default_factory=TransformerPeftConfig, desc="Configuration for the parameter-efficient fine tuning.", hint=FieldHint.architecture, ) From 0a3720956f781585cfaa92c908a42b270445646e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 13 May 2025 14:13:39 -0400 Subject: [PATCH 13/26] Revert "fix" This reverts commit f79ed279b15ad64ac898bbc4184b8ea1ddceeddc. --- fast_llm/config.py | 9 +++---- fast_llm/data/preparator/gpt_memmap/config.py | 2 +- fast_llm/engine/training/config.py | 4 +-- fast_llm/layers/common/config.py | 12 +++++++++ fast_llm/layers/transformer/config.py | 27 ++++++++++++++++--- 5 files changed, 43 insertions(+), 11 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 9fc07e7a..f7931697 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -143,10 +143,9 @@ def __init__( metadata=None, kw_only=dataclasses.MISSING, ): - if init: - if (default is dataclasses.MISSING) == (default_factory is dataclasses.MISSING): - raise ValueError("Fields should define exactly one of `default` or `default_factory`.") - else: + if (default is dataclasses.MISSING) == (default_factory is dataclasses.MISSING): + raise ValueError("Fields should define exactly one of `default` or `default_factory`.") + if not init: # Non-init fields cause errors when printed before validation. repr = False super().__init__( @@ -782,7 +781,7 @@ def _from_dict( # Check for nested configs to instantiate. try: if name in default: - out_arg_dict[name] = cls._from_dict_nested(default.pop(name), field.type, strict) + out_arg_dict[name] = cls._from_dict_nested(default[name], field.type, strict) except FieldTypeError as e: raise FieldTypeError( diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 013229be..bb16a82d 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -121,7 +121,7 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.core, ) distributed: DatasetPreparatorDistributedConfig = Field( - default_factory=DatasetPreparatorDistributedConfig, + default_factory=FimConfig, desc="Configuration for distributed processing.", hint=FieldHint.feature, ) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 65a7411a..210f302a 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -288,10 +288,10 @@ class TrainingConfig(Config): default_factory=MetricsLogsConfig, desc="Configuration for metric logging.", hint=FieldHint.core ) checkpoint: TrainingCheckpointConfig = Field( - default_factory=TrainingCheckpointConfig, desc="Configuration for checkpoints.", hint=FieldHint.core + default_factory=MetricsLogsConfig, desc="Configuration for checkpoints.", hint=FieldHint.core ) export: TrainingExportConfig = Field( - default_factory=TrainingExportConfig, desc="Configuration for exports.", hint=FieldHint.core + default_factory=MetricsLogsConfig, desc="Configuration for exports.", hint=FieldHint.core ) shutdown: ShutdownConfig = Field( default_factory=ShutdownConfig, desc="Configuration for automated shutdown.", hint=FieldHint.core diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 7914ee33..c03e9957 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -34,6 +34,18 @@ class NormalizationConfig(BaseModelConfig): def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": pass + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is NormalizationConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return LayerNormalizationConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + @config_class() class NoNormalizationConfig(NormalizationConfig): diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index d7297e28..7ab0a299 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -93,6 +93,18 @@ class TransformerLossNames: class RotaryConfig(BaseModelConfig): # TODO: Move rotary to its own submodule. + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is RotaryConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return DefaultRotaryConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + @property def enabled(self) -> bool: return False @@ -300,6 +312,18 @@ def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": pass + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is TransformerPeftConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return TransformerNoPeftConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + @config_class() class TransformerNoPeftConfig(NoPeftConfig, TransformerPeftConfig): @@ -381,17 +405,14 @@ def _validate(self) -> None: class TransformerConfig(BaseModelConfig): _abstract = False normalization: NormalizationConfig = Field( - default_factory=NormalizationConfig, desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) rotary: RotaryConfig = Field( - default_factory=RotaryConfig, desc="Configuration for the rotary positional embeddings.", hint=FieldHint.architecture, ) peft: TransformerPeftConfig = Field( - default_factory=TransformerPeftConfig, desc="Configuration for the parameter-efficient fine tuning.", hint=FieldHint.architecture, ) From 897cc0f64a24cf5091369a7aca028868f0140233 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 13 May 2025 14:14:11 -0400 Subject: [PATCH 14/26] Revert "Bring back default_factory" This reverts commit 31579bd01afa2db6b533e48e990117b8a26c2ca9. --- fast_llm/config.py | 21 +++++++----- fast_llm/data/data/config.py | 4 +-- fast_llm/data/data/gpt/config.py | 3 +- fast_llm/data/dataset/config.py | 2 -- fast_llm/data/dataset/gpt/config.py | 5 ++- fast_llm/data/preparator/config.py | 3 -- fast_llm/data/preparator/gpt_memmap/config.py | 4 --- fast_llm/engine/base_model/config.py | 1 - fast_llm/engine/checkpoint/convert.py | 4 +-- fast_llm/engine/checkpoint/state_dict.py | 2 ++ fast_llm/engine/config_utils/run.py | 8 ++--- fast_llm/engine/config_utils/runnable.py | 2 +- fast_llm/engine/multi_stage/config.py | 21 +++++------- fast_llm/engine/optimizer/config.py | 2 -- fast_llm/engine/training/config.py | 33 ++++--------------- fast_llm/layers/language_model/config.py | 1 - fast_llm/layers/ssm/config.py | 1 - fast_llm/models/custom/config.py | 10 +++--- fast_llm/models/gpt/config.py | 10 +++--- fast_llm/models/ssm/config.py | 11 +++---- tests/config/common.py | 2 +- tests/test_config.py | 14 ++++---- tests/test_ssms.py | 2 -- 23 files changed, 57 insertions(+), 109 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index f7931697..2c03fda2 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -143,8 +143,10 @@ def __init__( metadata=None, kw_only=dataclasses.MISSING, ): - if (default is dataclasses.MISSING) == (default_factory is dataclasses.MISSING): - raise ValueError("Fields should define exactly one of `default` or `default_factory`.") + if default is not dataclasses.MISSING and default_factory is not dataclasses.MISSING: + raise ValueError("cannot specify both default and default_factory") + if isinstance(default_factory, type) and issubclass(default_factory, Config): + raise ValueError("Config classes should not be used as `default_factory`") if not init: # Non-init fields cause errors when printed before validation. repr = False @@ -780,9 +782,9 @@ def _from_dict( else: # Check for nested configs to instantiate. try: - if name in default: - out_arg_dict[name] = cls._from_dict_nested(default[name], field.type, strict) - + value = cls._from_dict_nested(default.pop(name, MISSING), field.type, strict) + if value is not MISSING: + out_arg_dict[name] = value except FieldTypeError as e: raise FieldTypeError( f"Invalid field type `{get_type_name(field.type)}` in class {cls._get_class_name()}: " @@ -815,8 +817,11 @@ def _from_dict_nested(cls, value, type_, strict: bool): raise FieldTypeError(f"Unsupported __origin__ `{origin}`") elif not isinstance(type_, type): raise FieldTypeError(f"Not a type: {type_}.") - elif issubclass(type_, Config) and isinstance(value, dict): - value = type_._from_dict(value, strict) + elif issubclass(type_, Config): + if value is MISSING: + value = {} + if isinstance(value, dict): + value = type_._from_dict(value, strict) return value @classmethod @@ -888,11 +893,11 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ f"Config comparison errors:\n " + "\n".join(errors), log_fn=log_fn, ) - return None @classmethod def register_subclass(cls, name: str, cls_: type[typing.Self]) -> None: Assert.custom(issubclass, cls_, cls) + assert not cls_._abstract if name in cls._registry: old_cls = cls._registry[name] if old_cls.__name__ == cls_.__name__ and cls._registry[name].__module__ == cls_.__module__: diff --git a/fast_llm/data/data/config.py b/fast_llm/data/data/config.py index 25850ac3..41dbb5d9 100644 --- a/fast_llm/data/data/config.py +++ b/fast_llm/data/data/config.py @@ -9,6 +9,4 @@ class DataConfig(Config): _abstract = True _sampling_config_class: typing.ClassVar[type[SamplingData]] - sampling: SamplingConfig = Field( - default_factory=SamplingConfig, desc="Default configuration for dataset sampling." - ) + sampling: SamplingConfig = Field(desc="Default configuration for dataset sampling.") diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 6c598c0c..85bcc656 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -27,7 +27,6 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): _abstract = False tokenizer: TokenizerConfig = Field( - default_factory=TokenizerConfig, desc="Configuration for the tokenizer (for FIM).", hint=FieldHint.feature, ) @@ -37,7 +36,7 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): desc="Configuration for the dataset(s).", hint=FieldHint.core, ) - sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig) + sampling: GPTSamplingConfig = FieldUpdate() data_sample_warn_time_ms: float = Field( default=1000, desc="Warn if a sample takes too long to load.", diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 7901d6e7..1bb4b6be 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -174,12 +174,10 @@ class SampledDatasetUpdateConfig(SampledDatasetConfig): _abstract = True sampling: SamplingConfig = Field( - default_factory=SamplingConfig, desc="Optional override to sampling configuration parameters.", hint=FieldHint.core, ) dataset: SampledDatasetConfig = Field( - default_factory=SampledDatasetConfig, desc="The dataset to sample from.", hint=FieldHint.core, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 6c48e170..df38474e 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -173,8 +173,8 @@ def build(self) -> "GPTDatasetSlice": @config_class() class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig): _abstract = False - sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig) - dataset: GPTSampledDatasetConfig = FieldUpdate(default_factory=GPTSampledDatasetConfig) + sampling: GPTSamplingConfig = FieldUpdate() + dataset: GPTSampledDatasetConfig = FieldUpdate() @config_class() @@ -389,7 +389,6 @@ class GPTLegacyConfig(Config): valid=_validate_path, ) fim: FimConfig = Field( - default_factory=FimConfig, desc="Configuration for Fill In the Middle (FIM).", hint=FieldHint.feature, ) diff --git a/fast_llm/data/preparator/config.py b/fast_llm/data/preparator/config.py index b2068ddc..edf088c0 100644 --- a/fast_llm/data/preparator/config.py +++ b/fast_llm/data/preparator/config.py @@ -24,6 +24,3 @@ class DatasetPreparator[ConfigType: DatasetPreparatorConfig](Configurable[Config @abc.abstractmethod def run(self) -> None: raise NotImplementedError - - -RunnableConfig.register_subclass("prepare", DatasetPreparatorConfig) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index bb16a82d..775c367c 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -121,7 +121,6 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.core, ) distributed: DatasetPreparatorDistributedConfig = Field( - default_factory=FimConfig, desc="Configuration for distributed processing.", hint=FieldHint.feature, ) @@ -150,12 +149,10 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): valid=check_field(Assert.geq, 1), ) dataset: GPTHuggingfaceDatasetConfig = Field( - default_factory=GPTHuggingfaceDatasetConfig, desc="Configuration for the dataset.", hint=FieldHint.feature, ) tokenizer: TokenizerConfig = Field( - default_factory=TokenizerConfig, desc="Configuration for the tokenizer.", hint=FieldHint.feature, ) @@ -180,4 +177,3 @@ def get_dataset_preparator_class(cls) -> type["GPTMemmapDatasetPreparator"]: RunnableConfig.register_subclass("prepare_gpt_memmap", GPTMemmapDatasetPreparatorConfig) -DatasetPreparatorConfig.register_subclass("gpt_memmap", GPTMemmapDatasetPreparatorConfig) diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 4be42e06..25f53e4a 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -42,7 +42,6 @@ def _get_architecture(self) -> dict[str, typing.Any]: assert isinstance(field, Field), f"{name}, {field}" if field.hint == FieldHint.architecture: architecture[name] = self._serialize_architecture_field(getattr(self, name, MISSING)) - return architecture def _serialize_architecture_field(self, value: typing.Any) -> typing.Any: if isinstance(value, BaseModelConfig): diff --git a/fast_llm/engine/checkpoint/convert.py b/fast_llm/engine/checkpoint/convert.py index 51bd4a5f..97d9643d 100644 --- a/fast_llm/engine/checkpoint/convert.py +++ b/fast_llm/engine/checkpoint/convert.py @@ -18,8 +18,8 @@ @config_class() class ConvertConfig(RunnableConfig): - input: CheckpointLoadConfig = Field(default_factory=CheckpointLoadConfig) - output: CheckpointSaveConfig = Field(default_factory=CheckpointSaveConfig) + input: CheckpointLoadConfig = Field() + output: CheckpointSaveConfig = Field() use_cpu: bool = Field(default=False) exist_ok: bool = Field(default=False) layers_per_step: int | None = Field(default=None) diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index 556e97be..d6807138 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -26,6 +26,8 @@ logger = logging.getLogger(__name__) +torch.distributed.gather + class StateDictCheckpointHandler(CheckpointHandler): base_file_name: typing.ClassVar[str] = "model" diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index d6377409..126e0ae8 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -20,9 +20,7 @@ @config_class() class RunConfig(Config): - tensor_logs: TensorLogsConfig = Field( - default_factory=TensorLogsConfig, desc="Configuration for debug tensor logs.", hint=FieldHint.logging - ) + tensor_logs: TensorLogsConfig = Field(desc="Configuration for debug tensor logs.", hint=FieldHint.logging) # TODO v0.3: Adjust (now only affects logging to file). structured_logs: bool = Field( default=True, desc="Configure logging to the Fast-LLM format.", hint=FieldHint.logging @@ -70,9 +68,7 @@ def _validate(self): @config_class() class ExperimentConfig(RunnableConfig): - run: RunConfig = Field( - default_factory=RunConfig, desc="Global properties for the experiment.", hint=FieldHint.core - ) + run: RunConfig = Field(desc="Global properties for the experiment.", hint=FieldHint.core) def _show( self, diff --git a/fast_llm/engine/config_utils/runnable.py b/fast_llm/engine/config_utils/runnable.py index ac10225e..01d24eaa 100644 --- a/fast_llm/engine/config_utils/runnable.py +++ b/fast_llm/engine/config_utils/runnable.py @@ -23,7 +23,7 @@ def parse_and_run(cls, args: list[str] | None = None) -> None: if args is None: args = sys.argv[1:] cls_ = cls - while len(args) >= 1 and "=" not in args[0] and not args[0].startswith("-"): + while len(args) >= 1 and "=" not in args[0]: # Allow chained dynamic type selection without the `type=`, ex. `train gpt`. cls_ = cls_.get_subclass(args[0]) args = args[1:] diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 97534319..9434fba6 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -211,17 +211,12 @@ class FastLLMModelConfig(Config): FastLLMCheckpointFormat, ) model_name: typing.ClassVar[str] - base_model: BaseModelConfig = Field( - default_factory=BaseModelConfig, desc="Configuration for the base model.", hint=FieldHint.core - ) + base_model: BaseModelConfig = Field(desc="Configuration for the base model.", hint=FieldHint.core) multi_stage: MultiStageConfig = Field( - default_factory=MultiStageConfig, desc="Configuration for the stage breakdown of the model.", hint=FieldHint.core, ) - distributed: DistributedConfig = Field( - default_factory=DistributedConfig, desc="Distributed configuration.", hint=FieldHint.core - ) + distributed: DistributedConfig = Field(desc="Distributed configuration.", hint=FieldHint.core) @classmethod def __fast_llm_serialize__(cls) -> str: @@ -291,11 +286,8 @@ class PretrainedFastLLMModelConfig(Config): # TODO: Generalize data, schedule, logging, etc. _abstract = True # This configs may be overridden with the pretrained config during validation, so we should be careful about accessing them before. - model: FastLLMModelConfig = Field( - default_factory=FastLLMModelConfig, desc="Configuration for the Fast-LLM model.", hint=FieldHint.core - ) + model: FastLLMModelConfig = Field(desc="Configuration for the Fast-LLM model.", hint=FieldHint.core) pretrained: CheckpointLoadConfig = Field( - default_factory=CheckpointLoadConfig, desc="Configuration for loading the configuration and state of a pretrained model.", hint=FieldHint.feature, ) @@ -336,7 +328,6 @@ class CheckpointMetadata(Config): hint=FieldHint.core, ) config: FastLLMModelConfig = Field( - default_factory=FastLLMModelConfig, desc="The Fast-LLM model configuration for the saved model.", hint=FieldHint.core, ) @@ -381,9 +372,13 @@ def _from_dict( if "fast_llm_version" not in default: default["fast_llm_version"] = "0" + # Determine the model config class. + from fast_llm.models.auto import model_registry + model_config_class = default["model"] if isinstance(model_config_class, str): - model_config_class = FastLLMModelConfig.get_subclass(default["model"]) + Assert.incl(model_config_class, model_registry) + model_config_class = model_registry[model_config_class] default["model"] = model_config_class # TODO v0.3: Remove backward compatibility. diff --git a/fast_llm/engine/optimizer/config.py b/fast_llm/engine/optimizer/config.py index 3a154c9e..f4303a5d 100644 --- a/fast_llm/engine/optimizer/config.py +++ b/fast_llm/engine/optimizer/config.py @@ -74,12 +74,10 @@ class GradientScalerConfig(Config): class OptimizerConfig(Config): learning_rate: LearningRateScheduleConfig = Field( - default_factory=LearningRateScheduleConfig, desc="A schedule for the learning rate.", hint=FieldHint.core, ) gradient_scaler: GradientScalerConfig = Field( - default_factory=GradientScalerConfig, desc="Configuration for the fixed or dynamic gradient scaling.", hint=FieldHint.feature, ) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 210f302a..768d42f5 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -23,7 +23,6 @@ DistributedCheckpointFormat, ) from fast_llm.engine.config_utils.run import ExperimentConfig -from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import PretrainedFastLLMModelConfig from fast_llm.engine.optimizer.config import OptimizerConfig from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig @@ -142,7 +141,6 @@ class MetricsLogsConfig(IntervalConfig): @config_class() class WandbConfig(Config): alert: WandbAlertConfig = Field( - default_factory=WandbAlertConfig, desc="Configuration for Wandb alerts." " The alerts may be posted by email and/or slack depending on the Wandb account configuration.", hint=FieldHint.core, @@ -176,7 +174,6 @@ class TrainingCheckpointBaseConfig(IntervalConfig): _abstract = True save_name: typing.ClassVar[str] = "save" callback: CallbackConfig = Field( - default_factory=CallbackConfig, desc="Callback (shell script).", hint=FieldHint.core, ) @@ -284,19 +281,11 @@ class TrainingConfig(Config): desc="A dictionary of evaluation dataset names and their configurations for the validation phase.", hint=FieldHint.core, ) - logs: MetricsLogsConfig = Field( - default_factory=MetricsLogsConfig, desc="Configuration for metric logging.", hint=FieldHint.core - ) - checkpoint: TrainingCheckpointConfig = Field( - default_factory=MetricsLogsConfig, desc="Configuration for checkpoints.", hint=FieldHint.core - ) - export: TrainingExportConfig = Field( - default_factory=MetricsLogsConfig, desc="Configuration for exports.", hint=FieldHint.core - ) - shutdown: ShutdownConfig = Field( - default_factory=ShutdownConfig, desc="Configuration for automated shutdown.", hint=FieldHint.core - ) - wandb: WandbConfig = Field(default_factory=WandbConfig, desc="Configuration for Wandb.", hint=FieldHint.core) + logs: MetricsLogsConfig = Field(desc="Configuration for metric logging.", hint=FieldHint.core) + checkpoint: TrainingCheckpointConfig = Field(desc="Configuration for checkpoints.", hint=FieldHint.core) + export: TrainingExportConfig = Field(desc="Configuration for exports.", hint=FieldHint.core) + shutdown: ShutdownConfig = Field(desc="Configuration for automated shutdown.", hint=FieldHint.core) + wandb: WandbConfig = Field(desc="Configuration for Wandb.", hint=FieldHint.core) train_iters: int = Field( default=0, desc="Total number of training iterations.", hint=FieldHint.core, valid=check_field(Assert.geq, 0) ) @@ -349,30 +338,23 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): _abstract = True # TODO: Generalize data, schedule, logging, etc. training: TrainingConfig = Field( - default_factory=TrainingConfig, desc="Configuration for the training phases and global properties.", hint=FieldHint.core, ) batch: BatchConfig = Field( - default_factory=BatchConfig, desc="Configuration for the training, validation and test batches.", hint=FieldHint.core, ) - schedule: ScheduleConfig = Field( - default_factory=ScheduleConfig, desc="Configuration for the scheduling of each iteration.", hint=FieldHint.core - ) + schedule: ScheduleConfig = Field(desc="Configuration for the scheduling of each iteration.", hint=FieldHint.core) data: DataConfig = Field( - default_factory=DataConfig, desc="Configuration for the dataset and model-independent preprocessing.", hint=FieldHint.core, ) profiling: ProfilingConfig = Field( - default_factory=ProfilingConfig, desc="Configuration for the optional profiling of GPU and CPU CUDA operations.", hint=FieldHint.logging, ) optimizer: OptimizerConfig = Field( - default_factory=OptimizerConfig, desc="Configuration for the training optimizer and learning rate schedule.", hint=FieldHint.core, ) @@ -438,6 +420,3 @@ def new_setup(): old_setup() object.__setattr__(pretrained, "_setup", new_setup) - - -RunnableConfig.register_subclass("train", TrainerConfig) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index d0f03ccf..0db76ad1 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -40,7 +40,6 @@ class LanguageModelKwargs: @config_class() class LanguageModelBaseConfig(BaseModelConfig): transformer: TransformerConfig = Field( - default_factory=TransformerConfig, desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index c6fe622e..25ad3d22 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -26,7 +26,6 @@ class SSMConfig(BaseModelConfig): # Normalization normalization: NormalizationConfig = Field( - default_factory=NormalizationConfig, desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) diff --git a/fast_llm/models/custom/config.py b/fast_llm/models/custom/config.py index 414263c0..f09657e5 100644 --- a/fast_llm/models/custom/config.py +++ b/fast_llm/models/custom/config.py @@ -2,7 +2,6 @@ from fast_llm.config import FieldUpdate, config_class from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig, GPTTrainerConfig, PretrainedGPTModelConfig @@ -29,7 +28,7 @@ class CustomBaseModelConfig(GPTBaseModelConfig): class CustomModelConfig(GPTModelConfig): # TODO: Add custom model config parameters, if any (typically none). model_name: typing.ClassVar[str] = "gpt_custom" - base_model: CustomBaseModelConfig = FieldUpdate(default_factory=CustomBaseModelConfig) + base_model: CustomBaseModelConfig = FieldUpdate() @classmethod def get_model_class(cls) -> type["CustomModel"]: @@ -46,14 +45,14 @@ def get_huggingface_model_class(cls) -> type["HuggingfaceCustomModelForCausalLM" @config_class() class PretrainedCustomModelConfig(PretrainedGPTModelConfig): - model: CustomModelConfig = FieldUpdate(default_factory=CustomModelConfig) + model: CustomModelConfig = FieldUpdate() @config_class() class CustomTrainerConfig(PretrainedCustomModelConfig, GPTTrainerConfig): # TODO: Add custom trainer config parameters, if any (typically none). - data: CustomDataConfig = FieldUpdate(default_factory=CustomDataConfig) - reference_models: dict[str, PretrainedCustomModelConfig] = FieldUpdate(default_factory=PretrainedCustomModelConfig) + data: CustomDataConfig = FieldUpdate() + reference_models: dict[str, PretrainedCustomModelConfig] = FieldUpdate() @classmethod def get_trainer_class(cls) -> type["CustomTrainer"]: @@ -63,5 +62,4 @@ def get_trainer_class(cls) -> type["CustomTrainer"]: FastLLMModelConfig.register_subclass("gpt_custom", GPTModelConfig) -RunnableConfig.register_subclass("train_gpt_custom", CustomTrainerConfig) TrainerConfig.register_subclass("gpt_custom", CustomTrainerConfig) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index f2894f8c..3c889e4e 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -4,7 +4,6 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler -from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig @@ -130,7 +129,7 @@ def _from_dict( class GPTModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "gpt" - base_model: GPTBaseModelConfig = FieldUpdate(default_factory=GPTBaseModelConfig) + base_model: GPTBaseModelConfig = FieldUpdate() checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = FastLLMModelConfig.checkpoint_formats + ( AutoGPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, @@ -157,13 +156,13 @@ def get_huggingface_model_class(cls) -> type["HuggingfaceGPTModelForCausalLM"]: @config_class() class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): _abstract = False - model: GPTModelConfig = FieldUpdate(default_factory=GPTModelConfig) + model: GPTModelConfig = FieldUpdate() @config_class() class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): - data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) - batch: GPTBatchConfig = FieldUpdate(default_factory=GPTBatchConfig) + data: GPTDataConfig = FieldUpdate() + batch: GPTBatchConfig = FieldUpdate() # TODO: Use dynamic model type? reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() @@ -214,5 +213,4 @@ def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: FastLLMModelConfig.register_subclass("gpt", GPTModelConfig) -RunnableConfig.register_subclass("train_gpt", GPTTrainerConfig) TrainerConfig.register_subclass("gpt", GPTTrainerConfig) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 59f640b9..7db2a2b3 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -5,7 +5,6 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, config_class from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler -from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig @@ -27,7 +26,6 @@ class HybridSSMBaseModelConfig(LanguageModelBaseConfig): _abstract = False ssm: SSMConfig = Field( - default_factory=SSMConfig, desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) @@ -130,7 +128,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: class HybridSSMModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "hybrid_ssm" - base_model: HybridSSMBaseModelConfig = FieldUpdate(default_factory=HybridSSMBaseModelConfig) + base_model: HybridSSMBaseModelConfig = FieldUpdate() checkpoint_formats = FastLLMModelConfig.checkpoint_formats + (LLambaHuggingfaceCheckpointFormat,) @classmethod @@ -155,13 +153,13 @@ def _validate(self): @config_class() class PretrainedHybridSSMModelConfig(PretrainedFastLLMModelConfig): _abstract = False - model: HybridSSMModelConfig = FieldUpdate(default_factory=HybridSSMModelConfig) + model: HybridSSMModelConfig = FieldUpdate() @config_class() class HybridSSMTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): - data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) - batch: GPTBatchConfig = FieldUpdate(default_factory=GPTBatchConfig) + data: GPTDataConfig = FieldUpdate() + batch: GPTBatchConfig = FieldUpdate() @classmethod def get_trainer_class(cls) -> type["HybridSSMTrainer"]: @@ -171,5 +169,4 @@ def get_trainer_class(cls) -> type["HybridSSMTrainer"]: FastLLMModelConfig.register_subclass("hybrid_ssm", HybridSSMModelConfig) -RunnableConfig.register_subclass("train_hybrid_ssm", HybridSSMTrainerConfig) TrainerConfig.register_subclass("hybrid_ssm", HybridSSMTrainerConfig) diff --git a/tests/config/common.py b/tests/config/common.py index b671c4af..9ccfb597 100644 --- a/tests/config/common.py +++ b/tests/config/common.py @@ -58,7 +58,7 @@ def _validate(self) -> None: @config_class() class ExampleNestedConfig(ExampleConfig): - nested_field: ExampleConfig = Field(default_factory=ExampleConfig, hint=FieldHint.core) + nested_field: ExampleConfig = Field(hint=FieldHint.core) def check_config( diff --git a/tests/test_config.py b/tests/test_config.py index ec91f21c..1ea225b7 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -31,7 +31,7 @@ def run_without_import(cmd: str): "sys.path=[p for p in sys.path if not any(x in p for x in ('site-packages', 'dist-packages', '.egg'))]", # We still want to enable imports from within Fast-llm f"sys.path.append('{repo_path}')", - "from fast_llm.cli import fast_llm_main as main", + "from fast_llm.tools.cli import fast_llm as main", cmd, ] ), @@ -110,7 +110,7 @@ def test_pretrained_config(load_config: ModelConfigType): "rotary": {"type": "default"}, "num_layers": 12, # Default "hidden_size": 1024, # Default - "window_size": 32, + "window_size": 32, # Non-architecture "ffn_hidden_size": 4096, # Implicit default, default value "activation_type": "silu", # Implicit default, non-default value "head_groups": 4, @@ -131,7 +131,7 @@ def test_pretrained_config(load_config: ModelConfigType): "transformer": { # rotary: Don't override nested. "normalization": {"implementation": "triton"}, # Update non-default nested - "peft": {"type": "lora", "freeze_others": False}, # Update default nested, change type + "peft": {"freeze_others": False}, # Update default nested, non-architecture "hidden_size": 512, # Override, affects derived value (kv channels) "head_groups": 1, # Override to default }, @@ -156,9 +156,9 @@ def test_pretrained_config(load_config: ModelConfigType): if load_config in (ModelConfigType.fast_llm, ModelConfigType.model): expected_config["base_model"] = { "transformer": { - "normalization": {"type": "RMSNormalizationConfig", "implementation": "triton"}, - "rotary": {"type": "DefaultRotaryConfig"}, - "peft": {"type": "TransformerLoRAConfig", "freeze_others": False}, + "normalization": {"type": "rms_norm", "implementation": "triton"}, + "rotary": {"type": "default"}, + "peft": {"freeze_others": False}, "num_layers": 12, "hidden_size": 512, "ffn_hidden_size": 4096, @@ -170,8 +170,6 @@ def test_pretrained_config(load_config: ModelConfigType): "vocab_size": 1000, } else: - base_model_update["transformer"]["normalization"]["type"] = "RMSNormalizationConfig" - base_model_update["transformer"]["rotary"]["type"] = "DefaultRotaryConfig" expected_config["base_model"] = base_model_update check_equal_nested(serialized_config, expected_config) diff --git a/tests/test_ssms.py b/tests/test_ssms.py index 0fec3741..e6c9aafd 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -139,7 +139,6 @@ def test_load_from_llamba_checkpoint(distributed_config): assert torch.allclose(logits, hf_logits, atol=1e-2) -@pytest.mark.skip(reason="Too slow.") @pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") @pytest.mark.parametrize( "hybrid_block_layout,LAYER_CLS", @@ -208,7 +207,6 @@ def test_mamba_block(distributed_config, distributed): assert not torch.isinf(hidden_states).any() -@pytest.mark.slow @pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") @pytest.mark.parametrize( ("hybrid_block_layout"), From 8a49e0f05fa097e5122c7961912ae7f9c3a999dc Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 14 May 2025 15:49:54 -0400 Subject: [PATCH 15/26] stuff --- fast_llm/config.py | 43 +++++++++---------- fast_llm/data/preparator/config.py | 3 ++ fast_llm/data/preparator/gpt_memmap/config.py | 1 + fast_llm/engine/base_model/config.py | 1 + fast_llm/engine/checkpoint/convert.py | 2 +- fast_llm/engine/config_utils/runnable.py | 2 +- fast_llm/engine/multi_stage/config.py | 6 +-- fast_llm/engine/training/config.py | 6 ++- fast_llm/layers/transformer/config.py | 1 + fast_llm/models/custom/config.py | 9 ++-- fast_llm/models/gpt/config.py | 9 ++-- fast_llm/models/gpt/conversion.py | 19 ++++---- fast_llm/models/ssm/config.py | 11 ++--- tests/test_checkpoint.py | 12 +++--- tests/test_config.py | 17 +++++--- tests/test_ssms.py | 2 + 16 files changed, 73 insertions(+), 71 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 2c03fda2..a99553aa 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -137,7 +137,6 @@ def __init__( default=dataclasses.MISSING, default_factory=dataclasses.MISSING, init: bool = True, - repr: bool = True, hash=None, compare: bool = True, metadata=None, @@ -147,14 +146,11 @@ def __init__( raise ValueError("cannot specify both default and default_factory") if isinstance(default_factory, type) and issubclass(default_factory, Config): raise ValueError("Config classes should not be used as `default_factory`") - if not init: - # Non-init fields cause errors when printed before validation. - repr = False super().__init__( default=default, default_factory=default_factory, init=init, - repr=repr, + repr=False, hash=hash, compare=compare, metadata=metadata, @@ -246,7 +242,9 @@ def _process_config_class(cls: type["Config"]): return cls -def config_class[T: Config]() -> typing.Callable[[type[T]], type[T]]: +def config_class[ + T: Config +](dynamic_type: "dict[type[Config], str]|None" = None) -> typing.Callable[[type[T]], type[T]]: """ Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications. """ @@ -256,7 +254,7 @@ def wrap(cls): if hasattr(cls, "__post_init__"): raise TypeError(f"`__post_init__` should not be implemented for `Config` classes") - wrapped = _process_config_class(dataclasses.dataclass(cls, kw_only=True)) + wrapped = _process_config_class(dataclasses.dataclass(cls, kw_only=True, repr=False)) wrapped_init = cls.__init__ @@ -269,7 +267,12 @@ def __init__(self, **kwargs): if _AUTO_VALIDATE: self.validate() - cls.__init__ = __init__ + wrapped.__init__ = __init__ + + if dynamic_type is not None: + for cls_, name in dynamic_type.items(): + cls_.register_subclass(name, wrapped) + return wrapped return wrap @@ -284,7 +287,7 @@ def __call__(cls: "type[Config]", **kwargs): return super().__call__(**kwargs) -@dataclasses.dataclass() +@dataclasses.dataclass(kw_only=True, repr=False) class Config(metaclass=ConfigMeta): """ An advanced `dataclass` with basic type checking, validation and argparse support. @@ -299,14 +302,14 @@ class Config(metaclass=ConfigMeta): # Set to true to prevent instantiation. _abstract: typing.ClassVar[bool] = False # Keep track of whether an instance has been validated - _validated: bool = Field(init=False, repr=False) + _validated: bool = Field(init=False) # Keep track of unknown fields so they can be reported during validation. - _unknown_fields: dict[str, typing.Any] = Field(init=False, repr=False) + _unknown_fields: dict[str, typing.Any] = Field(init=False) # Keep track of explicitly set fields to ensure they get serialized and used as config updates. - _explicit_fields: set[str] = Field(init=False, repr=False) + _explicit_fields: set[str] = Field(init=False) # Used within `_set_implicit_default` to set implicit defaults for fields # without them being automatically added to `_explicit_fields`. - _setting_implicit_default: bool | None = Field(init=False, repr=False) + _setting_implicit_default: bool | None = Field(init=False) # A registry for all the config classes. _registry: typing.ClassVar[Registry[str, type[typing.Self]]] = Registry[str, "type[Config]"]("Config", {}) @@ -359,7 +362,7 @@ def _set_implicit_default(self, _value: bool | None = True): yield self._setting_implicit_default = False - def validate[T](self: T, *, _is_validating: bool = False) -> T: + def validate[T: Config](self: T, *, _is_validating: bool = False) -> T: """ Validate a class and mark it as read-only This should not be overridden in derived classes. @@ -481,12 +484,6 @@ def _validate_element(cls, value, type_, name: str): raise FieldTypeError(f"Not a type.") elif issubclass(type_, Config): cls._validate_element_type(value, type_, strict=False) - # If the value belongs to a proper subclass of `type_`, - # we need an explicitly set `type` field for serialization to remember the actual config class. - if type(value) != type_: - if value.type is None: - value.type = value.__class__.__name__ - value._explicit_fields.add("type") value.validate(_is_validating=True) else: @@ -693,6 +690,9 @@ def to_copy[ ) -> T: return self.from_dict(self, *updates, strict=strict, update_type=update_type) + def __repr__(self): + return self.to_logs(log_fn=str) + def to_logs[ T ]( @@ -893,11 +893,11 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ f"Config comparison errors:\n " + "\n".join(errors), log_fn=log_fn, ) + return None @classmethod def register_subclass(cls, name: str, cls_: type[typing.Self]) -> None: Assert.custom(issubclass, cls_, cls) - assert not cls_._abstract if name in cls._registry: old_cls = cls._registry[name] if old_cls.__name__ == cls_.__name__ and cls._registry[name].__module__ == cls_.__module__: @@ -970,7 +970,6 @@ def __init_subclass__(cls): valid=value.pop("valid", base_class_field.valid), default=value.pop("default", base_class_field.default), default_factory=value.pop("default_factory", base_class_field.default_factory), - repr=value.pop("repr", base_class_field.repr), hash=value.pop("hash", base_class_field.hash), compare=value.pop("compare", base_class_field.compare), metadata=value.pop("metadata", base_class_field.metadata), diff --git a/fast_llm/data/preparator/config.py b/fast_llm/data/preparator/config.py index edf088c0..b2068ddc 100644 --- a/fast_llm/data/preparator/config.py +++ b/fast_llm/data/preparator/config.py @@ -24,3 +24,6 @@ class DatasetPreparator[ConfigType: DatasetPreparatorConfig](Configurable[Config @abc.abstractmethod def run(self) -> None: raise NotImplementedError + + +RunnableConfig.register_subclass("prepare", DatasetPreparatorConfig) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 775c367c..253f474e 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -177,3 +177,4 @@ def get_dataset_preparator_class(cls) -> type["GPTMemmapDatasetPreparator"]: RunnableConfig.register_subclass("prepare_gpt_memmap", GPTMemmapDatasetPreparatorConfig) +DatasetPreparatorConfig.register_subclass("gpt_memmap", GPTMemmapDatasetPreparatorConfig) diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 25f53e4a..4be42e06 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -42,6 +42,7 @@ def _get_architecture(self) -> dict[str, typing.Any]: assert isinstance(field, Field), f"{name}, {field}" if field.hint == FieldHint.architecture: architecture[name] = self._serialize_architecture_field(getattr(self, name, MISSING)) + return architecture def _serialize_architecture_field(self, value: typing.Any) -> typing.Any: if isinstance(value, BaseModelConfig): diff --git a/fast_llm/engine/checkpoint/convert.py b/fast_llm/engine/checkpoint/convert.py index 97d9643d..7b097072 100644 --- a/fast_llm/engine/checkpoint/convert.py +++ b/fast_llm/engine/checkpoint/convert.py @@ -65,7 +65,7 @@ def run(self): f"Output path {self.output.path} already exists and has been processed. Skipping model conversion..." ) return - model_class = self.model_config_class.get_model_class() + model_class = self.model.get_model_class() if self.layers_per_step is None: self._convert_model_partial(model_class, self.output) else: diff --git a/fast_llm/engine/config_utils/runnable.py b/fast_llm/engine/config_utils/runnable.py index 01d24eaa..ac10225e 100644 --- a/fast_llm/engine/config_utils/runnable.py +++ b/fast_llm/engine/config_utils/runnable.py @@ -23,7 +23,7 @@ def parse_and_run(cls, args: list[str] | None = None) -> None: if args is None: args = sys.argv[1:] cls_ = cls - while len(args) >= 1 and "=" not in args[0]: + while len(args) >= 1 and "=" not in args[0] and not args[0].startswith("-"): # Allow chained dynamic type selection without the `type=`, ex. `train gpt`. cls_ = cls_.get_subclass(args[0]) args = args[1:] diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 9434fba6..ad61c70f 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -372,13 +372,9 @@ def _from_dict( if "fast_llm_version" not in default: default["fast_llm_version"] = "0" - # Determine the model config class. - from fast_llm.models.auto import model_registry - model_config_class = default["model"] if isinstance(model_config_class, str): - Assert.incl(model_config_class, model_registry) - model_config_class = model_registry[model_config_class] + model_config_class = FastLLMModelConfig.get_subclass(default["model"]) default["model"] = model_config_class # TODO v0.3: Remove backward compatibility. diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 768d42f5..0dba6e70 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -23,6 +23,7 @@ DistributedCheckpointFormat, ) from fast_llm.engine.config_utils.run import ExperimentConfig +from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import PretrainedFastLLMModelConfig from fast_llm.engine.optimizer.config import OptimizerConfig from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig @@ -120,7 +121,7 @@ class WandbAlertConfig(IntervalConfig): "The update may be posted by email and/or slack depending on the Wandb account configuration.", hint=FieldHint.feature, ) - post_alerts: bool = Field(init=False, repr=False) + post_alerts: bool = Field(init=False) def _validate(self) -> None: if self.status_updates is None: @@ -420,3 +421,6 @@ def new_setup(): old_setup() object.__setattr__(pretrained, "_setup", new_setup) + + +RunnableConfig.register_subclass("train", TrainerConfig) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 7ab0a299..e06977b8 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -375,6 +375,7 @@ def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": return parameter def _validate(self) -> None: + super()._validate() if TransformerSubLayerName.mlp_1 in self.layers or TransformerSubLayerName.mlp_2 in self.layers: # TODO: Add MLP support. raise NotImplementedError("LoRA not supported for MLP.") diff --git a/fast_llm/models/custom/config.py b/fast_llm/models/custom/config.py index f09657e5..963ffd35 100644 --- a/fast_llm/models/custom/config.py +++ b/fast_llm/models/custom/config.py @@ -2,6 +2,7 @@ from fast_llm.config import FieldUpdate, config_class from fast_llm.data.data.gpt.config import GPTDataConfig +from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig, GPTTrainerConfig, PretrainedGPTModelConfig @@ -24,7 +25,7 @@ class CustomBaseModelConfig(GPTBaseModelConfig): pass -@config_class() +@config_class(dynamic_type={FastLLMModelConfig: "gpt_custom"}) class CustomModelConfig(GPTModelConfig): # TODO: Add custom model config parameters, if any (typically none). model_name: typing.ClassVar[str] = "gpt_custom" @@ -48,7 +49,7 @@ class PretrainedCustomModelConfig(PretrainedGPTModelConfig): model: CustomModelConfig = FieldUpdate() -@config_class() +@config_class(dynamic_type={RunnableConfig: "train_gpt_custom", TrainerConfig: "gpt_custom"}) class CustomTrainerConfig(PretrainedCustomModelConfig, GPTTrainerConfig): # TODO: Add custom trainer config parameters, if any (typically none). data: CustomDataConfig = FieldUpdate() @@ -59,7 +60,3 @@ def get_trainer_class(cls) -> type["CustomTrainer"]: from fast_llm.models.custom.trainer import CustomTrainer return CustomTrainer - - -FastLLMModelConfig.register_subclass("gpt_custom", GPTModelConfig) -TrainerConfig.register_subclass("gpt_custom", CustomTrainerConfig) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 3c889e4e..64f6f1de 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -4,6 +4,7 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler +from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig @@ -125,7 +126,7 @@ def _from_dict( return super()._from_dict(default, strict, flat) -@config_class() +@config_class(dynamic_type={FastLLMModelConfig: "gpt"}) class GPTModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "gpt" @@ -159,7 +160,7 @@ class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): model: GPTModelConfig = FieldUpdate() -@config_class() +@config_class(dynamic_type={RunnableConfig: "train_gpt", TrainerConfig: "gpt"}) class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): data: GPTDataConfig = FieldUpdate() batch: GPTBatchConfig = FieldUpdate() @@ -210,7 +211,3 @@ def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: from fast_llm.models.gpt.model import GPTInferenceRunner return GPTInferenceRunner - - -FastLLMModelConfig.register_subclass("gpt", GPTModelConfig) -TrainerConfig.register_subclass("gpt", GPTTrainerConfig) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 9a27a079..46264d29 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -387,7 +387,6 @@ def __post_init__(self): def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: (rotary_config,) = fast_llm_values - serialized_config = rotary_config.to_dict() if type(rotary_config) is DefaultRotaryConfig: rotary_scaling = { "rope_type": "default", @@ -395,23 +394,23 @@ def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing elif type(rotary_config) is Llama3RotaryConfig: rotary_scaling = { "rope_type": "llama3", - "factor": serialized_config["scale_factor"], - "low_freq_factor": serialized_config["low_frequency_factor"], - "high_freq_factor": serialized_config["high_frequency_factor"], - "original_max_position_embeddings": serialized_config["original_context_length"], + "factor": rotary_config.scale_factor, + "low_freq_factor": rotary_config.low_frequency_factor, + "high_freq_factor": rotary_config.high_frequency_factor, + "original_max_position_embeddings": rotary_config.original_context_length, } elif type(rotary_config) is YarnRotaryConfig: rotary_scaling = { "rope_type": "yarn", - "attention_factor": serialized_config["attention_factor"], - "beta_fast": serialized_config["beta_fast"], - "beta_slow": serialized_config["beta_slow"], - "original_max_position_embeddings": serialized_config["original_context_length"], + "attention_factor": rotary_config.attention_factor, + "beta_fast": rotary_config.beta_fast, + "beta_slow": rotary_config.beta_slow, + "original_max_position_embeddings": rotary_config.original_context_length, } else: raise ValueError(f"Unsupported rotary type: {type(rotary_config).__name__}") - return serialized_config["theta"], rotary_scaling + return rotary_config.theta, rotary_scaling def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: rotary_theta, rope_scaling = export_values diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 7db2a2b3..0c2b3c48 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -5,6 +5,7 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, config_class from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler +from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig @@ -30,7 +31,7 @@ class HybridSSMBaseModelConfig(LanguageModelBaseConfig): hint=FieldHint.architecture, ) hybrid_block_layout: list[str] = Field( - default_factory=lambda: ["m2"], + default=("m2",), desc="Pattern of blocks to use in the model. 't' for Transformer, 'm' for Mamba1, 'm2' for Discrete Mamba2", hint=FieldHint.architecture, ) @@ -124,7 +125,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return LLambaHuggingfaceCheckpointHandler -@config_class() +@config_class(dynamic_type={FastLLMModelConfig: "hybrid_ssm"}) class HybridSSMModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "hybrid_ssm" @@ -156,7 +157,7 @@ class PretrainedHybridSSMModelConfig(PretrainedFastLLMModelConfig): model: HybridSSMModelConfig = FieldUpdate() -@config_class() +@config_class(dynamic_type={RunnableConfig: "train_hybrid_ssm", TrainerConfig: "hybrid_ssm"}) class HybridSSMTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): data: GPTDataConfig = FieldUpdate() batch: GPTBatchConfig = FieldUpdate() @@ -166,7 +167,3 @@ def get_trainer_class(cls) -> type["HybridSSMTrainer"]: from fast_llm.models.ssm.trainer import HybridSSMTrainer return HybridSSMTrainer - - -FastLLMModelConfig.register_subclass("hybrid_ssm", HybridSSMModelConfig) -TrainerConfig.register_subclass("hybrid_ssm", HybridSSMTrainerConfig) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index eb21f3b3..5c5f5b90 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -114,7 +114,7 @@ def test_convert_distributed_to_fast_llm(): path=_CONVERT_PATH / "fast_llm_0", format=FastLLMCheckpointFormat, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) @@ -133,7 +133,7 @@ def test_convert_fast_llm_to_huggingface(): path=_CONVERT_PATH / "huggingface_0", format=HUGGINGFACE_CHECKPOINT_FORMAT, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) @@ -150,7 +150,7 @@ def test_convert_huggingface_to_distributed(): path=_CONVERT_PATH / "distributed_0", format=DistributedCheckpointFormat, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) @@ -169,7 +169,7 @@ def test_convert_distributed_to_huggingface(): path=_CONVERT_PATH / "huggingface_1", format=HUGGINGFACE_CHECKPOINT_FORMAT, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) @@ -186,7 +186,7 @@ def test_convert_huggingface_to_fast_llm(): path=_CONVERT_PATH / "fast_llm_1", format=FastLLMCheckpointFormat, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) @@ -203,7 +203,7 @@ def test_convert_fast_llm_to_distributed(): path=_CONVERT_PATH / "distributed_1", format=DistributedCheckpointFormat, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) diff --git a/tests/test_config.py b/tests/test_config.py index 1ea225b7..07617f35 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -31,7 +31,7 @@ def run_without_import(cmd: str): "sys.path=[p for p in sys.path if not any(x in p for x in ('site-packages', 'dist-packages', '.egg'))]", # We still want to enable imports from within Fast-llm f"sys.path.append('{repo_path}')", - "from fast_llm.tools.cli import fast_llm as main", + "from fast_llm.cli import fast_llm_main as main", cmd, ] ), @@ -110,7 +110,7 @@ def test_pretrained_config(load_config: ModelConfigType): "rotary": {"type": "default"}, "num_layers": 12, # Default "hidden_size": 1024, # Default - "window_size": 32, # Non-architecture + "window_size": 32, "ffn_hidden_size": 4096, # Implicit default, default value "activation_type": "silu", # Implicit default, non-default value "head_groups": 4, @@ -131,7 +131,7 @@ def test_pretrained_config(load_config: ModelConfigType): "transformer": { # rotary: Don't override nested. "normalization": {"implementation": "triton"}, # Update non-default nested - "peft": {"freeze_others": False}, # Update default nested, non-architecture + "peft": {"type": "lora", "freeze_others": False}, # Update default nested, change type "hidden_size": 512, # Override, affects derived value (kv channels) "head_groups": 1, # Override to default }, @@ -156,9 +156,9 @@ def test_pretrained_config(load_config: ModelConfigType): if load_config in (ModelConfigType.fast_llm, ModelConfigType.model): expected_config["base_model"] = { "transformer": { - "normalization": {"type": "rms_norm", "implementation": "triton"}, - "rotary": {"type": "default"}, - "peft": {"freeze_others": False}, + "normalization": {"type": "RMSNormalizationConfig", "implementation": "triton"}, + "rotary": {"type": "DefaultRotaryConfig"}, + "peft": {"type": "TransformerLoRAConfig", "freeze_others": False, "layers": ["query", "value"]}, "num_layers": 12, "hidden_size": 512, "ffn_hidden_size": 4096, @@ -170,6 +170,11 @@ def test_pretrained_config(load_config: ModelConfigType): "vocab_size": 1000, } else: + base_model_update["transformer"]["peft"] = { + "type": "TransformerLoRAConfig", + "freeze_others": False, + "layers": ["query", "value"], + } expected_config["base_model"] = base_model_update check_equal_nested(serialized_config, expected_config) diff --git a/tests/test_ssms.py b/tests/test_ssms.py index e6c9aafd..0fec3741 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -139,6 +139,7 @@ def test_load_from_llamba_checkpoint(distributed_config): assert torch.allclose(logits, hf_logits, atol=1e-2) +@pytest.mark.skip(reason="Too slow.") @pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") @pytest.mark.parametrize( "hybrid_block_layout,LAYER_CLS", @@ -207,6 +208,7 @@ def test_mamba_block(distributed_config, distributed): assert not torch.isinf(hidden_states).any() +@pytest.mark.slow @pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") @pytest.mark.parametrize( ("hybrid_block_layout"), From 843a62108f464403c70e487000a70006e93ca118 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 14 May 2025 16:00:02 -0400 Subject: [PATCH 16/26] fix --- tools/push_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/push_model.py b/tools/push_model.py index edab3312..39a3b914 100644 --- a/tools/push_model.py +++ b/tools/push_model.py @@ -27,7 +27,7 @@ raise ImportError("Please install huggingface_hub to use this script") from e -from fast_llm.tools.convert import ConvertConfig # isort:skip +from fast_llm.engine.checkpoint.convert import ConvertConfig # isort:skip logger = logging.getLogger(__name__) From aa3bc0be1368c22a9eceb5d00a1f69db779858b9 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 14 May 2025 16:43:45 -0400 Subject: [PATCH 17/26] stuff --- fast_llm/config.py | 92 ++++++++----------- fast_llm/data/data/config.py | 4 +- fast_llm/data/data/gpt/config.py | 3 +- fast_llm/data/dataset/config.py | 2 - fast_llm/data/dataset/gpt/config.py | 5 +- fast_llm/data/preparator/gpt_memmap/config.py | 7 +- fast_llm/engine/base_model/base_model.py | 2 +- fast_llm/engine/base_model/config.py | 1 + fast_llm/engine/config_utils/run.py | 8 +- fast_llm/engine/multi_stage/config.py | 17 +--- fast_llm/engine/training/config.py | 32 ++----- fast_llm/layers/language_model/config.py | 1 - fast_llm/layers/ssm/config.py | 1 - fast_llm/layers/transformer/config.py | 3 - fast_llm/models/custom/config.py | 8 +- fast_llm/models/gpt/config.py | 8 +- fast_llm/models/ssm/config.py | 9 +- fast_llm/tools/cli.py | 2 +- fast_llm/tools/convert.py | 18 ++-- tests/config/common.py | 8 +- tests/config/test_config.py | 30 ++++++ tests/data/common.py | 2 +- tests/test_checkpoint.py | 16 ++-- tests/test_ssms.py | 2 + tools/push_model.py | 4 +- 25 files changed, 128 insertions(+), 157 deletions(-) create mode 100644 tests/config/test_config.py diff --git a/fast_llm/config.py b/fast_llm/config.py index 46c903f1..3b277202 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1,3 +1,4 @@ +import abc import contextlib import copy import dataclasses @@ -137,7 +138,6 @@ def __init__( default=dataclasses.MISSING, default_factory=dataclasses.MISSING, init: bool = True, - repr: bool = True, hash=None, compare: bool = True, metadata=None, @@ -146,12 +146,11 @@ def __init__( if default is not dataclasses.MISSING and default_factory is not dataclasses.MISSING: raise ValueError("cannot specify both default and default_factory") if isinstance(default_factory, type) and issubclass(default_factory, Config): - default_factory = _ConfigFactory(default_factory) + raise ValueError("Config classes should not be used as `default_factory`") super().__init__( default=default, default_factory=default_factory, init=init, - repr=repr, hash=hash, compare=compare, metadata=metadata, @@ -223,20 +222,6 @@ def valid(x): return valid -class _ConfigFactory: - """ - A dataclass default factory that prevents early validation. - Validation is still done through the parent config if needed. - """ - - def __init__(self, factory: typing.Callable[[], "Config"] | type["Config"]): - self._factory = factory - - def __call__(self): - with NoAutoValidate(): - return self._factory() - - class ValidationError(ValueError): pass @@ -257,7 +242,7 @@ def _process_config_class(cls: type["Config"]): return cls -def config_class(cls=None): +def config_class[T: Config]() -> typing.Callable[[type[T]], type[T]]: """ Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications. """ @@ -280,20 +265,23 @@ def __init__(self, **kwargs): if _AUTO_VALIDATE: self.validate() - cls.__init__ = __init__ + wrapped.__init__ = __init__ return wrapped - # See if we're being called as @config_class or @config_class(). - if cls is None: - # We're called with parens. - return wrap + return wrap + - # We're called as @config_class without parens. - return wrap(cls) +class ConfigMeta(abc.ABCMeta): + def __call__(cls: "type[Config]", **kwargs): + # Always go through `_from_dict` for correct dynamic class selection and nested config instantiation. + if not kwargs.pop("_from_dict_check", False): + # with NoAutoValidate(): + return cls._from_dict(kwargs) + return super().__call__(**kwargs) -@dataclasses.dataclass() -class Config: +@dataclasses.dataclass(kw_only=True, repr=False) +class Config(metaclass=ConfigMeta): """ An advanced `dataclass` with basic type checking, validation and argparse support. Typically, a subclass will: @@ -307,14 +295,14 @@ class Config: # Set to true to prevent instantiation. _abstract: typing.ClassVar[bool] = False # Keep track of whether an instance has been validated - _validated: bool = Field(init=False, repr=False) + _validated: bool = Field(init=False) # Keep track of unknown fields so they can be reported during validation. - _unknown_fields: dict[str, typing.Any] = Field(init=False, repr=False) + _unknown_fields: dict[str, typing.Any] = Field(init=False) # Keep track of explicitly set fields to ensure they get serialized and used as config updates. - _explicit_fields: set[str] = Field(init=False, repr=False) + _explicit_fields: set[str] = Field(init=False) # Used within `_set_implicit_default` to set implicit defaults for fields # without them being automatically added to `_explicit_fields`. - _setting_implicit_default: bool | None = Field(init=False, repr=False) + _setting_implicit_default: bool | None = Field(init=False) def __setattr__(self, key: str, value: typing.Any) -> None: """ @@ -339,7 +327,7 @@ def __setattr__(self, key: str, value: typing.Any) -> None: ) else: field = self.get_field(key) - if field.init and field._field_type != dataclasses._FIELD_CLASSVAR: + if field.init and field._field_type == dataclasses._FIELD: # Adding to explicit field list except within `_set_implicit_default` context, # during dataclass initialization (`_setting_implicit_default` not yet set) # and during automated config validation (`_setting_implicit_default=None`) @@ -358,13 +346,13 @@ def __delattr__(self, key: str) -> None: super().__delattr__(key) @contextlib.contextmanager - def _set_implicit_default(self, _value: bool | int = True): + def _set_implicit_default(self, _value: bool | None = True): assert self._setting_implicit_default is False self._setting_implicit_default = _value yield self._setting_implicit_default = False - def validate[T](self: T, *, _is_validating: bool = False) -> T: + def validate[T: Config](self: T, *, _is_validating: bool = False) -> T: """ Validate a class and mark it as read-only This should not be overridden in derived classes. @@ -388,11 +376,16 @@ def _validate(self) -> None: Can be extended to add custom post-processing (typically before the super() call) and validation (typically after) """ - self._check_abstract() + if self._abstract: + raise ValidationError(f"{type(self).__name__} is abstract") + if not self.__class_validated__: + raise ValidationError( + f"{type(self).__name__} hasn't been validated. Make sure to use the @config_class decorator." + ) errors = [] with self._set_implicit_default(None): for name, field in self.fields(): - if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa + if not field.init or field._field_type != dataclasses._FIELD: # noqa continue value = getattr(self, name) if isinstance(value, Tag): @@ -610,11 +603,7 @@ def _add_field_to_args( all_fields: bool = False, serializable: bool = True, ) -> None: - if ( - field is not None - and (not field.init or field._field_type == dataclasses._FIELD_CLASSVAR) - and not all_fields - ): + if field is not None and (not field.init or field._field_type != dataclasses._FIELD) and not all_fields: # Exclude class variables and derived fields unless requested explicitly. return explicit_field = ( @@ -677,6 +666,9 @@ def to_copy[ ) -> T: return self.from_dict(self, *updates, strict=strict, update_type=update_type) + def __repr__(self): + return self.to_logs(log_fn=str) + def to_logs[ T ]( @@ -739,7 +731,7 @@ def _from_dict( flat: bool = False, ) -> typing.Self: # TODO v0.3: Remove flat format - out_arg_dict = {} + out_arg_dict = {"_from_dict_check": True} # TODO v0.3: Remove backward compatibility fix if "__class__" in default: @@ -748,7 +740,7 @@ def _from_dict( # Do not validate yet in case the root class sets cross-dependencies in validation. with NoAutoValidate(): for name, field in cls.fields(): - if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa + if not field.init or field._field_type != dataclasses._FIELD: # noqa continue if flat: if isinstance(field.type, type) and issubclass(field.type, Config): @@ -869,22 +861,15 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ f"Config comparison errors:\n " + "\n".join(errors), log_fn=log_fn, ) - - @classmethod - def _check_abstract(cls) -> None: - if cls._abstract: - raise ValidationError(f"{cls.__name__} is abstract") - if not cls.__class_validated__: - raise ValidationError( - f"{cls.__name__} hasn't been validated. Make sure to use the @config_class decorator." - ) + return None def __init_subclass__(cls): """ We need to postpone validation until the class has been processed by the dataclass wrapper. """ + Assert.eq(cls.__name__, cls.__qualname__) for base_class in cls.__mro__: - if issubclass(base_class, Config): + if issubclass(base_class, Config) and base_class is not cls: assert cls.__class_validated__, ( f"Parent class {get_type_name(base_class)} of config class {get_type_name(cls)} has not been validated." f" Make sure to use the @config_class decorator." @@ -913,7 +898,6 @@ def __init_subclass__(cls): valid=value.pop("valid", base_class_field.valid), default=value.pop("default", base_class_field.default), default_factory=value.pop("default_factory", base_class_field.default_factory), - repr=value.pop("repr", base_class_field.repr), hash=value.pop("hash", base_class_field.hash), compare=value.pop("compare", base_class_field.compare), metadata=value.pop("metadata", base_class_field.metadata), diff --git a/fast_llm/data/data/config.py b/fast_llm/data/data/config.py index 25850ac3..41dbb5d9 100644 --- a/fast_llm/data/data/config.py +++ b/fast_llm/data/data/config.py @@ -9,6 +9,4 @@ class DataConfig(Config): _abstract = True _sampling_config_class: typing.ClassVar[type[SamplingData]] - sampling: SamplingConfig = Field( - default_factory=SamplingConfig, desc="Default configuration for dataset sampling." - ) + sampling: SamplingConfig = Field(desc="Default configuration for dataset sampling.") diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 6c598c0c..85bcc656 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -27,7 +27,6 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): _abstract = False tokenizer: TokenizerConfig = Field( - default_factory=TokenizerConfig, desc="Configuration for the tokenizer (for FIM).", hint=FieldHint.feature, ) @@ -37,7 +36,7 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): desc="Configuration for the dataset(s).", hint=FieldHint.core, ) - sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig) + sampling: GPTSamplingConfig = FieldUpdate() data_sample_warn_time_ms: float = Field( default=1000, desc="Warn if a sample takes too long to load.", diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 7901d6e7..1bb4b6be 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -174,12 +174,10 @@ class SampledDatasetUpdateConfig(SampledDatasetConfig): _abstract = True sampling: SamplingConfig = Field( - default_factory=SamplingConfig, desc="Optional override to sampling configuration parameters.", hint=FieldHint.core, ) dataset: SampledDatasetConfig = Field( - default_factory=SampledDatasetConfig, desc="The dataset to sample from.", hint=FieldHint.core, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index ed9128c6..f4f6e282 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -230,8 +230,8 @@ def build(self) -> "GPTDatasetSlice": class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig): _abstract = False type_: typing.ClassVar[str | None] = "sampled" - sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig) - dataset: GPTSampledDatasetConfig = FieldUpdate(default_factory=GPTSampledDatasetConfig) + sampling: GPTSamplingConfig = FieldUpdate() + dataset: GPTSampledDatasetConfig = FieldUpdate() @config_class() @@ -450,7 +450,6 @@ class GPTLegacyConfig(Config): valid=_validate_path, ) fim: FimConfig = Field( - default_factory=FimConfig, desc="Configuration for Fill In the Middle (FIM).", hint=FieldHint.feature, ) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 2c4311c3..7091f3c8 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -24,7 +24,7 @@ MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00" -@config_class +@config_class() class GPTHuggingfaceDatasetConfig(Config): path: str = Field( default=None, @@ -77,7 +77,7 @@ class GPTHuggingfaceDatasetConfig(Config): ) -@config_class +@config_class() class DatasetPreparatorDistributedConfig(Config): # TODO: Unify with fast_llm.engine.distributed.config.DistributedConfig @@ -120,7 +120,6 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.core, ) distributed: DatasetPreparatorDistributedConfig = Field( - default_factory=DatasetPreparatorDistributedConfig, desc="Configuration for distributed processing.", hint=FieldHint.feature, ) @@ -149,12 +148,10 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): valid=check_field(Assert.geq, 1), ) dataset: GPTHuggingfaceDatasetConfig = Field( - default_factory=GPTHuggingfaceDatasetConfig, desc="Configuration for the dataset.", hint=FieldHint.feature, ) tokenizer: TokenizerConfig = Field( - default_factory=TokenizerConfig, desc="Configuration for the tokenizer.", hint=FieldHint.feature, ) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 2be1e487..df603a91 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -90,7 +90,7 @@ def __init__( config: BaseModelConfig, distributed_config: DistributedConfig, ): - self._tensor_space = TensorSpace(distributed_config) + self._tensor_space: TensorSpace = TensorSpace(distributed_config) config.setup_tensor_space(self._tensor_space) super().__init__(config) diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 25f53e4a..4be42e06 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -42,6 +42,7 @@ def _get_architecture(self) -> dict[str, typing.Any]: assert isinstance(field, Field), f"{name}, {field}" if field.hint == FieldHint.architecture: architecture[name] = self._serialize_architecture_field(getattr(self, name, MISSING)) + return architecture def _serialize_architecture_field(self, value: typing.Any) -> typing.Any: if isinstance(value, BaseModelConfig): diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index d6377409..126e0ae8 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -20,9 +20,7 @@ @config_class() class RunConfig(Config): - tensor_logs: TensorLogsConfig = Field( - default_factory=TensorLogsConfig, desc="Configuration for debug tensor logs.", hint=FieldHint.logging - ) + tensor_logs: TensorLogsConfig = Field(desc="Configuration for debug tensor logs.", hint=FieldHint.logging) # TODO v0.3: Adjust (now only affects logging to file). structured_logs: bool = Field( default=True, desc="Configure logging to the Fast-LLM format.", hint=FieldHint.logging @@ -70,9 +68,7 @@ def _validate(self): @config_class() class ExperimentConfig(RunnableConfig): - run: RunConfig = Field( - default_factory=RunConfig, desc="Global properties for the experiment.", hint=FieldHint.core - ) + run: RunConfig = Field(desc="Global properties for the experiment.", hint=FieldHint.core) def _show( self, diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index e2d04f80..9434fba6 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -211,17 +211,12 @@ class FastLLMModelConfig(Config): FastLLMCheckpointFormat, ) model_name: typing.ClassVar[str] - base_model: BaseModelConfig = Field( - default_factory=BaseModelConfig, desc="Configuration for the base model.", hint=FieldHint.core - ) + base_model: BaseModelConfig = Field(desc="Configuration for the base model.", hint=FieldHint.core) multi_stage: MultiStageConfig = Field( - default_factory=MultiStageConfig, desc="Configuration for the stage breakdown of the model.", hint=FieldHint.core, ) - distributed: DistributedConfig = Field( - default_factory=DistributedConfig, desc="Distributed configuration.", hint=FieldHint.core - ) + distributed: DistributedConfig = Field(desc="Distributed configuration.", hint=FieldHint.core) @classmethod def __fast_llm_serialize__(cls) -> str: @@ -291,11 +286,8 @@ class PretrainedFastLLMModelConfig(Config): # TODO: Generalize data, schedule, logging, etc. _abstract = True # This configs may be overridden with the pretrained config during validation, so we should be careful about accessing them before. - model: FastLLMModelConfig = Field( - default_factory=FastLLMModelConfig, desc="Configuration for the Fast-LLM model.", hint=FieldHint.core - ) + model: FastLLMModelConfig = Field(desc="Configuration for the Fast-LLM model.", hint=FieldHint.core) pretrained: CheckpointLoadConfig = Field( - default_factory=CheckpointLoadConfig, desc="Configuration for loading the configuration and state of a pretrained model.", hint=FieldHint.feature, ) @@ -315,7 +307,7 @@ def _setup(self) -> None: pass -@config_class +@config_class() class CheckpointMetadata(Config): # TODO: Make entries more flexible? # I.e.. model / format / usage (ex. training) - specific entries instead of a generic metadata? @@ -336,7 +328,6 @@ class CheckpointMetadata(Config): hint=FieldHint.core, ) config: FastLLMModelConfig = Field( - default_factory=FastLLMModelConfig, desc="The Fast-LLM model configuration for the saved model.", hint=FieldHint.core, ) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 1e990e9c..a5be2e7e 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -120,7 +120,7 @@ class WandbAlertConfig(IntervalConfig): "The update may be posted by email and/or slack depending on the Wandb account configuration.", hint=FieldHint.feature, ) - post_alerts: bool = Field(init=False, repr=False) + post_alerts: bool = Field(init=False) def _validate(self) -> None: if self.status_updates is None: @@ -141,7 +141,6 @@ class MetricsLogsConfig(IntervalConfig): @config_class() class WandbConfig(Config): alert: WandbAlertConfig = Field( - default_factory=WandbAlertConfig, desc="Configuration for Wandb alerts." " The alerts may be posted by email and/or slack depending on the Wandb account configuration.", hint=FieldHint.core, @@ -175,7 +174,6 @@ class TrainingCheckpointBaseConfig(IntervalConfig): _abstract = True save_name: typing.ClassVar[str] = "save" callback: CallbackConfig = Field( - default_factory=CallbackConfig, desc="Callback (shell script).", hint=FieldHint.core, ) @@ -257,7 +255,6 @@ class TrainingExportConfig(TrainingCheckpointBaseConfig, CheckpointStateSaveConf offset = FieldUpdate(desc="Offset for the first export.") callback: CallbackConfig = FieldUpdate(desc="Callback (shell script) to run after export.") - @abc.abstractmethod def get_save_directory(self, experiment_directory: pathlib.Path) -> pathlib.Path: return experiment_directory / "export" / self.format.name @@ -284,19 +281,11 @@ class TrainingConfig(Config): desc="A dictionary of evaluation dataset names and their configurations for the validation phase.", hint=FieldHint.core, ) - logs: MetricsLogsConfig = Field( - default_factory=MetricsLogsConfig, desc="Configuration for metric logging.", hint=FieldHint.core - ) - checkpoint: TrainingCheckpointConfig = Field( - default_factory=MetricsLogsConfig, desc="Configuration for checkpoints.", hint=FieldHint.core - ) - export: TrainingExportConfig = Field( - default_factory=MetricsLogsConfig, desc="Configuration for exports.", hint=FieldHint.core - ) - shutdown: ShutdownConfig = Field( - default_factory=ShutdownConfig, desc="Configuration for automated shutdown.", hint=FieldHint.core - ) - wandb: WandbConfig = Field(default_factory=WandbConfig, desc="Configuration for Wandb.", hint=FieldHint.core) + logs: MetricsLogsConfig = Field(desc="Configuration for metric logging.", hint=FieldHint.core) + checkpoint: TrainingCheckpointConfig = Field(desc="Configuration for checkpoints.", hint=FieldHint.core) + export: TrainingExportConfig = Field(desc="Configuration for exports.", hint=FieldHint.core) + shutdown: ShutdownConfig = Field(desc="Configuration for automated shutdown.", hint=FieldHint.core) + wandb: WandbConfig = Field(desc="Configuration for Wandb.", hint=FieldHint.core) train_iters: int = Field( default=0, desc="Total number of training iterations.", hint=FieldHint.core, valid=check_field(Assert.geq, 0) ) @@ -349,30 +338,23 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): _abstract = True # TODO: Generalize data, schedule, logging, etc. training: TrainingConfig = Field( - default_factory=TrainingConfig, desc="Configuration for the training phases and global properties.", hint=FieldHint.core, ) batch: BatchConfig = Field( - default_factory=BatchConfig, desc="Configuration for the training, validation and test batches.", hint=FieldHint.core, ) - schedule: ScheduleConfig = Field( - default_factory=ScheduleConfig, desc="Configuration for the scheduling of each iteration.", hint=FieldHint.core - ) + schedule: ScheduleConfig = Field(desc="Configuration for the scheduling of each iteration.", hint=FieldHint.core) data: DataConfig = Field( - default_factory=DataConfig, desc="Configuration for the dataset and model-independent preprocessing.", hint=FieldHint.core, ) profiling: ProfilingConfig = Field( - default_factory=ProfilingConfig, desc="Configuration for the optional profiling of GPU and CPU CUDA operations.", hint=FieldHint.logging, ) optimizer: OptimizerConfig = Field( - default_factory=OptimizerConfig, desc="Configuration for the training optimizer and learning rate schedule.", hint=FieldHint.core, ) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index d0f03ccf..0db76ad1 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -40,7 +40,6 @@ class LanguageModelKwargs: @config_class() class LanguageModelBaseConfig(BaseModelConfig): transformer: TransformerConfig = Field( - default_factory=TransformerConfig, desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index c6fe622e..25ad3d22 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -26,7 +26,6 @@ class SSMConfig(BaseModelConfig): # Normalization normalization: NormalizationConfig = Field( - default_factory=NormalizationConfig, desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index e69b1841..c621139c 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -248,17 +248,14 @@ def _validate(self) -> None: class TransformerConfig(BaseModelConfig): _abstract = False normalization: NormalizationConfig = Field( - default_factory=NormalizationConfig, desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) rotary: RotaryConfig = Field( - default_factory=RotaryConfig, desc="Configuration for the rotary positional embeddings.", hint=FieldHint.architecture, ) peft: TransformerPeftConfig = Field( - default_factory=TransformerPeftConfig, desc="Configuration for the parameter-efficient fine tuning.", hint=FieldHint.architecture, ) diff --git a/fast_llm/models/custom/config.py b/fast_llm/models/custom/config.py index 8be45e1c..08902e2c 100644 --- a/fast_llm/models/custom/config.py +++ b/fast_llm/models/custom/config.py @@ -26,7 +26,7 @@ class CustomBaseModelConfig(GPTBaseModelConfig): class CustomModelConfig(GPTModelConfig): # TODO: Add custom model config parameters, if any (typically none). model_name: typing.ClassVar[str] = "gpt_custom" - base_model: CustomBaseModelConfig = FieldUpdate(default_factory=CustomBaseModelConfig) + base_model: CustomBaseModelConfig = FieldUpdate() @classmethod def get_model_class(cls) -> type["CustomModel"]: @@ -43,14 +43,14 @@ def get_huggingface_model_class(cls) -> type["HuggingfaceCustomModelForCausalLM" @config_class() class PretrainedCustomModelConfig(PretrainedGPTModelConfig): - model: CustomModelConfig = FieldUpdate(default_factory=CustomModelConfig) + model: CustomModelConfig = FieldUpdate() @config_class() class CustomTrainerConfig(PretrainedCustomModelConfig, GPTTrainerConfig): # TODO: Add custom trainer config parameters, if any (typically none). - data: CustomDataConfig = FieldUpdate(default_factory=CustomDataConfig) - reference_models: dict[str, PretrainedCustomModelConfig] = FieldUpdate(default_factory=PretrainedCustomModelConfig) + data: CustomDataConfig = FieldUpdate() + reference_models: dict[str, PretrainedCustomModelConfig] = FieldUpdate() @classmethod def get_trainer_class(cls) -> type["CustomTrainer"]: diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 418f948e..0ec3fb51 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -129,7 +129,7 @@ def _from_dict( class GPTModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "gpt" - base_model: GPTBaseModelConfig = FieldUpdate(default_factory=GPTBaseModelConfig) + base_model: GPTBaseModelConfig = FieldUpdate() checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = FastLLMModelConfig.checkpoint_formats + ( AutoGPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, @@ -156,13 +156,13 @@ def get_huggingface_model_class(cls) -> type["HuggingfaceGPTModelForCausalLM"]: @config_class() class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): _abstract = False - model: GPTModelConfig = FieldUpdate(default_factory=GPTModelConfig) + model: GPTModelConfig = FieldUpdate() @config_class() class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): - data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) - batch: GPTBatchConfig = FieldUpdate(default_factory=GPTBatchConfig) + data: GPTDataConfig = FieldUpdate() + batch: GPTBatchConfig = FieldUpdate() # TODO: Use dynamic model type? reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 0311cc69..771a4fca 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -26,7 +26,6 @@ class HybridSSMBaseModelConfig(LanguageModelBaseConfig): _abstract = False ssm: SSMConfig = Field( - default_factory=SSMConfig, desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) @@ -129,7 +128,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: class HybridSSMModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "hybrid_ssm" - base_model: HybridSSMBaseModelConfig = FieldUpdate(default_factory=HybridSSMBaseModelConfig) + base_model: HybridSSMBaseModelConfig = FieldUpdate() checkpoint_formats = FastLLMModelConfig.checkpoint_formats + (LLambaHuggingfaceCheckpointFormat,) @classmethod @@ -154,13 +153,13 @@ def _validate(self): @config_class() class PretrainedHybridSSMModelConfig(PretrainedFastLLMModelConfig): _abstract = False - model: HybridSSMModelConfig = FieldUpdate(default_factory=HybridSSMModelConfig) + model: HybridSSMModelConfig = FieldUpdate() @config_class() class HybridTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): - data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) - batch: GPTBatchConfig = FieldUpdate(default_factory=GPTBatchConfig) + data: GPTDataConfig = FieldUpdate() + batch: GPTBatchConfig = FieldUpdate() @classmethod def get_trainer_class(cls) -> type["SSMTrainer"]: diff --git a/fast_llm/tools/cli.py b/fast_llm/tools/cli.py index 0cc02f42..8df884fe 100644 --- a/fast_llm/tools/cli.py +++ b/fast_llm/tools/cli.py @@ -21,7 +21,7 @@ def fast_llm(args=None): if parsed.subcommand == "train": from fast_llm.tools.train import CliTrainingConfig as Runnable elif parsed.subcommand == "convert": - from fast_llm.tools.convert import ConversionConfig as Runnable + from fast_llm.tools.convert import ConvertConfig as Runnable elif parsed.subcommand == "prepare": from fast_llm.tools.prepare_dataset import PrepareDatasetConfig as Runnable else: diff --git a/fast_llm/tools/convert.py b/fast_llm/tools/convert.py index d3db3745..3ee580aa 100644 --- a/fast_llm/tools/convert.py +++ b/fast_llm/tools/convert.py @@ -19,13 +19,13 @@ @config_class() -class ConversionConfig(RunnableConfig): - input: CheckpointLoadConfig = Field(default_factory=CheckpointLoadConfig) - output: CheckpointSaveConfig = Field(default_factory=CheckpointSaveConfig) +class ConvertConfig(RunnableConfig): + input: CheckpointLoadConfig = Field() + output: CheckpointSaveConfig = Field() use_cpu: bool = Field(default=False) exist_ok: bool = Field(default=False) layers_per_step: int | None = Field(default=None) - model_config_class: type[FastLLMModelConfig] = Field(default=None) + model: type[FastLLMModelConfig] = Field(default=None) @classmethod def _get_parser(cls): @@ -44,9 +44,9 @@ def _from_parsed_args(cls, parsed: argparse.Namespace, unparsed: list[str]): return config def _validate(self): - assert self.model_config_class is not None - self.input.setup(self.model_config_class) - self.output.setup(self.model_config_class) + assert self.model is not None + self.input.setup(self.model) + self.output.setup(self.model) super()._validate() def _convert_model_partial( @@ -81,7 +81,7 @@ def run(self): f"Output path {self.output.path} already exists and has been processed. Skipping model conversion..." ) return - model_class = self.model_config_class.get_model_class() + model_class = self.model.get_model_class() if self.layers_per_step is None: self._convert_model_partial(model_class, self.output) else: @@ -161,4 +161,4 @@ def run(self): if __name__ == "__main__": - ConversionConfig.parse_and_run() + ConvertConfig.parse_and_run() diff --git a/tests/config/common.py b/tests/config/common.py index a2657926..9ccfb597 100644 --- a/tests/config/common.py +++ b/tests/config/common.py @@ -13,7 +13,7 @@ class ExampleEnum(enum.StrEnum): c = "c" -@config_class +@config_class() class ExampleConfig(Config): int_field: int = Field(default=0, hint=FieldHint.optional) bool_field: bool = Field(default=False, hint=FieldHint.optional) @@ -40,7 +40,7 @@ def _validate(self) -> None: super()._validate() -@config_class +@config_class() class ExampleVerboseConfig(Config): # These fields will have non-empty default serialized values. list_default_field: list[int] = Field(default_factory=lambda: [0], hint=FieldHint.optional) @@ -56,9 +56,9 @@ def _validate(self) -> None: super()._validate() -@config_class +@config_class() class ExampleNestedConfig(ExampleConfig): - nested_field: ExampleConfig = Field(default_factory=ExampleConfig, hint=FieldHint.core) + nested_field: ExampleConfig = Field(hint=FieldHint.core) def check_config( diff --git a/tests/config/test_config.py b/tests/config/test_config.py new file mode 100644 index 00000000..4c473fa6 --- /dev/null +++ b/tests/config/test_config.py @@ -0,0 +1,30 @@ +import pytest + +from fast_llm.config import NoAutoValidate +from tests.config.common import ExampleConfig + + +def test_auto_validate(): + assert (config := ExampleConfig())._validated + with pytest.raises(RuntimeError): + config.bool_field = True + config.bool_field = False + + assert ExampleConfig.from_dict({})._validated + + with NoAutoValidate(): + assert not (config := ExampleConfig())._validated + + config.bool_field = True + + config.validate() + + assert config._validated + with pytest.raises(RuntimeError): + config.bool_field = False + config.bool_field = True + + with NoAutoValidate(): + assert not (config := ExampleConfig.from_dict({}))._validated + config.validate() + assert config._validated diff --git a/tests/data/common.py b/tests/data/common.py index 47b53195..00c3ff20 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -31,7 +31,6 @@ def get_sampling_data( *, seed: int = 54983, cache_directory: pathlib.Path | None = None, - distributed: Distributed = Distributed(DistributedConfig(), use_cpu=True), phase=PhaseType.training, sequence_length: int = 512, vocab_size=TEST_VOCAB_SIZE, @@ -41,6 +40,7 @@ def get_sampling_data( truncate_documents=True, ) -> GPTSamplingData: # Config with convenient defaults. + distributed = Distributed(DistributedConfig(), use_cpu=True) return GPTSamplingData( config=GPTSamplingConfig( seed=seed, diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 257947e9..77a4b482 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -17,7 +17,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.engine.multi_stage.multi_stage import ShardName from fast_llm.models.auto import model_registry -from fast_llm.tools.convert import ConversionConfig +from fast_llm.tools.convert import ConvertConfig from tests.common import ( CONFIG_COMMON, FORCE_REUSE_RESULTS, @@ -90,7 +90,7 @@ def test_resume(): ) -def _run_conversion(config: ConversionConfig): +def _run_conversion(config: ConvertConfig): if config.output.path.is_dir() and not REUSE_RESULTS: shutil.rmtree(config.output.path) if not config.output.path.is_dir(): @@ -106,7 +106,7 @@ def _run_conversion(config: ConversionConfig): @pytest.mark.depends(on=["test_checkpoint_and_eval"]) def test_convert_distributed_to_fast_llm(): _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CKPT_PATH, format=DistributedCheckpointFormat, @@ -125,7 +125,7 @@ def test_convert_fast_llm_to_huggingface(): if HUGGINGFACE_CHECKPOINT_FORMAT is None: pytest.skip(f"Conversion not supported for {TEST_MODEL}") _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CONVERT_PATH / "fast_llm_0", format=FastLLMCheckpointFormat, @@ -142,7 +142,7 @@ def test_convert_fast_llm_to_huggingface(): @pytest.mark.depends(on=["test_convert_fast_llm_to_huggingface"]) def test_convert_huggingface_to_distributed(): _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_0", format=HUGGINGFACE_CHECKPOINT_FORMAT, @@ -161,7 +161,7 @@ def test_convert_distributed_to_huggingface(): if HUGGINGFACE_CHECKPOINT_FORMAT is None: pytest.skip(f"Conversion not supported for {TEST_MODEL}") _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CKPT_PATH, format=DistributedCheckpointFormat, @@ -178,7 +178,7 @@ def test_convert_distributed_to_huggingface(): @pytest.mark.depends(on=["test_convert_distributed_to_huggingface"]) def test_convert_huggingface_to_fast_llm(): _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_1", format=HUGGINGFACE_CHECKPOINT_FORMAT, @@ -195,7 +195,7 @@ def test_convert_huggingface_to_fast_llm(): @pytest.mark.depends(on=["test_convert_huggingface_to_fast_llm"]) def test_convert_fast_llm_to_distributed(): _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CONVERT_PATH / "fast_llm_1", format=FastLLMCheckpointFormat, diff --git a/tests/test_ssms.py b/tests/test_ssms.py index e6c9aafd..0fec3741 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -139,6 +139,7 @@ def test_load_from_llamba_checkpoint(distributed_config): assert torch.allclose(logits, hf_logits, atol=1e-2) +@pytest.mark.skip(reason="Too slow.") @pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") @pytest.mark.parametrize( "hybrid_block_layout,LAYER_CLS", @@ -207,6 +208,7 @@ def test_mamba_block(distributed_config, distributed): assert not torch.isinf(hidden_states).any() +@pytest.mark.slow @pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") @pytest.mark.parametrize( ("hybrid_block_layout"), diff --git a/tools/push_model.py b/tools/push_model.py index cd98b93c..edab3312 100644 --- a/tools/push_model.py +++ b/tools/push_model.py @@ -27,7 +27,7 @@ raise ImportError("Please install huggingface_hub to use this script") from e -from fast_llm.tools.convert import ConversionConfig # isort:skip +from fast_llm.tools.convert import ConvertConfig # isort:skip logger = logging.getLogger(__name__) @@ -147,7 +147,7 @@ def run(self) -> None: for _, checkpoint_path in new_checkpoint_paths: checkpoint_path_hf = checkpoint_path.with_name(checkpoint_path.name + "_hf") # Block until the conversion is done - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=checkpoint_path, format=DistributedCheckpointFormat, From 28d321e320354539f2939fe3f94095a96fc43dcc Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 14 May 2025 17:02:59 -0400 Subject: [PATCH 18/26] stuff --- fast_llm/config.py | 1 + fast_llm/engine/optimizer/config.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 3b277202..4928cdbd 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -151,6 +151,7 @@ def __init__( default=default, default_factory=default_factory, init=init, + repr=False, hash=hash, compare=compare, metadata=metadata, diff --git a/fast_llm/engine/optimizer/config.py b/fast_llm/engine/optimizer/config.py index 3a154c9e..f4303a5d 100644 --- a/fast_llm/engine/optimizer/config.py +++ b/fast_llm/engine/optimizer/config.py @@ -74,12 +74,10 @@ class GradientScalerConfig(Config): class OptimizerConfig(Config): learning_rate: LearningRateScheduleConfig = Field( - default_factory=LearningRateScheduleConfig, desc="A schedule for the learning rate.", hint=FieldHint.core, ) gradient_scaler: GradientScalerConfig = Field( - default_factory=GradientScalerConfig, desc="Configuration for the fixed or dynamic gradient scaling.", hint=FieldHint.feature, ) From 1bbd7fb1bf55258a4f7589d6eb0b63d71d2f0fa5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 14 May 2025 17:46:02 -0400 Subject: [PATCH 19/26] stuff --- fast_llm/tools/convert.py | 2 +- tests/test_checkpoint.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/fast_llm/tools/convert.py b/fast_llm/tools/convert.py index 3ee580aa..648138ec 100644 --- a/fast_llm/tools/convert.py +++ b/fast_llm/tools/convert.py @@ -40,7 +40,7 @@ def _get_parser(cls): @classmethod def _from_parsed_args(cls, parsed: argparse.Namespace, unparsed: list[str]): config = super()._from_parsed_args(parsed, unparsed) - config.model_config_class = model_registry[parsed.model_type] + config.model = model_registry[parsed.model_type] return config def _validate(self): diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 77a4b482..e0845a4c 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -115,7 +115,7 @@ def test_convert_distributed_to_fast_llm(): path=_CONVERT_PATH / "fast_llm_0", format=FastLLMCheckpointFormat, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) @@ -134,7 +134,7 @@ def test_convert_fast_llm_to_huggingface(): path=_CONVERT_PATH / "huggingface_0", format=HUGGINGFACE_CHECKPOINT_FORMAT, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) @@ -151,7 +151,7 @@ def test_convert_huggingface_to_distributed(): path=_CONVERT_PATH / "distributed_0", format=DistributedCheckpointFormat, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) @@ -170,7 +170,7 @@ def test_convert_distributed_to_huggingface(): path=_CONVERT_PATH / "huggingface_1", format=HUGGINGFACE_CHECKPOINT_FORMAT, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) @@ -187,7 +187,7 @@ def test_convert_huggingface_to_fast_llm(): path=_CONVERT_PATH / "fast_llm_1", format=FastLLMCheckpointFormat, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) @@ -204,7 +204,7 @@ def test_convert_fast_llm_to_distributed(): path=_CONVERT_PATH / "distributed_1", format=DistributedCheckpointFormat, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) From 87e11f0c2e73fe40088a8a2c33c4ec84b9d6873e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 14 May 2025 18:08:40 -0400 Subject: [PATCH 20/26] stuff --- fast_llm/data/dataset/gpt/config.py | 38 ++++++------------- fast_llm/data/preparator/config.py | 5 +-- fast_llm/data/preparator/gpt_memmap/config.py | 6 +-- fast_llm/engine/checkpoint/convert.py | 5 +-- fast_llm/engine/checkpoint/state_dict.py | 2 - fast_llm/engine/training/config.py | 5 +-- fast_llm/layers/common/config.py | 11 ++---- fast_llm/layers/transformer/config.py | 22 +++-------- 8 files changed, 25 insertions(+), 69 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index df38474e..ed9f57fc 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -108,7 +108,7 @@ def build(self) -> "GPTIndexedDataset": raise NotImplementedError() -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "random"}) class GPTRandomDatasetConfig(GPTSamplableDatasetConfig): _abstract: typing.ClassVar[bool] = False name: str = Field( @@ -123,7 +123,7 @@ def build(self) -> "GPTRandomDataset": return GPTRandomDataset(self.name) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "memmap"}) class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False path: pathlib.Path = Field( @@ -148,7 +148,7 @@ def build(self) -> "GPTMemmapDataset": return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated"}) class GPTConcatenatedDatasetConfig(ConcatenatedDatasetConfig, GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False datasets: list[GPTIndexedDatasetConfig] = FieldUpdate() @@ -159,7 +159,7 @@ def build(self) -> "GPTConcatenatedDataset": return self._build(GPTConcatenatedDataset) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "slice"}) class GPTDatasetSliceConfig(DatasetSliceConfig, GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False dataset: GPTIndexedDatasetConfig = FieldUpdate() @@ -170,20 +170,20 @@ def build(self) -> "GPTDatasetSlice": return self._build(GPTDatasetSlice) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "sampled"}) class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig): _abstract = False sampling: GPTSamplingConfig = FieldUpdate() dataset: GPTSampledDatasetConfig = FieldUpdate() -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "blended"}) class GPTBlendedDatasetConfig(BlendedDatasetConfig, GPTSampledDatasetConfig): _abstract: typing.ClassVar[bool] = False datasets: list[GPTSampledDatasetConfig] = FieldUpdate() -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "file"}) class GPTDatasetFromFileConfig(GPTSamplableDatasetConfig): _abstract: typing.ClassVar[bool] = False path: pathlib.Path = Field( @@ -221,7 +221,8 @@ def _convert_paths(self, config): return config -@config_class() +# Add user-friendly names for the configs. +@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated_memmap"}) class GPTConcatenatedMemmapConfig(GPTIndexedDatasetConfig): # TODO v0.3: Remove. _abstract: typing.ClassVar[bool] = False @@ -327,7 +328,7 @@ class FimConfig(Config): ) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "fim"}) class GPTFimSampledDatasetConfig(GPTSampledDatasetConfig, FimConfig): """ Configuration for FIM. @@ -394,7 +395,7 @@ class GPTLegacyConfig(Config): ) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "legacy"}) class GPTLegacyDatasetConfig(GPTSampledDatasetConfig, GPTLegacyConfig): _abstract: typing.ClassVar[bool] = False @@ -475,7 +476,7 @@ def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset: return GPTSampledDatasetConfig.from_dict(dataset_config).build_and_sample(sampling) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "test_slow"}) class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig): """ A mock dataset that mimics a slow dataset creation on one rank, which may trigger a timeout. @@ -494,18 +495,3 @@ def build_and_sample(self, sampling: SamplingData) -> SampledDataset: if sampling.distributed.config.rank == 0: time.sleep(self.sleep) return GPTRandomDatasetConfig().build_and_sample(sampling) - - -# Add user-friendly names for the configs. -GPTSampledDatasetConfig.register_subclass("dummy", GPTRandomDatasetConfig) -GPTSampledDatasetConfig.register_subclass("random", GPTRandomDatasetConfig) -GPTSampledDatasetConfig.register_subclass("memmap", GPTMemmapDatasetConfig) -GPTSampledDatasetConfig.register_subclass("concatenated", GPTConcatenatedDatasetConfig) -GPTSampledDatasetConfig.register_subclass("slice", GPTDatasetSliceConfig) -GPTSampledDatasetConfig.register_subclass("sampled", GPTSampledDatasetUpdateConfig) -GPTSampledDatasetConfig.register_subclass("blended", GPTBlendedDatasetConfig) -GPTSampledDatasetConfig.register_subclass("file", GPTDatasetFromFileConfig) -GPTSampledDatasetConfig.register_subclass("concatenated_memmap", GPTConcatenatedMemmapConfig) -GPTSampledDatasetConfig.register_subclass("fim", GPTFimSampledDatasetConfig) -GPTSampledDatasetConfig.register_subclass("legacy", GPTLegacyDatasetConfig) -GPTSampledDatasetConfig.register_subclass("test_slow", GPTTestSlowDatasetConfig) diff --git a/fast_llm/data/preparator/config.py b/fast_llm/data/preparator/config.py index b2068ddc..f5bdeed2 100644 --- a/fast_llm/data/preparator/config.py +++ b/fast_llm/data/preparator/config.py @@ -5,7 +5,7 @@ from fast_llm.engine.config_utils.runnable import RunnableConfig -@config_class() +@config_class(dynamic_type={RunnableConfig: "prepare"}) class DatasetPreparatorConfig(RunnableConfig): preparator_name: typing.ClassVar[str] @@ -24,6 +24,3 @@ class DatasetPreparator[ConfigType: DatasetPreparatorConfig](Configurable[Config @abc.abstractmethod def run(self) -> None: raise NotImplementedError - - -RunnableConfig.register_subclass("prepare", DatasetPreparatorConfig) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 253f474e..51831bde 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -111,7 +111,7 @@ def _validate(self) -> None: Assert.in_range(self.rank, 0, self.world_size) -@config_class() +@config_class(dynamic_type={RunnableConfig: "prepare_gpt_memmap", DatasetPreparatorConfig: "gpt_memmap"}) class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): preparator_name: typing.ClassVar[str] = "gpt_memmap" @@ -174,7 +174,3 @@ def get_dataset_preparator_class(cls) -> type["GPTMemmapDatasetPreparator"]: from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator return GPTMemmapDatasetPreparator - - -RunnableConfig.register_subclass("prepare_gpt_memmap", GPTMemmapDatasetPreparatorConfig) -DatasetPreparatorConfig.register_subclass("gpt_memmap", GPTMemmapDatasetPreparatorConfig) diff --git a/fast_llm/engine/checkpoint/convert.py b/fast_llm/engine/checkpoint/convert.py index 7b097072..0f670854 100644 --- a/fast_llm/engine/checkpoint/convert.py +++ b/fast_llm/engine/checkpoint/convert.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) -@config_class() +@config_class(dynamic_type={RunnableConfig: "convert"}) class ConvertConfig(RunnableConfig): input: CheckpointLoadConfig = Field() output: CheckpointSaveConfig = Field() @@ -142,6 +142,3 @@ def run(self): # All good! (self.output.path / "ok").open("w") logger.info(f">>> All done!") - - -RunnableConfig.register_subclass("convert", ConvertConfig) diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index d6807138..556e97be 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -26,8 +26,6 @@ logger = logging.getLogger(__name__) -torch.distributed.gather - class StateDictCheckpointHandler(CheckpointHandler): base_file_name: typing.ClassVar[str] = "model" diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 0dba6e70..b0a7a26b 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -275,7 +275,7 @@ class ShutdownConfig(IntervalConfig): ) -@config_class() +@config_class(dynamic_type={RunnableConfig: "train"}) class TrainingConfig(Config): evaluations: dict[str, EvaluationConfig] = Field( default_factory=dict, @@ -421,6 +421,3 @@ def new_setup(): old_setup() object.__setattr__(pretrained, "_setup", new_setup) - - -RunnableConfig.register_subclass("train", TrainerConfig) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index c03e9957..f6fbd4f5 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -47,7 +47,7 @@ def _from_dict( return super()._from_dict(default, strict=strict, flat=flat) -@config_class() +@config_class(dynamic_type={NormalizationConfig: "none"}) class NoNormalizationConfig(NormalizationConfig): _abstract = False @@ -123,7 +123,7 @@ def _from_dict( return super()._from_dict(default, strict, flat) -@config_class() +@config_class(dynamic_type={NormalizationConfig: "layer_norm"}) class LayerNormalizationConfig(LayerNormalizationBaseConfig): _abstract = False @@ -134,7 +134,7 @@ def module_class(self): return LayerNorm -@config_class() +@config_class(dynamic_type={NormalizationConfig: "rms_norm"}) class RMSNormalizationConfig(LayerNormalizationBaseConfig): _abstract = False @@ -145,11 +145,6 @@ def module_class(self): return RMSNorm -NormalizationConfig.register_subclass("none", NoNormalizationConfig) -NormalizationConfig.register_subclass("layer_norm", LayerNormalizationConfig) -NormalizationConfig.register_subclass("rms_norm", RMSNormalizationConfig) - - @config_class() class PeftConfig(BaseModelConfig): @abc.abstractmethod diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index e06977b8..c83b0b43 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -113,12 +113,12 @@ def get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") raise NotImplementedError() -@config_class() +@config_class(dynamic_type={RotaryConfig: "none"}) class NoRotaryConfig(RotaryConfig): _abstract = False -@config_class() +@config_class(dynamic_type={RotaryConfig: "default"}) class DefaultRotaryConfig(RotaryConfig): _abstract = False theta: float = Field( @@ -172,7 +172,7 @@ def _get_angle_scales(self, kv_channels: int, device="cuda") -> "torch.Tensor": return self.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) -@config_class() +@config_class(dynamic_type={RotaryConfig: "llama3"}) class Llama3RotaryConfig(DefaultRotaryConfig): """ Llama3 scaling: https://github.com/meta-llama/llama-models/blob/baf7b01b6e62bc7126c7b558d2b67d4533142680/models/llama3/reference_impl/model.py#L45-L67 @@ -209,7 +209,7 @@ def _get_angle_scales(self, kv_channels: int, device="cuda") -> "torch.Tensor": return torch.stack(new_scales) -@config_class() +@config_class(dynamic_type={RotaryConfig: "yarn"}) class YarnRotaryConfig(DefaultRotaryConfig): """ Yarn scaling: @@ -275,12 +275,6 @@ def _get_correction(self, beta: float, dim: int) -> float: ) -RotaryConfig.register_subclass("none", NoRotaryConfig) -RotaryConfig.register_subclass("default", DefaultRotaryConfig) -RotaryConfig.register_subclass("llama3", Llama3RotaryConfig) -RotaryConfig.register_subclass("yarn", YarnRotaryConfig) - - class AddLinearBiasChoices(str, enum.Enum): nowhere = "nowhere" everywhere = "everywhere" @@ -325,7 +319,7 @@ def _from_dict( return super()._from_dict(default, strict=strict, flat=flat) -@config_class() +@config_class(dynamic_type={TransformerPeftConfig: "none"}) class TransformerNoPeftConfig(NoPeftConfig, TransformerPeftConfig): _abstract = False @@ -339,7 +333,7 @@ def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": return parameter -@config_class() +@config_class(dynamic_type={TransformerPeftConfig: "lora"}) class TransformerLoRAConfig(LoRAConfig, TransformerPeftConfig): layers: list[TransformerSubLayerName] = Field( default=(TransformerSubLayerName.query, TransformerSubLayerName.value_), @@ -398,10 +392,6 @@ def _validate(self) -> None: ) -TransformerPeftConfig.register_subclass("none", TransformerNoPeftConfig) -TransformerPeftConfig.register_subclass("lora", TransformerLoRAConfig) - - @config_class() class TransformerConfig(BaseModelConfig): _abstract = False From 60a656ec7d2e9c1f794e8ba5bb3e2a7c964b3eb2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 14 May 2025 19:04:44 -0400 Subject: [PATCH 21/26] stuff --- tests/data/test_prepare_gpt_memmap.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index ccb94c23..9dd7975c 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from fast_llm.data.dataset.gpt.config import GPTBlendedDatasetConfig, GPTDatasetSliceConfig, GPTIndexedDatasetConfig +from fast_llm.data.dataset.gpt.config import GPTIndexedDatasetConfig from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, GPTMemmapDatasetPreparatorConfig @@ -77,12 +77,12 @@ def test_absent_metadata_local(): DATASET_DICT_0 = { - "type": MockGPTMemmapDatasetConfig.__name__, + "type": "mock_memmap", "num_documents": 500, "num_tokens_per_document": 300, } DATASET_DICT_1 = { - "type": MockGPTMemmapDatasetConfig.__name__, + "type": "mock_memmap", "num_documents": 1500, "num_tokens_per_document": 100, } @@ -101,13 +101,13 @@ def test_split_dataset(): config, { "training": { - "type": GPTDatasetSliceConfig.__name__, + "type": "slice", "dataset": dataset_config_0.to_dict(), "begin": 0, "end": 0.75, }, "validation": { - "type": GPTDatasetSliceConfig.__name__, + "type": "slice", "dataset": dataset_config_0.to_dict(), "begin": 0.75, "end": 1, @@ -147,11 +147,11 @@ def test_split_datasets_1(): config, { "training": { - "type": GPTBlendedDatasetConfig.__name__, + "type": "blended", "datasets": [ dataset_config_0.to_dict(), { - "type": GPTDatasetSliceConfig.__name__, + "type": "slice", "dataset": dataset_config_1.to_dict(), "begin": 0, "end": 0.5, @@ -160,7 +160,7 @@ def test_split_datasets_1(): "weights": [2 / 3, 1 / 3], }, "validation": { - "type": GPTDatasetSliceConfig.__name__, + "type": "slice", "dataset": dataset_config_1.to_dict(), "begin": 0.5, "end": 1, From 35959491886cd87d6a592b61745c42214a08c4a5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 14 May 2025 19:29:34 -0400 Subject: [PATCH 22/26] Minimalistic dynamic configs --- fast_llm/config.py | 73 ++++++++++++++++++++++- fast_llm/data/dataset/gpt/config.py | 92 +++++------------------------ 2 files changed, 85 insertions(+), 80 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 4928cdbd..6e3e92dc 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -12,7 +12,7 @@ import yaml -from fast_llm.utils import Assert, Tag, compare_nested, get_type_name, header, log +from fast_llm.utils import Assert, Registry, Tag, compare_nested, get_type_name, header, log logger = logging.getLogger(__name__) @@ -243,7 +243,9 @@ def _process_config_class(cls: type["Config"]): return cls -def config_class[T: Config]() -> typing.Callable[[type[T]], type[T]]: +def config_class[ + T: Config +](registry: bool = False, dynamic_type: "dict[type[Config], str]|None" = None) -> typing.Callable[[type[T]], type[T]]: """ Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications. """ @@ -253,7 +255,7 @@ def wrap(cls): if hasattr(cls, "__post_init__"): raise TypeError(f"`__post_init__` should not be implemented for `Config` classes") - wrapped = _process_config_class(dataclasses.dataclass(cls, kw_only=True)) + wrapped = _process_config_class(dataclasses.dataclass(cls, kw_only=True, repr=False)) wrapped_init = cls.__init__ @@ -267,6 +269,13 @@ def __init__(self, **kwargs): self.validate() wrapped.__init__ = __init__ + + wrapped._registry = Registry[str, type[wrapped]](wrapped.__name__, {}) if registry else None + + if dynamic_type is not None: + for cls_, name in dynamic_type.items(): + cls_.register_subclass(name, wrapped) + return wrapped return wrap @@ -305,6 +314,9 @@ class Config(metaclass=ConfigMeta): # without them being automatically added to `_explicit_fields`. _setting_implicit_default: bool | None = Field(init=False) + # A registry for all the config classes. + _registry: typing.ClassVar[Registry[str, type[typing.Self]] | None] = None + def __setattr__(self, key: str, value: typing.Any) -> None: """ Make the class read-only after validation. @@ -358,6 +370,17 @@ def validate[T: Config](self: T, *, _is_validating: bool = False) -> T: Validate a class and mark it as read-only This should not be overridden in derived classes. """ + # Should be handled in `from_dict`, but can fail if instantiating directly. + try: + expected_class = self.get_subclass(self.type) + except KeyError as e: + # Delayed instantiation error in `from_dict`. + raise ValidationError(*e.args) + + if expected_class is not None: + # Should be handled in `from_dict`, but can fail if instantiating directly. + Assert.is_(self.__class__, expected_class) + if not self._validated: try: self._validate() @@ -738,6 +761,14 @@ def _from_dict( if "__class__" in default: del default["__class__"] + try: + actual_cls = cls.get_subclass(default.get("type")) + if actual_cls is not None and actual_cls is not cls: + return actual_cls._from_dict(default, strict=strict, flat=flat) + except KeyError: + # Postpone error to validation. + pass + # Do not validate yet in case the root class sets cross-dependencies in validation. with NoAutoValidate(): for name, field in cls.fields(): @@ -864,6 +895,42 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ ) return None + @classmethod + def register_subclass(cls, name: str, cls_: type[typing.Self]) -> None: + Assert.custom(issubclass, cls_, cls) + if cls._registry is None: + raise NotImplementedError(f"Subclass `{name}` doesn't have a registry..") + if name in cls._registry: + old_cls = cls._registry[name] + if old_cls.__name__ == cls_.__name__ and cls._registry[name].__module__ == cls_.__module__: + del cls._registry[name] + else: + raise KeyError(f"{cls.__name__} class registry already has an entry {name} from class {cls.__name__}.") + cls._registry[name] = cls_ + + @classmethod + def get_subclass(cls, name: str | None): + # TODO: Make it case-insensitive? + if name is None: + return None + cls_ = None + for base_class in cls.__mro__: + if issubclass(base_class, Config) and base_class._registry is not None and name in base_class._registry: + if cls_ is None: + cls_ = base_class._registry[name] + if not issubclass(cls_, cls): + raise KeyError(f" {cls_.__name__} is not a subclass of {cls.__name__} (from type {name})") + elif base_class._registry[name] is not cls_: + # We explicitly prevent ambiguous classes to ensure safe and unambiguous serialization. + # TODO: Only really need to avoid conflict with `Config`'s registry, relax this a bit? + raise KeyError( + f"Ambiguous type `{name}` for base class {cls.__name__}." + f" ({cls_.__name__} vs {base_class._registry[name]})" + ) + if cls_ is None: + raise KeyError(f"Unknown type {name} for base class {cls.__name__}") + return cls_ + def __init_subclass__(cls): """ We need to postpone validation until the class has been processed by the dataclass wrapper. diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index f4f6e282..4ab0b7df 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -23,7 +23,7 @@ SamplingParameters, ) from fast_llm.engine.distributed.config import PhaseType -from fast_llm.utils import Assert, Registry, normalize_probabilities, padded_cumsum +from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum if typing.TYPE_CHECKING: from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset @@ -92,61 +92,9 @@ class GPTSamplingData(SamplingData): truncate_documents: bool = True -@config_class() +@config_class(registry=True) class GPTSampledDatasetConfig(SampledDatasetConfig): - - # TODO: Generalize dynamic types? - _registry: typing.ClassVar[Registry[str, type["GPTSampledDatasetConfig"]]] = Registry[ - str, type["GPTDatasetConfig"] - ]("gpt_dataset_class", {}) - type_: typing.ClassVar[str | None] = None - type: str | None = Field( - default=None, - desc="The type of dataset.", - hint=FieldHint.core, - ) - - def _validate(self) -> None: - if self.type is None: - self.type = self.type_ - # Should be handled in `from_dict`, but can fail if instantiating directly. - Assert.eq(self.type, self.__class__.type_) - super()._validate() - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - type_ = default.get("type") - if type_ is None: - actual_cls = cls - else: - if type_ not in cls._registry: - raise ValueError( - f"Unknown {cls._registry.name} type {type_}." f" Available types: {list(cls._registry.keys())}" - ) - actual_cls = cls._registry[type_] - Assert.custom(issubclass, actual_cls, cls) - if actual_cls == cls: - return super()._from_dict(default, strict=strict, flat=flat) - else: - return actual_cls._from_dict(default, strict=strict, flat=flat) - - def __init_subclass__(cls) -> None: - if cls._abstract and cls.type_ is not None: - # Abstract classes should not have a `type_` - raise ValueError(f"Abstract class {cls.__name__} has type = {cls.type_}, expected None.") - if cls.type_ is not None: - if cls.type_ in cls._registry: - raise ValueError( - f"Registry {cls._registry.name} already contains type {cls.type_}." - f" Make sure all classes either have a unique or `None` type." - ) - GPTSampledDatasetConfig._registry[cls.type_] = cls - super().__init_subclass__() + pass @config_class() @@ -160,10 +108,9 @@ def build(self) -> "GPTIndexedDataset": raise NotImplementedError() -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "random"}) class GPTRandomDatasetConfig(GPTSamplableDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "random" name: str = Field( default="dummy", desc="The name of the dataset.", @@ -176,10 +123,9 @@ def build(self) -> "GPTRandomDataset": return GPTRandomDataset(self.name) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "memmap"}) class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "memmap" path: pathlib.Path = Field( default=None, desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.", @@ -202,10 +148,9 @@ def build(self) -> "GPTMemmapDataset": return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated"}) class GPTConcatenatedDatasetConfig(ConcatenatedDatasetConfig, GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "concatenated" datasets: list[GPTIndexedDatasetConfig] = FieldUpdate() def build(self) -> "GPTConcatenatedDataset": @@ -214,10 +159,9 @@ def build(self) -> "GPTConcatenatedDataset": return self._build(GPTConcatenatedDataset) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "slice"}) class GPTDatasetSliceConfig(DatasetSliceConfig, GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "slice" dataset: GPTIndexedDatasetConfig = FieldUpdate() def build(self) -> "GPTDatasetSlice": @@ -226,25 +170,22 @@ def build(self) -> "GPTDatasetSlice": return self._build(GPTDatasetSlice) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "sampled"}) class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig): _abstract = False - type_: typing.ClassVar[str | None] = "sampled" sampling: GPTSamplingConfig = FieldUpdate() dataset: GPTSampledDatasetConfig = FieldUpdate() -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "blended"}) class GPTBlendedDatasetConfig(BlendedDatasetConfig, GPTSampledDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "blended" datasets: list[GPTSampledDatasetConfig] = FieldUpdate() -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "file"}) class GPTDatasetFromFileConfig(GPTSamplableDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "file" path: pathlib.Path = Field( default=None, desc="The path to a dataset config file.", @@ -280,11 +221,11 @@ def _convert_paths(self, config): return config -@config_class() +# Add user-friendly names for the configs. +@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated_memmap"}) class GPTConcatenatedMemmapConfig(GPTIndexedDatasetConfig): # TODO v0.3: Remove. _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "concatenated_memmap" path: pathlib.Path = Field( default=None, desc="The path to a dataset directory.", @@ -387,14 +328,13 @@ class FimConfig(Config): ) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "fim"}) class GPTFimSampledDatasetConfig(GPTSampledDatasetConfig, FimConfig): """ Configuration for FIM. """ _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "fim" dataset: GPTSampledDatasetConfig = Field( default=None, @@ -455,10 +395,9 @@ class GPTLegacyConfig(Config): ) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "legacy"}) class GPTLegacyDatasetConfig(GPTSampledDatasetConfig, GPTLegacyConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "legacy" def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset: @@ -537,7 +476,7 @@ def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset: return GPTSampledDatasetConfig.from_dict(dataset_config).build_and_sample(sampling) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "test_slow"}) class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig): """ A mock dataset that mimics a slow dataset creation on one rank, which may trigger a timeout. @@ -545,7 +484,6 @@ class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig): # TODO: This belongs to a testing plugin. _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "test_slow" sleep: float = Field( default=1, desc="Sleep time during build, in seconds.", From 39b1a04fd140718afda39e96a5882754819b49d8 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 14 May 2025 20:42:56 -0400 Subject: [PATCH 23/26] stuff --- fast_llm/config.py | 10 +++++++++- fast_llm/layers/common/config.py | 8 +++++++- fast_llm/layers/transformer/config.py | 16 ++++++++++++++-- tests/data/common.py | 3 +-- 4 files changed, 31 insertions(+), 6 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 6e3e92dc..380100e3 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -274,6 +274,7 @@ def __init__(self, **kwargs): if dynamic_type is not None: for cls_, name in dynamic_type.items(): + print(cls_, name, wrapped) cls_.register_subclass(name, wrapped) return wrapped @@ -899,7 +900,7 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ def register_subclass(cls, name: str, cls_: type[typing.Self]) -> None: Assert.custom(issubclass, cls_, cls) if cls._registry is None: - raise NotImplementedError(f"Subclass `{name}` doesn't have a registry..") + raise NotImplementedError(f"Subclass `{cls.__name__}` doesn't have a registry..") if name in cls._registry: old_cls = cls._registry[name] if old_cls.__name__ == cls_.__name__ and cls._registry[name].__module__ == cls_.__module__: @@ -980,6 +981,13 @@ def __init_subclass__(cls): # dataclasses expects an annotation, so we use the one from the base class. cls.__annotations__[name] = base_class_field.type + # Type for the field. At the end of class definition to avoid shadowing builtin. + type: str | None = Field( + default=None, + desc="The config class name.", + hint=FieldHint.feature, + ) + class Configurable[ConfigType: Config]: config_class: typing.ClassVar[type[Config]] = Config diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 269989ce..054c26c3 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -33,7 +33,7 @@ class NormalizationType(str, enum.Enum): rms_norm = "rms_norm" -@config_class() +@config_class(registry=True) class NormalizationConfig(BaseModelConfig): _abstract = False @@ -107,6 +107,12 @@ def _from_dict( return super()._from_dict(default, strict, flat) +for name in NormalizationType: + # We need this because we are using the reserved field name `type`. + # TODO: Implement proper dynamic typing. + NormalizationConfig.register_subclass(name.value, NormalizationConfig) + + class PeftType(str, enum.Enum): # TODO : Use a dynamic config type instead. none = "none" diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index c621139c..e7ef0b15 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -95,7 +95,7 @@ class RotaryEmbeddingType(str, enum.Enum): yarn = "yarn" -@config_class() +@config_class(registry=True) class RotaryConfig(BaseModelConfig): _abstract = False type: RotaryEmbeddingType = Field( @@ -158,6 +158,12 @@ def _validate(self) -> None: warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.") +for name in RotaryEmbeddingType: + # We need this because we are using the reserved field name `type`. + # TODO: Implement proper dynamic typing. + RotaryConfig.register_subclass(name.value, RotaryConfig) + + class AddLinearBiasChoices(str, enum.Enum): nowhere = "nowhere" everywhere = "everywhere" @@ -175,7 +181,7 @@ class TransformerSubLayerName(str, enum.Enum): mlp_2 = "mlp_2" -@config_class() +@config_class(registry=True) class TransformerPeftConfig(PeftConfig): layers: list[TransformerSubLayerName] = Field( default=None, @@ -244,6 +250,12 @@ def _validate(self) -> None: ) +for name in PeftType: + # We need this because we are using the reserved field name `type`. + # TODO: Implement proper dynamic typing. + TransformerPeftConfig.register_subclass(name.value, TransformerPeftConfig) + + @config_class() class TransformerConfig(BaseModelConfig): _abstract = False diff --git a/tests/data/common.py b/tests/data/common.py index 00c3ff20..cacb28e6 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -189,10 +189,9 @@ def validate_indexed_dataset_sampling( return token_ids -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "mock_memmap"}) class MockGPTMemmapDatasetConfig(GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "mock_memmap" num_documents: int | None = Field( default=None, desc="Expected number of documents in the dataset.", From 743edaae5b6952c39fd710fc4be90e9b2ec9ae51 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 15 May 2025 11:16:41 -0400 Subject: [PATCH 24/26] Simplify cli --- fast_llm/cli.py | 35 +++++++++++++++ fast_llm/data/preparator/config.py | 2 +- fast_llm/data/preparator/gpt_memmap/config.py | 3 +- .../{tools => engine/checkpoint}/convert.py | 26 ++--------- fast_llm/engine/config_utils/runnable.py | 15 +++++-- fast_llm/engine/multi_stage/config.py | 8 +--- fast_llm/engine/training/config.py | 3 +- fast_llm/models/auto.py | 35 +++------------ fast_llm/models/custom/config.py | 7 ++- fast_llm/models/gpt/config.py | 5 ++- fast_llm/models/ssm/config.py | 15 ++++--- fast_llm/tools/__init__.py | 0 fast_llm/tools/cli.py | 43 ------------------- fast_llm/tools/prepare_dataset.py | 24 ----------- fast_llm/tools/train.py | 24 ----------- fast_llm/utils.py | 5 +++ setup.cfg | 2 +- tests/common.py | 6 +-- tests/test_checkpoint.py | 5 +-- tests/test_multi_stage.py | 6 +-- tools/push_model.py | 2 +- 21 files changed, 93 insertions(+), 178 deletions(-) create mode 100644 fast_llm/cli.py rename fast_llm/{tools => engine/checkpoint}/convert.py (89%) delete mode 100644 fast_llm/tools/__init__.py delete mode 100644 fast_llm/tools/cli.py delete mode 100644 fast_llm/tools/prepare_dataset.py delete mode 100644 fast_llm/tools/train.py diff --git a/fast_llm/cli.py b/fast_llm/cli.py new file mode 100644 index 00000000..34546120 --- /dev/null +++ b/fast_llm/cli.py @@ -0,0 +1,35 @@ +import logging +import sys +import traceback + +from fast_llm.config import ValidationError +from fast_llm.engine.config_utils.logging import configure_logging +from fast_llm.engine.config_utils.run import log_main_rank +from fast_llm.engine.config_utils.runnable import RunnableConfig + +# Import these submodules to ensure classes are added to the dynamic class registry. +import fast_llm.data.auto # isort: skip +import fast_llm.engine.checkpoint.convert # isort: skip +import fast_llm.models.auto # isort: skip + +logger = logging.getLogger(__name__) + + +def fast_llm_main(args: list[str] | None = None): + # TODO: Add hook to register model classes? (environment variable?) + # (Pre-)configure logging + configure_logging() + try: + RunnableConfig.parse_and_run(args) + except Exception as e: + if sys.gettrace(): + raise + if isinstance(e, ValidationError): + log_main_rank(traceback.format_exc(), log_fn=logger.error) + else: + logger.critical(traceback.format_exc()) + sys.exit(1) + + +if __name__ == "__main__": + fast_llm_main() diff --git a/fast_llm/data/preparator/config.py b/fast_llm/data/preparator/config.py index edf088c0..7f6376c7 100644 --- a/fast_llm/data/preparator/config.py +++ b/fast_llm/data/preparator/config.py @@ -5,7 +5,7 @@ from fast_llm.engine.config_utils.runnable import RunnableConfig -@config_class() +@config_class(registry=True, dynamic_type={RunnableConfig: "prepare"}) class DatasetPreparatorConfig(RunnableConfig): preparator_name: typing.ClassVar[str] diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 7091f3c8..4b5aa302 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -6,6 +6,7 @@ from fast_llm.data.config import TokenizerConfig from fast_llm.data.preparator.config import DatasetPreparatorConfig from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -24,7 +25,7 @@ MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00" -@config_class() +@config_class(dynamic_type={RunnableConfig: "prepare_gpt_memmap", DatasetPreparatorConfig: "gpt_memmap"}) class GPTHuggingfaceDatasetConfig(Config): path: str = Field( default=None, diff --git a/fast_llm/tools/convert.py b/fast_llm/engine/checkpoint/convert.py similarity index 89% rename from fast_llm/tools/convert.py rename to fast_llm/engine/checkpoint/convert.py index 648138ec..0f670854 100644 --- a/fast_llm/tools/convert.py +++ b/fast_llm/engine/checkpoint/convert.py @@ -1,4 +1,3 @@ -import argparse import json import logging import math @@ -9,7 +8,6 @@ from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.functional.config import TritonConfig -from fast_llm.models.auto import model_registry from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -18,7 +16,7 @@ logger = logging.getLogger(__name__) -@config_class() +@config_class(dynamic_type={RunnableConfig: "convert"}) class ConvertConfig(RunnableConfig): input: CheckpointLoadConfig = Field() output: CheckpointSaveConfig = Field() @@ -27,24 +25,10 @@ class ConvertConfig(RunnableConfig): layers_per_step: int | None = Field(default=None) model: type[FastLLMModelConfig] = Field(default=None) - @classmethod - def _get_parser(cls): - parser = super()._get_parser() - parser.add_argument( - "model_type", - choices=model_registry.keys(), - help="The Fast-LLM model type to use. Must be defined in the model registry in `fast_llm.models.auto`.", - ) - return parser - - @classmethod - def _from_parsed_args(cls, parsed: argparse.Namespace, unparsed: list[str]): - config = super()._from_parsed_args(parsed, unparsed) - config.model = model_registry[parsed.model_type] - return config - def _validate(self): assert self.model is not None + if isinstance(self.model, str): + self.model = FastLLMModelConfig.get_subclass(self.model) self.input.setup(self.model) self.output.setup(self.model) super()._validate() @@ -158,7 +142,3 @@ def run(self): # All good! (self.output.path / "ok").open("w") logger.info(f">>> All done!") - - -if __name__ == "__main__": - ConvertConfig.parse_and_run() diff --git a/fast_llm/engine/config_utils/runnable.py b/fast_llm/engine/config_utils/runnable.py index 6142de47..a89bfa3c 100644 --- a/fast_llm/engine/config_utils/runnable.py +++ b/fast_llm/engine/config_utils/runnable.py @@ -16,13 +16,20 @@ logger = logging.getLogger(__name__) -@config_class() +@config_class(registry=True) class RunnableConfig(Config): @classmethod - def parse_and_run(cls, args=None) -> None: - parsed, unparsed = cls._get_parser().parse_known_args(args) + def parse_and_run(cls, args: list[str] | None = None) -> None: + if args is None: + args = sys.argv[1:] + cls_ = cls + while len(args) >= 1 and "=" not in args[0] and not args[0].startswith("-"): + # Allow chained dynamic type selection without the `type=`, ex. `train gpt`. + cls_ = cls_.get_subclass(args[0]) + args = args[1:] + parsed, unparsed = cls_._get_parser().parse_known_args([f"type={cls_.__name__}"] + args) with NoAutoValidate(): - config: "RunnableConfig" = cls._from_parsed_args(parsed, unparsed) + config: "RunnableConfig" = cls_._from_parsed_args(parsed, unparsed) try: config.configure_logging() config.validate() diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 9434fba6..a040a4fb 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -203,7 +203,7 @@ def _validate(self) -> None: SHARD_PAD_TO_MULTIPLE = 32 -@config_class() +@config_class(registry=True) class FastLLMModelConfig(Config): _abstract = True checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = ( @@ -372,13 +372,9 @@ def _from_dict( if "fast_llm_version" not in default: default["fast_llm_version"] = "0" - # Determine the model config class. - from fast_llm.models.auto import model_registry - model_config_class = default["model"] if isinstance(model_config_class, str): - Assert.incl(model_config_class, model_registry) - model_config_class = model_registry[model_config_class] + model_config_class = FastLLMModelConfig.get_subclass(default["model"]) default["model"] = model_config_class # TODO v0.3: Remove backward compatibility. diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index a5be2e7e..b0a7a26b 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -23,6 +23,7 @@ DistributedCheckpointFormat, ) from fast_llm.engine.config_utils.run import ExperimentConfig +from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import PretrainedFastLLMModelConfig from fast_llm.engine.optimizer.config import OptimizerConfig from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig @@ -274,7 +275,7 @@ class ShutdownConfig(IntervalConfig): ) -@config_class() +@config_class(dynamic_type={RunnableConfig: "train"}) class TrainingConfig(Config): evaluations: dict[str, EvaluationConfig] = Field( default_factory=dict, diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index 8f16aaea..3be74856 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -1,30 +1,7 @@ -from fast_llm.engine.multi_stage.config import FastLLMModelConfig -from fast_llm.engine.training.config import TrainerConfig -from fast_llm.models.custom.config import CustomModelConfig, CustomTrainerConfig -from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig -from fast_llm.models.ssm.config import HybridSSMModelConfig, HybridTrainerConfig -from fast_llm.utils import Registry +""" +Import these submodules to ensure classes are added to the dynamic class registry. +""" -model_registry = Registry[str, FastLLMModelConfig]( - "Model", - { - model.model_name: model - for model in [ - GPTModelConfig, - CustomModelConfig, - HybridSSMModelConfig, - ] - }, -) - -trainer_registry = Registry[str, TrainerConfig]( - "Model", - { - trainer.get_field("model").type.model_name: trainer - for trainer in [ - GPTTrainerConfig, - CustomTrainerConfig, - HybridTrainerConfig, - ] - }, -) +from fast_llm.models.custom.config import CustomModelConfig, CustomTrainerConfig # isort: skip +from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig # isort: skip +from fast_llm.models.ssm.config import HybridSSMModelConfig, HybridSSMTrainerConfig # isort: skip diff --git a/fast_llm/models/custom/config.py b/fast_llm/models/custom/config.py index 08902e2c..963ffd35 100644 --- a/fast_llm/models/custom/config.py +++ b/fast_llm/models/custom/config.py @@ -2,6 +2,9 @@ from fast_llm.config import FieldUpdate, config_class from fast_llm.data.data.gpt.config import GPTDataConfig +from fast_llm.engine.config_utils.runnable import RunnableConfig +from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.engine.training.config import TrainerConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig, GPTTrainerConfig, PretrainedGPTModelConfig if typing.TYPE_CHECKING: @@ -22,7 +25,7 @@ class CustomBaseModelConfig(GPTBaseModelConfig): pass -@config_class() +@config_class(dynamic_type={FastLLMModelConfig: "gpt_custom"}) class CustomModelConfig(GPTModelConfig): # TODO: Add custom model config parameters, if any (typically none). model_name: typing.ClassVar[str] = "gpt_custom" @@ -46,7 +49,7 @@ class PretrainedCustomModelConfig(PretrainedGPTModelConfig): model: CustomModelConfig = FieldUpdate() -@config_class() +@config_class(dynamic_type={RunnableConfig: "train_gpt_custom", TrainerConfig: "gpt_custom"}) class CustomTrainerConfig(PretrainedCustomModelConfig, GPTTrainerConfig): # TODO: Add custom trainer config parameters, if any (typically none). data: CustomDataConfig = FieldUpdate() diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 0ec3fb51..64f6f1de 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -4,6 +4,7 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler +from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig @@ -125,7 +126,7 @@ def _from_dict( return super()._from_dict(default, strict, flat) -@config_class() +@config_class(dynamic_type={FastLLMModelConfig: "gpt"}) class GPTModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "gpt" @@ -159,7 +160,7 @@ class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): model: GPTModelConfig = FieldUpdate() -@config_class() +@config_class(dynamic_type={RunnableConfig: "train_gpt", TrainerConfig: "gpt"}) class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): data: GPTDataConfig = FieldUpdate() batch: GPTBatchConfig = FieldUpdate() diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 771a4fca..a69e7b08 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -5,6 +5,7 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, config_class from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler +from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig @@ -16,7 +17,7 @@ if typing.TYPE_CHECKING: from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM from fast_llm.models.ssm.model import HybridSSMModel - from fast_llm.models.ssm.trainer import SSMTrainer + from fast_llm.models.ssm.trainer import HybridSSMTrainer logger = logging.getLogger(__name__) @@ -124,7 +125,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return LLambaHuggingfaceCheckpointHandler -@config_class() +@config_class(dynamic_type={FastLLMModelConfig: "hybrid_ssm"}) class HybridSSMModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "hybrid_ssm" @@ -156,13 +157,13 @@ class PretrainedHybridSSMModelConfig(PretrainedFastLLMModelConfig): model: HybridSSMModelConfig = FieldUpdate() -@config_class() -class HybridTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): +@config_class(dynamic_type={RunnableConfig: "train_hybrid_ssm", TrainerConfig: "hybrid_ssm"}) +class HybridSSMTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): data: GPTDataConfig = FieldUpdate() batch: GPTBatchConfig = FieldUpdate() @classmethod - def get_trainer_class(cls) -> type["SSMTrainer"]: - from fast_llm.models.ssm.trainer import SSMTrainer + def get_trainer_class(cls) -> type["HybridSSMTrainer"]: + from fast_llm.models.ssm.trainer import HybridSSMTrainer - return SSMTrainer + return HybridSSMTrainer diff --git a/fast_llm/tools/__init__.py b/fast_llm/tools/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/fast_llm/tools/cli.py b/fast_llm/tools/cli.py deleted file mode 100644 index 8df884fe..00000000 --- a/fast_llm/tools/cli.py +++ /dev/null @@ -1,43 +0,0 @@ -import argparse -import logging -import sys -import traceback - -from fast_llm.config import ValidationError -from fast_llm.engine.config_utils.logging import configure_logging -from fast_llm.engine.config_utils.run import log_main_rank - -logger = logging.getLogger(__name__) - - -def fast_llm(args=None): - # TODO: Add hook to register model classes? (environment variable?) - # (Pre-)configure logging - configure_logging() - parser = argparse.ArgumentParser(add_help=False) - parser.add_argument("subcommand", choices=["train", "convert", "prepare"]) - parsed, unparsed = parser.parse_known_args(args) - try: - if parsed.subcommand == "train": - from fast_llm.tools.train import CliTrainingConfig as Runnable - elif parsed.subcommand == "convert": - from fast_llm.tools.convert import ConvertConfig as Runnable - elif parsed.subcommand == "prepare": - from fast_llm.tools.prepare_dataset import PrepareDatasetConfig as Runnable - else: - raise RuntimeError("Unknown subcommand") - Runnable.parse_and_run(unparsed) - except ValidationError: - if sys.gettrace(): - raise - log_main_rank(traceback.format_exc(), log_fn=logger.error) - sys.exit(1) - except Exception: # noqa - if sys.gettrace(): - raise - logger.critical(traceback.format_exc()) - sys.exit(1) - - -if __name__ == "__main__": - fast_llm() diff --git a/fast_llm/tools/prepare_dataset.py b/fast_llm/tools/prepare_dataset.py deleted file mode 100644 index aafe2690..00000000 --- a/fast_llm/tools/prepare_dataset.py +++ /dev/null @@ -1,24 +0,0 @@ -import argparse - -from fast_llm.data.auto import dataset_preparator_registry -from fast_llm.engine.config_utils.runnable import RunnableConfig - - -class PrepareDatasetConfig(RunnableConfig): - @classmethod - def _get_parser(cls): - parser = super()._get_parser() - parser.add_argument( - "model_type", - choices=dataset_preparator_registry.keys(), - help="The Fast-LLM model type to use. Must be defined in the model registry in `fast_llm.models.auto`.", - ) - return parser - - @classmethod - def _from_parsed_args(cls, parsed: argparse.Namespace, unparsed: list[str]): - return dataset_preparator_registry[parsed.model_type]._from_parsed_args(parsed, unparsed) - - -if __name__ == "__main__": - PrepareDatasetConfig.parse_and_run() diff --git a/fast_llm/tools/train.py b/fast_llm/tools/train.py deleted file mode 100644 index ae902279..00000000 --- a/fast_llm/tools/train.py +++ /dev/null @@ -1,24 +0,0 @@ -import argparse - -from fast_llm.engine.config_utils.runnable import RunnableConfig -from fast_llm.models.auto import trainer_registry - - -class CliTrainingConfig(RunnableConfig): - @classmethod - def _get_parser(cls): - parser = super()._get_parser() - parser.add_argument( - "model_type", - choices=trainer_registry.keys(), - help="The Fast-LLM model type to use. Must be defined in the trainer registry in `fast_llm.models.auto`.", - ) - return parser - - @classmethod - def _from_parsed_args(cls, parsed: argparse.Namespace, unparsed: list[str]): - return trainer_registry[parsed.model_type]._from_parsed_args(parsed, unparsed) - - -if __name__ == "__main__": - CliTrainingConfig.parse_and_run() diff --git a/fast_llm/utils.py b/fast_llm/utils.py index d89c9d76..460745e3 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -218,6 +218,11 @@ def __setitem__(self, key: KeyType, value: ValueType): raise KeyError(f"Entry {key} already in {self._name} registry") self._data[key] = value + def __delitem__(self, key: KeyType): + if key not in self: + raise KeyError(f"Entry {key} not found in {self._name} registry") + del self._data[key] + def keys(self) -> list[KeyType]: return list(self._data) diff --git a/setup.cfg b/setup.cfg index 9b944b27..a48759c1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -68,4 +68,4 @@ DOCS = [options.entry_points] console_scripts = - fast-llm = fast_llm.tools.cli:fast_llm + fast-llm = fast_llm.cli:fast_llm_main diff --git a/tests/common.py b/tests/common.py index 569d690c..bcc563d7 100644 --- a/tests/common.py +++ b/tests/common.py @@ -13,6 +13,7 @@ from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.layers.ssm.config import SSMConfig from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.models.gpt.config import ( @@ -24,7 +25,6 @@ Starcoder2GPTHuggingfaceCheckpointFormat, ) from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, LLambaHuggingfaceCheckpointFormat -from fast_llm.tools.train import CliTrainingConfig from tests.compare_tensor_logs import CompareConfig, compare_tensor_logs # FIXME: figure out correct import of megatron modules without this hack @@ -392,7 +392,7 @@ def run_test_script( if is_megatron: script = [*script, f"--structured-logs-dir={path}", f"--data-cache-path={path}"] else: - script = [model_type, *script, f"run.experiment_dir={path}"] + script = ["train", model_type, *script, f"run.experiment_dir={path}"] header = ["Megatron-LM/pretrain_gpt.py"] if is_megatron else ["--no-python", "fast-llm", "train"] command = [ "python", @@ -408,7 +408,7 @@ def run_test_script( else: get_test_dataset() if num_gpus == 1 and not is_megatron: - CliTrainingConfig.parse_and_run(script) + RunnableConfig.parse_and_run(script) else: completed_proc = subprocess.run(command, env=env, timeout=60) if completed_proc.returncode: diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index e0845a4c..5c5f5b90 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -14,10 +14,9 @@ FastLLMCheckpointFormat, ModelConfigType, ) +from fast_llm.engine.checkpoint.convert import ConvertConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.engine.multi_stage.multi_stage import ShardName -from fast_llm.models.auto import model_registry -from fast_llm.tools.convert import ConvertConfig from tests.common import ( CONFIG_COMMON, FORCE_REUSE_RESULTS, @@ -31,7 +30,7 @@ ) from tests.compare_tensor_logs import CompareConfig, compare_logged_tensor -TEST_MODEL_CONFIG_CLS = model_registry[TEST_MODEL_TYPE] +TEST_MODEL_CONFIG_CLS = FastLLMModelConfig.get_subclass(TEST_MODEL_TYPE) TEST_MODEL_HF_CLS = TEST_MODEL_CONFIG_CLS.get_huggingface_model_class() TEST_MODEL_CLS = TEST_MODEL_CONFIG_CLS.get_model_class() TEST_BASE_MODEL_CONFIG_CLS = TEST_MODEL_CONFIG_CLS.get_base_model_config_class() diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index bb468ceb..aeea5b8c 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -2,14 +2,14 @@ from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.tools.train import CliTrainingConfig from fast_llm.utils import Assert from tests.common import CONFIG_COMMON, requires_cuda def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: - parsed, unparsed = CliTrainingConfig._get_parser().parse_known_args([model_type] + args) - config: TrainerConfig = CliTrainingConfig._from_parsed_args(parsed, unparsed) + cls = TrainerConfig.get_subclass(model_type) + parsed, unparsed = cls._get_parser().parse_known_args(args) + config: TrainerConfig = cls._from_parsed_args(parsed, unparsed) distributed = Distributed(config.model.distributed) trainer = config.get_trainer_class()(config=config) trainer.setup(distributed, config.get_run(distributed)) diff --git a/tools/push_model.py b/tools/push_model.py index edab3312..39a3b914 100644 --- a/tools/push_model.py +++ b/tools/push_model.py @@ -27,7 +27,7 @@ raise ImportError("Please install huggingface_hub to use this script") from e -from fast_llm.tools.convert import ConvertConfig # isort:skip +from fast_llm.engine.checkpoint.convert import ConvertConfig # isort:skip logger = logging.getLogger(__name__) From 038106fe2810fc2e95f4105234799b7d5c58d216 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 15 May 2025 11:19:08 -0400 Subject: [PATCH 25/26] Simplify cli --- fast_llm/data/auto.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/fast_llm/data/auto.py b/fast_llm/data/auto.py index 902faf1c..c44e538f 100644 --- a/fast_llm/data/auto.py +++ b/fast_llm/data/auto.py @@ -1,13 +1,5 @@ -from fast_llm.data.preparator.config import DatasetPreparatorConfig -from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig -from fast_llm.utils import Registry +""" +Import these submodules to ensure classes are added to the dynamic class registry. +""" -dataset_preparator_registry = Registry[str, DatasetPreparatorConfig]( - "DatasetPreparator", - { - dataset_preparator.preparator_name: dataset_preparator - for dataset_preparator in [ - GPTMemmapDatasetPreparatorConfig, - ] - }, -) +from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig # isort: skip From e199d0aa1d97d8f3185d7a440b887656baa1f722 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 27 May 2025 11:05:29 -0400 Subject: [PATCH 26/26] fix --- .github/workflows/docs.yaml | 8 +- Megatron-LM | 2 +- docs/recipes/generate.md | 77 ++++ fast_llm/data/dataset/gpt/fim.py | 3 + fast_llm/data/dataset/gpt/sampled.py | 78 +++- .../data/preparator/gpt_memmap/prepare.py | 90 +++++ fast_llm/data/tokenizer.py | 1 + fast_llm/engine/checkpoint/safe_load.py | 22 +- fast_llm/engine/distributed/config.py | 5 +- fast_llm/engine/inference/config.py | 17 + fast_llm/engine/inference/huggingface.py | 74 +++- fast_llm/engine/inference/runner.py | 39 +- fast_llm/engine/multi_stage/config.py | 15 +- fast_llm/engine/multi_stage/fast_llm_model.py | 5 +- fast_llm/engine/multi_stage/fsdp.py | 16 +- fast_llm/engine/multi_stage/multi_stage.py | 7 +- fast_llm/engine/multi_stage/stage.py | 18 +- fast_llm/engine/multi_stage/stage_base.py | 6 +- fast_llm/engine/schedule/schedule.py | 3 + fast_llm/functional/dpo.py | 78 ++++ fast_llm/layers/language_model/config.py | 17 + fast_llm/layers/language_model/head.py | 73 +++- .../layers/language_model/preprocessing.py | 51 +++ fast_llm/models/custom/config.py | 9 +- fast_llm/models/gpt/config.py | 30 +- fast_llm/models/gpt/model.py | 12 +- fast_llm/models/gpt/trainer.py | 1 + fast_llm/models/ssm/config.py | 17 +- fast_llm/models/ssm/trainer.py | 6 +- mkdocs.yaml | 1 + setup.cfg | 3 +- tests/common.py | 9 +- tests/conftest.py | 14 + tests/data/test_prepare_gpt_memmap.py | 37 ++ tests/test_checkpoint.py | 29 +- tests/test_functional.py | 149 +++++++ tests/test_gpt_generate_and_forward.py | 373 ++++++++++++++++++ tests/test_ssms.py | 1 + 38 files changed, 1276 insertions(+), 120 deletions(-) create mode 100644 docs/recipes/generate.md create mode 100644 fast_llm/functional/dpo.py create mode 100644 tests/test_gpt_generate_and_forward.py diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 549140ca..93191972 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -31,7 +31,9 @@ jobs: - run: | pip install "torch>=2.2.2" pip install pybind11 - FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]" + FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ + MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ + pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]" - name: Build the documentation run: mkdocs build @@ -56,6 +58,8 @@ jobs: - run: | pip install "torch>=2.2.2" pip install pybind11 - FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]" + FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ + MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ + pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]" - name: Publish the documentation run: mkdocs gh-deploy --force --dirty diff --git a/Megatron-LM b/Megatron-LM index cb6baf17..511e8f5c 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit cb6baf171d064db6c2fee52f32dc1b51a2e6538d +Subproject commit 511e8f5cbe3ab8291953ac64e5beceb727a1b814 diff --git a/docs/recipes/generate.md b/docs/recipes/generate.md new file mode 100644 index 00000000..e6bda803 --- /dev/null +++ b/docs/recipes/generate.md @@ -0,0 +1,77 @@ +--- +title: How to Generate with a Fast-LLM Model +--- + +Fast-LLM models support `generate` and `forward` operations through Hugging Face–compatible wrappers. + +⚠️ Limitations: + +- No support for `cache`, `past_key_values`, `labels`, `attention` outputs, or `inputs_embeds` +- `position_ids` are ignored and reconstructed from the attention mask +- **model-parallel** and **sequence-data-parallel** generation is **not** supported + +--- + +### 🔧 Generating Text from a Fast-LLM Model + +Below is a step-by-step example of how to generate text using a Fast-LLM model checkpoint from Hugging Face Hub. + +```python +# Import dependencies +import huggingface_hub +from transformers import AutoTokenizer +from fast_llm.engine.checkpoint.config import CheckpointLoadConfig +from fast_llm.models.gpt.config import LlamaGPTHuggingfaceCheckpointFormat +from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM + +# Specify model and configuration +model = "HuggingFaceTB/SmolLM2-135M-Instruct" +checkpoint_format = LlamaGPTHuggingfaceCheckpointFormat +max_new_tokens = 50 + +# Download model checkpoint from the Hugging Face Hub to a local directory +model_path = huggingface_hub.snapshot_download(repo_id=model, local_dir="/tmp") + +# Load tokenizer from the downloaded model +tokenizer = AutoTokenizer.from_pretrained(model_path) + +# Optional: updates to Fast-LLM config before loading the model +updates = { + ("base_model", "transformer", "use_flash_attention"): True, + ("distributed", "training_dtype"): "bf16" +} + +# Load the model from the checkpoint with the given configuration +model = HuggingfaceGPTModelForCausalLM.from_pretrained( + CheckpointLoadConfig( + path=model_path, + format=checkpoint_format, + model_weights=True, + ), + updates, +) + +# Example input messages formatted for chat-style generation +messages = [ + {"role": "user", "content": "What is gravity?"}, + {"role": "user", "content": "Who is the president of EU?"}, +] + +# Convert messages into model input format using chat template +input_text = [tokenizer.apply_chat_template([el], tokenize=False) for el in messages] + +# Prepare tokenized input for the model +tokenizer.padding_side = "left" # Important for correct padding +inputs = tokenizer(input_text, padding="longest", return_tensors="pt").to("cuda") + +# Generate text using the model +outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=False) + +# Decode and display outputs +outputs = [tokenizer.decode(el, skip_special_tokens=True) for el in outputs] + +print("--------------------------------------------------------------------") +for el in outputs: + print(el) + print("--------------------------------------------------------------------") +``` diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 63b7f437..2b2c8b3b 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -20,8 +20,11 @@ def __init__( ): if sampling.parameters.use_loss_masking_spans: raise NotImplementedError("FIM is currently not compatible with loss masking.") + if sampling.parameters.use_preference_loss_spans: + raise NotImplementedError("FIM is currently not compatible with preference loss masking.") self._config = config self._dataset = dataset + self._seed = sampling.config.seed self._tokenizer = sampling.tokenizer if self._tokenizer is None: diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 065eb94d..8bb5f737 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -30,6 +30,8 @@ class GPTSample: token_ids: np.ndarray loss_masking_spans: np.ndarray | None = None + chosen_span: np.ndarray | None = None + rejected_span: np.ndarray | None = None sequence_lengths: np.ndarray | None = None @@ -112,6 +114,14 @@ def __init__( self._token_cumsum_shuffled = MemmapArray(base_path.with_name(base_path.name + "_shuffled_cumsum.npy")) self._token_cumsum_unshuffled = MemmapArray(base_path.with_name(base_path.name + "_unshuffled_cumsum.npy")) self._yaml_path = base_path.with_suffix(".yaml") + + # keep document sizes and len filtered docs for preference loss masking + if self._parameters.use_preference_loss_spans: + self._document_sizes = MemmapArray(base_path.with_name(base_path.name + "_doc_sizes.npy")) + self._doc_length_filtered_indicies = MemmapArray( + base_path.with_name(base_path.name + "_doc_length_filtered_indices.npy") + ) + # Sample or validate the dataset of a given rank. if sampling.distributed.config.rank == sampling.get_next_rank(): self._sample() @@ -145,10 +155,14 @@ def _sample(self) -> None: raise RuntimeError( f" > No documents shorter than {self._parameters.sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." ) + # We produce sequences of length `self._sequence_length + extra_tokens` so the last token has a label for all prediction heads, # but in case of truncations we also include those last labels in the following sample, # so we need `sequence_length * num_samples + extra_tokens` tokens in total. - if self._truncate_documents: + if self._parameters.use_preference_loss_spans: + documents_per_epoch = (~long_docs_filter).sum().item() + num_epochs = math.ceil(self._parameters.num_samples / documents_per_epoch) + elif self._truncate_documents: num_epochs = math.ceil( (self._parameters.sequence_length * self._parameters.num_samples + self._parameters.extra_tokens) / tokens_per_epoch @@ -187,8 +201,8 @@ def _sample(self) -> None: if self._yaml_path is not None and self._yaml_path.is_file(): loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) - self._load_yaml_data(loaded_yaml_data) - if not self._truncate_documents: + self._load_yaml_data(yaml_data) + if not self._truncate_documents and not self._parameters.use_preference_loss_spans: del loaded_yaml_data["unshuffled_tokens"] if loaded_yaml_data != yaml_data: @@ -251,6 +265,24 @@ def _sample(self) -> None: else: raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}") + if self._parameters.use_preference_loss_spans: + yaml_data["unshuffled_tokens"] = 0 # not used, ignore + + # index of all documents less than seq length long + doc_length_filtered_indicies = torch.nonzero(~long_docs_filter, as_tuple=True)[0] + self._doc_length_filtered_indicies.save(doc_length_filtered_indicies.numpy(force=self._config.gpu)) + + # apply shuffling on doc_length_filtered_indicies + if shuffled_epochs > 0: + self._document_shuffling.save( + document_shuffling[: self._parameters.num_samples].numpy(force=self._config.gpu) + ) + self._document_sizes.save(document_sizes.numpy(force=self._config.gpu)) + if self._yaml_path is not None: + self._yaml_path.parent.mkdir(parents=True, exist_ok=True) + yaml.safe_dump(yaml_data, self._yaml_path.open("w")) + return + # To get a sample on the fly we need to know where it begins, # and this is a non-trivial information because the documents have variable length. # The starting point `(document[idx], token[idx])` corresponds to the `(idx * sequence_length)` th token, i.e. @@ -349,6 +381,40 @@ def __getitem__(self, index: int) -> typing.Any: The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`). """ self._lazy_load() + + if self._parameters.use_preference_loss_spans: + if index < self._unshuffled_documents: + document_index = self._doc_length_filtered_indicies[index % self._documents_per_epoch] + else: + document_index = self._doc_length_filtered_indicies[ + self._document_shuffling[index - self._unshuffled_documents].item() + ] + + sample = self._indexed_dataset.get( + document_index, + offset=0, + length=self._document_sizes[document_index], + use_loss_masking_spans=self._parameters.use_loss_masking_spans, + use_preference_loss_spans=self._parameters.use_preference_loss_spans, + ) + + chosen_span_end = sample.chosen_span[1] + 1 + sequence_lengths = [ + chosen_span_end, + len(sample.token_ids) - chosen_span_end, + ] + + # compute padding size + padding = np.full((self._parameters.sequence_length + 1,), 0) + padding[: len(sample.token_ids)] = sample.token_ids + sequence_lengths.append(self._parameters.sequence_length - len(sample.token_ids)) + sample.token_ids = padding + + if not self._parameters.cross_document_attention: + sample.sequence_lengths = np.array(sequence_lengths) + + return sample + # tokens at the boundary are included in only one sample when we pack without truncations # in case of packing with truncations, the last token from the previous sample is also the first token of the next sample sample_length = ( @@ -454,7 +520,9 @@ def _lazy_load(self): def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._documents_per_epoch = data["dataset"]["documents_per_epoch"] - if "unshuffled_tokens" not in data: + if self._parameters.use_preference_loss_spans: + data["unshuffled_tokens"] = 0 # not used, ignore + elif "unshuffled_tokens" not in data: # Backward compatibility # TODO v0.x: Remove assert self._truncate_documents @@ -485,6 +553,8 @@ def __init__( ) self._config = sampling.config self._parameters = sampling.parameters + if self._parameters.use_preference_loss_spans: + raise NotImplementedError("Legacy sampling does not support preference loss masking.") if sampling.cache_directory is None: log_main_rank( diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 23e497bf..0cba3aa1 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -74,6 +74,70 @@ def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict "num_tokens": num_tokens, } + def _tokenize_preference_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: + packed_texts = [] + chosen_spans = [] + rejected_spans = [] + + for conv_history, chosen_text, rejected_text in zip( + batch[self._config.dataset.field], + batch[self._config.dataset.chosen_text], + batch[self._config.dataset.rejected_text], + ): + # compute chosen span + full_chosen_text = conv_history + chosen_text + self._tokenizer.tokenizer.eos_token + chosen_span = [len(conv_history), len(full_chosen_text) - 1] + offset = len(full_chosen_text) + chosen_spans.append(chosen_span) + + # compute rejected span + full_rejected_text = self._tokenizer.tokenizer.bos_token + conv_history + rejected_text + rejected_span = [ + offset + len(self._tokenizer.tokenizer.bos_token + conv_history), + offset + len(full_rejected_text) - 1, + ] + rejected_spans.append(rejected_span) + + # pack texts + packed_text = full_chosen_text + full_rejected_text + + assert ( + packed_text[chosen_span[0] : chosen_span[1] + 1] == chosen_text + self._tokenizer.tokenizer.eos_token + ), f"{packed_text[chosen_span[0]: chosen_span[1] + 1]} does not match {chosen_text}" + + assert ( + packed_text[rejected_span[0] : rejected_span[1] + 1] == rejected_text + ), f"{packed_text[rejected_span[0]: rejected_span[1] + 1]} does not match {rejected_text}" + packed_texts.append(packed_text) + + # tokenize with spans + input_ids, chosen_token_spans, rejected_token_spans = map( + list, + zip( + *[ + ( + np.array(input_ids, dtype=self._data_type.numpy), + np.array(token_spans[0], dtype=np.int32), + np.array( + [token_spans[1][0], token_spans[1][1] + 1], dtype=np.int32 + ), # adding 1 to end for eos token + ) + for input_ids, token_spans in [ + self._tokenizer.tokenize_with_spans(text, [chosen_span, rejected_span]) + for text, chosen_span, rejected_span in zip(packed_texts, chosen_spans, rejected_spans) + ] + ] + ), + ) + + num_tokens = [len(x) for x in input_ids] + return { + "input_ids": input_ids, + "chosen_token_spans": chosen_token_spans, + "rejected_token_spans": rejected_token_spans, + "num_tokens": num_tokens, + } + def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetConfig: shard_idx, shard_dataset = args prefix = f"shard_{self._config.distributed.rank}_{shard_idx}" @@ -86,6 +150,18 @@ def _document_generator(): np.array(item["input_ids"], dtype=self._data_type.numpy), np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), ) + elif ( + "chosen_token_spans" in shard_dataset.column_names + and "rejected_token_spans" in shard_dataset.column_names + and self._config.dataset.chosen_text is not None + and self._config.dataset.rejected_text is not None + ): + for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): + yield GPTSample( + token_ids=np.array(item["input_ids"], dtype=self._data_type.numpy), + chosen_span=np.array(item["chosen_token_spans"], dtype=np.int32).reshape(-1, 2), + rejected_span=np.array(item["rejected_token_spans"], dtype=np.int32).reshape(-1, 2), + ) else: for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) @@ -214,10 +290,24 @@ def run(self) -> None: ) if self._config.dataset.field not in dataset.column_names: raise ValueError(f"Dataset does not have field '{self._config.dataset.field}'.") + if self._config.dataset.loss_masking_spans is not None and ( + self._config.dataset.chosen_text is not None or self._config.dataset.rejected_text is not None + ): + raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.") + if (self._config.dataset.chosen_text is None) != (self._config.dataset.rejected_text is None): + raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") + + # route tokenize function if self._config.dataset.loss_masking_spans is not None: if self._config.dataset.loss_masking_spans not in dataset.column_names: raise ValueError(f"Dataset does not have spans field '{self._config.dataset.loss_masking_spans}'.") tokenize_fn = self._tokenize_batch_with_spans + elif self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None: + if self._config.dataset.chosen_text not in dataset.column_names: + raise ValueError(f"Dataset does not have chosen spans field '{self._config.dataset.chosen_text}'.") + if self._config.dataset.rejected_text not in dataset.column_names: + raise ValueError(f"Dataset does not have rejected spans field '{self._config.dataset.rejected_text}'.") + tokenize_fn = self._tokenize_preference_batch_with_spans else: tokenize_fn = self._tokenize_batch diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 28e105ee..9b1d8f04 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -52,6 +52,7 @@ def tokenize_with_spans( token_spans = [] char_pos = 0 beginning_of_text = True + for start, end in char_spans: if char_pos < start: curr_text = text[char_pos:start] diff --git a/fast_llm/engine/checkpoint/safe_load.py b/fast_llm/engine/checkpoint/safe_load.py index 2eec57e0..e72a3a15 100644 --- a/fast_llm/engine/checkpoint/safe_load.py +++ b/fast_llm/engine/checkpoint/safe_load.py @@ -40,8 +40,8 @@ def __enter__(self) -> "SafeLoad": triton_fill(self_shard, math.nan) # Reset and count shard pads for _, fsdp, fsdp_shards in self._model.split_shards_by_fsdp(self._self_shards): - for fsdp_shard in fsdp_shards.values(): - self._loaded += fsdp.reset_shard_pad(fsdp_shard) + for shard_name, fsdp_shard in fsdp_shards.items(): + self._loaded += fsdp.reset_shard_pad(fsdp_shard, shard_name) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -145,9 +145,21 @@ def _check_parameters(self, errors: list[str]) -> None: elif counter is not None and counter > 0: errors.append(f'Loaded off-device parameter : "{parameter_name}" for shard "{shard_name}"') if self._distributed.world_group is not None: - counter_tensor = torch.tensor( - [counter or 0 for counter in counter_per_parameter.values()], dtype=torch.int64 - ).to(self._distributed.device) + counter_list = [] + for parameter_name, counter in counter_per_parameter.items(): + parameter_stage = self._model.get_parameter_stage(parameter_name) + parameter_meta = parameter_stage.get_parameter_meta(parameter_name) + if ( + counter is None + or (not parameter_meta.is_tensor_parallel and self._distributed.config.tensor_rank != 0) + or parameter_stage.is_tied_weight_copy + ): + # Ignore the counter from missing or duplicate tensors. + counter = 0 + counter_list.append(counter) + + counter_tensor = torch.tensor(counter_list, dtype=torch.int64).to(self._distributed.device) + add_ephemeral_timeout(self._distributed.world_group, self._timeout) all_reduce(counter_tensor, group=self._distributed.world_group) counter_per_parameter = { diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 66d89e1a..8e2430d5 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -284,9 +284,8 @@ def _validate(self) -> None: self.batch_data_rank = self.data_rank // self.sequence_data_parallel self.tensor_rank = self.rank % self.tensor_parallel - if self.tensor_parallel == 1: - with self._set_implicit_default(): - self.sequence_tensor_parallel = False + if self.tensor_parallel == 1 and self.sequence_tensor_parallel: + self.sequence_tensor_parallel = False if self.reference_config is not None: self.reference_config.validate() diff --git a/fast_llm/engine/inference/config.py b/fast_llm/engine/inference/config.py index d4b46bcc..b09c88ba 100644 --- a/fast_llm/engine/inference/config.py +++ b/fast_llm/engine/inference/config.py @@ -1,3 +1,4 @@ +import copy import logging import os import pathlib @@ -36,6 +37,22 @@ def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = finally: transformers.configuration_utils.CONFIG_NAME = _backup + def __deepcopy__(self, memo): + # Hugging Face's PretrainedModel will deep copy the config + # when `generate` is enabled. However, `fast_llm_config` + # cannot be deep copied if the world size is greater than 1, + # as it will contain references to process groups. + # Therefore, we copy it by reference instead. + cls = self.__class__ + copied = cls.__new__(cls) + memo[id(self)] = copied + for k, v in self.__dict__.items(): + if k == "fast_llm_config": + setattr(copied, k, v) # Keep the same reference + else: + setattr(copied, k, copy.deepcopy(v, memo)) + return copied + @classmethod def _get_config_dict( cls, pretrained_model_name_or_path: str | os.PathLike | CheckpointLoadMetadataConfig, **kwargs diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 196310b4..e679cfd6 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -2,6 +2,8 @@ import pathlib import typing +import torch +import transformers.generation.utils import transformers.modeling_outputs from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat @@ -9,6 +11,8 @@ from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.config import StageMode from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.engine.schedule.runner import ScheduleRunner +from fast_llm.utils import Assert class HuggingfacePreTrainedModel(transformers.PreTrainedModel): @@ -20,21 +24,36 @@ class HuggingfacePreTrainedModel(transformers.PreTrainedModel): # _supports_cache_class = False # _tied_weights_keys = [] - def __init__(self, config: HuggingfaceModelConfig, fast_llm_model: FastLLMModel, **kwargs): + def __init__( + self, + fast_llm_model: FastLLMModel, + config: HuggingfaceModelConfig | None = None, + runner: ScheduleRunner | None = None, + **kwargs, + ): + if config is None: + config = self.config_class(fast_llm_model.config) + assert self.runner_class.model_class.config_class is config.model_config_class assert config.fast_llm_config is fast_llm_model.config assert isinstance(config, self.config_class) super().__init__(config, **kwargs) - self._inference_runner = self.runner_class(fast_llm_model) - if not fast_llm_model.is_setup: - fast_llm_model.setup(mode=StageMode.inference) + self._inference_runner = self.runner_class(fast_llm_model, runner) + + # A model can be created from pretrained which set it up in the current HF wrapper api + # or set existing model which also must be setup, so, do not accept not setup model + assert fast_llm_model.is_setup + + # We only support data parallel for now + Assert.eq(fast_llm_model.distributed.config.model_parallel, 1) + Assert.eq(fast_llm_model.distributed.config.sequence_data_parallel, 1) + self._inference_runner.setup() + # Transformers needs to be able to inspect the base model. self.fast_llm_base_model = fast_llm_model.base_model - # TODO: Support distributed models? - assert fast_llm_model.config.distributed.world_size == 1 with transformers.modeling_utils.no_init_weights(): self.post_init() @@ -43,8 +62,12 @@ def __init__(self, config: HuggingfaceModelConfig, fast_llm_model: FastLLMModel, def from_pretrained( cls, pretrained_model_name_or_path: str | os.PathLike | CheckpointLoadConfig, - *, - mode: StageMode = StageMode.inference, + *updates: dict[str | tuple[str, ...], typing.Any], + optimizer_state_names: tuple[str, ...] | None = None, + # setup: bool = True, + mode: StageMode = StageMode.training, + use_cpu: bool = False, + stage_filter: set | None = None, **kwargs, ) -> typing.Self: # Pretrained config. @@ -54,18 +77,37 @@ def from_pretrained( format=FastLLMCheckpointFormat, ) - updates = {} - torch_dtype = kwargs.pop("torch_dtype", None) - if torch_dtype is not None: - updates[("distributed", "training_dtype")] = torch_dtype - # Create the model + # always set up model and crate distributed instance internally for now fast_llm_model = cls.runner_class.model_class.from_pretrained( - pretrained_model_name_or_path, updates, mode=mode + pretrained_model_name_or_path, + *updates, + optimizer_state_names=optimizer_state_names, + setup=True, + mode=mode, + use_cpu=use_cpu, + stage_filter=stage_filter, ) - config = cls.config_class(fast_llm_model.config) - return cls(config, fast_llm_model, **kwargs) + return cls(fast_llm_model, **kwargs) def _init_weights(self, module) -> None: raise NotImplementedError(module) + + +class HuggingfaceBaseModelForCausalLM(HuggingfacePreTrainedModel, transformers.generation.utils.GenerationMixin): + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + past_key_values=None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | transformers.modeling_outputs.CausalLMOutputWithPast: + # Meant to be overridden in derived classes + raise NotImplementedError() diff --git a/fast_llm/engine/inference/runner.py b/fast_llm/engine/inference/runner.py index 30f836b7..3003c5f9 100644 --- a/fast_llm/engine/inference/runner.py +++ b/fast_llm/engine/inference/runner.py @@ -7,27 +7,42 @@ from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule +from fast_llm.utils import Assert class InferenceRunner(abc.ABC): model_class: typing.ClassVar[type[FastLLMModel]] = FastLLMModel batch_config_class: typing.ClassVar[type[BatchConfig]] = BatchConfig - def __init__(self, fast_llm_model: FastLLMModel): + def __init__( + self, + fast_llm_model: FastLLMModel, + runner: ScheduleRunner | None = None, + ): assert isinstance(fast_llm_model, self.model_class) self._fast_llm_model = fast_llm_model - # We only need a basic schedule and don't care about dimensions. - self._schedule_config = ScheduleConfig() - # TODO: Sort things out. + with NoAutoValidate(): self._batch_config = self.batch_config_class() self._batch_config.setup(self._fast_llm_model.config.distributed) self._batch_config.validate() - self._runner = ScheduleRunner( - config=self._schedule_config, - multi_stage=self._fast_llm_model, - distributed_config=self._fast_llm_model.config.distributed, - ) + + if runner is None: + # We only need a basic schedule and don't care about dimensions. + self._schedule_config = ScheduleConfig() + # TODO: Sort things out. + + self._runner = ScheduleRunner( + config=self._schedule_config, + multi_stage=self._fast_llm_model, + distributed_config=self._fast_llm_model.config.distributed, + ) + else: + self._schedule_config = runner.config + self._runner = runner + # External runner from training loop must be already setup + assert runner._is_setup + # TODO: Random state? (Distributed.set_step) self._schedule = Schedule( multi_stage=self._fast_llm_model, @@ -42,7 +57,11 @@ def fast_llm_model(self) -> FastLLMModel: return self._fast_llm_model def setup(self): - self._runner.setup(self._fast_llm_model.distributed) + if not self._runner._is_setup: + self._runner.setup(self._fast_llm_model.distributed) + else: + # Means external runner was passed, check it has the same distributed class as the model + Assert.is_(self._runner._distributed, self._fast_llm_model.distributed) def forward( self, input_, kwargs: dict, *, iteration: int = 1, return_metrics: bool = False diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index ad61c70f..5b140d4e 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -30,12 +30,17 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.engine.inference.huggingface import HuggingfacePreTrainedModel + from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel logger = logging.getLogger(__name__) +class ShardName: + weights = "weights" + grads = "grads" + + class StageMode(str, enum.Enum): # Allow forward and backward passes and optimizer. # TODO: Add mode for forward and backward but not optimizer? @@ -242,7 +247,7 @@ def get_model_class(cls) -> type["FastLLMModel"]: raise NotImplementedError @classmethod - def get_huggingface_model_class(cls) -> type["HuggingfacePreTrainedModel"]: + def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceBaseModelForCausalLM"]: raise NotImplementedError @classmethod @@ -372,9 +377,13 @@ def _from_dict( if "fast_llm_version" not in default: default["fast_llm_version"] = "0" + # Determine the model config class. + from fast_llm.models.auto import model_registry + model_config_class = default["model"] if isinstance(model_config_class, str): - model_config_class = FastLLMModelConfig.get_subclass(default["model"]) + Assert.incl(model_config_class, model_registry) + model_config_class = model_registry[model_config_class] default["model"] = model_config_class # TODO v0.3: Remove backward compatibility. diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index de26f9bf..56bae90f 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -89,9 +89,8 @@ def initialize_weights(self, timeout: float | None = None) -> None: stage.initialize_weights() for name, tied_parameter in self._tied_parameters.items(): if tied_parameter.group is not None: - broadcast( - self._stages[tied_parameter.main_stage].weight_shard, 0, tied_parameter.group, timeout=timeout - ) + for fsdp in self._stages[tied_parameter.main_stage].fsdps: + broadcast(fsdp.weight_shard, 0, tied_parameter.group, timeout=timeout) self._finalize_load(reset_optimizer=True) def _finalize_load(self, reset_optimizer: bool = True) -> None: diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index e9c84aa3..5cf51dd5 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -10,8 +10,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.engine.distributed.config import DistributedDim from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.config import SHARD_PAD_TO_MULTIPLE, StageMode -from fast_llm.functional.triton.pointwise import triton_add, triton_copy +from fast_llm.engine.multi_stage.config import SHARD_PAD_TO_MULTIPLE, ShardName, StageMode +from fast_llm.functional.triton.pointwise import triton_add, triton_copy, triton_fill from fast_llm.logging import log_distributed_tensor from fast_llm.tensor import ParameterMeta, SafeTensorSlice, TensorMeta from fast_llm.utils import Assert, clamp, padded_cumsum @@ -246,13 +246,14 @@ def setup( ) self._parameter_buffers[parameter_name] = parameter_buffer - def reset_shard_pad(self, shard: torch.Tensor) -> int: + def reset_shard_pad(self, shard: torch.Tensor, shard_name: str) -> int: assert self._is_setup assert self._mode.on_device # TODO: Needed? # Prevent nans with the padded values # Also ensures a correct parameter count in loading context. - self._weight_shard_meta.validate(shard) + shard_meta = self._weight_shard_meta if shard_name == ShardName.weights else self._grad_shard_meta + shard_meta.validate(shard) if self._shard_pad > 0: shard[-self._shard_pad :].zero_() return self._shard_pad @@ -452,5 +453,12 @@ def copy_shard_overlaps( begin, end = self._parameter_range_in_shard(name) for shard_name, shard in shards.items(): + # Shards can be empty (frozen weights) + if shard.numel() == 0: + continue + if loaded_shards[shard_name].numel() == 0: + shard[begin:end][overlap_mask] = 0 + counter += overlap_count + continue shard[begin:end][overlap_mask] = loaded_shards[shard_name][overlap_index_map_masked] counter += overlap_count diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 21d0fe55..497d1110 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -14,7 +14,7 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode from fast_llm.engine.multi_stage.fsdp import FSDP from fast_llm.engine.multi_stage.stage import Stage from fast_llm.engine.optimizer.config import ParamGroup @@ -24,11 +24,6 @@ logger = logging.getLogger(__name__) -class ShardName: - weights = "weights" - grads = "grads" - - class MultiStageModel[ConfigType: FastLLMModelConfig](Configurable[ConfigType]): config_class: typing.ClassVar[type[FastLLMModelConfig]] = FastLLMModelConfig base_model_class: typing.ClassVar[type[BaseModel]] = BaseModel diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 675e878b..a2a9d9d3 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -13,6 +13,9 @@ from fast_llm.tensor import ParameterMeta, TensorMeta, accumulate_gradient from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + pass + logger = logging.getLogger(__name__) @@ -111,6 +114,19 @@ def forward( metrics, ) self._log_layer_forward(output, kwargs, i) + + # TODO: very slow and memory consuming, only use for debugging for now + # TODO: decide if and how we want to return + # HF transformer style details from forward properly + if "output_hidden_states" in kwargs and kwargs["output_hidden_states"]: + # Last layer does not provide output + if output is not None: + meta = self._meta_outputs[i] + output_global, _ = meta.local_to_global(output.detach(), distributed=self._distributed) + kwargs["hidden_states"][self._layer_range[i]] = { + "layer_type": type(layer).__name__, + "tensor": output_global, + } return None if output is None else output.detach(), (input_, output) def backward( @@ -156,7 +172,7 @@ def reduce_gradients(self, accumulate=False) -> None: level=self._config.debug_param_gradients, global_=False, ) - if self._config.debug_all_param_gradients: + if self._config.debug_all_param_gradients and fsdp.requires_grad: fsdp.log_shard( name="gradient", shard=fsdp.grad_shard, diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index fd50f55c..3ca28ba5 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -10,7 +10,7 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.config import StageConfig, StageMode +from fast_llm.engine.multi_stage.config import ShardName, StageConfig, StageMode from fast_llm.engine.multi_stage.fsdp import FSDP from fast_llm.engine.optimizer.config import ParamGroup from fast_llm.logging import log_generator @@ -209,7 +209,7 @@ def initialize_weights(self) -> None: meta.init_parameter(parameter, self._distributed) if self.mode.on_device: - fsdp.reset_shard_pad(fsdp.weight_shard) + fsdp.reset_shard_pad(fsdp.weight_shard, ShardName.weights) if self._config.debug_param_init: log_generator("CPU generator after reset", torch.random.default_generator) @@ -316,7 +316,6 @@ def import_state_tensor( """ Given a global parameter tensor, set the associated slice of a local parameter shard. Return the size of the local slice. - TODO: Doesn't work """ fsdp_index = self._fsdp_index[parameter_name] return self._fsdps[fsdp_index].import_state_tensor(parameter_name, shards[fsdp_index], tensor) @@ -324,7 +323,6 @@ def import_state_tensor( def _export_shard( self, shards: tuple[torch.Tensor], data_type: DataType | None = None ) -> typing.Generator[tuple[str, torch.Tensor], None, None]: - # TODO: Doesn't work for fsdp, shard in zip(self._fsdps, shards, strict=True): yield from fsdp.export_shard(shard, self._distributed, data_type) diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index 44a5f677..91ce0d89 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -203,6 +203,7 @@ def _create_index(self) -> None: Assert.incl(step.type_, (StepType.forward, StepType.backward)) step.global_index = i # TODO: More configurable placement? + step.pipeline_rank = step.stage % self._distributed.pipeline_parallel step.local_index = len(self._device_steps[step.pipeline_rank]) self._device_steps[step.pipeline_rank].append(step) @@ -222,6 +223,7 @@ def _create_index(self) -> None: Assert.empty(step_map) # Related steps + for i, step in enumerate(self._steps): if self._is_training: if step.type_ == StepType.forward: @@ -229,6 +231,7 @@ def _create_index(self) -> None: step.backward_step = self.get_step(StepType.backward, *step.map_index[1:]) else: step.forward_step = self.get_step(StepType.forward, *step.map_index[1:]) + if step.type_ == StepType.forward and step.stage == 0: step.prev_step = None elif step.type_ == StepType.backward and step.stage == self._num_stages - 1: diff --git a/fast_llm/functional/dpo.py b/fast_llm/functional/dpo.py new file mode 100644 index 00000000..3a70f308 --- /dev/null +++ b/fast_llm/functional/dpo.py @@ -0,0 +1,78 @@ +import torch + + +def _compute_logprobs_for_preference_spans( + logits: torch.Tensor, targets: torch.Tensor, chosen_spans: torch.Tensor, rejected_spans: torch.Tensor +): + assert torch.all(targets < logits.size(-1)), "Target out of vocab range" + + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + + # gather log probabilities corresponding to the target tokens + selected_log_probs = log_probs.gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) + + # apply chosen mask + chosen_logp = 0 + for idx, span in enumerate(chosen_spans): + chosen_logp += selected_log_probs[idx][span[0].item() : span[1].item() + 1].sum() + + # apply rejected mask + rejected_logp = 0 + for idx, span in enumerate(rejected_spans): + rejected_logp += selected_log_probs[idx][span[0].item() : span[1].item() + 1].sum() + + return chosen_logp, rejected_logp, selected_log_probs + + +def _compute_dpo_loss( + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + reference_chosen_logps: torch.Tensor, + reference_rejected_logps: torch.Tensor, + beta: float, +): + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + + diff_logratios = pi_logratios - ref_logratios + + losses = -torch.nn.functional.logsigmoid(beta * diff_logratios) + return losses + + +def compute_dpo_loss( + logits: torch.Tensor, + targets: torch.Tensor, + reference_model_logits: torch.Tensor, + chosen_spans: torch.Tensor, + rejected_spans: torch.Tensor, + beta: float, + grad_output: float | None, +) -> tuple[torch.Tensor, torch.Tensor]: + with torch.enable_grad(): + logits_ = logits.float().detach().requires_grad_() + reference_model_logits_ = reference_model_logits.float().detach() + + policy_chosen_logps, policy_rejected_logps, _ = _compute_logprobs_for_preference_spans( + logits_, targets, chosen_spans, rejected_spans + ) + + reference_chosen_logps, reference_rejected_logps, _ = _compute_logprobs_for_preference_spans( + reference_model_logits_, targets, chosen_spans, rejected_spans + ) + + losses = _compute_dpo_loss( + policy_chosen_logps=policy_chosen_logps, + policy_rejected_logps=policy_rejected_logps, + reference_chosen_logps=reference_chosen_logps, + reference_rejected_logps=reference_rejected_logps, + beta=beta, + ) + + if grad_output is None: + loss = None + else: + loss = losses.mean() + loss.backward(torch.full_like(loss, grad_output)) + loss.detach() + return loss.detach(), logits_.grad.detach().to(logits.dtype) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 0db76ad1..2d5fd843 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -34,6 +34,8 @@ class LanguageModelKwargs: # TODO: These are generic labels = "labels" phase = "phase" + chosen_spans = "chosen_spans" + rejected_spans = "rejected_spans" loss_mask = "loss_mask" @@ -87,6 +89,21 @@ class LanguageModelBaseConfig(BaseModelConfig): desc="Min value for clamping initialized weights of the vocabulary embedding and output (logits).", hint=FieldHint.feature, ) + enable_dpo: bool | None = Field( + default=False, + desc="Whether to enable DPO loss", + hint=FieldHint.feature, + ) + dpo_beta: float | None = Field( + default=1.0, + desc="Beta value for DPO loss.", + hint=FieldHint.feature, + ) + dpo_reference_model: str | None = Field( + default=None, + desc="Name of the reference model to use for dpo.", + hint=FieldHint.feature, + ) cross_entropy_impl: CrossEntropyImpl = Field( default=CrossEntropyImpl.auto, desc="Implementation for the cross-entropy computation.", diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 813dcc07..d6d1b8a5 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -12,6 +12,7 @@ from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig from fast_llm.functional.cross_entropy import cross_entropy_forward_backward +from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.language_model.config import ( @@ -73,14 +74,18 @@ def __init__( self._init_output_weights(hidden_dim, config) - self._cross_entropy_impl = config.cross_entropy_impl - if self._cross_entropy_impl == CrossEntropyImpl.auto: - if self._parallel_embeddings: - self._cross_entropy_impl = CrossEntropyImpl.fused - elif TritonConfig.TRITON_ENABLED: - self._cross_entropy_impl = CrossEntropyImpl.triton - else: - self._cross_entropy_impl = CrossEntropyImpl.fused + self._use_dpo_loss = config.enable_dpo + if self._use_dpo_loss: + self.dpo_beta = config.dpo_beta + else: + self._cross_entropy_impl = config.cross_entropy_impl + if self._cross_entropy_impl == CrossEntropyImpl.auto: + if self._parallel_embeddings: + self._cross_entropy_impl = CrossEntropyImpl.fused + elif TritonConfig.TRITON_ENABLED: + self._cross_entropy_impl = CrossEntropyImpl.triton + else: + self._cross_entropy_impl = CrossEntropyImpl.fused self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) @@ -143,12 +148,12 @@ def _forward_backward( ) -> tuple[torch.Tensor, torch.Tensor | None]: target = kwargs.get( LanguageModelKwargs.labels - if self._config.distillation_model is None + if self._use_dpo_loss or self._config.distillation_model is None else f"{self._config.distillation_model}_logits" ) # Loss mask for distillation. (Labels are already masked.) loss_mask = None - if target is not None: + if target is not None and not self._use_dpo_loss: if self._config.distillation_model is None: # MTP: Shift the labels target_sequence_length = ( @@ -175,6 +180,22 @@ def _forward_backward( with torch.enable_grad(): ln_output = self.final_norm(input_) + if "output_hidden_states" in kwargs and kwargs["output_hidden_states"]: + # The last hidden layer output is returned normalized in the HF Transformers-style output, at least for LLama style models. + # So, if needed, we gather the data after normalization and set it as the output of the previous layer. + dims = list(kwargs[TransformerKwargs.hidden_dims]) + sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first]) + dims[sequence_index] = ( + TensorDim( + TransformerDimNames.sequence_q_tp, dims[sequence_index].global_size, DistributedDimNames.tensor + ) + if self._sequence_parallel_logits + else TensorDim(TransformerDimNames.sequence_q, dims[sequence_index].global_size) + ) + meta = TensorMeta.from_dims(tuple(dims), tensor_name="transformer hidden_state", dtype=ln_output.dtype) + hidden_state, _ = meta.local_to_global(ln_output.detach(), distributed=self._tensor_space.distributed) + kwargs["hidden_states"][len(kwargs["hidden_states"]) - 1]["tensor"] = hidden_state + grad_output = kwargs[TransformerKwargs.grad_output] / ( self._group_size if self._sequence_parallel_logits else 1 ) @@ -309,16 +330,28 @@ def _logits_cross_entropy_forward_backward( if target is None: return logits * self._logits_scale_factor, None - loss, grad = cross_entropy_forward_backward( - logits.flatten(0, -2), - target, - loss_mask, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - grad_output=grad_output, - implementation=self._cross_entropy_impl, - logits_scale_factor=self._logits_scale_factor, - target_format=TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits, - ) + if self._use_dpo_loss: + loss, grad = compute_dpo_loss( + logits, + target, + kwargs.get(f"{self._config.dpo_reference_model}_logits"), + kwargs[LanguageModelKwargs.chosen_spans], + kwargs[LanguageModelKwargs.rejected_spans], + self.dpo_beta, + grad_output, + ) + else: + loss, grad = cross_entropy_forward_backward( + logits.flatten(0, -2), + target, + loss_mask, + group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, + grad_output=grad_output, + implementation=self._cross_entropy_impl, + logits_scale_factor=self._logits_scale_factor, + target_format=TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits, + ) + # TODO: de-allocate earlier. del logits return loss, output_parallel_linear_backward(grad, context) diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index 7e95bb5c..d719bef3 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -69,3 +69,54 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: tensor_name=LanguageModelKwargs.position_ids, dtype=torch.int64, ) + + +class PreferenceSpanPreprocessor(Preprocessor): + def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): + self._config = config + self._tensor_space = tensor_space + self._distributed_config = self._tensor_space.distributed_config + self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + return + + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size + sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size + sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels + + if LanguageModelKwargs.chosen_spans not in kwargs or LanguageModelKwargs.rejected_spans not in kwargs: + raise ValueError("Expected chosen spans or rejected spans to be found within the batch.") + + chosen_spans = kwargs[LanguageModelKwargs.chosen_spans] + chosen_valid_spans = [] + for spans in chosen_spans: + if not spans.numel(): + continue + # only keep spans within the sequence or partially within the sequence + valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] + if valid_spans.numel(): + # if span is partially within the sequence, truncate parts of spans that are outside of the sequence + valid_spans[0].clamp_(min=sequence_offset) + valid_spans[1].clamp_(max=sequence_k) + valid_spans -= sequence_offset + + chosen_valid_spans.append(valid_spans) + kwargs[LanguageModelKwargs.chosen_spans] = chosen_valid_spans + + rejected_spans = kwargs[LanguageModelKwargs.rejected_spans] + rejected_valid_spans = [] + for spans in rejected_spans: + if not spans.numel(): + continue + # only keep spans within the sequence or partially within the sequence + valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] + if valid_spans.numel(): + # if span is partially within the sequence, truncate parts of spans that are outside of the sequence + valid_spans[0].clamp_(min=sequence_offset) + valid_spans[1].clamp_(max=sequence_k) + valid_spans -= sequence_offset + + rejected_valid_spans.append(valid_spans) + kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans diff --git a/fast_llm/models/custom/config.py b/fast_llm/models/custom/config.py index 963ffd35..4251bd27 100644 --- a/fast_llm/models/custom/config.py +++ b/fast_llm/models/custom/config.py @@ -2,9 +2,6 @@ from fast_llm.config import FieldUpdate, config_class from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.engine.config_utils.runnable import RunnableConfig -from fast_llm.engine.multi_stage.config import FastLLMModelConfig -from fast_llm.engine.training.config import TrainerConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig, GPTTrainerConfig, PretrainedGPTModelConfig if typing.TYPE_CHECKING: @@ -25,7 +22,7 @@ class CustomBaseModelConfig(GPTBaseModelConfig): pass -@config_class(dynamic_type={FastLLMModelConfig: "gpt_custom"}) +@config_class() class CustomModelConfig(GPTModelConfig): # TODO: Add custom model config parameters, if any (typically none). model_name: typing.ClassVar[str] = "gpt_custom" @@ -38,7 +35,7 @@ def get_model_class(cls) -> type["CustomModel"]: return CustomModel @classmethod - def get_huggingface_model_class(cls) -> type["HuggingfaceCustomModelForCausalLM"]: + def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceCustomModelForCausalLM"]: from fast_llm.models.custom.huggingface import HuggingfaceCustomModelForCausalLM return HuggingfaceCustomModelForCausalLM @@ -49,7 +46,7 @@ class PretrainedCustomModelConfig(PretrainedGPTModelConfig): model: CustomModelConfig = FieldUpdate() -@config_class(dynamic_type={RunnableConfig: "train_gpt_custom", TrainerConfig: "gpt_custom"}) +@config_class() class CustomTrainerConfig(PretrainedCustomModelConfig, GPTTrainerConfig): # TODO: Add custom trainer config parameters, if any (typically none). data: CustomDataConfig = FieldUpdate() diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 64f6f1de..d9085c67 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -4,7 +4,6 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler -from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig @@ -126,7 +125,7 @@ def _from_dict( return super()._from_dict(default, strict, flat) -@config_class(dynamic_type={FastLLMModelConfig: "gpt"}) +@config_class() class GPTModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "gpt" @@ -148,7 +147,7 @@ def get_model_class(cls) -> type["GPTModel"]: return GPTModel @classmethod - def get_huggingface_model_class(cls) -> type["HuggingfaceGPTModelForCausalLM"]: + def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceGPTModelForCausalLM"]: from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM return HuggingfaceGPTModelForCausalLM @@ -160,7 +159,7 @@ class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): model: GPTModelConfig = FieldUpdate() -@config_class(dynamic_type={RunnableConfig: "train_gpt", TrainerConfig: "gpt"}) +@config_class() class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): data: GPTDataConfig = FieldUpdate() batch: GPTBatchConfig = FieldUpdate() @@ -174,14 +173,29 @@ def _validate(self) -> None: if self.model.base_model.use_megatron_initialization: set_megatron_distributed_seeds(self.model.distributed) super()._validate() - if (name := self.model.base_model.distillation_model) is None: - Assert.empty(self.reference_models) - else: - Assert.eq(self.reference_models.keys(), {name}) + if self.model.base_model.use_absolute_position_embeddings: Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) + + distillation_model = self.model.base_model.distillation_model + dpo_reference_model = self.model.base_model.dpo_reference_model + + if self.model.base_model.enable_dpo: + assert dpo_reference_model is not None + Assert.none(distillation_model) + else: + Assert.none(dpo_reference_model) + + if distillation_model is None and dpo_reference_model is None: + Assert.empty(self.reference_models) + else: + assert distillation_model is None or dpo_reference_model is None # currently don't support both + expected_names = {name for name in (distillation_model, dpo_reference_model) if name is not None} + Assert.eq(self.reference_models.keys(), expected_names) + for reference_model in self.reference_models.values(): Assert.none(reference_model.model.base_model.distillation_model) + Assert.none(reference_model.model.base_model.dpo_reference_model) # TODO: Support more LM head features. Assert.none(reference_model.model.base_model.cross_entropy_splits) Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index d177a41d..b548ab52 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -13,7 +13,7 @@ from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead -from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor +from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor from fast_llm.layers.transformer.config import ( RoutingType, TransformerDimNames, @@ -70,6 +70,9 @@ def __init__( else: self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) + if self._config.enable_dpo: # TODO better way to pass in? + self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._tensor_space)) + def get_output_layers(self) -> list[Layer]: layers = [] for i in range(self._config.prediction_heads): @@ -283,6 +286,10 @@ def preprocess( tokens = batch.token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() if batch.sequence_lengths is not None: kwargs_meta[TransformerKwargs.sequence_lengths] = batch.sequence_lengths + if batch.chosen_spans is not None: + kwargs_meta[LanguageModelKwargs.chosen_spans] = batch.chosen_spans + if batch.rejected_spans is not None: + kwargs_meta[LanguageModelKwargs.rejected_spans] = batch.rejected_spans # TODO: Add pasts/presents to meta input? # Use lists as pointers so `past_key_values` is populated during the previous micro_sequence. @@ -294,7 +301,7 @@ def preprocess( TransformerKwargs.presents: presents, } if phase != PhaseType.inference: - sequence_offset = sequence_k - sequence_q + 1 + sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels if sequence_first: labels = batch.token_ids[sequence_offset : sequence_k + prediction_heads] else: @@ -312,6 +319,7 @@ def preprocess( (spans[:, 0] <= sequence_k + prediction_heads - 1) & (spans[:, 1] >= sequence_offset) ] if valid_spans.numel(): + # if span is partially within the sequence, truncate parts of spans that are outside of the sequence valid_spans[:, 0].clamp_(min=sequence_offset) valid_spans[:, 1].clamp_(max=sequence_k + prediction_heads - 1) valid_spans -= sequence_offset diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 3bdb05c3..cc39d7f7 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -28,6 +28,7 @@ def _get_sampling_parameters( "vocab_size": self._config.model.base_model.vocab_size, "sequence_length": self._config.batch.sequence_length, "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, + "use_preference_loss_spans": self._config.model.base_model.enable_dpo, "cross_document_attention": self._config.batch.cross_document_attention, "extra_tokens": self._config.model.base_model.prediction_heads, } diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 0c2b3c48..771a4fca 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -5,7 +5,6 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, config_class from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler -from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig @@ -17,7 +16,7 @@ if typing.TYPE_CHECKING: from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM from fast_llm.models.ssm.model import HybridSSMModel - from fast_llm.models.ssm.trainer import HybridSSMTrainer + from fast_llm.models.ssm.trainer import SSMTrainer logger = logging.getLogger(__name__) @@ -31,7 +30,7 @@ class HybridSSMBaseModelConfig(LanguageModelBaseConfig): hint=FieldHint.architecture, ) hybrid_block_layout: list[str] = Field( - default=("m2",), + default_factory=lambda: ["m2"], desc="Pattern of blocks to use in the model. 't' for Transformer, 'm' for Mamba1, 'm2' for Discrete Mamba2", hint=FieldHint.architecture, ) @@ -125,7 +124,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return LLambaHuggingfaceCheckpointHandler -@config_class(dynamic_type={FastLLMModelConfig: "hybrid_ssm"}) +@config_class() class HybridSSMModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "hybrid_ssm" @@ -157,13 +156,13 @@ class PretrainedHybridSSMModelConfig(PretrainedFastLLMModelConfig): model: HybridSSMModelConfig = FieldUpdate() -@config_class(dynamic_type={RunnableConfig: "train_hybrid_ssm", TrainerConfig: "hybrid_ssm"}) -class HybridSSMTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): +@config_class() +class HybridTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): data: GPTDataConfig = FieldUpdate() batch: GPTBatchConfig = FieldUpdate() @classmethod - def get_trainer_class(cls) -> type["HybridSSMTrainer"]: - from fast_llm.models.ssm.trainer import HybridSSMTrainer + def get_trainer_class(cls) -> type["SSMTrainer"]: + from fast_llm.models.ssm.trainer import SSMTrainer - return HybridSSMTrainer + return SSMTrainer diff --git a/fast_llm/models/ssm/trainer.py b/fast_llm/models/ssm/trainer.py index efa7b704..c0e5be26 100644 --- a/fast_llm/models/ssm/trainer.py +++ b/fast_llm/models/ssm/trainer.py @@ -1,10 +1,10 @@ import typing from fast_llm.models.gpt.trainer import GPTTrainer -from fast_llm.models.ssm.config import HybridSSMTrainerConfig +from fast_llm.models.ssm.config import HybridTrainerConfig from fast_llm.models.ssm.model import HybridSSMModel -class HybridSSMTrainer[ConfigType: HybridSSMTrainerConfig](GPTTrainer[ConfigType]): - config_class: typing.ClassVar[type[HybridSSMTrainerConfig]] = HybridSSMTrainerConfig +class SSMTrainer[ConfigType: HybridTrainerConfig](GPTTrainer[ConfigType]): + config_class: typing.ClassVar[type[HybridTrainerConfig]] = HybridTrainerConfig model_class: typing.ClassVar[type[HybridSSMModel]] = HybridSSMModel diff --git a/mkdocs.yaml b/mkdocs.yaml index a080bc83..ab71bc23 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -173,6 +173,7 @@ nav: - Continue training a model: recipes/continue-training.md - Upcycle Llama 3B to MoE: recipes/upcycle-llama-3b-to-moe.md - Instruction Finetuning: recipes/instruction-finetuning.md + - Generate: recipes/generate.md - Reference: - User Guide: - Configuration: user_guide/configuration.md diff --git a/setup.cfg b/setup.cfg index a48759c1..2e3b549f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -65,7 +65,8 @@ DOCS = mkdocs-git-revision-date-localized-plugin pypandoc_binary mkdocs-bibtex + cairosvg==2.7.0 [options.entry_points] console_scripts = - fast-llm = fast_llm.cli:fast_llm_main + fast-llm = fast_llm.tools.cli:fast_llm diff --git a/tests/common.py b/tests/common.py index bcc563d7..6179957b 100644 --- a/tests/common.py +++ b/tests/common.py @@ -13,7 +13,6 @@ from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample -from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.layers.ssm.config import SSMConfig from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.models.gpt.config import ( @@ -25,6 +24,7 @@ Starcoder2GPTHuggingfaceCheckpointFormat, ) from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, LLambaHuggingfaceCheckpointFormat +from fast_llm.tools.train import CliTrainingConfig from tests.compare_tensor_logs import CompareConfig, compare_tensor_logs # FIXME: figure out correct import of megatron modules without this hack @@ -361,6 +361,7 @@ def run_test_script( config: CompareConfig | None = None, prepare_fn=None, compare_fn=None, + do_compare: bool = True, ): if torch.cuda.device_count() < num_gpus: pytest.skip(f"Not enough GPUs to run test ({torch.cuda.device_count()}<{num_gpus})") @@ -392,7 +393,7 @@ def run_test_script( if is_megatron: script = [*script, f"--structured-logs-dir={path}", f"--data-cache-path={path}"] else: - script = ["train", model_type, *script, f"run.experiment_dir={path}"] + script = [model_type, *script, f"run.experiment_dir={path}"] header = ["Megatron-LM/pretrain_gpt.py"] if is_megatron else ["--no-python", "fast-llm", "train"] command = [ "python", @@ -408,12 +409,12 @@ def run_test_script( else: get_test_dataset() if num_gpus == 1 and not is_megatron: - RunnableConfig.parse_and_run(script) + CliTrainingConfig.parse_and_run(script) else: completed_proc = subprocess.run(command, env=env, timeout=60) if completed_proc.returncode: raise RuntimeError(f"Process failed with return code {completed_proc.returncode}") - if compare: + if compare and do_compare: if compare_fn is not None: compare_fn(TEST_RESULTS_PATH / name, TEST_RESULTS_PATH / compare) compare_tensor_logs( diff --git a/tests/conftest.py b/tests/conftest.py index 445f59bb..1c718c21 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,10 +3,19 @@ def pytest_addoption(parser): parser.addoption("--skip-slow", action="store_true") + parser.addoption( + "--run-extra-slow", + action="store_true", + default=False, + help="Run tests marked as extra_slow", + ) def pytest_configure(config): config.addinivalue_line("markers", "slow: Test is slow.") + config.addinivalue_line( + "markers", "extra_slow: Mark test as extra slow and skip unless --run-extra-slow is given." + ) def pytest_collection_modifyitems(config, items): @@ -15,3 +24,8 @@ def pytest_collection_modifyitems(config, items): for item in items: if "slow" in item.keywords: item.add_marker(skip_slow) + if not config.getoption("--run-extra-slow"): + skip_extra_slow = pytest.mark.skip(reason="need --run-extra-slow option to run") + for item in items: + if "extra_slow" in item.keywords: + item.add_marker(skip_extra_slow) diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 9dd7975c..17ba5de0 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -39,6 +39,43 @@ def test_write_memmap_dataset(dtype): ), f"Mismatch for document {i}: {document} != {dataset.get(i)}." +@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) +def test_write_memmap_preference_dataset(dtype): + def generate_valid_span(max_seq_length): + span = np.random.choice(np.arange(0, max_seq_length - 1), size=2, replace=False) + return np.sort(span) + + vocab_size = 1000 + max_seq_length = 8192 + num_samples = 100 + + documents = [ + GPTSample( + token_ids=np.random.randint(vocab_size, size=max_seq_length).astype(dtype), + chosen_span=generate_valid_span(max_seq_length=max_seq_length), + rejected_span=generate_valid_span(max_seq_length=max_seq_length), + ) + for _ in range(num_samples) + ] + with tempfile.TemporaryDirectory() as temp_dir: + prefix = pathlib.Path(temp_dir) + GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) + dataset = GPTMemmapDataset(name="foo", prefix=prefix) + for i, document in enumerate(documents): + dataset_item = dataset.get(i, use_preference_loss_spans=True) + assert np.array_equal( + dataset_item.token_ids, document.token_ids, equal_nan=True + ), f"Token ids mismatch for document {i}: {document} != {dataset.get(i)}." + + assert np.array_equal( + dataset_item.chosen_span, document.chosen_span, equal_nan=True + ), f"Chosen loss masking spans mismatch for document {i}: {document.chosen_span} != {dataset.get(i).chosen_span}." + + assert np.array_equal( + dataset_item.rejected_span, document.rejected_span, equal_nan=True + ), f"Rejected loss masking spans mismatch for document {i}: {document.rejected_span} != {dataset.get(i).rejected_span}." + + def test_load_metadata_from_hub(): with tempfile.TemporaryDirectory(suffix="test") as local_folder: get_preparator(local_folder, "lhoestq/demo1")._save_croissant_metadata() diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 5c5f5b90..ca32082b 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -14,9 +14,9 @@ FastLLMCheckpointFormat, ModelConfigType, ) -from fast_llm.engine.checkpoint.convert import ConvertConfig -from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode -from fast_llm.engine.multi_stage.multi_stage import ShardName +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode +from fast_llm.models.auto import model_registry +from fast_llm.tools.convert import ConvertConfig from tests.common import ( CONFIG_COMMON, FORCE_REUSE_RESULTS, @@ -30,8 +30,8 @@ ) from tests.compare_tensor_logs import CompareConfig, compare_logged_tensor -TEST_MODEL_CONFIG_CLS = FastLLMModelConfig.get_subclass(TEST_MODEL_TYPE) -TEST_MODEL_HF_CLS = TEST_MODEL_CONFIG_CLS.get_huggingface_model_class() +TEST_MODEL_CONFIG_CLS = model_registry[TEST_MODEL_TYPE] +TEST_MODEL_HF_CLS = TEST_MODEL_CONFIG_CLS.get_huggingface_model_for_causal_lm_class() TEST_MODEL_CLS = TEST_MODEL_CONFIG_CLS.get_model_class() TEST_BASE_MODEL_CONFIG_CLS = TEST_MODEL_CONFIG_CLS.get_base_model_config_class() @@ -75,6 +75,7 @@ def _compare_resume_fn(test_path: pathlib.Path, compare_path: pathlib.Path): @pytest.mark.depends(on=["test_checkpoint_and_eval"]) def test_resume(): + # Resume from iteration=1 and compare outputs with the baseline run. run_test_script( f"test_{TEST_MODEL}_resume", CONFIG_COMMON @@ -89,6 +90,24 @@ def test_resume(): ) +@pytest.mark.depends(on=["test_checkpoint_and_eval"]) +def test_resume_frozen(): + # Resume with frozen mlp. No comparison. + run_test_script( + f"test_{TEST_MODEL}_resume_frozen", + CONFIG_COMMON + + [ + "training.checkpoint.interval=1", + "training.evaluations.validation.interval=2", + "training.evaluations.validation.iterations=1", + "model.base_model.transformer.mlp_lr_scale=0.", + ], + compare=f"test_{TEST_MODEL}_checkpoint_and_eval", + prepare_fn=_prepare_resume_fn, + do_compare=False, + ) + + def _run_conversion(config: ConvertConfig): if config.output.path.is_dir() and not REUSE_RESULTS: shutil.rmtree(config.output.path) diff --git a/tests/test_functional.py b/tests/test_functional.py index 3e5c7f87..34a7a77f 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,13 +1,162 @@ +import random + import pytest import torch from fast_llm.functional.config import ActivationType, MLPRecomputeLevel +from fast_llm.functional.dpo import _compute_dpo_loss, _compute_logprobs_for_preference_spans from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped, torch_mlp_activation from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.utils import Assert from tests.common import requires_cuda +def ref_log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0) -> torch.Tensor: + if temperature != 1.0: + logits.div_(temperature) + batch_dim = logits.shape[:-1] + last_dim = logits.shape[-1] + + output = torch.nn.functional.cross_entropy(logits.reshape(-1, last_dim), labels.reshape(-1), reduction="none") + log_probs_labels = -output.view(*batch_dim) + + return log_probs_labels + + +def ref_packed_get_batch_logps( + logits: torch.FloatTensor, + labels: torch.LongTensor, + attention_mask, + prompt_id_lens, + packed_seq_lens, +) -> torch.FloatTensor: + labels = labels[:, 1:] + logits = logits[:, :-1, :] + per_token_logps = ref_log_probs_from_logits(logits, labels) + + loss_masks = attention_mask.clone().bool() + + index = 0 + for i, seq_len in enumerate(packed_seq_lens): + loss_masks[0, index : index + prompt_id_lens[i]] = False + index = index + seq_len + + loss_masks = loss_masks[:, 1:] + + logprobs_sums = [] + index = 0 + for i, seq_len in enumerate(packed_seq_lens): + seq = per_token_logps[0, index : index + seq_len - 1] + mask = loss_masks[0, index : index + seq_len - 1] + logprobs_sums.append((seq * mask).sum()) + index = index + seq_len + chosen_logps = logprobs_sums[: len(packed_seq_lens) // 2] + rejected_logps = logprobs_sums[len(packed_seq_lens) // 2 :] + + return torch.tensor(chosen_logps), torch.tensor(rejected_logps) + + +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8]) +@pytest.mark.parametrize("seq_length", [1024, 4096, 8192]) +@pytest.mark.parametrize("vocab_size", [1000, 2000, 8000]) +def test_preference_logps(batch_size, seq_length, vocab_size): + random.seed(0) + torch.manual_seed(0) + + def random_split(seq_length): + min_val = int(seq_length * 0.3) + max_val = int(seq_length * 0.7) + + if max_val < min_val: + max_val = min_val + + a = random.randint(min_val, max_val) + b = seq_length - a + return [a, b] + + logits = torch.randn(batch_size, seq_length, vocab_size) + targets = torch.randint(0, vocab_size, (batch_size, seq_length)) + packed_seq_lens = random_split(seq_length) # simulate different chosen/rejected lengths + prompt_id_lens = [int(min(packed_seq_lens) * 0.75)] * 2 # sequences are 75% prompt 25% generation + attention_mask = torch.tensor([1] * packed_seq_lens[0] + [2] * packed_seq_lens[1]).unsqueeze(0) + + chosen_span = torch.tensor([[prompt_id_lens[0], packed_seq_lens[0] - 1]]) - 1 # shift by 1 due to label shifting + rejected_span = ( + torch.tensor([[packed_seq_lens[0] + prompt_id_lens[1], packed_seq_lens[0] + packed_seq_lens[1] - 1]]) - 1 + ) # shift by 1 due to label shifting + + ref_chosen_logps, ref_rejected_logps = ref_packed_get_batch_logps( + logits, targets, attention_mask, prompt_id_lens, packed_seq_lens + ) + + chosen_logps, rejected_logps, selected_log_probs = _compute_logprobs_for_preference_spans( + logits=logits, + targets=targets[:, 1:], + chosen_spans=chosen_span, + rejected_spans=rejected_span, + ) + + ref_logps = ref_log_probs_from_logits(logits[:, :-1, :], targets[:, 1:]) + + # check all logps + Assert.custom(torch.allclose, ref_logps, selected_log_probs, rtol=1e-5) + + # check chosen and rejected summed logps + Assert.custom(torch.allclose, ref_chosen_logps, chosen_logps, rtol=1e-5) + Assert.custom(torch.allclose, ref_rejected_logps, rejected_logps, rtol=1e-5) + + +def ref_dpo_loss_fcn( + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + reference_chosen_logps: torch.Tensor, + reference_rejected_logps: torch.Tensor, + beta=1, + label_smoothing=0, +) -> torch.Tensor: + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + logits = pi_logratios - ref_logratios + + # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf) + losses = ( + -torch.nn.functional.logsigmoid(beta * logits) * (1 - label_smoothing) + - torch.nn.functional.logsigmoid(-beta * logits) * label_smoothing + ) + + loss = losses.mean() + + return loss + + +def test_dpo_loss(): + torch.manual_seed(0) + + NUM_SAMPLES = 20 + policy_chosen_logps = torch.rand(NUM_SAMPLES) + policy_rejected_logps = torch.rand(NUM_SAMPLES) + reference_chosen_logps = torch.rand(NUM_SAMPLES) + reference_rejected_logps = torch.rand(NUM_SAMPLES) + betas = torch.rand(NUM_SAMPLES) + + for i in range(NUM_SAMPLES): + fastllm_dpo_loss = _compute_dpo_loss( + policy_chosen_logps=policy_chosen_logps[i], + policy_rejected_logps=policy_rejected_logps[i], + reference_chosen_logps=reference_chosen_logps[i], + reference_rejected_logps=reference_rejected_logps[i], + beta=betas[i].item(), + ) + ref_dpo_loss = ref_dpo_loss_fcn( + policy_chosen_logps=policy_chosen_logps[i].unsqueeze(0), + policy_rejected_logps=policy_rejected_logps[i].unsqueeze(0), + reference_chosen_logps=reference_chosen_logps[i].unsqueeze(0), + reference_rejected_logps=reference_rejected_logps[i].unsqueeze(0), + beta=betas[i].item(), + ) + Assert.rms_close(fastllm_dpo_loss, ref_dpo_loss, 1e-5) + + @requires_cuda @pytest.mark.parametrize("gated", [True, False]) @pytest.mark.parametrize( diff --git a/tests/test_gpt_generate_and_forward.py b/tests/test_gpt_generate_and_forward.py new file mode 100644 index 00000000..134b51e6 --- /dev/null +++ b/tests/test_gpt_generate_and_forward.py @@ -0,0 +1,373 @@ +import huggingface_hub +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from fast_llm.engine.checkpoint.config import CheckpointLoadConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.schedule.config import ScheduleConfig +from fast_llm.engine.schedule.runner import ScheduleRunner +from fast_llm.models.gpt.config import LlamaGPTHuggingfaceCheckpointFormat, PretrainedGPTModelConfig +from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM +from tests.common import TEST_RESULTS_PATH, requires_cuda + + +def _prepare_checkpoint(model: str) -> str: + path = TEST_RESULTS_PATH.resolve() / "generate/model" + model_path = huggingface_hub.snapshot_download(repo_id=model, local_dir=path) + return model_path + + +def _prepare_data(tokenizer, use_batch_size2: bool): + messages = [ + {"role": "user", "content": "What is gravity?"}, + {"role": "user", "content": "Who is the president of EU?"}, + ] + if not use_batch_size2: + messages = messages[0:1] + + input_text = [tokenizer.apply_chat_template([el], tokenize=False) for el in messages] + + tokenizer.padding_side = "left" + inputs = tokenizer(input_text, padding="longest", return_tensors="pt").to("cuda") + return inputs + + +def _prepare_rand_data(vocab_size, use_batch_size2: bool): + inputs = torch.randint( + 1, + vocab_size, + [2 if use_batch_size2 else 1, 10], + dtype=torch.int64, + generator=torch.Generator().manual_seed(42), + ).cuda() + attention_mask = torch.ones_like(inputs) + # simulate left padding on one of the rows + if use_batch_size2: + inputs[1, :5] = 0 + attention_mask[1, :5] = 0 + return {"input_ids": inputs, "attention_mask": attention_mask} + + +def _get_hf_model(model_path: str, use_flash_attention: bool, use_bf16: bool): + hf_kwargs = {} + if use_flash_attention: + hf_kwargs["attn_implementation"] = "flash_attention_2" + hf_kwargs["torch_dtype"] = torch.bfloat16 + elif use_bf16: + hf_kwargs["torch_dtype"] = torch.bfloat16 + return AutoModelForCausalLM.from_pretrained(model_path, **hf_kwargs).to("cuda") + + +def _get_fast_llm_model( + model_path: str, use_flash_attention: bool, use_bf16: bool, checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat +): + updates = {} + if use_flash_attention: + updates[("base_model", "transformer", "use_flash_attention")] = True + updates[("distributed", "training_dtype")] = "bf16" + else: + updates[("base_model", "transformer", "use_flash_attention")] = False + if use_bf16: + updates[("distributed", "training_dtype")] = "bf16" + return HuggingfaceGPTModelForCausalLM.from_pretrained( + CheckpointLoadConfig( + path=model_path, + format=checkpoint_format, + model_weights=True, + ), + updates, + ) + + +def _get_fast_llm_model_from_model( + model_path: str, use_flash_attention: bool, use_bf16: bool, checkpoint_format=LlamaGPTHuggingfaceCheckpointFormat +): + updates = { + ("pretrained", "path"): model_path, + ("pretrained", "model_weights"): True, + ("pretrained", "format"): checkpoint_format.name, + } + + if use_flash_attention: + updates[("model", "base_model", "transformer", "use_flash_attention")] = True + updates[("model", "distributed", "training_dtype")] = "bf16" + else: + updates[("model", "base_model", "transformer", "use_flash_attention")] = False + if use_bf16: + updates[("model", "distributed", "training_dtype")] = "bf16" + + config = PretrainedGPTModelConfig.from_dict({}, updates) + multi_stage = config.model.get_model_class()(config.model) + schedule_config = ScheduleConfig() + runner = ScheduleRunner( + config=schedule_config, + multi_stage=multi_stage, + distributed_config=config.model.distributed, + ) + distributed = Distributed(config.model.distributed) + + with torch.no_grad(): + multi_stage.setup(distributed) + runner.setup(distributed) + + multi_stage.load_checkpoint(config.pretrained) + + return HuggingfaceGPTModelForCausalLM(multi_stage, runner=runner) + + +def _trim_output(output, inputs): + res = [] + for output_row, input_row in zip(output, inputs["input_ids"]): + res.append(output_row[len(input_row) :]) + return res + + +def _generate( + inputs, + hf_model, + fast_llm_model, + max_new_tokens: int, +): + return { + "hf": _trim_output(hf_model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=False), inputs), + "fast_llm": _trim_output( + fast_llm_model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=False), inputs + ), + } + + +def _compare_gen_outputs(outputs: dict[str, list[torch.Tensor]], min_matching_tokens: int | None = None): + for hf_output, fast_llm_output in zip(outputs["hf"], outputs["fast_llm"]): + if min_matching_tokens is not None: + hf_output = hf_output[:min_matching_tokens] + fast_llm_output = fast_llm_output[:min_matching_tokens] + assert torch.equal(hf_output, fast_llm_output) + + +def _test_for_batches( + hf_model, + fast_llm_model, + max_new_tokens, + min_matching_tokens_batch_size_1, + min_matching_tokens_batch_size_2, + tokenizer=None, +): + if tokenizer is not None: + inputs = _prepare_data(tokenizer, use_batch_size2=False) + else: + inputs = _prepare_rand_data(fast_llm_model.config.fast_llm_config.base_model.vocab_size, use_batch_size2=False) + outputs = _generate( + inputs, + hf_model, + fast_llm_model, + max_new_tokens=max_new_tokens, + ) + _compare_gen_outputs(outputs, min_matching_tokens=min_matching_tokens_batch_size_1) + + if tokenizer is not None: + inputs = _prepare_data(tokenizer, use_batch_size2=True) + else: + inputs = _prepare_rand_data(fast_llm_model.config.fast_llm_config.base_model.vocab_size, use_batch_size2=True) + outputs = _generate( + inputs, + hf_model, + fast_llm_model, + max_new_tokens=max_new_tokens, + ) + _compare_gen_outputs(outputs, min_matching_tokens=min_matching_tokens_batch_size_2) + + +@pytest.fixture(scope="module") +def model_and_tokenizer(): + model = "HuggingFaceTB/SmolLM2-135M-Instruct" + fast_llm_checkpoint_format = LlamaGPTHuggingfaceCheckpointFormat + model_path = _prepare_checkpoint(model) + tokenizer = AutoTokenizer.from_pretrained(model_path) + return model_path, tokenizer, fast_llm_checkpoint_format + + +@pytest.fixture(scope="module") +def small_model(): + from .common import _CONFIGS, TEST_RESULTS_PATH, run_test_script + + _, _, _, common_config, fast_llm_checkpoint_format = _CONFIGS["llama"] + run_test_script( + f"test_llama_generate_and_forward", + common_config + + ["training.checkpoint.interval=1", "training.export.format=llama", "training.export.interval=1"], + ) + return TEST_RESULTS_PATH / "test_llama_generate_and_forward/export/llama/2", fast_llm_checkpoint_format + + +def _test_generate( + model_path, + fast_llm_checkpoint_format, + use_flash_attention, + use_bf16, + max_new_tokens, + min_matching_tokens_batch_size_1, + min_matching_tokens_batch_size_2, + tokenizer=None, +): + hf_model = _get_hf_model(model_path, use_flash_attention, use_bf16) + fast_llm_model = _get_fast_llm_model(model_path, use_flash_attention, use_bf16, fast_llm_checkpoint_format) + + _test_for_batches( + hf_model, + fast_llm_model, + max_new_tokens, + min_matching_tokens_batch_size_1, + min_matching_tokens_batch_size_2, + tokenizer=tokenizer, + ) + + +@pytest.mark.extra_slow +@requires_cuda +@pytest.mark.parametrize( + "use_flash_attention, use_bf16, max_new_tokens, min_matching_tokens_batch_size_1, min_matching_tokens_batch_size_2", + [ + # No flash attention + no bf16 + (False, False, 10, 10, 10), + # No flash attention + with bf16 + (False, True, 10, 10, 10), + # Flash attention must be paired with bf16 + (True, True, 10, 10, 10), + ], +) +def test_generate( + model_and_tokenizer, + use_flash_attention, + use_bf16, + max_new_tokens, + min_matching_tokens_batch_size_1, + min_matching_tokens_batch_size_2, +): + model_path, tokenizer, fast_llm_checkpoint_format = model_and_tokenizer + _test_generate( + model_path, + fast_llm_checkpoint_format, + use_flash_attention, + use_bf16, + max_new_tokens, + min_matching_tokens_batch_size_1, + min_matching_tokens_batch_size_2, + tokenizer=tokenizer, + ) + + +@pytest.mark.slow +@requires_cuda +@pytest.mark.parametrize( + "use_flash_attention, use_bf16, max_new_tokens, min_matching_tokens_batch_size_1, min_matching_tokens_batch_size_2", + [ + # No flash attention + no bf16 + (False, False, 10, 10, 10), + # No flash attention + with bf16 + (False, True, 10, 10, 10), + # Flash attention must be paired with bf16 + (True, True, 10, 10, 10), + ], +) +def test_small_generate( + small_model, + use_flash_attention, + use_bf16, + max_new_tokens, + min_matching_tokens_batch_size_1, + min_matching_tokens_batch_size_2, +): + model_path, fast_llm_checkpoint_format = small_model + _test_generate( + model_path, + fast_llm_checkpoint_format, + use_flash_attention, + use_bf16, + max_new_tokens, + min_matching_tokens_batch_size_1, + min_matching_tokens_batch_size_2, + ) + + +def _test_generate_from_model(model_path, tokenizer, fast_llm_checkpoint_format): + max_new_tokens = 10 + min_matching_tokens_batch_size_1 = 10 + min_matching_tokens_batch_size_2 = 10 + + # Use flash attention for speed + hf_model = _get_hf_model(model_path, True, True) + fast_llm_model = _get_fast_llm_model_from_model(model_path, True, True, fast_llm_checkpoint_format) + + _test_for_batches( + hf_model, + fast_llm_model, + max_new_tokens, + min_matching_tokens_batch_size_1, + min_matching_tokens_batch_size_2, + tokenizer=tokenizer, + ) + + +@pytest.mark.extra_slow +@requires_cuda +def test_generate_from_model( + model_and_tokenizer, +): + model_path, tokenizer, fast_llm_checkpoint_format = model_and_tokenizer + _test_generate_from_model(model_path, tokenizer, fast_llm_checkpoint_format) + + +@pytest.mark.slow +@requires_cuda +def test_small_generate_from_model( + small_model, +): + model_path, fast_llm_checkpoint_format = small_model + _test_generate_from_model(model_path, None, fast_llm_checkpoint_format) + + +def _test_forward_return_hidden_states( + model_path, + fast_llm_checkpoint_format, + vocab_size: int | None = None, +): + # Use flash attention for speed + # TODO: hidden states have differences between HF and Fast-LLM despite resulting in the similar logits, + # decide if to leave as it. + # hf_model = _get_hf_model(model_path, True, True) + fast_llm_model = _get_fast_llm_model(model_path, True, True, fast_llm_checkpoint_format) + + inputs_ids = torch.randint( + 1, + fast_llm_model.config.fast_llm_config.base_model.vocab_size if vocab_size is None else vocab_size, + [1, 10], + dtype=torch.int64, + generator=torch.Generator().manual_seed(42), + ).cuda() + + # res_hf = hf_model.forward(input_ids=inputs_ids, output_hidden_states=True, return_dict=True, use_cache=False) + res_fast_llm = fast_llm_model.forward( + input_ids=inputs_ids, output_hidden_states=True, return_dict=True, use_cache=False + ) + + # hidden_states include embeddings layer + assert ( + len(res_fast_llm.hidden_states) - 1 == fast_llm_model.config.fast_llm_config.base_model.transformer.num_layers + ) + + +@pytest.mark.extra_slow +@requires_cuda +def test_forward_return_hidden_states( + model_and_tokenizer, +): + model_path, tokenizer, fast_llm_checkpoint_format = model_and_tokenizer + _test_forward_return_hidden_states(model_path, fast_llm_checkpoint_format, tokenizer.vocab_size) + + +@pytest.mark.slow +@requires_cuda +def test_small_forward_return_hidden_states(small_model): + model_path, fast_llm_checkpoint_format = small_model + _test_forward_return_hidden_states(model_path, fast_llm_checkpoint_format) diff --git a/tests/test_ssms.py b/tests/test_ssms.py index 0fec3741..f1ef5654 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -139,6 +139,7 @@ def test_load_from_llamba_checkpoint(distributed_config): assert torch.allclose(logits, hf_logits, atol=1e-2) +# TODO: Speed up this test or bring it back as an integration test. @pytest.mark.skip(reason="Too slow.") @pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") @pytest.mark.parametrize(