diff --git a/fast_llm/cli.py b/fast_llm/cli.py new file mode 100644 index 00000000..34546120 --- /dev/null +++ b/fast_llm/cli.py @@ -0,0 +1,35 @@ +import logging +import sys +import traceback + +from fast_llm.config import ValidationError +from fast_llm.engine.config_utils.logging import configure_logging +from fast_llm.engine.config_utils.run import log_main_rank +from fast_llm.engine.config_utils.runnable import RunnableConfig + +# Import these submodules to ensure classes are added to the dynamic class registry. +import fast_llm.data.auto # isort: skip +import fast_llm.engine.checkpoint.convert # isort: skip +import fast_llm.models.auto # isort: skip + +logger = logging.getLogger(__name__) + + +def fast_llm_main(args: list[str] | None = None): + # TODO: Add hook to register model classes? (environment variable?) + # (Pre-)configure logging + configure_logging() + try: + RunnableConfig.parse_and_run(args) + except Exception as e: + if sys.gettrace(): + raise + if isinstance(e, ValidationError): + log_main_rank(traceback.format_exc(), log_fn=logger.error) + else: + logger.critical(traceback.format_exc()) + sys.exit(1) + + +if __name__ == "__main__": + fast_llm_main() diff --git a/fast_llm/config.py b/fast_llm/config.py index 380100e3..a99553aa 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -16,7 +16,6 @@ logger = logging.getLogger(__name__) - _AUTO_VALIDATE = True MISSING = Tag("") @@ -245,7 +244,7 @@ def _process_config_class(cls: type["Config"]): def config_class[ T: Config -](registry: bool = False, dynamic_type: "dict[type[Config], str]|None" = None) -> typing.Callable[[type[T]], type[T]]: +](dynamic_type: "dict[type[Config], str]|None" = None) -> typing.Callable[[type[T]], type[T]]: """ Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications. """ @@ -270,11 +269,8 @@ def __init__(self, **kwargs): wrapped.__init__ = __init__ - wrapped._registry = Registry[str, type[wrapped]](wrapped.__name__, {}) if registry else None - if dynamic_type is not None: for cls_, name in dynamic_type.items(): - print(cls_, name, wrapped) cls_.register_subclass(name, wrapped) return wrapped @@ -316,7 +312,7 @@ class Config(metaclass=ConfigMeta): _setting_implicit_default: bool | None = Field(init=False) # A registry for all the config classes. - _registry: typing.ClassVar[Registry[str, type[typing.Self]] | None] = None + _registry: typing.ClassVar[Registry[str, type[typing.Self]]] = Registry[str, "type[Config]"]("Config", {}) def __setattr__(self, key: str, value: typing.Any) -> None: """ @@ -371,17 +367,6 @@ def validate[T: Config](self: T, *, _is_validating: bool = False) -> T: Validate a class and mark it as read-only This should not be overridden in derived classes. """ - # Should be handled in `from_dict`, but can fail if instantiating directly. - try: - expected_class = self.get_subclass(self.type) - except KeyError as e: - # Delayed instantiation error in `from_dict`. - raise ValidationError(*e.args) - - if expected_class is not None: - # Should be handled in `from_dict`, but can fail if instantiating directly. - Assert.is_(self.__class__, expected_class) - if not self._validated: try: self._validate() @@ -401,6 +386,17 @@ def _validate(self) -> None: Can be extended to add custom post-processing (typically before the super() call) and validation (typically after) """ + # Should be handled in `from_dict`, but can fail if instantiating directly. + try: + expected_class = self.get_subclass(self.type) + except KeyError as e: + # Delayed instantiation error in `from_dict`. + raise ValidationError(*e.args) + + if expected_class is not None: + # Should be handled in `from_dict`, but can fail if instantiating directly. + Assert.is_(self.__class__, expected_class) + if self._abstract: raise ValidationError(f"{type(self).__name__} is abstract") if not self.__class_validated__: @@ -409,6 +405,8 @@ def _validate(self) -> None: ) errors = [] with self._set_implicit_default(None): + # Set the type field, or override it to the provided type with the actual class for clarity and safety. + self.type = self.__class__.__name__ for name, field in self.fields(): if not field.init or field._field_type != dataclasses._FIELD: # noqa continue @@ -486,6 +484,7 @@ def _validate_element(cls, value, type_, name: str): raise FieldTypeError(f"Not a type.") elif issubclass(type_, Config): cls._validate_element_type(value, type_, strict=False) + value.validate(_is_validating=True) else: value = cls._validate_simple(value, type_) @@ -737,7 +736,7 @@ def from_dict( for keys, value in update.items(): set_nested_dict_value(default, keys, value, update_type) - return cls._from_dict(default, strict) + return cls._from_dict(default, strict=strict) @classmethod def from_flat_dict( @@ -899,8 +898,6 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ @classmethod def register_subclass(cls, name: str, cls_: type[typing.Self]) -> None: Assert.custom(issubclass, cls_, cls) - if cls._registry is None: - raise NotImplementedError(f"Subclass `{cls.__name__}` doesn't have a registry..") if name in cls._registry: old_cls = cls._registry[name] if old_cls.__name__ == cls_.__name__ and cls._registry[name].__module__ == cls_.__module__: @@ -916,7 +913,7 @@ def get_subclass(cls, name: str | None): return None cls_ = None for base_class in cls.__mro__: - if issubclass(base_class, Config) and base_class._registry is not None and name in base_class._registry: + if issubclass(base_class, Config) and name in base_class._registry: if cls_ is None: cls_ = base_class._registry[name] if not issubclass(cls_, cls): @@ -937,6 +934,12 @@ def __init_subclass__(cls): We need to postpone validation until the class has been processed by the dataclass wrapper. """ Assert.eq(cls.__name__, cls.__qualname__) + cls._registry = Registry[str, type[cls]](cls.__name__, {}) + if not cls._abstract: + Config.register_subclass(cls.__name__, cls) + short_name = cls.__name__.strip("Config") + if short_name != cls.__name__: + Config.register_subclass(short_name, cls) for base_class in cls.__mro__: if issubclass(base_class, Config) and base_class is not cls: assert cls.__class_validated__, ( @@ -982,7 +985,7 @@ def __init_subclass__(cls): cls.__annotations__[name] = base_class_field.type # Type for the field. At the end of class definition to avoid shadowing builtin. - type: str | None = Field( + type: str = Field( default=None, desc="The config class name.", hint=FieldHint.feature, diff --git a/fast_llm/data/auto.py b/fast_llm/data/auto.py index 902faf1c..c44e538f 100644 --- a/fast_llm/data/auto.py +++ b/fast_llm/data/auto.py @@ -1,13 +1,5 @@ -from fast_llm.data.preparator.config import DatasetPreparatorConfig -from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig -from fast_llm.utils import Registry +""" +Import these submodules to ensure classes are added to the dynamic class registry. +""" -dataset_preparator_registry = Registry[str, DatasetPreparatorConfig]( - "DatasetPreparator", - { - dataset_preparator.preparator_name: dataset_preparator - for dataset_preparator in [ - GPTMemmapDatasetPreparatorConfig, - ] - }, -) +from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig # isort: skip diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index c6fece9d..02c1b6c0 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -32,29 +32,18 @@ class GPTBatch: token_ids: torch.Tensor loss_masking_spans: list[torch.Tensor] | None = None sequence_lengths: list[torch.Tensor] | None = None - chosen_spans: list[torch.Tensor] | None = None - rejected_spans: list[torch.Tensor] | None = None def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch: stacked_ids = np.stack([sample.token_ids for sample in batch]) stacked_spans = None sequence_lengths = None - stacked_chosen_spans = None - stacked_rejected_spans = None if sampling_parameters.use_loss_masking_spans: stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] - if sampling_parameters.use_preference_loss_spans: - stacked_chosen_spans = [torch.from_numpy(sample.chosen_span) for sample in batch] - stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch] if not sampling_parameters.cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] return GPTBatch( - token_ids=torch.from_numpy(stacked_ids), - loss_masking_spans=stacked_spans, - sequence_lengths=sequence_lengths, - chosen_spans=stacked_chosen_spans, - rejected_spans=stacked_rejected_spans, + token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths ) @@ -160,7 +149,6 @@ def get_iterator( sampling_parameters = self._sampling_parameters[dataset_name] Assert.in_range_incl(batch_config.sequence_length, 1, sampling_parameters.sequence_length) log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...") - return iter( torch.utils.data.DataLoader( self._datasets[dataset_name], # noqa diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index ae87e0e7..ed9f57fc 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -73,7 +73,6 @@ class GPTSamplingParameters(SamplingParameters): sequence_length: int vocab_size: int use_loss_masking_spans: bool = False - use_preference_loss_spans: bool = False cross_document_attention: bool = True # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. @@ -93,7 +92,7 @@ class GPTSamplingData(SamplingData): truncate_documents: bool = True -@config_class(registry=True) +@config_class() class GPTSampledDatasetConfig(SampledDatasetConfig): pass diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index f39fd56f..ef060b00 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -34,16 +34,13 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None self._name = name self._prefix = pathlib.Path(prefix) self._has_spans = 0 - self._has_preference_spans = False with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: + assert self._version in [1, 2], f"Unsupported version for gpt_memmap dataset: {self._version}." + if self._version == 2: self._has_spans = struct.unpack("= 3: - self._has_preference_spans = struct.unpack("= 2: + if self._has_spans and self._version == 2: self._spans = [] self._num_spans = np.frombuffer( self._index_bin_buffer, @@ -91,36 +83,6 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None ).reshape(-1, 2) ) - # read preference spans - self._chosen_spans = None - self._rejected_spans = None - if self._has_preference_spans and self._version >= 3: - self._chosen_spans = [] - self._rejected_spans = [] - chosen_span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes - for idx in range(self._num_documents): - self._chosen_spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=2, - offset=chosen_span_offset + idx * 2 * np.dtype(np.int32).itemsize, - ) - ) - - rejected_span_offset = ( - offset + self._document_sizes.nbytes + self._pointers.nbytes + np.array(self._chosen_spans).nbytes - ) - for idx in range(self._num_documents): - self._rejected_spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=2, - offset=rejected_span_offset + idx * 2 * np.dtype(np.int32).itemsize, - ) - ) - self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) @@ -143,12 +105,7 @@ def __del__(self): del self._index_bin_buffer_mmap def get( - self, - idx: int, - offset: int = 0, - length: int | None = None, - use_loss_masking_spans: bool = False, - use_preference_loss_spans: bool = False, + self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False ) -> GPTSample: token_ids = np.frombuffer( self._bin_buffer, @@ -159,53 +116,13 @@ def get( sample_spans = None if use_loss_masking_spans and self._spans is not None: sample_spans = self._spans[idx] - - # filter spans that are outside the range of the selected tokens in the document + # adjust the spans for the offset and length sample_spans = sample_spans[ (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) ] - - # subtract by offset to normalize span boundaries - sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset # offset + sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset - - chosen_span = None - rejected_span = None - - if use_preference_loss_spans: - if not self._has_preference_spans: - raise ValueError("No preference spans found in memmap dataset.") - elif self._has_preference_spans and self._chosen_spans is None: - raise ValueError("Failed to read chosen spans from memmap dataset.") - elif self._has_preference_spans and self._rejected_spans is None: - raise ValueError("Failed to read rejected spans from memmap dataset.") - else: - chosen_span = self._chosen_spans[idx] - - # filter spans that are outside the range of the selected tokens in the document - chosen_span = chosen_span[(chosen_span[0] < offset + len(token_ids)) & (chosen_span[1] >= offset)][0] - - # subtract by offset to normalize span boundaries - chosen_span[0] = np.maximum(chosen_span[0], offset) - offset # offset - chosen_span[1] = np.minimum(chosen_span[1], offset + len(token_ids) - 1) - offset - - rejected_span = self._rejected_spans[idx] - - # filter spans that are outside the range of the selected tokens in the document - rejected_span = rejected_span[ - (rejected_span[0] < offset + len(token_ids)) & (rejected_span[1] >= offset) - ][0] - - # subtract by offset to normalize span boundaries - rejected_span[0] = np.maximum(rejected_span[0], offset) - offset # offset - rejected_span[1] = np.minimum(rejected_span[1], offset + len(token_ids) - 1) - offset - - return GPTSample( - token_ids=token_ids, - loss_masking_spans=sample_spans, - chosen_span=chosen_span, - rejected_span=rejected_span, - ) + return GPTSample(token_ids=token_ids, loss_masking_spans=sample_spans) @property def name(self) -> str: @@ -240,8 +157,6 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # number of spans for each document num_spans = [] spans = [] - chosen_spans = [] - rejected_spans = [] prefix = pathlib.Path(prefix) prefix.parent.mkdir(parents=True, exist_ok=True) @@ -267,10 +182,6 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP if document.loss_masking_spans is not None: num_spans.append(len(document.loss_masking_spans)) spans.append(document.loss_masking_spans) - if document.chosen_span is not None: - chosen_spans.append(document.chosen_span) - if document.rejected_span is not None: - rejected_spans.append(document.rejected_span) offset += doc_length * np.dtype(dtype).itemsize num_documents += 1 @@ -282,20 +193,15 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP spans = np.vstack(spans, dtype=np.int32) else: spans = np.array(spans, dtype=np.int32) - chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2) - rejected_spans = np.array(rejected_spans, dtype=np.int32).reshape(-1, 2) # Write the index file (.idx) with prefix.with_suffix(".idx").open("wb") as idx_stream: idx_stream.write(MEMMAP_INDEX_HEADER) # Indicates the version # Version 2 optionally adds loss-masking spans - # Version 3 optionally adds chosen/rejected spans - idx_stream.write(struct.pack(" 0 else 0)) - # Flag to indicate whether preference loss-masking spans are present - idx_stream.write(struct.pack(" 0 and rejected_spans.size > 0 else 0)) # Data type idx_stream.write(struct.pack(" None: Assert.in_range(self.rank, 0, self.world_size) -@config_class() +@config_class(dynamic_type={RunnableConfig: "prepare_gpt_memmap", DatasetPreparatorConfig: "gpt_memmap"}) class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): preparator_name: typing.ClassVar[str] = "gpt_memmap" diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 988e23e7..9b1d8f04 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -52,7 +52,7 @@ def tokenize_with_spans( token_spans = [] char_pos = 0 beginning_of_text = True - + for start, end in char_spans: if char_pos < start: curr_text = text[char_pos:start] diff --git a/fast_llm/tools/convert.py b/fast_llm/engine/checkpoint/convert.py similarity index 89% rename from fast_llm/tools/convert.py rename to fast_llm/engine/checkpoint/convert.py index 648138ec..0f670854 100644 --- a/fast_llm/tools/convert.py +++ b/fast_llm/engine/checkpoint/convert.py @@ -1,4 +1,3 @@ -import argparse import json import logging import math @@ -9,7 +8,6 @@ from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.functional.config import TritonConfig -from fast_llm.models.auto import model_registry from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -18,7 +16,7 @@ logger = logging.getLogger(__name__) -@config_class() +@config_class(dynamic_type={RunnableConfig: "convert"}) class ConvertConfig(RunnableConfig): input: CheckpointLoadConfig = Field() output: CheckpointSaveConfig = Field() @@ -27,24 +25,10 @@ class ConvertConfig(RunnableConfig): layers_per_step: int | None = Field(default=None) model: type[FastLLMModelConfig] = Field(default=None) - @classmethod - def _get_parser(cls): - parser = super()._get_parser() - parser.add_argument( - "model_type", - choices=model_registry.keys(), - help="The Fast-LLM model type to use. Must be defined in the model registry in `fast_llm.models.auto`.", - ) - return parser - - @classmethod - def _from_parsed_args(cls, parsed: argparse.Namespace, unparsed: list[str]): - config = super()._from_parsed_args(parsed, unparsed) - config.model = model_registry[parsed.model_type] - return config - def _validate(self): assert self.model is not None + if isinstance(self.model, str): + self.model = FastLLMModelConfig.get_subclass(self.model) self.input.setup(self.model) self.output.setup(self.model) super()._validate() @@ -158,7 +142,3 @@ def run(self): # All good! (self.output.path / "ok").open("w") logger.info(f">>> All done!") - - -if __name__ == "__main__": - ConvertConfig.parse_and_run() diff --git a/fast_llm/engine/config_utils/runnable.py b/fast_llm/engine/config_utils/runnable.py index 6142de47..ac10225e 100644 --- a/fast_llm/engine/config_utils/runnable.py +++ b/fast_llm/engine/config_utils/runnable.py @@ -19,10 +19,17 @@ @config_class() class RunnableConfig(Config): @classmethod - def parse_and_run(cls, args=None) -> None: - parsed, unparsed = cls._get_parser().parse_known_args(args) + def parse_and_run(cls, args: list[str] | None = None) -> None: + if args is None: + args = sys.argv[1:] + cls_ = cls + while len(args) >= 1 and "=" not in args[0] and not args[0].startswith("-"): + # Allow chained dynamic type selection without the `type=`, ex. `train gpt`. + cls_ = cls_.get_subclass(args[0]) + args = args[1:] + parsed, unparsed = cls_._get_parser().parse_known_args([f"type={cls_.__name__}"] + args) with NoAutoValidate(): - config: "RunnableConfig" = cls._from_parsed_args(parsed, unparsed) + config: "RunnableConfig" = cls_._from_parsed_args(parsed, unparsed) try: config.configure_logging() config.validate() diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index a5be2e7e..b0a7a26b 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -23,6 +23,7 @@ DistributedCheckpointFormat, ) from fast_llm.engine.config_utils.run import ExperimentConfig +from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import PretrainedFastLLMModelConfig from fast_llm.engine.optimizer.config import OptimizerConfig from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig @@ -274,7 +275,7 @@ class ShutdownConfig(IntervalConfig): ) -@config_class() +@config_class(dynamic_type={RunnableConfig: "train"}) class TrainingConfig(Config): evaluations: dict[str, EvaluationConfig] = Field( default_factory=dict, diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 054c26c3..f6fbd4f5 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -1,3 +1,4 @@ +import abc import enum import typing @@ -6,6 +7,8 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: + import torch + from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.layers.common.linear import LinearBase, LinearLike from fast_llm.layers.common.normalization import LayerNorm, RMSNorm @@ -23,26 +26,42 @@ class NormalizationImplementation(str, enum.Enum): triton = "triton" -class NormalizationType(str, enum.Enum): - """ - An enum for the available normalization layers. - TODO: Add no_norm type? - """ +@config_class() +class NormalizationConfig(BaseModelConfig): + pass - layer_norm = "layer_norm" - rms_norm = "rms_norm" + @abc.abstractmethod + def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": + pass + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is NormalizationConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return LayerNormalizationConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) -@config_class(registry=True) -class NormalizationConfig(BaseModelConfig): + +@config_class(dynamic_type={NormalizationConfig: "none"}) +class NoNormalizationConfig(NormalizationConfig): _abstract = False - # Normalization type - type: NormalizationType = Field( - default=NormalizationType.layer_norm, - desc="The type of normalization to use, for example Layer Norm or RMS Norm.", - hint=FieldHint.architecture, - ) + @abc.abstractmethod + def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": + return torch.nn.Identity() + + +@config_class() +class LayerNormalizationBaseConfig(NormalizationConfig): + """ + Common configuration for layer norm and rms norm + """ + # TODO: Rename to normalization_epsilon epsilon: float = Field( default=1e-5, @@ -69,7 +88,6 @@ class NormalizationConfig(BaseModelConfig): ) def get_layer(self, hidden_dim: "TensorDim") -> "LayerNorm | RMSNorm": - from fast_llm.layers.common.normalization import LayerNorm, RMSNorm from fast_llm.tensor import init_uniform_ kwargs = { @@ -83,14 +101,12 @@ def get_layer(self, hidden_dim: "TensorDim") -> "LayerNorm | RMSNorm": kwargs["weight_init_method"] = init_uniform_( mean - self.initialization_range, mean + self.initialization_range ) - if self.type == NormalizationType.layer_norm: - if self.initialization_range: - kwargs["bias_init_method"] = init_uniform_(-self.initialization_range, self.initialization_range) - return LayerNorm(**kwargs) - elif self.type == NormalizationType.rms_norm: - return RMSNorm(**kwargs) - else: - raise ValueError(self.type) + return self.module_class(**kwargs) + + @property + @abc.abstractmethod + def module_class(self): + pass @classmethod def _from_dict( @@ -107,27 +123,47 @@ def _from_dict( return super()._from_dict(default, strict, flat) -for name in NormalizationType: - # We need this because we are using the reserved field name `type`. - # TODO: Implement proper dynamic typing. - NormalizationConfig.register_subclass(name.value, NormalizationConfig) +@config_class(dynamic_type={NormalizationConfig: "layer_norm"}) +class LayerNormalizationConfig(LayerNormalizationBaseConfig): + _abstract = False + + @property + def module_class(self): + from fast_llm.layers.common.normalization import LayerNorm + + return LayerNorm -class PeftType(str, enum.Enum): - # TODO : Use a dynamic config type instead. - none = "none" - lora = "lora" +@config_class(dynamic_type={NormalizationConfig: "rms_norm"}) +class RMSNormalizationConfig(LayerNormalizationBaseConfig): + _abstract = False + + @property + def module_class(self): + from fast_llm.layers.common.normalization import RMSNorm + + return RMSNorm @config_class() class PeftConfig(BaseModelConfig): + @abc.abstractmethod + def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": + pass + + +@config_class() +class NoPeftConfig(PeftConfig): + _abstract = False + + def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": + return linear + + +@config_class() +class LoRAConfig(PeftConfig): _abstract = False - type: PeftType = Field( - default=PeftType.none, - desc="The type of parameter-efficient fine tuning to use Only LoRA is supported at the moment.", - hint=FieldHint.core, - ) rank: int = Field( default=8, desc="The LoRA rank, i.e. the size of the intermediate dimension.", @@ -145,20 +181,15 @@ class PeftConfig(BaseModelConfig): ) def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": - if self.type == PeftType.none: - return linear - elif self.type == PeftType.lora: - from fast_llm.layers.common.peft import lora_linear - - # TODO: Init method? - return lora_linear( - linear, - linear.weight.param_init_method, - linear.weight.param_init_method, - self.rank, - self.alpha, - self.dropout, - **kwargs, - ) - else: - raise NotImplementedError(self.type) + from fast_llm.layers.common.peft import lora_linear + + # TODO: Init method? + return lora_linear( + linear, + linear.weight.param_init_method, + linear.weight.param_init_method, + self.rank, + self.alpha, + self.dropout, + **kwargs, + ) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index e7ef0b15..c83b0b43 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -1,3 +1,4 @@ +import abc import enum import functools import logging @@ -11,7 +12,7 @@ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel, TritonConfig -from fast_llm.layers.common.config import NormalizationConfig, PeftConfig, PeftType +from fast_llm.layers.common.config import LoRAConfig, NoPeftConfig, NormalizationConfig, PeftConfig from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: @@ -88,21 +89,38 @@ class TransformerLossNames: router_z_loss = "router_z_loss" -class RotaryEmbeddingType(str, enum.Enum): - none = "none" - default = "default" - llama3 = "llama3" - yarn = "yarn" +@config_class() +class RotaryConfig(BaseModelConfig): + # TODO: Move rotary to its own submodule. + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is RotaryConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return DefaultRotaryConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + @property + def enabled(self) -> bool: + return False -@config_class(registry=True) -class RotaryConfig(BaseModelConfig): + def get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") -> "torch.Tensor": + raise NotImplementedError() + + +@config_class(dynamic_type={RotaryConfig: "none"}) +class NoRotaryConfig(RotaryConfig): + _abstract = False + + +@config_class(dynamic_type={RotaryConfig: "default"}) +class DefaultRotaryConfig(RotaryConfig): _abstract = False - type: RotaryEmbeddingType = Field( - default=RotaryEmbeddingType.none, - desc="The type of rotary embedding to use. Choices: none, default, llama3.", - hint=FieldHint.architecture, - ) theta: float = Field( default=10000, desc="Scale for the rotary positional embeddings", @@ -114,54 +132,147 @@ class RotaryConfig(BaseModelConfig): desc="Enable the triton implementation of the rotary embeddings. Affects the model layout.", hint=FieldHint.architecture, ) - # TODO: These are not really architecture parameters, but we want to import them from huggingface. - scale_factor: float = Field( - default=8.0, desc="Scaling factor for llama3-type scaling.", hint=FieldHint.architecture - ) - low_frequency_factor: float = Field( - default=1.0, desc="Low frequency factor for llama3-type scaling.", hint=FieldHint.feature - ) - high_frequency_factor: float = Field( - default=4.0, desc="High frequency factor for llama3-type scaling.", hint=FieldHint.feature - ) - original_context_length: int = Field( - default=8192, desc="Original context length for llama3/yarn-type scaling.", hint=FieldHint.feature - ) + + @property + def enabled(self) -> bool: + return True + + @property + def complex_format(self) -> bool: + # TODO: Make a backup implementation that doesn't affect the layout. + return not self.triton + + def _validate(self) -> None: + super()._validate() + if self.triton and not TritonConfig.TRITON_ENABLED: + warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.") + + def get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") -> "torch.Tensor": + import torch + + from fast_llm.functional.rotary import convert_rotary_complex_to_real + + # Calculate the complex frequencies (https://blog.eleuther.ai/rotary-embeddings/) + # `exp(i * n * a) = cos(n * a) + i sin(n * a)`, + # `a = theta ** - (2 * (channel // 2) / kv_channels)`, + # where n is the position in the sequence. + # We preform the calculation in high precision because it matters for rotary embeddings. + positions = torch.arange(sequence_length, device=device, dtype=torch.float64) + angles = torch.outer(positions, self._get_angle_scales(kv_channels, device)) + frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) + if not self.complex_format: + frequencies = convert_rotary_complex_to_real( + torch.view_as_real(frequencies).flatten(-2), kv_channels, 3 + ).contiguous() + return frequencies + + def _get_angle_scales(self, kv_channels: int, device="cuda") -> "torch.Tensor": + import torch + + return self.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) + + +@config_class(dynamic_type={RotaryConfig: "llama3"}) +class Llama3RotaryConfig(DefaultRotaryConfig): + """ + Llama3 scaling: https://github.com/meta-llama/llama-models/blob/baf7b01b6e62bc7126c7b558d2b67d4533142680/models/llama3/reference_impl/model.py#L45-L67 + """ + + # TODO: Add descriptions. + scale_factor: float = Field(default=8.0, hint=FieldHint.feature) + low_frequency_factor: float = Field(default=1.0, hint=FieldHint.feature) + high_frequency_factor: float = Field(default=4.0, hint=FieldHint.feature) + original_context_length: int = Field(default=8192, hint=FieldHint.feature) + + def _validate(self) -> None: + super()._validate() + Assert.gt(self.high_frequency_factor, self.low_frequency_factor) + + def _get_angle_scales(self, kv_channels: int, device="cuda") -> "torch.Tensor": + import torch + + scales = super()._get_angle_scales(kv_channels, device) + low_frequency_wavelength = self.original_context_length / self.low_frequency_factor + high_frequency_wavelength = self.original_context_length / self.high_frequency_factor + new_scales = [] + for scale in scales: + wavelength = 2 * math.pi / scale + if wavelength < high_frequency_wavelength: + new_scales.append(scale) + elif wavelength > low_frequency_wavelength: + new_scales.append(scale / self.scale_factor) + else: + smooth = (self.original_context_length / wavelength - self.low_frequency_factor) / ( + self.high_frequency_factor - self.low_frequency_factor + ) + new_scales.append((1 - smooth) * scale / self.scale_factor + smooth * scale) + return torch.stack(new_scales) + + +@config_class(dynamic_type={RotaryConfig: "yarn"}) +class YarnRotaryConfig(DefaultRotaryConfig): + """ + Yarn scaling: + https://github.com/huggingface/transformers/blob/006d9249ec0270ff6c4d3840979d23fe94bdc763/src/transformers/modeling_rope_utils.py#L163 + [original paper](https://arxiv.org/abs/2309.00071) + """ + + # TODO: Add descriptions. + scale_factor: float = Field(default=8.0, hint=FieldHint.feature) attention_factor: None | float = Field( default=None, - desc="Attention factor for yarn-type scaling.", hint=FieldHint.feature, ) beta_fast: float = Field( default=32.0, - desc="Beta-fast for yarn-type scaling.", hint=FieldHint.feature, ) beta_slow: float = Field( default=1.0, - desc="Beta-slow for yarn-type scaling.", hint=FieldHint.feature, ) - - @property - def enabled(self) -> bool: - return self.type != RotaryEmbeddingType.none - - @property - def complex_format(self) -> bool: - # TODO: Make a backup implementation that doesn't affect the layout. - return self.enabled and not self.triton + original_context_length: int = Field(default=8192, hint=FieldHint.feature) def _validate(self) -> None: + if self.attention_factor is None: + with self._set_implicit_default(): + self.attention_factor = 0.1 * math.log(self.scale_factor) + 1.0 super()._validate() - if self.triton and not TritonConfig.TRITON_ENABLED: - warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.") + def _linear_ramp_factor(self, min, max, dim): + import torch + + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func -for name in RotaryEmbeddingType: - # We need this because we are using the reserved field name `type`. - # TODO: Implement proper dynamic typing. - RotaryConfig.register_subclass(name.value, RotaryConfig) + def get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") -> "torch.Tensor": + return super().get_frequencies(sequence_length, kv_channels, device) * self.attention_factor + + def _get_angle_scales(self, kv_channels: int, device="cuda") -> "torch.Tensor": + import torch + + scales = super()._get_angle_scales(kv_channels, device) + # TODO: max_position_embeddings or original_context_length? + # see https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L304 + low = max(self._get_correction(self.beta_slow, kv_channels), 0) + high = min(self._get_correction(self.beta_fast, kv_channels), kv_channels - 1) + if low == high: + high += 0.001 # Prevent singularity + + # Get n-dimensional rotational scaling corrected for extrapolation + extrapolation_factor = torch.clamp( + (torch.arange(kv_channels, dtype=torch.float32, device=scales.device) - low) / (high - low), 0, 1 + ) + return scales / self.scale_factor * extrapolation_factor + scales * (1 - extrapolation_factor) + + def _get_correction(self, beta: float, dim: int) -> float: + return math.floor( + dim * math.log(self.original_context_length / (beta * 2 * math.pi)) / (2 * math.log(self.theta)) + ) class AddLinearBiasChoices(str, enum.Enum): @@ -181,10 +292,51 @@ class TransformerSubLayerName(str, enum.Enum): mlp_2 = "mlp_2" -@config_class(registry=True) +@config_class() class TransformerPeftConfig(PeftConfig): + @abc.abstractmethod + def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": + pass + + @abc.abstractmethod + def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": + pass + + @abc.abstractmethod + def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": + pass + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is TransformerPeftConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return TransformerNoPeftConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + +@config_class(dynamic_type={TransformerPeftConfig: "none"}) +class TransformerNoPeftConfig(NoPeftConfig, TransformerPeftConfig): + _abstract = False + + def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": + return super().apply_linear(linear) + + def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": + return module + + def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": + return parameter + + +@config_class(dynamic_type={TransformerPeftConfig: "lora"}) +class TransformerLoRAConfig(LoRAConfig, TransformerPeftConfig): layers: list[TransformerSubLayerName] = Field( - default=None, + default=(TransformerSubLayerName.query, TransformerSubLayerName.value_), desc="The layers on which to apply LoRA.", hint=FieldHint.feature, ) @@ -194,66 +346,50 @@ class TransformerPeftConfig(PeftConfig): ) def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - if self.type != PeftType.none: - if layer_type is None or self.layers is None or layer_type in self.layers: - if layer_type == TransformerSubLayerName.key: - return super().apply_linear(linear, out_channel_end=div(linear._out_dim.global_size, 2)) - elif layer_type == TransformerSubLayerName.value_: - return super().apply_linear(linear, out_channel_begin=div(linear._out_dim.global_size, 2)) - else: - return super().apply_linear(linear) - elif self.freeze_others: - linear.weight.requires_grad = False + if layer_type is None or self.layers is None or layer_type in self.layers: + if layer_type == TransformerSubLayerName.key: + return super().apply_linear(linear, out_channel_end=div(linear._out_dim.global_size, 2)) + elif layer_type == TransformerSubLayerName.value_: + return super().apply_linear(linear, out_channel_begin=div(linear._out_dim.global_size, 2)) + else: + return super().apply_linear(linear) + elif self.freeze_others: + linear.weight.requires_grad = False return linear def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - if self.type != PeftType.none and self.freeze_others: + if self.freeze_others: for parameter in module.parameters(): parameter.requires_grad = False return module def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - if self.type != PeftType.none and self.freeze_others: + if self.freeze_others: parameter.requires_grad = False return parameter def _validate(self) -> None: - if self.layers is None: - with self._set_implicit_default(): - # Setting the default layers only whee PeFT is enabled - # so they don't appear when serializing the default transformer config. - self.layers = ( - [TransformerSubLayerName.query, TransformerSubLayerName.value_] - if self.type == PeftType.lora - else [] - ) - if self.type != PeftType.none: - if TransformerSubLayerName.mlp_1 in self.layers or TransformerSubLayerName.mlp_2 in self.layers: - # TODO: Add MLP support. - raise NotImplementedError("LoRA not supported for MLP.") - if TransformerSubLayerName.dense in self.layers: - # TODO: Support InputParallelLinear (different output format). - raise NotImplementedError("LoRA not supported for attention dense layer.") - if ( - sum( - name in self.layers - for name in ( - TransformerSubLayerName.key_value, - TransformerSubLayerName.key, - TransformerSubLayerName.value_, - ) - ) - > 1 - ): - raise ValueError( - f"{TransformerSubLayerName.key_value.value}, {TransformerSubLayerName.key.value} and {TransformerSubLayerName.value_.value} are mutually exclusive." + super()._validate() + if TransformerSubLayerName.mlp_1 in self.layers or TransformerSubLayerName.mlp_2 in self.layers: + # TODO: Add MLP support. + raise NotImplementedError("LoRA not supported for MLP.") + if TransformerSubLayerName.dense in self.layers: + # TODO: Support InputParallelLinear (different output format). + raise NotImplementedError("LoRA not supported for attention dense layer.") + if ( + sum( + name in self.layers + for name in ( + TransformerSubLayerName.key_value, + TransformerSubLayerName.key, + TransformerSubLayerName.value_, ) - - -for name in PeftType: - # We need this because we are using the reserved field name `type`. - # TODO: Implement proper dynamic typing. - TransformerPeftConfig.register_subclass(name.value, TransformerPeftConfig) + ) + > 1 + ): + raise ValueError( + f"{TransformerSubLayerName.key_value.value}, {TransformerSubLayerName.key.value} and {TransformerSubLayerName.value_.value} are mutually exclusive." + ) @config_class() @@ -644,7 +780,7 @@ def _from_dict( default, "use_rotary_embeddings", ("rotary", "type"), - lambda x: RotaryEmbeddingType.default if x else RotaryEmbeddingType.none, + lambda x: "default" if x else "none", ) cls._handle_renamed_field(default, "rotary_embedding_scale", ("rotary", "theta"), lambda x: math.exp(-x)) cls._handle_renamed_field(default, "triton_rotary", ("rotary", "triton")) diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 0697bd21..bcf41365 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -1,135 +1,16 @@ import logging -import math import typing import torch from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.functional.rotary import convert_rotary_complex_to_real -from fast_llm.layers.transformer.config import ( - RotaryConfig, - RotaryEmbeddingType, - TransformerConfig, - TransformerDimNames, - TransformerKwargs, -) +from fast_llm.layers.transformer.config import RotaryConfig, TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) -def apply_llama3_scaling(config: RotaryConfig, frequencies: torch.Tensor) -> torch.Tensor: - """ - Llama3 scaling: https://github.com/meta-llama/llama-models/blob/baf7b01b6e62bc7126c7b558d2b67d4533142680/models/llama3/reference_impl/model.py#L45-L67 - """ - low_frequency_wavelength = config.original_context_length / config.low_frequency_factor - high_frequency_wavelength = config.original_context_length / config.high_frequency_factor - new_frequencies = [] - for frequency in frequencies: - wavelength = 2 * math.pi / frequency - if wavelength < high_frequency_wavelength: - new_frequencies.append(frequency) - elif wavelength > low_frequency_wavelength: - new_frequencies.append(frequency / config.scale_factor) - else: - assert low_frequency_wavelength != high_frequency_wavelength - smooth = (config.original_context_length / wavelength - config.low_frequency_factor) / ( - config.high_frequency_factor - config.low_frequency_factor - ) - new_frequencies.append((1 - smooth) * frequency / config.scale_factor + smooth * frequency) - return torch.tensor(new_frequencies, dtype=frequencies.dtype, device=frequencies.device), 1.0 - - -def apply_yarn_scaling(config: RotaryConfig, frequencies: torch.Tensor, kv_channels, sequence_length) -> torch.Tensor: - """ - Yarn scaling: - https://github.com/huggingface/transformers/blob/006d9249ec0270ff6c4d3840979d23fe94bdc763/src/transformers/modeling_rope_utils.py#L163 - [original paper](https://arxiv.org/abs/2309.00071) - """ - base = config.theta - partial_rotary_factor = 1.0 - dim = int(kv_channels * partial_rotary_factor) - factor = config.scale_factor - - attention_factor = config.attention_factor - if attention_factor is None: - attention_factor = 0.1 * math.log(factor) + 1.0 - - # Compute the inverse frequencies - def find_correction_dim(num_rotations, dim, base, max_position_embeddings): - """Inverse dimension formula to find the dimension based on the number of rotations""" - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - - def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings): - """Find dimension range bounds based on rotations""" - low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) - return max(low, 0), min(high, dim - 1) - - def linear_ramp_factor(min, max, dim): - if min == max: - max += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs - # to expand the possible context length. In other words, interpolation = apply scaling factor. - # pos_freqs = base ** (torch.arange(0, dim, 2).float().to(frequencies.device) / dim) - # inv_freq_extrapolation = 1.0 / pos_freqs - # inv_freq_interpolation = 1.0 / (factor * pos_freqs) - - inv_freq_extrapolation = frequencies - inv_freq_interpolation = frequencies / factor - - # TODO: max_position_embeddings or original_context_length? - # see https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L304 - low, high = find_correction_range(config.beta_fast, config.beta_slow, dim, base, config.original_context_length) - - # Get n-dimensional rotational scaling corrected for extrapolation - inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(frequencies.device) - inv_freq = ( - inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) - + inv_freq_extrapolation * inv_freq_extrapolation_factor - ) - - return inv_freq, attention_factor - - -def get_rotary_frequencies( - config: RotaryConfig, - sequence_length, - kv_channels, - *, - device="cuda", -) -> torch.Tensor: - # Calculate the complex frequencies (https://blog.eleuther.ai/rotary-embeddings/) - # `exp(i * n * a) = cos(n * a) + i sin(n * a)`, - # `a = theta ** - (2 * (channel // 2) / kv_channels)`, - # where n is the position in the sequence. - # We preform the calculation in high precision because it matters for rotary embeddings. - positions = torch.arange(sequence_length, device=device, dtype=torch.float64) - frequencies = config.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) - # Apply scaling - if config.type == RotaryEmbeddingType.llama3: - frequencies, attention_scaling = apply_llama3_scaling(config, frequencies) - elif config.type == RotaryEmbeddingType.yarn: - frequencies, attention_scaling = apply_yarn_scaling(config, frequencies, kv_channels, sequence_length) - else: - attention_scaling = 1.0 - angles = torch.outer(positions, frequencies) - frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) - if not config.complex_format: - frequencies = convert_rotary_complex_to_real( - torch.view_as_real(frequencies).flatten(-2), kv_channels, 3 - ).contiguous() - # Advanced Rope types like yarn apply a post-processing scaling factor, equivalent to scaling attention. - frequencies = frequencies * attention_scaling - return frequencies - - class RotaryEmbeddingPreprocessor(Preprocessor): _scalar_dim: TensorDim _kv_channels_dim: TensorDim @@ -155,8 +36,7 @@ def _create_tensors(self, sequence_length: int) -> None: return self._tensor_cache_max_sequence_length = sequence_length - self._rotary_embedding_frequencies = get_rotary_frequencies( - self._config, + self._rotary_embedding_frequencies = self._config.get_frequencies( sequence_length, self._kv_channels_dim.global_size, device=self._tensor_space.distributed.device, @@ -239,10 +119,7 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: ] if (sequence_lengths := kwargs.get(TransformerKwargs.sequence_lengths, None)) is not None: seq_ids = torch.stack( - [ - torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)]) - for sample_lens in sequence_lengths - ] + [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] ) document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(self._tensor_space.distributed.device) kwargs[TransformerKwargs.attention_mask] = ( diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index 8f16aaea..3be74856 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -1,30 +1,7 @@ -from fast_llm.engine.multi_stage.config import FastLLMModelConfig -from fast_llm.engine.training.config import TrainerConfig -from fast_llm.models.custom.config import CustomModelConfig, CustomTrainerConfig -from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig -from fast_llm.models.ssm.config import HybridSSMModelConfig, HybridTrainerConfig -from fast_llm.utils import Registry +""" +Import these submodules to ensure classes are added to the dynamic class registry. +""" -model_registry = Registry[str, FastLLMModelConfig]( - "Model", - { - model.model_name: model - for model in [ - GPTModelConfig, - CustomModelConfig, - HybridSSMModelConfig, - ] - }, -) - -trainer_registry = Registry[str, TrainerConfig]( - "Model", - { - trainer.get_field("model").type.model_name: trainer - for trainer in [ - GPTTrainerConfig, - CustomTrainerConfig, - HybridTrainerConfig, - ] - }, -) +from fast_llm.models.custom.config import CustomModelConfig, CustomTrainerConfig # isort: skip +from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig # isort: skip +from fast_llm.models.ssm.config import HybridSSMModelConfig, HybridSSMTrainerConfig # isort: skip diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index bd733692..46264d29 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -25,8 +25,14 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex -from fast_llm.layers.common.config import NormalizationType -from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType, TransformerConfig +from fast_llm.layers.common.config import LayerNormalizationConfig, RMSNormalizationConfig +from fast_llm.layers.transformer.config import ( + DefaultRotaryConfig, + Llama3RotaryConfig, + RoutingType, + TransformerConfig, + YarnRotaryConfig, +) from fast_llm.models.gpt.config import ( GPTBaseModelConfig, GPTModelConfig, @@ -184,7 +190,7 @@ def _create_transformer_layer_converters( self, fast_llm_layer_name: str, hf_layer_name: str, ignore_export: bool = False ) -> list[WeightConverter]: transformer_config: TransformerConfig = self._model.config.base_model.transformer - norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm + norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) converters = [] names_bias_cls = [ # Self-attn @@ -250,7 +256,7 @@ def _create_transformer_layer_converters( def _create_lm_head_converters(self) -> list[WeightConverter]: num_layers = self._model.config.base_model.transformer.num_layers prediction_heads = self._model.config.base_model.prediction_heads - norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm + norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) converters = [] # Next-token prediction head @@ -318,10 +324,11 @@ def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantExportParamConverter(export_names=(("architectures",),), export_value=["Starcoder2ForCausalLM"]), ConstantImportParamConverter( - fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=RotaryEmbeddingType.default + fast_llm_names=(("transformer", "rotary", "type"),), fast_llm_value=DefaultRotaryConfig ), ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.layer_norm + fast_llm_names=(("transformer", "normalization", "type"),), + fast_llm_value=LayerNormalizationConfig.__name__, ), RenameParamConverter( fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) @@ -350,75 +357,89 @@ class CommonLlamaHuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + fast_llm_names=(("transformer", "normalization", "type"),), + fast_llm_value=RMSNormalizationConfig.__name__, ), RenameParamConverter( fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) ), RenameParamConverter( fast_llm_names=(("transformer", "kv_channels"),), - export_names=(("head_dim"),), + export_names=(("head_dim",),), ), ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), - RopeScalingParamConverter( - fast_llm_names=( - ("transformer", "rotary", "type"), - ("transformer", "rotary", "scale_factor"), - ("transformer", "rotary", "low_frequency_factor"), - ("transformer", "rotary", "high_frequency_factor"), - ("transformer", "rotary", "original_context_length"), - ("transformer", "rotary", "attention_factor"), - ("transformer", "rotary", "beta_fast"), - ("transformer", "rotary", "beta_slow"), + LLamaRotaryParamConverter( + fast_llm_names=(("transformer", "rotary"),), + export_names=( + ("rope_theta",), + ("rope_scaling",), ), - export_names=(("rope_scaling",),), ), ] @dataclasses.dataclass -class RopeScalingParamConverter(ParamConverter): - _HUGGINGFACE_NAMES = ( - "rope_type", - "factor", - "low_freq_factor", - "high_freq_factor", - "original_max_position_embeddings", - "attention_factor", - "beta_fast", - "beta_slow", - ) - +class LLamaRotaryParamConverter(ParamConverter): def __post_init__(self): - Assert.eq(len(self.fast_llm_names), 8) - Assert.eq(len(self.export_names), 1) + Assert.eq(len(self.fast_llm_names), 1) + Assert.eq(len(self.export_names), 2) def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - rope_type, *parameters = fast_llm_values - if rope_type == RotaryEmbeddingType.default: - return (None,) - elif rope_type == RotaryEmbeddingType.llama3: - return ({key: value for key, value in zip(self._HUGGINGFACE_NAMES, ("llama3", *parameters), strict=True)},) - elif rope_type == RotaryEmbeddingType.yarn: - return ({key: value for key, value in zip(self._HUGGINGFACE_NAMES, ("yarn", *parameters), strict=True)},) + (rotary_config,) = fast_llm_values + if type(rotary_config) is DefaultRotaryConfig: + rotary_scaling = { + "rope_type": "default", + } + elif type(rotary_config) is Llama3RotaryConfig: + rotary_scaling = { + "rope_type": "llama3", + "factor": rotary_config.scale_factor, + "low_freq_factor": rotary_config.low_frequency_factor, + "high_freq_factor": rotary_config.high_frequency_factor, + "original_max_position_embeddings": rotary_config.original_context_length, + } + elif type(rotary_config) is YarnRotaryConfig: + rotary_scaling = { + "rope_type": "yarn", + "attention_factor": rotary_config.attention_factor, + "beta_fast": rotary_config.beta_fast, + "beta_slow": rotary_config.beta_slow, + "original_max_position_embeddings": rotary_config.original_context_length, + } else: - raise ValueError(f"Unsupported rotary scaling type: {rope_type}") + raise ValueError(f"Unsupported rotary type: {type(rotary_config).__name__}") + + return rotary_config.theta, rotary_scaling def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - (export_value,) = export_values - if ( - export_value is None - or export_value is MISSING - or (rope_type := export_value[self._HUGGINGFACE_NAMES[0]]) == "default" - ): - return (RotaryEmbeddingType.default,) + (DEFAULT,) * 7 - elif rope_type == RotaryEmbeddingType.llama3: - return ("llama3", *[export_value.get(key, DEFAULT) for key in self._HUGGINGFACE_NAMES[1:]]) - elif rope_type == RotaryEmbeddingType.yarn: - return ("yarn", *[export_value.get(key, DEFAULT) for key in self._HUGGINGFACE_NAMES[1:]]) - else: - raise ValueError(f"Unsupported rotary scaling type: {rope_type}") + rotary_theta, rope_scaling = export_values + rotary_type = "default" if rope_scaling in (None, MISSING) else rope_scaling.get("rope_type", "default") + rotary_config = { + "type": rotary_type, + "theta": rotary_theta, + } + if rotary_type == "default": + pass + elif rotary_type == "llama3": + rotary_config.update( + { + "scale_factor": rope_scaling.get("factor", DEFAULT), + "low_frequency_factor": rope_scaling.get("low_freq_factor", DEFAULT), + "high_frequency_factor": rope_scaling.get("high_freq_factor", DEFAULT), + "original_context_length": rope_scaling.get("original_max_position_embeddings", DEFAULT), + } + ) + elif rotary_type == "yarn": + rotary_config.update( + { + "attention_factor": rope_scaling.get("attention_factor", DEFAULT), + "beta_fast": rope_scaling.get("beta_fast", DEFAULT), + "beta_slow": rope_scaling.get("beta_slow", DEFAULT), + "original_context_length": rope_scaling.get("original_max_position_embeddings", DEFAULT), + } + ) + return (rotary_config,) # RotaryConfig.from_dict(rotary_config) class LlamaHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): @@ -481,7 +502,8 @@ def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ ConstantExportParamConverter(export_names=(("architectures",),), export_value=["Qwen2ForCausalLM"]), ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + fast_llm_names=(("transformer", "normalization", "type"),), + fast_llm_value=RMSNormalizationConfig.__name__, ), RenameParamConverter( fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) @@ -490,18 +512,12 @@ def _create_config_converters(cls) -> list[ParamConverter]: ConstantImportParamConverter( fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value="only_attn_qkv" ), - RopeScalingParamConverter( - fast_llm_names=( - ("transformer", "rotary", "type"), - ("transformer", "rotary", "scale_factor"), - ("transformer", "rotary", "low_frequency_factor"), - ("transformer", "rotary", "high_frequency_factor"), - ("transformer", "rotary", "original_context_length"), - ("transformer", "rotary", "attention_factor"), - ("transformer", "rotary", "beta_fast"), - ("transformer", "rotary", "beta_slow"), + LLamaRotaryParamConverter( + fast_llm_names=(("transformer", "rotary"),), + export_names=( + ("rope_theta",), + ("rope_scaling",), ), - export_names=(("rope_scaling",),), ), IgnoreImportQwen2SlidingWindowParamsConverter(), ] @@ -637,7 +653,7 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig def _create_lm_head_converters(self) -> list[WeightConverter]: num_layers = self._model.config.base_model.transformer.num_layers prediction_heads = self._model.config.base_model.prediction_heads - norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm + norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) converters = [] # Next-token prediction head diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index cf7da387..0da4acbb 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -8,7 +8,7 @@ from fast_llm.data.data.gpt.data import GPTBatch from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig -from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM +from fast_llm.engine.inference.huggingface import HuggingfacePreTrainedModel from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner @@ -22,7 +22,7 @@ class HuggingfaceGPTModelConfig(HuggingfaceModelConfig): fast_llm_config: GPTModelConfig -class HuggingfaceGPTModelForCausalLM(HuggingfaceBaseModelForCausalLM): +class HuggingfaceGPTModelForCausalLM(HuggingfacePreTrainedModel): config_class = HuggingfaceGPTModelConfig config: HuggingfaceGPTModelConfig runner_class: typing.ClassVar[type[GPTInferenceRunner]] = GPTInferenceRunner @@ -55,32 +55,21 @@ def forward( if output_attentions: raise NotImplementedError() + if output_hidden_states: + raise NotImplementedError() + if attention_mask is not None: + raise NotImplementedError() + if position_ids is not None: + raise NotImplementedError() if inputs_embeds is not None: raise NotImplementedError() if labels is not None: raise NotImplementedError() - # NOTE: We are ignoring position_ids as we reconstruct them from attention_mask via sequence_lengths. - if attention_mask is not None: - # First non zero indexes or zero index if the row is all zeros (invalid row) - first_non_zero_indexes = attention_mask.argmax(dim=1) - - # Check if the sequence is left-padded and if the remaining ones are continuous 1-ns - assert (attention_mask.sum(axis=1) == (attention_mask.shape[1] - first_non_zero_indexes)).all() - - sequence_lenghts = [ - torch.tensor( - [attention_mask.shape[1]] if el == 0 else [el, attention_mask.shape[1] - el], dtype=torch.int64 - ) - for el in first_non_zero_indexes.tolist() - ] - else: - sequence_lenghts = None - # Iteration serves as a random seed, using random module because it's not seeded by Fast LLM iteration = random.randint(0, 2**32) batch = self.fast_llm_base_model.preprocess( - GPTBatch(input_ids, sequence_lengths=sequence_lenghts), phase=PhaseType.inference, iteration=iteration + GPTBatch(input_ids), phase=PhaseType.inference, iteration=iteration ) ((input_, kwargs),) = batch @@ -93,35 +82,23 @@ def forward( # The transformers will save the present keys and values to this list. kwargs[TransformerKwargs.presents] = [] - if output_hidden_states: - kwargs["output_hidden_states"] = True - kwargs["hidden_states"] = {} - else: - kwargs["output_hidden_states"] = False - self._inference_runner.forward(input_, kwargs, iteration=iteration) # TODO: Make a proper way of returning the model output. logits = kwargs["logits"] - # TODO: convert hidden state form dict to list to be the same as with HFs - hidden_states = None - if output_hidden_states: - hidden_states = kwargs["hidden_states"] - if not return_dict: - # TODO: Then implementing cache, check hidden state goes before past in the tuple - if output_hidden_states: - outputs = (logits, hidden_states) - else: - outputs = (logits,) - + outputs = (logits,) if use_cache: outputs += (kwargs[TransformerKwargs.presents],) return outputs return transformers.modeling_outputs.CausalLMOutputWithPast( logits=logits, - hidden_states=hidden_states, past_key_values=kwargs[TransformerKwargs.presents], ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + raise NotImplementedError() diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 190b2ffa..783eb621 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -16,7 +16,7 @@ from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import NormalizationType +from fast_llm.layers.common.config import RMSNormalizationConfig from fast_llm.models.gpt.conversion import MLPLayer2Converter from fast_llm.models.ssm.config import HybridSSMModelConfig, LLambaHuggingfaceCheckpointFormat from fast_llm.models.ssm.model import HybridSSMModel @@ -51,7 +51,8 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_names=(("ssm", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) ), ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + fast_llm_names=(("transformer", "normalization", "type"),), + fast_llm_value=RMSNormalizationConfig.__name__, ), RenameParamConverter( fast_llm_names=(("vocab_size",),), diff --git a/fast_llm/tools/__init__.py b/fast_llm/tools/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/fast_llm/tools/cli.py b/fast_llm/tools/cli.py deleted file mode 100644 index 8df884fe..00000000 --- a/fast_llm/tools/cli.py +++ /dev/null @@ -1,43 +0,0 @@ -import argparse -import logging -import sys -import traceback - -from fast_llm.config import ValidationError -from fast_llm.engine.config_utils.logging import configure_logging -from fast_llm.engine.config_utils.run import log_main_rank - -logger = logging.getLogger(__name__) - - -def fast_llm(args=None): - # TODO: Add hook to register model classes? (environment variable?) - # (Pre-)configure logging - configure_logging() - parser = argparse.ArgumentParser(add_help=False) - parser.add_argument("subcommand", choices=["train", "convert", "prepare"]) - parsed, unparsed = parser.parse_known_args(args) - try: - if parsed.subcommand == "train": - from fast_llm.tools.train import CliTrainingConfig as Runnable - elif parsed.subcommand == "convert": - from fast_llm.tools.convert import ConvertConfig as Runnable - elif parsed.subcommand == "prepare": - from fast_llm.tools.prepare_dataset import PrepareDatasetConfig as Runnable - else: - raise RuntimeError("Unknown subcommand") - Runnable.parse_and_run(unparsed) - except ValidationError: - if sys.gettrace(): - raise - log_main_rank(traceback.format_exc(), log_fn=logger.error) - sys.exit(1) - except Exception: # noqa - if sys.gettrace(): - raise - logger.critical(traceback.format_exc()) - sys.exit(1) - - -if __name__ == "__main__": - fast_llm() diff --git a/fast_llm/tools/prepare_dataset.py b/fast_llm/tools/prepare_dataset.py deleted file mode 100644 index aafe2690..00000000 --- a/fast_llm/tools/prepare_dataset.py +++ /dev/null @@ -1,24 +0,0 @@ -import argparse - -from fast_llm.data.auto import dataset_preparator_registry -from fast_llm.engine.config_utils.runnable import RunnableConfig - - -class PrepareDatasetConfig(RunnableConfig): - @classmethod - def _get_parser(cls): - parser = super()._get_parser() - parser.add_argument( - "model_type", - choices=dataset_preparator_registry.keys(), - help="The Fast-LLM model type to use. Must be defined in the model registry in `fast_llm.models.auto`.", - ) - return parser - - @classmethod - def _from_parsed_args(cls, parsed: argparse.Namespace, unparsed: list[str]): - return dataset_preparator_registry[parsed.model_type]._from_parsed_args(parsed, unparsed) - - -if __name__ == "__main__": - PrepareDatasetConfig.parse_and_run() diff --git a/fast_llm/tools/train.py b/fast_llm/tools/train.py deleted file mode 100644 index ae902279..00000000 --- a/fast_llm/tools/train.py +++ /dev/null @@ -1,24 +0,0 @@ -import argparse - -from fast_llm.engine.config_utils.runnable import RunnableConfig -from fast_llm.models.auto import trainer_registry - - -class CliTrainingConfig(RunnableConfig): - @classmethod - def _get_parser(cls): - parser = super()._get_parser() - parser.add_argument( - "model_type", - choices=trainer_registry.keys(), - help="The Fast-LLM model type to use. Must be defined in the trainer registry in `fast_llm.models.auto`.", - ) - return parser - - @classmethod - def _from_parsed_args(cls, parsed: argparse.Namespace, unparsed: list[str]): - return trainer_registry[parsed.model_type]._from_parsed_args(parsed, unparsed) - - -if __name__ == "__main__": - CliTrainingConfig.parse_and_run() diff --git a/fast_llm/utils.py b/fast_llm/utils.py index d89c9d76..460745e3 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -218,6 +218,11 @@ def __setitem__(self, key: KeyType, value: ValueType): raise KeyError(f"Entry {key} already in {self._name} registry") self._data[key] = value + def __delitem__(self, key: KeyType): + if key not in self: + raise KeyError(f"Entry {key} not found in {self._name} registry") + del self._data[key] + def keys(self) -> list[KeyType]: return list(self._data) diff --git a/tests/data/common.py b/tests/data/common.py index cacb28e6..00c3ff20 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -189,9 +189,10 @@ def validate_indexed_dataset_sampling( return token_ids -@config_class(dynamic_type={GPTSampledDatasetConfig: "mock_memmap"}) +@config_class() class MockGPTMemmapDatasetConfig(GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False + type_: typing.ClassVar[str | None] = "mock_memmap" num_documents: int | None = Field( default=None, desc="Expected number of documents in the dataset.", diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 7578a5f0..4ddba660 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -11,7 +11,6 @@ from fast_llm.engine.multi_stage.config import StageConfig from fast_llm.engine.multi_stage.stage import Stage from fast_llm.functional.config import CrossEntropyImpl -from fast_llm.layers.common.config import NormalizationType from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead @@ -91,7 +90,7 @@ def test_lm_head( config = GPTBaseModelConfig.from_dict( { "transformer": { - "normalization": {"type": NormalizationType.rms_norm}, + "normalization": {"type": "rms_norm"}, "hidden_size": HIDDEN_SIZE, "num_layers": 0, }, diff --git a/tests/test_config.py b/tests/test_config.py index 80bed418..07617f35 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -11,8 +11,7 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.models.auto import trainer_registry -from fast_llm.models.gpt.config import GPTModelConfig, PretrainedGPTModelConfig +from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert, check_equal_nested from tests.common import TEST_RESULTS_PATH @@ -32,7 +31,7 @@ def run_without_import(cmd: str): "sys.path=[p for p in sys.path if not any(x in p for x in ('site-packages', 'dist-packages', '.egg'))]", # We still want to enable imports from within Fast-llm f"sys.path.append('{repo_path}')", - "from fast_llm.tools.cli import fast_llm as main", + "from fast_llm.cli import fast_llm_main as main", cmd, ] ), @@ -44,24 +43,24 @@ def run_without_import(cmd: str): def test_validate_train_gpt_without_import(): - run_without_import("main(['train', 'gpt', '-v'])") + run_without_import("main(['train_gpt', '-v'])") def test_validate_prepare_gpt_memmap_without_import(): run_without_import( - "main(['prepare', 'gpt_memmap', '-v', 'dataset.path=test', 'output_path=test', 'tokenizer.path=test'])" + "main(['prepare_gpt_memmap', '-v', 'dataset.path=test', 'output_path=test', 'tokenizer.path=test'])" ) def test_validate_convert_gpt_without_import(): - run_without_import("main(['convert', 'gpt', '-v'])") + run_without_import("main(['convert', 'model=gpt', '-v'])") def test_validate_example_config(): fast_llm_config_dict = yaml.safe_load( (pathlib.Path(__file__).parents[1] / "examples" / "mistral.yaml").read_text() ) - trainer_registry["gpt"].from_dict(fast_llm_config_dict) + GPTTrainerConfig.from_dict(fast_llm_config_dict) def test_do_use_flash_attention(): @@ -111,7 +110,7 @@ def test_pretrained_config(load_config: ModelConfigType): "rotary": {"type": "default"}, "num_layers": 12, # Default "hidden_size": 1024, # Default - "window_size": 32, # Non-architecture + "window_size": 32, "ffn_hidden_size": 4096, # Implicit default, default value "activation_type": "silu", # Implicit default, non-default value "head_groups": 4, @@ -132,7 +131,7 @@ def test_pretrained_config(load_config: ModelConfigType): "transformer": { # rotary: Don't override nested. "normalization": {"implementation": "triton"}, # Update non-default nested - "peft": {"freeze_others": False}, # Update default nested, non-architecture + "peft": {"type": "lora", "freeze_others": False}, # Update default nested, change type "hidden_size": 512, # Override, affects derived value (kv channels) "head_groups": 1, # Override to default }, @@ -157,9 +156,9 @@ def test_pretrained_config(load_config: ModelConfigType): if load_config in (ModelConfigType.fast_llm, ModelConfigType.model): expected_config["base_model"] = { "transformer": { - "normalization": {"type": "rms_norm", "implementation": "triton"}, - "rotary": {"type": "default"}, - "peft": {"freeze_others": False}, + "normalization": {"type": "RMSNormalizationConfig", "implementation": "triton"}, + "rotary": {"type": "DefaultRotaryConfig"}, + "peft": {"type": "TransformerLoRAConfig", "freeze_others": False, "layers": ["query", "value"]}, "num_layers": 12, "hidden_size": 512, "ffn_hidden_size": 4096, @@ -171,6 +170,11 @@ def test_pretrained_config(load_config: ModelConfigType): "vocab_size": 1000, } else: + base_model_update["transformer"]["peft"] = { + "type": "TransformerLoRAConfig", + "freeze_others": False, + "layers": ["query", "value"], + } expected_config["base_model"] = base_model_update check_equal_nested(serialized_config, expected_config) diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index bb468ceb..aeea5b8c 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -2,14 +2,14 @@ from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.tools.train import CliTrainingConfig from fast_llm.utils import Assert from tests.common import CONFIG_COMMON, requires_cuda def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: - parsed, unparsed = CliTrainingConfig._get_parser().parse_known_args([model_type] + args) - config: TrainerConfig = CliTrainingConfig._from_parsed_args(parsed, unparsed) + cls = TrainerConfig.get_subclass(model_type) + parsed, unparsed = cls._get_parser().parse_known_args(args) + config: TrainerConfig = cls._from_parsed_args(parsed, unparsed) distributed = Distributed(config.model.distributed) trainer = config.get_trainer_class()(config=config) trainer.setup(distributed, config.get_run(distributed)) diff --git a/tests/test_triton_kernels.py b/tests/test_triton_kernels.py index 108a2898..b44c0580 100644 --- a/tests/test_triton_kernels.py +++ b/tests/test_triton_kernels.py @@ -28,8 +28,7 @@ from fast_llm.functional.triton.pointwise import triton_add, triton_copy, triton_fill from fast_llm.functional.triton.rotary import triton_rotary_ from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.transformer.config import RotaryConfig, RotaryEmbeddingType -from fast_llm.layers.transformer.preprocessing import get_rotary_frequencies +from fast_llm.layers.transformer.config import DefaultRotaryConfig from fast_llm.utils import Assert, rms_diff from tests.common import requires_cuda @@ -92,8 +91,7 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): y1 = apply_rotary_embeddings( x, - get_rotary_frequencies( - RotaryConfig(type=RotaryEmbeddingType.default, triton=False), + DefaultRotaryConfig(triton=False).get_frequencies( sequence_length, kv_channels, device="cuda", @@ -103,12 +101,7 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): y2 = convert_rotary_real_to_complex( triton_rotary_( convert_rotary_complex_to_real(x, kv_channels, 3), - get_rotary_frequencies( - RotaryConfig(type=RotaryEmbeddingType.default, triton=True), - sequence_length, - kv_channels, - device="cuda", - ), + DefaultRotaryConfig(triton=True).get_frequencies(sequence_length, kv_channels, device="cuda"), ), kv_channels, 3, diff --git a/tools/push_model.py b/tools/push_model.py index edab3312..39a3b914 100644 --- a/tools/push_model.py +++ b/tools/push_model.py @@ -27,7 +27,7 @@ raise ImportError("Please install huggingface_hub to use this script") from e -from fast_llm.tools.convert import ConvertConfig # isort:skip +from fast_llm.engine.checkpoint.convert import ConvertConfig # isort:skip logger = logging.getLogger(__name__)