From 513775730095c5e53e4d982139a250aba9c99b52 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 26 Mar 2025 00:10:47 -0400 Subject: [PATCH 001/114] stuff --- fast_llm/config.py | 177 ++++++++++++++++++----- fast_llm/data/data/config.py | 11 +- fast_llm/data/data/gpt/config.py | 12 +- fast_llm/data/dataset/config.py | 26 ++-- fast_llm/data/dataset/gpt/config.py | 13 +- fast_llm/engine/checkpoint/config.py | 5 +- fast_llm/engine/checkpoint/external.py | 4 +- fast_llm/engine/distributed/config.py | 3 +- fast_llm/engine/schedule/config.py | 7 +- fast_llm/engine/training/config.py | 4 +- fast_llm/layers/language_model/config.py | 23 +-- fast_llm/layers/transformer/config.py | 94 ++++++------ fast_llm/profile.py | 4 +- fast_llm/utils.py | 34 ----- tests/data/common.py | 7 +- 15 files changed, 241 insertions(+), 183 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index f1c88965..326845f0 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1,3 +1,4 @@ +import contextlib import dataclasses import enum import logging @@ -9,7 +10,7 @@ import yaml -from fast_llm.utils import Assert, Tag, get_type_name, header, log, pop_nested_dict_value, set_nested_dict_value +from fast_llm.utils import Assert, Tag, get_type_name, header, log logger = logging.getLogger(__name__) @@ -43,6 +44,13 @@ class _ConfigDictFormat(str, enum.Enum): tuple = "tuple" +class UpdateType(str, enum.Enum): + # Override entries no matter what they contais. + override = "override" + # Override atomic entries and lists, but update dicts recursively by setting or overriding only the specified entries. + update = "update" + + class FieldHint: """ A label defined for each config field, to let the user and some methods know how important each field is. @@ -125,6 +133,9 @@ def __init__( # Should raise an Exception in case of failure, and return the validated value. # Run before the default validation (type check). valid: typing.Optional[typing.Callable[[typing.Any], typing.Any]] = None, + # Option to skip (postpone) instantiation of a `Config` field. + # Note: The config still needs to be instantiated for validation to succeed. + # auto_instantiate: bool = True, default=dataclasses.MISSING, default_factory=dataclasses.MISSING, init: bool = True, @@ -152,6 +163,7 @@ def __init__( self.doc = doc self.hint = hint self.valid = valid + # self.auto_instantiate = auto_instantiate class FieldUpdate(dict): @@ -254,7 +266,16 @@ def config_class(cls=None): def wrap(cls): Assert.custom(issubclass, cls, Config) - return _process_config_class(dataclasses.dataclass(cls)) + wrapped = _process_config_class(dataclasses.dataclass(cls)) + + wrapped_init = cls.__init__ + + def __init__(self, **kwargs): + wrapped_init(self, **kwargs) + self._explicit_fields = set(kwargs) + + cls.__init__ = __init__ + return wrapped # See if we're being called as @config_class or @config_class(). if cls is None: @@ -277,9 +298,17 @@ class Config: # We can't use @config_class on this one because it needs this class to be defined, so we assume this one is OK. __class_validated__: typing.ClassVar[bool] = True + # Set to true to prevent instantiation. _abstract: typing.ClassVar[bool] = False + # Keep track of whether an instance has been validated _validated: bool = Field(init=False, repr=False) + # Keep track of unknown fields so they can be reported during validation. _unknown_fields: dict[str, typing.Any] = Field(init=False, repr=False) + # Keep track of explicitly set fields to ensure they get serialized and used as config updates. + _explicit_fields: set[str] = Field(init=False, repr=False) + # Used within `_set_implicit_default` to set implicit defaults for fields + # without them being automatically added to `_explicit_fields`. + _setting_implicit_default: bool = Field(init=False, repr=False) def __post_init__(self): """ @@ -288,6 +317,7 @@ def __post_init__(self): and all post-processing should be done in `_validate` """ self._validated = False + self._setting_implicit_default = False if _AUTO_VALIDATE: self.validate() @@ -305,6 +335,12 @@ def __setattr__(self, key: str, value: typing.Any) -> None: f"Cannot set attribute `{key}`" f" in configuration class `{get_type_name(type(self))}` after validation." ) + elif not getattr(self, "_setting_implicit_default", True): + field = self.get_field(key) + if field.init and field._field_type != dataclasses._FIELD_CLASSVAR: + # Adding to explicit field list except within `_set_implicit_default` context + # and during dataclass initialization (`_setting_implicit_default` not yet set). + self._explicit_fields.add(key) super().__setattr__(key, value) def __delattr__(self, key: str) -> None: @@ -318,6 +354,12 @@ def __delattr__(self, key: str) -> None: ) super().__delattr__(key) + @contextlib.contextmanager + def _set_implicit_default(self): + self._setting_implicit_default = True + yield + self._setting_implicit_default = False + def validate[T](self: T, *, _is_validating: bool = False) -> T: """ Validate a class and mark it as read-only @@ -332,6 +374,7 @@ def validate[T](self: T, *, _is_validating: bool = False) -> T: else: raise type(e)("\n".join(e.args)) from None self._validated = True + print("WLIEHGIUWERGNHBWIO", self.__class__.__name__, self._explicit_fields) return self def _validate(self) -> None: @@ -344,16 +387,17 @@ def _validate(self) -> None: """ self._check_abstract() errors = [] - for name, field in self.fields(): - if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa - continue - value = getattr(self, name) - if value is DEFAULT: - # Replace the value with its default. - # We still need to validate because some fields have invalid defaults. - value = field.default - new_value = self._validate_nested(value, field.type, field.name, field.valid, errors, False) - setattr(self, name, new_value) + with self._set_implicit_default(): + for name, field in self.fields(): + if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa + continue + value = getattr(self, name) + if value is DEFAULT: + # Replace the value with its default. + # We still need to validate because some fields have invalid defaults. + value = field.default + new_value = self._validate_nested(value, field.type, field.name, field.valid, errors, False) + setattr(self, name, new_value) for name in getattr(self, "_unknown_fields", {}): errors.append(f"Unknown field `{name}` in class {self._get_class_name()}") if errors: @@ -555,9 +599,8 @@ def _to_dict( return arg_dict - @classmethod def _add_field_to_args( - cls, + self, args: dict | list, name: str | None, field: Field | None, @@ -574,46 +617,48 @@ def _add_field_to_args( ): # Exclude class variables and derived fields unless requested explicitly. return - elif isinstance(value, Config): + explicit_field = ( + field is None + or name in self._explicit_fields + or (verbose is not None and verbose >= FieldHintImportance[field.hint]) + ) + if isinstance(value, Config): field_value = value._to_dict( verbose=verbose, all_fields=all_fields, format_=format_, serializable=serializable, ) + # Empty configs can safely be trimmed. + explicit_field = all_fields elif isinstance(value, (list, tuple, set)): field_value = {} if format_ == _ConfigDictFormat.tuple else [] for i, list_value in enumerate(value): - cls._add_field_to_args( + self._add_field_to_args( field_value, str(i), None, list_value, verbose, all_fields, format_, serializable ) elif isinstance(value, dict): field_value = {} for dict_name, dict_value in value.items(): - cls._add_field_to_args( + self._add_field_to_args( field_value, dict_name, None, dict_value, verbose, all_fields, format_, serializable ) - elif ( - verbose is not None - and field is not None - and FieldHintImportance[field.hint] > verbose - and value == field.default - ): - # Exclude unimportant default values. - return - else: + elif explicit_field: field_value = value if serializable: - field_value = cls._serialize_value(value) + field_value = self._serialize_value(value) if format_ == _ConfigDictFormat.tuple: field_value = {(): field_value} + else: + # Exclude unimportant (implicit or explicit) default values. + return if serializable: - name = cls._serialize_value(name) + name = self._serialize_value(name) if format_ == _ConfigDictFormat.tuple: args.update({(name,) + name_: value_ for name_, value_ in field_value.items()}) elif format_ == _ConfigDictFormat.nested: - if not isinstance(field_value, (dict, list)) or len(field_value) > 0 or all_fields: + if not isinstance(field_value, (dict, list)) or len(field_value) > 0 or explicit_field or all_fields: if isinstance(args, dict): args[name] = field_value else: @@ -671,6 +716,7 @@ def from_dict( default: typing.Union["Config", dict[str, typing.Any]], *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], strict: bool = True, + update_type: UpdateType = UpdateType.override, ) -> typing.Self: if isinstance(default, Config): default = default._to_dict() @@ -678,7 +724,7 @@ def from_dict( if isinstance(update, Config): update = update._to_dict(format_=_ConfigDictFormat.tuple) for keys, value in update.items(): - set_nested_dict_value(default, keys, value) + set_nested_dict_value(default, keys, value, update_type) return cls._from_dict(default, strict) @@ -712,10 +758,7 @@ def _from_dict( continue if flat: if isinstance(field.type, type) and issubclass(field.type, Config): - if flat: - out_arg_dict[name] = field.type._from_dict(default, False, True) - else: - out_arg_dict[name] = field.type._from_dict(default.pop(name, {}), strict) + out_arg_dict[name] = field.type._from_dict(default, False, True) elif name in default: out_arg_dict[name] = default.pop(name) else: @@ -916,3 +959,69 @@ def __init__(self, config: ConfigType, *args, **kwargs): @property def config(self) -> ConfigType: return self._config + + +def set_nested_dict_value[ + KeyType, ValueType +]( + d: dict[KeyType, ValueType], + keys: KeyType | tuple[KeyType, ...], + value: ValueType, + update_type: UpdateType = UpdateType.override, +) -> None: + if isinstance(keys, tuple): + for key in keys[:-1]: + d = d.setdefault(key, {}) + assert isinstance(d, dict) + key = keys[-1] + else: + key = keys + if update_type == UpdateType.override: + d[key] = value + elif update_type == UpdateType.update: + # TODO: Improve error messages, ex. for nested cases? + if isinstance(d[key], Config): + raise ValueError("Cannot update an already instantiated config.") + elif isinstance(value, Config): + raise ValueError("Cannot update a config dict with an already instantiated config.") + elif isinstance(d, dict): + if key in d: + Assert.custom(isinstance, d[key], dict) + else: + d[key] = {} + for key_, value_ in value.items(): + set_nested_dict_value(d, key_, value_, update_type) + elif ( + isinstance(value, (list, set, tuple)) + and any(isinstance(value_, (list, set, tuple, dict, Config)) for value_ in value) + ) or ( + isinstance(d[key], (list, set, tuple)) + and any(isinstance(value_, (list, set, tuple, dict, Config)) for value_ in d[key]) + ): + raise ValueError("Update not supported for nested lists.") + else: + d[key] = value + else: + raise NotImplementedError(update_type) + + +def get_nested_dict_value[ + KeyType, ValueType +](d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...]) -> ValueType: + if isinstance(keys, tuple): + for key in keys: + d = d[key] + return d + else: + return d[keys] + + +def pop_nested_dict_value[ + KeyType, ValueType +](d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...]) -> ValueType: + if isinstance(keys, tuple): + for key in keys[:-1]: + d = d[key] + return d.pop(keys[-1]) + else: + return d.pop(keys) diff --git a/fast_llm/data/data/config.py b/fast_llm/data/data/config.py index 752fdfd1..25850ac3 100644 --- a/fast_llm/data/data/config.py +++ b/fast_llm/data/data/config.py @@ -1,18 +1,9 @@ import typing -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, config_class +from fast_llm.config import Config, Field, config_class from fast_llm.data.dataset.config import SamplingConfig, SamplingData -@config_class() -class SamplingDefaultConfig(SamplingConfig): - seed: int = FieldUpdate( - default=784569, - desc="Seed for random sampling.", - hint=FieldHint.feature, - ) - - @config_class() class DataConfig(Config): _abstract = True diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index cbbfa036..d1d6bd40 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -2,13 +2,12 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.config import MultiprocessingContext, TokenizerConfig -from fast_llm.data.data.config import DataConfig, SamplingDefaultConfig +from fast_llm.data.data.config import DataConfig from fast_llm.data.dataset.gpt.config import ( GPTLegacyConfig, GPTLegacyDatasetConfig, GPTSampledDatasetConfig, GPTSamplingConfig, - ShufflingType, ) from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert @@ -16,13 +15,6 @@ logger = logging.getLogger(__name__) -@config_class() -class GPTSamplingDefaultConfig(SamplingDefaultConfig, GPTSamplingConfig): - gpu: bool = FieldUpdate(default=True) - use_loss_masking_spans: bool = FieldUpdate(default=False) - shuffle: ShufflingType = FieldUpdate(default=ShufflingType.epoch) - - @config_class() class GPTDataConfig(DataConfig, GPTLegacyConfig): """ @@ -44,7 +36,7 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): desc="Configuration for the dataset(s).", hint=FieldHint.core, ) - sampling: GPTSamplingDefaultConfig = FieldUpdate(default_factory=GPTSamplingDefaultConfig) + sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig) data_sample_warn_time_ms: float = Field( default=1000, desc="Warn if a sample takes too long to load.", diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 431a28a0..7808158b 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -5,7 +5,7 @@ import pathlib import typing -from fast_llm.config import Config, Field, FieldHint, FieldVerboseLevel, check_field, config_class +from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert, normalize_probabilities @@ -17,20 +17,12 @@ @config_class() class SamplingConfig(Config): - seed: int | None = Field( - default=None, + seed: int = Field( + default=784569, desc="Seed for random sampling.", hint=FieldHint.feature, ) - @property - def updates(self) -> dict[str, typing.Any]: - return { - key: value - for key, value in self.to_serialized(verbose=FieldVerboseLevel.everything).items() - if value is not None - } - @dataclasses.dataclass(kw_only=True) class SamplingData: @@ -44,10 +36,10 @@ class SamplingData: # Using a mutable rather than an int so it's shared with all copies made with `update`. _rank_counter: typing.Iterator[int] = itertools.count - def update(self, config: SamplingConfig, **kwargs): - if config_updates := config.updates: - kwargs["config"] = self.config.to_copy(config_updates) - return dataclasses.replace(self, **kwargs) if kwargs else self + def update_config(self, update: SamplingConfig): + return dataclasses.replace( + self, config=self.config.from_dict(self.config, update, update_type=UpdateType.update) + ) def get_next_rank(self) -> int: # Counter that loops over ranks to try to distribute workloads evenly between ranks. @@ -163,7 +155,7 @@ class SampledDatasetUpdateConfig(SampledDatasetConfig): Only explicitly set parameters (not None) will be updated, other will still be taken from `build_and_sample`'s argument. """ - _abstract = False + _abstract = True sampling: SamplingConfig = Field( default_factory=SamplingConfig, desc="Optional override to sampling configuration parameters.", @@ -176,7 +168,7 @@ class SampledDatasetUpdateConfig(SampledDatasetConfig): ) def build_and_sample(self, data: SamplingData) -> SampledDataset: - return self.dataset.build_and_sample(data.update(self.sampling)) + return self.dataset.build_and_sample(data.update_config(self.sampling)) @config_class() diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 74d8a0c3..118b3039 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -45,20 +45,20 @@ class ShufflingType(str, enum.Enum): @config_class() class GPTSamplingConfig(SamplingConfig): - gpu: bool | None = Field( - default=None, + gpu: bool = Field( + default=True, desc="Enable fast sampling on GPU." " Note that random sampling works differently on GPU," " so the sample won't match the CPU equivalent.", hint=FieldHint.feature, ) - use_loss_masking_spans: bool | None = Field( - default=None, + use_loss_masking_spans: bool = Field( + default=False, desc="Read loss masking spans from the dataset.", hint=FieldHint.feature, ) - shuffle: ShufflingType | None = Field( - default=None, + shuffle: ShufflingType = Field( + default=ShufflingType.epoch, desc="Shuffling strategy.", hint=FieldHint.feature, ) @@ -210,6 +210,7 @@ def build(self) -> "GPTDatasetSlice": @config_class() class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig): + _abstract = False type_: typing.ClassVar[str | None] = "sampled" sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig) dataset: GPTSampledDatasetConfig = FieldUpdate(default_factory=GPTSampledDatasetConfig) diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index 92f1165d..46c8f483 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -164,8 +164,9 @@ class CheckpointStateSaveConfigBase(CheckpointSaveConfigBase, CheckpointStateCon def _validate(self) -> None: if self.optimizer_state is None: - # TODO: Make sure it's a type - self.optimizer_state = self.format.support_optimizer + with self._set_implicit_default(): + # TODO: Make sure it's a type + self.optimizer_state = self.format.support_optimizer super()._validate() if self.optimizer_state: assert self.format.support_optimizer diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 83514c86..76f5e336 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -7,14 +7,14 @@ import torch from fast_llm import __version__ -from fast_llm.config import MISSING +from fast_llm.config import MISSING, get_nested_dict_value, set_nested_dict_value from fast_llm.engine.base_model.config import BaseModelArchitectureConfig from fast_llm.engine.checkpoint.config import CheckpointLoadMetadataConfig from fast_llm.engine.checkpoint.state_dict import StateDictCheckpointHandler from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.tensor import SafeTensorSlice -from fast_llm.utils import Assert, get_nested_dict_value, set_nested_dict_value +from fast_llm.utils import Assert logger = logging.getLogger(__name__) diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 1b3e73bb..76c496ac 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -279,7 +279,8 @@ def _validate(self) -> None: self.tensor_rank = self.rank % self.tensor_parallel if self.tensor_parallel == 1: - self.sequence_tensor_parallel = False + with self._set_implicit_default(): + self.sequence_tensor_parallel = False self.distributed_dims = {} diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 83d3d51a..91256deb 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -79,10 +79,6 @@ def setup(self, distributed_config: DistributedConfig) -> None: def num_inputs(self) -> int: return self.sequential_micro_batches * self.num_micro_sequences - @property - def _is_setup(self) -> bool: - return hasattr(self, "_distributed") - def _validate(self) -> None: # Use the distributed properties to determine the batch size and its breakdown. # Requires post-processed distributed config args @@ -133,7 +129,8 @@ def _validate(self) -> None: " Use at your own risk." ) if self.micro_sequence_length is None: - self.micro_sequence_length = self.sequence_length + with self._set_implicit_default(): + self.micro_sequence_length = self.sequence_length self.num_micro_sequences = div(self.sequence_length, self.micro_sequence_length) super()._validate() diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 30add2f4..3a65bbc9 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -42,7 +42,8 @@ class IntervalConfig(Config): def _validate(self) -> None: if self.interval: - self.offset %= self.interval + with self._set_implicit_default(): + self.offset %= self.interval super()._validate() def enabled(self, iteration: int | None = None) -> bool: @@ -109,6 +110,7 @@ class WandbAlertConfig(IntervalConfig): "The update may be posted by email and/or slack depending on the Wandb account configuration.", hint=FieldHint.feature, ) + post_alerts: bool = Field(init=False, repr=False) def _validate(self) -> None: if self.status_updates is None: diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 8e3a467c..fa5d4920 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -60,7 +60,8 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig): def _validate(self) -> None: if self.use_position_embeddings is None: - self.use_position_embeddings = not self.transformer.rotary.enabled + with self._set_implicit_default(): + self.use_position_embeddings = not self.transformer.rotary.enabled super()._validate() def setup_tensor_space(self, tensor_space: TensorSpace) -> None: @@ -175,14 +176,14 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): ) def _validate(self) -> None: - if self.transformer.init_method_std is None: - self.transformer.init_method_std = self.transformer.hidden_size**-0.5 - if self.init_method_std_embed is None: - self.init_method_std_embed = self.transformer.init_method_std - if self.init_method_max_embed is None: - self.init_method_max_embed = self.transformer.init_method_max - if self.init_method_min_embed is None: - self.init_method_min_embed = self.transformer.init_method_min - if self.init_method_max_embed is not None and self.init_method_min_embed is not None: - Assert.leq(self.init_method_min_embed, self.init_method_max_embed) + self.transformer.validate() + with self._set_implicit_default(): + if self.init_method_std_embed is None: + self.init_method_std_embed = self.transformer.init_method_std + if self.init_method_max_embed is None: + self.init_method_max_embed = self.transformer.init_method_max + if self.init_method_min_embed is None: + self.init_method_min_embed = self.transformer.init_method_min + if self.init_method_max_embed is not None and self.init_method_min_embed is not None: + Assert.leq(self.init_method_min_embed, self.init_method_max_embed) super()._validate() diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 1352c7f0..13983137 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -250,12 +250,13 @@ class TransformerArchitectureConfig(BaseModelArchitectureConfig): ) def _validate(self) -> None: - if self.ffn_hidden_size is None: - self.ffn_hidden_size = 4 * self.hidden_size - if self.kv_channels is None: - self.kv_channels = div(self.hidden_size, self.num_attention_heads) - if self.activation_type is None: - self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu + with self._set_implicit_default(): + if self.ffn_hidden_size is None: + self.ffn_hidden_size = 4 * self.hidden_size + if self.kv_channels is None: + self.kv_channels = div(self.hidden_size, self.num_attention_heads) + if self.activation_type is None: + self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu self.projection_size = self.num_attention_heads * self.kv_channels self.num_unshared_experts = self.num_experts - self.num_shared_experts @@ -569,46 +570,47 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): ) def _validate(self) -> None: - if self.init_method_std is None: - self.init_method_std = self.hidden_size**-0.5 - if self.init_method_std_qkv is None: - self.init_method_std_qkv = self.init_method_std - if self.init_method_std_attn_proj is None: - self.init_method_std_attn_proj = self.init_method_std / (2 * self.num_layers) ** 0.5 - if self.init_method_std_mlp_1 is None: - self.init_method_std_mlp_1 = self.init_method_std - if self.init_method_std_mlp_2 is None: - self.init_method_std_mlp_2 = self.init_method_std / (2 * self.num_layers) ** 0.5 - if self.mlp_lr_scale is None or len(self.mlp_lr_scale) == 0: - self.mlp_lr_scale = [None] - if self.init_method_max_qkv is None: - self.init_method_max_qkv = self.init_method_max - if self.init_method_min_qkv is None: - self.init_method_min_qkv = self.init_method_min - if self.init_method_max_attn_proj is None: - self.init_method_max_attn_proj = self.init_method_max - if self.init_method_min_attn_proj is None: - self.init_method_min_attn_proj = self.init_method_min - if self.init_method_max_mlp_1 is None: - self.init_method_max_mlp_1 = self.init_method_max - if self.init_method_min_mlp_1 is None: - self.init_method_min_mlp_1 = self.init_method_min - if self.init_method_max_mlp_2 is None: - self.init_method_max_mlp_2 = self.init_method_max - if self.init_method_min_mlp_2 is None: - self.init_method_min_mlp_2 = self.init_method_min - if self.init_method_min is not None and self.init_method_max is not None: - Assert.leq(self.init_method_min, self.init_method_max) - if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: - Assert.leq(self.init_method_min, self.init_method_max) - if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: - Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) - if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: - Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) - if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: - Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) - if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: - Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) + with self._set_implicit_default(): + if self.init_method_std is None: + self.init_method_std = self.hidden_size**-0.5 + if self.init_method_std_qkv is None: + self.init_method_std_qkv = self.init_method_std + if self.init_method_std_attn_proj is None: + self.init_method_std_attn_proj = self.init_method_std / (2 * self.num_layers) ** 0.5 + if self.init_method_std_mlp_1 is None: + self.init_method_std_mlp_1 = self.init_method_std + if self.init_method_std_mlp_2 is None: + self.init_method_std_mlp_2 = self.init_method_std / (2 * self.num_layers) ** 0.5 + if self.mlp_lr_scale is None or len(self.mlp_lr_scale) == 0: + self.mlp_lr_scale = [None] + if self.init_method_max_qkv is None: + self.init_method_max_qkv = self.init_method_max + if self.init_method_min_qkv is None: + self.init_method_min_qkv = self.init_method_min + if self.init_method_max_attn_proj is None: + self.init_method_max_attn_proj = self.init_method_max + if self.init_method_min_attn_proj is None: + self.init_method_min_attn_proj = self.init_method_min + if self.init_method_max_mlp_1 is None: + self.init_method_max_mlp_1 = self.init_method_max + if self.init_method_min_mlp_1 is None: + self.init_method_min_mlp_1 = self.init_method_min + if self.init_method_max_mlp_2 is None: + self.init_method_max_mlp_2 = self.init_method_max + if self.init_method_min_mlp_2 is None: + self.init_method_min_mlp_2 = self.init_method_min + if self.init_method_min is not None and self.init_method_max is not None: + Assert.leq(self.init_method_min, self.init_method_max) + if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: + Assert.leq(self.init_method_min, self.init_method_max) + if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: + Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) + if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: + Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) + if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: + Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) + if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: + Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) super()._validate() Assert.geq(self.attention_dropout, 0) Assert.geq(self.hidden_dropout, 0) diff --git a/fast_llm/profile.py b/fast_llm/profile.py index a0fc3946..a3902cf1 100644 --- a/fast_llm/profile.py +++ b/fast_llm/profile.py @@ -94,7 +94,9 @@ def _validate(self) -> None: self.global_attention_layers = set() profile_ranks = set(self.ranks or []) Assert.eq(len(profile_ranks), len(self.ranks or [])) - self.ranks = profile_ranks # noqa + with self._set_implicit_default(): + self.ranks = profile_ranks # noqa + super()._validate() def get_profiler( self, *, distributed_config: DistributedConfig | None = None, start_step: int = 0 diff --git a/fast_llm/utils.py b/fast_llm/utils.py index d650fa94..b0e48231 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -249,40 +249,6 @@ def normalize_probabilities(p: "npt.ArrayLike", return_array: bool = False) -> " return out if return_array else out.tolist() -def set_nested_dict_value[ - KeyType, ValueType -](d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...], value: ValueType) -> None: - if isinstance(keys, tuple): - for key in keys[:-1]: - d = d.setdefault(key, {}) - assert isinstance(d, dict) - d[keys[-1]] = value - else: - d[keys] = value - - -def get_nested_dict_value[ - KeyType, ValueType -](d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...]) -> ValueType: - if isinstance(keys, tuple): - for key in keys: - d = d[key] - return d - else: - return d[keys] - - -def pop_nested_dict_value[ - KeyType, ValueType -](d: dict[KeyType, ValueType], keys: KeyType | tuple[KeyType, ...]) -> ValueType: - if isinstance(keys, tuple): - for key in keys[:-1]: - d = d[key] - return d.pop(keys[-1]) - else: - return d.pop(keys) - - class InvalidObject: """ Store an error and raise it if accessed. diff --git a/tests/data/common.py b/tests/data/common.py index 917b4914..bdfd54a7 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -5,12 +5,13 @@ import torch from fast_llm.config import Field, FieldHint, NoAutoValidate, config_class -from fast_llm.data.data.gpt.config import GPTDataConfig, GPTSamplingDefaultConfig +from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import ( GPTIndexedDatasetConfig, GPTSampledDatasetConfig, + GPTSamplingConfig, GPTSamplingData, ShufflingType, ) @@ -39,7 +40,7 @@ def get_sampling_data( ) -> GPTSamplingData: # Config with convenient defaults. return GPTSamplingData( - config=GPTSamplingDefaultConfig( + config=GPTSamplingConfig( seed=seed, gpu=gpu, shuffle=shuffle, @@ -76,7 +77,7 @@ def get_test_data_and_compare_samples( distributed_config = DistributedConfig(seed=seed if legacy else 87522) distributed = Distributed(distributed_config, use_cpu=True) assert "sampling" not in config - config["sampling"] = GPTSamplingDefaultConfig( + config["sampling"] = GPTSamplingConfig( seed=87522 if legacy else seed, gpu=gpu, shuffle=shuffle, From f26010ef9f8cfd070734751f9dec45a364496308 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 26 Mar 2025 21:21:45 -0400 Subject: [PATCH 002/114] Update pretrained config --- fast_llm/config.py | 5 +-- fast_llm/engine/checkpoint/config.py | 1 + fast_llm/engine/checkpoint/distributed.py | 6 +-- fast_llm/engine/huggingface/config.py | 5 +-- fast_llm/engine/huggingface/model.py | 8 ++-- fast_llm/engine/multi_stage/config.py | 42 +++---------------- fast_llm/engine/multi_stage/fast_llm_model.py | 7 ++-- fast_llm/layers/transformer/config.py | 17 ++++---- fast_llm/layers/transformer/mlp.py | 4 +- 9 files changed, 28 insertions(+), 67 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 326845f0..5436a294 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -374,7 +374,6 @@ def validate[T](self: T, *, _is_validating: bool = False) -> T: else: raise type(e)("\n".join(e.args)) from None self._validated = True - print("WLIEHGIUWERGNHBWIO", self.__class__.__name__, self._explicit_fields) return self def _validate(self) -> None: @@ -713,8 +712,8 @@ def _get_class_name(cls) -> str: @classmethod def from_dict( cls, - default: typing.Union["Config", dict[str, typing.Any]], - *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], + default: "Config| dict[str, typing.Any]]", + *updates: "Config| dict[str | tuple[str, ...], typing.Any]", strict: bool = True, update_type: UpdateType = UpdateType.override, ) -> typing.Self: diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index 46c8f483..621f7fe8 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -200,6 +200,7 @@ class CheckpointSaveConfig(CheckpointSaveMetadataConfig, CheckpointStateSaveConf @config_class() class CheckpointLoadMetadataConfig(CheckpointPathConfigBase): + # TODO!!!!!!! _abstract = False load_config: ModelConfigType = Field( diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index 9c171bef..a920a52c 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -13,7 +13,6 @@ CheckpointLoadMetadataConfig, CheckpointSaveConfig, DistributedCheckpointFormat, - ModelConfigType, export_safetensors_metadata, ) from fast_llm.engine.checkpoint.safe_load import SafeLoad @@ -43,15 +42,14 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None: # TODO: More safety checks - loaded_config_dict = config.to_copy({"load_config": ModelConfigType.fast_llm}) - loaded_config = self._model.config_class.from_metadata(loaded_config_dict, metadata) num_shards = self.get_num_shards(config) shard_names = self.get_shard_names(config) Assert.eq(metadata.shards[:num_shards], list(shard_names)) same_format = ( - loaded_config.to_serialized(verbose=None) == self._model.config.to_serialized(verbose=None) + type(metadata.config) == type(self._model.config) and config.optimizer_state + and metadata.config.to_serialized(verbose=None) == self._model.config.to_serialized(verbose=None) ) # Make sure all nodes agree on which loading scheme to use. # Note: they may not agree before the broadcast because of the rank comparison, but that's ok. diff --git a/fast_llm/engine/huggingface/config.py b/fast_llm/engine/huggingface/config.py index e02abc28..e79857c9 100644 --- a/fast_llm/engine/huggingface/config.py +++ b/fast_llm/engine/huggingface/config.py @@ -73,10 +73,7 @@ def _get_config_dict( torch_dtype = kwargs.pop("torch_dtype", None) if torch_dtype is not None: updates[("distributed", "training_dtype")] = torch_dtype - fast_llm_config = cls.model_config_class.from_metadata( - pretrained, metadata, default=kwargs.pop("fast_llm_config", None), updates=updates - ) - + fast_llm_config = cls.model_config_class.from_dict(metadata.config, kwargs.pop("fast_llm_config", {}), updates) config_dict = {"fast_llm_config": fast_llm_config} return config_dict, kwargs diff --git a/fast_llm/engine/huggingface/model.py b/fast_llm/engine/huggingface/model.py index 499f0af1..e4f2cd99 100644 --- a/fast_llm/engine/huggingface/model.py +++ b/fast_llm/engine/huggingface/model.py @@ -73,15 +73,13 @@ def from_pretrained( format=FastLLMCheckpointFormat, ) - config_updates = {} + updates = {} torch_dtype = kwargs.pop("torch_dtype", None) if torch_dtype is not None: - config_updates[("distributed", "training_dtype")] = torch_dtype + updates[("distributed", "training_dtype")] = torch_dtype # Create the model - fast_llm_model = cls.model_class.from_pretrained( - pretrained_model_name_or_path, config_updates=config_updates, mode=mode - ) + fast_llm_model = cls.model_class.from_pretrained(pretrained_model_name_or_path, updates, mode=mode) config = cls.config_class(fast_llm_model.config) return cls(config, fast_llm_model, **kwargs) diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index d6997105..d8333c9b 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -246,46 +246,12 @@ def get_base_model_config_class(cls) -> type[BaseModelConfig]: @classmethod def from_pretrained( - cls, - pretrained: CheckpointLoadMetadataConfig, - default: typing.Self | None = None, + cls, pretrained: CheckpointLoadMetadataConfig, *updates: Config | dict[str | tuple[str, ...], typing.Any] ) -> typing.Self: # TODO: Add *updates? assert pretrained.path is not None metadata = cls.load_metadata(pretrained) - return cls.from_metadata(pretrained, metadata, default) - - @classmethod - def from_metadata( - cls, - pretrained: CheckpointLoadMetadataConfig, - metadata: "CheckpointMetadata", - default: typing.Self | None = None, - updates: dict[str | tuple[str, ...], typing.Any] | None = None, - ) -> typing.Self: - # TODO: Standardize to *updates? - # TODO v0.3: Update, remove support for older checkpoints. - if metadata.fast_llm_version.major != 0 or metadata.fast_llm_version.minor not in (0, 1, 2): - raise ValueError(f"Invalid checkpoint version: {metadata.fast_llm_version}") - pretrained_config = cls.from_dict(metadata.config) - if not pretrained.load_config.load_architecture: - assert default is not None - config = default.to_copy() - config.base_model.compare_architecture(pretrained_config.base_model, pretrained.compare_log_fn) - elif pretrained.load_config.load_fast_llm: - config = pretrained_config - else: - with NoAutoValidate(): - config = cls() if default is None else default.to_copy() - if pretrained.load_config.load_base_model: - config.base_model = pretrained_config.base_model - else: - config.base_model = config.base_model.to_copy(pretrained_config.base_model.get_architecture()) - config.validate() - - if updates: - config = config.to_copy(updates) - return config + return cls.from_dict(metadata.config, *updates) @classmethod def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata": @@ -328,7 +294,7 @@ def _validate(self) -> None: self.pretrained.setup(self.model) self.pretrained.validate() if self.pretrained.path is not None: - self.model = self.model.from_pretrained(self.pretrained, default=self.model) + self.model = self.model.from_pretrained(self.pretrained, self.model) self._setup() super()._validate() @@ -380,6 +346,8 @@ def _validate(self) -> None: self.format = self.model.get_checkpoint_format(self.format) super()._validate() + if self.fast_llm_version.major != 0 or self.fast_llm_version.minor not in (0, 1, 2): + raise ValueError(f"Invalid checkpoint version: {self.fast_llm_version}") Assert.eq(self.config.__class__, self.model) @classmethod diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index b268ec29..22e5ccac 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -1,6 +1,7 @@ import logging import typing +from fast_llm.config import UpdateType from fast_llm.core.distributed import broadcast from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig from fast_llm.engine.distributed.distributed import Distributed @@ -45,9 +46,7 @@ def load_checkpoint(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] def from_pretrained( cls, pretrained_config: CheckpointLoadConfig, - default_config: FastLLMModelConfig = None, - *, - config_updates: dict[str | tuple[str, ...], typing.Any] | None = None, + *updates: dict[str | tuple[str, ...], typing.Any], optimizer_state_names: tuple[str, ...] | None = None, setup: bool = True, mode: StageMode = StageMode.training, @@ -55,7 +54,7 @@ def from_pretrained( stage_filter: set | None = None, ) -> typing.Self: metadata = cls.config_class.load_metadata(pretrained_config) - config = cls.config_class.from_metadata(pretrained_config, metadata, default_config, config_updates) + config = cls.config_class.from_dict(metadata.config, *updates, update_type=UpdateType.update) if mode.support_training: # TODO v0.3: Make metadata.shards mandatory? if metadata.shards: diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 13983137..9410157b 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -532,8 +532,8 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - mlp_lr_scale: list[float | None] = Field( - default_factory=list, + mlp_lr_scale: float | None | list[float | None] = Field( + default=None, desc="Custom learning rate scale for each expert.", doc="May be used to freeze some experts by setting their scale to zero.", hint=FieldHint.feature, @@ -581,8 +581,6 @@ def _validate(self) -> None: self.init_method_std_mlp_1 = self.init_method_std if self.init_method_std_mlp_2 is None: self.init_method_std_mlp_2 = self.init_method_std / (2 * self.num_layers) ** 0.5 - if self.mlp_lr_scale is None or len(self.mlp_lr_scale) == 0: - self.mlp_lr_scale = [None] if self.init_method_max_qkv is None: self.init_method_max_qkv = self.init_method_max if self.init_method_min_qkv is None: @@ -614,10 +612,13 @@ def _validate(self) -> None: super()._validate() Assert.geq(self.attention_dropout, 0) Assert.geq(self.hidden_dropout, 0) - Assert.incl(len(self.mlp_lr_scale), (1, self.num_experts)) - for scale in self.mlp_lr_scale: - if scale is not None: - Assert.geq(scale, 0) + if isinstance(self.mlp_lr_scale, list): + Assert.eq(len(self.mlp_lr_scale), self.num_experts) + for scale in self.mlp_lr_scale: + if scale is not None: + Assert.geq(scale, 0) + elif self.mlp_lr_scale is not None: + Assert.geq(self.mlp_lr_scale, 0) def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: use_flash_attention = self.use_flash_attention and distributed_config.training_dtype in ( diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index adc6242d..ff4eaf26 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -45,7 +45,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, - lr_scale=tuple(config.mlp_lr_scale), + lr_scale=tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale, ) self.layer_2 = LinearBase( self._intermediate_dim, @@ -55,7 +55,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s bias_init_method=init_method_2 if config.random_bias_init else init_zeros_, auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1, transposed_weight=True, - lr_scale=tuple(config.mlp_lr_scale), + lr_scale=tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale, ) From b930a391b37703e7dce23fdb544b08fe98d42084 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 26 Mar 2025 21:27:40 -0400 Subject: [PATCH 003/114] stuff --- fast_llm/config.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 326845f0..5436a294 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -374,7 +374,6 @@ def validate[T](self: T, *, _is_validating: bool = False) -> T: else: raise type(e)("\n".join(e.args)) from None self._validated = True - print("WLIEHGIUWERGNHBWIO", self.__class__.__name__, self._explicit_fields) return self def _validate(self) -> None: @@ -713,8 +712,8 @@ def _get_class_name(cls) -> str: @classmethod def from_dict( cls, - default: typing.Union["Config", dict[str, typing.Any]], - *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], + default: "Config| dict[str, typing.Any]]", + *updates: "Config| dict[str | tuple[str, ...], typing.Any]", strict: bool = True, update_type: UpdateType = UpdateType.override, ) -> typing.Self: From 8117c47b483c26853bf5015ef85b4e94472de1b1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 26 Mar 2025 21:40:37 -0400 Subject: [PATCH 004/114] fixes --- fast_llm/engine/multi_stage/config.py | 7 +++---- fast_llm/engine/multi_stage/fast_llm_model.py | 6 +++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index d8333c9b..342a453b 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -10,6 +10,7 @@ Field, FieldHint, NoAutoValidate, + UpdateType, ValidationError, check_field, config_class, @@ -248,13 +249,11 @@ def get_base_model_config_class(cls) -> type[BaseModelConfig]: def from_pretrained( cls, pretrained: CheckpointLoadMetadataConfig, *updates: Config | dict[str | tuple[str, ...], typing.Any] ) -> typing.Self: - # TODO: Add *updates? - assert pretrained.path is not None - metadata = cls.load_metadata(pretrained) - return cls.from_dict(metadata.config, *updates) + return cls.from_dict(cls.load_metadata(pretrained).config, *updates, update_type=UpdateType.update) @classmethod def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata": + assert config.path is not None with NoAutoValidate(): metadata = config.format.get_handler_class().load_metadata(config) try: diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index 22e5ccac..2dec7959 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -36,11 +36,11 @@ def load_checkpoint(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] # TODO: Test with more distributed configs. # TODO: Safety checks # TODO: Handle barriers, ok file, etc. here - fast_llm_metadata = self.config_class.load_metadata(config) + metadata = self.config_class.load_metadata(config) converter = config.format.get_handler_class()(self) - converter.load(config, fast_llm_metadata) + converter.load(config, metadata) self._finalize_load(reset_optimizer=not config.optimizer_state) - return fast_llm_metadata.metadata + return metadata.metadata @classmethod def from_pretrained( From 1c995d3e76be57ec80f9f305d83b613e0c8bdba3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 26 Mar 2025 21:50:00 -0400 Subject: [PATCH 005/114] fix --- fast_llm/engine/checkpoint/config.py | 16 ---------------- fast_llm/engine/checkpoint/distributed.py | 6 +++--- fast_llm/engine/checkpoint/huggingface.py | 2 +- 3 files changed, 4 insertions(+), 20 deletions(-) diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index 621f7fe8..a3472523 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -200,24 +200,8 @@ class CheckpointSaveConfig(CheckpointSaveMetadataConfig, CheckpointStateSaveConf @config_class() class CheckpointLoadMetadataConfig(CheckpointPathConfigBase): - # TODO!!!!!!! _abstract = False - load_config: ModelConfigType = Field( - default=ModelConfigType.architecture, - desc="Configuration to save/load.", - hint=FieldHint.core, - ) - - def _validate(self) -> None: - super()._validate() - if self.format.enforce_architecture_match: - assert self.load_config.load_architecture - - @property - def compare_log_fn(self): - return ValueError if self.load_config.load_architecture else logger.warning - @config_class() class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase): diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index a920a52c..953cdef8 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -67,11 +67,11 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No self._model.state_shard[:num_shards].copy_(f.get_slice("state_shard")[:num_shards]) else: log_main_rank("Checkpoint format doesn't match, using safe load") - self._model.config.base_model.compare_architecture(loaded_config.base_model, config.compare_log_fn) + self._model.config.base_model.compare_architecture(metadata.config.base_model, logger.warning) with SafeLoad(self._model, num_shards=num_shards, timeout=config.timeout) as context: - for rank in range(loaded_config.distributed.world_size): + for rank in range(metadata.config.distributed.world_size): loaded_model = self._model.__class__( - loaded_config.to_copy({("distributed", "rank"): rank}), + metadata.config.to_copy({("distributed", "rank"): rank}), optimizer_state_names=shard_names[1:], verbose=False, ) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 87651dc4..d4533663 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -41,7 +41,7 @@ def _serialize_metadata(self, config: CheckpointSaveMetadataConfig, metadata: Ch def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None: assert not config.optimizer_state - self._model.config.base_model.compare_architecture(metadata.config.base_model, config.compare_log_fn) + self._model.config.base_model.compare_architecture(metadata.config.base_model, logger.warning) super().load(config, metadata) @classmethod From 506fe92917b28fc2d865edf69bad9827c5f92bfa Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 27 Mar 2025 16:04:35 -0400 Subject: [PATCH 006/114] fixes --- fast_llm/config.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 5436a294..222a3ec7 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -266,13 +266,21 @@ def config_class(cls=None): def wrap(cls): Assert.custom(issubclass, cls, Config) - wrapped = _process_config_class(dataclasses.dataclass(cls)) + if hasattr(cls, "__post_init__"): + raise TypeError(f"`__post_init__` should not be implemented for `Config` classes") + + wrapped = _process_config_class(dataclasses.dataclass(cls, kw_only=True)) wrapped_init = cls.__init__ def __init__(self, **kwargs): + # This is similar to `__post_init__`, but has access to the list of arguments passed to `__init__`. wrapped_init(self, **kwargs) self._explicit_fields = set(kwargs) + self._validated = False + self._setting_implicit_default = False + if _AUTO_VALIDATE: + self.validate() cls.__init__ = __init__ return wrapped @@ -310,17 +318,6 @@ class Config: # without them being automatically added to `_explicit_fields`. _setting_implicit_default: bool = Field(init=False, repr=False) - def __post_init__(self): - """ - Perform validation unless prevented with `NoAutoValidate`. - In general this should not be overridden in derived classes, - and all post-processing should be done in `_validate` - """ - self._validated = False - self._setting_implicit_default = False - if _AUTO_VALIDATE: - self.validate() - def __setattr__(self, key: str, value: typing.Any) -> None: """ Make the class read-only after validation. @@ -983,13 +980,15 @@ def set_nested_dict_value[ raise ValueError("Cannot update an already instantiated config.") elif isinstance(value, Config): raise ValueError("Cannot update a config dict with an already instantiated config.") - elif isinstance(d, dict): + elif isinstance(value, dict): if key in d: Assert.custom(isinstance, d[key], dict) else: d[key] = {} for key_, value_ in value.items(): set_nested_dict_value(d, key_, value_, update_type) + elif isinstance(d[key], dict): + raise ValueError("Cannot replace a dict with a non-dict value.") elif ( isinstance(value, (list, set, tuple)) and any(isinstance(value_, (list, set, tuple, dict, Config)) for value_ in value) From 971d3ef23297f7dd64550facff25f8609c0fb097 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 27 Mar 2025 18:32:07 -0400 Subject: [PATCH 007/114] fixes --- fast_llm/config.py | 14 +++++++++----- fast_llm/engine/huggingface/config.py | 5 +++-- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 222a3ec7..7cb54919 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -90,12 +90,12 @@ class FieldHint: class FieldVerboseLevel: - nothing = -1 + explicit = None core = 0 optional = 10 performance = 20 debug = 50 - everything = None + everything = 2**31 FieldHintDoc = { @@ -680,7 +680,7 @@ def to_copy[ ](self: T, *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], strict: bool = True,) -> T: return self.from_dict(self, *updates, strict=strict) - def to_serialized(self, verbose: int | None = FieldVerboseLevel.core) -> dict[str, typing.Any]: + def to_serialized(self, verbose: int | None = FieldVerboseLevel.explicit) -> dict[str, typing.Any]: return self._to_dict(verbose=verbose, format_=_ConfigDictFormat.nested, serializable=True) def to_logs[ @@ -863,8 +863,12 @@ def _handle_renamed_field( def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typing.Callable] = ValueError): # TODO: Check classes? - self_dict = self._to_dict(format_=_ConfigDictFormat.tuple, serializable=True) - other_dict = other._to_dict(format_=_ConfigDictFormat.tuple, serializable=True) + self_dict = self._to_dict( + format_=_ConfigDictFormat.tuple, serializable=True, verbose=FieldVerboseLevel.everything + ) + other_dict = other._to_dict( + format_=_ConfigDictFormat.tuple, serializable=True, verbose=FieldVerboseLevel.everything + ) compare = { key: (self_dict.get(key, MISSING), other_dict.get(key, MISSING)) for key in self_dict.keys() | other_dict.keys() diff --git a/fast_llm/engine/huggingface/config.py b/fast_llm/engine/huggingface/config.py index e02abc28..2b240e4b 100644 --- a/fast_llm/engine/huggingface/config.py +++ b/fast_llm/engine/huggingface/config.py @@ -5,6 +5,7 @@ import transformers +from fast_llm.config import FieldVerboseLevel from fast_llm.engine.checkpoint.config import CheckpointLoadMetadataConfig, FastLLMCheckpointFormat from fast_llm.engine.multi_stage.config import FastLLMModelConfig @@ -90,12 +91,12 @@ def __eq__(self, other) -> bool: def to_dict(self) -> dict[str, typing.Any]: out = super().to_dict() - out["fast_llm_config"] = self.fast_llm_config.to_serialized(verbose=None) + out["fast_llm_config"] = self.fast_llm_config.to_serialized(verbose=FieldVerboseLevel.everything) return out def to_diff_dict(self) -> dict[str, typing.Any]: out = super().to_diff_dict() - out["fast_llm_config"] = self.fast_llm_config.to_serialized() + out["fast_llm_config"] = self.fast_llm_config.to_serialized(verbose=FieldVerboseLevel.explicit) return out def to_json_file(self, json_file_path: str | os.PathLike, use_diff: bool = True) -> None: From 6bf20cb2d72faabbf5eb6eea4de4f46180f836f8 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 27 Mar 2025 21:26:59 -0400 Subject: [PATCH 008/114] Tests wip --- fast_llm/config.py | 40 ++++-- tests/config/__init__.py | 0 tests/config/common.py | 37 ++++++ tests/config/test_field.py | 176 ++++++++++++++++++++++++++ tests/data/test_dataset_from_file.py | 1 - tests/data/test_prepare_gpt_memmap.py | 1 - tests/test_config.py | 58 +++------ 7 files changed, 258 insertions(+), 55 deletions(-) create mode 100644 tests/config/__init__.py create mode 100644 tests/config/common.py create mode 100644 tests/config/test_field.py diff --git a/fast_llm/config.py b/fast_llm/config.py index 7cb54919..67aa5b7a 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1,4 +1,5 @@ import contextlib +import copy import dataclasses import enum import logging @@ -316,7 +317,7 @@ class Config: _explicit_fields: set[str] = Field(init=False, repr=False) # Used within `_set_implicit_default` to set implicit defaults for fields # without them being automatically added to `_explicit_fields`. - _setting_implicit_default: bool = Field(init=False, repr=False) + _setting_implicit_default: bool | None = Field(init=False, repr=False) def __setattr__(self, key: str, value: typing.Any) -> None: """ @@ -332,12 +333,20 @@ def __setattr__(self, key: str, value: typing.Any) -> None: f"Cannot set attribute `{key}`" f" in configuration class `{get_type_name(type(self))}` after validation." ) - elif not getattr(self, "_setting_implicit_default", True): - field = self.get_field(key) - if field.init and field._field_type != dataclasses._FIELD_CLASSVAR: - # Adding to explicit field list except within `_set_implicit_default` context - # and during dataclass initialization (`_setting_implicit_default` not yet set). - self._explicit_fields.add(key) + if getattr(self, "_setting_implicit_default", None) is not None: + if self._setting_implicit_default: + if key in self._explicit_fields: + raise RuntimeError( + f"Trying to set an implicit default for field `{key}`," + f"but the field has already been set explicitly." + ) + else: + field = self.get_field(key) + if field.init and field._field_type != dataclasses._FIELD_CLASSVAR: + # Adding to explicit field list except within `_set_implicit_default` context, + # during dataclass initialization (`_setting_implicit_default` not yet set) + # and during automated config validation (`_setting_implicit_default=None`) + self._explicit_fields.add(key) super().__setattr__(key, value) def __delattr__(self, key: str) -> None: @@ -352,8 +361,9 @@ def __delattr__(self, key: str) -> None: super().__delattr__(key) @contextlib.contextmanager - def _set_implicit_default(self): - self._setting_implicit_default = True + def _set_implicit_default(self, _value: bool | int = True): + assert self._setting_implicit_default is False + self._setting_implicit_default = _value yield self._setting_implicit_default = False @@ -383,7 +393,7 @@ def _validate(self) -> None: """ self._check_abstract() errors = [] - with self._set_implicit_default(): + with self._set_implicit_default(None): for name, field in self.fields(): if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa continue @@ -567,7 +577,7 @@ def get_field(cls, name: str) -> Field: def _to_dict( self, - verbose: int | None = None, + verbose: int | None = FieldVerboseLevel.explicit, all_fields: bool = False, format_: _ConfigDictFormat = _ConfigDictFormat.nested, serializable: bool = False, @@ -716,6 +726,8 @@ def from_dict( ) -> typing.Self: if isinstance(default, Config): default = default._to_dict() + else: + default = copy.deepcopy(default) for update in updates: if isinstance(update, Config): update = update._to_dict(format_=_ConfigDictFormat.tuple) @@ -980,7 +992,7 @@ def set_nested_dict_value[ d[key] = value elif update_type == UpdateType.update: # TODO: Improve error messages, ex. for nested cases? - if isinstance(d[key], Config): + if isinstance(d.get(key), Config): raise ValueError("Cannot update an already instantiated config.") elif isinstance(value, Config): raise ValueError("Cannot update a config dict with an already instantiated config.") @@ -991,13 +1003,13 @@ def set_nested_dict_value[ d[key] = {} for key_, value_ in value.items(): set_nested_dict_value(d, key_, value_, update_type) - elif isinstance(d[key], dict): + elif isinstance(d.get(key), dict): raise ValueError("Cannot replace a dict with a non-dict value.") elif ( isinstance(value, (list, set, tuple)) and any(isinstance(value_, (list, set, tuple, dict, Config)) for value_ in value) ) or ( - isinstance(d[key], (list, set, tuple)) + isinstance(d.get(key), (list, set, tuple)) and any(isinstance(value_, (list, set, tuple, dict, Config)) for value_ in d[key]) ): raise ValueError("Update not supported for nested lists.") diff --git a/tests/config/__init__.py b/tests/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/config/common.py b/tests/config/common.py new file mode 100644 index 00000000..3109175a --- /dev/null +++ b/tests/config/common.py @@ -0,0 +1,37 @@ +import enum +import pathlib + +from fast_llm.config import Config, Field, FieldHint, config_class + + +class TestEnum(str, enum.Enum): + a = "a" + b = "b" + c = "c" + + +@config_class +class TestConfig(Config): + int_field: int = Field(default=0, hint=FieldHint.optional) + bool_field: bool = Field(default=False, hint=FieldHint.optional) + str_field: str = Field(default="", hint=FieldHint.optional) + path_field: pathlib.Path = Field(default="", hint=FieldHint.optional) + float_field: float = Field(default=4.0, hint=FieldHint.optional) + optional_field: str | None = Field(default=None, hint=FieldHint.optional) + union_field: str | int = Field(default=7, hint=FieldHint.optional) + implicit_field: str = Field(default=None, hint=FieldHint.optional) + list_field: list[int] = Field(default_factory=list, hint=FieldHint.optional) + tuple_field: tuple[int, ...] = Field(default=(), hint=FieldHint.optional) + # tuple_fixed_length_field: tuple[int, str] = Field(default=(5, "text"), hint=FieldHint.optional) + set_field: set[int] = Field(default_factory=set, hint=FieldHint.optional) + dict_field: dict[int, int] = Field(default_factory=dict, hint=FieldHint.optional) + type_field: type[int] = Field(default=int, hint=FieldHint.optional) + enum_field: TestEnum = Field(default=TestEnum.a, hint=FieldHint.optional) + core_field: int = Field(default=4, hint=FieldHint.core) + complex_field: dict[str | int, list[tuple[str, int]] | None] = Field(default_factory=dict, hint=FieldHint.optional) + + def _validate(self) -> None: + with self._set_implicit_default(): + if self.implicit_field is None: + self.implicit_field = "implicit" + super()._validate() diff --git a/tests/config/test_field.py b/tests/config/test_field.py new file mode 100644 index 00000000..27e7c8b5 --- /dev/null +++ b/tests/config/test_field.py @@ -0,0 +1,176 @@ +import math +import pathlib + +import numpy +import pytest + +from fast_llm.config import FieldVerboseLevel, ValidationError +from fast_llm.utils import Assert +from tests.config.common import TestConfig, TestEnum + + +def check_config(internal_config, *alternate, serialized_config=None): + serialized_config = serialized_config if serialized_config else alternate[0] if alternate else internal_config + for init_config in (internal_config, *alternate): + config = TestConfig.from_dict(init_config) + Assert.eq(config.to_serialized(), serialized_config) + Assert.eq(config._to_dict(), internal_config) + + +def check_invalid_config(config): + with pytest.raises(ValidationError): + TestConfig.from_dict(config) + + +def test_create_and_serialize_config(): + Assert.eq(TestConfig.from_dict({}).to_serialized(), {}) + + +@pytest.mark.parametrize("value", (0, -6, 3, True)) +def test_int_field(value): + check_config({"int_field": value}) + + +@pytest.mark.parametrize("value", (4.0, math.inf, "1", None, [4])) +def test_int_field_invalid(value): + check_invalid_config({"int_field": value}) + + +@pytest.mark.parametrize("value", (True, False)) +def test_bool_field(value): + check_config({"bool_field": value}) + + +@pytest.mark.parametrize("value", (1, "True", None, [True])) +def test_bool_field_invalid(value): + check_invalid_config({"bool_field": value}) + + +@pytest.mark.parametrize("value", ("", "text", "1", TestEnum.a)) +def test_str_field(value): + check_config({"str_field": value}) + + +@pytest.mark.parametrize("value", (1, True, None, ["text"], pathlib.Path("a"))) +def test_str_field_invalid(value): + check_invalid_config({"str_field": value}) + + +@pytest.mark.parametrize("value", (".", "text", "/a/b/c.d")) +def test_path_field(value): + check_config({"path_field": pathlib.Path(value)}, {"path_field": value}) + + +@pytest.mark.parametrize("value", (1, True, None, [pathlib.Path("a")])) +def test_path_field_invalid(value): + check_invalid_config({"path_field": value}) + + +@pytest.mark.parametrize("value", (4.0, math.pi, math.inf, 3, True, numpy.float64(3), math.nan)) +def test_float_field(value): + check_config({"float_field": value}) + + +@pytest.mark.parametrize("value", (None, [4.7], "0.0")) +def test_float_field_invalid(value): + check_invalid_config({"float_field": value}) + + +@pytest.mark.parametrize("value", ("", None, "text")) +def test_optional_field(value): + check_config({"optional_field": value}) + + +@pytest.mark.parametrize("value", (True, 6, [None])) +def test_optional_field_invalid(value): + check_invalid_config({"optional": value}) + + +@pytest.mark.parametrize("value", ("", 0, True, "text", 7)) +def test_union_field(value): + check_config({"union_field": value}) + + +@pytest.mark.parametrize("value", (6.0, [""])) +def test_union_field_invalid(value): + check_invalid_config({"optional": value}) + + +@pytest.mark.parametrize("value", ("implicit", "", "text")) +def test_implicit_field(value): + check_config({"implicit_field": value}) + + +TUPLE_VALUES = ((), (1,), (3, 4, 6), (4, 5, 4)) + + +@pytest.mark.parametrize("value", TUPLE_VALUES) +def test_list_field(value): + check_config( + {"list_field": list(value)}, + {"list_field": value}, + serialized_config={"list_field": list(value)}, + ) + + +@pytest.mark.parametrize("value", TUPLE_VALUES) +def test_tuple_field(value): + check_config( + {"tuple_field": list(value)}, + {"tuple_field": value}, + serialized_config={"tuple_field": list(value)}, + ) + + +@pytest.mark.parametrize("value", TUPLE_VALUES) +def test_set_field(value): + check_config( + {"set_field": list(set(value))}, + {"set_field": set(value)}, + {"set_field": list(value)}, + {"set_field": tuple(value)}, + serialized_config={"set_field": list(set(value))}, + ) + + +# @pytest.mark.parametrize("value", ((0, ""), (5, "text"), (True, "True"))) +# def test_tuple_fixed_length_field(value): +# expected_config = {"tuple_variable_length_field": value} +# Assert.eq(TestConfig.from_dict(expected_config).to_serialized(), expected_config) +# Assert.eq(TestConfig.from_dict({"tuple_variable_length_field": list(value)}).to_serialized(), expected_config) +# Assert.eq(TestConfig.from_dict({"tuple_variable_length_field": set(value)}).to_serialized(), {"tuple_variable_length_field": tuple(set(value))}) + + +@pytest.mark.parametrize("value", ({}, {True: 2}, {1: 2, 3: 4})) +def test_dict_field(value): + check_config({"dict_field": value}) + + +class IntClass(int): + pass + + +@pytest.mark.parametrize("value", (int, bool, IntClass)) +def test_type_field(value): + check_config({"type_field": value}, serialized_config={"type_field": str(value)}) + + +@pytest.mark.parametrize("value", (TestEnum.a, TestEnum.b, TestEnum.c)) +def test_enum_field(value): + check_config({"enum_field": value}, {"enum_field": value.value}) + + +def test_core_field(): + Assert.eq(TestConfig.from_dict({}).to_serialized(verbose=FieldVerboseLevel.core), {"core_field": 4}) + + +@pytest.mark.parametrize( + "value", + ( + {}, + {3: None, "text": [], False: [["", 3], ["a", -7]]}, + {0: [[".", 8]]}, + ), +) +def test_complex_field(value): + check_config({"complex_field": value}) diff --git a/tests/data/test_dataset_from_file.py b/tests/data/test_dataset_from_file.py index 4ac2fcdf..280b3413 100644 --- a/tests/data/test_dataset_from_file.py +++ b/tests/data/test_dataset_from_file.py @@ -8,5 +8,4 @@ def test_dataset_from_file(): get_test_dataset() dataset_config = {"type": "file", "path": str(DATASET_PREFIX.parent.joinpath("fast_llm_config.yaml"))} dataset = get_dataset_config(dataset_config, GPTDatasetFromFileConfig).build() - print("kjhbwiugfberibgiujebi", len(dataset)) compare_indexed_dataset(dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES) diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 9a15a051..a6fd3246 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -148,7 +148,6 @@ def test_split_datasets_1(): { "training": { "type": "blended", - "name": "blended", "datasets": [ dataset_config_0.to_serialized(), { diff --git a/tests/test_config.py b/tests/test_config.py index 7141812a..5c45db0b 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,19 +1,14 @@ import pathlib -import pytest import subprocess import unittest.mock -import yaml +import pytest +import yaml -from fast_llm.layers.transformer.config import ( - TransformerConfig, - TransformerArchitectureConfig, - AddLinearBiasChoices, -) -from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.data.dataset.gpt.config import GPTSamplingConfig from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.config import ValidationError - +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.transformer.config import AddLinearBiasChoices, TransformerArchitectureConfig, TransformerConfig from fast_llm.models.auto import trainer_registry @@ -90,33 +85,6 @@ def test_do_use_flash_attention(): config.do_use_flash_attention(mock_distributed_config) -def test_add_linear_biases_valid_values(): - # Valid boolean values - assert TransformerArchitectureConfig(add_linear_biases=True).add_linear_biases is True - assert TransformerArchitectureConfig(add_linear_biases=False).add_linear_biases is False - - # Valid enum values - assert TransformerArchitectureConfig(add_linear_biases="nowhere").add_linear_biases == AddLinearBiasChoices.nowhere - assert ( - TransformerArchitectureConfig(add_linear_biases="everywhere").add_linear_biases - == AddLinearBiasChoices.everywhere - ) - assert ( - TransformerArchitectureConfig(add_linear_biases="only_attn_qkv").add_linear_biases == AddLinearBiasChoices.only_attn_qkv - ) - - -def test_add_linear_biases_invalid_values(): - with pytest.raises(ValidationError): - TransformerArchitectureConfig(add_linear_biases="invalid_value") - - with pytest.raises(ValidationError): - TransformerArchitectureConfig(add_linear_biases=123) - - with pytest.raises(ValidationError): - TransformerArchitectureConfig(add_linear_biases=None) - - def test_add_mlp_bias(): assert TransformerArchitectureConfig(add_linear_biases=True).add_mlp_bias is True assert TransformerArchitectureConfig(add_linear_biases=False).add_mlp_bias is False @@ -130,7 +98,9 @@ def test_add_attn_qkv_bias(): assert TransformerArchitectureConfig(add_linear_biases=False).add_attn_qkv_bias is False assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.everywhere).add_attn_qkv_bias is True assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.nowhere).add_attn_qkv_bias is False - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_attn_qkv_bias is True + assert ( + TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_attn_qkv_bias is True + ) def test_add_attn_dense_bias(): @@ -138,4 +108,14 @@ def test_add_attn_dense_bias(): assert TransformerArchitectureConfig(add_linear_biases=False).add_attn_dense_bias is False assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.everywhere).add_attn_dense_bias is True assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.nowhere).add_attn_dense_bias is False - assert TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_attn_dense_bias is False + assert ( + TransformerArchitectureConfig(add_linear_biases=AddLinearBiasChoices.only_attn_qkv).add_attn_dense_bias + is False + ) + + +@pytest.mark.parametrize("cls", (GPTSamplingConfig,)) +def test_serialize_default_config_updates(cls): + # Config classes used as config updates should have a default that serializes to an empty dict + # so no value is incorrectly overridden. + assert cls.from_dict({}).to_serialized() == {} From c13fb19f8763b0aebe83058b375b9732e70721d2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 28 Mar 2025 22:28:57 -0400 Subject: [PATCH 009/114] misc --- fast_llm/config.py | 29 +++++----- fast_llm/data/dataset/gpt/config.py | 4 +- fast_llm/utils.py | 2 +- tests/config/common.py | 6 +- tests/config/test_field.py | 86 ++++++++++++++++++++++------- 5 files changed, 88 insertions(+), 39 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 67aa5b7a..c311abf4 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -468,10 +468,10 @@ def _validate_element(cls, value, type_, name: str): elif not isinstance(type_, type): raise FieldTypeError(f"Not a type.") elif issubclass(type_, Config): - cls._validate_element_type(value, type_, name) + cls._validate_element_type(value, type_, strict=False) value.validate(_is_validating=True) else: - value = cls._validate_simple(value, type_, name) + value = cls._validate_simple(value, type_) return value @classmethod @@ -491,7 +491,7 @@ def _validate_union(cls, value, type_, name: str): @classmethod def _validate_array(cls, value, type_, name: str): origin = type_.__origin__ - cls._validate_element_type(value, (origin, list, tuple), name) + cls._validate_element_type(value, (origin, list, tuple), strict=False) args = getattr(type_, "__args__", [typing.Any, ...] if origin is tuple else [typing.Any]) errors = [] if issubclass(origin, tuple) and not (len(args) == 2 and args[1] is ...): @@ -518,7 +518,7 @@ def _validate_dict(cls, value, type_, name: str): if len(args) > 2: raise FieldTypeError(f"Invalid dict specification `{get_type_name(type_)}` for field `{name}`") args.extend([typing.Any for _ in range(2 - len(args))]) - cls._validate_element_type(value, type_.__origin__, name) + cls._validate_element_type(value, type_.__origin__, strict=False) errors = [] new_value = {} old_keys = {} @@ -534,19 +534,22 @@ def _validate_dict(cls, value, type_, name: str): return new_value @classmethod - def _validate_simple(cls, value, type_, name: str): + def _validate_simple(cls, value, type_, strict: bool = True): if hasattr(type_, "__fast_llm_validator__"): value = type_.__fast_llm_validator__(value) - elif type_ is float and isinstance(value, int): + elif type_ is float and type(value) == int: # Ints are ok too. value = float(value) elif issubclass(type_, enum.Enum) and not isinstance(value, type_) and issubclass(type_, type(value)): # Enum values are ok too. value = type_(value) - elif issubclass(type_, pathlib.PurePath) and isinstance(value, str): - # Str paths are ok too. - value = type_(value) - cls._validate_element_type(value, type_, name) + elif issubclass(type_, pathlib.PurePath): + if isinstance(value, str): + # Str paths are ok too. + value = type_(value) + # Path type may depend on the OS. + strict = False + cls._validate_element_type(value, type_, strict) return value @classmethod @@ -560,9 +563,9 @@ def _validate_type(cls, value, type_: type | tuple[type, ...], name): raise ValidationError(f"Field value `{value} is not a subclass of `{get_type_name(type_)}`") @classmethod - def _validate_element_type(cls, value, type_: type | tuple[type, ...], name): - if not isinstance(value, type_): - raise ValidationError(f"Unexpected type `{get_type_name(type(value))}`") + def _validate_element_type(cls, value, type_: type | tuple[type, ...], strict: bool = True): + if not (type(value) == type_ if strict else isinstance(value, type_)): + raise ValidationError(f"Unexpected field type: {get_type_name(type(value))} != {get_type_name(type_)}") @classmethod def fields(cls) -> typing.Iterable[tuple[str, Field]]: diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 118b3039..4f15492a 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -484,8 +484,8 @@ def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset: "type": "slice", # TODO: this duplicates memmap datasets for each phase. "dataset": {"type": "memmap", "path": prefix}, - "begin": phase_splits[phase_index], - "end": phase_splits[phase_index + 1], + "begin": float(phase_splits[phase_index]), + "end": float(phase_splits[phase_index + 1]), } for prefix in dataset_prefixes ] diff --git a/fast_llm/utils.py b/fast_llm/utils.py index aac6f607..4edd8b98 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -86,7 +86,7 @@ class Assert: @staticmethod def eq(x, *args, msg=None): for arg in args: - assert x == arg, f"{x} != {arg} " + f"| {msg}" if msg else "" + assert x == arg, f"{x} != {arg} " + (f"| {msg}" if msg else "") @staticmethod def is_(x, y): diff --git a/tests/config/common.py b/tests/config/common.py index 3109175a..143d770c 100644 --- a/tests/config/common.py +++ b/tests/config/common.py @@ -4,14 +4,14 @@ from fast_llm.config import Config, Field, FieldHint, config_class -class TestEnum(str, enum.Enum): +class ExampleEnum(enum.StrEnum): a = "a" b = "b" c = "c" @config_class -class TestConfig(Config): +class ExampleConfig(Config): int_field: int = Field(default=0, hint=FieldHint.optional) bool_field: bool = Field(default=False, hint=FieldHint.optional) str_field: str = Field(default="", hint=FieldHint.optional) @@ -26,7 +26,7 @@ class TestConfig(Config): set_field: set[int] = Field(default_factory=set, hint=FieldHint.optional) dict_field: dict[int, int] = Field(default_factory=dict, hint=FieldHint.optional) type_field: type[int] = Field(default=int, hint=FieldHint.optional) - enum_field: TestEnum = Field(default=TestEnum.a, hint=FieldHint.optional) + enum_field: ExampleEnum = Field(default=ExampleEnum.a, hint=FieldHint.optional) core_field: int = Field(default=4, hint=FieldHint.core) complex_field: dict[str | int, list[tuple[str, int]] | None] = Field(default_factory=dict, hint=FieldHint.optional) diff --git a/tests/config/test_field.py b/tests/config/test_field.py index 27e7c8b5..4f39f741 100644 --- a/tests/config/test_field.py +++ b/tests/config/test_field.py @@ -6,32 +6,59 @@ from fast_llm.config import FieldVerboseLevel, ValidationError from fast_llm.utils import Assert -from tests.config.common import TestConfig, TestEnum +from tests.config.common import ExampleConfig, ExampleEnum + + +def _check_equal(config_a, config_b): + # Check for equality of both values and types. + for key in config_a.keys() | config_b.keys(): + assert key in config_a and key in config_b, key + Assert.eq(type(config_a[key]), type(config_b[key])) + if isinstance(config_a[key], (list, tuple, set)): + Assert.eq(len(config_a[key]), len(config_b[key])) + for i in range(len(config_a[key])): + _check_equal({"": config_a[key][i]}, {"": config_b[key][i]}) + elif isinstance(config_a[key], dict): + _check_equal(config_a[key], config_b[key]) + else: + try: + Assert.eq(config_a[key], config_b[key]) + except AssertionError as e: + # Special case for `math.nan` + if config_a[key] is not config_b[key]: + raise e + + +def check_equal(config_a, config_b): + try: + _check_equal(config_a, config_b) + except AssertionError as e: + raise AssertionError(config_a, config_b, *e.args) def check_config(internal_config, *alternate, serialized_config=None): serialized_config = serialized_config if serialized_config else alternate[0] if alternate else internal_config for init_config in (internal_config, *alternate): - config = TestConfig.from_dict(init_config) - Assert.eq(config.to_serialized(), serialized_config) - Assert.eq(config._to_dict(), internal_config) + config = ExampleConfig.from_dict(init_config) + check_equal(config.to_serialized(), serialized_config) + check_equal(config._to_dict(), internal_config) def check_invalid_config(config): with pytest.raises(ValidationError): - TestConfig.from_dict(config) + ExampleConfig.from_dict(config) def test_create_and_serialize_config(): - Assert.eq(TestConfig.from_dict({}).to_serialized(), {}) + Assert.eq(ExampleConfig.from_dict({}).to_serialized(), {}) -@pytest.mark.parametrize("value", (0, -6, 3, True)) +@pytest.mark.parametrize("value", (0, -6, 3)) def test_int_field(value): check_config({"int_field": value}) -@pytest.mark.parametrize("value", (4.0, math.inf, "1", None, [4])) +@pytest.mark.parametrize("value", (4.0, math.inf, "1", None, [4], True)) def test_int_field_invalid(value): check_invalid_config({"int_field": value}) @@ -46,12 +73,12 @@ def test_bool_field_invalid(value): check_invalid_config({"bool_field": value}) -@pytest.mark.parametrize("value", ("", "text", "1", TestEnum.a)) +@pytest.mark.parametrize("value", ("", "text", "1")) def test_str_field(value): - check_config({"str_field": value}) + check_config({"str_field": str(value)}, {"str_field": value}) -@pytest.mark.parametrize("value", (1, True, None, ["text"], pathlib.Path("a"))) +@pytest.mark.parametrize("value", (1, True, None, ["text"], pathlib.Path("a"), ExampleEnum.a)) def test_str_field_invalid(value): check_invalid_config({"str_field": value}) @@ -66,12 +93,14 @@ def test_path_field_invalid(value): check_invalid_config({"path_field": value}) -@pytest.mark.parametrize("value", (4.0, math.pi, math.inf, 3, True, numpy.float64(3), math.nan)) +@pytest.mark.parametrize("value", (4.0, math.pi, math.inf, 3, math.nan)) def test_float_field(value): - check_config({"float_field": value}) + check_config( + {"float_field": float(value)}, {"float_field": value}, serialized_config={"float_field": float(value)} + ) -@pytest.mark.parametrize("value", (None, [4.7], "0.0")) +@pytest.mark.parametrize("value", (None, [4.7], "0.0", True, numpy.float64(3))) def test_float_field_invalid(value): check_invalid_config({"float_field": value}) @@ -86,16 +115,20 @@ def test_optional_field_invalid(value): check_invalid_config({"optional": value}) -@pytest.mark.parametrize("value", ("", 0, True, "text", 7)) +@pytest.mark.parametrize("value", ("", 0, "text", 7)) def test_union_field(value): check_config({"union_field": value}) -@pytest.mark.parametrize("value", (6.0, [""])) +@pytest.mark.parametrize("value", (6.0, [""], True)) def test_union_field_invalid(value): check_invalid_config({"optional": value}) +def test_implicit_field_value(): + Assert.eq(ExampleConfig.from_dict({}).implicit_field, "implicit") + + @pytest.mark.parametrize("value", ("implicit", "", "text")) def test_implicit_field(value): check_config({"implicit_field": value}) @@ -141,11 +174,16 @@ def test_set_field(value): # Assert.eq(TestConfig.from_dict({"tuple_variable_length_field": set(value)}).to_serialized(), {"tuple_variable_length_field": tuple(set(value))}) -@pytest.mark.parametrize("value", ({}, {True: 2}, {1: 2, 3: 4})) +@pytest.mark.parametrize("value", ({}, {1: 2, 3: 4})) def test_dict_field(value): check_config({"dict_field": value}) +@pytest.mark.parametrize("value", ({True: 2}, {4: "3"}, {4: {1: 4}}, None, 4, {1}, [5, 7], "text")) +def test_dict_field_invalid(value): + check_invalid_config({"dict_field": value}) + + class IntClass(int): pass @@ -155,22 +193,30 @@ def test_type_field(value): check_config({"type_field": value}, serialized_config={"type_field": str(value)}) -@pytest.mark.parametrize("value", (TestEnum.a, TestEnum.b, TestEnum.c)) +@pytest.mark.parametrize("value", (ExampleEnum.a, ExampleEnum.b, ExampleEnum.c)) def test_enum_field(value): check_config({"enum_field": value}, {"enum_field": value.value}) def test_core_field(): - Assert.eq(TestConfig.from_dict({}).to_serialized(verbose=FieldVerboseLevel.core), {"core_field": 4}) + Assert.eq(ExampleConfig.from_dict({}).to_serialized(verbose=FieldVerboseLevel.core), {"core_field": 4}) @pytest.mark.parametrize( "value", ( {}, - {3: None, "text": [], False: [["", 3], ["a", -7]]}, + {3: None, "text": [], 0: [["", 3], ["a", -7]]}, {0: [[".", 8]]}, ), ) def test_complex_field(value): check_config({"complex_field": value}) + + +@pytest.mark.parametrize( + "value", + ({3: None, "text": [], False: [["", 3], ["a", -7]]},), +) +def test_complex_field_invalid(value): + check_invalid_config({"complex_field": value}) From a20fcecfb870fb076bfa067b8622c6a31aa4d928 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 31 Mar 2025 20:23:42 -0400 Subject: [PATCH 010/114] tests --- tests/config/common.py | 21 ++++++ tests/config/test_field.py | 133 +++++++++++++++++++++++++++---------- 2 files changed, 120 insertions(+), 34 deletions(-) diff --git a/tests/config/common.py b/tests/config/common.py index 143d770c..f9449507 100644 --- a/tests/config/common.py +++ b/tests/config/common.py @@ -35,3 +35,24 @@ def _validate(self) -> None: if self.implicit_field is None: self.implicit_field = "implicit" super()._validate() + + +@config_class +class ExampleVerboseConfig(Config): + # These fields will have non-empty default serialized values. + list_default_field: list[int] = Field(default_factory=lambda: [0], hint=FieldHint.optional) + tuple_default_field: tuple[int, ...] = Field(default=(0, 1), hint=FieldHint.optional) + tuple_fixed_length_field: tuple[int, str] = Field(default=(5, "text"), hint=FieldHint.optional) + set_default_field: set[int] = Field(default_factory=lambda: {0, 1, 2}, hint=FieldHint.optional) + dict_default_field: dict[str, int] = Field(default_factory=lambda: {"0": 0, "1": 1}, hint=FieldHint.optional) + explicit_field: str = Field(default=None, hint=FieldHint.optional) + + def _validate(self) -> None: + if self.explicit_field is None: + self.explicit_field = "explicit" + super()._validate() + + +@config_class +class ExampleNestedConfig(ExampleConfig): + nested_field: ExampleConfig = Field(default_factory=ExampleConfig, hint=FieldHint.core) diff --git a/tests/config/test_field.py b/tests/config/test_field.py index 4f39f741..bed9c181 100644 --- a/tests/config/test_field.py +++ b/tests/config/test_field.py @@ -4,29 +4,29 @@ import numpy import pytest -from fast_llm.config import FieldVerboseLevel, ValidationError +from fast_llm.config import Config, FieldVerboseLevel, ValidationError from fast_llm.utils import Assert -from tests.config.common import ExampleConfig, ExampleEnum +from tests.config.common import ExampleConfig, ExampleEnum, ExampleVerboseConfig def _check_equal(config_a, config_b): # Check for equality of both values and types. - for key in config_a.keys() | config_b.keys(): - assert key in config_a and key in config_b, key - Assert.eq(type(config_a[key]), type(config_b[key])) - if isinstance(config_a[key], (list, tuple, set)): - Assert.eq(len(config_a[key]), len(config_b[key])) - for i in range(len(config_a[key])): - _check_equal({"": config_a[key][i]}, {"": config_b[key][i]}) - elif isinstance(config_a[key], dict): + Assert.eq(type(config_a), type(config_b)) + if isinstance(config_a, dict): + for key in config_a.keys() | config_b.keys(): + assert key in config_a and key in config_b, key _check_equal(config_a[key], config_b[key]) - else: - try: - Assert.eq(config_a[key], config_b[key]) - except AssertionError as e: - # Special case for `math.nan` - if config_a[key] is not config_b[key]: - raise e + elif isinstance(config_a, (list, tuple, set)): + Assert.eq(len(config_a), len(config_b)) + for i in range(len(config_a)): + _check_equal(config_a[i], config_b[i]) + else: + try: + Assert.eq(config_a, config_b) + except AssertionError: + # Special case for `math.nan` + if config_a is not config_b: + raise def check_equal(config_a, config_b): @@ -36,17 +36,30 @@ def check_equal(config_a, config_b): raise AssertionError(config_a, config_b, *e.args) -def check_config(internal_config, *alternate, serialized_config=None): +def check_config( + internal_config, + *alternate, + serialized_config=None, + cls: type[Config] = ExampleConfig, + fields: list[str] | None = None, +): serialized_config = serialized_config if serialized_config else alternate[0] if alternate else internal_config for init_config in (internal_config, *alternate): - config = ExampleConfig.from_dict(init_config) - check_equal(config.to_serialized(), serialized_config) - check_equal(config._to_dict(), internal_config) + config = cls.from_dict(init_config) + serialized_config_ = config.to_serialized() + internal_config_ = config._to_dict() + if fields is None: + check_equal(serialized_config_, serialized_config) + check_equal(internal_config_, internal_config) + else: + for field in fields: + check_equal(serialized_config_[field], serialized_config[field]) + check_equal(internal_config_[field], internal_config[field]) -def check_invalid_config(config): +def check_invalid_config(config, cls: type[Config] = ExampleConfig): with pytest.raises(ValidationError): - ExampleConfig.from_dict(config) + cls.from_dict(config) def test_create_and_serialize_config(): @@ -134,10 +147,11 @@ def test_implicit_field(value): check_config({"implicit_field": value}) -TUPLE_VALUES = ((), (1,), (3, 4, 6), (4, 5, 4)) +ARRAY_VALUES = ((), (1,), (3, 4, 6), (4, 5, 4)) +ARRAY_VALUES_INVALID = (6.0, {}, True, "text") -@pytest.mark.parametrize("value", TUPLE_VALUES) +@pytest.mark.parametrize("value", ARRAY_VALUES) def test_list_field(value): check_config( {"list_field": list(value)}, @@ -146,7 +160,12 @@ def test_list_field(value): ) -@pytest.mark.parametrize("value", TUPLE_VALUES) +@pytest.mark.parametrize("value", ARRAY_VALUES_INVALID) +def test_list_field_invalid(value): + check_invalid_config({"list_field": value}) + + +@pytest.mark.parametrize("value", ARRAY_VALUES) def test_tuple_field(value): check_config( {"tuple_field": list(value)}, @@ -155,7 +174,12 @@ def test_tuple_field(value): ) -@pytest.mark.parametrize("value", TUPLE_VALUES) +@pytest.mark.parametrize("value", ARRAY_VALUES_INVALID) +def test_tuple_field_invalid(value): + check_invalid_config({"tuple_field": value}) + + +@pytest.mark.parametrize("value", ARRAY_VALUES) def test_set_field(value): check_config( {"set_field": list(set(value))}, @@ -166,12 +190,9 @@ def test_set_field(value): ) -# @pytest.mark.parametrize("value", ((0, ""), (5, "text"), (True, "True"))) -# def test_tuple_fixed_length_field(value): -# expected_config = {"tuple_variable_length_field": value} -# Assert.eq(TestConfig.from_dict(expected_config).to_serialized(), expected_config) -# Assert.eq(TestConfig.from_dict({"tuple_variable_length_field": list(value)}).to_serialized(), expected_config) -# Assert.eq(TestConfig.from_dict({"tuple_variable_length_field": set(value)}).to_serialized(), {"tuple_variable_length_field": tuple(set(value))}) +@pytest.mark.parametrize("value", ARRAY_VALUES_INVALID) +def test_tuple_field_invalid(value): + check_invalid_config({"set_field": value}) @pytest.mark.parametrize("value", ({}, {1: 2, 3: 4})) @@ -193,9 +214,19 @@ def test_type_field(value): check_config({"type_field": value}, serialized_config={"type_field": str(value)}) +@pytest.mark.parametrize("value", (5, None, [], "text")) +def test_type_field_invalid(value): + check_invalid_config({"type_field": value}) + + @pytest.mark.parametrize("value", (ExampleEnum.a, ExampleEnum.b, ExampleEnum.c)) def test_enum_field(value): - check_config({"enum_field": value}, {"enum_field": value.value}) + check_config({"enum_field": value}, {"enum_field": str(value)}) + + +@pytest.mark.parametrize("value", (5, None, [], "text")) +def test_enum_field_invalid(value): + check_invalid_config({"type_field": value}) def test_core_field(): @@ -220,3 +251,37 @@ def test_complex_field(value): ) def test_complex_field_invalid(value): check_invalid_config({"complex_field": value}) + + +def test_verbose_config_default(): + default_values = { + "list_default_field": [0], + "tuple_default_field": [0, 1], + "tuple_fixed_length_field": [5, "text"], + "set_default_field": [0, 1, 2], + "dict_default_field": {"0": 0, "1": 1}, + "explicit_field": "explicit", + } + config = ExampleVerboseConfig.from_dict({}) + check_equal(config.to_serialized(), default_values) + check_equal(config._to_dict(), default_values) + + +@pytest.mark.parametrize("value", ((0, ""), (5, "text"), (7, "True"))) +def test_tuple_fixed_length_field(value): + check_config( + {"tuple_fixed_length_field": list(value)}, + {"tuple_fixed_length_field": value}, + serialized_config={"tuple_fixed_length_field": list(value)}, + cls=ExampleVerboseConfig, + fields=["tuple_fixed_length_field"], + ) + + +@pytest.mark.parametrize("value", ((), (5,), ("", 0), ("0", "True"), (0, "", "text"))) +def test_tuple_fixed_length_field_invalid(value): + check_invalid_config({"tuple_fixed_length_field": value}, cls=ExampleVerboseConfig) + + +# TODO: Test other fields with defaults. +# TODO: Test nested fields. From 9af372df69e71e7a818bb52f4e7d26706d42e19c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 1 Apr 2025 18:07:02 -0400 Subject: [PATCH 011/114] Tests, fixes, remove tuple format --- fast_llm/config.py | 111 ++++++------------ fast_llm/data/dataset/gpt/config.py | 2 +- fast_llm/data/dataset/gpt/sampled.py | 2 +- .../data/preparator/gpt_memmap/prepare.py | 2 +- fast_llm/engine/checkpoint/distributed.py | 8 +- fast_llm/engine/checkpoint/external.py | 2 +- fast_llm/engine/checkpoint/huggingface.py | 2 +- fast_llm/engine/checkpoint/state_dict.py | 2 +- fast_llm/engine/config_utils/run.py | 4 +- fast_llm/engine/huggingface/config.py | 4 +- fast_llm/engine/training/wandb.py | 2 +- fast_llm/utils.py | 32 +++++ tests/config/common.py | 31 ++++- tests/config/test_field.py | 67 ++--------- tests/config/test_update.py | 52 ++++++++ tests/data/test_prepare_gpt_memmap.py | 20 ++-- tests/test_config.py | 2 +- tools/moe_add_experts.py | 2 +- 18 files changed, 185 insertions(+), 162 deletions(-) create mode 100644 tests/config/test_update.py diff --git a/fast_llm/config.py b/fast_llm/config.py index c311abf4..0abd9073 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -11,7 +11,7 @@ import yaml -from fast_llm.utils import Assert, Tag, get_type_name, header, log +from fast_llm.utils import Assert, Tag, compare_nested, get_type_name, header, log logger = logging.getLogger(__name__) @@ -38,13 +38,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): _AUTO_VALIDATE = self._old_value -class _ConfigDictFormat(str, enum.Enum): - # TODO v0.3: delete class - flat = "flat" - nested = "nested" - tuple = "tuple" - - class UpdateType(str, enum.Enum): # Override entries no matter what they contais. override = "override" @@ -578,33 +571,26 @@ def fields(cls) -> typing.Iterable[tuple[str, Field]]: def get_field(cls, name: str) -> Field: return cls.__dataclass_fields__[name] # noqa - def _to_dict( + def to_dict( self, verbose: int | None = FieldVerboseLevel.explicit, all_fields: bool = False, - format_: _ConfigDictFormat = _ConfigDictFormat.nested, - serializable: bool = False, + serialized: bool = True, ) -> dict[str, typing.Any]: """ Serialize the config to a dict that can (generally) be used to reconstruct an identical `Config`. - When not flat, the dict includes a `__class__` entry which allows support for derived classes. Args: all_fields: Include the derived fields, with `init=False`. - format_: The config format used to represent nested configs. Options: - * `ConfigDictFormat.nested`: Preserve the nested config structure by returning nested dicts. - Also save a `__class__` entry to support derived classes. Standard format. - * `ConfigDictFormat.tuple`: Preserve the nested config structure by returning tuples of keys. - Used for config updates. - serializable: Ensure the dict is serializable to json or yaml. Information may be lost. + serialized: Ensure the dict is serializable to json or yaml. Information may be lost. """ arg_dict = {} for name, field in self.fields(): value = getattr(self, name, MISSING) - self._add_field_to_args(arg_dict, name, field, value, verbose, all_fields, format_, serializable) + self._add_field_to_args(arg_dict, name, field, value, verbose, all_fields, serialized) if hasattr(self, "_unknown_fields"): for name, value in self._unknown_fields.items(): - self._add_field_to_args(arg_dict, f"!!! {name}", None, value, None, all_fields, format_, serializable) + self._add_field_to_args(arg_dict, f"!!! {name}", None, value, None, all_fields, serialized) return arg_dict @@ -616,13 +602,12 @@ def _add_field_to_args( value: typing.Any, verbose: int | None = None, all_fields: bool = False, - format_: _ConfigDictFormat = _ConfigDictFormat.nested, - serializable: bool = False, + serializable: bool = True, ) -> None: if ( field is not None and (not field.init or field._field_type == dataclasses._FIELD_CLASSVAR) - and not (all_fields) + and not all_fields ): # Exclude class variables and derived fields unless requested explicitly. return @@ -632,48 +617,36 @@ def _add_field_to_args( or (verbose is not None and verbose >= FieldHintImportance[field.hint]) ) if isinstance(value, Config): - field_value = value._to_dict( + field_value = value.to_dict( verbose=verbose, all_fields=all_fields, - format_=format_, - serializable=serializable, + serialized=serializable, ) # Empty configs can safely be trimmed. explicit_field = all_fields elif isinstance(value, (list, tuple, set)): - field_value = {} if format_ == _ConfigDictFormat.tuple else [] + field_value = [] for i, list_value in enumerate(value): - self._add_field_to_args( - field_value, str(i), None, list_value, verbose, all_fields, format_, serializable - ) + self._add_field_to_args(field_value, str(i), None, list_value, verbose, all_fields, serializable) elif isinstance(value, dict): field_value = {} for dict_name, dict_value in value.items(): - self._add_field_to_args( - field_value, dict_name, None, dict_value, verbose, all_fields, format_, serializable - ) + self._add_field_to_args(field_value, dict_name, None, dict_value, verbose, all_fields, serializable) elif explicit_field: field_value = value if serializable: field_value = self._serialize_value(value) - if format_ == _ConfigDictFormat.tuple: - field_value = {(): field_value} else: # Exclude unimportant (implicit or explicit) default values. return if serializable: name = self._serialize_value(name) - if format_ == _ConfigDictFormat.tuple: - args.update({(name,) + name_: value_ for name_, value_ in field_value.items()}) - elif format_ == _ConfigDictFormat.nested: - if not isinstance(field_value, (dict, list)) or len(field_value) > 0 or explicit_field or all_fields: - if isinstance(args, dict): - args[name] = field_value - else: - args.append(field_value) - else: - raise NotImplementedError(format_) + if not isinstance(field_value, (dict, list)) or len(field_value) > 0 or explicit_field or all_fields: + if isinstance(args, dict): + args[name] = field_value + else: + args.append(field_value) @classmethod def _serialize_value(cls, value: typing.Any) -> int | float | bool | str | None: @@ -689,12 +662,14 @@ def _serialize_value(cls, value: typing.Any) -> int | float | bool | str | None: return value def to_copy[ - T - ](self: T, *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], strict: bool = True,) -> T: - return self.from_dict(self, *updates, strict=strict) - - def to_serialized(self, verbose: int | None = FieldVerboseLevel.explicit) -> dict[str, typing.Any]: - return self._to_dict(verbose=verbose, format_=_ConfigDictFormat.nested, serializable=True) + T: Config, + ]( + self: T, + *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], + strict: bool = True, + update_type: UpdateType = UpdateType.override, + ) -> T: + return self.from_dict(self, *updates, strict=strict, update_type=update_type) def to_logs[ T @@ -706,7 +681,7 @@ def to_logs[ width: int = 80, fill_char: str = "-", ) -> T: - arg_dict = self.to_serialized(verbose=verbose) + arg_dict = self.to_dict(verbose=verbose) if title is None: title = self._get_class_name() return log_fn( @@ -728,12 +703,14 @@ def from_dict( update_type: UpdateType = UpdateType.override, ) -> typing.Self: if isinstance(default, Config): - default = default._to_dict() + default = default.to_dict(serialized=False) else: default = copy.deepcopy(default) for update in updates: if isinstance(update, Config): - update = update._to_dict(format_=_ConfigDictFormat.tuple) + update = update.to_dict(serialized=False) + else: + update = copy.deepcopy(update) for keys, value in update.items(): set_nested_dict_value(default, keys, value, update_type) @@ -878,27 +855,15 @@ def _handle_renamed_field( def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typing.Callable] = ValueError): # TODO: Check classes? - self_dict = self._to_dict( - format_=_ConfigDictFormat.tuple, serializable=True, verbose=FieldVerboseLevel.everything - ) - other_dict = other._to_dict( - format_=_ConfigDictFormat.tuple, serializable=True, verbose=FieldVerboseLevel.everything - ) - compare = { - key: (self_dict.get(key, MISSING), other_dict.get(key, MISSING)) - for key in self_dict.keys() | other_dict.keys() - } - diff = { - key: (self_value, other_value) - for key, (self_value, other_value) in compare.items() - if self_value != other_value - } - if diff: - log( + self_dict = self.to_dict(verbose=FieldVerboseLevel.everything) + other_dict = other.to_dict(verbose=FieldVerboseLevel.everything) + errors = compare_nested(self_dict, other_dict) + if errors: + return log( f"Config diff:\n " + "\n ".join( f"{'.'.join(key)}`: `{self_value}` != `{other_value}`" - for key, (self_value, other_value) in diff.items() + for key, (self_value, other_value) in errors.items() ), log_fn=log_fn, ) @@ -1005,7 +970,7 @@ def set_nested_dict_value[ else: d[key] = {} for key_, value_ in value.items(): - set_nested_dict_value(d, key_, value_, update_type) + set_nested_dict_value(d[key], key_, value_, update_type) elif isinstance(d.get(key), dict): raise ValueError("Cannot replace a dict with a non-dict value.") elif ( diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 9c5e6f13..0958f118 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -505,7 +505,7 @@ def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset: dataset_config = { "type": "fim", "dataset": dataset_config, - **self.fim.to_serialized(), + **self.fim.to_dict(), } # Legacy sampling config dataset_config = { diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index f5d23031..25529ef0 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -154,7 +154,7 @@ def _sample(self) -> None: "num_samples": self._num_samples, "unshuffled_epochs": unshuffled_epochs, "sequence_length": self._sequence_length, - "config": self._config.to_serialized(), + "config": self._config.to_dict(), } self._load_yaml_data(yaml_data) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index b3dae1df..23e497bf 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -281,7 +281,7 @@ def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[GPTMemmapDa def _save_dataset_config(cls, dataset_config: GPTIndexedDatasetConfig, output_path: pathlib.Path) -> None: logger.info(f"Saving config to {output_path}") yaml.safe_dump( - dataset_config.to_serialized(), + dataset_config.to_dict(), output_path.open("w"), ) diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index 503839f0..f27fff5d 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -32,7 +32,7 @@ def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetada return CheckpointMetadata.from_dict(yaml.safe_load((config.path / "metadata.yaml").open("r"))) def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None: - serialized_metadata = metadata.to_serialized() + serialized_metadata = metadata.to_dict() if self._model.config.distributed.rank == 0: yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w")) safetensors.torch.save_file( @@ -50,10 +50,8 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No Assert.leq(set(self.get_shard_names(config)), set(metadata.shards)) Assert.eq(metadata.shards[: len(shard_names)], list(shard_names)) - same_format = ( - loaded_config.to_serialized(verbose=None) == self._model.config.to_serialized(verbose=None) - and config.optimizer_state - ) + # Using `log_fn=bool` sets the output to true if the error list is non-empty. + same_format = config.optimizer_state and not loaded_config.compare(self._model.config, log_fn=bool) # Make sure all nodes agree on which loading scheme to use. # Note: they may not agree before the broadcast because of the rank comparison, but that's ok. same_format = broadcast_scalar(same_format, torch.uint8, self._model.distributed.world_group) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 98cab927..654ba21f 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -232,7 +232,7 @@ def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetada fast_llm_version=__version__, model=cls._model_class, format=config.format, - config=cls._model_class.from_dict({"base_model": imported_model_config.to_serialized()}), + config=cls._model_class.from_dict({"base_model": imported_model_config.to_dict()}), shards=["weights"], ) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 87651dc4..f335015a 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -34,7 +34,7 @@ def _serialize_metadata(self, config: CheckpointSaveMetadataConfig, metadata: Ch huggingface_config = self._export_config(self._model.config.base_model) self._save_config(config.path, huggingface_config) return { - "fast_llm_metadata": metadata.to_serialized(), + "fast_llm_metadata": metadata.to_dict(), "model_config": huggingface_config, "format": "pt", } diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index 5288d49f..71c83ece 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -71,7 +71,7 @@ def _save_serialized_metadata(self, config: CheckpointSaveMetadataConfig, metada def _serialize_metadata( self, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata ) -> dict[str, typing.Any]: - return metadata.to_serialized() + return metadata.to_dict() def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None: with SafeLoad(self._model, shard_names=self.get_shard_names(config), timeout=config.timeout) as context: diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 0ac46339..d6377409 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -147,8 +147,8 @@ def __init__( self._is_pipeline_parallel_main_rank = ( self._distributed_config.data_rank == 0 and self._distributed_config.tensor_rank == 0 ) - config_dict = config.to_serialized() - config_dict_verbose = config.to_serialized(verbose=FieldVerboseLevel.performance) + config_dict = config.to_dict() + config_dict_verbose = config.to_dict(verbose=FieldVerboseLevel.performance) if self._config.experiment_dir is not None: self._experiment_directory = self._config.experiment_dir.resolve() diff --git a/fast_llm/engine/huggingface/config.py b/fast_llm/engine/huggingface/config.py index 2b240e4b..d4b46bcc 100644 --- a/fast_llm/engine/huggingface/config.py +++ b/fast_llm/engine/huggingface/config.py @@ -91,12 +91,12 @@ def __eq__(self, other) -> bool: def to_dict(self) -> dict[str, typing.Any]: out = super().to_dict() - out["fast_llm_config"] = self.fast_llm_config.to_serialized(verbose=FieldVerboseLevel.everything) + out["fast_llm_config"] = self.fast_llm_config.to_dict(verbose=FieldVerboseLevel.everything) return out def to_diff_dict(self) -> dict[str, typing.Any]: out = super().to_diff_dict() - out["fast_llm_config"] = self.fast_llm_config.to_serialized(verbose=FieldVerboseLevel.explicit) + out["fast_llm_config"] = self.fast_llm_config.to_dict(verbose=FieldVerboseLevel.explicit) return out def to_json_file(self, json_file_path: str | os.PathLike, use_diff: bool = True) -> None: diff --git a/fast_llm/engine/training/wandb.py b/fast_llm/engine/training/wandb.py index e3d421a3..185b89c2 100644 --- a/fast_llm/engine/training/wandb.py +++ b/fast_llm/engine/training/wandb.py @@ -40,7 +40,7 @@ def __init__(self, config: WandbConfig, run: Run, experiment_config: Config): if wandb_path is not None: yaml.safe_dump(wandb_config, wandb_path.open("w")) # TODO: Does wandb work with nested configs? - self._wandb = wandb.init(config=experiment_config.to_serialized(), **wandb_config) + self._wandb = wandb.init(config=experiment_config.to_dict(), **wandb_config) else: self._wandb = None diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 4edd8b98..da083eef 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -289,3 +289,35 @@ def new_decorator(*args, **kwargs): return out return new_decorator + + +def compare_nested(config_a, config_b, errors: list | None = None, prefix: tuple = ()): + if errors is None: + errors = [] + # Check for equality of both values and types. + if type(config_a) != type(config_b): + errors.append(f"Type mismatch for key `{".".join(prefix)}`: {type(config_a)} != {type(config_b)}") + if isinstance(config_a, dict): + for key in config_a.keys() | config_b.keys(): + key_ = prefix + (key,) + if key not in config_a: + errors.append(f"Key `{".".join(key_)}` missing in lhs.") + elif key not in config_b: + errors.append(f"Key `{".".join(key_)}` missing in rhs.") + else: + compare_nested(config_a[key], config_b[key], errors, key_) + elif isinstance(config_a, (list, tuple, set)): + if len(config_a) != len(config_b): + errors.append(f"Length mismatch for key `{".".join(prefix)}`: {len(config_a)} != {len(config_b)}.") + else: + for i in range(len(config_a)): + compare_nested(config_a[i], config_b[i], errors, prefix + (str(i),)) + elif config_a != config_b and config_a is not config_b: + # `is not` needed for special cases like `math.nan` + errors.append(f"Different value for key `{".".join(prefix)}`: {config_a} != {config_b}.") + return errors + + +def check_equal_nested(config_a, config_b): + if errors := compare_nested(config_a, config_b): + raise ValueError("\n".join(errors)) diff --git a/tests/config/common.py b/tests/config/common.py index f9449507..a2657926 100644 --- a/tests/config/common.py +++ b/tests/config/common.py @@ -1,7 +1,10 @@ import enum import pathlib -from fast_llm.config import Config, Field, FieldHint, config_class +import pytest + +from fast_llm.config import Config, Field, FieldHint, ValidationError, config_class +from fast_llm.utils import check_equal_nested class ExampleEnum(enum.StrEnum): @@ -56,3 +59,29 @@ def _validate(self) -> None: @config_class class ExampleNestedConfig(ExampleConfig): nested_field: ExampleConfig = Field(default_factory=ExampleConfig, hint=FieldHint.core) + + +def check_config( + internal_config, + *alternate, + serialized_config=None, + cls: type[Config] = ExampleConfig, + fields: list[str] | None = None, +): + serialized_config = serialized_config if serialized_config else alternate[0] if alternate else internal_config + for init_config in (internal_config, *alternate): + config = cls.from_dict(init_config) + serialized_config_ = config.to_dict() + internal_config_ = config.to_dict(serialized=False) + if fields is None: + check_equal_nested(serialized_config_, serialized_config) + check_equal_nested(internal_config_, internal_config) + else: + for field in fields: + check_equal_nested(serialized_config_[field], serialized_config[field]) + check_equal_nested(internal_config_[field], internal_config[field]) + + +def check_invalid_config(config, cls: type[Config] = ExampleConfig): + with pytest.raises(ValidationError): + cls.from_dict(config) diff --git a/tests/config/test_field.py b/tests/config/test_field.py index bed9c181..91b5c0d8 100644 --- a/tests/config/test_field.py +++ b/tests/config/test_field.py @@ -4,66 +4,13 @@ import numpy import pytest -from fast_llm.config import Config, FieldVerboseLevel, ValidationError -from fast_llm.utils import Assert -from tests.config.common import ExampleConfig, ExampleEnum, ExampleVerboseConfig - - -def _check_equal(config_a, config_b): - # Check for equality of both values and types. - Assert.eq(type(config_a), type(config_b)) - if isinstance(config_a, dict): - for key in config_a.keys() | config_b.keys(): - assert key in config_a and key in config_b, key - _check_equal(config_a[key], config_b[key]) - elif isinstance(config_a, (list, tuple, set)): - Assert.eq(len(config_a), len(config_b)) - for i in range(len(config_a)): - _check_equal(config_a[i], config_b[i]) - else: - try: - Assert.eq(config_a, config_b) - except AssertionError: - # Special case for `math.nan` - if config_a is not config_b: - raise - - -def check_equal(config_a, config_b): - try: - _check_equal(config_a, config_b) - except AssertionError as e: - raise AssertionError(config_a, config_b, *e.args) - - -def check_config( - internal_config, - *alternate, - serialized_config=None, - cls: type[Config] = ExampleConfig, - fields: list[str] | None = None, -): - serialized_config = serialized_config if serialized_config else alternate[0] if alternate else internal_config - for init_config in (internal_config, *alternate): - config = cls.from_dict(init_config) - serialized_config_ = config.to_serialized() - internal_config_ = config._to_dict() - if fields is None: - check_equal(serialized_config_, serialized_config) - check_equal(internal_config_, internal_config) - else: - for field in fields: - check_equal(serialized_config_[field], serialized_config[field]) - check_equal(internal_config_[field], internal_config[field]) - - -def check_invalid_config(config, cls: type[Config] = ExampleConfig): - with pytest.raises(ValidationError): - cls.from_dict(config) +from fast_llm.config import FieldVerboseLevel +from fast_llm.utils import Assert, check_equal_nested +from tests.config.common import ExampleConfig, ExampleEnum, ExampleVerboseConfig, check_config, check_invalid_config def test_create_and_serialize_config(): - Assert.eq(ExampleConfig.from_dict({}).to_serialized(), {}) + Assert.eq(ExampleConfig.from_dict({}).to_dict(), {}) @pytest.mark.parametrize("value", (0, -6, 3)) @@ -230,7 +177,7 @@ def test_enum_field_invalid(value): def test_core_field(): - Assert.eq(ExampleConfig.from_dict({}).to_serialized(verbose=FieldVerboseLevel.core), {"core_field": 4}) + Assert.eq(ExampleConfig.from_dict({}).to_dict(verbose=FieldVerboseLevel.core), {"core_field": 4}) @pytest.mark.parametrize( @@ -263,8 +210,8 @@ def test_verbose_config_default(): "explicit_field": "explicit", } config = ExampleVerboseConfig.from_dict({}) - check_equal(config.to_serialized(), default_values) - check_equal(config._to_dict(), default_values) + check_equal_nested(config.to_dict(), default_values) + check_equal_nested(config.to_dict(serialized=False), default_values) @pytest.mark.parametrize("value", ((0, ""), (5, "text"), (7, "True"))) diff --git a/tests/config/test_update.py b/tests/config/test_update.py new file mode 100644 index 00000000..ad534d49 --- /dev/null +++ b/tests/config/test_update.py @@ -0,0 +1,52 @@ +import pytest + +from fast_llm.config import UpdateType +from fast_llm.utils import check_equal_nested +from tests.config.common import ExampleNestedConfig + +TEST_CONFIGS = ( + ( + # Empty config + {}, + {}, + {}, + None, + ), + ( + # Update unset field; don't update set field; update + {"int_field": 4, "str_field": "text"}, + {"float_field": 3.0, "str_field": ""}, + {"int_field": 4, "float_field": 3.0, "str_field": ""}, + None, + ), + ( + # Update/override nested field. + {"nested_field": {"int_field": 4, "str_field": "text"}}, + {"nested_field": {"float_field": 3.0, "str_field": ""}}, + {"nested_field": {"int_field": 4, "float_field": 3.0, "str_field": ""}}, + {"nested_field": {"float_field": 3.0, "str_field": ""}}, + ), + # TODO: Add more complex cases +) + + +@pytest.mark.parametrize(("base", "update", "updated", "overridden"), TEST_CONFIGS) +def test_update(base, update, updated, overridden) -> None: + if overridden is None: + overridden = updated + check_equal_nested(ExampleNestedConfig.from_dict(base, update, update_type=UpdateType.update).to_dict(), updated) + check_equal_nested( + ExampleNestedConfig.from_dict(base) + .to_copy(ExampleNestedConfig.from_dict(update), update_type=UpdateType.update) + .to_dict(), + updated, + ) + check_equal_nested( + ExampleNestedConfig.from_dict(base, update, update_type=UpdateType.override).to_dict(), overridden + ) + check_equal_nested( + ExampleNestedConfig.from_dict(base) + .to_copy(ExampleNestedConfig.from_dict(update), update_type=UpdateType.override) + .to_dict(), + overridden, + ) diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index a6fd3246..9dd7975c 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -95,20 +95,20 @@ def test_split_dataset(): {"training": 3, "validation": 1}, pathlib.Path("."), ) - config = {key: value.to_serialized() for key, value in config.items()} + config = {key: value.to_dict() for key, value in config.items()} Assert.eq( config, { "training": { "type": "slice", - "dataset": dataset_config_0.to_serialized(), + "dataset": dataset_config_0.to_dict(), "begin": 0, "end": 0.75, }, "validation": { "type": "slice", - "dataset": dataset_config_0.to_serialized(), + "dataset": dataset_config_0.to_dict(), "begin": 0.75, "end": 1, }, @@ -124,13 +124,13 @@ def test_split_datasets_0(): {"training": 1, "validation": 1}, pathlib.Path("."), ) - config = {key: value.to_serialized() for key, value in config.items()} + config = {key: value.to_dict() for key, value in config.items()} Assert.eq( config, { - "training": dataset_config_0.to_serialized(), - "validation": dataset_config_1.to_serialized(), + "training": dataset_config_0.to_dict(), + "validation": dataset_config_1.to_dict(), }, ) @@ -141,7 +141,7 @@ def test_split_datasets_1(): config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0, dataset_config_1], {"training": 3, "validation": 1}, pathlib.Path(".") ) - config = {key: value.to_serialized() for key, value in config.items()} + config = {key: value.to_dict() for key, value in config.items()} Assert.eq( config, @@ -149,10 +149,10 @@ def test_split_datasets_1(): "training": { "type": "blended", "datasets": [ - dataset_config_0.to_serialized(), + dataset_config_0.to_dict(), { "type": "slice", - "dataset": dataset_config_1.to_serialized(), + "dataset": dataset_config_1.to_dict(), "begin": 0, "end": 0.5, }, @@ -161,7 +161,7 @@ def test_split_datasets_1(): }, "validation": { "type": "slice", - "dataset": dataset_config_1.to_serialized(), + "dataset": dataset_config_1.to_dict(), "begin": 0.5, "end": 1, }, diff --git a/tests/test_config.py b/tests/test_config.py index 5c45db0b..79437e9d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -118,4 +118,4 @@ def test_add_attn_dense_bias(): def test_serialize_default_config_updates(cls): # Config classes used as config updates should have a default that serializes to an empty dict # so no value is incorrectly overridden. - assert cls.from_dict({}).to_serialized() == {} + assert cls.from_dict({}).to_dict() == {} diff --git a/tools/moe_add_experts.py b/tools/moe_add_experts.py index 975ece86..69311017 100644 --- a/tools/moe_add_experts.py +++ b/tools/moe_add_experts.py @@ -93,7 +93,7 @@ def run(self): model.save_pretrained(self.output_dir, state_dict=state_dict) # Save surgery config as yaml - yaml.safe_dump(self.to_serialized(), (self.output_dir / "surgery_config.yaml").open("w")) + yaml.safe_dump(self.to_dict(), (self.output_dir / "surgery_config.yaml").open("w")) logger.info("Done!") From dded00af39930f7cc57ade985dd65e314e3b62a4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 1 Apr 2025 20:19:15 -0400 Subject: [PATCH 012/114] fix --- fast_llm/config.py | 10 ++++------ fast_llm/utils.py | 5 +++++ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 0abd9073..62db786d 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -391,9 +391,11 @@ def _validate(self) -> None: if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa continue value = getattr(self, name) - if value is DEFAULT: + if isinstance(value, Tag): + Assert.is_(value, DEFAULT) # Replace the value with its default. # We still need to validate because some fields have invalid defaults. + # TODO: Improve (still needed with new config update format? Do earlier to allow implicit defaults?) value = field.default new_value = self._validate_nested(value, field.type, field.name, field.valid, errors, False) setattr(self, name, new_value) @@ -860,11 +862,7 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ errors = compare_nested(self_dict, other_dict) if errors: return log( - f"Config diff:\n " - + "\n ".join( - f"{'.'.join(key)}`: `{self_value}` != `{other_value}`" - for key, (self_value, other_value) in errors.items() - ), + f"Config comparison errors:\n " + "\n".join(errors), log_fn=log_fn, ) diff --git a/fast_llm/utils.py b/fast_llm/utils.py index da083eef..a8c5eac6 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -71,12 +71,17 @@ def rms_diff(x: "torch.Tensor", y: "torch.Tensor") -> "torch.Tensor": class Tag: + __slots__ = ("value",) + def __init__(self, value: str): self.value = value def __repr__(self) -> str: return self.value + def __deepcopy__(self, memodict: dict[str, typing.Any]) -> typing.Self: + return self + class Assert: """ From 986f9f3c9a5ebdc40dd9879540449a0fdb2aa80f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 1 Apr 2025 20:27:32 -0400 Subject: [PATCH 013/114] fix --- tests/test_checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index d5685a71..d446f414 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -409,6 +409,7 @@ def test_load_pretrained_distributed_with_config(): ) +@pytest.mark.skip(reason="Fails because of incorrect init config.") @pytest.mark.depends(on=["test_load_pretrained_distributed_in_dp2"]) def test_load_pretrained_in_dp2_match_checkpoint(): test_ckpt_path = TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" / "checkpoint" / "1" From 8e3e7957b759d17c194d78edf736af7136d0586d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 1 Apr 2025 21:21:25 -0400 Subject: [PATCH 014/114] fixes --- fast_llm/engine/checkpoint/distributed.py | 2 +- tests/common.py | 4 ++-- tests/test_checkpoint.py | 3 --- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index e3cd7d16..4225a404 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -48,7 +48,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No Assert.eq(metadata.shards[: len(shard_names)], list(shard_names)) # Using `log_fn=bool` sets the output to true if the error list is non-empty. - same_format = config.optimizer_state and not loaded_config.compare(self._model.config, log_fn=bool) + same_format = config.optimizer_state and not metadata.config.compare(self._model.config, log_fn=bool) # Make sure all nodes agree on which loading scheme to use. # Note: they may not agree before the broadcast because of the rank comparison, but that's ok. same_format = broadcast_scalar(same_format, torch.uint8, self._model.distributed.world_group) diff --git a/tests/common.py b/tests/common.py index 14ec5c61..cc749901 100644 --- a/tests/common.py +++ b/tests/common.py @@ -54,7 +54,7 @@ "model.base_model.transformer.num_layers=2", "model.base_model.transformer.hidden_size=256", "model.base_model.transformer.num_attention_heads=8", - "model.base_model.transformer.init_method_std=0.022", + # "model.base_model.transformer.init_method_std=0.022", f"model.base_model.vocab_size={TEST_VOCAB_SIZE}", f"model.multi_stage.debug_param_init={_LOG_LEVEL}", f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", @@ -101,7 +101,7 @@ "--global-batch-size=8", "--max-position-embeddings=512", "--seq-length=512", - "--init-method-std=0.022", + "--init-method-std=0.0625", "--lr=0.0001", "--num-workers=0", "--valid-num-workers=0", diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index d446f414..6793a670 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -259,7 +259,6 @@ def test_load_pretrained_distributed_checkpoint(): path=_CKPT_PATH, format=DistributedCheckpointFormat, optimizer_state=True, - load_config=ModelConfigType.fast_llm, ) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_ref) _compare_configs(config.base_model, model.config.base_model) @@ -409,7 +408,6 @@ def test_load_pretrained_distributed_with_config(): ) -@pytest.mark.skip(reason="Fails because of incorrect init config.") @pytest.mark.depends(on=["test_load_pretrained_distributed_in_dp2"]) def test_load_pretrained_in_dp2_match_checkpoint(): test_ckpt_path = TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" / "checkpoint" / "1" @@ -454,7 +452,6 @@ def test_load_pretrained_in_dp2_match_checkpoint(): assert (stage_shard_test[stage_shard_ref.numel() :] == 0).all() # noqa -@pytest.mark.skip(reason="Fails because of incorrect init config.") @pytest.mark.slow @pytest.mark.depends(on=["test_load_pretrained_in_dp2_match_checkpoint"]) def test_load_distributed_checkpoint_dp2(): From da6eb7bf7b16b709c81f06df50a5cac342ee7915 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 3 Apr 2025 01:17:37 -0400 Subject: [PATCH 015/114] fixes --- fast_llm/data/dataset/gpt/sampled.py | 4 +- fast_llm/engine/checkpoint/config.py | 30 +++++- fast_llm/engine/checkpoint/distributed.py | 24 +++-- fast_llm/engine/checkpoint/external.py | 2 +- fast_llm/engine/checkpoint/huggingface.py | 5 +- fast_llm/engine/checkpoint/state_dict.py | 4 +- fast_llm/engine/huggingface/config.py | 5 +- fast_llm/engine/multi_stage/fast_llm_model.py | 7 +- fast_llm/engine/training/trainer.py | 1 + tests/common.py | 6 +- tests/test_checkpoint.py | 95 +++++++++++-------- tests/test_config.py | 11 ++- 12 files changed, 124 insertions(+), 70 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index c96eb35f..fa486216 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -65,7 +65,7 @@ def __getitem__(self, item: typing.Any) -> np.ndarray: def _lazy_load(self): if self._array is None: - assert self.exists() + assert self.exists(), self._path self._array = np.load(self._path, mmap_mode="r") @@ -432,7 +432,7 @@ def _lazy_load(self): def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._documents_per_epoch = data["dataset"]["documents_per_epoch"] - if unshuffled_tokens := data.get("unshuffled_tokens") is not None: + if (unshuffled_tokens := data.get("unshuffled_tokens")) is not None: self._unshuffled_tokens = unshuffled_tokens else: self._unshuffled_tokens = data["unshuffled_epochs"] * data["dataset"]["tokens_per_epoch"] diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index 7dbd5ce7..55440a5c 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -202,6 +202,17 @@ class CheckpointSaveConfig(CheckpointSaveMetadataConfig, CheckpointStateSaveConf class CheckpointLoadMetadataConfig(CheckpointPathConfigBase): _abstract = False + load_config: ModelConfigType = Field( + default=ModelConfigType.model, + desc="Configuration to save/load.", + hint=FieldHint.core, + ) + + def _validate(self) -> None: + super()._validate() + if self.format.enforce_architecture_match: + assert self.load_config.load_architecture + @config_class() class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase): @@ -225,8 +236,23 @@ def __init__(self, model: "FastLLMModel"): # TODO: save_metadata? @classmethod - @abc.abstractmethod def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata": + updates = {} + metadata = cls._load_metadata(config) + if not config.load_config.load_fast_llm: + updates[("config", "multi_stage")] = {} + updates[("config", "distributed")] = {} + if not config.load_config.load_architecture: + updates[("config", "base_model")] = {} + elif not config.load_config.load_base_model: + updates[("config", "base_model")] = metadata.config.base_model.get_architecture() + if updates: + metadata = metadata.to_copy(updates) + return metadata + + @classmethod + @abc.abstractmethod + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata": pass @abc.abstractmethod @@ -234,7 +260,7 @@ def save(self, config: CheckpointSaveConfig, metadata: "CheckpointMetadata"): pass @abc.abstractmethod - def load(self, config: CheckpointLoadConfig, metadata: "CheckpointMetadata"): + def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: pass def get_shard_names(self, config: CheckpointStateConfigBase) -> tuple[str, ...]: diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index 4225a404..ac06df5c 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -13,6 +13,7 @@ CheckpointLoadMetadataConfig, CheckpointSaveConfig, DistributedCheckpointFormat, + ModelConfigType, export_safetensors_metadata, ) from fast_llm.engine.checkpoint.safe_load import SafeLoad @@ -27,7 +28,7 @@ class DistributedCheckpointHandler(CheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = DistributedCheckpointFormat @classmethod - def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: return CheckpointMetadata.from_dict(yaml.safe_load((config.path / "metadata.yaml").open("r"))) def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None: @@ -40,15 +41,16 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No metadata=export_safetensors_metadata(serialized_metadata), ) - def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None: + def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: # TODO: More safety checks + loaded_metadata = self._model.config.load_metadata(config.to_copy({"load_config": ModelConfigType.fast_llm})) shard_names = self.get_shard_names(config) # Make sure all shards to load are in the checkpoint. - Assert.leq(set(self.get_shard_names(config)), set(metadata.shards)) - Assert.eq(metadata.shards[: len(shard_names)], list(shard_names)) + Assert.leq(set(self.get_shard_names(config)), set(loaded_metadata.shards)) + Assert.eq(loaded_metadata.shards[: len(shard_names)], list(shard_names)) # Using `log_fn=bool` sets the output to true if the error list is non-empty. - same_format = config.optimizer_state and not metadata.config.compare(self._model.config, log_fn=bool) + same_format = config.optimizer_state and not loaded_metadata.config.compare(self._model.config, log_fn=bool) # Make sure all nodes agree on which loading scheme to use. # Note: they may not agree before the broadcast because of the rank comparison, but that's ok. same_format = broadcast_scalar(same_format, torch.uint8, self._model.distributed.world_group) @@ -67,7 +69,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No log_main_rank("Using legacy distributed checkpoint loader.", log_fn=logger.warning) for shard_name in shard_names: self._model.get_shard(shard_name).copy_( - f.get_slice("state_shard")[metadata.shards.index(shard_name)] + f.get_slice("state_shard")[loaded_metadata.shards.index(shard_name)] ) else: # TODO: Does this copy twice? @@ -76,11 +78,11 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No else: log_main_rank("Checkpoint format doesn't match, using safe load", log_fn=logger.info) - self._model.config.base_model.compare_architecture(metadata.config.base_model, logger.warning) + self._model.config.base_model.compare_architecture(loaded_metadata.config.base_model, logger.warning) with SafeLoad(self._model, shard_names=shard_names, timeout=config.timeout) as context: - for rank in range(metadata.config.distributed.world_size): + for rank in range(loaded_metadata.config.distributed.world_size): loaded_model = self._model.__class__( - metadata.config.to_copy({("distributed", "rank"): rank}), + loaded_metadata.config.to_copy({("distributed", "rank"): rank}), optimizer_state_names=shard_names[1:], verbose=False, ) @@ -94,7 +96,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No # TODO v0.3: Use checkpoint version? Drop support? log_main_rank("Using legacy distributed checkpoint loader.", log_fn=logger.warning) loaded_shards = { - shard_name: f.get_slice("state_shard")[metadata.shards.index(shard_name)] + shard_name: f.get_slice("state_shard")[loaded_metadata.shards.index(shard_name)] for shard_name in shard_names } else: @@ -119,3 +121,5 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No ) context.mark_as_loaded(counter.item()) + + return loaded_metadata.metadata diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 654ba21f..e3b6dcf2 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -226,7 +226,7 @@ def __init__(self, model: "FastLLMModel"): } @classmethod - def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: imported_model_config = cls._import_config(cls._load_config(config.path), True) return CheckpointMetadata( fast_llm_version=__version__, diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 7357b722..a5777d45 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -39,10 +39,11 @@ def _serialize_metadata(self, config: CheckpointSaveMetadataConfig, metadata: Ch "format": "pt", } - def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None: + def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: assert not config.optimizer_state + metadata = self._model.config.load_metadata(config) self._model.config.base_model.compare_architecture(metadata.config.base_model, logger.warning) - super().load(config, metadata) + super().load(config) @classmethod def get_huggingface_model_type(self) -> str: diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index 71c83ece..1bb47e5c 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -73,7 +73,7 @@ def _serialize_metadata( ) -> dict[str, typing.Any]: return metadata.to_dict() - def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None: + def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: with SafeLoad(self._model, shard_names=self.get_shard_names(config), timeout=config.timeout) as context: # The tensor mapping may not be one-to-one. `convert_state_dict` pops all tensors from # `state_dict` that are ready for conversion, @@ -116,7 +116,7 @@ class FastLLMCheckpointHandler(StateDictCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = FastLLMCheckpointFormat @classmethod - def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: + def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: path = config.path / f"metadata.yaml" logger.warning(f"Loading metadata from {path}") return CheckpointMetadata.from_dict(yaml.safe_load(path.open("r"))) diff --git a/fast_llm/engine/huggingface/config.py b/fast_llm/engine/huggingface/config.py index 08070804..d4b46bcc 100644 --- a/fast_llm/engine/huggingface/config.py +++ b/fast_llm/engine/huggingface/config.py @@ -74,7 +74,10 @@ def _get_config_dict( torch_dtype = kwargs.pop("torch_dtype", None) if torch_dtype is not None: updates[("distributed", "training_dtype")] = torch_dtype - fast_llm_config = cls.model_config_class.from_dict(metadata.config, kwargs.pop("fast_llm_config", {}), updates) + fast_llm_config = cls.model_config_class.from_metadata( + pretrained, metadata, default=kwargs.pop("fast_llm_config", None), updates=updates + ) + config_dict = {"fast_llm_config": fast_llm_config} return config_dict, kwargs diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index e2255faa..de26f9bf 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -31,16 +31,15 @@ def save_checkpoint( ) converter.save(config, fast_llm_metadata) - def load_checkpoint(self, config: CheckpointLoadConfig) -> dict[str, typing.Any]: + def load_checkpoint(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None: # TODO: Simplify branching. # TODO: Test with more distributed configs. # TODO: Safety checks # TODO: Handle barriers, ok file, etc. here - metadata = self.config_class.load_metadata(config) converter = config.format.get_handler_class()(self) - converter.load(config, metadata) + metadata = converter.load(config) self._finalize_load(reset_optimizer=not config.optimizer_state) - return metadata.metadata + return metadata @classmethod def from_pretrained( diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index f2ed4a38..c6daa081 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -494,6 +494,7 @@ def _load_checkpoint(self, config: TrainingCheckpointConfig, iteration: int) -> metadata = self._multi_stage.load_checkpoint( config.get_load_config(checkpoint_directory, timeout=self._config.training.timeout) ) + assert metadata is not None self._optimizer.load(metadata["optimizer"]) if "schedules" in metadata: # Backward compatibility. diff --git a/tests/common.py b/tests/common.py index cc749901..9ecb60ff 100644 --- a/tests/common.py +++ b/tests/common.py @@ -54,7 +54,7 @@ "model.base_model.transformer.num_layers=2", "model.base_model.transformer.hidden_size=256", "model.base_model.transformer.num_attention_heads=8", - # "model.base_model.transformer.init_method_std=0.022", + "model.base_model.transformer.init_method_std=0.022", f"model.base_model.vocab_size={TEST_VOCAB_SIZE}", f"model.multi_stage.debug_param_init={_LOG_LEVEL}", f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", @@ -101,7 +101,7 @@ "--global-batch-size=8", "--max-position-embeddings=512", "--seq-length=512", - "--init-method-std=0.0625", + "--init-method-std=0.022", "--lr=0.0001", "--num-workers=0", "--valid-num-workers=0", @@ -394,7 +394,7 @@ def run_test_script( if num_gpus == 1 and not is_megatron: CliTrainingConfig.parse_and_run(script) else: - completed_proc = subprocess.run(command, env=env) + completed_proc = subprocess.run(command, env=env, timeout=30) if completed_proc.returncode: raise RuntimeError(f"Process failed with return code {completed_proc.returncode}") if compare: diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 6793a670..4171581a 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -14,7 +14,7 @@ FastLLMCheckpointFormat, ModelConfigType, ) -from fast_llm.engine.multi_stage.config import StageMode +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.engine.multi_stage.multi_stage import ShardName from fast_llm.models.auto import model_registry from fast_llm.tools.convert import ConversionConfig @@ -246,8 +246,12 @@ def test_converted_huggingface(): assert (h0[key] == h1[key]).all() -def _compare_configs(config_ref, config_test): - config_ref.compare(config_test) +def _compare_model_configs(config_ref: FastLLMModelConfig, config_test: FastLLMModelConfig): + config_ref.base_model.compare(config_test.base_model) + + +def _compare_architectures(config_ref: FastLLMModelConfig, config_test: FastLLMModelConfig): + config_ref.base_model.get_architecture().compare(config_test.base_model.get_architecture()) @pytest.mark.depends(on=["test_converted_distributed"]) @@ -261,7 +265,7 @@ def test_load_pretrained_distributed_checkpoint(): optimizer_state=True, ) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_ref) - _compare_configs(config.base_model, model.config.base_model) + _compare_model_configs(config, model.config) state_shards = safetensors.torch.load_file( _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) ) @@ -271,20 +275,24 @@ def test_load_pretrained_distributed_checkpoint(): @pytest.mark.depends(on=["test_load_pretrained_distributed_checkpoint"]) def test_load_converted_distributed_checkpoint(): - pretrained_config_ref = CheckpointLoadConfig(path=_CKPT_PATH, format=DistributedCheckpointFormat) - pretrained_config_0 = CheckpointLoadConfig( - path=_CONVERT_PATH / "distributed_0", - format=DistributedCheckpointFormat, + config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained( + CheckpointLoadConfig(path=_CKPT_PATH, format=DistributedCheckpointFormat) ) - pretrained_config_1 = CheckpointLoadConfig( - path=_CONVERT_PATH / "distributed_1", - format=DistributedCheckpointFormat, + + model = TEST_MODEL_CLS.from_pretrained( + CheckpointLoadConfig( + path=_CONVERT_PATH / "distributed_0", + format=DistributedCheckpointFormat, + ) ) - config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) - model = TEST_MODEL_CLS.from_pretrained(pretrained_config_0) - config_1 = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_1) - _compare_configs(config.base_model, model.config.base_model) - _compare_configs(config.base_model, config_1.base_model) + config_alt = TEST_MODEL_CONFIG_CLS.from_pretrained( + CheckpointLoadConfig( + path=_CONVERT_PATH / "distributed_1", + format=DistributedCheckpointFormat, + ) + ) + _compare_architectures(config_ref, model.config) + _compare_model_configs(model.config, config_alt) weight_shard = safetensors.torch.load_file( _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) )[WEIGHT_SHARD_SAVE_NAME] @@ -293,14 +301,17 @@ def test_load_converted_distributed_checkpoint(): @pytest.mark.depends(on=["test_converted_fast_llm", "test_load_pretrained_distributed_checkpoint"]) def test_load_converted_fast_llm_checkpoint(): - pretrained_config_ref = CheckpointLoadConfig(path=_CKPT_PATH, format=DistributedCheckpointFormat) - pretrained_config_0 = CheckpointLoadConfig(path=_CONVERT_PATH / "fast_llm_0", format=FastLLMCheckpointFormat) - pretrained_config_1 = CheckpointLoadConfig(path=_CONVERT_PATH / "fast_llm_1", format=FastLLMCheckpointFormat) - config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) - model = TEST_MODEL_CLS.from_pretrained(pretrained_config_0) - config_1 = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_1) - _compare_configs(config.base_model, model.config.base_model) - _compare_configs(config.base_model, config_1.base_model) + config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained( + CheckpointLoadConfig(path=_CKPT_PATH, format=DistributedCheckpointFormat) + ) + model = TEST_MODEL_CLS.from_pretrained( + CheckpointLoadConfig(path=_CONVERT_PATH / "fast_llm_0", format=FastLLMCheckpointFormat) + ) + config_alt = TEST_MODEL_CONFIG_CLS.from_pretrained( + CheckpointLoadConfig(path=_CONVERT_PATH / "fast_llm_1", format=FastLLMCheckpointFormat) + ) + _compare_architectures(config_ref, model.config) + _compare_architectures(config_ref, config_alt) weight_shard = safetensors.torch.load_file( _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) )[WEIGHT_SHARD_SAVE_NAME] @@ -309,23 +320,27 @@ def test_load_converted_fast_llm_checkpoint(): @pytest.mark.depends(on=["test_converted_fast_llm", "test_load_pretrained_distributed_checkpoint"]) def test_load_converted_huggingface_checkpoint(): - pretrained_config_ref = CheckpointLoadConfig( - path=_CKPT_PATH, - format=DistributedCheckpointFormat, + config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained( + CheckpointLoadConfig( + path=_CKPT_PATH, + format=DistributedCheckpointFormat, + ) ) - pretrained_config_0 = CheckpointLoadConfig( - path=_CONVERT_PATH / "huggingface_0", - format=HUGGINGFACE_CHECKPOINT_FORMAT, + model = TEST_MODEL_CLS.from_pretrained( + CheckpointLoadConfig( + path=_CONVERT_PATH / "huggingface_1", + format=HUGGINGFACE_CHECKPOINT_FORMAT, + ), + mode=StageMode.weights, ) - pretrained_config_1 = CheckpointLoadConfig( - path=_CONVERT_PATH / "huggingface_1", - format=HUGGINGFACE_CHECKPOINT_FORMAT, + config_alt = TEST_MODEL_CONFIG_CLS.from_pretrained( + CheckpointLoadConfig( + path=_CONVERT_PATH / "huggingface_0", + format=HUGGINGFACE_CHECKPOINT_FORMAT, + ) ) - config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) - model = TEST_MODEL_CLS.from_pretrained(pretrained_config_0, mode=StageMode.weights) - config_1 = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_1) - _compare_configs(config.base_model, model.config.base_model) - _compare_configs(config.base_model, config_1.base_model) + _compare_architectures(config_ref, model.config) + _compare_model_configs(model.config, config_alt) weight_shard = safetensors.torch.load_file( _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) )[WEIGHT_SHARD_SAVE_NAME] @@ -423,7 +438,7 @@ def test_load_pretrained_in_dp2_match_checkpoint(): ) config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) config_test = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_test) - _compare_configs(config_ref.base_model, config_test.base_model) + _compare_model_configs(config_ref, config_test) shards_ref = safetensors.torch.load_file(_CKPT_PATH / "rank_0.safetensors") shards_test = [safetensors.torch.load_file(test_ckpt_path / f"rank_{i}.safetensors") for i in range(2)] ref_model = TEST_MODEL_CLS(config_ref) @@ -467,7 +482,7 @@ def test_load_distributed_checkpoint_dp2(): ) config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_test, mode=StageMode.weights) - _compare_configs(config.base_model, model.config.base_model) + _compare_model_configs(config, model.config) weight_shard = safetensors.torch.load_file( _CKPT_PATH / "rank_0.safetensors", device=str(model._distributed.device) )[WEIGHT_SHARD_SAVE_NAME] diff --git a/tests/test_config.py b/tests/test_config.py index 79437e9d..ed758965 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -10,6 +10,8 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.transformer.config import AddLinearBiasChoices, TransformerArchitectureConfig, TransformerConfig from fast_llm.models.auto import trainer_registry +from fast_llm.models.gpt.config import GPTModelConfig +from fast_llm.utils import check_equal_nested def run_without_import(cmd: str): @@ -114,8 +116,11 @@ def test_add_attn_dense_bias(): ) -@pytest.mark.parametrize("cls", (GPTSamplingConfig,)) -def test_serialize_default_config_updates(cls): +@pytest.mark.parametrize( + ("cls", "default"), + ((GPTSamplingConfig, {}), (GPTModelConfig, {"distributed": {"world_size": 1, "rank": 0, "local_world_size": 1}})), +) +def test_serialize_default_config_updates(cls, default): # Config classes used as config updates should have a default that serializes to an empty dict # so no value is incorrectly overridden. - assert cls.from_dict({}).to_dict() == {} + check_equal_nested(cls.from_dict({}).to_dict(), default) From baad705d6960d9578a2f5e29664284250d569980 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 3 Apr 2025 19:16:01 -0400 Subject: [PATCH 016/114] fix --- fast_llm/layers/transformer/config.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index a1cb658e..cf409e77 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -186,7 +186,7 @@ class TransformerSubLayerName(str, enum.Enum): @config_class() class TransformerPeftConfig(PeftConfig): layers: list[TransformerSubLayerName] = Field( - default_factory=lambda: [TransformerSubLayerName.query, TransformerSubLayerName.value_], + default=None, desc="The layers on which to apply LoRA.", hint=FieldHint.feature, ) @@ -220,6 +220,15 @@ def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": return parameter def _validate(self) -> None: + if self.layers is None: + with self._set_implicit_default(): + # Setting the default layers only whee PeFT is enabled + # so they don't appear when serializing the default transformer config. + self.layers = ( + [TransformerSubLayerName.query, TransformerSubLayerName.value_] + if self.type == PeftType.lora + else [] + ) if self.type != PeftType.none: if TransformerSubLayerName.mlp_1 in self.layers or TransformerSubLayerName.mlp_2 in self.layers: # TODO: Add MLP support. From b7028378a2f8cb4e6e863ac55af69b0f11f71cff Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 4 Apr 2025 21:48:10 -0400 Subject: [PATCH 017/114] Test, fixes --- fast_llm/engine/checkpoint/config.py | 11 +-- fast_llm/engine/checkpoint/distributed.py | 7 ++ fast_llm/engine/checkpoint/huggingface.py | 6 +- fast_llm/engine/checkpoint/state_dict.py | 17 ++++- fast_llm/engine/multi_stage/config.py | 14 ++-- tests/test_config.py | 84 ++++++++++++++++++++++- 6 files changed, 123 insertions(+), 16 deletions(-) diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index 55440a5c..62928ed0 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -201,9 +201,9 @@ class CheckpointSaveConfig(CheckpointSaveMetadataConfig, CheckpointStateSaveConf @config_class() class CheckpointLoadMetadataConfig(CheckpointPathConfigBase): _abstract = False - + # TODO: Set default to model? (Not backward compatible) load_config: ModelConfigType = Field( - default=ModelConfigType.model, + default=ModelConfigType.architecture, desc="Configuration to save/load.", hint=FieldHint.core, ) @@ -233,7 +233,10 @@ class CheckpointHandler(abc.ABC): def __init__(self, model: "FastLLMModel"): self._model = model - # TODO: save_metadata? + @classmethod + @abc.abstractmethod + def save_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: "CheckpointMetadata"): + pass @classmethod def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata": @@ -245,7 +248,7 @@ def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetad if not config.load_config.load_architecture: updates[("config", "base_model")] = {} elif not config.load_config.load_base_model: - updates[("config", "base_model")] = metadata.config.base_model.get_architecture() + updates[("config", "base_model")] = metadata.config.base_model.get_architecture().to_dict() if updates: metadata = metadata.to_copy(updates) return metadata diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py index ac06df5c..de1625f6 100644 --- a/fast_llm/engine/checkpoint/distributed.py +++ b/fast_llm/engine/checkpoint/distributed.py @@ -12,6 +12,7 @@ CheckpointLoadConfig, CheckpointLoadMetadataConfig, CheckpointSaveConfig, + CheckpointSaveMetadataConfig, DistributedCheckpointFormat, ModelConfigType, export_safetensors_metadata, @@ -27,6 +28,12 @@ class DistributedCheckpointHandler(CheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = DistributedCheckpointFormat + @classmethod + def save_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata): + config.path.mkdir(parents=True, exist_ok=True) + serialized_metadata = metadata.to_dict() + yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w")) + @classmethod def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata: return CheckpointMetadata.from_dict(yaml.safe_load((config.path / "metadata.yaml").open("r"))) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index a5777d45..2972a4fa 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -20,8 +20,10 @@ class HuggingfaceStateDictCheckpointHandler(ExternalStateDictCheckpointHandler, abc.ABC): - def _save_serialized_metadata(self, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None: - path = config.path / f"{self.base_file_name}.safetensors.index.json" + @classmethod + def _save_serialized_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None: + config.path.mkdir(parents=True, exist_ok=True) + path = config.path / f"{cls.base_file_name}.safetensors.index.json" logger.info(f"Saving index to {path}") # Save the index. json.dump( diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index 1bb47e5c..556e97be 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -30,6 +30,13 @@ class StateDictCheckpointHandler(CheckpointHandler): base_file_name: typing.ClassVar[str] = "model" + @classmethod + def save_metadata( + cls, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata, index: dict | None = None + ): + serialized_metadata = cls._serialize_metadata(config, metadata) + cls._save_serialized_metadata(config, serialized_metadata, {} if index is None else index) + def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None: serialized_metadata = self._serialize_metadata(config, metadata) saver = StateDictSaver( @@ -64,12 +71,14 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No if self._model.config.distributed.rank == 0: self._save_serialized_metadata(config, serialized_metadata, index) + @classmethod @abc.abstractmethod - def _save_serialized_metadata(self, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None: + def _save_serialized_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None: pass + @classmethod def _serialize_metadata( - self, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata + cls, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata ) -> dict[str, typing.Any]: return metadata.to_dict() @@ -121,9 +130,11 @@ def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetad logger.warning(f"Loading metadata from {path}") return CheckpointMetadata.from_dict(yaml.safe_load(path.open("r"))) + @classmethod def _save_serialized_metadata( - self, config: CheckpointSaveMetadataConfig, serialized_metadata: dict, index: dict + cls, config: CheckpointSaveMetadataConfig, serialized_metadata: dict, index: dict ) -> None: + config.path.mkdir(parents=True, exist_ok=True) path = config.path / f"metadata.yaml" logger.info(f"Saving metadata to {path}") if "metadata" not in serialized_metadata: diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 43b412fb..6a0c8813 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -187,11 +187,12 @@ class MultiStageConfig(StageConfig): def _validate(self) -> None: super()._validate() if self.zero_stage is not None: - Assert.in_range_incl(self.zero_stage, 1, 3) - if self.zero_stage >= 2: - self.num_grad_buffers = 2 - if self.zero_stage >= 3: - self.num_weight_buffers = 2 + with self._set_implicit_default(): + Assert.in_range_incl(self.zero_stage, 1, 3) + if self.zero_stage >= 2: + self.num_grad_buffers = 2 + if self.zero_stage >= 3: + self.num_weight_buffers = 2 if self.num_grad_buffers is not None: Assert.geq(self.num_grad_buffers, 1) if self.num_weight_buffers is not None: @@ -281,6 +282,9 @@ def to_metadata(self, config: CheckpointSaveMetadataConfig, **kwargs) -> "Checkp **kwargs, ) + def save_metadata(self, config: CheckpointSaveMetadataConfig, **kwargs) -> None: + self.get_checkpoint_handler_class(config.format).save_metadata(config, self.to_metadata(config, **kwargs)) + @config_class() class PretrainedFastLLMModelConfig(Config): diff --git a/tests/test_config.py b/tests/test_config.py index ed758965..79c6738d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -5,13 +5,16 @@ import pytest import yaml +from fast_llm.config import NoAutoValidate from fast_llm.data.dataset.gpt.config import GPTSamplingConfig +from fast_llm.engine.checkpoint.config import CheckpointSaveMetadataConfig, ModelConfigType from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.transformer.config import AddLinearBiasChoices, TransformerArchitectureConfig, TransformerConfig from fast_llm.models.auto import trainer_registry -from fast_llm.models.gpt.config import GPTModelConfig -from fast_llm.utils import check_equal_nested +from fast_llm.models.gpt.config import GPTModelConfig, PretrainedGPTModelConfig +from fast_llm.utils import Assert, check_equal_nested +from tests.common import TEST_RESULTS_PATH def run_without_import(cmd: str): @@ -124,3 +127,80 @@ def test_serialize_default_config_updates(cls, default): # Config classes used as config updates should have a default that serializes to an empty dict # so no value is incorrectly overridden. check_equal_nested(cls.from_dict({}).to_dict(), default) + + +@pytest.mark.parametrize("load_config", tuple(ModelConfigType)) +def test_pretrained_config(load_config: ModelConfigType): + config_path = TEST_RESULTS_PATH / "pretrained_config" + pretrained_model_config = GPTModelConfig.from_dict( + { + "base_model": { + "transformer": { + "normalization": {"type": "rms_norm"}, # Nested + "rotary": {"type": "default"}, + "num_layers": 12, # Default + "hidden_size": 1024, # Default + "window_size": 32, # Non-architecture + "ffn_hidden_size": 4096, # Implicit default, default value + "activation_type": "silu", # Implicit default, non-default value + "head_groups": 4, + }, + "tie_word_embeddings": False, + }, + "multi_stage": {"zero_stage": 3}, + "distributed": {"training_dtype": "bfloat16"}, + } + ) + with NoAutoValidate(): + save_config = CheckpointSaveMetadataConfig.from_dict({"format": "fast_llm", "path": config_path}) + save_config.setup(GPTModelConfig) + save_config.validate() + pretrained_model_config.save_metadata(save_config) + + base_model_update = { + "transformer": { + # rotary: Don't override nested. + "normalization": {"implementation": "triton"}, # Update non-default nested + "peft": {"freeze_others": False}, # Update default nested, non-architecture + "hidden_size": 512, # Override, affects derived value (kv channels) + "head_groups": 1, # Override to default + }, + "vocab_size": 1000, + } + pretrained_config = PretrainedGPTModelConfig.from_dict( + { + "model": { + "base_model": base_model_update, + "distributed": {"seed": 1234, "training_dtype": "float16"}, + }, + "pretrained": {"format": "fast_llm", "path": config_path, "load_config": load_config}, + } + ) + Assert.eq(pretrained_config.model.base_model.transformer.kv_channels, 64) + serialized_config = pretrained_config.model.to_dict() + expected_config = {"distributed": DistributedConfig().to_dict()} + + if load_config == ModelConfigType.fast_llm: + expected_config["multi_stage"] = {"zero_stage": 3} + expected_config["distributed"].update({"seed": 1234, "training_dtype": "float16"}) + if load_config in (ModelConfigType.architecture, ModelConfigType.fast_llm, ModelConfigType.model): + expected_config["base_model"] = { + "transformer": { + "normalization": {"type": "rms_norm", "implementation": "triton"}, + "rotary": {"type": "default"}, + "peft": {"freeze_others": False}, + "num_layers": 12, + "hidden_size": 512, + "ffn_hidden_size": 4096, + "activation_type": "silu", + "head_groups": 1, + }, + "tie_word_embeddings": False, + "vocab_size": 1000, + } + if load_config != ModelConfigType.architecture: + expected_config["base_model"]["transformer"]["window_size"] = 32 + else: + expected_config["base_model"] = base_model_update + + check_equal_nested(serialized_config, expected_config) From a8684f869a3377f13fbf96c87a7fb850aed52757 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 11 Apr 2025 00:28:08 -0400 Subject: [PATCH 018/114] Knowledge distillation, fix cross-entropy --- fast_llm/functional/config.py | 6 + fast_llm/functional/cross_entropy.py | 191 +++++++++++--------- fast_llm/functional/triton/cross_entropy.py | 123 ++++++++++--- tests/test_triton_kernels.py | 73 ++++++-- 4 files changed, 263 insertions(+), 130 deletions(-) diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 9f1fe005..7284ca07 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -91,3 +91,9 @@ class CrossEntropyImpl(str, enum.Enum): torch = "torch" fused = "fused" triton = "triton" + + +class TargetFormat(enum.StrEnum): + labels = "labels" + logits = "logits" + probabilities = "probabilities" diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index e87581f1..62c61e8e 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -3,7 +3,7 @@ import torch.autograd from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce -from fast_llm.functional.config import CrossEntropyImpl +from fast_llm.functional.config import CrossEntropyImpl, TargetFormat from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward from fast_llm.utils import Assert @@ -12,34 +12,65 @@ def torch_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, grad_output: float | None, - logits_scale_factor: float = 1.0, + logits_scale_factor: float, + target_format: TargetFormat, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A wrapper for the pytorch implementation of cross-entropy. The cross-entropy kernels themselves are well-optimized, but the need for explicit casting and separate forward and backward kernels lead to poor performance. - TODO: loss masking only works for this method if the masking index is set to -100. + TODO: loss masking only works for with labels format and if the masking index is set to -100. """ # Torch compile doesn't understand this. - with torch.enable_grad(): - logits_ = logits.float().detach().requires_grad_() - if logits_scale_factor != 1.0: - logits_ *= logits_scale_factor + with torch.set_grad_enabled(grad_output is not None): + logits_ = logits.float().detach().requires_grad_(grad_output is not None) + if target_format == TargetFormat.logits: + target = torch.softmax(target, dim=-1) + loss = torch.nn.functional.cross_entropy( + logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target + ).mean() if grad_output is None: - loss = None + grad = None else: - loss = torch.nn.functional.cross_entropy(logits_, target).mean() loss.backward(torch.full_like(loss, grad_output)) - loss.detach_() - return loss.detach(), logits_.grad.detach().to(logits.dtype) + grad = logits_.grad.detach().to(logits.dtype) + return loss.detach_(), grad + + +# @torch.compile +def _fused_softmax_base( + logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + logits = logits.float() + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + logits_max = torch.max(logits, dim=dim, keepdim=True)[0] + if group is not None: + all_reduce(logits_max, op=ReduceOp.MAX, group=group) + logits_norm = (logits - logits_max).float() + exp_logits = logits_norm.exp() + sum_exp_logits = exp_logits.sum(dim=dim, keepdim=True) + if group is not None: + all_reduce(sum_exp_logits, op=ReduceOp.SUM, group=group) + return logits_norm, exp_logits, sum_exp_logits + + +# @torch.compile +def fused_softmax( + logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup = None, dim: int = -1 +) -> torch.Tensor: + _, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group, dim) + return exp_logits / sum_exp_logits -@torch.compile +# @torch.compile def fused_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, grad_output: float | None, - logits_scale_factor: float = 1.0, + logits_scale_factor: float, + target_format: TargetFormat, + group: ProcessGroup | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A fused implementation of cross-entropy with torch compile. @@ -48,82 +79,67 @@ def fused_cross_entropy_forward_backward( """ # Do the forward and backward passes all at once, and fused with dtype conversion. # Way faster and more memory-efficient than the pytorch version. - loss_mask = target >= 0 - # Ignore_index can go out of bounds, so set masked values to zero. - target = (target * loss_mask).unsqueeze(1) - logits_norm = logits.sub(torch.max(logits, dim=-1)[0].unsqueeze(dim=-1)).float() - if logits_scale_factor != 1.0: - logits_norm *= logits_scale_factor - exp_logits = logits_norm.exp() - sum_exp_logits = exp_logits.sum(dim=-1) - - if grad_output is None: - grad = None - else: - exp_logits = exp_logits.scatter(1, target, exp_logits.gather(1, target) - sum_exp_logits.unsqueeze(dim=-1)) - # exp_logits[torch.arange(0, logits.size(0), device=logits.device), target.squeeze(dim=-1)]-=sum_exp_logits - exp_logits = exp_logits.mul((grad_output / logits.size(0)) / sum_exp_logits.unsqueeze(dim=-1)) - if logits_scale_factor != 1.0: - exp_logits *= logits_scale_factor - - grad = torch.where(loss_mask.unsqueeze(1), exp_logits.to(logits.dtype), 0) - - per_sample_loss = sum_exp_logits.log().sub(logits_norm.gather(1, target).squeeze(1)) * loss_mask + logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) - return per_sample_loss.mean(), grad + if target_format == TargetFormat.logits: + target = fused_softmax(target, logits_scale_factor, group) - -@torch.compile -def parallel_cross_entropy_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - grad_output: float | None, - group: ProcessGroup, - logits_scale_factor: float = 1.0, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - A fused implementation of cross-entropy with torch compile, with support for tensor parallelism. - Comes with a noticeable overhead, but reduces memory usage. - """ - # TODO: Compiled version incorrect for some inputs (32 bit indexing issue?). - # TODO: Optimize, overlap/combine reductions - loss_mask = target >= 0 - target = target.unsqueeze(1) - - logits_max = torch.max(logits, dim=-1)[0] - all_reduce(logits_max, op=ReduceOp.MAX, group=group) - logits_norm = logits.sub(logits_max.unsqueeze(dim=-1)).float() - if logits_scale_factor != 1.0: - logits_norm *= logits_scale_factor - - exp_logits = logits_norm.exp() - sum_exp_logits = exp_logits.sum(dim=-1) - all_reduce(sum_exp_logits, op=ReduceOp.SUM, group=group) - - # Mask the target (fused) - # TODO: Could mask earlier on cpu or overlap with reduce? - vocab_start_index = logits.size(-1) * group.rank() - target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1)) - target = (target - vocab_start_index) * target_mask + if target_format == TargetFormat.labels: + target = target.unsqueeze(-1) + loss_mask = target >= 0 + if group is None: + # Keep values within range for scatter and gather ops to work. + target = target * loss_mask + target_mask = None + else: + # Mask the target (fused) + # TODO: Could mask earlier on cpu or overlap with reduce? + vocab_start_index = logits.size(-1) * group.rank() + target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1)) + target = (target - vocab_start_index) * target_mask + else: + # TODO: Support masking + loss_mask = None + # Target should be tensor-parallel already, no further manipulation needed. + target_mask = None if grad_output is None: grad = None else: - exp_logits1 = exp_logits.scatter( - 1, target, exp_logits.gather(1, target) - target_mask * sum_exp_logits.unsqueeze(dim=-1) - ) - exp_logits2 = exp_logits1.mul((grad_output / logits.size(0)) / sum_exp_logits.unsqueeze(dim=-1)) + # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. + if target_format == TargetFormat.labels: + grad_base = exp_logits.scatter_add( + 1, target, -sum_exp_logits if target_mask is None else -target_mask * sum_exp_logits + ) + else: + grad_base = exp_logits - sum_exp_logits * target + + grad = grad_base.mul((grad_output / logits.size(0)) / sum_exp_logits) if logits_scale_factor != 1.0: - exp_logits2 *= logits_scale_factor + grad *= logits_scale_factor + grad = grad.to(logits.dtype) + if loss_mask is not None: + grad = torch.where(loss_mask, grad.to(logits.dtype), 0) + + # loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) + if target_format == TargetFormat.labels: + predicted_logits = logits_norm.gather(1, target) + if group is not None: + predicted_logits = target_mask * predicted_logits + all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) + else: + predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True) - grad = torch.where(loss_mask.unsqueeze(1), exp_logits2.to(logits.dtype), 0) + per_sample_loss = sum_exp_logits.log() - predicted_logits + if loss_mask is not None: + per_sample_loss = per_sample_loss * loss_mask - predicted_logits = (target_mask * logits_norm.gather(1, target)).squeeze(1) - all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) - per_sample_loss = sum_exp_logits.log().sub(predicted_logits) * loss_mask + loss = per_sample_loss.mean() + if target_format != TargetFormat.labels and group is not None: + all_reduce(loss, op=ReduceOp.MEAN, group=group) - return per_sample_loss.mean(), grad + return loss, grad _CROSS_ENTROPY_IMPLEMENTATIONS = { @@ -134,12 +150,13 @@ def parallel_cross_entropy_forward_backward( def cross_entropy_forward_backward( - logits, - target, + logits: torch.Tensor, + target: torch.Tensor, grad_output: float | None, - group: ProcessGroup | None, + group: ProcessGroup | None = None, implementation: CrossEntropyImpl = CrossEntropyImpl.fused, logits_scale_factor: float = 1.0, + target_format: TargetFormat = TargetFormat.labels, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Select the appropriate implementation of cross-entropy. @@ -147,12 +164,18 @@ def cross_entropy_forward_backward( It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way, which is faster and has a relatively small memory overhead. """ + if target_format == TargetFormat.labels: + Assert.eq(target.shape, logits.shape[:-1]) + Assert.eq(target.dtype, torch.int64) + else: + Assert.eq(target.shape, logits.shape) + assert target.dtype.is_floating_point, target.dtype if group: Assert.eq(implementation, CrossEntropyImpl.fused) - return parallel_cross_entropy_forward_backward( - logits, target, grad_output, group, logits_scale_factor=logits_scale_factor + return fused_cross_entropy_forward_backward( + logits, target, grad_output, logits_scale_factor, target_format, group ) else: return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( - logits, target, grad_output, logits_scale_factor=logits_scale_factor + logits, target, grad_output, logits_scale_factor, target_format ) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 8b622849..321bd0fa 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -1,6 +1,6 @@ import torch -from fast_llm.functional.config import TritonConfig +from fast_llm.functional.config import TargetFormat, TritonConfig from fast_llm.functional.triton import tl, tl_constexpr, triton, triton_jit @@ -11,9 +11,9 @@ def triton_cross_entropy_forward_backward_kernel( grad_logits_ptr, losses_ptr, grad_losses, - n_cols, - logits_stride_0, - grad_logits_stride_0, + n_cols: tl_constexpr, + logits_stride_0: tl_constexpr, + grad_logits_stride_0: tl_constexpr, logits_scale_factor: tl_constexpr, block_size: tl_constexpr, ): @@ -33,27 +33,78 @@ def triton_cross_entropy_forward_backward_kernel( label_idx = tl.load(labels_ptr + block_idx) - label_logits = tl.load(logits_ptr + label_idx).to(tl.float32) if label_idx < 0: + # Loss mask loss = 0.0 else: + label_logits = tl.load(logits_ptr + label_idx).to(tl.float32) + if logits_scale_factor != 1.0: + label_logits *= logits_scale_factor loss = tl.log(sum_exp_logits) + max_logits - label_logits tl.store(losses_ptr + block_idx, loss) - grad_logits_ptr = grad_logits_ptr + block_idx * grad_logits_stride_0 + if grad_losses is not None: + if label_idx < 0: + grad_losses = 0.0 + grad_base = exp_logits / sum_exp_logits + grad_logits = grad_losses * tl.where(col_offsets == label_idx, grad_base - 1.0, grad_base) + if logits_scale_factor != 1.0: + grad_logits *= logits_scale_factor + tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) + + +@triton_jit() +def triton_cross_entropy_from_distribution_forward_backward_kernel( + logits_ptr, + target_ptr, + grad_logits_ptr, + losses_ptr, + grad_losses, + n_cols: tl_constexpr, + logits_stride_0: tl_constexpr, + grad_logits_stride_0: tl_constexpr, + logits_scale_factor: tl_constexpr, + from_logits: tl_constexpr, + block_size: tl_constexpr, +): + # TODO: Int64 ptr only if needed? + block_idx = tl.program_id(0).to(tl.int64) col_offsets = tl.arange(0, block_size) - label_idx = tl.load(labels_ptr + block_idx) - exp_logits = exp_logits / sum_exp_logits + logits_ptr = logits_ptr + block_idx * logits_stride_0 + mask = col_offsets < n_cols + + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + + max_logits = tl.max(logits, 0) + exp_logits = tl.exp(logits - max_logits) + sum_exp_logits = tl.sum(exp_logits, 0) + + target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if from_logits: + max_target_logits = tl.max(logits, 0) + exp_target_logits = tl.exp(target - max_target_logits) + sum_exp_target_logits = tl.sum(exp_target_logits, 0) + target = exp_target_logits / sum_exp_target_logits + + # per_sample_loss = log(sum_exp_logits) - sum(probabilities * logits) + loss = tl.log(sum_exp_logits) - tl.sum(target * logits, 0) + tl.store(losses_ptr + block_idx, loss) + + # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. if logits_scale_factor != 1.0: exp_logits *= logits_scale_factor - if label_idx < 0: - grad_losses = 0.0 - grad_logits = grad_losses * tl.where(col_offsets == label_idx, exp_logits - 1.0, exp_logits) - tl.store(grad_logits_ptr + col_offsets, grad_logits, mask=mask) + grad_logits = grad_losses * (exp_logits / sum_exp_logits - target) + tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) def triton_cross_entropy_forward_backward( - logits, target, grad_output: float | None, logits_scale_factor: float = 1.0 + logits: torch.Tensor, + target: torch.Tensor, + grad_output: float | None, + logits_scale_factor: float, + target_format: TargetFormat, ) -> tuple[torch.Tensor, torch.Tensor]: """ A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes, @@ -72,18 +123,34 @@ def triton_cross_entropy_forward_backward( num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16) losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) # TODO: Safe to do inplace? - grad_logits = torch.empty_like(logits) - triton_cross_entropy_forward_backward_kernel[(n_rows,)]( - logits, - target, - grad_logits, - losses, - 1 if grad_output is None else grad_output / n_rows, - n_cols, - logits.stride(0), - grad_logits.stride(0), - logits_scale_factor, - block_size=block_size, - num_warps=num_warps, - ) - return losses.mean(), None if grad_output is None else grad_logits + grad_logits = None if grad_output is None else torch.empty_like(logits) + if target_format == TargetFormat.labels: + triton_cross_entropy_forward_backward_kernel[(n_rows,)]( + logits, + target, + grad_logits, + losses, + None if grad_output is None else grad_output / n_rows, + n_cols, + logits.stride(0), + None if grad_output is None else grad_logits.stride(0), + logits_scale_factor, + block_size=block_size, + num_warps=num_warps, + ) + else: + triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( + logits, + target, + grad_logits, + losses, + None if grad_output is None else grad_output / n_rows, + n_cols, + logits.stride(0), + None if grad_output is None else grad_logits.stride(0), + logits_scale_factor, + block_size=block_size, + num_warps=num_warps, + from_logits=target_format == TargetFormat.logits, + ) + return losses.mean(), grad_logits diff --git a/tests/test_triton_kernels.py b/tests/test_triton_kernels.py index e61c2d51..e52073aa 100644 --- a/tests/test_triton_kernels.py +++ b/tests/test_triton_kernels.py @@ -1,14 +1,20 @@ import pytest import torch -from fast_llm.functional.config import MAX_DROPLESS_BLOCK_SIZE_ROW, ActivationType, TritonConfig +from fast_llm.functional.config import ( + MAX_DROPLESS_BLOCK_SIZE_ROW, + ActivationType, + CrossEntropyImpl, + TargetFormat, + TritonConfig, +) +from fast_llm.functional.cross_entropy import cross_entropy_forward_backward from fast_llm.functional.rotary import ( apply_rotary_embeddings, convert_rotary_complex_to_real, convert_rotary_real_to_complex, ) from fast_llm.functional.triton.adam import triton_adam -from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward from fast_llm.functional.triton.mlp import ( torch_mlp_activation, triton_mlp_activation_backward, @@ -184,24 +190,55 @@ def test_triton_mlp_activation(gated, activation_type, recompute): @requires_cuda -def test_triton_cross_entropy(): +@pytest.mark.parametrize( + ("num_columns", "grad_output", "logits_scale_factor"), + ( + (8192, 1.0, 1.0), + (8192, None, 1.0), + (8192, 1.0, 4.0), + (8192, 4.0, 1.0), + (65536, 1.0, 1.0), + (131072, 1.0, 1.0), + ), +) +@pytest.mark.parametrize("target_format", (TargetFormat.labels,)) # TargetFormat.logits, TargetFormat.probabilities)) +def test_cross_entropy(num_columns, grad_output, logits_scale_factor, target_format): + # TODO: Test tensor-parallel implementation. assert TritonConfig.TRITON_ENABLED - logits = torch.randn(1024, 8192, dtype=torch.bfloat16, device="cuda", requires_grad=True) - labels = torch.randint(0, 8192, (1024,), dtype=torch.int64, device="cuda") - - from fast_llm.functional.cross_entropy import ( - fused_cross_entropy_forward_backward, - torch_cross_entropy_forward_backward, - ) - - c1, g1 = torch_cross_entropy_forward_backward(logits, labels, 1) - c2, g2 = fused_cross_entropy_forward_backward(logits, labels, 1) - c3, g3 = triton_cross_entropy_forward_backward(logits, labels, 1) + logits = torch.randn(256, num_columns, dtype=torch.bfloat16, device="cuda", requires_grad=True) + if target_format == TargetFormat.labels: + target = torch.randint(0, num_columns, (256,), dtype=torch.int64, device="cuda") + else: + target = torch.randn(256, num_columns, dtype=torch.bfloat16, device="cuda") + + kwargs = { + "logits": logits, + "target": target, + "grad_output": grad_output, + "logits_scale_factor": logits_scale_factor, + "target_format": target_format, + } + # Torch serves as the reference implementation. + out_torch, grad_torch = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.torch) + + out_fused, grad_fused = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.fused) + Assert.rms_close(out_fused, out_torch, 5e-3) + if grad_output is None: + assert grad_torch is None + assert grad_fused is None + else: + Assert.rms_close(grad_fused, grad_torch, 5e-3) - Assert.rms_close(c2, c3, 5e-3) - Assert.rms_close(c1, c3, 5e-3) - Assert.rms_close(g1, g3, 5e-3) - Assert.rms_close(g2, g3, 5e-3) + if target_format == TargetFormat.probabilities or num_columns > 65536: + with pytest.raises(AssertionError): + cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) + else: + out_triton, grad_triton = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) + Assert.rms_close(out_triton, out_torch, 5e-3) + if grad_output is None: + assert grad_triton is None + else: + Assert.rms_close(grad_triton, grad_torch, 5e-3) @requires_cuda From b781729d684b4c2415585277f333afc75999874d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Sun, 13 Apr 2025 11:15:17 -0400 Subject: [PATCH 019/114] Fixes, distillation --- fast_llm/functional/cross_entropy.py | 4 +- fast_llm/functional/triton/cross_entropy.py | 32 +++++++++------ fast_llm/layers/language_model/config.py | 11 +++++- fast_llm/layers/language_model/head.py | 43 +++++++++++++-------- fast_llm/models/gpt/config.py | 7 +++- fast_llm/models/gpt/model.py | 2 +- fast_llm/utils.py | 2 +- tests/test_triton_kernels.py | 14 +++++-- 8 files changed, 78 insertions(+), 37 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 62c61e8e..0a611832 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -25,6 +25,8 @@ def torch_cross_entropy_forward_backward( with torch.set_grad_enabled(grad_output is not None): logits_ = logits.float().detach().requires_grad_(grad_output is not None) if target_format == TargetFormat.logits: + if logits_scale_factor != 1.0: + target = target * logits_scale_factor target = torch.softmax(target, dim=-1) loss = torch.nn.functional.cross_entropy( logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target @@ -63,7 +65,7 @@ def fused_softmax( return exp_logits / sum_exp_logits -# @torch.compile +@torch.compile def fused_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 321bd0fa..62ed2e0e 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -62,6 +62,7 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( grad_losses, n_cols: tl_constexpr, logits_stride_0: tl_constexpr, + target_stride_0: tl_constexpr, grad_logits_stride_0: tl_constexpr, logits_scale_factor: tl_constexpr, from_logits: tl_constexpr, @@ -70,33 +71,40 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( # TODO: Int64 ptr only if needed? block_idx = tl.program_id(0).to(tl.int64) col_offsets = tl.arange(0, block_size) - logits_ptr = logits_ptr + block_idx * logits_stride_0 mask = col_offsets < n_cols - logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + logits = tl.load(logits_ptr + block_idx * logits_stride_0 + col_offsets, mask=mask, other=-float("inf")).to( + tl.float32 + ) if logits_scale_factor != 1.0: logits *= logits_scale_factor max_logits = tl.max(logits, 0) - exp_logits = tl.exp(logits - max_logits) + logits_norm = logits - max_logits + exp_logits = tl.exp(logits_norm) sum_exp_logits = tl.sum(exp_logits, 0) - target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + target = tl.load(target_ptr + block_idx * target_stride_0 + col_offsets, mask=mask, other=-float("inf")).to( + tl.float32 + ) if from_logits: - max_target_logits = tl.max(logits, 0) + if logits_scale_factor != 1.0: + target *= logits_scale_factor + max_target_logits = tl.max(target, 0) exp_target_logits = tl.exp(target - max_target_logits) sum_exp_target_logits = tl.sum(exp_target_logits, 0) target = exp_target_logits / sum_exp_target_logits # per_sample_loss = log(sum_exp_logits) - sum(probabilities * logits) - loss = tl.log(sum_exp_logits) - tl.sum(target * logits, 0) + loss = tl.log(sum_exp_logits) - tl.sum(target * logits_norm, 0) tl.store(losses_ptr + block_idx, loss) - # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. - if logits_scale_factor != 1.0: - exp_logits *= logits_scale_factor - grad_logits = grad_losses * (exp_logits / sum_exp_logits - target) - tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) + if grad_losses is not None: + # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. + grad_logits = grad_losses * (exp_logits / sum_exp_logits - target) + if logits_scale_factor != 1.0: + grad_logits *= logits_scale_factor + tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) def triton_cross_entropy_forward_backward( @@ -117,7 +125,6 @@ def triton_cross_entropy_forward_backward( assert logits.is_contiguous() assert target.is_contiguous() n_rows, n_cols = logits.shape - assert target.shape == (n_rows,) block_size = triton.next_power_of_2(n_cols) assert block_size <= TritonConfig.MAX_BLOCK_SIZE_BYTES num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16) @@ -147,6 +154,7 @@ def triton_cross_entropy_forward_backward( None if grad_output is None else grad_output / n_rows, n_cols, logits.stride(0), + target.stride(0), None if grad_output is None else grad_logits.stride(0), logits_scale_factor, block_size=block_size, diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 3bd79603..22cce43a 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -151,6 +151,12 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) + distillation_model: str | None = Field( + default=None, + desc="Name of the reference model to use for knowledge distillation." + "If provided, replace the loss with a distillation loss.", + hint=FieldHint.feature, + ) # Tensor-parallel word embeddings # (Default init std is different, dropout won't match, needs seq_first = False.) # (disable to allow for sequence-parallel embeddings and logits, better for larger models) @@ -195,6 +201,9 @@ def _validate(self) -> None: self.init_method_max_embed = self.transformer.init_method_max if self.init_method_min_embed is None: self.init_method_min_embed = self.transformer.init_method_min + super()._validate() if self.init_method_max_embed is not None and self.init_method_min_embed is not None: Assert.leq(self.init_method_min_embed, self.init_method_max_embed) - super()._validate() + if self.distillation_model is not None: + if self.prediction_heads > 1: + raise NotImplementedError("Multi-token prediction not supported with distillation.") diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 1286121c..04e4020f 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -10,7 +10,7 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward -from fast_llm.functional.config import CrossEntropyImpl, TritonConfig +from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig from fast_llm.functional.cross_entropy import cross_entropy_forward_backward from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss @@ -145,12 +145,22 @@ def forward( def _forward_backward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None ) -> tuple[torch.Tensor, torch.Tensor | None]: - labels = kwargs[LanguageModelKwargs.labels] if LanguageModelKwargs.labels in kwargs else None - # MTP: Shift the labels - labels = labels[:, self._prediction_distance :].flatten() if labels is not None else None + target = kwargs.get( + LanguageModelKwargs.labels + if self._config.distillation_model is None + else f"{self._config.distillation_model}_logits" + ) + if target is not None: + if self._config.distillation_model is None: + # Target is labels (token ids) + # MTP: Shift the labels + target = target[:, self._prediction_distance :].flatten() + else: + # Target is reference model logits. + target = target.flatten(0, -2) if self._sequence_parallel_logits: - labels = split_op(labels, self._tensor_space.distributed.tensor_group, 0) - do_grad = labels is not None and self.training + target = split_op(target, self._tensor_space.distributed.tensor_group, 0) + do_grad = target is not None and self.training input_ = input_.detach().requires_grad_(do_grad) with torch.enable_grad(): # MTP: truncate the input @@ -166,7 +176,7 @@ def _forward_backward( output_weights = self._get_output_weights(kwargs) loss, ln_output_grad = self._logits_cross_entropy_forward_backward_split( - ln_output.detach(), labels, output_weights, grad_output, kwargs, losses + ln_output.detach(), target, output_weights, grad_output, kwargs, losses ) if do_grad: @@ -185,29 +195,29 @@ def _get_output_weights(self, kwargs: dict) -> torch.Tensor: def _logits_cross_entropy_forward_backward_split( self, input_: torch.Tensor, - labels: torch.Tensor | None, + target: torch.Tensor | None, weight: torch.Tensor, grad_output: float, kwargs: dict, losses: dict | None = None, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: - if self._cross_entropy_splits is None or labels is None: + if self._cross_entropy_splits is None or target is None: loss, logit_input_grad = self._logits_cross_entropy_forward_backward( - input_, labels, weight, grad_output, kwargs, losses + input_, target, weight, grad_output, kwargs, losses ) - if labels is None: + if target is None: # TODO: Make a proper way of returning the model output. kwargs["logits"] = loss return None, None else: loss = None # TODO MTP: allow a _cross_entropy_splits that is not a divisor of the sequence length - split_size = div(labels.numel(), self._cross_entropy_splits) + split_size = div(target.size(0), self._cross_entropy_splits) grad_output /= self._cross_entropy_splits logit_input = input_.flatten(0, -2) logit_input_grad = torch.empty_like(logit_input) for logit_input_, labels_, logit_input_grad_ in zip( - logit_input.split(split_size), labels.split(split_size), logit_input_grad.split(split_size) + logit_input.split(split_size), target.split(split_size), logit_input_grad.split(split_size) ): loss_, grad_ = self._logits_cross_entropy_forward_backward( logit_input_, @@ -231,7 +241,7 @@ def _logits_cross_entropy_forward_backward_split( def _logits_cross_entropy_forward_backward( self, input_: torch.Tensor, - labels: torch.Tensor | None, + target: torch.Tensor | None, weight: torch.Tensor, grad_output: float, kwargs: dict, @@ -285,15 +295,16 @@ def _logits_cross_entropy_forward_backward( scale=self._logits_scale_factor, ) - if labels is None: + if target is None: return logits * self._logits_scale_factor, None loss, grad = cross_entropy_forward_backward( logits.flatten(0, -2), - labels, + target, group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, grad_output=grad_output, implementation=self._cross_entropy_impl, logits_scale_factor=self._logits_scale_factor, + target_format=TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits, ) # TODO: de-allocate earlier. del logits diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 19c8e6ac..09c3e757 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -139,9 +139,14 @@ def _validate(self) -> None: self.batch.sequence_length = self.model.base_model.max_position_embeddings if self.model.base_model.use_megatron_initialization: set_megatron_distributed_seeds(self.model.distributed) + super()._validate() + if (name := self.model.base_model.distillation_model) is None: + Assert.empty(self.reference_models) + else: + Assert.eq(self.reference_models.keys(), {name}) for reference_model in self.reference_models.values(): Assert.none(reference_model.model.base_model.cross_entropy_splits) - super()._validate() + Assert.none(reference_model.model.base_model.distillation_model) @classmethod def get_trainer_class(cls) -> type["GPTTrainer"]: diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index c672b216..c0eabc45 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -269,7 +269,7 @@ def preprocess( labels = batch.token_ids[sequence_offset : sequence_k + 1] else: # TODO: Avoid multiple contiguous calls? - labels = batch.token_ids[:, sequence_k - sequence_q + 1 : sequence_k + 1].contiguous() + labels = batch.token_ids[:, sequence_offset : sequence_k + 1].contiguous() # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss # TODO: take ignore_index from config if batch.loss_masking_spans is not None: diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 0a4ce007..2499676c 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -86,7 +86,7 @@ class Assert: @staticmethod def eq(x, *args, msg=None): for arg in args: - assert x == arg, f"{x} != {arg} " + f"| {msg}" if msg else "" + assert x == arg, f"{x} != {arg} " + (f"| {msg}" if msg else "") @staticmethod def is_(x, y): diff --git a/tests/test_triton_kernels.py b/tests/test_triton_kernels.py index e52073aa..b6970ddf 100644 --- a/tests/test_triton_kernels.py +++ b/tests/test_triton_kernels.py @@ -190,6 +190,7 @@ def test_triton_mlp_activation(gated, activation_type, recompute): @requires_cuda +@pytest.mark.slow @pytest.mark.parametrize( ("num_columns", "grad_output", "logits_scale_factor"), ( @@ -201,15 +202,20 @@ def test_triton_mlp_activation(gated, activation_type, recompute): (131072, 1.0, 1.0), ), ) -@pytest.mark.parametrize("target_format", (TargetFormat.labels,)) # TargetFormat.logits, TargetFormat.probabilities)) +@pytest.mark.parametrize("target_format", (TargetFormat.labels, TargetFormat.logits, TargetFormat.probabilities)) def test_cross_entropy(num_columns, grad_output, logits_scale_factor, target_format): # TODO: Test tensor-parallel implementation. assert TritonConfig.TRITON_ENABLED - logits = torch.randn(256, num_columns, dtype=torch.bfloat16, device="cuda", requires_grad=True) + # We want something moderately close to the target for the test to be meaningful + logits_var = torch.randn(256, num_columns, dtype=torch.bfloat16, device="cuda") / 3 if target_format == TargetFormat.labels: target = torch.randint(0, num_columns, (256,), dtype=torch.int64, device="cuda") + logits = (torch.nn.functional.one_hot(target, num_columns) + logits_var).requires_grad_() else: target = torch.randn(256, num_columns, dtype=torch.bfloat16, device="cuda") + logits = (target + logits_var).requires_grad_() + if target_format == TargetFormat.probabilities: + target = torch.softmax(target, -1) kwargs = { "logits": logits, @@ -229,16 +235,16 @@ def test_cross_entropy(num_columns, grad_output, logits_scale_factor, target_for else: Assert.rms_close(grad_fused, grad_torch, 5e-3) - if target_format == TargetFormat.probabilities or num_columns > 65536: + if num_columns > 65536: with pytest.raises(AssertionError): cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) else: out_triton, grad_triton = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) - Assert.rms_close(out_triton, out_torch, 5e-3) if grad_output is None: assert grad_triton is None else: Assert.rms_close(grad_triton, grad_torch, 5e-3) + Assert.rms_close(out_triton, out_torch, 5e-3) @requires_cuda From db6504b0546b8462e79173e5f03e960c2d694d6d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 14 Apr 2025 14:55:18 -0400 Subject: [PATCH 020/114] fixes --- fast_llm/config.py | 1 + fast_llm/engine/multi_stage/config.py | 3 ++- fast_llm/engine/training/config.py | 8 +++----- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index f1c88965..443925ce 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -30,6 +30,7 @@ def __enter__(self): global _AUTO_VALIDATE self._old_value = _AUTO_VALIDATE _AUTO_VALIDATE = False + return _AUTO_VALIDATE def __exit__(self, exc_type, exc_val, exc_tb): global _AUTO_VALIDATE diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index e6de074f..ee94ce61 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -336,7 +336,8 @@ def _validate(self) -> None: self.pretrained.setup(self.model) self.pretrained.validate() if self.pretrained.path is not None: - self.model = self.model.from_pretrained(self.pretrained, default=self.model) + with NoAutoValidate(): + self.model = self.model.from_pretrained(self.pretrained, default=self.model) self._setup() super()._validate() diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 4f5164b0..9819ced3 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -383,17 +383,15 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): def _validate(self) -> None: self.training.export.setup(self.model) - self.model.validate() + for reference_model in self.reference_models.values(): + _add_reference_distributed_to_pretrained(reference_model, self.model.distributed) + super()._validate() if self.reference_models: # TODO: Add support. Assert.eq(self.model.distributed.pipeline_parallel, 1) # TODO: Check if these work. Assert.eq(self.model.distributed.tensor_parallel, 1) Assert.eq(self.model.distributed.sequence_data_parallel, 1) - - for reference_model in self.reference_models.values(): - _add_reference_distributed_to_pretrained(reference_model, self.model.distributed) - super()._validate() if self.run.experiment_dir is None: assert not self.training.checkpoint.enabled() From cff9892d44a9380a992f33692500ed7e08191824 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 14 Apr 2025 18:11:05 -0400 Subject: [PATCH 021/114] fixes --- fast_llm/engine/inference/huggingface.py | 4 ++- tests/test_checkpoint.py | 33 +++++++++++++++++++++--- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 75aea9dd..196310b4 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -60,7 +60,9 @@ def from_pretrained( updates[("distributed", "training_dtype")] = torch_dtype # Create the model - fast_llm_model = cls.model_class.from_pretrained(pretrained_model_name_or_path, updates, mode=mode) + fast_llm_model = cls.runner_class.model_class.from_pretrained( + pretrained_model_name_or_path, updates, mode=mode + ) config = cls.config_class(fast_llm_model.config) return cls(config, fast_llm_model, **kwargs) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 4171581a..0c5e177d 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -263,6 +263,7 @@ def test_load_pretrained_distributed_checkpoint(): path=_CKPT_PATH, format=DistributedCheckpointFormat, optimizer_state=True, + load_config=ModelConfigType.model, ) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_ref) _compare_model_configs(config, model.config) @@ -276,19 +277,25 @@ def test_load_pretrained_distributed_checkpoint(): @pytest.mark.depends(on=["test_load_pretrained_distributed_checkpoint"]) def test_load_converted_distributed_checkpoint(): config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained( - CheckpointLoadConfig(path=_CKPT_PATH, format=DistributedCheckpointFormat) + CheckpointLoadConfig( + path=_CKPT_PATH, + format=DistributedCheckpointFormat, + load_config=ModelConfigType.model, + ) ) model = TEST_MODEL_CLS.from_pretrained( CheckpointLoadConfig( path=_CONVERT_PATH / "distributed_0", format=DistributedCheckpointFormat, + load_config=ModelConfigType.model, ) ) config_alt = TEST_MODEL_CONFIG_CLS.from_pretrained( CheckpointLoadConfig( path=_CONVERT_PATH / "distributed_1", format=DistributedCheckpointFormat, + load_config=ModelConfigType.model, ) ) _compare_architectures(config_ref, model.config) @@ -302,13 +309,25 @@ def test_load_converted_distributed_checkpoint(): @pytest.mark.depends(on=["test_converted_fast_llm", "test_load_pretrained_distributed_checkpoint"]) def test_load_converted_fast_llm_checkpoint(): config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained( - CheckpointLoadConfig(path=_CKPT_PATH, format=DistributedCheckpointFormat) + CheckpointLoadConfig( + path=_CKPT_PATH, + format=DistributedCheckpointFormat, + load_config=ModelConfigType.model, + ) ) model = TEST_MODEL_CLS.from_pretrained( - CheckpointLoadConfig(path=_CONVERT_PATH / "fast_llm_0", format=FastLLMCheckpointFormat) + CheckpointLoadConfig( + path=_CONVERT_PATH / "fast_llm_0", + format=FastLLMCheckpointFormat, + load_config=ModelConfigType.model, + ) ) config_alt = TEST_MODEL_CONFIG_CLS.from_pretrained( - CheckpointLoadConfig(path=_CONVERT_PATH / "fast_llm_1", format=FastLLMCheckpointFormat) + CheckpointLoadConfig( + path=_CONVERT_PATH / "fast_llm_1", + format=FastLLMCheckpointFormat, + load_config=ModelConfigType.model, + ) ) _compare_architectures(config_ref, model.config) _compare_architectures(config_ref, config_alt) @@ -324,12 +343,14 @@ def test_load_converted_huggingface_checkpoint(): CheckpointLoadConfig( path=_CKPT_PATH, format=DistributedCheckpointFormat, + load_config=ModelConfigType.model, ) ) model = TEST_MODEL_CLS.from_pretrained( CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_1", format=HUGGINGFACE_CHECKPOINT_FORMAT, + load_config=ModelConfigType.model, ), mode=StageMode.weights, ) @@ -337,6 +358,7 @@ def test_load_converted_huggingface_checkpoint(): CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_0", format=HUGGINGFACE_CHECKPOINT_FORMAT, + load_config=ModelConfigType.model, ) ) _compare_architectures(config_ref, model.config) @@ -353,6 +375,7 @@ def test_run_converted_model(): CheckpointLoadConfig( path=_CKPT_PATH, format=DistributedCheckpointFormat, + load_config=ModelConfigType.model, ) ) test_input = torch.randint( @@ -364,6 +387,7 @@ def test_run_converted_model(): CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_0", format=HUGGINGFACE_CHECKPOINT_FORMAT, + load_config=ModelConfigType.model, ) ) errors = [] @@ -479,6 +503,7 @@ def test_load_distributed_checkpoint_dp2(): pretrained_config_test = CheckpointLoadConfig( path=TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" / "checkpoint" / "1", format=DistributedCheckpointFormat, + load_config=ModelConfigType.model, ) config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_test, mode=StageMode.weights) From b67006a0ea650145357e023d6c4f517a4b7de2c2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 14 Apr 2025 20:33:57 -0400 Subject: [PATCH 022/114] fixes --- fast_llm/engine/distributed/config.py | 3 ++- fast_llm/engine/training/config.py | 7 +++---- fast_llm/functional/triton/cross_entropy.py | 2 +- tests/test_triton_kernels.py | 13 +++++++------ 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 8f04a705..66d89e1a 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -293,7 +293,6 @@ def _validate(self) -> None: if self.reference_config.reference_config is not None: self.reference_config = self.reference_config.reference_config assert self.reference_config.reference_config is None - self.compare(self.reference_config, ValueError) self.distributed_dims = self.reference_config.distributed_dims else: self.distributed_dims = {} @@ -368,6 +367,8 @@ def _validate(self) -> None: super()._validate() + if self.reference_config is not None: + self.compare(self.reference_config, ValueError) Assert.in_range(self.rank, 0, self.world_size) Assert.in_range(self.local_rank, 0, self.local_world_size) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 578937fb..8b4cadc3 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -429,13 +429,12 @@ def _add_reference_distributed_to_pretrained(pretrained: PretrainedFastLLMModelC def new_setup(): # Make sure the distributed config isn't set - # TODO!!!!!!!!!!!!!: Uncomment after #205 - # pretrained.model.distributed.validate() - # Assert.leq(pretrained.model.distributed.to_dict().keys(), {"world_size", "rank", "local_world_size"}) + pretrained.model.distributed.validate() + Assert.leq(pretrained.model.distributed.to_dict().keys(), {"world_size", "rank", "local_world_size"}) with NoAutoValidate(): pretrained.model.distributed = distributed.to_copy() # Allow sharing the `Distributed` instance. pretrained.model.distributed.reference_config = distributed old_setup() - pretrained._setup = new_setup + object.__setattr__(pretrained, "_setup", new_setup) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 62ed2e0e..d825af03 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -96,7 +96,7 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( target = exp_target_logits / sum_exp_target_logits # per_sample_loss = log(sum_exp_logits) - sum(probabilities * logits) - loss = tl.log(sum_exp_logits) - tl.sum(target * logits_norm, 0) + loss = tl.log(sum_exp_logits) - tl.sum(tl.where(mask, target * logits_norm, 0), 0) tl.store(losses_ptr + block_idx, loss) if grad_losses is not None: diff --git a/tests/test_triton_kernels.py b/tests/test_triton_kernels.py index b6970ddf..1ace81d7 100644 --- a/tests/test_triton_kernels.py +++ b/tests/test_triton_kernels.py @@ -194,12 +194,13 @@ def test_triton_mlp_activation(gated, activation_type, recompute): @pytest.mark.parametrize( ("num_columns", "grad_output", "logits_scale_factor"), ( - (8192, 1.0, 1.0), - (8192, None, 1.0), - (8192, 1.0, 4.0), - (8192, 4.0, 1.0), - (65536, 1.0, 1.0), - (131072, 1.0, 1.0), + (8192, 1.0, 1.0), # Simple + (5000, 1.0, 1.0), # Not a power of 2 + (5000, None, 1.0), # No grad + (5000, 1.0, 4.0), # Loss scaling + (5000, 4.0, 1.0), # Grad scaling + (65536, 1.0, 1.0), # Max block size + (65537, 1.0, 1.0), # Above max block size ), ) @pytest.mark.parametrize("target_format", (TargetFormat.labels, TargetFormat.logits, TargetFormat.probabilities)) From 20141081adc406e3315cce9825df5b58b9630258 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 16 Apr 2025 11:13:00 -0400 Subject: [PATCH 023/114] Add constraints --- fast_llm/models/gpt/config.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 09c3e757..f0c314d6 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -144,9 +144,37 @@ def _validate(self) -> None: Assert.empty(self.reference_models) else: Assert.eq(self.reference_models.keys(), {name}) + if self.model.base_model.use_absolute_position_embeddings: + Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) for reference_model in self.reference_models.values(): - Assert.none(reference_model.model.base_model.cross_entropy_splits) Assert.none(reference_model.model.base_model.distillation_model) + # TODO: Support more LM head features. + Assert.none(reference_model.model.base_model.cross_entropy_splits) + Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) + Assert.geq(reference_model.model.base_model.prediction_heads, self.model.base_model.prediction_heads) + # TODO: Support distinct preprocessing + reference_model.model.base_model.transformer.rotary.compare( + self.model.base_model.transformer.rotary, + NotImplementedError, + ) + Assert.eq( + reference_model.model.base_model.use_absolute_position_embeddings, + self.model.base_model.use_absolute_position_embeddings, + ) + if reference_model.model.base_model.use_absolute_position_embeddings: + assert self.model.base_model.use_absolute_position_embeddings + Assert.geq( + reference_model.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length + ) + use_flash = reference_model.model.base_model.transformer.do_use_flash_attention( + reference_model.model.distributed + ) + Assert.eq(use_flash, self.model.base_model.transformer.do_use_flash_attention(self.model.distributed)) + if use_flash: + Assert.eq( + reference_model.model.base_model.transformer.window_size, + self.model.base_model.transformer.window_size, + ) @classmethod def get_trainer_class(cls) -> type["GPTTrainer"]: From fa3d556f371a75c29da53e082474c47a557dfa29 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 16 Apr 2025 12:35:16 -0400 Subject: [PATCH 024/114] Add constraints --- fast_llm/models/gpt/config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index b02dd7be..705e9918 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -186,6 +186,9 @@ def _validate(self) -> None: Assert.eq(self.reference_models.keys(), {name}) if self.model.base_model.use_absolute_position_embeddings: Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) + if self.model.base_model.distillation_model is not None: + # TODO: Support loss masking for distillation? + assert not self.batch.use_loss_masking_spans for reference_model in self.reference_models.values(): Assert.none(reference_model.model.base_model.distillation_model) # TODO: Support more LM head features. From 6c2c887b47dd1220dc626ac6edb817e004e0173d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 16 Apr 2025 14:30:09 -0400 Subject: [PATCH 025/114] Separate reference model preprocessing --- fast_llm/engine/base_model/base_model.py | 27 +++++++++---- fast_llm/engine/training/trainer.py | 9 +---- fast_llm/models/gpt/config.py | 23 ----------- fast_llm/models/gpt/model.py | 50 ++++++++++++++++-------- fast_llm/models/gpt/trainer.py | 23 ----------- 5 files changed, 54 insertions(+), 78 deletions(-) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 76da0f9b..3835b190 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -6,13 +6,16 @@ import torch.nn from fast_llm.config import Configurable -from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig, Preprocessor +from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.engine.inference.runner import InferenceRunner + class Module(torch.nn.Module, abc.ABC): """ """ @@ -80,6 +83,7 @@ def get_layers(self) -> list[Layer]: class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], SequentialLayers, abc.ABC): config_class: typing.ClassVar[type[BaseModelConfig]] = BaseModelConfig + _is_setup: bool = False def __init__( self, @@ -96,6 +100,16 @@ def __init__( # Rename to the parameter full name value.tensor_name = key + # Reference models + # TODO: Add basic handling (preprocessor) in this class. + self._reference_models: dict[str, "InferenceRunner"] = {} + + def setup(self, distributed: Distributed) -> None: + assert not self._is_setup + distributed.check_config(self._tensor_space.distributed_config) + self._tensor_space.setup(distributed) + self._is_setup = True + @classmethod def architecture_cls(cls) -> type[BaseModelArchitectureConfig]: return cls.config_class.architecture_class @@ -104,10 +118,6 @@ def architecture_cls(cls) -> type[BaseModelArchitectureConfig]: def get_layers(self) -> list[Layer]: pass - @abc.abstractmethod - def setup(self, distributed: Distributed) -> None: - pass - @abc.abstractmethod def preprocess_meta(self, batch_meta: typing.Any, phase: PhaseType) -> list[tuple[TensorMeta, dict]]: pass @@ -136,6 +146,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: def loss_defs(self) -> list[LossDef]: pass - def add_preprocessor(self, preprocessor: Preprocessor): - # TODO: Generalize preprocessors. - raise NotImplementedError() + def add_reference_model(self, name: str, inference_runner: InferenceRunner) -> None: + assert name not in self._reference_models + assert not self._is_setup + self._reference_models[name] = inference_runner diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 66f1ad86..abd8f9dc 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -12,11 +12,9 @@ from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data from fast_llm.data.dataset.config import SamplingParameters -from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.run import Run, is_main_rank, log_main_rank, log_pipeline_parallel_main_rank from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.config import StageMode from fast_llm.engine.optimizer.config import ParamGroup from fast_llm.engine.optimizer.optimizer import Optimizer @@ -55,9 +53,7 @@ def __init__(self, config: TrainerConfig): self._reference_models[name] = self._config.get_inference_runner_class()( reference_config.model.get_model_class()(reference_config.model) ) - self._multi_stage.base_model.add_preprocessor( - self._get_reference_model_preprocessor(name, self._reference_models[name]) - ) + self._multi_stage.base_model.add_reference_model(name, self._reference_models[name]) phase: PhaseType self._runner = ScheduleRunner( @@ -562,6 +558,3 @@ def _get_last_checkpoint(self) -> int | None: def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: # TODO: Do in model, automate/generalize, get other stats pass - - def _get_reference_model_preprocessor(self, name: str, inference_runner: InferenceRunner) -> Preprocessor: - raise NotImplementedError() diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 705e9918..e6230116 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -195,29 +195,6 @@ def _validate(self) -> None: Assert.none(reference_model.model.base_model.cross_entropy_splits) Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) Assert.geq(reference_model.model.base_model.prediction_heads, self.model.base_model.prediction_heads) - # TODO: Support distinct preprocessing - reference_model.model.base_model.transformer.rotary.compare( - self.model.base_model.transformer.rotary, - NotImplementedError, - ) - Assert.eq( - reference_model.model.base_model.use_absolute_position_embeddings, - self.model.base_model.use_absolute_position_embeddings, - ) - if reference_model.model.base_model.use_absolute_position_embeddings: - assert self.model.base_model.use_absolute_position_embeddings - Assert.geq( - reference_model.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length - ) - use_flash = reference_model.model.base_model.transformer.do_use_flash_attention( - reference_model.model.distributed - ) - Assert.eq(use_flash, self.model.base_model.transformer.do_use_flash_attention(self.model.distributed)) - if use_flash: - Assert.eq( - reference_model.model.base_model.transformer.window_size, - self.model.base_model.transformer.window_size, - ) @classmethod def _from_dict( diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 55b08f2e..77faa8a3 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -8,7 +8,6 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType -from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames @@ -41,7 +40,6 @@ class GPTBaseModel[ConfigType: GPTBaseModelConfig](BaseModel[ConfigType]): """ config_class: typing.ClassVar[type[GPTBaseModelConfig]] = GPTBaseModelConfig - _is_setup: bool = False _rotary_embedding_frequencies: torch.Tensor _position_ids: torch.Tensor _mask: torch.Tensor @@ -59,6 +57,7 @@ def __init__( for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) param.init_parameter = get_init_megatron(param, self._config.transformer) # Noqa + # `self._reference_models` is not populated at this point, so we pass a mutable dict. self._preprocessors: list[Preprocessor] = [] if self._config.use_absolute_position_embeddings: self._preprocessors.append(PositionEmbeddingPreprocessor(self._config, self._tensor_space)) @@ -113,12 +112,6 @@ def get_layers(self) -> list[Layer]: *self.get_output_layers(), ] - def setup(self, distributed: Distributed) -> None: - assert not self._is_setup - distributed.check_config(self._tensor_space.distributed_config) - self._tensor_space.setup(distributed) - self._is_setup = True - def preprocess_meta( self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType ) -> list[tuple[TensorMeta, dict]]: @@ -186,12 +179,20 @@ def preprocess_meta( TransformerKwargs.sequence_q_dim: sequence_q_dim, } - preprocessed_meta = [] - for sequence_k_past in range( + sequence_k_pasts = range( sequence_q_dim.size * self._tensor_space.distributed_config.sequence_data_rank, sequence_length, micro_sequence_length, - ): + ) + reference_preprocessed_metas = {} + for name, reference_model in self._reference_models.items(): + reference_preprocessed_metas[name] = reference_model.fast_llm_model.base_model.preprocess_meta( + batch_meta, phase + ) + Assert.eq(len(reference_preprocessed_metas[name]), len(sequence_k_pasts)) + + preprocessed_meta = [] + for i, sequence_k_past in enumerate(sequence_k_pasts): sequence_k = sequence_k_past + sequence_q_dim.size sequence_k_dim = TensorDim(TransformerDimNames.sequence_k, sequence_k) @@ -209,6 +210,15 @@ def preprocess_meta( ) for preprocessor in self._preprocessors: preprocessor.preprocess_meta(kwargs) + reference_kwargs = {} + for name, reference_preprocessed_meta in reference_preprocessed_metas.items(): + reference_tokens, reference_kwargs_ = reference_preprocessed_meta[i] + for key, value in common_kwargs.items(): + Assert.eq(reference_kwargs_[key], value) + Assert.eq(reference_kwargs_[TransformerKwargs.sequence_k_dim], sequence_k_dim) + reference_kwargs[name] = reference_kwargs_ + kwargs["reference_models"] = reference_kwargs + preprocessed_meta.append((tokens, kwargs)) return preprocessed_meta @@ -237,13 +247,22 @@ def preprocess( dtype=torch.int64, non_blocking=True, ) + + reference_logits = {} + for name, reference_model in self._reference_models.items(): + reference_logits[name] = [] + for _, kwargs_meta in preprocessed_meta: + reference_tokens, reference_kwargs = kwargs_meta["reference_models"][name] + reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) + reference_logits[name].append(reference_kwargs["logits"]) + if sequence_first: # Move the sequence dimension first to make sequence parallel ops more efficient. batch.token_ids = batch.token_ids.transpose(0, 1).contiguous() preprocessed = [] presents = None - for i, (tokens_meta, kwargs_meta) in enumerate(preprocessed_meta): + for i, (_, kwargs_meta) in enumerate(preprocessed_meta): sequence_k = kwargs_meta[TransformerKwargs.sequence_k_dim].size if sequence_first: tokens = batch.token_ids[sequence_k - sequence_q : sequence_k] @@ -286,6 +305,9 @@ def preprocess( else: labels[i, start : end + 1] = -100 kwargs[LanguageModelKwargs.labels] = labels + for name, reference_logits_ in reference_logits.items(): + kwargs[f"{name}_logits"] = reference_logits_[i] + for preprocessor in self._preprocessors: preprocessor.preprocess(tokens, kwargs) preprocessed.append((tokens, kwargs)) @@ -361,10 +383,6 @@ def loss_defs(self) -> list[LossDef]: ) return loss_defs - def add_preprocessor(self, preprocessor: Preprocessor): - assert not self._is_setup - self._preprocessors.append(preprocessor) - class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): config_class: typing.ClassVar[type[GPTModelConfig]] = GPTModelConfig diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index a269f5a6..57327b27 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -3,33 +3,13 @@ from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.gpt.config import GPTSamplingParameters -from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.training.trainer import Trainer -from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.models.gpt.config import GPTTrainerConfig -from fast_llm.models.gpt.model import GPTInferenceRunner logger = logging.getLogger(__name__) -class GPTReferenceModelPreprocessor(Preprocessor): - def __init__(self, name: str, inference_runner: GPTInferenceRunner): - self._name = name - self._inference_runner = inference_runner - - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - pass - - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - # TODO: Fix random state/iteration. - preprocess_kwargs = kwargs.copy() - del preprocess_kwargs[LanguageModelKwargs.labels] - self._inference_runner.forward(batch, preprocess_kwargs, iteration=1) - # TODO: Improve. - kwargs[f"{self._name}_logits"] = preprocess_kwargs["logits"] - - class GPTTrainer[ConfigType: GPTTrainerConfig](Trainer[ConfigType]): config_class: typing.ClassVar[type[GPTTrainerConfig]] = GPTTrainerConfig @@ -101,6 +81,3 @@ def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, hardware_flops = flops_per_iteration + 7 / 6 * attn_flops ratio = elapsed_time_per_iteration * self._config.model.distributed.world_size * 1e12 return model_tflops / ratio, hardware_flops / ratio - - def _get_reference_model_preprocessor(self, name: str, inference_runner: GPTInferenceRunner) -> Preprocessor: - return GPTReferenceModelPreprocessor(name, inference_runner) From 67f9db637242498329f932dcf0650b4829e6599e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 16 Apr 2025 16:10:19 -0400 Subject: [PATCH 026/114] fix --- fast_llm/engine/base_model/base_model.py | 2 +- fast_llm/engine/multi_stage/config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 3835b190..2dbf8cc8 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -146,7 +146,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: def loss_defs(self) -> list[LossDef]: pass - def add_reference_model(self, name: str, inference_runner: InferenceRunner) -> None: + def add_reference_model(self, name: str, inference_runner: "InferenceRunner") -> None: assert name not in self._reference_models assert not self._is_setup self._reference_models[name] = inference_runner diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 69bf3695..e2d04f80 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -30,7 +30,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.engine.inference.model import HuggingfacePreTrainedModel + from fast_llm.engine.inference.huggingface import HuggingfacePreTrainedModel from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel logger = logging.getLogger(__name__) From 537deca2a96353b149908aa14c0dbeb8b2188bc3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 17 Apr 2025 15:47:47 -0400 Subject: [PATCH 027/114] fix --- fast_llm/functional/cross_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 0a611832..1eb6c8c0 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -112,7 +112,7 @@ def fused_cross_entropy_forward_backward( # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. if target_format == TargetFormat.labels: grad_base = exp_logits.scatter_add( - 1, target, -sum_exp_logits if target_mask is None else -target_mask * sum_exp_logits + 1, target, -sum_exp_logits if target_mask is None else -(target_mask * sum_exp_logits) ) else: grad_base = exp_logits - sum_exp_logits * target From d2b3154f22c574ec4b36b4cc706a42964b55684e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 21 Apr 2025 10:49:24 -0400 Subject: [PATCH 028/114] misc --- fast_llm/layers/language_model/head.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 8edf1bc8..c2974415 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -50,9 +50,7 @@ def __init__( self._group_size = tensor_space.distributed_config.tensor_parallel self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings - self._sequence_parallel_logits = ( - tensor_space.distributed_config.sequence_tensor_parallel and not config.parallel_embeddings - ) + self._sequence_parallel_logits = self._sequence_parallel and not self._parallel_embeddings self._cross_entropy_splits = config.cross_entropy_splits if self._cross_entropy_splits is not None and self._sequence_parallel: assert not self._parallel_embeddings @@ -215,12 +213,12 @@ def _logits_cross_entropy_forward_backward_split( grad_output /= self._cross_entropy_splits logit_input = input_.flatten(0, -2) logit_input_grad = torch.empty_like(logit_input) - for logit_input_, labels_, logit_input_grad_ in zip( + for logit_input_, target_, logit_input_grad_ in zip( logit_input.split(split_size), target.split(split_size), logit_input_grad.split(split_size) ): loss_, grad_ = self._logits_cross_entropy_forward_backward( logit_input_, - labels_, + target_, weight, grad_output, kwargs, From a0ba05161cd03443b2b247e6db98322e5285b8b4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 25 Apr 2025 15:50:50 -0400 Subject: [PATCH 029/114] fixes --- fast_llm/engine/config_utils/tensor_space.py | 3 ++ fast_llm/engine/training/config.py | 30 ++++++++-------- fast_llm/models/gpt/model.py | 36 +++++++++++++------- 3 files changed, 42 insertions(+), 27 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 0384fdac..5020bc65 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -27,6 +27,9 @@ def __repr__(self) -> str: f")" ) + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + @property def name(self) -> str: return self._name diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 8b4cadc3..1e990e9c 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -23,7 +23,6 @@ DistributedCheckpointFormat, ) from fast_llm.engine.config_utils.run import ExperimentConfig -from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.multi_stage.config import PretrainedFastLLMModelConfig from fast_llm.engine.optimizer.config import OptimizerConfig from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig @@ -386,7 +385,7 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): def _validate(self) -> None: self.training.export.setup(self.model) for reference_model in self.reference_models.values(): - _add_reference_distributed_to_pretrained(reference_model, self.model.distributed) + self._add_reference_distributed_to_pretrained(reference_model) super()._validate() if self.reference_models: # TODO: Add support. @@ -396,6 +395,8 @@ def _validate(self) -> None: Assert.eq(self.model.distributed.sequence_data_parallel, 1) if self.run.experiment_dir is None: assert not self.training.checkpoint.enabled() + for reference_model in self.reference_models.values(): + assert reference_model.model.distributed.reference_config is self.model.distributed def _setup(self): super()._setup() @@ -423,18 +424,17 @@ def runnable(): return runnable + def _add_reference_distributed_to_pretrained(self, pretrained: PretrainedFastLLMModelConfig): + old_setup = pretrained._setup -def _add_reference_distributed_to_pretrained(pretrained: PretrainedFastLLMModelConfig, distributed: DistributedConfig): - old_setup = pretrained._setup - - def new_setup(): - # Make sure the distributed config isn't set - pretrained.model.distributed.validate() - Assert.leq(pretrained.model.distributed.to_dict().keys(), {"world_size", "rank", "local_world_size"}) - with NoAutoValidate(): - pretrained.model.distributed = distributed.to_copy() - # Allow sharing the `Distributed` instance. - pretrained.model.distributed.reference_config = distributed - old_setup() + def new_setup(): + # Make sure the distributed config isn't set + pretrained.model.distributed.validate() + Assert.leq(pretrained.model.distributed.to_dict().keys(), {"world_size", "rank", "local_world_size"}) + with NoAutoValidate(): + pretrained.model.distributed = self.model.distributed.to_copy() + # Allow sharing the `Distributed` instance. + pretrained.model.distributed.reference_config = self.model.distributed + old_setup() - object.__setattr__(pretrained, "_setup", new_setup) + object.__setattr__(pretrained, "_setup", new_setup) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index c955eec5..5c408a60 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -125,7 +125,7 @@ def preprocess_meta( else: micro_batch_size, sequence_length = batch_meta.shape if phase != PhaseType.inference: - sequence_length -= 1 + sequence_length -= self._config.prediction_heads micro_sequence_length = sequence_length batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) @@ -187,7 +187,7 @@ def preprocess_meta( reference_preprocessed_metas = {} for name, reference_model in self._reference_models.items(): reference_preprocessed_metas[name] = reference_model.fast_llm_model.base_model.preprocess_meta( - batch_meta, phase + batch_meta, PhaseType.inference ) Assert.eq(len(reference_preprocessed_metas[name]), len(sequence_k_pasts)) @@ -213,9 +213,14 @@ def preprocess_meta( reference_kwargs = {} for name, reference_preprocessed_meta in reference_preprocessed_metas.items(): reference_tokens, reference_kwargs_ = reference_preprocessed_meta[i] - for key, value in common_kwargs.items(): - Assert.eq(reference_kwargs_[key], value) - Assert.eq(reference_kwargs_[TransformerKwargs.sequence_k_dim], sequence_k_dim) + for key in ( + TransformerKwargs.sequence_first, + TransformerKwargs.hidden_dims, + TransformerKwargs.sequence_length, + TransformerKwargs.sequence_q_dim, + TransformerKwargs.sequence_k_dim, + ): + Assert.eq(reference_kwargs_[key], kwargs[key]) reference_kwargs[name] = reference_kwargs_ kwargs["reference_models"] = reference_kwargs @@ -249,13 +254,21 @@ def preprocess( non_blocking=True, ) - reference_logits = {} + reference_logits = [{} for _ in preprocessed_meta] for name, reference_model in self._reference_models.items(): - reference_logits[name] = [] - for _, kwargs_meta in preprocessed_meta: - reference_tokens, reference_kwargs = kwargs_meta["reference_models"][name] + reference_preprocessed_meta = [ + (tokens_meta, kwargs_meta["reference_models"][name]) for tokens_meta, kwargs_meta in preprocessed_meta + ] + + reference_batch = reference_model.fast_llm_model.base_model.preprocess( + batch, reference_preprocessed_meta, phase=PhaseType.inference, iteration=iteration + ) + + # TODO: Do things work with >1? + Assert.eq(len(reference_batch), len(preprocessed_meta), 1) + for i, (reference_tokens, reference_kwargs) in enumerate(reference_batch): reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) - reference_logits[name].append(reference_kwargs["logits"]) + reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] if sequence_first: # Move the sequence dimension first to make sequence parallel ops more efficient. @@ -308,8 +321,7 @@ def preprocess( else: labels[i, start : end + 1] = -100 kwargs[LanguageModelKwargs.labels] = labels - for name, reference_logits_ in reference_logits.items(): - kwargs[f"{name}_logits"] = reference_logits_[i] + kwargs.update(reference_logits[i]) for preprocessor in self._preprocessors: preprocessor.preprocess(tokens, kwargs) From 9ddfb69115d6d84d960e76c36d4e65994eae76cc Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 25 Apr 2025 21:43:23 +0000 Subject: [PATCH 030/114] add per-layer lr-scale --- fast_llm/layers/common/config.py | 3 ++- fast_llm/layers/common/normalization.py | 5 +++++ fast_llm/layers/language_model/config.py | 13 +++++++++++++ fast_llm/layers/language_model/embedding.py | 2 ++ fast_llm/layers/language_model/head.py | 1 + fast_llm/layers/transformer/attention.py | 13 ++++++++----- fast_llm/layers/transformer/config.py | 6 ++++++ .../layers/transformer/mixture_of_experts.py | 9 ++++++--- fast_llm/layers/transformer/mlp.py | 15 ++++++++++----- fast_llm/layers/transformer/transformer.py | 5 +++-- fast_llm/models/gpt/model.py | 2 +- fast_llm/utils.py | 18 ++++++++++++++++++ 12 files changed, 75 insertions(+), 17 deletions(-) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 71c15c9b..6e596751 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -82,7 +82,7 @@ class NormalizationConfig(NormalizationArchitectureConfig, BaseModelConfig): valid=check_field(Assert.geq, 0), ) - def get_layer(self, hidden_dim: "TensorDim") -> "LayerNorm | RMSNorm": + def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": from fast_llm.layers.common.normalization import LayerNorm, RMSNorm from fast_llm.tensor import init_uniform_ @@ -91,6 +91,7 @@ def get_layer(self, hidden_dim: "TensorDim") -> "LayerNorm | RMSNorm": "eps": self.epsilon, "implementation": self.implementation, "zero_centered": self.zero_centered, + "lr_scale": lr_scale, } if self.initialization_range: mean = 0 if self.zero_centered else 1 diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index 04123014..984778f8 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -152,6 +152,7 @@ def __init__( weight_init_method=None, bias_init_method=init_zeros_, zero_centered: bool = False, + lr_scale: float | None = None, ): super().__init__() assert hidden_dim.parallel_dim is None @@ -190,12 +191,14 @@ def __init__( init_method=weight_init_method, weight_decay=False, auto_grad_accumulation=implementation == NormalizationImplementation.torch, + lr_scale=lr_scale, ) self.bias = ParameterMeta.from_dims( (hidden_dim,), init_method=bias_init_method, weight_decay=False, auto_grad_accumulation=implementation == NormalizationImplementation.torch, + lr_scale=lr_scale, ) self.normalized_shape = self.weight.shape @@ -230,6 +233,7 @@ def __init__( implementation: NormalizationImplementation = NormalizationImplementation.auto, weight_init_method=None, zero_centered: bool = False, + lr_scale: float | None = None, ): super().__init__() assert hidden_dim.parallel_dim is None @@ -263,6 +267,7 @@ def __init__( init_method=weight_init_method, weight_decay=False, auto_grad_accumulation=True, + lr_scale=lr_scale, ) self.normalized_shape = self.weight.shape diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index b4b4e187..c99ee4f6 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -202,6 +202,19 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + embeddings_lr_scale: float | None = Field( + default=None, + desc="Learning rate scale for the word embeddings.", + doc="May be used to freeze some layers by setting their scale to zero.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + output_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the output weights.", + doc="May be used to freeze the output weights by setting their scale to zero.", + hint=FieldHint.feature, + ) def _validate(self) -> None: self.transformer.validate() diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 1d9406ed..e0386d8d 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -62,6 +62,7 @@ def __init__( min_val=config.init_method_min_embed, max_val=config.init_method_max_embed, ), + lr_scale=config.embeddings_lr_scale, ) if self._use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( @@ -72,6 +73,7 @@ def __init__( max_val=config.init_method_max_embed, ), allow_sequence_tensor_parallel=not config.parallel_embeddings, + lr_scale=config.embeddings_lr_scale, ) # PEFT. diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index c2974415..1153fb2c 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -102,6 +102,7 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: min_val=config.init_method_min_embed, max_val=config.init_method_max_embed, ), + lr_scale=config.output_lr_scale, ) def forward( diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index c7ae55c5..54fff228 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -17,7 +17,7 @@ ) from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ -from fast_llm.utils import Assert +from fast_llm.utils import Assert, get_lr_scale try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -84,7 +84,7 @@ def __init__( super().__init__() self._config = config self._tensor_space = tensor_space - Assert.in_range_incl(layer_index, 1, self._config.num_layers) + # Assert.in_range_incl(layer_index, 1, self._config.num_layers) self._layer_index = layer_index self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel self._debug_transformer = self._config.debug_transformer @@ -110,6 +110,9 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) + # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, @@ -118,7 +121,7 @@ def __init__( weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=self._config.attention_lr_scale, + lr_scale=attention_lr_scale, ) self.key_value = OutputParallelLinear( hidden_dim, @@ -127,7 +130,7 @@ def __init__( weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=self._config.attention_lr_scale, + lr_scale=attention_lr_scale, ) self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward) @@ -139,7 +142,7 @@ def __init__( weight_init_method=init_method_std_attn_proj, bias_init_method=init_method_std_attn_proj if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=self._config.attention_lr_scale, + lr_scale=attention_lr_scale, ) # PEFT. diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index cf409e77..c13c2a09 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -636,6 +636,12 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): doc="May be used to freeze some experts by setting their scale to zero.", hint=FieldHint.feature, ) + per_layer_lr_scale: list[float] | None = Field( + default=None, + desc="Custom learning rate scale for each layer.", + doc="May be used to freeze some layers by setting their scale to zero.", + hint=FieldHint.feature, + ) router_lr_scale: float | None = Field( default=None, desc="Custom learning rate for the MoE router weight.", diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 85c6686f..49778c63 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -21,7 +21,7 @@ from fast_llm.layers.transformer.mlp import MLPBase from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta, init_normal_ -from fast_llm.utils import Assert +from fast_llm.utils import Assert, get_lr_scale logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." @@ -59,6 +59,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._z_loss_factor = config.expert_z_loss_coefficient self._moe_jitter_eps = config.moe_jitter_eps + layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) + self.router = Linear( tensor_space.get_tensor_dim(TransformerDimNames.hidden), tensor_space.get_tensor_dim(TransformerDimNames.unshared_experts), @@ -66,7 +69,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s weight_init_method=init_normal_( std=config.init_method_std, min_val=config.init_method_min, max_val=config.init_method_max ), - lr_scale=config.router_lr_scale, + lr_scale=router_lr_scale, ) dropless_moe = config.dropless_moe if dropless_moe and tensor_space.distributed_config.sequence_tensor_parallel: diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index 1c38705f..c4d8afdc 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -10,13 +10,14 @@ from fast_llm.layers.common.linear import LinearBase from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerSubLayerName from fast_llm.tensor import init_normal_, init_zeros_ -from fast_llm.utils import Assert +from fast_llm.utils import Assert, get_lr_scale class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): super().__init__() self._name = name + self._layer_index = layer_index init_method_1 = init_normal_( std=config.init_method_std_mlp_1, @@ -38,6 +39,10 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._activation_type = config.activation_type self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation + layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale + lr_scale = get_lr_scale(lr_scale, layer_lr_scale) + # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, @@ -45,7 +50,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, - lr_scale=tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale, + lr_scale=lr_scale, ) self.layer_2 = LinearBase( self._intermediate_dim, @@ -55,7 +60,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s bias_init_method=init_method_2 if config.random_bias_init else init_zeros_, auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1, transposed_weight=True, - lr_scale=tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale, + lr_scale=lr_scale, ) # PEFT. @@ -64,7 +69,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): Assert.eq(config.num_experts, 1) super().__init__(config, tensor_space, name) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 92df1893..9e1e0bcf 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -39,8 +39,9 @@ def __init__( self._layer_index = layer_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) - self.norm_1 = self._config.normalization.get_layer(hidden_dim) - self.norm_2 = self._config.normalization.get_layer(hidden_dim) + layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + self.norm_1 = self._config.normalization.get_layer(hidden_dim, lr_scale=layer_lr_scale) + self.norm_2 = self._config.normalization.get_layer(hidden_dim, lr_scale=layer_lr_scale) self._create_mixer() diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index a7ec58d6..873c8f80 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -80,7 +80,7 @@ def get_output_layers(self) -> list[Layer]: self._config.transformer, self._tensor_space, # TODO MTP: which index? - layer_index=self._config.transformer.num_layers, + layer_index=self._config.transformer.num_layers + i, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1, diff --git a/fast_llm/utils.py b/fast_llm/utils.py index a8c5eac6..c524a315 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -326,3 +326,21 @@ def compare_nested(config_a, config_b, errors: list | None = None, prefix: tuple def check_equal_nested(config_a, config_b): if errors := compare_nested(config_a, config_b): raise ValueError("\n".join(errors)) + + +def get_lr_scale( + lr_scale: float | None | tuple[float | None, ...], layer_lr_scale: float | None +) -> float | None | tuple[float | None, ...]: + """ + Combine module and layer lr_scale. + If one is None, return the other. + """ + if lr_scale is None: + return layer_lr_scale + if layer_lr_scale is None: + return lr_scale + if isinstance(lr_scale, float): + return lr_scale * layer_lr_scale + if isinstance(lr_scale, tuple): + return tuple(lrs * layer_lr_scale if lrs is not None else layer_lr_scale for lrs in lr_scale) + raise ValueError(f"Invalid lr_scale: {lr_scale} (type {type(lr_scale)})") From 5e282cc0369d40ecbd81963fecadd7c699317b2a Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 28 Apr 2025 16:24:17 +0000 Subject: [PATCH 031/114] modeling mtp llamba --- .../ssm/external/configuration_mtp_llamba.py | 94 +++++ .../models/ssm/external/discrete_mamba2.py | 382 +++++++++++++++++ .../ssm/external/modeling_mtp_llamba.py | 389 ++++++++++++++++++ 3 files changed, 865 insertions(+) create mode 100644 fast_llm/models/ssm/external/configuration_mtp_llamba.py create mode 100644 fast_llm/models/ssm/external/discrete_mamba2.py create mode 100644 fast_llm/models/ssm/external/modeling_mtp_llamba.py diff --git a/fast_llm/models/ssm/external/configuration_mtp_llamba.py b/fast_llm/models/ssm/external/configuration_mtp_llamba.py new file mode 100644 index 00000000..b8173b73 --- /dev/null +++ b/fast_llm/models/ssm/external/configuration_mtp_llamba.py @@ -0,0 +1,94 @@ +from enum import Enum + +from transformers.configuration_utils import PretrainedConfig + + +class StateUpdateKernel(Enum): + ssu_verification = "ssu_verification" # selective scan for multi-token verification, not implemented yet + cs = "chunk_scan" # see https://proceedings.mlr.press/v262/wu24a.html + ssu = "standard" # usual one token per time-step inference using selective-scan update, no verification + + +class MTPLlambaConfig(PretrainedConfig): + r"""Configuration class for the CustomMamba model. + + This configuration is used to instantiate the CustomMamba model according to the specified arguments, + defining the model architecture. + + Args: + vocab_size (`int`, *optional*, defaults to 128256): + Vocabulary size of the model. + tie_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + pad_vocab_size_multiple (`int`, *optional*, defaults to 8): + Pad the vocabulary size up to the next multiple of this value. + lm_head_bias (`bool`, *optional*, defaults to `False`): + Whether the LM head includes a bias term. + d_model (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + lm_head_prenorm (`str`, *optional*, defaults to "rms"): + Normalization type for LM head. + n_layer (`int`, *optional*, defaults to 32): + Number of layers in the model. + resid_dropout (`float`, *optional*, defaults to 0.0): + Dropout rate for residual connections. + norm_epsilon (`float`, *optional*, defaults to 1e-5): + Epsilon value used for normalization layers. + mlp_cfg (`dict`, *optional*): + Configuration for the MLP (Multi-Layer Perceptron) layer, including intermediate size, activation function, and whether to use bias. + ssm_cfg (`dict`, *optional*): + Configuration for the SSM (State Space Model) layer, including d_state, number of heads, expansion, and other parameters. + + """ + + model_type = "llamba" + + def __init__( + self, + vocab_size: int, + d_model: int, + tie_embeddings: bool = False, + pad_vocab_size_multiple: int = 8, + lm_head_bias: bool = False, + n_layer: int = 32, + resid_dropout: float = 0.0, + norm_epsilon: float = 1e-5, + mlp_cfg: dict = None, + ssm_cfg: dict = None, + prediction_heads=1, + state_update_kernel: StateUpdateKernel = StateUpdateKernel.cs, + **kwargs, + ): + super().__init__(**kwargs) + + self.vocab_size = vocab_size + self.tie_embeddings = tie_embeddings + self.pad_vocab_size_multiple = pad_vocab_size_multiple + self.lm_head_bias = lm_head_bias + self.d_model = d_model + self.n_layer = n_layer + self.resid_dropout = resid_dropout + self.norm_epsilon = norm_epsilon + self.prediction_heads = prediction_heads + assert ( + state_update_kernel != StateUpdateKernel.ssu_verification + ), "Only chunk scan and standard modes are supported for now" + self.state_update_kernel = state_update_kernel + + # MLP (Multi-Layer Perceptron) Config + self.mlp_cfg = mlp_cfg or { + "intermediate_size": 14336, + "bias": False, + "act_fn": "silu", + } + + # SSM (State Space Model) Config + self.ssm_cfg = ssm_cfg or { + "d_state": 64, + "n_v_heads": 32, + "n_qk_heads": 32, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + } diff --git a/fast_llm/models/ssm/external/discrete_mamba2.py b/fast_llm/models/ssm/external/discrete_mamba2.py new file mode 100644 index 00000000..bb8afaa7 --- /dev/null +++ b/fast_llm/models/ssm/external/discrete_mamba2.py @@ -0,0 +1,382 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined + +from .configuration_mtp_llamba import StateUpdateKernel + +try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +except ImportError: + causal_conv1d_fn, causal_conv1d_update = None, None + + +class DiscreteMamba2(nn.Module): + """DiscreteMamba2 (taken github.com/goombalab/phi-mamba.git).""" + + def __init__( + self, + d_model, + d_state=64, + n_qk_heads=32, + n_v_heads=32, + d_conv=4, + expand=1, + activation="identity", + bias=False, + conv_bias=True, + chunk_size=128, + layer_idx=None, + device=None, + dtype=None, + verification_mode: StateUpdateKernel = StateUpdateKernel.cs, + **kwargs, + ): + """ + See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. + Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr". + + Other options are all experimental and should not need to be configured. + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = self.expand * self.d_model + self.n_qk_heads = n_qk_heads + self.n_v_heads = n_v_heads + self.headdim = self.d_inner // self.n_v_heads + assert self.n_v_heads == self.d_inner // self.headdim + assert self.d_inner % self.headdim == 0 + assert self.n_v_heads % self.n_qk_heads == 0 + self.activation = activation + self.chunk_size = chunk_size + self.layer_idx = layer_idx + self.bias = bias + self.kwargs = kwargs + self.inference_mode = verification_mode + assert verification_mode in [ + StateUpdateKernel.cs, + StateUpdateKernel.standard, + ], "Only chunk scan and standard selective scan are supported for now" + + # Projections + self.in_proj = nn.Linear( + self.d_model, + 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, + bias=bias, + **factory_kwargs, + ) + self.z_bias = ( + nn.Parameter(torch.zeros(self.d_inner, **factory_kwargs)) if not bias else 0 + ) # make sure z_bias always exists + + # Convolutional layer + conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state + self.conv_bias = conv_bias + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + **factory_kwargs, + ) + + # Activation after conv + if self.activation == "identity": + self.act = nn.Identity() + elif self.activation in ["silu", "swish"]: + self.act = nn.SiLU() + else: + raise ValueError(f"Unknown activation {self.activation}") + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.n_v_heads, **factory_kwargs)) + + # out_proj + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + + @property + def d_output(self): + """Returns the output dimension of the model.""" + return self.d_model + + @property + def state_to_tensor(self): + """Returns the state of the model as a tensor.""" + return self.layer.state_to_tensor + + def forward(self, u, inference_params=None, **kwargs): + """ + Args: + u: (B, L, D), + inference_params: dict.. Here we assume it contains a mask tensor of shape (B, L) with 1s for valid tokens and 0s for no-op tokens. + + Returns: + outputs: dict. + outputs["hidden_states"]: (B, L, D). + outputs["state"]: inference cache. + """ + outputs = {} + # assert state is None + batch, seqlen, dim = u.shape + + state = None + if inference_params is not None: + state = self._get_states_from_cache(inference_params, batch) + + if ( + state is not None + and inference_params.seqlen_offset > 0 # meaning we are in the middle of the sequence + and seqlen == 1 + and self.inference_mode != StateUpdateKernel.cs + ): + # we go in here for standard 1 token per time-step inference. + # seqlen_offset > 0 means we are in the middle of a sequence + # States are updated inplace + u = u.squeeze(1) if len(u.shape) == 3 else u + out, _ = self.step(u, state) + out = out.unsqueeze(1) if len(u.shape) == 2 else out + return {"hidden_states": out} + + # Hacky way to initialize state during inference + chunk_size = self.chunk_size if state is None else seqlen + + # Pad input to nearest multiple of chunklen + padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size + u = F.pad(u, (0, 0, 0, padded_len - seqlen)) + + # Project input + xBCzA_log = self.in_proj(u) + + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + if state is not None: + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") + state["conv"].copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + + # Convolutional layer + xBC = self.convolutional_forward( + xBC, padded_len, mask=inference_params.mask if inference_params is not None else None + ) + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) + B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) + C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + + # SSM forward + # TODO: this kernel needs to be aupdated to use the mask! If used solely for throughout benchmarking, it is enough to call it as is. + result = mamba_chunk_scan_combined( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=A_log, + dt_softplus=True, + A=-torch.ones(self.n_v_heads, device=A_log.device), + B=B, + C=C, + chunk_size=chunk_size, + # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation + return_final_states=(state is not None), + ) + + if state is not None: + y, ssm_state = result + state["ssm"].copy_(ssm_state) + else: + y = result + + Du = torch.einsum("h,blhp->blhp", self.D, x) + y = rearrange(y + Du, "b l h p -> b l (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + outputs["hidden_states"] = out[:, :seqlen, :] + + return outputs + + def step(self, u, state, **kwargs): + """ + Args: + u: (B, D), + state: dict. + + Returns: + out: (B, D), + state: dict. + + """ + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + xBC, conv_state = self.convolutional_step(xBC, state["conv"]) + state["conv"].copy_(conv_state) # update state in place + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) + B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) + C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) + + state["ssm"] = state["ssm"].to(x.dtype) + zeros = torch.zeros((self.n_v_heads, self.headdim), device=A_log.device).to(dtype=x.dtype) + ones = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=A_log.device).to(dtype=x.dtype) + y = selective_state_update( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=repeat(A_log, "b h -> b h p", p=self.headdim), + dt_softplus=True, + A=-ones, + B=B, + C=C, + state=state["ssm"], # will be updated in place + dt_bias=zeros, + D=zeros, + ) + + y = y + self.D[:, None] * x + y = rearrange(y, "b h p -> b (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + + return out, state + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + """Allocate memory for inference cache.""" + device = self.in_proj.weight.device + # conv_state: + conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + conv_state = torch.zeros( + batch_size, + self.d_conv, + self.conv1d.weight.shape[0], + device=device, + dtype=conv_dtype, + ).transpose(1, 2) + # ssm_state: + ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype + ssm_state = torch.zeros( + batch_size, + self.n_v_heads, + self.headdim, + self.d_state, + device=device, + dtype=ssm_dtype, + ) + return {"conv": conv_state, "ssm": ssm_state} + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + """ + Get states from cache. + + conv_state: (batch, d_conv, conv1d.weight.shape[0]) + ssm_state: (batch, n_qk_heads, headdim, d_state) + """ + assert self.layer_idx is not None + # Allocate memory if not exists + if self.layer_idx not in inference_params.key_value_memory_dict: + inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( + batch_size, inference_params.max_seqlen, dtype=torch.float32 + ) + # Get states + states = inference_params.key_value_memory_dict[self.layer_idx] + if initialize_states: + states["conv"].zero_() + states["ssm"].zero_() + return states + + def convolutional_forward(self, xBC, padded_len, mask=None): + """Convolutional layer forward pass for the full sequence.""" + seqlen = xBC.shape[1] + mask_seql = -1 if mask is None else mask.shape[1] + # If seqlen != mask_seql, this likely means we preallocated mask for static generation, + # but here we are in the prefill phase. + # Note, mask is needed to prevent state upodate for no-op tokens as described in https://proceedings.mlr.press/v262/wu24a.html + # Note, if we want to use joint attanimnet and advancement in selective-scan mode, we would need to implement masking into the kernel of causal_conv1d_fn and mamba_chunk_scan_combined + if causal_conv1d_fn is None or self.activation not in [ + "silu", + "swish", + "identity", + ]: + if mask_seql == seqlen: + xBC = xBC * mask.unsqueeze(-1) + + xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) + if mask_seql == seqlen: + xBC = xBC * mask.unsqueeze(-1) + else: + # TODO: note, this only works for chunked inference, for autoregressive mode we need to update the kernel to make sure conv state is not poluted + if mask_seql == seqlen: + xBC = xBC * mask.unsqueeze(-1) + xBC = causal_conv1d_fn( + xBC.transpose(1, 2), + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + activation=None if self.activation == "identity" else self.activation, + ).transpose(1, 2) + + if mask_seql == seqlen: + xBC = xBC * mask.unsqueeze(-1) + return xBC + + def convolutional_step(self, xBC, conv_state): + """Convolutional layer forward pass for a single step.""" + conv_state = conv_state.to(xBC.dtype) + if causal_conv1d_update: + xBC = causal_conv1d_update( + xBC, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation if self.activation != "identity" else None, + ) + else: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = xBC + xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv_bias: + xBC = xBC + self.conv1d.bias + xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype + + return xBC, conv_state diff --git a/fast_llm/models/ssm/external/modeling_mtp_llamba.py b/fast_llm/models/ssm/external/modeling_mtp_llamba.py new file mode 100644 index 00000000..6d9746db --- /dev/null +++ b/fast_llm/models/ssm/external/modeling_mtp_llamba.py @@ -0,0 +1,389 @@ +# Copyright (c) 2024, Kevin Li, Aviv Bick. + +import json +import os +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin +from mamba_ssm.utils.generation import GenerationMixin +from torch import Tensor, nn +from transformers.activations import ACT2FN +from transformers.utils.generic import ModelOutput + +from .configuration_mtp_llamba import MTPLlambaConfig as LlambaConfig +from .discrete_mamba2 import DiscreteMamba2 + + +class LlamaRMSNorm(nn.Module): + """LlamaRMSNorm (taken from transformers.models.llama.modeling_llama.LlamaRMSNorm).""" + + def __init__(self, hidden_size, eps=1e-6, factory_kwargs=None): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + """ + Args: + hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size). + + Returns: + torch.Tensor of shape (batch_size, seq_len, hidden_size). + """ + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + """Set the extra representation of the module.""" + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class LlamaMLP(nn.Module): + """LlamaMLP (taken from transformers.models.llama.modeling_llama.LlamaMLP).""" + + def __init__(self, hidden_size, intermediate_size, bias, act_fn, factory_kwargs=None): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias, **factory_kwargs) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias, **factory_kwargs) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias, **factory_kwargs) + self.act_fn = ACT2FN[act_fn] + + def forward(self, x): + """ + Args: + x: torch.Tensor of shape (batch_size, seq_len, hidden_size). + + Returns: + torch.Tensor of shape (batch_size, seq_len, hidden_size). + """ + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +@dataclass +class CustomMambaCausalLMOutput(ModelOutput): + """Custom output class for MambaLMHeadModel.""" + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + + +class MTPLlambaLMHeadModel(nn.Module, GenerationMixin, PyTorchModelHubMixin): + """MambaLM model with a language modeling head on top (linear layer).""" + + def __init__(self, config, initializer_cfg=None, device=None, dtype=None, **kwargs) -> None: + super().__init__() + + # Load config + if not isinstance(config, LlambaConfig): + config = LlambaConfig(**config) + self.config = config + + # Factory kwargs + factory_kwargs = {"device": device, "dtype": dtype} + + # Pad vocab size to be a multiple of pad_vocab_size_multiple + vocab_size = config.vocab_size + pad_vocab_size_multiple = config.pad_vocab_size_multiple + if vocab_size % pad_vocab_size_multiple != 0: + vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) + self.config.vocab_size = vocab_size + + # Mixer model + self.backbone = MixerModel( + input_size=vocab_size, + config=self.config, + initializer_cfg=initializer_cfg, + **factory_kwargs, + ) + + # MTP heads + self.mtp_heads = nn.ModuleList( + [ + Block( + config=config, + factory_kwargs=factory_kwargs, + layer_idx=layer_idx, + ).to(device) + for layer_idx in range(config.n_layer, config.n_layer + config.prediction_heads - 1) + ] + ) + + self.mtp_norms = nn.ModuleList( + [ + LlamaRMSNorm(config.d_model, eps=config.norm_epsilon, factory_kwargs=factory_kwargs) + for _ in range(config.prediction_heads - 1) + ] + ) + # LM head + if not self.config.tie_embeddings: + self.lm_head = nn.Linear( + in_features=self.config.d_model, + out_features=self.config.vocab_size, + bias=self.config.lm_head_bias, + **factory_kwargs, + ) + else: + self.lm_head = lambda x: x @ self.backbone.embedding.weight.t() + + def allocate_inference_cache(self, *args, **kwargs): + """Allocate inference cache for the model.""" + + mtps = { + i + self.config.n_layer: layer.allocate_inference_cache(*args, **kwargs) + for i, layer in enumerate(self.mtp_heads) + } + return {**self.backbone.allocate_inference_cache(*args, **kwargs), **mtps} + + def forward( + self, + input_ids, + position_ids=None, + return_hidden_states=False, + return_logits=True, + inference_params=None, + num_last_tokens=0, + ): + """ + Args: + input_ids: torch.Tensor of shape (batch_size, seq_len), + position_ids: torch.Tensor of shape (batch_size, seq_len), optional, not used (just for compatibility), + return_hidden_states: bool, optional, + return_logits: bool, optional, whether to compute the logits with the LM head, + inference_params: dict, optional, the model's inference cache, + num_last_tokens: int, optional. If > 0, only return the logits for the last n tokens. + + Returns: + CustomMambaCausalLMOutput. + + """ + outputs = self.backbone( + input_ids, + return_hidden_states=return_hidden_states, + inference_params=inference_params, + position_ids=position_ids, + ) + + # MTP heads processing + latents = [] + hidden_states = outputs["last_hidden_state"] + hidden_states_before_last = outputs["hidden_state_before_last"] + + # last layer already has layer norm applied + latents.append(hidden_states) + + # Process through MTP heads + for i, mtp_head in enumerate(self.mtp_heads): + mtp_outputs = mtp_head( + hidden_states_before_last, + inference_params=inference_params, + position_ids=position_ids, + ) + mtp_hidden_states = mtp_outputs["hidden_states"] + latents.append(self.mtp_norms[i](mtp_hidden_states)) + + # Stack the latents to get (batch_size, seq_len, num_prediction_heads, hidden_size) + stacked_latents = torch.stack(latents, dim=-2) + + if return_logits: + if isinstance(self.lm_head, nn.Linear): + # Apply lm_head to each prediction head's output + logits = self.lm_head(stacked_latents).float() + else: + # Using the tied embedding weights + logits = self.lm_head(stacked_latents) + + outputs["logits"] = logits if num_last_tokens == 0 else logits[:, -num_last_tokens:] + else: + outputs["logits"] = None + + return CustomMambaCausalLMOutput( + loss=None, + logits=outputs["logits"], + all_hidden_states=outputs["all_hidden_states"], + last_hidden_state=stacked_latents, + ) + + def save_pretrained(self, save_directory): + """ + Minimal implementation of save_pretrained for MambaLMHeadModel. + Save the model and its configuration file to a directory. + """ + # Ensure save_directory exists + if not os.path.exists(save_directory): + os.makedirs(save_directory) + + # Save the model's state_dict + model_path = os.path.join(save_directory, "pytorch_model.bin") + torch.save(self.state_dict(), model_path) + + # Save the configuration of the model + config_path = os.path.join(save_directory, "config.json") + with open(config_path, "w") as f: + json.dump(self.config.to_dict(), f) + + +class MixerModel(nn.Module): + """Mixer model with a stack of Mixer layers.""" + + def __init__(self, input_size, config=None, device=None, dtype=None, **kwargs) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.config = config + self.embedding = nn.Embedding(input_size, self.config.d_model, **factory_kwargs) + + self.layers = nn.ModuleList( + [ + Block( + config=config, + factory_kwargs=factory_kwargs, + layer_idx=i, + ).to(device) + for i in range(self.config.n_layer) + ] + ) + + self.final_layernorm = LlamaRMSNorm( + hidden_size=self.config.d_model, + eps=self.config.norm_epsilon, + factory_kwargs=factory_kwargs, + ) + + return + + def allocate_inference_cache(self, *args, **kwargs): + """Allocate inference cache for the model.""" + return {i: layer.allocate_inference_cache(*args, **kwargs) for i, layer in enumerate(self.layers)} + + def forward( + self, + input_ids, + return_hidden_states=False, + inference_params=None, + position_ids=None, + ): + """Run the model.""" + # Start running the layers + hidden_states = self.embedding(input_ids) + + # Initialize outputs + outputs = { + "last_hidden_state": None, + "hidden_state_before_last": None, + "all_hidden_states": (hidden_states,) if return_hidden_states else (), + } + + # Run the layers + for layer in self.layers: + layer_outputs = layer( + hidden_states, + inference_params=inference_params, + position_ids=position_ids, + ) + if layer == self.layers[-1]: + outputs["hidden_state_before_last"] = hidden_states + # Record outputs + hidden_states = layer_outputs["hidden_states"] + if return_hidden_states: + outputs["all_hidden_states"] += (hidden_states,) + + # Last layer, apply layer norm + outputs["last_hidden_state"] = self.final_layernorm(hidden_states) + return outputs + + +class Block(nn.Module): + """ + Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection. + + This Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA/MLP -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Add -> LN -> Mixer, returning both + the hidden_states (output of the mixer) and the residual. + This is purely for performance reasons, as we can fuse add and LayerNorm. + The residual needs to be provided (except for the very first block). + """ + + def __init__(self, config, factory_kwargs, layer_idx, **kwargs): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + # Mixer + self.mixer = DiscreteMamba2( + d_model=self.config.d_model, + layer_idx=layer_idx, + **config.ssm_cfg, + **factory_kwargs, + ) + + # Other components + self.input_layernorm = LlamaRMSNorm(hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs) + self.post_attention_layernorm = LlamaRMSNorm( + hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs + ) + self.mlp = LlamaMLP( + hidden_size=self.config.d_model, + **config.mlp_cfg, + factory_kwargs=factory_kwargs, + ) + + def forward( + self, + hidden_states: Tensor, + inference_params=None, + **kwargs, + ): + """ + Pass the input through the encoder layer. + + Args: + hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size), + inference_params: dict, optional, + + Returns: + dict with keys: + hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size), + mamba_hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size), + transfer_matrix: torch.Tensor of shape (batch_size, seq_len, seq_len). + """ + outputs = {} + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Apply Mixer + mixer_outputs = self.mixer( + hidden_states, + inference_params=inference_params, + ) + + hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs["hidden_states"] = hidden_states + + return outputs + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + """Allocate inference cache for the model.""" + if getattr(self.mixer, "allocate_inference_cache", None) is None: + return + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) From 87b319769cbdf5a1130edc41c758b4c981fd5847 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 29 Apr 2025 01:23:20 +0000 Subject: [PATCH 032/114] modeling apriel ssm --- .../ssm/external/configuration_ssm_apriel.py | 101 +++ .../ssm/external/modeling_ssm_apriel.py | 730 ++++++++++++++++++ 2 files changed, 831 insertions(+) create mode 100644 fast_llm/models/ssm/external/configuration_ssm_apriel.py create mode 100644 fast_llm/models/ssm/external/modeling_ssm_apriel.py diff --git a/fast_llm/models/ssm/external/configuration_ssm_apriel.py b/fast_llm/models/ssm/external/configuration_ssm_apriel.py new file mode 100644 index 00000000..0c75ca65 --- /dev/null +++ b/fast_llm/models/ssm/external/configuration_ssm_apriel.py @@ -0,0 +1,101 @@ +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Apriel SSM model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import is_torch_available, logging + +logger = logging.get_logger(__name__) + +if is_torch_available(): + pass + + +class AprielSSMConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`AprielModel`]. It is used to instantiate an Apriel + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Apriel-5B-Base. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + .... + ```""" + + model_type = "apriel_ssm" + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + hidden_act="silu", + initializer_range=0.02, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + mlp_bias=False, + rms_norm_eps=1e-5, + ssm_cfg: dict = None, + **kwargs, + ): + self.vocab_size = vocab_size + # self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + # self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + # self.rope_theta = rope_theta + self.mlp_bias = mlp_bias + # self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + # if self.rope_scaling is not None and "type" in self.rope_scaling: + # self.rope_scaling["rope_type"] = self.rope_scaling["type"] + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + self.ssm_cfg = ssm_cfg or { + "d_state": 64, + "n_v_heads": 24, + "n_qk_heads": 24, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + "d_inner": 4104, # to make sure we have 24 heads + } + + +__all__ = ["AprielConfig"] diff --git a/fast_llm/models/ssm/external/modeling_ssm_apriel.py b/fast_llm/models/ssm/external/modeling_ssm_apriel.py new file mode 100644 index 00000000..d30d5b66 --- /dev/null +++ b/fast_llm/models/ssm/external/modeling_ssm_apriel.py @@ -0,0 +1,730 @@ +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from einops import rearrange, repeat +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined +from mamba_ssm.utils.generation import GenerationMixin +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from transformers.utils.generic import ModelOutput + +from .configuration_ssm_apriel import AprielSSMConfig + +logger = logging.get_logger(__name__) + + +@dataclass +class CustomMambaCausalLMOutput(ModelOutput): + """Custom output class for MambaLMHeadModel.""" + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + + +class AprielRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + AprielRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(AprielRMSNorm) + + +class AprielMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def segsum(x): + """More stable segment sum calculation.""" + # [1, 2, 3] + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + # [[1, 1, 1], [2, 2, 2], [3, 3, 3]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) + x = x.masked_fill(~mask, 0) + # [[0, 0, 0], [2, 0, 0], [3, 3, 0]] + x_segsum = torch.cumsum(x, dim=-2) + # [[0, 0, 0], [2, 0, 0], [5, 3, 0]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def materialize_mixer(A_log, B, C, D): + """ + Since the transfer matrix will be equated to the attention matrix, + we need to support the form: torch.matmul(attn_weights, value_states). + Thus, y = torch.matmul(T, X) + Arguments: + A_log: (batch, length, n_heads) + B: (batch, length, n_heads, d_state) + C: (batch, length, n_heads, d_state) + Return: + T: (batch, n_heads, length, length) + """ + batch_size, length, n_heads, d_state = B.shape + assert A_log.shape == (batch_size, length, n_heads) + assert B.shape == C.shape == (batch_size, length, n_heads, d_state) + + # Compute: + A_log = rearrange(-F.softplus(A_log), "b l h -> b h l") + powers = torch.exp(segsum(A_log)) + T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers) + + # Add D: + if D is not None: + T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1) + + T = rearrange(T, "b h z l -> b h l z") + return T + + +class DiscreteMamba2(nn.Module): + def __init__( + self, + d_model, + d_state=64, + n_qk_heads=32, + n_v_heads=32, + d_conv=4, + expand=1, + activation="identity", + bias=False, + conv_bias=True, + chunk_size=128, + layer_idx=None, + device=None, + dtype=None, + d_inner=None, + **kwargs, # Absorb kwarg for general module + ): + """ + See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. + Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" + + Other options are all experimental and should not need to be configured + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = self.expand * self.d_model if d_inner is None else d_inner + self.n_qk_heads = n_qk_heads + self.n_v_heads = n_v_heads + self.headdim = self.d_inner // self.n_v_heads + assert self.n_v_heads == self.d_inner // self.headdim + assert self.d_inner % self.headdim == 0 + assert self.n_v_heads % self.n_qk_heads == 0 + self.activation = activation + self.chunk_size = chunk_size + self.layer_idx = layer_idx + self.bias = bias + self.kwargs = kwargs + + # Projections + self.in_proj = nn.Linear( + self.d_model, + 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, + bias=bias, + **factory_kwargs, + ) + self.z_bias = ( + nn.Parameter(torch.zeros(self.d_inner, device=device)) if not bias else 0 + ) # make sure z_bias always exists + + # Convolutional layer + conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state + self.conv_bias = conv_bias + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + **factory_kwargs, + ) + + # Activation after conv + if self.activation == "identity": + self.act = nn.Identity() + elif self.activation in ["silu", "swish"]: + self.act = nn.SiLU() + else: + raise ValueError(f"Unknown activation {self.activation}") + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.n_v_heads, device=device)) + self.D._optim = {"weight_decay": 0.0} + + # out_proj + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + + @property + def d_output(self): + return self.d_model + + @property + def state_to_tensor(self): + return self.layer.state_to_tensor + + def forward(self, u, return_mixer_matrix=False, inference_params=None, **kwargs): + """ + u: (B, L, D) + Returns: same shape as u + """ + outputs = {} + # assert state is None + batch, seqlen, dim = u.shape + + state = None + if inference_params is not None: + state = self._get_states_from_cache(inference_params, batch) + if inference_params.seqlen_offset > 0: + # States are updated inplace + out, _ = self.step(u, state) + return {"hidden_states": out} + + # Hacky way to initialize state during inference + chunk_size = self.chunk_size if state is None else seqlen + + # Pad input to nearest multiple of chunklen + padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size + u = F.pad(u, (0, 0, 0, padded_len - seqlen)) + + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + if state is not None: + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") + state["conv"].copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + + # Convolutional layer + xBC = self.convolutional_forward(xBC, padded_len) + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) + B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) + C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + + # SSM forward + result = mamba_chunk_scan_combined( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=A_log, + dt_softplus=True, + A=-torch.ones(self.n_v_heads, device=A_log.device), + B=B, + C=C, + chunk_size=chunk_size, + # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation + return_final_states=(state is not None), + ) + + if state is not None: + y, ssm_state = result + state["ssm"].copy_(ssm_state) + else: + y = result + + Du = torch.einsum("h,blhp->blhp", self.D, x) + y = rearrange(y + Du, "b l h p -> b l (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + outputs["hidden_states"] = out[:, :seqlen, :] + + if return_mixer_matrix: + outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] + return outputs + + def step(self, u, state, **kwargs): + """ + u: (B D) + state: dict of states + Returns: same shape as u + """ + + # Project input + xBCzA_log = self.in_proj(u.squeeze(1)) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + xBC, conv_state = self.convolutional_step(xBC, state["conv"]) + state["conv"].copy_(conv_state) # update state in place + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) + B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) + C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) + + state["ssm"] = state["ssm"].to(x.dtype) + zeros = torch.zeros((self.n_v_heads, self.headdim), device=A_log.device).to(dtype=x.dtype) + ones = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=A_log.device).to(dtype=x.dtype) + y = selective_state_update( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=repeat(A_log, "b h -> b h p", p=self.headdim), + dt_softplus=True, + A=-ones, + B=B, + C=C, + state=state["ssm"], # will be updated in place + dt_bias=zeros, + D=zeros, + ) + + y = y + self.D[:, None] * x + y = rearrange(y, "b h p -> b (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + + return out, state + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + device = self.in_proj.weight.device + # conv_state: + conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + conv_state = torch.zeros( + batch_size, + self.d_conv, + self.conv1d.weight.shape[0], + device=device, + dtype=conv_dtype, + ).transpose(1, 2) + # ssm_state: + ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype + ssm_state = torch.zeros( + batch_size, + self.n_v_heads, + self.headdim, + self.d_state, + device=device, + dtype=ssm_dtype, + ) + return {"conv": conv_state, "ssm": ssm_state} + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + """ + conv_state: (batch, d_conv, conv1d.weight.shape[0]) + ssm_state: (batch, n_qk_heads, headdim, d_state) + """ + assert self.layer_idx is not None + # Allocate memory if not exists + if self.layer_idx not in inference_params.key_value_memory_dict: + inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( + batch_size, inference_params.max_seqlen, dtype=torch.float32 + ) + # Get states + states = inference_params.key_value_memory_dict[self.layer_idx] + if initialize_states: + states["conv"].zero_() + states["ssm"].zero_() + return states + + def convolutional_forward(self, xBC, padded_len): + if causal_conv1d_fn is None or self.activation not in [ + "silu", + "swish", + "identity", + ]: + xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) + else: + xBC = causal_conv1d_fn( + xBC.transpose(1, 2), + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + activation=None if self.activation == "identity" else self.activation, + ).transpose(1, 2) + return xBC + + def convolutional_step(self, xBC, conv_state): + # Convolutional layer + conv_state = conv_state.to(xBC.dtype) + if causal_conv1d_update: + xBC = causal_conv1d_update( + xBC, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation if self.activation != "identity" else None, + ) + else: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = xBC + xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv_bias: + xBC = xBC + self.conv1d.bias + xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype + + return xBC, conv_state + + +class AprielDecoderLayer(nn.Module): + def __init__(self, config: AprielSSMConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.mixer = DiscreteMamba2( + d_model=config.hidden_size, + layer_idx=layer_idx, + **config.ssm_cfg, + ) + + self.mlp = AprielMLP(config) + self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, hidden_states: torch.Tensor, inference_params=None, **kwargs + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + + outputs = {} + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + mixer_outputs = self.mixer( + hidden_states, + inference_params=inference_params, + ) + + hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs["hidden_states"] = hidden_states + + return outputs + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + """Allocate inference cache for the model.""" + if getattr(self.mixer, "allocate_inference_cache", None) is None: + return + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + +APRIEL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`AprielSSMConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Apriel Model outputting raw hidden-states without any specific head on top.", + APRIEL_START_DOCSTRING, +) +class AprielSSMPreTrainedModel(PreTrainedModel): + config_class = AprielSSMConfig + base_model_prefix = "model" + _no_split_modules = ["AprielDecoderLayer"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def allocate_inference_cache(self, *args, **kwargs): + """Allocate inference cache for the model.""" + return getattr(self, self.base_model_prefix).allocate_inference_cache(*args, **kwargs) + + +APRIEL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Apriel Model outputting raw hidden-states without any specific head on top.", + APRIEL_START_DOCSTRING, +) +class AprielSSMModel(AprielSSMPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`] + Args: + config: AprielSSMConfig + """ + + def __init__(self, config: AprielSSMConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [AprielDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def allocate_inference_cache(self, *args, **kwargs): + """Allocate inference cache for the model.""" + return {i: layer.allocate_inference_cache(*args, **kwargs) for i, layer in enumerate(self.layers)} + + @add_start_docstrings_to_model_forward(APRIEL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + return_hidden_states=False, + inference_params=None, + position_ids=None, + ) -> Union[tuple, BaseModelOutputWithPast]: + + hidden_states = self.embed_tokens(input_ids) + + # decoder layers + outputs = { + "last_hidden_state": None, + "all_hidden_states": (hidden_states,) if return_hidden_states else (), + } + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + + layer_outputs = decoder_layer( + hidden_states, + inference_params=inference_params, + position_ids=position_ids, + ) + # Record outputs + hidden_states = layer_outputs["hidden_states"] + if return_hidden_states: + outputs["all_hidden_states"] += (hidden_states,) + + outputs["last_hidden_state"] = self.norm(hidden_states) + return outputs + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class AprielSSMForCausalLM(AprielSSMPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = AprielSSMModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids=None, + return_hidden_states=False, + return_logits=True, + inference_params=None, + num_last_tokens=0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[tuple, CausalLMOutputWithPast]: + + outputs = self.model( + input_ids, + return_hidden_states=return_hidden_states, + inference_params=inference_params, + position_ids=position_ids, + ) + + if outputs["last_hidden_state"] is not None and return_logits: + logits = self.lm_head(outputs["last_hidden_state"]).float() + outputs["logits"] = logits if num_last_tokens == 0 else logits[:, -num_last_tokens:] + else: + outputs["logits"] = None + + return CustomMambaCausalLMOutput( + loss=None, + logits=outputs["logits"], + all_hidden_states=outputs["all_hidden_states"], + last_hidden_state=outputs["last_hidden_state"], + ) + + +__all__ = [ + "AprielSSMForCausalLM", + "AprielModel", + "AprielSSMPreTrainedModel", +] From d3e1df246279e367cbd45bbf4b5492165b6e4af5 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 29 Apr 2025 12:25:05 +0000 Subject: [PATCH 033/114] Apriel to SSM --- .../models/ssm/external/ariel_to_ssm.ipynb | 447 ++++++++++++++++++ 1 file changed, 447 insertions(+) create mode 100644 fast_llm/models/ssm/external/ariel_to_ssm.ipynb diff --git a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb new file mode 100644 index 00000000..8c5f64ae --- /dev/null +++ b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb @@ -0,0 +1,447 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import torch\n", + "from mamba_ssm import MambaLMHeadModel\n", + "from mamba_ssm.models.config_mamba import MambaConfig\n", + "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", + "from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig\n", + "from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM\n", + "from transformers.cache_utils import StaticCache\n", + "from types import SimpleNamespace\n", + "\n", + "# make sure the code changes reflected without reload\n", + "%load_ext autoreload\n", + "%autoreload 2\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 9.90it/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "AprielForCausalLM(\n", + " (model): AprielModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (rotary_emb): AprielRotaryEmbedding()\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", + "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", + "apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", + "apriel_state_dict = apriel_model.state_dict()\n", + "apriel_model.to(device).to(dtype=torch.bfloat16)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.bfloat16" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_model.config.torch_dtype" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "n_params = sum(p.numel() for p in apriel_model.parameters() if p.requires_grad)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4.83207168" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "n_params/1e9" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "apriel_ssm_config = AprielSSMConfig(vocab_size=config.vocab_size, \n", + " hidden_size=config.hidden_size,\n", + " intermediate_size=config.intermediate_size,\n", + " num_hidden_layers=config.num_hidden_layers,\n", + " hidden_act=config.hidden_act,\n", + " initializer_range=config.initializer_range,\n", + " use_cache=config.use_cache,\n", + " mlp_bias=config.mlp_bias,\n", + " tie_word_embeddings=config.tie_word_embeddings,\n", + " pad_token_id=config.pad_token_id,\n", + " bos_token_id=config.bos_token_id,\n", + " eos_token_id=config.eos_token_id,\n", + " rms_norm_eps=config.rms_norm_eps)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "apriel_ssm = AprielSSMForCausalLM(apriel_ssm_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMConfig {\n", + " \"_attn_implementation_autoset\": true,\n", + " \"bos_token_id\": 1,\n", + " \"eos_token_id\": 2,\n", + " \"hidden_act\": \"silu\",\n", + " \"hidden_size\": 4096,\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 8192,\n", + " \"mlp_bias\": false,\n", + " \"model_type\": \"apriel_ssm\",\n", + " \"num_hidden_layers\": 28,\n", + " \"rms_norm_eps\": 1e-05,\n", + " \"ssm_cfg\": {\n", + " \"activation\": \"identity\",\n", + " \"bias\": false,\n", + " \"chunk_size\": 128,\n", + " \"d_inner\": 4104,\n", + " \"d_state\": 64,\n", + " \"expand\": 1,\n", + " \"n_qk_heads\": 24,\n", + " \"n_v_heads\": 24\n", + " },\n", + " \"tie_word_embeddings\": false,\n", + " \"transformers_version\": \"4.48.1\",\n", + " \"use_cache\": true,\n", + " \"vocab_size\": 131072\n", + "}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_ssm_config" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "N params SSM: 5.660780512\n" + ] + } + ], + "source": [ + "print(\"N params SSM:\", sum(p.numel() for p in apriel_ssm.parameters() if p.requires_grad)/1e9)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load State dict into SSM" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "_IncompatibleKeys(missing_keys=['model.layers.0.mixer.z_bias', 'model.layers.0.mixer.D', 'model.layers.0.mixer.in_proj.weight', 'model.layers.0.mixer.conv1d.weight', 'model.layers.0.mixer.conv1d.bias', 'model.layers.0.mixer.out_proj.weight', 'model.layers.1.mixer.z_bias', 'model.layers.1.mixer.D', 'model.layers.1.mixer.in_proj.weight', 'model.layers.1.mixer.conv1d.weight', 'model.layers.1.mixer.conv1d.bias', 'model.layers.1.mixer.out_proj.weight', 'model.layers.2.mixer.z_bias', 'model.layers.2.mixer.D', 'model.layers.2.mixer.in_proj.weight', 'model.layers.2.mixer.conv1d.weight', 'model.layers.2.mixer.conv1d.bias', 'model.layers.2.mixer.out_proj.weight', 'model.layers.3.mixer.z_bias', 'model.layers.3.mixer.D', 'model.layers.3.mixer.in_proj.weight', 'model.layers.3.mixer.conv1d.weight', 'model.layers.3.mixer.conv1d.bias', 'model.layers.3.mixer.out_proj.weight', 'model.layers.4.mixer.z_bias', 'model.layers.4.mixer.D', 'model.layers.4.mixer.in_proj.weight', 'model.layers.4.mixer.conv1d.weight', 'model.layers.4.mixer.conv1d.bias', 'model.layers.4.mixer.out_proj.weight', 'model.layers.5.mixer.z_bias', 'model.layers.5.mixer.D', 'model.layers.5.mixer.in_proj.weight', 'model.layers.5.mixer.conv1d.weight', 'model.layers.5.mixer.conv1d.bias', 'model.layers.5.mixer.out_proj.weight', 'model.layers.6.mixer.z_bias', 'model.layers.6.mixer.D', 'model.layers.6.mixer.in_proj.weight', 'model.layers.6.mixer.conv1d.weight', 'model.layers.6.mixer.conv1d.bias', 'model.layers.6.mixer.out_proj.weight', 'model.layers.7.mixer.z_bias', 'model.layers.7.mixer.D', 'model.layers.7.mixer.in_proj.weight', 'model.layers.7.mixer.conv1d.weight', 'model.layers.7.mixer.conv1d.bias', 'model.layers.7.mixer.out_proj.weight', 'model.layers.8.mixer.z_bias', 'model.layers.8.mixer.D', 'model.layers.8.mixer.in_proj.weight', 'model.layers.8.mixer.conv1d.weight', 'model.layers.8.mixer.conv1d.bias', 'model.layers.8.mixer.out_proj.weight', 'model.layers.9.mixer.z_bias', 'model.layers.9.mixer.D', 'model.layers.9.mixer.in_proj.weight', 'model.layers.9.mixer.conv1d.weight', 'model.layers.9.mixer.conv1d.bias', 'model.layers.9.mixer.out_proj.weight', 'model.layers.10.mixer.z_bias', 'model.layers.10.mixer.D', 'model.layers.10.mixer.in_proj.weight', 'model.layers.10.mixer.conv1d.weight', 'model.layers.10.mixer.conv1d.bias', 'model.layers.10.mixer.out_proj.weight', 'model.layers.11.mixer.z_bias', 'model.layers.11.mixer.D', 'model.layers.11.mixer.in_proj.weight', 'model.layers.11.mixer.conv1d.weight', 'model.layers.11.mixer.conv1d.bias', 'model.layers.11.mixer.out_proj.weight', 'model.layers.12.mixer.z_bias', 'model.layers.12.mixer.D', 'model.layers.12.mixer.in_proj.weight', 'model.layers.12.mixer.conv1d.weight', 'model.layers.12.mixer.conv1d.bias', 'model.layers.12.mixer.out_proj.weight', 'model.layers.13.mixer.z_bias', 'model.layers.13.mixer.D', 'model.layers.13.mixer.in_proj.weight', 'model.layers.13.mixer.conv1d.weight', 'model.layers.13.mixer.conv1d.bias', 'model.layers.13.mixer.out_proj.weight', 'model.layers.14.mixer.z_bias', 'model.layers.14.mixer.D', 'model.layers.14.mixer.in_proj.weight', 'model.layers.14.mixer.conv1d.weight', 'model.layers.14.mixer.conv1d.bias', 'model.layers.14.mixer.out_proj.weight', 'model.layers.15.mixer.z_bias', 'model.layers.15.mixer.D', 'model.layers.15.mixer.in_proj.weight', 'model.layers.15.mixer.conv1d.weight', 'model.layers.15.mixer.conv1d.bias', 'model.layers.15.mixer.out_proj.weight', 'model.layers.16.mixer.z_bias', 'model.layers.16.mixer.D', 'model.layers.16.mixer.in_proj.weight', 'model.layers.16.mixer.conv1d.weight', 'model.layers.16.mixer.conv1d.bias', 'model.layers.16.mixer.out_proj.weight', 'model.layers.17.mixer.z_bias', 'model.layers.17.mixer.D', 'model.layers.17.mixer.in_proj.weight', 'model.layers.17.mixer.conv1d.weight', 'model.layers.17.mixer.conv1d.bias', 'model.layers.17.mixer.out_proj.weight', 'model.layers.18.mixer.z_bias', 'model.layers.18.mixer.D', 'model.layers.18.mixer.in_proj.weight', 'model.layers.18.mixer.conv1d.weight', 'model.layers.18.mixer.conv1d.bias', 'model.layers.18.mixer.out_proj.weight', 'model.layers.19.mixer.z_bias', 'model.layers.19.mixer.D', 'model.layers.19.mixer.in_proj.weight', 'model.layers.19.mixer.conv1d.weight', 'model.layers.19.mixer.conv1d.bias', 'model.layers.19.mixer.out_proj.weight', 'model.layers.20.mixer.z_bias', 'model.layers.20.mixer.D', 'model.layers.20.mixer.in_proj.weight', 'model.layers.20.mixer.conv1d.weight', 'model.layers.20.mixer.conv1d.bias', 'model.layers.20.mixer.out_proj.weight', 'model.layers.21.mixer.z_bias', 'model.layers.21.mixer.D', 'model.layers.21.mixer.in_proj.weight', 'model.layers.21.mixer.conv1d.weight', 'model.layers.21.mixer.conv1d.bias', 'model.layers.21.mixer.out_proj.weight', 'model.layers.22.mixer.z_bias', 'model.layers.22.mixer.D', 'model.layers.22.mixer.in_proj.weight', 'model.layers.22.mixer.conv1d.weight', 'model.layers.22.mixer.conv1d.bias', 'model.layers.22.mixer.out_proj.weight', 'model.layers.23.mixer.z_bias', 'model.layers.23.mixer.D', 'model.layers.23.mixer.in_proj.weight', 'model.layers.23.mixer.conv1d.weight', 'model.layers.23.mixer.conv1d.bias', 'model.layers.23.mixer.out_proj.weight', 'model.layers.24.mixer.z_bias', 'model.layers.24.mixer.D', 'model.layers.24.mixer.in_proj.weight', 'model.layers.24.mixer.conv1d.weight', 'model.layers.24.mixer.conv1d.bias', 'model.layers.24.mixer.out_proj.weight', 'model.layers.25.mixer.z_bias', 'model.layers.25.mixer.D', 'model.layers.25.mixer.in_proj.weight', 'model.layers.25.mixer.conv1d.weight', 'model.layers.25.mixer.conv1d.bias', 'model.layers.25.mixer.out_proj.weight', 'model.layers.26.mixer.z_bias', 'model.layers.26.mixer.D', 'model.layers.26.mixer.in_proj.weight', 'model.layers.26.mixer.conv1d.weight', 'model.layers.26.mixer.conv1d.bias', 'model.layers.26.mixer.out_proj.weight', 'model.layers.27.mixer.z_bias', 'model.layers.27.mixer.D', 'model.layers.27.mixer.in_proj.weight', 'model.layers.27.mixer.conv1d.weight', 'model.layers.27.mixer.conv1d.bias', 'model.layers.27.mixer.out_proj.weight'], unexpected_keys=['model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.18.self_attn.q_proj.weight', 'model.layers.18.self_attn.k_proj.weight', 'model.layers.18.self_attn.v_proj.weight', 'model.layers.18.self_attn.o_proj.weight', 'model.layers.19.self_attn.q_proj.weight', 'model.layers.19.self_attn.k_proj.weight', 'model.layers.19.self_attn.v_proj.weight', 'model.layers.19.self_attn.o_proj.weight', 'model.layers.20.self_attn.q_proj.weight', 'model.layers.20.self_attn.k_proj.weight', 'model.layers.20.self_attn.v_proj.weight', 'model.layers.20.self_attn.o_proj.weight', 'model.layers.21.self_attn.q_proj.weight', 'model.layers.21.self_attn.k_proj.weight', 'model.layers.21.self_attn.v_proj.weight', 'model.layers.21.self_attn.o_proj.weight', 'model.layers.22.self_attn.q_proj.weight', 'model.layers.22.self_attn.k_proj.weight', 'model.layers.22.self_attn.v_proj.weight', 'model.layers.22.self_attn.o_proj.weight', 'model.layers.23.self_attn.q_proj.weight', 'model.layers.23.self_attn.k_proj.weight', 'model.layers.23.self_attn.v_proj.weight', 'model.layers.23.self_attn.o_proj.weight', 'model.layers.24.self_attn.q_proj.weight', 'model.layers.24.self_attn.k_proj.weight', 'model.layers.24.self_attn.v_proj.weight', 'model.layers.24.self_attn.o_proj.weight', 'model.layers.25.self_attn.q_proj.weight', 'model.layers.25.self_attn.k_proj.weight', 'model.layers.25.self_attn.v_proj.weight', 'model.layers.25.self_attn.o_proj.weight', 'model.layers.26.self_attn.q_proj.weight', 'model.layers.26.self_attn.k_proj.weight', 'model.layers.26.self_attn.v_proj.weight', 'model.layers.26.self_attn.o_proj.weight', 'model.layers.27.self_attn.q_proj.weight', 'model.layers.27.self_attn.k_proj.weight', 'model.layers.27.self_attn.v_proj.weight', 'model.layers.27.self_attn.o_proj.weight'])" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_ssm.load_state_dict(apriel_state_dict, strict=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "apriel_ssm.to(device).to(dtype=torch.bfloat16)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Save checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [], + "source": [ + "apriel_ssm.save_pretrained(\"/mnt/checkpoints/ssm/ariel_ssm\")" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "24" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_ssm.model.layers[0].mixer.n_v_heads" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMForCausalLM(\n", + " (model): AprielSSMModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=12320, bias=False)\n", + " (conv1d): Conv1d(8192, 8192, kernel_size=(4,), stride=(1,), padding=(3,), groups=8192)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (rotary_emb): AprielRotaryEmbedding()\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_ssm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Try a forward pass" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [], + "source": [ + "input_ids = torch.randint(0, 32000, (1, 128), dtype=torch.long, device=device)\n", + "batch_size = 1\n", + "max_length = 128\n", + "state = SimpleNamespace()\n", + "state.key_value_memory_dict = apriel_ssm.allocate_inference_cache(batch_size, max_length, dtype=torch.bfloat16)\n", + "state.batch_size = batch_size\n", + "state.seqlen_offset = 0\n", + "static_inputs = {\"inference_params\": state,\n", + " \"input_ids\": input_ids,\n", + " \"use_cache\": True,\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "CustomMambaCausalLMOutput(loss=None, logits=tensor([[[-5.4688, -1.6641, 0.4609, ..., -7.1562, -3.7812, -5.9062],\n", + " [-3.5000, 1.4297, 4.3125, ..., -5.3438, -4.9375, -2.9844],\n", + " [-3.1094, 0.7930, 2.2969, ..., -3.1250, -4.1875, -2.1250],\n", + " ...,\n", + " [-5.3438, -3.0938, -3.9062, ..., -4.9062, -3.0000, -3.9688],\n", + " [-3.0625, -3.2188, 5.6562, ..., -2.7812, -2.5938, -6.6562],\n", + " [-1.8438, -1.7500, 5.9062, ..., -3.7188, -2.1250, -0.8281]]],\n", + " device='cuda:0', grad_fn=), all_hidden_states=(), last_hidden_state=tensor([[[ 1.2266, 0.5547, -1.1953, ..., 0.1089, -2.5781, 0.6328],\n", + " [-0.4395, 0.5938, -0.1562, ..., -0.6719, -0.6367, -0.3086],\n", + " [ 0.0077, 0.6680, -1.0703, ..., -3.6875, 0.2207, 0.1299],\n", + " ...,\n", + " [-0.0703, 0.4551, 0.1104, ..., 1.3438, 1.3984, 1.1641],\n", + " [-0.0613, 1.9141, -0.5430, ..., -1.0312, -0.6680, 0.0518],\n", + " [-0.6172, 0.2148, -0.5977, ..., -1.2734, -0.1914, 2.2344]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=))" + ] + }, + "execution_count": 56, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_ssm.forward(**static_inputs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "hymba2", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 082cf22c941a2d8ea992d8f088f27fc57c92c4de Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 29 Apr 2025 13:03:23 +0000 Subject: [PATCH 034/114] Apriel SSM conversion --- fast_llm/layers/ssm/config.py | 10 +- fast_llm/models/ssm/config.py | 24 +- fast_llm/models/ssm/conversion.py | 298 ++++++++++++++---- .../models/ssm/external/ariel_to_ssm.ipynb | 115 ++++++- 4 files changed, 374 insertions(+), 73 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 984858fc..2effa8a6 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -55,7 +55,7 @@ class SSMArchitectureConfig(BaseModelArchitectureConfig): hint=FieldHint.core, ) - dt_rank: int = Field( + dt_rank: None | int = Field( default=None, desc="Rank of the Δ projection matrix. If 'None', will be set to ceil(hidden_size/16)", hint=FieldHint.core, @@ -85,12 +85,16 @@ class SSMArchitectureConfig(BaseModelArchitectureConfig): hint=FieldHint.core, ) + d_inner: None | int = Field( + default=None, + desc="Inner dimension for Mamba2 blocks.", + hint=FieldHint.core, + ) + def _validate(self) -> None: with self._set_implicit_default(): if self.activation_type is None: self.activation_type = ActivationType.silu - if self.dt_rank is None: - self.dt_rank = -1 # set to -1, it will be overwrittem in ssm validation super()._validate() diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index b38467d3..9d8c9bfd 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -49,12 +49,16 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: "Block pattern must contain at least one 'm' or 'm2', use gpt model for transformer only architectures" ) - if self.ssm.dt_rank < 0: + if self.ssm.dt_rank is None: mamba_dt_rank = math.ceil(self.transformer.hidden_size / 16) else: mamba_dt_rank = self.ssm.dt_rank - d_inner = int(self.ssm.expansion_factor * self.transformer.hidden_size) + d_inner = ( + int(self.ssm.expansion_factor * self.transformer.hidden_size) + if self.ssm.d_inner is None + else self.ssm.d_inner + ) # Hidden dimension tensor_space.add_tensor_dim(TensorDim(SSMDimNames.model_dim, self.transformer.hidden_size)) # Mamba-specific dimensions @@ -115,12 +119,26 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return LLambaHuggingfaceCheckpointHandler +class AprielSSMHuggingfaceCheckpointFormat(CheckpointFormat): + support_optimizer: typing.ClassVar[bool] = False + name: typing.ClassVar[str] = "apriel_ssm" + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.ssm.conversion import AprielSSMHuggingfaceCheckpointHandler + + return AprielSSMHuggingfaceCheckpointHandler + + @config_class() class HybridSSMModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "hybrid_ssm" base_model: HybridSSMBaseModelConfig = FieldUpdate(default_factory=HybridSSMBaseModelConfig) - checkpoint_formats = FastLLMModelConfig.checkpoint_formats + (LLambaHuggingfaceCheckpointFormat,) + checkpoint_formats = FastLLMModelConfig.checkpoint_formats + ( + LLambaHuggingfaceCheckpointFormat, + AprielSSMHuggingfaceCheckpointFormat, + ) @classmethod def get_model_class(cls) -> type["HybridSSMModel"]: diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 190b2ffa..a8b6ceff 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -5,6 +5,7 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( + ConstantExportParamConverter, ConstantImportParamConverter, IgnoreImportWeightConverter, MappedConfigParamConverter, @@ -18,7 +19,11 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import NormalizationType from fast_llm.models.gpt.conversion import MLPLayer2Converter -from fast_llm.models.ssm.config import HybridSSMModelConfig, LLambaHuggingfaceCheckpointFormat +from fast_llm.models.ssm.config import ( + AprielSSMHuggingfaceCheckpointFormat, + HybridSSMModelConfig, + LLambaHuggingfaceCheckpointFormat, +) from fast_llm.models.ssm.model import HybridSSMModel from fast_llm.utils import Assert @@ -26,74 +31,17 @@ pass -class LLambaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): +class CommonSSMHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): _model: HybridSSMModel _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig - format: typing.ClassVar[type[CheckpointFormat]] = LLambaHuggingfaceCheckpointFormat @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - """ - Create config converters for the model, see args under https://huggingface.co/cartesia-ai/Llamba-8B/blob/main/config.json - """ return super()._create_config_converters() + [ - ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), - RenameParamConverter( - fast_llm_names=(("transformer", "num_layers"),), - export_names=(("n_layer",),), - ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), - # TODO: is there an equivalen of pad_vocab_size_multiple in FastLLM, does it matter? - RenameParamConverter( - fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) - ), - RenameParamConverter( - fast_llm_names=(("ssm", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) - ), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm - ), RenameParamConverter( fast_llm_names=(("vocab_size",),), export_names=(("vocab_size",),), ), - RenameParamConverter( - fast_llm_names=(("tie_word_embeddings",),), - export_names=(("tie_embeddings",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "hidden_size"),), - export_names=(("d_model",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "ffn_hidden_size"),), - export_names=( - ( - "mlp_cfg", - "intermediate_size", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "add_linear_biases"),), - export_names=( - ( - "mlp_cfg", - "bias", - ), - ), - ), - MappedConfigParamConverter( - fast_llm_names=(("transformer", "activation_type"),), - export_names=( - ( - "mlp_cfg", - "act_fn", - ), - ), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), RenameParamConverter( fast_llm_names=(("ssm", "state_size"),), export_names=( @@ -161,6 +109,238 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ] + +class LLambaHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandler): + _model: HybridSSMModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + format: typing.ClassVar[type[CheckpointFormat]] = LLambaHuggingfaceCheckpointFormat + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + """ + Create config converters for the model, see args under https://huggingface.co/cartesia-ai/Llamba-8B/blob/main/config.json + """ + return super()._create_config_converters() + [ + RenameParamConverter( + fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) + ), + ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=(("transformer", "num_layers"),), + export_names=(("n_layer",),), + ), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + ), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=( + ( + "mlp_cfg", + "act_fn", + ), + ), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("transformer", "add_linear_biases"),), + export_names=( + ( + "mlp_cfg", + "bias", + ), + ), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "ffn_hidden_size"),), + export_names=( + ( + "mlp_cfg", + "intermediate_size", + ), + ), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "hidden_size"),), + export_names=(("d_model",),), + ), + RenameParamConverter( + fast_llm_names=(("tie_word_embeddings",),), + export_names=(("tie_embeddings",),), + ), + ] + + def _create_weight_converters(self) -> list[WeightConverter]: + converters = [] + num_layers = self._model.config.base_model.transformer.num_layers + norm_bias: bool = False + ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear + + # Embedding and output + if self._model.config.base_model.tie_word_embeddings: + converters.append(WeightConverter("layers.0.word_embeddings_weight", "backbone.embedding.weight")) + converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) + else: + converters.append(WeightConverter("layers.0.word_embeddings_weight", "backbone.embedding.weight")) + converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) + + # Final norm + converters += self._get_weight_and_bias_converters( + f"layers.{num_layers + 1}.final_norm", "backbone.final_layernorm", norm_bias + ) + + for i in range(num_layers): + # SSM + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.in_proj", f"backbone.layers.{i}.mixer.in_proj", ssm_bias + ) + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.out_proj", f"backbone.layers.{i}.mixer.out_proj", ssm_bias + ) + converters.append( + WeightConverter(f"layers.{i+1}.mixer.D", f"backbone.layers.{i}.mixer.D", self._model.config.base_model) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.z_bias", f"backbone.layers.{i}.mixer.z_bias", self._model.config.base_model + ) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.conv1d_weight", + f"backbone.layers.{i}.mixer.conv1d.weight", + self._model.config.base_model, + ) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.conv1d_bias", + f"backbone.layers.{i}.mixer.conv1d.bias", + self._model.config.base_model, + ) + ) + + # Norm + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.norm_1", f"backbone.layers.{i}.input_layernorm", norm_bias + ) + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.norm_2", f"backbone.layers.{i}.post_attention_layernorm", norm_bias + ) + + # MLP + converters += self._get_mlp_converters(f"layers.{i+1}", f"backbone.layers.{i}") + + return converters + + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases + return [ + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), + linear_bias, + SplitWeightConverter, + ), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + linear_bias, + MLPLayer2Converter, + ), + ] + + def _get_weight_and_bias_converters( + self, + fast_llm_prefix: str | tuple[str, ...], + hf_prefix: str | tuple[str, ...], + use_bias: bool, + cls=WeightConverter, + ) -> list[WeightConverter]: + if isinstance(fast_llm_prefix, str): + fast_llm_prefix = (fast_llm_prefix,) + if isinstance(hf_prefix, str): + hf_prefix = (hf_prefix,) + converters = [ + cls( + tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), + tuple(f"{prefix}.weight" for prefix in hf_prefix), + self._model.config.base_model, + ) + ] + if use_bias: + converters.append( + cls( + tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), + tuple(f"{prefix}.bias" for prefix in hf_prefix), + self._model.config.base_model, + ) + ) + return converters + + @classmethod + def _load_config(cls, directory: pathlib.Path | str) -> dict: + if not os.path.exists(directory / "config.json"): + raise FileNotFoundError(f"config.json not found in {directory}") + with open(directory / "config.json") as f: + config = json.load(f) + Assert.eq(config["model_type"], cls.get_huggingface_model_type()) + return config + + @classmethod + def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: + with open(directory / "config.json", "w") as f: + json.dump(config, f) + + +class AprielSSMHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandler): + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHuggingfaceCheckpointFormat + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + RenameParamConverter( + fast_llm_names=(("ssm", "d_inner"),), + export_names=(("ssm_cfg", "d_inner"),), + ), + ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False), + ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=(("hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("transformer", "num_layers"),), + export_names=(("num_hidden_layers",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "hidden_size"),), + export_names=(("hidden_size",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "ffn_hidden_size"),), + export_names=(("intermediate_size",),), + ), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + ), + RenameParamConverter( + fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) + ), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), + ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=(("tie_word_embeddings",),), + export_names=(("tie_word_embeddings",),), + ), + ConstantImportParamConverter(fast_llm_names=(("hybrid_block_layout"),), fast_llm_value=["m2"]), + ] + def _create_weight_converters(self) -> list[WeightConverter]: converters = [] num_layers = self._model.config.base_model.transformer.num_layers diff --git a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb index 8c5f64ae..85608075 100644 --- a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb +++ b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb @@ -48,7 +48,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 9.90it/s]\n" + "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 6.68it/s]\n" ] }, { @@ -246,7 +246,52 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMForCausalLM(\n", + " (model): AprielSSMModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=11304, bias=False)\n", + " (conv1d): Conv1d(7176, 7176, kernel_size=(4,), stride=(1,), padding=(3,), groups=7176)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=4104, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "apriel_ssm.to(device).to(dtype=torch.bfloat16)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -255,7 +300,7 @@ "_IncompatibleKeys(missing_keys=['model.layers.0.mixer.z_bias', 'model.layers.0.mixer.D', 'model.layers.0.mixer.in_proj.weight', 'model.layers.0.mixer.conv1d.weight', 'model.layers.0.mixer.conv1d.bias', 'model.layers.0.mixer.out_proj.weight', 'model.layers.1.mixer.z_bias', 'model.layers.1.mixer.D', 'model.layers.1.mixer.in_proj.weight', 'model.layers.1.mixer.conv1d.weight', 'model.layers.1.mixer.conv1d.bias', 'model.layers.1.mixer.out_proj.weight', 'model.layers.2.mixer.z_bias', 'model.layers.2.mixer.D', 'model.layers.2.mixer.in_proj.weight', 'model.layers.2.mixer.conv1d.weight', 'model.layers.2.mixer.conv1d.bias', 'model.layers.2.mixer.out_proj.weight', 'model.layers.3.mixer.z_bias', 'model.layers.3.mixer.D', 'model.layers.3.mixer.in_proj.weight', 'model.layers.3.mixer.conv1d.weight', 'model.layers.3.mixer.conv1d.bias', 'model.layers.3.mixer.out_proj.weight', 'model.layers.4.mixer.z_bias', 'model.layers.4.mixer.D', 'model.layers.4.mixer.in_proj.weight', 'model.layers.4.mixer.conv1d.weight', 'model.layers.4.mixer.conv1d.bias', 'model.layers.4.mixer.out_proj.weight', 'model.layers.5.mixer.z_bias', 'model.layers.5.mixer.D', 'model.layers.5.mixer.in_proj.weight', 'model.layers.5.mixer.conv1d.weight', 'model.layers.5.mixer.conv1d.bias', 'model.layers.5.mixer.out_proj.weight', 'model.layers.6.mixer.z_bias', 'model.layers.6.mixer.D', 'model.layers.6.mixer.in_proj.weight', 'model.layers.6.mixer.conv1d.weight', 'model.layers.6.mixer.conv1d.bias', 'model.layers.6.mixer.out_proj.weight', 'model.layers.7.mixer.z_bias', 'model.layers.7.mixer.D', 'model.layers.7.mixer.in_proj.weight', 'model.layers.7.mixer.conv1d.weight', 'model.layers.7.mixer.conv1d.bias', 'model.layers.7.mixer.out_proj.weight', 'model.layers.8.mixer.z_bias', 'model.layers.8.mixer.D', 'model.layers.8.mixer.in_proj.weight', 'model.layers.8.mixer.conv1d.weight', 'model.layers.8.mixer.conv1d.bias', 'model.layers.8.mixer.out_proj.weight', 'model.layers.9.mixer.z_bias', 'model.layers.9.mixer.D', 'model.layers.9.mixer.in_proj.weight', 'model.layers.9.mixer.conv1d.weight', 'model.layers.9.mixer.conv1d.bias', 'model.layers.9.mixer.out_proj.weight', 'model.layers.10.mixer.z_bias', 'model.layers.10.mixer.D', 'model.layers.10.mixer.in_proj.weight', 'model.layers.10.mixer.conv1d.weight', 'model.layers.10.mixer.conv1d.bias', 'model.layers.10.mixer.out_proj.weight', 'model.layers.11.mixer.z_bias', 'model.layers.11.mixer.D', 'model.layers.11.mixer.in_proj.weight', 'model.layers.11.mixer.conv1d.weight', 'model.layers.11.mixer.conv1d.bias', 'model.layers.11.mixer.out_proj.weight', 'model.layers.12.mixer.z_bias', 'model.layers.12.mixer.D', 'model.layers.12.mixer.in_proj.weight', 'model.layers.12.mixer.conv1d.weight', 'model.layers.12.mixer.conv1d.bias', 'model.layers.12.mixer.out_proj.weight', 'model.layers.13.mixer.z_bias', 'model.layers.13.mixer.D', 'model.layers.13.mixer.in_proj.weight', 'model.layers.13.mixer.conv1d.weight', 'model.layers.13.mixer.conv1d.bias', 'model.layers.13.mixer.out_proj.weight', 'model.layers.14.mixer.z_bias', 'model.layers.14.mixer.D', 'model.layers.14.mixer.in_proj.weight', 'model.layers.14.mixer.conv1d.weight', 'model.layers.14.mixer.conv1d.bias', 'model.layers.14.mixer.out_proj.weight', 'model.layers.15.mixer.z_bias', 'model.layers.15.mixer.D', 'model.layers.15.mixer.in_proj.weight', 'model.layers.15.mixer.conv1d.weight', 'model.layers.15.mixer.conv1d.bias', 'model.layers.15.mixer.out_proj.weight', 'model.layers.16.mixer.z_bias', 'model.layers.16.mixer.D', 'model.layers.16.mixer.in_proj.weight', 'model.layers.16.mixer.conv1d.weight', 'model.layers.16.mixer.conv1d.bias', 'model.layers.16.mixer.out_proj.weight', 'model.layers.17.mixer.z_bias', 'model.layers.17.mixer.D', 'model.layers.17.mixer.in_proj.weight', 'model.layers.17.mixer.conv1d.weight', 'model.layers.17.mixer.conv1d.bias', 'model.layers.17.mixer.out_proj.weight', 'model.layers.18.mixer.z_bias', 'model.layers.18.mixer.D', 'model.layers.18.mixer.in_proj.weight', 'model.layers.18.mixer.conv1d.weight', 'model.layers.18.mixer.conv1d.bias', 'model.layers.18.mixer.out_proj.weight', 'model.layers.19.mixer.z_bias', 'model.layers.19.mixer.D', 'model.layers.19.mixer.in_proj.weight', 'model.layers.19.mixer.conv1d.weight', 'model.layers.19.mixer.conv1d.bias', 'model.layers.19.mixer.out_proj.weight', 'model.layers.20.mixer.z_bias', 'model.layers.20.mixer.D', 'model.layers.20.mixer.in_proj.weight', 'model.layers.20.mixer.conv1d.weight', 'model.layers.20.mixer.conv1d.bias', 'model.layers.20.mixer.out_proj.weight', 'model.layers.21.mixer.z_bias', 'model.layers.21.mixer.D', 'model.layers.21.mixer.in_proj.weight', 'model.layers.21.mixer.conv1d.weight', 'model.layers.21.mixer.conv1d.bias', 'model.layers.21.mixer.out_proj.weight', 'model.layers.22.mixer.z_bias', 'model.layers.22.mixer.D', 'model.layers.22.mixer.in_proj.weight', 'model.layers.22.mixer.conv1d.weight', 'model.layers.22.mixer.conv1d.bias', 'model.layers.22.mixer.out_proj.weight', 'model.layers.23.mixer.z_bias', 'model.layers.23.mixer.D', 'model.layers.23.mixer.in_proj.weight', 'model.layers.23.mixer.conv1d.weight', 'model.layers.23.mixer.conv1d.bias', 'model.layers.23.mixer.out_proj.weight', 'model.layers.24.mixer.z_bias', 'model.layers.24.mixer.D', 'model.layers.24.mixer.in_proj.weight', 'model.layers.24.mixer.conv1d.weight', 'model.layers.24.mixer.conv1d.bias', 'model.layers.24.mixer.out_proj.weight', 'model.layers.25.mixer.z_bias', 'model.layers.25.mixer.D', 'model.layers.25.mixer.in_proj.weight', 'model.layers.25.mixer.conv1d.weight', 'model.layers.25.mixer.conv1d.bias', 'model.layers.25.mixer.out_proj.weight', 'model.layers.26.mixer.z_bias', 'model.layers.26.mixer.D', 'model.layers.26.mixer.in_proj.weight', 'model.layers.26.mixer.conv1d.weight', 'model.layers.26.mixer.conv1d.bias', 'model.layers.26.mixer.out_proj.weight', 'model.layers.27.mixer.z_bias', 'model.layers.27.mixer.D', 'model.layers.27.mixer.in_proj.weight', 'model.layers.27.mixer.conv1d.weight', 'model.layers.27.mixer.conv1d.bias', 'model.layers.27.mixer.out_proj.weight'], unexpected_keys=['model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.18.self_attn.q_proj.weight', 'model.layers.18.self_attn.k_proj.weight', 'model.layers.18.self_attn.v_proj.weight', 'model.layers.18.self_attn.o_proj.weight', 'model.layers.19.self_attn.q_proj.weight', 'model.layers.19.self_attn.k_proj.weight', 'model.layers.19.self_attn.v_proj.weight', 'model.layers.19.self_attn.o_proj.weight', 'model.layers.20.self_attn.q_proj.weight', 'model.layers.20.self_attn.k_proj.weight', 'model.layers.20.self_attn.v_proj.weight', 'model.layers.20.self_attn.o_proj.weight', 'model.layers.21.self_attn.q_proj.weight', 'model.layers.21.self_attn.k_proj.weight', 'model.layers.21.self_attn.v_proj.weight', 'model.layers.21.self_attn.o_proj.weight', 'model.layers.22.self_attn.q_proj.weight', 'model.layers.22.self_attn.k_proj.weight', 'model.layers.22.self_attn.v_proj.weight', 'model.layers.22.self_attn.o_proj.weight', 'model.layers.23.self_attn.q_proj.weight', 'model.layers.23.self_attn.k_proj.weight', 'model.layers.23.self_attn.v_proj.weight', 'model.layers.23.self_attn.o_proj.weight', 'model.layers.24.self_attn.q_proj.weight', 'model.layers.24.self_attn.k_proj.weight', 'model.layers.24.self_attn.v_proj.weight', 'model.layers.24.self_attn.o_proj.weight', 'model.layers.25.self_attn.q_proj.weight', 'model.layers.25.self_attn.k_proj.weight', 'model.layers.25.self_attn.v_proj.weight', 'model.layers.25.self_attn.o_proj.weight', 'model.layers.26.self_attn.q_proj.weight', 'model.layers.26.self_attn.k_proj.weight', 'model.layers.26.self_attn.v_proj.weight', 'model.layers.26.self_attn.o_proj.weight', 'model.layers.27.self_attn.q_proj.weight', 'model.layers.27.self_attn.k_proj.weight', 'model.layers.27.self_attn.v_proj.weight', 'model.layers.27.self_attn.o_proj.weight'])" ] }, - "execution_count": 49, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -266,14 +311,58 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMForCausalLM(\n", + " (model): AprielSSMModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=11304, bias=False)\n", + " (conv1d): Conv1d(7176, 7176, kernel_size=(4,), stride=(1,), padding=(3,), groups=7176)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=4104, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "\n", "apriel_ssm.to(device).to(dtype=torch.bfloat16)" ] }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "# apriel_ssm.state_dict()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -283,11 +372,21 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 15, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/toolkit/.local/lib/python3.12/site-packages/transformers/modeling_utils.py:2714: UserWarning: `save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead.\n", + " warnings.warn(\n" + ] + } + ], "source": [ - "apriel_ssm.save_pretrained(\"/mnt/checkpoints/ssm/ariel_ssm\")" + "apriel_ssm.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm\",\n", + " save_config=True)\n" ] }, { From 0d4d5c5b0cae2503a11b905e832d0693fa643c89 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 29 Apr 2025 13:14:00 -0400 Subject: [PATCH 035/114] fix --- fast_llm/layers/language_model/config.py | 2 ++ fast_llm/models/gpt/model.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index b4b4e187..4fb471fb 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -215,6 +215,8 @@ def _validate(self) -> None: super()._validate() if self.init_method_max_embed is not None and self.init_method_min_embed is not None: Assert.leq(self.init_method_min_embed, self.init_method_max_embed) + if self.prediction_heads > 1: + Assert.gt(self.transformer.num_layers, 1) if self.distillation_model is not None: if self.prediction_heads > 1: raise NotImplementedError("Multi-token prediction not supported with distillation.") diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 54c3d882..9e28373b 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -102,6 +102,9 @@ def get_layers(self) -> list[Layer]: self._config.transformer, self._tensor_space, layer_index=i + 1, + # The last layer only returns the transformer output. + # The previous layers return a stack of shared_hidden and transformer_output. + return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, ) for i in range(self._config.transformer.num_layers) ], From c43e535ce8a296ed95e468b370e45d6165d5beb3 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 29 Apr 2025 18:52:28 +0000 Subject: [PATCH 036/114] wip --- fast_llm/models/ssm/config.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 9d8c9bfd..7cad0d52 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -10,7 +10,7 @@ from fast_llm.engine.training.config import TrainerConfig from fast_llm.layers.language_model.config import LanguageModelArchitectureConfig, LanguageModelBaseConfig from fast_llm.layers.ssm.config import SSMDimNames -from fast_llm.models.gpt.config import GPTBatchConfig +from fast_llm.models.gpt.config import GPTBatchConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -169,9 +169,36 @@ class PretrainedHybridSSMModelConfig(PretrainedFastLLMModelConfig): class HybridTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) batch: GPTBatchConfig = FieldUpdate(default_factory=GPTBatchConfig) + reference_models: dict[str, PretrainedGPTModelConfig] = ( + FieldUpdate() + ) # TODO: make sure any reference mdoel can be suported @classmethod def get_trainer_class(cls) -> type["SSMTrainer"]: from fast_llm.models.ssm.trainer import SSMTrainer return SSMTrainer + + def _validate(self) -> None: + super()._validate() + if (name := self.model.base_model.distillation_model) is None: + Assert.empty(self.reference_models) + else: + Assert.eq(self.reference_models.keys(), {name}) + if self.model.base_model.use_absolute_position_embeddings: + Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) + if self.model.base_model.distillation_model is not None: + # TODO: Support loss masking for distillation? + assert not self.batch.use_loss_masking_spans + for reference_model in self.reference_models.values(): + Assert.none(reference_model.model.base_model.distillation_model) + # TODO: Support more LM head features. + Assert.none(reference_model.model.base_model.cross_entropy_splits) + Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) + Assert.geq(reference_model.model.base_model.prediction_heads, self.model.base_model.prediction_heads) + + @classmethod + def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: + from fast_llm.models.gpt.model import GPTInferenceRunner + + return GPTInferenceRunner From a1f44d41d47c4424b2b033ae5515d7a067d110c7 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 29 Apr 2025 21:31:43 +0000 Subject: [PATCH 037/114] conversion apriel ssm --- fast_llm/models/gpt/config.py | 2 ++ fast_llm/models/ssm/conversion.py | 24 ++++++++++++------------ 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 988f27b8..b82dd3e8 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -226,4 +226,6 @@ def get_trainer_class(cls) -> type["GPTTrainer"]: def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: from fast_llm.models.gpt.model import GPTInferenceRunner + # TODO" we dont have inference runner for SSM/Hybrid yet, should return None? + return GPTInferenceRunner diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index a8b6ceff..fb7776d5 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -349,58 +349,58 @@ def _create_weight_converters(self) -> list[WeightConverter]: # Embedding and output if self._model.config.base_model.tie_word_embeddings: - converters.append(WeightConverter("layers.0.word_embeddings_weight", "backbone.embedding.weight")) + converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) else: - converters.append(WeightConverter("layers.0.word_embeddings_weight", "backbone.embedding.weight")) + converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) # Final norm converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", "backbone.final_layernorm", norm_bias + f"layers.{num_layers + 1}.final_norm", "model.norm", norm_bias ) for i in range(num_layers): # SSM converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.in_proj", f"backbone.layers.{i}.mixer.in_proj", ssm_bias + f"layers.{i+1}.mixer.in_proj", f"model.layers.{i}.mixer.in_proj", ssm_bias ) converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.out_proj", f"backbone.layers.{i}.mixer.out_proj", ssm_bias + f"layers.{i+1}.mixer.out_proj", f"model.layers.{i}.mixer.out_proj", ssm_bias ) converters.append( - WeightConverter(f"layers.{i+1}.mixer.D", f"backbone.layers.{i}.mixer.D", self._model.config.base_model) + WeightConverter(f"layers.{i+1}.mixer.D", f"model.layers.{i}.mixer.D", self._model.config.base_model) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.z_bias", f"backbone.layers.{i}.mixer.z_bias", self._model.config.base_model + f"layers.{i+1}.mixer.z_bias", f"model.layers.{i}.mixer.z_bias", self._model.config.base_model ) ) converters.append( WeightConverter( f"layers.{i+1}.mixer.conv1d_weight", - f"backbone.layers.{i}.mixer.conv1d.weight", + f"model.layers.{i}.mixer.conv1d.weight", self._model.config.base_model, ) ) converters.append( WeightConverter( f"layers.{i+1}.mixer.conv1d_bias", - f"backbone.layers.{i}.mixer.conv1d.bias", + f"model.layers.{i}.mixer.conv1d.bias", self._model.config.base_model, ) ) # Norm converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.norm_1", f"backbone.layers.{i}.input_layernorm", norm_bias + f"layers.{i+1}.norm_1", f"model.layers.{i}.input_layernorm", norm_bias ) converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.norm_2", f"backbone.layers.{i}.post_attention_layernorm", norm_bias + f"layers.{i+1}.norm_2", f"model.layers.{i}.post_attention_layernorm", norm_bias ) # MLP - converters += self._get_mlp_converters(f"layers.{i+1}", f"backbone.layers.{i}") + converters += self._get_mlp_converters(f"layers.{i+1}", f"model.layers.{i}") return converters From fbec02df596e91d013ea814d4cd00b4c346d2849 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 29 Apr 2025 21:32:21 +0000 Subject: [PATCH 038/114] config apriel --- fast_llm/models/ssm/external/configuration_ssm_apriel.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/fast_llm/models/ssm/external/configuration_ssm_apriel.py b/fast_llm/models/ssm/external/configuration_ssm_apriel.py index 0c75ca65..2e5d5810 100644 --- a/fast_llm/models/ssm/external/configuration_ssm_apriel.py +++ b/fast_llm/models/ssm/external/configuration_ssm_apriel.py @@ -56,6 +56,7 @@ def __init__( mlp_bias=False, rms_norm_eps=1e-5, ssm_cfg: dict = None, + head_dim: int = 128, **kwargs, ): self.vocab_size = vocab_size @@ -71,8 +72,7 @@ def __init__( self.use_cache = use_cache # self.rope_theta = rope_theta self.mlp_bias = mlp_bias - # self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads - + self.head_dim = head_dim # Validate the correctness of rotary position embeddings parameters # BC: if there is a 'type' field, copy it it to 'rope_type'. # if self.rope_scaling is not None and "type" in self.rope_scaling: @@ -94,8 +94,9 @@ def __init__( "chunk_size": 128, "activation": "identity", "bias": False, - "d_inner": 4104, # to make sure we have 24 heads + "d_inner": 24 * self.head_dim, # num_heads * head_dim } + assert self.head_dim == self.ssm_cfg["d_inner"] // self.ssm_cfg["n_qk_heads"] __all__ = ["AprielConfig"] From 75d646059fcaac0762aaedae1c8a175ea2a67e37 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 29 Apr 2025 23:53:19 +0000 Subject: [PATCH 039/114] temp checkpoint conversion --- fast_llm/models/ssm/conversion.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index fb7776d5..a9a139b5 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -1,3 +1,4 @@ +import dataclasses import json import os import pathlib @@ -31,6 +32,23 @@ pass +@dataclasses.dataclass(kw_only=True) +class HybridBlockLayoutConverter(ParamConverter): + num_layers_getter: typing.Callable[[typing.Any], int] = lambda config: config.transformer.num_layers + + # TODO: generalize this t + def __post_init__(self) -> None: + Assert.eq(len(self.fast_llm_names), 1) + Assert.eq(len(self.export_names), 1) + + def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + # Use the expanded list as-is + return (["m2"],) + + def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: + return (["m2"],) + + class CommonSSMHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): _model: HybridSSMModel _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig @@ -170,6 +188,9 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_names=(("tie_word_embeddings",),), export_names=(("tie_embeddings",),), ), + HybridBlockLayoutConverter( + fast_llm_names=(("hybrid_block_layout",),), export_names=(("hybrid_block_layout",),) + ), ] def _create_weight_converters(self) -> list[WeightConverter]: @@ -338,7 +359,10 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_names=(("tie_word_embeddings",),), export_names=(("tie_word_embeddings",),), ), - ConstantImportParamConverter(fast_llm_names=(("hybrid_block_layout"),), fast_llm_value=["m2"]), + # ConstantImportParamConverter(fast_llm_names=(("hybrid_block_layout"),), fast_llm_value=["m2"]), + HybridBlockLayoutConverter( + fast_llm_names=(("hybrid_block_layout",),), export_names=(("hybrid_block_layout",),) + ), ] def _create_weight_converters(self) -> list[WeightConverter]: From 73a4252e2f93b5ae5cdb36a9a026bd81db440a01 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 30 Apr 2025 13:33:04 +0000 Subject: [PATCH 040/114] block pattern for hybrid conversion --- fast_llm/models/ssm/conversion.py | 56 +++++++++++++++++++------------ 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index a9a139b5..fb862a02 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -1,4 +1,3 @@ -import dataclasses import json import os import pathlib @@ -32,21 +31,39 @@ pass -@dataclasses.dataclass(kw_only=True) -class HybridBlockLayoutConverter(ParamConverter): - num_layers_getter: typing.Callable[[typing.Any], int] = lambda config: config.transformer.num_layers +class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): + """ + This is a temporary solution for importing/exporting hybrid models. Since there is no standard solution for this in HF, we just use the block_pattern. + If block_pattern is None, it will multiply the provided default block type by the number of layers and export/import it. + If block_pattern is provided, it will export/import it as-is. + """ - # TODO: generalize this t - def __post_init__(self) -> None: - Assert.eq(len(self.fast_llm_names), 1) - Assert.eq(len(self.export_names), 1) + _model: HybridSSMModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + _default_block_type: str = "m2" - def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - # Use the expanded list as-is - return (["m2"],) + @classmethod + def _import_config(cls, config, architecture_only: bool = False): + cls.num_layers = config["n_layer"] if "n_layer" in config else config["num_hidden_layers"] + cls.block_pattern = config.get("hybrid_block_layout", None) + return super()._import_config(config, architecture_only) - def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: - return (["m2"],) + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + if cls.block_pattern is not None: + block_converter = MappedConfigParamConverter( + fast_llm_names=(("hybrid_block_layout",),), + export_names=(("hybrid_block_layout",),), + fast_llm_value=cls.block_pattern, + export_value=cls.block_pattern, + ) + else: + block_converter = ConstantImportParamConverter( + fast_llm_names=(("hybrid_block_layout",),), + fast_llm_value=[cls._default_block_type] * cls.num_layers, + ) + + return super()._create_config_converters() + [block_converter] class CommonSSMHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): @@ -128,10 +145,11 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] -class LLambaHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandler): +class LLambaHuggingfaceCheckpointHandler(HybridModelCheckpointHandler, CommonSSMHuggingfaceCheckpointHandler): _model: HybridSSMModel _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig format: typing.ClassVar[type[CheckpointFormat]] = LLambaHuggingfaceCheckpointFormat + _default_block_type: str = "m2" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: @@ -188,9 +206,6 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_names=(("tie_word_embeddings",),), export_names=(("tie_embeddings",),), ), - HybridBlockLayoutConverter( - fast_llm_names=(("hybrid_block_layout",),), export_names=(("hybrid_block_layout",),) - ), ] def _create_weight_converters(self) -> list[WeightConverter]: @@ -316,9 +331,10 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An json.dump(config, f) -class AprielSSMHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandler): +class AprielSSMHuggingfaceCheckpointHandler(HybridModelCheckpointHandler, CommonSSMHuggingfaceCheckpointHandler): _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHuggingfaceCheckpointFormat + _default_block_type: str = "m2" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: @@ -359,10 +375,6 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_names=(("tie_word_embeddings",),), export_names=(("tie_word_embeddings",),), ), - # ConstantImportParamConverter(fast_llm_names=(("hybrid_block_layout"),), fast_llm_value=["m2"]), - HybridBlockLayoutConverter( - fast_llm_names=(("hybrid_block_layout",),), export_names=(("hybrid_block_layout",),) - ), ] def _create_weight_converters(self) -> list[WeightConverter]: From 5afc7dc21ead590a67c7bc4fd2765f4f95456205 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 30 Apr 2025 14:18:55 +0000 Subject: [PATCH 041/114] SSMBlockType --- fast_llm/layers/ssm/config.py | 13 +++++++++++++ fast_llm/models/ssm/config.py | 19 +++++++++++-------- fast_llm/models/ssm/conversion.py | 7 ++++--- fast_llm/models/ssm/model.py | 10 +++++----- 4 files changed, 33 insertions(+), 16 deletions(-) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 2effa8a6..459401df 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,3 +1,5 @@ +import enum + from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig from fast_llm.functional.config import ActivationType @@ -20,6 +22,17 @@ class SSMDimNames: v_heads = "v_heads" # Number of V heads +class SSMBlockType(str, enum.Enum): + """ + An enum for the available mamba types for the MLP layer. + """ + + mamba = "m" + mamba2_discrete = "m2d" + mamba2 = "m2" + transformer = "t" + + @config_class() class SSMArchitectureConfig(BaseModelArchitectureConfig): _abstract = False diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 7cad0d52..d77d206b 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -9,7 +9,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.layers.language_model.config import LanguageModelArchitectureConfig, LanguageModelBaseConfig -from fast_llm.layers.ssm.config import SSMDimNames +from fast_llm.layers.ssm.config import SSMBlockType, SSMDimNames from fast_llm.models.gpt.config import GPTBatchConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert @@ -24,8 +24,8 @@ class HybridSSMArchitectureConfig(LanguageModelArchitectureConfig): _abstract = False hybrid_block_layout: list[str] = Field( - default_factory=lambda: ["m2"], - desc="Pattern of blocks to use in the model. 't' for Transformer, 'm' for Mamba1, 'm2' for Descrete Mamba2.", + default_factory=lambda: [SSMBlockType.mamba2_discrete.value], + desc=f"Pattern of blocks to use in the model. Availabel types: {SSMBlockType.__members__.values()}", hint=FieldHint.core, ) @@ -44,9 +44,12 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: Some of these can be setup directly in the layer config, but keeping them here for clarity. """ super().setup_tensor_space(tensor_space) - if not "m2" in self.hybrid_block_layout and not "m" in self.hybrid_block_layout: + if ( + not SSMBlockType.mamba2_discrete.value in self.hybrid_block_layout + and not SSMBlockType.mamba.value in self.hybrid_block_layout + ): raise ValueError( - "Block pattern must contain at least one 'm' or 'm2', use gpt model for transformer only architectures" + f"Block pattern must contain at least one '{SSMBlockType.mamba2_discrete.value}' or '{SSMBlockType.mamba.value}', use gpt model for transformer only architectures" ) if self.ssm.dt_rank is None: @@ -69,7 +72,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel_size, self.ssm.conv_kernel_dimension)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba, d_inner * 2)) - if "m2" in self.hybrid_block_layout: + if SSMBlockType.mamba2_discrete.value in self.hybrid_block_layout: # Mamba2 specific dimensions # as per https://github.com/cartesia-ai/edge/blob/a0e121ebed3d2324c6d762b0e211a08d62583681/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py#L66C3-L66C4 headdim = d_inner // self.ssm.n_v_heads @@ -101,8 +104,8 @@ def _validate(self): Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) Assert.custom( - lambda _: all(block_type in ["t", "m", "m2"] for block_type in self.hybrid_block_layout), - f"Invalid block type: {self.hybrid_block_layout}. Must be 't' or 'm' or 'm2'", + lambda _: all(block_type in SSMBlockType.__members__.values() for block_type in self.hybrid_block_layout), + f"Invalid block type: {self.hybrid_block_layout}. Must be one of {SSMBlockType.__members__.values()}", ) super()._validate() diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index fb862a02..c2e54ca0 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -18,6 +18,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import NormalizationType +from fast_llm.layers.ssm.config import SSMBlockType from fast_llm.models.gpt.conversion import MLPLayer2Converter from fast_llm.models.ssm.config import ( AprielSSMHuggingfaceCheckpointFormat, @@ -40,7 +41,7 @@ class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): _model: HybridSSMModel _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig - _default_block_type: str = "m2" + _default_block_type: str = SSMBlockType.mamba2_discrete.value @classmethod def _import_config(cls, config, architecture_only: bool = False): @@ -149,7 +150,7 @@ class LLambaHuggingfaceCheckpointHandler(HybridModelCheckpointHandler, CommonSSM _model: HybridSSMModel _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig format: typing.ClassVar[type[CheckpointFormat]] = LLambaHuggingfaceCheckpointFormat - _default_block_type: str = "m2" + _default_block_type: str = SSMBlockType.mamba2_discrete.value @classmethod def _create_config_converters(cls) -> list[ParamConverter]: @@ -334,7 +335,7 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An class AprielSSMHuggingfaceCheckpointHandler(HybridModelCheckpointHandler, CommonSSMHuggingfaceCheckpointHandler): _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHuggingfaceCheckpointFormat - _default_block_type: str = "m2" + _default_block_type: str = SSMBlockType.mamba2_discrete.value @classmethod def _create_config_converters(cls) -> list[ParamConverter]: diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 33d2c185..c9d1ba7d 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -11,7 +11,7 @@ from fast_llm.layers.ssm.mamba_layer import MambaLayer from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.models.gpt.model import GPTBaseModel -from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig +from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType logger = logging.getLogger(__name__) @@ -43,7 +43,7 @@ def get_layers(self) -> list[Layer]: # Create blocks according to pattern for i, block_type in enumerate(self._config.hybrid_block_layout): - if block_type == "t": + if block_type == SSMBlockType.transformer.value: # Transformer block layers.append( TransformerLayer( @@ -52,7 +52,7 @@ def get_layers(self) -> list[Layer]: layer_index=i + 1, ) ) - elif block_type == "m2": + elif block_type == SSMBlockType.mamba2_discrete.value: mamba_block = self.SSM_BLOCK_CLS( config_transformer=self._config.transformer, config_ssm=self._config.ssm, @@ -62,7 +62,7 @@ def get_layers(self) -> list[Layer]: ) layers.append(mamba_block) - elif block_type == "m": + elif block_type == SSMBlockType.mamba.value: # Create Mamba block mamba_block = self.SSM_BLOCK_CLS( config_transformer=self._config.transformer, @@ -74,7 +74,7 @@ def get_layers(self) -> list[Layer]: layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be 't' or 'm' or 'm2'") + raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") # Add the language model head layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)) From 8e9facf9584ff5e8a72cecdcbc8c0b0d6fcdaab8 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 30 Apr 2025 19:02:48 +0000 Subject: [PATCH 042/114] wip --- fast_llm/layers/ssm/mamba2.py | 354 +++ .../models/ssm/external/ariel_to_ssm.ipynb | 2240 ++++++++++++++++- 2 files changed, 2507 insertions(+), 87 deletions(-) create mode 100644 fast_llm/layers/ssm/mamba2.py diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py new file mode 100644 index 00000000..5763cb92 --- /dev/null +++ b/fast_llm/layers/ssm/mamba2.py @@ -0,0 +1,354 @@ +""" +This code is adapted from https://github.com/jxiw/MambaInLlama/blob/main/mamba2/hybrid_mamba_layer.py +""" + +import math + +import causal_conv1d +import einops +import mamba_ssm.ops.triton.ssd_combined +import torch +from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated + +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.common.linear import Linear +from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.tensor import kaiming_init_ + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Mamba2(torch.nn.Module): + def __init__( + self, + config: SSMConfig, + layer_idx: int, + tensor_space: TensorSpace, + ): + # factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.config: SSMConfig = config + bias = config.add_bias_linear + self.layer_idx = layer_idx + + td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) + tensor_space.get_tensor_dim(SSMDimNames.state_dim) + td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) + tensor_space.get_tensor_dim(SSMDimNames.conv_dim) + tensor_space.get_tensor_dim(SSMDimNames.qk_heads) + tensor_space.get_tensor_dim(SSMDimNames.v_heads) + tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) + tensor_space.get_tensor_dim(SSMDimNames.inner_proj_mamba2) + + # self.d_model = d_model + # self.d_state = d_state + # self.d_conv = d_conv + # self.conv_init = conv_init + # self.expand = expand + # self.process_group = process_group + # self.sequence_parallel = sequence_parallel + # self.world_size = 1 if process_group is None else process_group.size() + # self.local_rank = 0 if process_group is None else process_group.rank() + # self.d_inner = d_inner if d_inner is not None else (self.expand * self.d_model) // self.world_size + # # assert self.d_inner * self.world_size == self.expand * self.d_model + # self.headdim = headdim + # self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size + # assert ngroups % self.world_size == 0 + # self.ngroups = ngroups // self.world_size + # assert self.d_ssm % self.headdim == 0 + # self.nheads = self.d_ssm // self.headdim + # self.D_has_hdim = D_has_hdim + # self.rmsnorm = rmsnorm + # self.norm_before_gate = norm_before_gate + # self.dt_limit = dt_limit + # self.activation = "silu" + # self.chunk_size = chunk_size + # self.use_mem_eff_path = use_mem_eff_path + # self.layer_idx = layer_idx + # self.d_xb = d_xb + # self.repeat_group = self.d_inner // self.d_xb + # self.repeat_kv_before_conv = repeat_kv_before_conv + + assert self.d_inner == self.ngroups * self.d_state + assert self.d_inner == self.d_ssm + + self.nheads = self.ngroups + self.headdim = self.d_state + + # Order: [z, x, B, C, dt] + # [hidden_dim, hidden_dim, d_state] + d_in_proj = self.d_inner + self.d_xb + self.d_xb + self.d_inner + self.nheads + # d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads + if self.process_group is None: + self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs) + else: + self.in_proj = ColumnParallelLinear( + self.d_model, + d_in_proj * self.world_size, + bias=bias, + process_group=self.process_group, + sequence_parallel=self.sequence_parallel, + **factory_kwargs, + ) + + # conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state + + if self.repeat_kv_before_conv: + conv_dim = self.d_inner + self.d_inner + self.d_inner + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + **factory_kwargs, + ) + else: + conv_dim = self.d_inner + self.d_xb + self.d_xb + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + **factory_kwargs, + ) + + if self.conv_init is not None: + nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init) + + self.act = nn.SiLU() + + # Initialize log dt bias + dt = torch.exp( + torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) + ) + dt = torch.clamp(dt, min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + # Just to be explicit. Without this we already don't put wd on dt_bias because of the check + # name.endswith("bias") in param_grouping.py + self.dt_bias._no_weight_decay = True + + assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] + A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range) + A_log = torch.log(A).to(dtype=dtype) + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device)) + self.D._no_weight_decay = True + + if self.rmsnorm: + assert RMSNormGated is not None + self.norm = RMSNormGated( + self.d_ssm, + eps=1e-5, + norm_before_gate=self.norm_before_gate, + group_size=self.d_ssm // ngroups, + **factory_kwargs, + ) + + # self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + self.out_proj = Linear( + td_inner, + td_model, + bias=bias, + weight_init_method=kaiming_init_(td_inner.size), + ) + + def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None): + """ + u: (batch, seqlen, hidden_dim) if seqlen=None. + If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we + split u during sequence parallel, we split the batch * seqlen dimension + (in case batch is small). + Returns: same shape as u + """ + seqlen_og = seqlen + if seqlen is None: + batch, seqlen, dim = u.shape + else: + batch_seqlen, dim = u.shape + batch = batch_seqlen // seqlen + + conv_state, ssm_state = None, None + if inference_params is not None: + inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch + conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch) + if inference_params.seqlen_offset > 0: + # The states are updated inplace + out, _, _ = self.step(u, conv_state, ssm_state) + return out + + zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj) + if seqlen_og is not None: + zxbcdt = einops.rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen) + # If the model is loaded in fp16, without the .float() here, A might be -inf + A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state) + dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit) + + # [z, x, B, C, dt] + d_mlp = (zxbcdt.shape[-1] - 2 * self.d_inner - 2 * self.d_xb - self.nheads) // 2 + z0, x0, z, xBC, dt = torch.split( + zxbcdt, [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.d_xb, self.nheads], dim=-1 + ) + + if self.repeat_kv_before_conv: + x, B, C = torch.split(xBC, [self.d_xb, self.d_xb, self.ngroups * self.d_state], dim=-1) + # minic the GQA + x = einops.rearrange(x, "b l (xb_group dstate) -> b xb_group l dstate", dstate=self.d_state) + x = repeat_kv(x, self.repeat_group) + # x shape: (bsz, n_group, l, dim) + B = einops.rearrange(B, "b l (xb_group dstate) -> b xb_group l dstate", dstate=self.d_state) + B = repeat_kv(B, self.repeat_group) + # combine x, B, C + x = einops.rearrange(x, "b g l p -> b l (g p)") + B = einops.rearrange(B, "b g l p -> b l (g p)") + xBC = torch.cat((x, B, C), dim=-1) + + if conv_state is not None: + if cu_seqlens is None: + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + xBC_t = einops.rearrange(xBC, "b l d -> b d l") + conv_state.copy_( + torch.nn.functional.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0)) + ) # Update state (B D W) + else: + assert ( + causal_conv1d.causal_conv1d_varlen_states is not None + ), "varlen inference requires causal_conv1d package" + assert batch == 1, "varlen inference only supports batch dimension 1" + conv_varlen_states = causal_conv1d.causal_conv1d_varlen_states( + xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1] + ) + conv_state.copy_(conv_varlen_states) + assert self.activation in ["silu", "swish"] + + if causal_conv1d.causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: + assert seq_idx is None, "varlen conv1d requires the causal_conv1d package" + xBC = self.act( + self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.d_conv - 1) :] + ) # (B, L, self.d_ssm + 2 * ngroups * d_state) + else: + xBC = causal_conv1d.causal_conv1d_fn( + xBC.transpose(1, 2), + einops.rearrange(self.conv1d.weight, "d 1 w -> d w"), + bias=self.conv1d.bias, + activation=self.activation, + seq_idx=seq_idx, + ).transpose(1, 2) + + if self.repeat_kv_before_conv: + x, B, C = torch.split( + xBC, [self.ngroups * self.d_state, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1 + ) + + y = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( + einops.rearrange(x, "b l (h p) -> b l h p", p=self.headdim), + dt, + A, + einops.rearrange(B, "b l (g n) -> b l g n", g=self.ngroups), + einops.rearrange(C, "b l (g n) -> b l g n", g=self.ngroups), + chunk_size=self.chunk_size, + D=einops.rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D, + z=einops.rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None, + dt_bias=self.dt_bias, + dt_softplus=True, + seq_idx=seq_idx, + cu_seqlens=cu_seqlens, + **dt_limit_kwargs, + return_final_states=ssm_state is not None, + return_varlen_states=cu_seqlens is not None and inference_params is not None, + ) + + else: + # self.d_xb + self.d_xb + self.d_inner + x, B, C = torch.split(xBC, [self.d_xb, self.d_xb, self.ngroups * self.d_state], dim=-1) + + # minic the GQA + x = einops.rearrange(x, "b l (xb_group dstate) -> b xb_group l dstate", dstate=self.d_state) + x = repeat_kv(x, self.repeat_group) + # x shape: (bsz, n_group, l, dim) + + B = einops.rearrange(B, "b l (xb_group dstate) -> b xb_group l dstate", dstate=self.d_state) + B = repeat_kv(B, self.repeat_group) + + y = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( + # einops.rearrange(x, "b l (h p) -> b l h p", p=self.headdim), + einops.rearrange(x, "b g l p -> b l g p"), + dt, + A, + # einops.rearrange(B, "b l (g n) -> b l g n", g=self.ngroups), + einops.rearrange(B, "b g l n -> b l g n"), + einops.rearrange(C, "b l (g n) -> b l g n", g=self.ngroups), + chunk_size=self.chunk_size, + D=einops.rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D, + z=einops.rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None, + dt_bias=self.dt_bias, + dt_softplus=True, + seq_idx=seq_idx, + cu_seqlens=cu_seqlens, + **dt_limit_kwargs, + return_final_states=ssm_state is not None, + return_varlen_states=cu_seqlens is not None and inference_params is not None, + ) + + if ssm_state is not None: + y, last_state, *rest = y + if cu_seqlens is None: + ssm_state.copy_(last_state) + else: + varlen_states = rest[0] + ssm_state.copy_(varlen_states) + y = einops.rearrange(y, "b l h p -> b l (h p)") + if self.rmsnorm: + y = self.norm(y, z) + if d_mlp > 0: + y = torch.cat([torch.nn.functional.silu(z0) * x0, y], dim=-1) + if seqlen_og is not None: + y = einops.rearrange(y, "b l d -> (b l) d") + out = self.out_proj(y) + return out + + assert self.layer_idx is not None + if self.layer_idx not in inference_params.key_value_memory_dict: + (batch_size,) + conv_state = torch.zeros( + batch_size, + self.d_conv, + self.conv1d.weight.shape[0], + device=self.conv1d.weight.device, + dtype=self.conv1d.weight.dtype, + ).transpose(1, 2) + ssm_state = torch.zeros( + batch_size, + self.nheads, + self.headdim, + self.d_state, + device=self.in_proj.weight.device, + dtype=self.in_proj.weight.dtype, + ) + inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) + else: + conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] + # TODO: What if batch size changes between generation, and we reuse the same states? + if initialize_states: + conv_state.zero_() + ssm_state.zero_() + return conv_state, ssm_state diff --git a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb index 85608075..a8390fa3 100644 --- a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb +++ b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -31,7 +31,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -41,14 +41,14 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 6.68it/s]\n" + "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 8.90it/s]\n" ] }, { @@ -82,7 +82,7 @@ ")" ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -97,7 +97,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -106,7 +106,7 @@ "torch.bfloat16" ] }, - "execution_count": 4, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -117,7 +117,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -126,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -135,7 +135,7 @@ "4.83207168" ] }, - "execution_count": 6, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -146,7 +146,58 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n" + ] + } + ], + "source": [ + "config_apriel = AprielSSMConfig.from_pretrained(\"/mnt/checkpoints_fml/pretrained_models/ssm/apriel_ssm_instruct_base\", trust_remote_code=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n", + "You are using a model of type llamba to instantiate a model of type apriel_ssm. This is not supported for all configurations of models and can yield errors.\n" + ] + }, + { + "ename": "KeyError", + "evalue": "'n_qk_heads'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[12], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m stage2_checkpoint \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/mnt/checkpoints_fml/pretrained_models/ssm/mohawk_final\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 2\u001b[0m stage2_apriel_ssm \u001b[38;5;241m=\u001b[39m \u001b[43mAprielSSMForCausalLM\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstage2_checkpoint\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtorch_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbfloat16\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/modeling_utils.py:3571\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 3569\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(config, PretrainedConfig):\n\u001b[1;32m 3570\u001b[0m config_path \u001b[38;5;241m=\u001b[39m config \u001b[38;5;28;01mif\u001b[39;00m config \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m pretrained_model_name_or_path\n\u001b[0;32m-> 3571\u001b[0m config, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconfig_class\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3572\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3573\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3574\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_unused_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 3575\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3576\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3577\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3578\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3579\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3580\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3581\u001b[0m \u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msubfolder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3582\u001b[0m \u001b[43m \u001b[49m\u001b[43m_from_auto\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfrom_auto_class\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3583\u001b[0m \u001b[43m \u001b[49m\u001b[43m_from_pipeline\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfrom_pipeline\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3584\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3585\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3586\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3587\u001b[0m \u001b[38;5;66;03m# In case one passes a config to `from_pretrained` + \"attn_implementation\"\u001b[39;00m\n\u001b[1;32m 3588\u001b[0m \u001b[38;5;66;03m# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 3592\u001b[0m \u001b[38;5;66;03m# we pop attn_implementation from the kwargs but this handles the case where users\u001b[39;00m\n\u001b[1;32m 3593\u001b[0m \u001b[38;5;66;03m# passes manually the config to `from_pretrained`.\u001b[39;00m\n\u001b[1;32m 3594\u001b[0m config \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(config)\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/configuration_utils.py:569\u001b[0m, in \u001b[0;36mPretrainedConfig.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, **kwargs)\u001b[0m\n\u001b[1;32m 563\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_type:\n\u001b[1;32m 564\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarning(\n\u001b[1;32m 565\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou are using a model of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mconfig_dict[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m to instantiate a model of type \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 566\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. This is not supported for all configurations of models and can yield errors.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 567\u001b[0m )\n\u001b[0;32m--> 569\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/configuration_utils.py:740\u001b[0m, in \u001b[0;36mPretrainedConfig.from_dict\u001b[0;34m(cls, config_dict, **kwargs)\u001b[0m\n\u001b[1;32m 737\u001b[0m \u001b[38;5;66;03m# We remove it from kwargs so that it does not appear in `return_unused_kwargs`.\u001b[39;00m\n\u001b[1;32m 738\u001b[0m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattn_implementation\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattn_implementation\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m--> 740\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mconfig_dict\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 742\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(config, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpruned_heads\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 743\u001b[0m config\u001b[38;5;241m.\u001b[39mpruned_heads \u001b[38;5;241m=\u001b[39m {\u001b[38;5;28mint\u001b[39m(key): value \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m config\u001b[38;5;241m.\u001b[39mpruned_heads\u001b[38;5;241m.\u001b[39mitems()}\n", + "File \u001b[0;32m~/dev/Fast-LLM/fast_llm/models/ssm/external/configuration_ssm_apriel.py:99\u001b[0m, in \u001b[0;36mAprielSSMConfig.__init__\u001b[0;34m(self, vocab_size, hidden_size, intermediate_size, num_hidden_layers, hidden_act, initializer_range, use_cache, pad_token_id, bos_token_id, eos_token_id, tie_word_embeddings, mlp_bias, rms_norm_eps, ssm_cfg, head_dim, **kwargs)\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(\n\u001b[1;32m 82\u001b[0m pad_token_id\u001b[38;5;241m=\u001b[39mpad_token_id,\n\u001b[1;32m 83\u001b[0m bos_token_id\u001b[38;5;241m=\u001b[39mbos_token_id,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 87\u001b[0m )\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mssm_cfg \u001b[38;5;241m=\u001b[39m ssm_cfg \u001b[38;5;129;01mor\u001b[39;00m {\n\u001b[1;32m 90\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_state\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m64\u001b[39m,\n\u001b[1;32m 91\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mn_v_heads\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m24\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m24\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhead_dim, \u001b[38;5;66;03m# num_heads * head_dim\u001b[39;00m\n\u001b[1;32m 98\u001b[0m }\n\u001b[0;32m---> 99\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhead_dim \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mssm_cfg[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mssm_cfg\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mn_qk_heads\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\n", + "\u001b[0;31mKeyError\u001b[0m: 'n_qk_heads'" + ] + } + ], + "source": [ + "stage2_checkpoint = \"/mnt/checkpoints_fml/pretrained_models/ssm/mohawk_final\"\n", + "stage2_apriel_ssm = AprielSSMForCausalLM.from_pretrained(stage2_checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -162,12 +213,13 @@ " pad_token_id=config.pad_token_id,\n", " bos_token_id=config.bos_token_id,\n", " eos_token_id=config.eos_token_id,\n", + " head_dim=config.head_dim,\n", " rms_norm_eps=config.rms_norm_eps)" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -176,60 +228,1984 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "AprielSSMConfig {\n", - " \"_attn_implementation_autoset\": true,\n", - " \"bos_token_id\": 1,\n", - " \"eos_token_id\": 2,\n", - " \"hidden_act\": \"silu\",\n", - " \"hidden_size\": 4096,\n", - " \"initializer_range\": 0.02,\n", - " \"intermediate_size\": 8192,\n", - " \"mlp_bias\": false,\n", - " \"model_type\": \"apriel_ssm\",\n", - " \"num_hidden_layers\": 28,\n", - " \"rms_norm_eps\": 1e-05,\n", - " \"ssm_cfg\": {\n", - " \"activation\": \"identity\",\n", - " \"bias\": false,\n", - " \"chunk_size\": 128,\n", - " \"d_inner\": 4104,\n", - " \"d_state\": 64,\n", - " \"expand\": 1,\n", - " \"n_qk_heads\": 24,\n", - " \"n_v_heads\": 24\n", - " },\n", - " \"tie_word_embeddings\": false,\n", - " \"transformers_version\": \"4.48.1\",\n", - " \"use_cache\": true,\n", - " \"vocab_size\": 131072\n", - "}" + "OrderedDict([('model.embed_tokens.weight',\n", + " tensor([[ 0.0105, 0.0330, -0.0032, ..., 0.0076, -0.0051, 0.0112],\n", + " [-0.0111, -0.0101, 0.0064, ..., 0.0144, 0.0098, -0.0194],\n", + " [ 0.0301, 0.0228, 0.0105, ..., -0.0159, 0.0112, -0.0009],\n", + " ...,\n", + " [ 0.0266, 0.0224, -0.0150, ..., 0.0189, -0.0253, -0.0300],\n", + " [-0.0304, 0.0249, 0.0140, ..., -0.0235, 0.0315, -0.0188],\n", + " [-0.0215, -0.0034, 0.0035, ..., -0.0125, 0.0084, 0.0246]])),\n", + " ('model.layers.0.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.0.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.0.mixer.in_proj.weight',\n", + " tensor([[ 0.0104, 0.0055, -0.0148, ..., 0.0208, -0.0074, 0.0015],\n", + " [ 0.0102, 0.0148, 0.0148, ..., -0.0041, 0.0224, -0.0336],\n", + " [ 0.0129, -0.0179, -0.0120, ..., 0.0175, 0.0300, -0.0234],\n", + " ...,\n", + " [-0.0215, 0.0002, 0.0093, ..., -0.0424, 0.0016, -0.0162],\n", + " [-0.0178, -0.0093, 0.0226, ..., 0.0005, 0.0062, 0.0150],\n", + " [-0.0204, 0.0039, -0.0364, ..., -0.0128, 0.0002, 0.0134]])),\n", + " ('model.layers.0.mixer.conv1d.weight',\n", + " tensor([[[-0.1064, -0.3782, -0.3080, -0.3179]],\n", + " \n", + " [[-0.3493, 0.2230, 0.1062, 0.0614]],\n", + " \n", + " [[-0.4650, 0.0300, 0.3021, 0.1197]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.3686, 0.0679, 0.1440, 0.4445]],\n", + " \n", + " [[-0.1480, 0.3750, -0.0552, -0.0297]],\n", + " \n", + " [[ 0.0677, 0.0925, -0.0268, -0.0232]]])),\n", + " ('model.layers.0.mixer.conv1d.bias',\n", + " tensor([ 0.1379, 0.0862, -0.0723, ..., -0.2628, -0.1867, -0.1233])),\n", + " ('model.layers.0.mixer.out_proj.weight',\n", + " tensor([[ 0.0208, -0.0106, -0.0016, ..., 0.0117, 0.0140, -0.0040],\n", + " [-0.0147, 0.0419, 0.0327, ..., -0.0073, -0.0127, 0.0190],\n", + " [-0.0218, 0.0030, 0.0115, ..., -0.0062, 0.0214, 0.0105],\n", + " ...,\n", + " [ 0.0089, 0.0154, -0.0178, ..., -0.0206, -0.0378, 0.0102],\n", + " [ 0.0153, -0.0249, 0.0219, ..., 0.0119, 0.0019, 0.0383],\n", + " [-0.0126, 0.0284, -0.0035, ..., 0.0118, -0.0186, -0.0232]])),\n", + " ('model.layers.0.mlp.gate_proj.weight',\n", + " tensor([[-0.0032, -0.0405, 0.0180, ..., -0.0030, -0.0222, 0.0069],\n", + " [-0.0071, -0.0064, -0.0207, ..., 0.0037, -0.0077, 0.0261],\n", + " [ 0.0236, 0.0167, 0.0065, ..., 0.0064, 0.0035, -0.0092],\n", + " ...,\n", + " [-0.0357, 0.0192, 0.0099, ..., -0.0067, -0.0181, 0.0082],\n", + " [-0.0139, -0.0161, -0.0015, ..., -0.0052, -0.0337, 0.0514],\n", + " [ 0.0105, -0.0205, 0.0198, ..., 0.0090, 0.0315, 0.0066]])),\n", + " ('model.layers.0.mlp.up_proj.weight',\n", + " tensor([[ 0.0074, 0.0237, -0.0300, ..., 0.0343, 0.0016, 0.0395],\n", + " [ 0.0270, 0.0085, 0.0193, ..., 0.0199, -0.0139, 0.0094],\n", + " [ 0.0036, 0.0073, 0.0149, ..., 0.0094, 0.0346, -0.0111],\n", + " ...,\n", + " [ 0.0159, -0.0346, -0.0128, ..., 0.0377, -0.0531, -0.0305],\n", + " [ 0.0283, 0.0162, -0.0377, ..., -0.0254, 0.0110, -0.0167],\n", + " [-0.0277, 0.0130, 0.0161, ..., 0.0089, -0.0190, 0.0214]])),\n", + " ('model.layers.0.mlp.down_proj.weight',\n", + " tensor([[ 0.0157, 0.0105, 0.0036, ..., 0.0229, 0.0080, 0.0303],\n", + " [-0.0143, -0.0067, 0.0016, ..., 0.0494, -0.0043, 0.0072],\n", + " [-0.0148, 0.0113, 0.0025, ..., -0.0186, 0.0206, -0.0119],\n", + " ...,\n", + " [-0.0226, 0.0099, 0.0010, ..., 0.0123, -0.0170, 0.0024],\n", + " [-0.0120, -0.0015, -0.0355, ..., 0.0064, 0.0175, -0.0065],\n", + " [ 0.0364, 0.0364, 0.0265, ..., -0.0222, 0.0030, 0.0296]])),\n", + " ('model.layers.0.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.0.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.1.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.1.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.1.mixer.in_proj.weight',\n", + " tensor([[-0.0116, -0.0182, -0.0017, ..., -0.0216, -0.0136, -0.0203],\n", + " [-0.0142, -0.0106, -0.0334, ..., 0.0287, -0.0273, 0.0050],\n", + " [ 0.0131, -0.0106, -0.0012, ..., 0.0261, -0.0228, -0.0026],\n", + " ...,\n", + " [-0.0029, 0.0023, 0.0360, ..., -0.0195, 0.0018, -0.0227],\n", + " [ 0.0004, 0.0015, -0.0051, ..., -0.0095, 0.0269, 0.0179],\n", + " [ 0.0295, -0.0520, 0.0009, ..., 0.0019, 0.0255, 0.0478]])),\n", + " ('model.layers.1.mixer.conv1d.weight',\n", + " tensor([[[-0.4725, -0.2938, -0.3816, -0.1239]],\n", + " \n", + " [[-0.2002, 0.3790, 0.1908, -0.4679]],\n", + " \n", + " [[-0.3674, 0.3774, -0.2479, 0.4324]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.4181, 0.2263, -0.1937, 0.3585]],\n", + " \n", + " [[ 0.0704, 0.0913, 0.4217, 0.3004]],\n", + " \n", + " [[ 0.3175, -0.3239, -0.0614, -0.3978]]])),\n", + " ('model.layers.1.mixer.conv1d.bias',\n", + " tensor([ 0.4302, 0.0269, -0.3462, ..., 0.4887, 0.2848, 0.0745])),\n", + " ('model.layers.1.mixer.out_proj.weight',\n", + " tensor([[-0.0069, 0.0233, 0.0133, ..., -0.0064, -0.0085, 0.0166],\n", + " [-0.0302, 0.0129, -0.0042, ..., 0.0109, 0.0009, -0.0087],\n", + " [-0.0373, -0.0233, -0.0043, ..., -0.0017, 0.0384, -0.0114],\n", + " ...,\n", + " [-0.0219, 0.0330, -0.0341, ..., 0.0080, 0.0089, 0.0268],\n", + " [-0.0019, -0.0069, 0.0276, ..., 0.0182, -0.0240, 0.0163],\n", + " [ 0.0081, 0.0070, 0.0156, ..., -0.0135, 0.0469, -0.0221]])),\n", + " ('model.layers.1.mlp.gate_proj.weight',\n", + " tensor([[ 0.0175, -0.0074, -0.0028, ..., 0.0197, 0.0034, 0.0221],\n", + " [ 0.0063, 0.0339, -0.0047, ..., 0.0037, -0.0126, -0.0342],\n", + " [-0.0093, -0.0148, -0.0236, ..., 0.0190, -0.0451, -0.0173],\n", + " ...,\n", + " [ 0.0167, 0.0161, 0.0019, ..., -0.0083, -0.0133, 0.0141],\n", + " [-0.0163, 0.0383, -0.0203, ..., 0.0336, -0.0148, 0.0013],\n", + " [-0.0138, -0.0275, -0.0268, ..., -0.0243, -0.0031, -0.0227]])),\n", + " ('model.layers.1.mlp.up_proj.weight',\n", + " tensor([[ 0.0054, 0.0031, 0.0256, ..., 0.0002, 0.0020, -0.0050],\n", + " [ 0.0247, -0.0298, -0.0218, ..., -0.0161, 0.0253, 0.0128],\n", + " [-0.0231, -0.0012, 0.0130, ..., 0.0031, -0.0324, 0.0107],\n", + " ...,\n", + " [ 0.0359, -0.0202, 0.0386, ..., -0.0104, 0.0274, 0.0161],\n", + " [ 0.0062, -0.0111, 0.0338, ..., 0.0041, 0.0001, -0.0019],\n", + " [ 0.0105, -0.0258, 0.0184, ..., -0.0270, -0.0138, -0.0367]])),\n", + " ('model.layers.1.mlp.down_proj.weight',\n", + " tensor([[-0.0163, -0.0308, -0.0203, ..., 0.0002, -0.0227, 0.0019],\n", + " [ 0.0206, 0.0037, 0.0064, ..., -0.0261, -0.0206, 0.0063],\n", + " [ 0.0044, -0.0073, -0.0576, ..., -0.0015, -0.0082, 0.0022],\n", + " ...,\n", + " [-0.0034, 0.0142, -0.0547, ..., -0.0106, -0.0090, 0.0249],\n", + " [-0.0068, 0.0127, -0.0066, ..., -0.0255, 0.0004, 0.0106],\n", + " [-0.0293, 0.0146, -0.0142, ..., -0.0073, -0.0284, -0.0069]])),\n", + " ('model.layers.1.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.1.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.2.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.2.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.2.mixer.in_proj.weight',\n", + " tensor([[ 0.0337, -0.0055, -0.0538, ..., -0.0051, 0.0107, -0.0338],\n", + " [ 0.0227, -0.0008, 0.0003, ..., -0.0312, 0.0090, -0.0126],\n", + " [-0.0238, 0.0146, 0.0240, ..., -0.0114, -0.0180, 0.0025],\n", + " ...,\n", + " [-0.0208, -0.0261, 0.0227, ..., 0.0071, 0.0014, 0.0237],\n", + " [ 0.0356, 0.0372, 0.0186, ..., 0.0052, 0.0049, -0.0195],\n", + " [ 0.0023, -0.0159, -0.0238, ..., 0.0194, -0.0056, -0.0275]])),\n", + " ('model.layers.2.mixer.conv1d.weight',\n", + " tensor([[[ 0.1054, -0.4185, 0.4229, 0.3289]],\n", + " \n", + " [[-0.0081, 0.0321, 0.1334, -0.1055]],\n", + " \n", + " [[ 0.1587, -0.3806, -0.1336, -0.2662]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.2830, -0.3875, -0.2972, 0.0030]],\n", + " \n", + " [[ 0.4210, 0.2190, -0.4942, 0.0465]],\n", + " \n", + " [[-0.1830, -0.3686, 0.2928, -0.0313]]])),\n", + " ('model.layers.2.mixer.conv1d.bias',\n", + " tensor([-0.2931, -0.3513, -0.3013, ..., -0.1934, -0.3115, 0.3889])),\n", + " ('model.layers.2.mixer.out_proj.weight',\n", + " tensor([[-0.0038, -0.0160, -0.0042, ..., 0.0062, 0.0059, -0.0126],\n", + " [-0.0027, -0.0012, -0.0065, ..., -0.0032, 0.0129, -0.0298],\n", + " [ 0.0394, -0.0096, 0.0107, ..., -0.0290, 0.0248, 0.0308],\n", + " ...,\n", + " [ 0.0087, 0.0067, -0.0261, ..., -0.0038, -0.0168, 0.0485],\n", + " [ 0.0118, 0.0042, -0.0186, ..., 0.0104, 0.0281, 0.0028],\n", + " [ 0.0304, -0.0382, -0.0028, ..., -0.0264, -0.0050, 0.0050]])),\n", + " ('model.layers.2.mlp.gate_proj.weight',\n", + " tensor([[-0.0169, 0.0036, 0.0024, ..., 0.0429, 0.0313, 0.0167],\n", + " [-0.0100, 0.0011, -0.0024, ..., -0.0065, 0.0090, 0.0123],\n", + " [ 0.0102, 0.0282, 0.0166, ..., -0.0082, 0.0123, 0.0253],\n", + " ...,\n", + " [ 0.0168, -0.0056, -0.0096, ..., -0.0090, 0.0150, 0.0209],\n", + " [ 0.0258, 0.0113, -0.0093, ..., 0.0335, 0.0386, -0.0156],\n", + " [ 0.0129, 0.0338, -0.0006, ..., -0.0346, 0.0135, -0.0213]])),\n", + " ('model.layers.2.mlp.up_proj.weight',\n", + " tensor([[-0.0029, 0.0416, -0.0102, ..., -0.0413, 0.0019, 0.0063],\n", + " [ 0.0054, 0.0138, 0.0031, ..., -0.0077, -0.0070, -0.0016],\n", + " [ 0.0128, 0.0153, -0.0147, ..., -0.0131, -0.0244, 0.0097],\n", + " ...,\n", + " [-0.0190, -0.0025, 0.0322, ..., -0.0106, -0.0323, -0.0144],\n", + " [-0.0269, -0.0007, 0.0070, ..., 0.0191, -0.0025, 0.0033],\n", + " [-0.0311, 0.0217, -0.0021, ..., 0.0302, -0.0131, 0.0388]])),\n", + " ('model.layers.2.mlp.down_proj.weight',\n", + " tensor([[ 0.0150, -0.0127, 0.0372, ..., 0.0018, 0.0018, 0.0187],\n", + " [-0.0262, 0.0164, 0.0281, ..., 0.0120, -0.0187, -0.0177],\n", + " [ 0.0129, -0.0042, 0.0018, ..., -0.0136, 0.0278, 0.0284],\n", + " ...,\n", + " [ 0.0048, 0.0421, -0.0018, ..., 0.0002, -0.0064, 0.0085],\n", + " [ 0.0276, 0.0146, 0.0228, ..., 0.0055, -0.0288, -0.0081],\n", + " [-0.0133, 0.0102, 0.0318, ..., 0.0209, -0.0270, 0.0128]])),\n", + " ('model.layers.2.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.2.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.3.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.3.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.3.mixer.in_proj.weight',\n", + " tensor([[ 7.4766e-03, -9.8698e-03, -1.9172e-02, ..., 3.7842e-02,\n", + " -2.1648e-03, 2.8147e-03],\n", + " [ 2.4954e-02, -1.2659e-02, 8.0447e-04, ..., 3.1716e-02,\n", + " 4.9989e-03, 6.4200e-03],\n", + " [-3.3345e-02, -1.5256e-02, 2.7295e-02, ..., -1.1240e-02,\n", + " 9.7000e-03, 3.1136e-05],\n", + " ...,\n", + " [-2.0807e-04, -2.5132e-02, -1.9983e-02, ..., -2.9541e-02,\n", + " 4.6152e-04, 5.5341e-02],\n", + " [ 2.0498e-03, 2.2021e-02, -7.6882e-03, ..., 1.6469e-02,\n", + " -1.0645e-02, -1.8442e-03],\n", + " [ 2.0949e-03, -1.2398e-02, 1.2922e-02, ..., 1.1862e-02,\n", + " -4.7119e-03, 3.2352e-02]])),\n", + " ('model.layers.3.mixer.conv1d.weight',\n", + " tensor([[[ 0.2590, 0.1670, 0.3987, -0.1694]],\n", + " \n", + " [[-0.4425, 0.1468, 0.3060, -0.0764]],\n", + " \n", + " [[-0.3638, -0.0575, 0.2156, -0.2468]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0111, -0.0182, -0.3816, 0.0382]],\n", + " \n", + " [[-0.4723, -0.3712, 0.1963, 0.2877]],\n", + " \n", + " [[-0.4890, 0.1197, 0.1361, 0.3282]]])),\n", + " ('model.layers.3.mixer.conv1d.bias',\n", + " tensor([-0.4712, -0.3272, 0.4587, ..., -0.3145, 0.4086, 0.4005])),\n", + " ('model.layers.3.mixer.out_proj.weight',\n", + " tensor([[-0.0362, 0.0137, -0.0296, ..., -0.0028, 0.0104, 0.0393],\n", + " [ 0.0130, 0.0246, -0.0132, ..., 0.0082, -0.0044, -0.0054],\n", + " [-0.0081, -0.0115, -0.0064, ..., 0.0250, -0.0076, -0.0021],\n", + " ...,\n", + " [ 0.0230, -0.0055, 0.0056, ..., 0.0076, 0.0016, -0.0068],\n", + " [ 0.0472, -0.0068, 0.0336, ..., 0.0079, 0.0211, 0.0031],\n", + " [-0.0450, -0.0005, 0.0219, ..., 0.0044, -0.0006, -0.0278]])),\n", + " ('model.layers.3.mlp.gate_proj.weight',\n", + " tensor([[ 0.0034, 0.0445, -0.0132, ..., 0.0290, 0.0019, 0.0048],\n", + " [ 0.0271, 0.0109, 0.0028, ..., -0.0304, -0.0237, -0.0017],\n", + " [ 0.0098, 0.0252, 0.0392, ..., 0.0486, 0.0326, -0.0171],\n", + " ...,\n", + " [-0.0015, 0.0080, 0.0005, ..., -0.0158, -0.0067, 0.0347],\n", + " [-0.0638, 0.0120, 0.0076, ..., 0.0007, 0.0052, -0.0109],\n", + " [-0.0303, -0.0168, -0.0537, ..., -0.0163, -0.0030, -0.0068]])),\n", + " ('model.layers.3.mlp.up_proj.weight',\n", + " tensor([[-0.0074, -0.0101, 0.0073, ..., -0.0012, -0.0208, -0.0239],\n", + " [ 0.0035, 0.0010, 0.0157, ..., -0.0228, -0.0224, 0.0194],\n", + " [ 0.0457, -0.0129, -0.0063, ..., -0.0312, 0.0261, -0.0018],\n", + " ...,\n", + " [ 0.0012, 0.0093, 0.0121, ..., -0.0035, -0.0367, -0.0454],\n", + " [ 0.0308, -0.0334, 0.0062, ..., 0.0043, -0.0031, -0.0406],\n", + " [-0.0175, -0.0089, -0.0137, ..., -0.0322, -0.0070, -0.0219]])),\n", + " ('model.layers.3.mlp.down_proj.weight',\n", + " tensor([[ 0.0226, 0.0074, -0.0170, ..., 0.0035, 0.0420, -0.0085],\n", + " [ 0.0116, 0.0173, -0.0009, ..., -0.0302, 0.0075, 0.0153],\n", + " [-0.0092, 0.0119, 0.0164, ..., 0.0233, -0.0177, -0.0397],\n", + " ...,\n", + " [-0.0006, -0.0275, 0.0127, ..., -0.0185, 0.0335, -0.0133],\n", + " [ 0.0064, -0.0200, 0.0296, ..., 0.0041, -0.0114, -0.0221],\n", + " [ 0.0317, 0.0392, 0.0553, ..., 0.0191, 0.0188, -0.0176]])),\n", + " ('model.layers.3.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.3.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.4.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.4.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.4.mixer.in_proj.weight',\n", + " tensor([[-0.0266, 0.0092, -0.0260, ..., -0.0121, -0.0286, 0.0267],\n", + " [ 0.0144, -0.0053, -0.0060, ..., -0.0065, 0.0201, -0.0025],\n", + " [-0.0092, -0.0465, -0.0032, ..., 0.0192, -0.0026, 0.0104],\n", + " ...,\n", + " [-0.0210, -0.0286, -0.0148, ..., 0.0593, 0.0130, 0.0118],\n", + " [ 0.0361, -0.0070, 0.0054, ..., -0.0073, 0.0004, 0.0287],\n", + " [ 0.0450, -0.0286, 0.0191, ..., -0.0180, 0.0039, -0.0033]])),\n", + " ('model.layers.4.mixer.conv1d.weight',\n", + " tensor([[[ 0.1450, 0.2065, -0.1750, -0.4560]],\n", + " \n", + " [[-0.2889, -0.4707, -0.0741, 0.1254]],\n", + " \n", + " [[-0.4665, 0.1876, -0.4049, 0.1143]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0709, 0.2021, -0.0053, -0.1558]],\n", + " \n", + " [[-0.0195, -0.4046, -0.2437, -0.4405]],\n", + " \n", + " [[-0.3615, -0.4314, 0.1667, 0.3139]]])),\n", + " ('model.layers.4.mixer.conv1d.bias',\n", + " tensor([-0.3220, -0.4181, -0.0623, ..., 0.2788, 0.0518, 0.4607])),\n", + " ('model.layers.4.mixer.out_proj.weight',\n", + " tensor([[-0.0011, -0.0279, -0.0160, ..., -0.0222, 0.0262, 0.0234],\n", + " [ 0.0024, 0.0178, -0.0142, ..., 0.0048, -0.0145, 0.0332],\n", + " [-0.0084, -0.0037, 0.0054, ..., -0.0201, -0.0341, -0.0053],\n", + " ...,\n", + " [-0.0120, -0.0440, 0.0097, ..., -0.0070, -0.0129, 0.0170],\n", + " [ 0.0096, -0.0034, -0.0025, ..., 0.0242, 0.0047, 0.0093],\n", + " [ 0.0254, 0.0207, 0.0135, ..., 0.0204, -0.0185, -0.0026]])),\n", + " ('model.layers.4.mlp.gate_proj.weight',\n", + " tensor([[ 0.0049, 0.0087, 0.0081, ..., 0.0145, 0.0188, 0.0441],\n", + " [-0.0103, 0.0147, 0.0180, ..., -0.0190, 0.0182, 0.0160],\n", + " [-0.0041, 0.0289, 0.0106, ..., 0.0144, -0.0070, 0.0104],\n", + " ...,\n", + " [ 0.0086, 0.0079, 0.0155, ..., 0.0037, -0.0242, 0.0091],\n", + " [-0.0320, 0.0084, -0.0508, ..., 0.0003, -0.0120, 0.0129],\n", + " [ 0.0079, 0.0185, 0.0285, ..., -0.0324, 0.0444, -0.0147]])),\n", + " ('model.layers.4.mlp.up_proj.weight',\n", + " tensor([[ 3.4382e-03, 1.9171e-02, 4.1226e-03, ..., 1.3158e-02,\n", + " 3.6365e-02, -8.1017e-03],\n", + " [ 1.8713e-02, -2.7732e-03, 3.1982e-02, ..., -8.5724e-03,\n", + " -3.1505e-02, 2.1047e-03],\n", + " [ 1.2329e-02, 1.8352e-03, 9.2540e-03, ..., 2.9880e-02,\n", + " -2.7856e-04, -8.7440e-04],\n", + " ...,\n", + " [-2.2330e-02, -2.0716e-02, 9.0004e-05, ..., -1.6298e-02,\n", + " -1.9620e-02, 2.5112e-02],\n", + " [ 7.1659e-03, 1.2942e-02, 1.0291e-03, ..., -1.0113e-02,\n", + " -1.6838e-03, 2.0189e-02],\n", + " [ 7.2108e-03, 3.1229e-02, 2.2533e-03, ..., -2.0148e-02,\n", + " -1.3502e-02, -1.8923e-02]])),\n", + " ('model.layers.4.mlp.down_proj.weight',\n", + " tensor([[ 0.0140, -0.0129, 0.0005, ..., -0.0068, -0.0335, 0.0172],\n", + " [-0.0175, -0.0011, 0.0114, ..., -0.0087, -0.0048, -0.0231],\n", + " [-0.0053, -0.0079, -0.0172, ..., -0.0125, -0.0200, 0.0127],\n", + " ...,\n", + " [ 0.0321, -0.0039, 0.0142, ..., 0.0384, 0.0054, 0.0321],\n", + " [ 0.0041, -0.0150, 0.0141, ..., 0.0049, -0.0348, -0.0028],\n", + " [ 0.0176, 0.0132, 0.0090, ..., -0.0117, 0.0241, 0.0417]])),\n", + " ('model.layers.4.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.4.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.5.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.5.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.5.mixer.in_proj.weight',\n", + " tensor([[ 0.0270, 0.0124, 0.0098, ..., 0.0170, -0.0225, 0.0032],\n", + " [ 0.0245, -0.0008, 0.0226, ..., 0.0219, -0.0219, 0.0087],\n", + " [-0.0175, 0.0181, 0.0124, ..., 0.0038, -0.0094, 0.0079],\n", + " ...,\n", + " [-0.0080, -0.0011, 0.0316, ..., -0.0012, 0.0254, 0.0251],\n", + " [-0.0141, -0.0159, -0.0069, ..., 0.0147, -0.0161, -0.0093],\n", + " [ 0.0252, 0.0125, 0.0174, ..., -0.0065, 0.0110, 0.0272]])),\n", + " ('model.layers.5.mixer.conv1d.weight',\n", + " tensor([[[ 0.0684, -0.4353, 0.3899, 0.3199]],\n", + " \n", + " [[ 0.4136, 0.4306, -0.4871, 0.4781]],\n", + " \n", + " [[-0.2516, 0.2109, 0.3891, 0.1501]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0781, -0.0675, -0.2995, -0.1805]],\n", + " \n", + " [[-0.3360, -0.4148, 0.1846, -0.1013]],\n", + " \n", + " [[ 0.1725, 0.1929, -0.0337, 0.1375]]])),\n", + " ('model.layers.5.mixer.conv1d.bias',\n", + " tensor([-0.4975, -0.0629, -0.2420, ..., -0.2253, 0.2512, 0.2788])),\n", + " ('model.layers.5.mixer.out_proj.weight',\n", + " tensor([[ 1.4306e-02, 1.3230e-02, -2.4141e-02, ..., 1.1763e-02,\n", + " 7.0706e-03, -4.7970e-03],\n", + " [ 2.7478e-02, 1.5179e-03, 1.9229e-02, ..., 1.0928e-02,\n", + " 2.2802e-02, -2.9729e-03],\n", + " [ 1.0169e-02, -1.0741e-02, 2.0628e-02, ..., -1.8109e-02,\n", + " -4.2582e-03, 2.4007e-02],\n", + " ...,\n", + " [-3.2843e-03, 3.7835e-03, -6.7958e-03, ..., -2.6205e-02,\n", + " -2.0391e-02, 5.3912e-03],\n", + " [ 1.2515e-02, -6.4975e-03, 9.9616e-05, ..., 1.0444e-02,\n", + " -2.0596e-02, -8.2915e-03],\n", + " [ 1.7899e-02, 2.0418e-02, -1.9891e-02, ..., -6.6709e-03,\n", + " -3.8566e-02, 2.7005e-02]])),\n", + " ('model.layers.5.mlp.gate_proj.weight',\n", + " tensor([[-2.3807e-03, 2.2714e-03, 2.2736e-05, ..., -2.3039e-03,\n", + " 3.6159e-02, -1.7253e-02],\n", + " [ 3.6929e-02, -6.2031e-03, 1.3606e-02, ..., 2.3592e-02,\n", + " 4.4487e-03, -9.6723e-03],\n", + " [ 4.7507e-02, 2.6413e-02, 1.6759e-02, ..., 1.1910e-02,\n", + " 1.2872e-02, -1.0443e-02],\n", + " ...,\n", + " [-2.0354e-02, -3.9074e-03, 9.7952e-03, ..., 1.0730e-02,\n", + " 2.8752e-02, -8.0048e-03],\n", + " [ 2.5331e-02, -9.9732e-03, 1.0772e-02, ..., 2.0420e-02,\n", + " -3.2179e-02, -1.6437e-02],\n", + " [-3.4425e-02, -1.4578e-02, 2.9686e-03, ..., 4.5907e-02,\n", + " 7.7639e-03, -2.2494e-03]])),\n", + " ('model.layers.5.mlp.up_proj.weight',\n", + " tensor([[ 1.5868e-02, -1.9222e-02, -1.2880e-03, ..., 8.3353e-03,\n", + " -1.8538e-02, 6.7395e-03],\n", + " [-1.8051e-02, -5.0142e-02, -2.2177e-03, ..., -9.3852e-03,\n", + " -3.0374e-02, 2.5795e-02],\n", + " [-1.1737e-02, 2.6278e-02, -2.3205e-02, ..., -1.8399e-03,\n", + " 1.4115e-02, -2.6438e-02],\n", + " ...,\n", + " [ 2.7706e-02, -2.5067e-03, -8.7058e-03, ..., 2.1662e-03,\n", + " -4.9858e-02, -1.1575e-02],\n", + " [-9.5670e-04, 2.1698e-02, -5.4794e-03, ..., -1.0661e-02,\n", + " 1.8568e-02, 5.2615e-03],\n", + " [ 1.0739e-03, 2.2945e-02, 3.0835e-02, ..., 4.1212e-03,\n", + " 1.2643e-02, -1.1568e-05]])),\n", + " ('model.layers.5.mlp.down_proj.weight',\n", + " tensor([[ 0.0052, -0.0343, 0.0072, ..., 0.0004, 0.0320, 0.0362],\n", + " [ 0.0171, -0.0238, -0.0316, ..., 0.0231, 0.0377, 0.0141],\n", + " [-0.0205, 0.0152, 0.0002, ..., -0.0061, -0.0353, -0.0138],\n", + " ...,\n", + " [-0.0039, -0.0039, 0.0326, ..., -0.0208, 0.0160, 0.0185],\n", + " [ 0.0176, -0.0300, -0.0024, ..., -0.0292, -0.0254, -0.0366],\n", + " [ 0.0361, 0.0243, -0.0253, ..., -0.0036, -0.0099, -0.0133]])),\n", + " ('model.layers.5.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.5.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.6.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.6.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.6.mixer.in_proj.weight',\n", + " tensor([[-0.0505, -0.0650, 0.0059, ..., 0.0060, 0.0347, 0.0149],\n", + " [-0.0216, 0.0057, -0.0281, ..., -0.0162, 0.0081, 0.0016],\n", + " [-0.0339, -0.0314, 0.0253, ..., 0.0030, 0.0139, -0.0039],\n", + " ...,\n", + " [ 0.0355, -0.0238, -0.0015, ..., 0.0063, 0.0284, -0.0089],\n", + " [ 0.0093, -0.0381, -0.0261, ..., -0.0170, -0.0170, -0.0288],\n", + " [-0.0228, -0.0110, 0.0107, ..., 0.0300, 0.0010, 0.0141]])),\n", + " ('model.layers.6.mixer.conv1d.weight',\n", + " tensor([[[ 0.4364, 0.2888, 0.2343, 0.3226]],\n", + " \n", + " [[ 0.2804, 0.3558, 0.4061, -0.0480]],\n", + " \n", + " [[ 0.4964, 0.0709, 0.0748, 0.0971]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.4291, 0.2445, -0.3121, 0.4013]],\n", + " \n", + " [[-0.1590, -0.1516, 0.0804, 0.2009]],\n", + " \n", + " [[ 0.1686, 0.0492, -0.2932, 0.1381]]])),\n", + " ('model.layers.6.mixer.conv1d.bias',\n", + " tensor([ 0.4241, -0.0500, 0.3393, ..., 0.1598, -0.4924, -0.3241])),\n", + " ('model.layers.6.mixer.out_proj.weight',\n", + " tensor([[ 0.0026, 0.0272, 0.0005, ..., 0.0434, -0.0293, -0.0105],\n", + " [ 0.0323, -0.0515, 0.0107, ..., -0.0406, 0.0252, -0.0038],\n", + " [-0.0156, -0.0078, 0.0173, ..., 0.0312, -0.0014, -0.0014],\n", + " ...,\n", + " [ 0.0014, -0.0522, -0.0154, ..., 0.0090, -0.0050, -0.0049],\n", + " [ 0.0350, 0.0099, -0.0014, ..., -0.0008, -0.0185, -0.0033],\n", + " [ 0.0134, 0.0002, 0.0325, ..., -0.0129, 0.0165, -0.0265]])),\n", + " ('model.layers.6.mlp.gate_proj.weight',\n", + " tensor([[-0.0011, 0.0202, 0.0236, ..., -0.0137, -0.0063, 0.0085],\n", + " [ 0.0163, 0.0261, 0.0120, ..., -0.0003, -0.0254, 0.0001],\n", + " [ 0.0318, -0.0121, 0.0103, ..., -0.0053, 0.0194, 0.0530],\n", + " ...,\n", + " [ 0.0039, 0.0228, -0.0147, ..., 0.0027, 0.0092, -0.0033],\n", + " [-0.0040, 0.0144, 0.0038, ..., -0.0106, -0.0022, 0.0094],\n", + " [ 0.0220, 0.0296, 0.0550, ..., 0.0079, -0.0135, -0.0092]])),\n", + " ('model.layers.6.mlp.up_proj.weight',\n", + " tensor([[ 0.0061, -0.0291, -0.0133, ..., 0.0054, -0.0049, -0.0028],\n", + " [-0.0032, -0.0201, 0.0218, ..., -0.0155, -0.0264, 0.0496],\n", + " [-0.0046, 0.0384, -0.0093, ..., 0.0356, -0.0245, 0.0175],\n", + " ...,\n", + " [-0.0111, -0.0092, -0.0143, ..., 0.0010, -0.0453, 0.0024],\n", + " [ 0.0078, -0.0025, 0.0227, ..., -0.0130, 0.0118, 0.0095],\n", + " [ 0.0234, -0.0114, -0.0102, ..., -0.0179, -0.0066, -0.0115]])),\n", + " ('model.layers.6.mlp.down_proj.weight',\n", + " tensor([[ 3.6976e-02, 1.7124e-02, -2.1290e-02, ..., -2.5206e-02,\n", + " 4.8023e-03, 9.8474e-03],\n", + " [-7.2866e-03, -5.4149e-03, -2.2242e-03, ..., -8.1606e-03,\n", + " -9.5275e-04, -1.8121e-02],\n", + " [-8.3493e-03, 1.2509e-02, 1.0773e-02, ..., 2.7061e-02,\n", + " 2.8131e-03, 5.8219e-03],\n", + " ...,\n", + " [ 8.7099e-03, 3.9196e-02, -3.5129e-03, ..., -2.3595e-02,\n", + " -8.3965e-03, 2.0074e-02],\n", + " [-2.7467e-02, -2.8721e-03, -2.2291e-02, ..., 9.7135e-03,\n", + " 3.4947e-02, -2.2158e-02],\n", + " [ 6.1744e-03, -4.7684e-03, 4.6690e-04, ..., -3.2948e-03,\n", + " 4.0735e-05, 3.3651e-02]])),\n", + " ('model.layers.6.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.6.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.7.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.7.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.7.mixer.in_proj.weight',\n", + " tensor([[-0.0045, -0.0288, 0.0362, ..., -0.0092, -0.0026, 0.0051],\n", + " [ 0.0160, 0.0139, 0.0057, ..., 0.0121, 0.0071, 0.0134],\n", + " [ 0.0062, 0.0181, 0.0161, ..., -0.0284, -0.0014, -0.0171],\n", + " ...,\n", + " [-0.0053, 0.0067, 0.0095, ..., -0.0175, 0.0235, 0.0125],\n", + " [-0.0048, 0.0041, 0.0038, ..., 0.0099, 0.0194, 0.0124],\n", + " [ 0.0131, 0.0073, -0.0284, ..., 0.0138, -0.0218, 0.0019]])),\n", + " ('model.layers.7.mixer.conv1d.weight',\n", + " tensor([[[ 0.2528, -0.0556, -0.3225, 0.1327]],\n", + " \n", + " [[-0.0437, 0.4941, -0.4075, 0.1062]],\n", + " \n", + " [[-0.3428, 0.2675, 0.1871, 0.0260]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0409, -0.4458, 0.4488, 0.2841]],\n", + " \n", + " [[-0.2370, -0.3965, 0.0656, -0.1339]],\n", + " \n", + " [[ 0.4677, 0.0073, 0.3741, 0.1525]]])),\n", + " ('model.layers.7.mixer.conv1d.bias',\n", + " tensor([-0.1844, -0.1347, 0.0043, ..., -0.3839, -0.2167, -0.4637])),\n", + " ('model.layers.7.mixer.out_proj.weight',\n", + " tensor([[-2.8471e-02, 3.9783e-03, 6.0125e-03, ..., -1.6079e-02,\n", + " 1.4225e-02, 2.8166e-02],\n", + " [ 5.4680e-03, -5.1414e-03, 5.3077e-05, ..., 1.8734e-02,\n", + " 3.7454e-03, 1.7579e-02],\n", + " [-1.2955e-02, 1.4954e-02, 6.4922e-03, ..., -2.6830e-02,\n", + " 1.4766e-02, -1.8002e-02],\n", + " ...,\n", + " [ 1.7150e-02, 4.6781e-02, -1.1136e-02, ..., 4.7242e-03,\n", + " -1.3072e-02, -1.0412e-02],\n", + " [ 5.5498e-03, -3.0803e-02, -2.4880e-02, ..., -4.2644e-03,\n", + " -1.1047e-02, 1.5815e-02],\n", + " [ 1.7242e-02, 2.7994e-02, -4.8186e-04, ..., -2.2003e-02,\n", + " -2.1834e-02, -2.1826e-02]])),\n", + " ('model.layers.7.mlp.gate_proj.weight',\n", + " tensor([[-0.0302, -0.0160, -0.0341, ..., -0.0121, 0.0007, -0.0338],\n", + " [-0.0186, 0.0257, -0.0154, ..., 0.0153, -0.0029, 0.0163],\n", + " [ 0.0170, 0.0223, -0.0185, ..., -0.0020, 0.0061, 0.0174],\n", + " ...,\n", + " [-0.0044, 0.0044, 0.0077, ..., -0.0183, 0.0041, -0.0003],\n", + " [ 0.0168, 0.0149, -0.0221, ..., 0.0112, 0.0357, 0.0042],\n", + " [ 0.0310, -0.0217, 0.0070, ..., -0.0394, -0.0065, 0.0204]])),\n", + " ('model.layers.7.mlp.up_proj.weight',\n", + " tensor([[-0.0031, -0.0110, 0.0091, ..., 0.0152, -0.0013, 0.0096],\n", + " [ 0.0013, 0.0354, -0.0037, ..., 0.0130, 0.0204, 0.0262],\n", + " [-0.0075, -0.0044, 0.0207, ..., 0.0057, 0.0115, 0.0151],\n", + " ...,\n", + " [-0.0015, 0.0095, -0.0100, ..., -0.0150, 0.0105, -0.0350],\n", + " [-0.0300, -0.0092, -0.0176, ..., -0.0113, 0.0164, -0.0117],\n", + " [-0.0291, -0.0085, 0.0058, ..., 0.0386, -0.0174, -0.0092]])),\n", + " ('model.layers.7.mlp.down_proj.weight',\n", + " tensor([[-0.0276, 0.0017, -0.0217, ..., 0.0302, -0.0079, -0.0003],\n", + " [ 0.0379, 0.0052, 0.0052, ..., 0.0145, 0.0139, -0.0143],\n", + " [ 0.0176, -0.0028, 0.0172, ..., -0.0205, -0.0165, -0.0040],\n", + " ...,\n", + " [ 0.0095, -0.0139, 0.0077, ..., -0.0080, 0.0339, 0.0172],\n", + " [-0.0177, 0.0009, -0.0245, ..., 0.0040, 0.0258, 0.0202],\n", + " [-0.0064, -0.0270, 0.0041, ..., -0.0133, -0.0040, 0.0038]])),\n", + " ('model.layers.7.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.7.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.8.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.8.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.8.mixer.in_proj.weight',\n", + " tensor([[ 0.0050, 0.0270, -0.0196, ..., -0.0121, -0.0090, 0.0083],\n", + " [-0.0083, -0.0177, 0.0159, ..., 0.0298, -0.0202, -0.0265],\n", + " [ 0.0058, 0.0186, 0.0125, ..., -0.0067, -0.0255, 0.0298],\n", + " ...,\n", + " [-0.0164, 0.0012, 0.0023, ..., -0.0355, 0.0347, -0.0011],\n", + " [-0.0371, 0.0033, 0.0345, ..., -0.0097, 0.0019, 0.0185],\n", + " [-0.0322, -0.0160, 0.0072, ..., -0.0195, -0.0229, 0.0118]])),\n", + " ('model.layers.8.mixer.conv1d.weight',\n", + " tensor([[[-0.0520, 0.3004, -0.1990, 0.2512]],\n", + " \n", + " [[-0.4120, -0.0055, 0.1484, -0.3316]],\n", + " \n", + " [[ 0.3939, -0.0567, 0.1432, 0.1880]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.2849, 0.2494, -0.2141, -0.3375]],\n", + " \n", + " [[-0.2823, -0.2402, 0.2228, 0.2331]],\n", + " \n", + " [[ 0.1914, 0.4269, 0.1228, -0.3408]]])),\n", + " ('model.layers.8.mixer.conv1d.bias',\n", + " tensor([0.1304, 0.2065, 0.3084, ..., 0.3863, 0.4883, 0.4724])),\n", + " ('model.layers.8.mixer.out_proj.weight',\n", + " tensor([[ 0.0008, -0.0019, 0.0084, ..., -0.0003, 0.0045, 0.0024],\n", + " [ 0.0137, -0.0003, -0.0031, ..., 0.0013, 0.0131, 0.0090],\n", + " [ 0.0095, 0.0488, -0.0355, ..., 0.0344, -0.0229, -0.0150],\n", + " ...,\n", + " [ 0.0029, 0.0164, -0.0380, ..., -0.0005, -0.0031, 0.0127],\n", + " [-0.0039, 0.0283, 0.0295, ..., 0.0271, -0.0105, -0.0158],\n", + " [-0.0057, -0.0178, 0.0129, ..., 0.0323, -0.0091, 0.0178]])),\n", + " ('model.layers.8.mlp.gate_proj.weight',\n", + " tensor([[-0.0047, 0.0037, -0.0129, ..., 0.0255, -0.0118, 0.0084],\n", + " [ 0.0418, -0.0020, 0.0205, ..., 0.0161, 0.0306, 0.0250],\n", + " [ 0.0011, 0.0144, 0.0204, ..., -0.0007, 0.0298, -0.0067],\n", + " ...,\n", + " [-0.0536, -0.0083, -0.0049, ..., -0.0028, 0.0301, -0.0205],\n", + " [ 0.0031, 0.0139, 0.0070, ..., 0.0120, 0.0004, -0.0226],\n", + " [ 0.0114, -0.0173, 0.0212, ..., -0.0413, -0.0069, 0.0007]])),\n", + " ('model.layers.8.mlp.up_proj.weight',\n", + " tensor([[-0.0005, 0.0028, -0.0137, ..., 0.0078, 0.0348, 0.0006],\n", + " [-0.0020, 0.0300, -0.0056, ..., -0.0258, -0.0130, -0.0212],\n", + " [-0.0135, -0.0111, 0.0151, ..., 0.0043, -0.0426, -0.0109],\n", + " ...,\n", + " [ 0.0273, 0.0057, -0.0108, ..., -0.0205, 0.0005, -0.0239],\n", + " [ 0.0226, 0.0325, -0.0187, ..., 0.0069, -0.0132, -0.0002],\n", + " [ 0.0280, -0.0007, -0.0047, ..., 0.0159, -0.0054, -0.0172]])),\n", + " ('model.layers.8.mlp.down_proj.weight',\n", + " tensor([[-0.0091, 0.0072, 0.0030, ..., 0.0025, -0.0159, -0.0277],\n", + " [ 0.0159, -0.0260, -0.0076, ..., -0.0059, -0.0129, 0.0358],\n", + " [ 0.0026, -0.0357, -0.0138, ..., -0.0326, -0.0291, 0.0010],\n", + " ...,\n", + " [-0.0237, 0.0272, -0.0130, ..., -0.0280, 0.0097, -0.0563],\n", + " [ 0.0092, 0.0056, 0.0079, ..., -0.0224, 0.0039, -0.0054],\n", + " [-0.0109, -0.0241, -0.0223, ..., -0.0187, 0.0190, 0.0082]])),\n", + " ('model.layers.8.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.8.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.9.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.9.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.9.mixer.in_proj.weight',\n", + " tensor([[ 4.9824e-02, 5.7576e-03, -5.1022e-03, ..., -2.5615e-02,\n", + " 7.1750e-04, 1.5247e-02],\n", + " [-2.8065e-02, -1.2649e-02, -2.3566e-02, ..., 1.7742e-02,\n", + " -1.1202e-02, -2.1476e-02],\n", + " [ 2.0911e-02, 1.6496e-02, -1.9818e-02, ..., 4.0223e-02,\n", + " 1.8544e-02, -2.3633e-02],\n", + " ...,\n", + " [-4.3387e-02, -1.6504e-02, 2.2008e-02, ..., -2.5138e-03,\n", + " -5.6073e-03, -4.8212e-03],\n", + " [-1.9964e-05, -1.5835e-02, 1.2977e-02, ..., 4.1913e-03,\n", + " 4.5898e-02, -3.5822e-02],\n", + " [ 3.1376e-02, -5.4614e-03, -2.5093e-02, ..., -3.7903e-03,\n", + " 1.3560e-02, 3.3366e-02]])),\n", + " ('model.layers.9.mixer.conv1d.weight',\n", + " tensor([[[ 0.1986, -0.1666, -0.4140, -0.4607]],\n", + " \n", + " [[-0.3454, -0.3973, 0.2169, -0.2138]],\n", + " \n", + " [[ 0.2006, -0.3736, 0.3944, -0.0589]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.4604, 0.1224, -0.2571, -0.0286]],\n", + " \n", + " [[-0.2723, -0.1617, 0.3483, 0.2299]],\n", + " \n", + " [[ 0.4866, 0.2559, 0.3969, 0.0554]]])),\n", + " ('model.layers.9.mixer.conv1d.bias',\n", + " tensor([ 0.3388, 0.4633, -0.3762, ..., -0.3491, -0.2971, 0.0494])),\n", + " ('model.layers.9.mixer.out_proj.weight',\n", + " tensor([[ 0.0023, -0.0181, 0.0358, ..., 0.0243, 0.0070, -0.0183],\n", + " [ 0.0006, 0.0065, 0.0057, ..., -0.0351, -0.0107, 0.0132],\n", + " [ 0.0153, -0.0038, 0.0059, ..., -0.0285, -0.0247, -0.0104],\n", + " ...,\n", + " [ 0.0244, -0.0120, 0.0064, ..., -0.0133, 0.0263, 0.0016],\n", + " [ 0.0056, -0.0111, 0.0029, ..., -0.0017, -0.0172, -0.0071],\n", + " [-0.0056, -0.0192, -0.0238, ..., 0.0245, -0.0102, -0.0331]])),\n", + " ('model.layers.9.mlp.gate_proj.weight',\n", + " tensor([[-0.0132, 0.0014, -0.0413, ..., -0.0254, -0.0245, 0.0031],\n", + " [-0.0195, -0.0107, -0.0192, ..., 0.0012, -0.0026, 0.0148],\n", + " [-0.0074, -0.0070, -0.0078, ..., 0.0013, -0.0011, -0.0111],\n", + " ...,\n", + " [-0.0137, 0.0302, 0.0084, ..., -0.0063, -0.0065, 0.0240],\n", + " [ 0.0072, 0.0134, 0.0161, ..., 0.0122, 0.0182, 0.0137],\n", + " [ 0.0079, 0.0008, 0.0160, ..., 0.0281, 0.0226, 0.0058]])),\n", + " ('model.layers.9.mlp.up_proj.weight',\n", + " tensor([[ 0.0078, 0.0153, -0.0155, ..., 0.0153, -0.0164, -0.0140],\n", + " [-0.0072, -0.0050, 0.0030, ..., 0.0146, -0.0148, -0.0080],\n", + " [ 0.0165, -0.0078, 0.0005, ..., -0.0545, -0.0096, 0.0296],\n", + " ...,\n", + " [-0.0253, 0.0183, -0.0081, ..., -0.0061, 0.0270, -0.0003],\n", + " [-0.0015, -0.0320, 0.0361, ..., -0.0087, 0.0341, -0.0157],\n", + " [ 0.0041, 0.0102, -0.0195, ..., -0.0441, -0.0106, 0.0275]])),\n", + " ('model.layers.9.mlp.down_proj.weight',\n", + " tensor([[-6.3367e-02, -1.8214e-02, 5.7221e-03, ..., 2.1307e-02,\n", + " -3.0707e-02, -1.3281e-02],\n", + " [-7.7457e-05, -9.1894e-05, 6.8686e-03, ..., -4.7175e-03,\n", + " -1.1585e-03, -2.7604e-02],\n", + " [ 2.9301e-02, -5.9431e-03, -2.5356e-03, ..., -2.7858e-02,\n", + " 1.1647e-02, 1.1245e-02],\n", + " ...,\n", + " [-1.0442e-02, -9.6151e-03, -3.6635e-02, ..., -1.1052e-02,\n", + " -4.5122e-03, 4.0012e-03],\n", + " [ 3.2950e-02, -1.3836e-03, -7.8318e-03, ..., -1.2788e-03,\n", + " 2.3422e-02, -3.2098e-02],\n", + " [-9.2294e-03, 1.3838e-02, -2.0327e-02, ..., -3.8760e-02,\n", + " 2.2118e-02, 1.0696e-02]])),\n", + " ('model.layers.9.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.9.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.10.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.10.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.10.mixer.in_proj.weight',\n", + " tensor([[ 0.0096, -0.0159, 0.0141, ..., 0.0111, 0.0218, 0.0220],\n", + " [-0.0381, -0.0015, 0.0126, ..., -0.0066, -0.0034, -0.0119],\n", + " [ 0.0223, 0.0032, -0.0195, ..., -0.0107, -0.0018, 0.0059],\n", + " ...,\n", + " [-0.0256, -0.0170, -0.0362, ..., -0.0007, -0.0039, 0.0075],\n", + " [ 0.0136, -0.0045, 0.0128, ..., -0.0017, 0.0083, -0.0004],\n", + " [-0.0246, -0.0021, 0.0073, ..., 0.0020, 0.0071, 0.0090]])),\n", + " ('model.layers.10.mixer.conv1d.weight',\n", + " tensor([[[ 0.0463, -0.4497, -0.0679, -0.2209]],\n", + " \n", + " [[-0.3805, 0.4459, 0.1999, -0.4996]],\n", + " \n", + " [[ 0.1529, 0.1789, -0.1535, 0.1824]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.1087, -0.4478, -0.0420, 0.3437]],\n", + " \n", + " [[-0.2809, -0.4617, 0.3209, 0.4873]],\n", + " \n", + " [[ 0.1139, -0.0060, -0.0219, 0.0853]]])),\n", + " ('model.layers.10.mixer.conv1d.bias',\n", + " tensor([ 0.1364, -0.0475, 0.0849, ..., 0.1928, 0.2075, 0.1058])),\n", + " ('model.layers.10.mixer.out_proj.weight',\n", + " tensor([[-0.0164, -0.0188, 0.0174, ..., -0.0106, -0.0107, -0.0036],\n", + " [ 0.0048, -0.0016, -0.0444, ..., -0.0182, -0.0264, -0.0038],\n", + " [ 0.0089, -0.0225, -0.0002, ..., -0.0141, -0.0008, -0.0037],\n", + " ...,\n", + " [-0.0005, 0.0159, 0.0033, ..., 0.0187, -0.0064, 0.0233],\n", + " [-0.0050, 0.0296, 0.0147, ..., -0.0018, 0.0137, -0.0346],\n", + " [-0.0064, -0.0132, -0.0434, ..., -0.0173, -0.0113, -0.0175]])),\n", + " ('model.layers.10.mlp.gate_proj.weight',\n", + " tensor([[-0.0174, -0.0053, -0.0325, ..., -0.0072, -0.0280, 0.0033],\n", + " [ 0.0006, -0.0160, 0.0346, ..., 0.0019, 0.0059, 0.0198],\n", + " [ 0.0231, -0.0187, 0.0115, ..., 0.0085, 0.0080, 0.0061],\n", + " ...,\n", + " [ 0.0153, 0.0241, -0.0184, ..., 0.0089, -0.0242, 0.0010],\n", + " [-0.0019, -0.0322, 0.0011, ..., -0.0097, -0.0305, 0.0065],\n", + " [-0.0107, 0.0240, 0.0168, ..., 0.0226, -0.0238, 0.0117]])),\n", + " ('model.layers.10.mlp.up_proj.weight',\n", + " tensor([[-0.0072, 0.0352, 0.0282, ..., -0.0025, -0.0114, 0.0129],\n", + " [-0.0102, 0.0196, 0.0760, ..., 0.0461, -0.0058, -0.0112],\n", + " [-0.0271, 0.0323, -0.0069, ..., 0.0133, -0.0371, -0.0619],\n", + " ...,\n", + " [ 0.0100, 0.0011, 0.0262, ..., -0.0232, 0.0217, 0.0002],\n", + " [ 0.0151, -0.0266, -0.0074, ..., 0.0096, 0.0036, 0.0033],\n", + " [ 0.0004, 0.0103, 0.0363, ..., -0.0095, -0.0309, -0.0059]])),\n", + " ('model.layers.10.mlp.down_proj.weight',\n", + " tensor([[ 0.0124, -0.0225, -0.0294, ..., 0.0280, 0.0056, 0.0231],\n", + " [ 0.0124, -0.0030, 0.0014, ..., 0.0323, 0.0094, -0.0034],\n", + " [-0.0078, 0.0041, -0.0056, ..., 0.0241, -0.0278, -0.0152],\n", + " ...,\n", + " [-0.0044, 0.0025, -0.0161, ..., -0.0075, -0.0126, 0.0014],\n", + " [-0.0109, -0.0050, 0.0327, ..., -0.0300, -0.0048, 0.0284],\n", + " [ 0.0050, -0.0183, 0.0086, ..., -0.0072, 0.0139, -0.0010]])),\n", + " ('model.layers.10.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.10.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.11.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.11.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.11.mixer.in_proj.weight',\n", + " tensor([[-0.0133, 0.0225, 0.0486, ..., -0.0214, -0.0120, -0.0150],\n", + " [ 0.0183, 0.0020, 0.0079, ..., -0.0163, 0.0016, -0.0214],\n", + " [-0.0276, -0.0112, 0.0121, ..., -0.0057, -0.0143, -0.0462],\n", + " ...,\n", + " [-0.0142, -0.0080, -0.0194, ..., 0.0087, -0.0212, -0.0140],\n", + " [ 0.0060, -0.0005, -0.0171, ..., -0.0017, 0.0223, 0.0169],\n", + " [-0.0290, -0.0016, 0.0117, ..., 0.0037, 0.0047, 0.0152]])),\n", + " ('model.layers.11.mixer.conv1d.weight',\n", + " tensor([[[-0.2822, -0.4216, 0.4786, 0.0802]],\n", + " \n", + " [[-0.3671, 0.1761, -0.2686, 0.1631]],\n", + " \n", + " [[-0.3902, -0.2811, -0.0748, 0.4662]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.1623, 0.2871, -0.4585, 0.4755]],\n", + " \n", + " [[-0.0260, 0.4541, -0.2983, 0.2297]],\n", + " \n", + " [[-0.2991, -0.3590, -0.3256, -0.1434]]])),\n", + " ('model.layers.11.mixer.conv1d.bias',\n", + " tensor([ 0.1218, -0.0542, 0.3485, ..., 0.0528, 0.2711, -0.2811])),\n", + " ('model.layers.11.mixer.out_proj.weight',\n", + " tensor([[ 0.0032, 0.0028, -0.0122, ..., -0.0299, -0.0105, 0.0021],\n", + " [-0.0466, -0.0170, -0.0017, ..., 0.0156, -0.0287, 0.0066],\n", + " [ 0.0016, 0.0054, -0.0071, ..., -0.0240, 0.0215, -0.0046],\n", + " ...,\n", + " [-0.0210, 0.0034, -0.0267, ..., 0.0461, -0.0076, -0.0016],\n", + " [-0.0012, -0.0101, 0.0196, ..., 0.0121, -0.0043, -0.0143],\n", + " [-0.0067, 0.0086, 0.0134, ..., 0.0080, 0.0255, 0.0225]])),\n", + " ('model.layers.11.mlp.gate_proj.weight',\n", + " tensor([[ 0.0179, -0.0429, -0.0134, ..., 0.0110, 0.0368, -0.0259],\n", + " [ 0.0013, -0.0231, 0.0072, ..., -0.0056, -0.0012, -0.0037],\n", + " [-0.0172, -0.0162, 0.0088, ..., -0.0175, 0.0079, -0.0065],\n", + " ...,\n", + " [ 0.0287, -0.0289, 0.0045, ..., 0.0039, 0.0269, 0.0199],\n", + " [ 0.0043, -0.0202, -0.0261, ..., 0.0104, -0.0161, -0.0057],\n", + " [-0.0154, 0.0085, 0.0061, ..., 0.0208, 0.0001, 0.0166]])),\n", + " ('model.layers.11.mlp.up_proj.weight',\n", + " tensor([[-0.0107, 0.0328, 0.0065, ..., -0.0190, -0.0082, -0.0047],\n", + " [-0.0001, 0.0102, 0.0310, ..., -0.0396, -0.0278, -0.0095],\n", + " [-0.0288, 0.0052, 0.0137, ..., -0.0220, 0.0007, -0.0170],\n", + " ...,\n", + " [ 0.0213, -0.0074, -0.0033, ..., 0.0183, 0.0336, -0.0180],\n", + " [-0.0098, -0.0162, 0.0486, ..., 0.0191, 0.0064, 0.0269],\n", + " [-0.0251, 0.0081, 0.0053, ..., 0.0110, 0.0023, 0.0041]])),\n", + " ('model.layers.11.mlp.down_proj.weight',\n", + " tensor([[ 0.0166, -0.0410, 0.0066, ..., -0.0273, 0.0220, 0.0184],\n", + " [ 0.0092, 0.0087, -0.0136, ..., 0.0013, -0.0205, 0.0247],\n", + " [-0.0252, -0.0040, -0.0112, ..., -0.0331, 0.0201, -0.0038],\n", + " ...,\n", + " [ 0.0072, 0.0190, 0.0089, ..., 0.0098, -0.0235, -0.0141],\n", + " [-0.0045, -0.0381, -0.0134, ..., 0.0171, -0.0077, -0.0180],\n", + " [ 0.0109, 0.0060, 0.0048, ..., -0.0108, -0.0122, 0.0110]])),\n", + " ('model.layers.11.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.11.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.12.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.12.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.12.mixer.in_proj.weight',\n", + " tensor([[ 0.0043, 0.0138, 0.0138, ..., -0.0042, 0.0121, -0.0190],\n", + " [ 0.0002, -0.0199, 0.0315, ..., 0.0170, 0.0051, -0.0062],\n", + " [-0.0053, 0.0043, 0.0283, ..., -0.0087, 0.0069, -0.0160],\n", + " ...,\n", + " [-0.0313, 0.0200, 0.0036, ..., 0.0147, 0.0153, 0.0098],\n", + " [-0.0157, 0.0120, -0.0112, ..., 0.0166, -0.0005, 0.0066],\n", + " [-0.0271, 0.0037, 0.0163, ..., 0.0304, 0.0023, 0.0083]])),\n", + " ('model.layers.12.mixer.conv1d.weight',\n", + " tensor([[[-0.4295, -0.2474, -0.2324, -0.2138]],\n", + " \n", + " [[ 0.3607, -0.4824, 0.1667, 0.1348]],\n", + " \n", + " [[ 0.3596, 0.1167, 0.1089, -0.4010]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.3527, -0.3346, -0.3755, 0.1450]],\n", + " \n", + " [[-0.1921, -0.0632, -0.4885, -0.3986]],\n", + " \n", + " [[ 0.1950, 0.3037, -0.1630, 0.0353]]])),\n", + " ('model.layers.12.mixer.conv1d.bias',\n", + " tensor([0.3103, 0.0451, 0.4533, ..., 0.0235, 0.1819, 0.3933])),\n", + " ('model.layers.12.mixer.out_proj.weight',\n", + " tensor([[ 0.0167, -0.0197, -0.0054, ..., 0.0096, 0.0271, -0.0118],\n", + " [ 0.0167, -0.0455, 0.0001, ..., 0.0003, 0.0265, 0.0111],\n", + " [ 0.0231, -0.0113, 0.0195, ..., -0.0171, -0.0044, -0.0244],\n", + " ...,\n", + " [ 0.0042, 0.0048, 0.0357, ..., 0.0126, -0.0288, 0.0149],\n", + " [ 0.0192, 0.0078, 0.0126, ..., 0.0029, 0.0255, -0.0203],\n", + " [-0.0054, -0.0543, 0.0039, ..., -0.0240, 0.0282, 0.0082]])),\n", + " ('model.layers.12.mlp.gate_proj.weight',\n", + " tensor([[-0.0417, -0.0193, -0.0022, ..., 0.0031, 0.0337, 0.0175],\n", + " [ 0.0215, -0.0109, -0.0657, ..., -0.0145, -0.0475, -0.0091],\n", + " [-0.0225, -0.0012, -0.0020, ..., -0.0291, 0.0097, 0.0163],\n", + " ...,\n", + " [-0.0018, 0.0048, -0.0265, ..., -0.0056, 0.0446, 0.0045],\n", + " [ 0.0270, 0.0086, -0.0110, ..., -0.0038, 0.0176, 0.0138],\n", + " [-0.0134, 0.0046, -0.0186, ..., -0.0098, 0.0191, 0.0095]])),\n", + " ('model.layers.12.mlp.up_proj.weight',\n", + " tensor([[ 0.0180, 0.0075, 0.0147, ..., 0.0142, 0.0291, -0.0303],\n", + " [-0.0079, -0.0277, -0.0151, ..., -0.0069, -0.0045, -0.0223],\n", + " [ 0.0180, -0.0087, 0.0074, ..., 0.0215, 0.0274, -0.0199],\n", + " ...,\n", + " [-0.0215, -0.0115, 0.0140, ..., -0.0283, -0.0171, -0.0229],\n", + " [ 0.0231, -0.0179, -0.0386, ..., 0.0364, 0.0311, 0.0048],\n", + " [-0.0111, 0.0079, 0.0328, ..., 0.0285, 0.0423, 0.0039]])),\n", + " ('model.layers.12.mlp.down_proj.weight',\n", + " tensor([[-0.0361, 0.0192, -0.0005, ..., -0.0151, 0.0116, -0.0068],\n", + " [ 0.0203, -0.0064, 0.0061, ..., 0.0325, -0.0004, -0.0299],\n", + " [-0.0028, 0.0131, 0.0141, ..., -0.0108, -0.0070, -0.0090],\n", + " ...,\n", + " [ 0.0165, -0.0198, -0.0242, ..., 0.0162, 0.0099, 0.0025],\n", + " [ 0.0148, 0.0056, -0.0139, ..., 0.0108, -0.0477, 0.0225],\n", + " [ 0.0156, 0.0249, -0.0287, ..., -0.0200, -0.0496, 0.0169]])),\n", + " ('model.layers.12.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.12.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.13.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.13.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.13.mixer.in_proj.weight',\n", + " tensor([[-0.0064, -0.0200, 0.0384, ..., -0.0036, 0.0158, -0.0007],\n", + " [-0.0074, 0.0105, 0.0043, ..., 0.0097, 0.0259, -0.0012],\n", + " [ 0.0297, -0.0146, -0.0012, ..., 0.0273, 0.0309, 0.0087],\n", + " ...,\n", + " [ 0.0204, -0.0063, 0.0136, ..., -0.0092, 0.0196, 0.0057],\n", + " [ 0.0195, 0.0059, 0.0228, ..., 0.0093, -0.0183, -0.0003],\n", + " [-0.0131, -0.0447, -0.0262, ..., -0.0125, 0.0237, -0.0404]])),\n", + " ('model.layers.13.mixer.conv1d.weight',\n", + " tensor([[[ 7.7458e-03, 4.9829e-01, 2.1690e-01, -2.3587e-01]],\n", + " \n", + " [[ 3.7281e-01, -4.0991e-03, 2.4588e-01, -1.1600e-01]],\n", + " \n", + " [[-4.8238e-01, -2.8961e-01, -4.4331e-02, 1.0011e-01]],\n", + " \n", + " ...,\n", + " \n", + " [[-3.6304e-01, -1.4106e-01, -3.5434e-01, 1.4923e-01]],\n", + " \n", + " [[-2.3703e-01, 3.9285e-04, -2.1456e-02, -2.5568e-01]],\n", + " \n", + " [[ 1.5303e-02, -8.3474e-03, -3.2668e-01, -4.8096e-01]]])),\n", + " ('model.layers.13.mixer.conv1d.bias',\n", + " tensor([-0.2462, 0.1532, -0.2298, ..., -0.3016, 0.1210, -0.3777])),\n", + " ('model.layers.13.mixer.out_proj.weight',\n", + " tensor([[-0.0019, 0.0103, 0.0098, ..., -0.0050, 0.0180, -0.0117],\n", + " [-0.0153, 0.0134, -0.0102, ..., 0.0327, -0.0387, 0.0025],\n", + " [ 0.0102, -0.0038, 0.0224, ..., -0.0118, 0.0234, 0.0014],\n", + " ...,\n", + " [-0.0201, 0.0233, 0.0189, ..., 0.0010, 0.0313, 0.0130],\n", + " [ 0.0193, 0.0035, -0.0253, ..., 0.0084, -0.0208, 0.0372],\n", + " [ 0.0367, -0.0029, -0.0205, ..., -0.0055, -0.0209, 0.0082]])),\n", + " ('model.layers.13.mlp.gate_proj.weight',\n", + " tensor([[ 0.0148, -0.0052, 0.0371, ..., -0.0118, 0.0397, -0.0234],\n", + " [ 0.0237, -0.0323, 0.0219, ..., 0.0098, -0.0304, 0.0165],\n", + " [ 0.0168, -0.0289, 0.0038, ..., 0.0022, 0.0174, 0.0043],\n", + " ...,\n", + " [-0.0135, 0.0258, -0.0172, ..., 0.0251, -0.0071, -0.0384],\n", + " [ 0.0005, -0.0123, 0.0116, ..., 0.0041, -0.0108, -0.0068],\n", + " [ 0.0116, 0.0069, 0.0063, ..., 0.0045, -0.0145, 0.0185]])),\n", + " ('model.layers.13.mlp.up_proj.weight',\n", + " tensor([[-0.0002, -0.0120, 0.0069, ..., 0.0005, -0.0108, -0.0284],\n", + " [ 0.0215, 0.0045, 0.0167, ..., 0.0177, -0.0030, 0.0051],\n", + " [ 0.0265, 0.0169, 0.0047, ..., 0.0069, -0.0299, 0.0196],\n", + " ...,\n", + " [ 0.0127, -0.0063, 0.0242, ..., -0.0061, -0.0263, 0.0041],\n", + " [ 0.0142, -0.0515, -0.0221, ..., -0.0369, -0.0399, -0.0210],\n", + " [ 0.0123, 0.0133, -0.0269, ..., 0.0092, -0.0177, 0.0226]])),\n", + " ('model.layers.13.mlp.down_proj.weight',\n", + " tensor([[ 0.0048, 0.0360, -0.0037, ..., 0.0169, 0.0304, -0.0162],\n", + " [ 0.0271, -0.0121, 0.0108, ..., -0.0424, 0.0293, -0.0137],\n", + " [ 0.0225, -0.0061, -0.0096, ..., 0.0075, -0.0168, 0.0142],\n", + " ...,\n", + " [ 0.0039, -0.0152, -0.0156, ..., 0.0181, 0.0105, 0.0070],\n", + " [ 0.0311, 0.0205, 0.0259, ..., -0.0025, 0.0060, -0.0125],\n", + " [ 0.0004, -0.0114, 0.0022, ..., -0.0159, -0.0290, 0.0036]])),\n", + " ('model.layers.13.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.13.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.14.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.14.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.14.mixer.in_proj.weight',\n", + " tensor([[-0.0123, 0.0054, 0.0059, ..., 0.0285, -0.0292, -0.0184],\n", + " [-0.0146, -0.0175, 0.0155, ..., -0.0206, -0.0190, -0.0172],\n", + " [ 0.0050, -0.0235, -0.0159, ..., -0.0013, -0.0102, 0.0082],\n", + " ...,\n", + " [-0.0243, -0.0013, 0.0312, ..., -0.0141, -0.0156, 0.0279],\n", + " [ 0.0018, 0.0181, -0.0188, ..., 0.0593, -0.0155, 0.0156],\n", + " [ 0.0036, 0.0182, -0.0308, ..., 0.0306, -0.0035, 0.0037]])),\n", + " ('model.layers.14.mixer.conv1d.weight',\n", + " tensor([[[-0.4608, 0.4926, -0.2625, 0.3060]],\n", + " \n", + " [[-0.0932, 0.0153, 0.2298, -0.1735]],\n", + " \n", + " [[-0.1927, 0.1979, -0.1773, 0.3277]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0538, -0.2180, -0.4857, -0.1428]],\n", + " \n", + " [[-0.1736, 0.2405, 0.3148, -0.4481]],\n", + " \n", + " [[-0.4971, -0.1558, 0.2762, -0.1849]]])),\n", + " ('model.layers.14.mixer.conv1d.bias',\n", + " tensor([-0.2181, -0.2375, 0.0896, ..., 0.0744, 0.0857, 0.4347])),\n", + " ('model.layers.14.mixer.out_proj.weight',\n", + " tensor([[-3.8364e-04, 2.4458e-02, 5.8783e-03, ..., -1.3479e-02,\n", + " -2.4306e-02, 5.7698e-03],\n", + " [ 4.5843e-02, -3.9217e-03, -6.9897e-03, ..., 5.5401e-03,\n", + " -1.4523e-02, 1.2266e-02],\n", + " [-7.1069e-03, 5.5550e-03, 1.1359e-02, ..., 3.5839e-02,\n", + " 1.0787e-02, 8.4053e-03],\n", + " ...,\n", + " [ 3.3029e-03, 5.4333e-03, -9.3382e-03, ..., -1.7376e-02,\n", + " 1.5601e-02, -6.3227e-03],\n", + " [-6.9199e-03, -1.6950e-02, 1.5155e-03, ..., 1.2324e-02,\n", + " 1.2259e-02, 5.5500e-02],\n", + " [-1.6177e-02, -6.5257e-05, -9.3656e-03, ..., 1.0653e-02,\n", + " 1.8864e-02, -1.2508e-02]])),\n", + " ('model.layers.14.mlp.gate_proj.weight',\n", + " tensor([[ 0.0279, 0.0025, 0.0214, ..., -0.0137, -0.0042, 0.0172],\n", + " [-0.0240, -0.0150, 0.0170, ..., 0.0090, 0.0002, 0.0172],\n", + " [-0.0181, 0.0052, -0.0418, ..., 0.0106, 0.0052, -0.0264],\n", + " ...,\n", + " [-0.0295, 0.0323, 0.0387, ..., -0.0116, -0.0140, -0.0053],\n", + " [ 0.0411, 0.0189, 0.0236, ..., 0.0094, -0.0176, -0.0066],\n", + " [ 0.0004, 0.0291, 0.0402, ..., 0.0127, -0.0009, 0.0010]])),\n", + " ('model.layers.14.mlp.up_proj.weight',\n", + " tensor([[ 0.0198, -0.0115, -0.0045, ..., 0.0273, 0.0012, -0.0082],\n", + " [-0.0217, 0.0075, 0.0006, ..., 0.0047, -0.0416, -0.0011],\n", + " [ 0.0012, -0.0214, -0.0211, ..., 0.0030, -0.0176, -0.0215],\n", + " ...,\n", + " [ 0.0062, -0.0305, 0.0310, ..., 0.0044, -0.0379, 0.0155],\n", + " [-0.0062, 0.0451, 0.0167, ..., 0.0062, -0.0033, 0.0012],\n", + " [ 0.0293, -0.0186, 0.0295, ..., 0.0092, 0.0100, 0.0038]])),\n", + " ('model.layers.14.mlp.down_proj.weight',\n", + " tensor([[ 0.0019, 0.0114, -0.0202, ..., 0.0227, -0.0227, -0.0005],\n", + " [-0.0437, -0.0045, -0.0385, ..., -0.0083, -0.0135, 0.0172],\n", + " [-0.0032, -0.0024, 0.0137, ..., 0.0071, 0.0034, 0.0104],\n", + " ...,\n", + " [ 0.0210, -0.0237, -0.0166, ..., -0.0105, 0.0490, 0.0155],\n", + " [-0.0109, 0.0112, 0.0082, ..., -0.0342, -0.0133, -0.0086],\n", + " [ 0.0282, -0.0210, -0.0127, ..., -0.0047, -0.0126, 0.0103]])),\n", + " ('model.layers.14.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.14.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.15.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.15.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.15.mixer.in_proj.weight',\n", + " tensor([[-0.0098, -0.0201, -0.0033, ..., -0.0289, 0.0275, 0.0186],\n", + " [ 0.0048, 0.0075, -0.0033, ..., 0.0011, 0.0042, 0.0040],\n", + " [-0.0079, -0.0025, 0.0018, ..., -0.0051, -0.0231, -0.0022],\n", + " ...,\n", + " [ 0.0186, -0.0104, -0.0062, ..., 0.0086, -0.0007, -0.0653],\n", + " [-0.0212, 0.0034, 0.0019, ..., 0.0167, 0.0050, 0.0120],\n", + " [ 0.0066, 0.0381, -0.0225, ..., -0.0043, 0.0229, -0.0004]])),\n", + " ('model.layers.15.mixer.conv1d.weight',\n", + " tensor([[[ 0.2306, 0.2721, 0.3406, 0.4513]],\n", + " \n", + " [[ 0.0991, 0.4973, 0.0010, -0.1445]],\n", + " \n", + " [[ 0.2975, 0.4813, 0.2817, -0.0468]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0104, -0.1473, 0.1685, -0.4390]],\n", + " \n", + " [[ 0.3669, 0.3461, 0.0845, 0.3576]],\n", + " \n", + " [[-0.1177, 0.0524, 0.4329, 0.0687]]])),\n", + " ('model.layers.15.mixer.conv1d.bias',\n", + " tensor([-0.0356, 0.4173, 0.3287, ..., -0.0141, 0.1365, 0.2086])),\n", + " ('model.layers.15.mixer.out_proj.weight',\n", + " tensor([[-0.0137, -0.0239, -0.0133, ..., -0.0177, -0.0125, -0.0015],\n", + " [ 0.0168, 0.0120, 0.0034, ..., 0.0098, 0.0098, 0.0110],\n", + " [-0.0315, 0.0447, 0.0189, ..., 0.0305, 0.0131, -0.0230],\n", + " ...,\n", + " [-0.0480, 0.0170, 0.0025, ..., 0.0317, -0.0378, -0.0236],\n", + " [-0.0319, -0.0290, 0.0023, ..., -0.0093, 0.0354, 0.0126],\n", + " [-0.0107, 0.0100, -0.0101, ..., 0.0046, 0.0205, -0.0203]])),\n", + " ('model.layers.15.mlp.gate_proj.weight',\n", + " tensor([[ 0.0160, 0.0432, 0.0073, ..., -0.0003, -0.0170, 0.0236],\n", + " [ 0.0055, 0.0066, -0.0311, ..., 0.0049, -0.0130, 0.0040],\n", + " [-0.0147, -0.0184, 0.0281, ..., 0.0016, 0.0077, -0.0072],\n", + " ...,\n", + " [-0.0049, -0.0434, -0.0118, ..., 0.0137, -0.0225, -0.0058],\n", + " [ 0.0221, -0.0077, 0.0029, ..., 0.0087, -0.0361, -0.0100],\n", + " [ 0.0263, 0.0228, 0.0050, ..., -0.0557, 0.0037, 0.0196]])),\n", + " ('model.layers.15.mlp.up_proj.weight',\n", + " tensor([[ 0.0093, -0.0189, 0.0173, ..., 0.0276, 0.0075, -0.0215],\n", + " [-0.0147, 0.0241, 0.0109, ..., 0.0120, 0.0032, 0.0327],\n", + " [ 0.0036, 0.0127, 0.0116, ..., 0.0100, -0.0003, 0.0233],\n", + " ...,\n", + " [-0.0063, 0.0160, 0.0138, ..., -0.0078, -0.0098, 0.0150],\n", + " [ 0.0138, -0.0236, 0.0109, ..., -0.0156, -0.0143, 0.0273],\n", + " [ 0.0345, 0.0201, -0.0119, ..., -0.0182, 0.0053, 0.0105]])),\n", + " ('model.layers.15.mlp.down_proj.weight',\n", + " tensor([[-0.0114, 0.0138, -0.0110, ..., 0.0084, -0.0144, 0.0100],\n", + " [ 0.0016, -0.0069, 0.0172, ..., -0.0394, 0.0368, 0.0468],\n", + " [-0.0184, -0.0094, -0.0273, ..., -0.0195, 0.0148, 0.0142],\n", + " ...,\n", + " [ 0.0311, 0.0093, -0.0130, ..., -0.0023, 0.0395, -0.0375],\n", + " [ 0.0056, 0.0027, 0.0061, ..., 0.0058, 0.0225, -0.0153],\n", + " [-0.0031, -0.0107, 0.0020, ..., -0.0173, -0.0050, 0.0423]])),\n", + " ('model.layers.15.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.15.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.16.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.16.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.16.mixer.in_proj.weight',\n", + " tensor([[-0.0063, 0.0006, 0.0130, ..., 0.0186, 0.0408, 0.0126],\n", + " [-0.0015, -0.0029, 0.0268, ..., -0.0042, -0.0209, -0.0046],\n", + " [-0.0034, -0.0286, 0.0185, ..., -0.0125, 0.0050, 0.0033],\n", + " ...,\n", + " [ 0.0045, 0.0133, 0.0220, ..., 0.0165, 0.0287, 0.0371],\n", + " [ 0.0100, -0.0232, 0.0103, ..., -0.0083, -0.0105, -0.0187],\n", + " [-0.0412, -0.0035, 0.0028, ..., 0.0286, 0.0349, -0.0037]])),\n", + " ('model.layers.16.mixer.conv1d.weight',\n", + " tensor([[[-0.1874, 0.2517, 0.0537, 0.1258]],\n", + " \n", + " [[ 0.1465, 0.2013, 0.3547, 0.2689]],\n", + " \n", + " [[ 0.4834, 0.4906, 0.0844, -0.0541]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.3004, 0.3313, 0.1688, 0.4381]],\n", + " \n", + " [[-0.0606, 0.3455, -0.0910, 0.1148]],\n", + " \n", + " [[-0.1421, -0.1254, -0.2353, -0.1675]]])),\n", + " ('model.layers.16.mixer.conv1d.bias',\n", + " tensor([ 0.2835, 0.2361, 0.1225, ..., -0.2119, -0.1929, 0.3877])),\n", + " ('model.layers.16.mixer.out_proj.weight',\n", + " tensor([[-0.0121, 0.0194, 0.0060, ..., -0.0029, -0.0147, -0.0085],\n", + " [-0.0216, -0.0012, 0.0287, ..., 0.0102, -0.0133, -0.0153],\n", + " [ 0.0136, -0.0296, 0.0417, ..., -0.0118, -0.0283, 0.0359],\n", + " ...,\n", + " [-0.0263, -0.0003, 0.0022, ..., 0.0135, -0.0519, -0.0254],\n", + " [ 0.0121, -0.0144, -0.0026, ..., 0.0096, 0.0130, 0.0095],\n", + " [-0.0147, -0.0217, 0.0099, ..., 0.0267, -0.0072, -0.0213]])),\n", + " ('model.layers.16.mlp.gate_proj.weight',\n", + " tensor([[ 0.0103, -0.0396, -0.0127, ..., 0.0020, -0.0055, 0.0291],\n", + " [ 0.0194, 0.0357, -0.0020, ..., -0.0112, 0.0448, -0.0224],\n", + " [-0.0390, 0.0142, -0.0224, ..., -0.0030, 0.0102, 0.0078],\n", + " ...,\n", + " [ 0.0165, -0.0251, 0.0196, ..., 0.0213, 0.0040, -0.0228],\n", + " [-0.0145, 0.0218, -0.0032, ..., -0.0240, -0.0079, 0.0256],\n", + " [ 0.0539, -0.0027, -0.0227, ..., -0.0184, -0.0109, 0.0236]])),\n", + " ('model.layers.16.mlp.up_proj.weight',\n", + " tensor([[ 7.1125e-03, -3.2583e-04, -2.6297e-02, ..., -4.9575e-03,\n", + " -1.2243e-02, -1.3005e-02],\n", + " [ 2.5637e-02, -1.1874e-02, 1.1376e-02, ..., -1.4700e-02,\n", + " -1.5193e-02, 2.6111e-03],\n", + " [-4.8919e-02, -4.9716e-04, 5.8527e-03, ..., 8.6775e-05,\n", + " 1.0694e-02, 3.7682e-03],\n", + " ...,\n", + " [ 8.8393e-03, -4.3317e-02, 2.8372e-02, ..., 2.2709e-02,\n", + " -4.8128e-03, 1.6899e-02],\n", + " [ 1.3257e-02, 2.1000e-02, 1.5035e-03, ..., 1.5603e-02,\n", + " -5.5857e-03, 4.0449e-03],\n", + " [-2.6754e-02, -1.6263e-02, 1.9013e-02, ..., -9.0918e-03,\n", + " -8.0242e-03, -1.0925e-02]])),\n", + " ('model.layers.16.mlp.down_proj.weight',\n", + " tensor([[ 0.0207, -0.0038, -0.0234, ..., 0.0299, -0.0329, -0.0117],\n", + " [-0.0316, 0.0032, 0.0131, ..., 0.0020, -0.0320, 0.0381],\n", + " [-0.0192, -0.0031, -0.0030, ..., -0.0224, 0.0037, 0.0085],\n", + " ...,\n", + " [ 0.0044, 0.0281, -0.0208, ..., 0.0179, -0.0085, -0.0010],\n", + " [-0.0076, -0.0008, 0.0483, ..., 0.0082, -0.0177, -0.0039],\n", + " [ 0.0224, 0.0019, 0.0181, ..., 0.0143, -0.0252, 0.0022]])),\n", + " ('model.layers.16.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.16.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.17.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.17.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.17.mixer.in_proj.weight',\n", + " tensor([[-0.0115, 0.0061, -0.0062, ..., -0.0132, -0.0047, 0.0274],\n", + " [ 0.0076, 0.0278, -0.0147, ..., 0.0439, -0.0093, -0.0154],\n", + " [-0.0383, -0.0264, -0.0053, ..., -0.0206, 0.0275, 0.0188],\n", + " ...,\n", + " [ 0.0096, 0.0228, 0.0351, ..., 0.0227, 0.0138, -0.0164],\n", + " [ 0.0321, -0.0293, -0.0054, ..., 0.0109, -0.0113, -0.0130],\n", + " [-0.0120, -0.0132, 0.0092, ..., -0.0338, 0.0308, -0.0135]])),\n", + " ('model.layers.17.mixer.conv1d.weight',\n", + " tensor([[[-0.4933, 0.4156, 0.2523, -0.0026]],\n", + " \n", + " [[-0.2572, 0.4916, 0.3642, -0.2145]],\n", + " \n", + " [[ 0.0261, 0.4852, -0.1448, 0.2288]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.3698, -0.4122, -0.2264, -0.1378]],\n", + " \n", + " [[ 0.1447, 0.4556, -0.0466, 0.0389]],\n", + " \n", + " [[-0.3891, 0.4149, 0.1454, -0.4282]]])),\n", + " ('model.layers.17.mixer.conv1d.bias',\n", + " tensor([-0.3919, -0.4015, 0.2591, ..., -0.3368, 0.2285, 0.1701])),\n", + " ('model.layers.17.mixer.out_proj.weight',\n", + " tensor([[-0.0127, -0.0155, 0.0193, ..., 0.0204, 0.0025, 0.0159],\n", + " [ 0.0192, 0.0194, -0.0169, ..., -0.0062, 0.0262, 0.0070],\n", + " [ 0.0397, 0.0009, 0.0189, ..., -0.0082, 0.0352, -0.0150],\n", + " ...,\n", + " [-0.0339, -0.0142, -0.0151, ..., 0.0229, 0.0032, 0.0038],\n", + " [ 0.0235, 0.0319, -0.0137, ..., -0.0121, 0.0112, 0.0162],\n", + " [ 0.0060, 0.0102, -0.0016, ..., 0.0118, 0.0158, -0.0140]])),\n", + " ('model.layers.17.mlp.gate_proj.weight',\n", + " tensor([[ 0.0285, -0.0090, -0.0095, ..., 0.0315, -0.0065, 0.0189],\n", + " [ 0.0040, -0.0358, -0.0039, ..., -0.0074, -0.0285, -0.0223],\n", + " [ 0.0202, 0.0021, -0.0104, ..., -0.0083, 0.0300, -0.0267],\n", + " ...,\n", + " [ 0.0093, -0.0008, -0.0372, ..., 0.0422, 0.0309, 0.0095],\n", + " [ 0.0027, 0.0252, 0.0378, ..., -0.0238, 0.0234, -0.0062],\n", + " [-0.0061, -0.0022, -0.0033, ..., 0.0157, -0.0296, 0.0034]])),\n", + " ('model.layers.17.mlp.up_proj.weight',\n", + " tensor([[ 0.0061, -0.0135, 0.0029, ..., 0.0328, 0.0008, -0.0072],\n", + " [ 0.0145, -0.0226, -0.0095, ..., 0.0114, 0.0224, -0.0160],\n", + " [ 0.0097, -0.0024, -0.0179, ..., 0.0073, -0.0061, -0.0195],\n", + " ...,\n", + " [ 0.0308, -0.0014, 0.0104, ..., 0.0047, 0.0026, 0.0243],\n", + " [-0.0364, 0.0350, 0.0031, ..., -0.0072, 0.0267, 0.0017],\n", + " [ 0.0227, -0.0146, 0.0146, ..., -0.0434, -0.0159, 0.0230]])),\n", + " ('model.layers.17.mlp.down_proj.weight',\n", + " tensor([[-0.0216, 0.0211, 0.0136, ..., -0.0004, 0.0051, 0.0415],\n", + " [-0.0061, -0.0123, 0.0156, ..., -0.0005, -0.0183, -0.0137],\n", + " [-0.0146, -0.0274, -0.0439, ..., -0.0033, -0.0030, -0.0074],\n", + " ...,\n", + " [-0.0108, -0.0005, -0.0094, ..., -0.0243, 0.0065, -0.0005],\n", + " [-0.0126, 0.0124, -0.0006, ..., -0.0282, -0.0110, 0.0128],\n", + " [-0.0162, -0.0102, 0.0025, ..., -0.0084, 0.0066, -0.0074]])),\n", + " ('model.layers.17.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.17.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.18.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.18.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.18.mixer.in_proj.weight',\n", + " tensor([[-9.4961e-03, -1.2349e-04, -7.1455e-03, ..., 1.9508e-02,\n", + " -6.8715e-03, -1.3565e-02],\n", + " [-2.9701e-03, 3.1580e-03, 1.8849e-02, ..., 7.6566e-03,\n", + " -1.0968e-02, -8.0445e-03],\n", + " [-1.5402e-02, -6.7267e-03, 9.6119e-03, ..., 1.9799e-02,\n", + " 2.0198e-03, -1.7366e-03],\n", + " ...,\n", + " [ 8.2379e-03, 5.1668e-03, 3.8116e-02, ..., -3.8710e-03,\n", + " 1.4452e-02, -2.5152e-02],\n", + " [ 1.1949e-02, -1.2245e-03, 1.0568e-02, ..., -3.1690e-02,\n", + " 3.8135e-05, 1.7263e-02],\n", + " [ 1.6173e-04, 5.6721e-04, 2.1043e-02, ..., -3.6167e-02,\n", + " -1.1129e-02, -9.6768e-03]])),\n", + " ('model.layers.18.mixer.conv1d.weight',\n", + " tensor([[[ 0.2776, 0.2169, -0.2840, 0.1736]],\n", + " \n", + " [[-0.0598, -0.2654, 0.2423, -0.0874]],\n", + " \n", + " [[-0.3612, -0.3049, -0.3197, -0.2763]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.1389, 0.2034, -0.1739, 0.1634]],\n", + " \n", + " [[-0.2836, -0.0471, 0.1284, -0.0099]],\n", + " \n", + " [[ 0.2952, -0.2676, -0.3961, 0.2656]]])),\n", + " ('model.layers.18.mixer.conv1d.bias',\n", + " tensor([ 0.1804, 0.0336, 0.4006, ..., 0.2943, -0.1079, 0.0963])),\n", + " ('model.layers.18.mixer.out_proj.weight',\n", + " tensor([[ 0.0109, -0.0181, 0.0148, ..., -0.0105, -0.0011, -0.0052],\n", + " [ 0.0507, 0.0100, -0.0273, ..., -0.0069, 0.0054, 0.0129],\n", + " [ 0.0014, 0.0423, -0.0193, ..., -0.0023, -0.0293, 0.0004],\n", + " ...,\n", + " [ 0.0420, -0.0401, 0.0205, ..., 0.0135, -0.0089, -0.0023],\n", + " [ 0.0242, 0.0273, 0.0139, ..., -0.0402, 0.0061, 0.0119],\n", + " [-0.0145, 0.0102, 0.0245, ..., 0.0205, -0.0251, 0.0006]])),\n", + " ('model.layers.18.mlp.gate_proj.weight',\n", + " tensor([[ 0.0241, -0.0086, 0.0136, ..., -0.0219, -0.0064, -0.0142],\n", + " [-0.0067, 0.0252, 0.0246, ..., -0.0205, -0.0273, 0.0137],\n", + " [-0.0030, 0.0055, -0.0063, ..., 0.0107, 0.0083, -0.0037],\n", + " ...,\n", + " [-0.0154, 0.0101, 0.0221, ..., 0.0025, -0.0109, 0.0133],\n", + " [-0.0175, 0.0105, -0.0246, ..., 0.0244, 0.0023, 0.0080],\n", + " [-0.0060, 0.0183, 0.0297, ..., 0.0420, -0.0006, -0.0119]])),\n", + " ('model.layers.18.mlp.up_proj.weight',\n", + " tensor([[ 0.0066, -0.0009, -0.0070, ..., -0.0064, 0.0002, 0.0196],\n", + " [-0.0173, -0.0362, -0.0011, ..., 0.0158, -0.0198, -0.0046],\n", + " [ 0.0133, -0.0090, -0.0092, ..., 0.0039, -0.0052, -0.0101],\n", + " ...,\n", + " [ 0.0077, -0.0063, 0.0010, ..., 0.0091, 0.0218, 0.0132],\n", + " [ 0.0005, -0.0046, 0.0207, ..., 0.0112, 0.0183, -0.0020],\n", + " [ 0.0238, -0.0022, 0.0364, ..., -0.0042, 0.0237, 0.0183]])),\n", + " ('model.layers.18.mlp.down_proj.weight',\n", + " tensor([[ 0.0305, 0.0178, -0.0264, ..., -0.0158, 0.0135, 0.0132],\n", + " [ 0.0248, -0.0061, 0.0144, ..., -0.0165, 0.0098, 0.0410],\n", + " [-0.0156, -0.0039, 0.0112, ..., -0.0431, -0.0084, -0.0197],\n", + " ...,\n", + " [ 0.0071, 0.0236, -0.0038, ..., 0.0035, -0.0236, 0.0106],\n", + " [-0.0369, -0.0029, -0.0182, ..., -0.0008, -0.0417, 0.0064],\n", + " [-0.0273, 0.0207, 0.0130, ..., 0.0372, 0.0163, 0.0273]])),\n", + " ('model.layers.18.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.18.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.19.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.19.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.19.mixer.in_proj.weight',\n", + " tensor([[-0.0079, 0.0147, -0.0337, ..., -0.0201, -0.0254, 0.0035],\n", + " [ 0.0139, 0.0054, -0.0093, ..., -0.0208, -0.0289, -0.0087],\n", + " [ 0.0004, -0.0034, 0.0090, ..., -0.0109, -0.0093, 0.0102],\n", + " ...,\n", + " [ 0.0128, 0.0015, -0.0101, ..., -0.0482, -0.0217, 0.0144],\n", + " [-0.0100, -0.0079, 0.0286, ..., -0.0025, -0.0210, 0.0164],\n", + " [-0.0264, 0.0015, 0.0031, ..., 0.0027, 0.0131, -0.0384]])),\n", + " ('model.layers.19.mixer.conv1d.weight',\n", + " tensor([[[ 0.4729, 0.3708, -0.4394, -0.3549]],\n", + " \n", + " [[ 0.2230, -0.3271, 0.3017, -0.2552]],\n", + " \n", + " [[-0.0417, 0.1893, 0.4552, -0.0644]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.2565, 0.0407, 0.3521, 0.4116]],\n", + " \n", + " [[ 0.0795, -0.0374, 0.1034, 0.4254]],\n", + " \n", + " [[ 0.3333, 0.2431, 0.3459, -0.2676]]])),\n", + " ('model.layers.19.mixer.conv1d.bias',\n", + " tensor([-0.2287, -0.4446, -0.2300, ..., -0.2317, -0.3395, 0.4310])),\n", + " ('model.layers.19.mixer.out_proj.weight',\n", + " tensor([[-0.0456, -0.0167, -0.0117, ..., -0.0068, -0.0150, 0.0125],\n", + " [ 0.0194, 0.0172, -0.0232, ..., -0.0202, -0.0066, 0.0083],\n", + " [ 0.0320, -0.0065, 0.0274, ..., 0.0200, 0.0090, 0.0105],\n", + " ...,\n", + " [ 0.0315, 0.0415, 0.0128, ..., -0.0143, -0.0338, -0.0231],\n", + " [ 0.0227, -0.0177, -0.0034, ..., 0.0174, 0.0006, 0.0212],\n", + " [ 0.0358, 0.0084, 0.0075, ..., 0.0091, 0.0062, 0.0114]])),\n", + " ('model.layers.19.mlp.gate_proj.weight',\n", + " tensor([[-0.0010, 0.0156, 0.0042, ..., -0.0181, 0.0113, 0.0089],\n", + " [-0.0182, 0.0068, -0.0043, ..., -0.0323, -0.0019, -0.0045],\n", + " [ 0.0168, -0.0093, -0.0162, ..., -0.0074, 0.0166, -0.0334],\n", + " ...,\n", + " [ 0.0038, -0.0211, -0.0054, ..., -0.0229, 0.0193, -0.0210],\n", + " [ 0.0153, -0.0372, 0.0119, ..., 0.0043, -0.0097, -0.0025],\n", + " [ 0.0037, 0.0208, -0.0135, ..., 0.0052, -0.0125, -0.0282]])),\n", + " ('model.layers.19.mlp.up_proj.weight',\n", + " tensor([[-0.0026, 0.0360, 0.0161, ..., 0.0199, -0.0283, -0.0026],\n", + " [ 0.0185, 0.0122, -0.0299, ..., 0.0125, 0.0063, 0.0387],\n", + " [-0.0085, -0.0010, -0.0054, ..., -0.0088, -0.0034, -0.0179],\n", + " ...,\n", + " [-0.0179, 0.0211, -0.0003, ..., -0.0071, -0.0145, 0.0235],\n", + " [-0.0002, 0.0060, -0.0172, ..., -0.0086, 0.0175, -0.0232],\n", + " [-0.0081, -0.0280, -0.0152, ..., -0.0221, 0.0047, -0.0077]])),\n", + " ('model.layers.19.mlp.down_proj.weight',\n", + " tensor([[ 0.0038, -0.0027, -0.0122, ..., 0.0090, 0.0044, 0.0128],\n", + " [ 0.0054, 0.0075, 0.0116, ..., 0.0232, 0.0130, 0.0298],\n", + " [-0.0498, -0.0208, -0.0127, ..., 0.0166, -0.0221, 0.0038],\n", + " ...,\n", + " [ 0.0101, 0.0051, 0.0209, ..., 0.0137, -0.0225, 0.0142],\n", + " [-0.0433, -0.0217, -0.0167, ..., -0.0179, -0.0191, -0.0021],\n", + " [-0.0020, 0.0084, -0.0114, ..., 0.0324, 0.0216, -0.0062]])),\n", + " ('model.layers.19.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.19.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.20.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.20.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.20.mixer.in_proj.weight',\n", + " tensor([[ 3.3776e-02, 3.6619e-02, 6.8532e-03, ..., 5.7664e-02,\n", + " -2.3083e-02, -6.2962e-02],\n", + " [-2.9787e-03, -2.5050e-03, -3.4841e-03, ..., 5.4946e-03,\n", + " 9.0683e-03, 2.1583e-04],\n", + " [ 7.4430e-03, -1.0495e-02, 3.5169e-02, ..., -5.1808e-02,\n", + " 3.2650e-03, -3.1967e-02],\n", + " ...,\n", + " [-5.8685e-02, 4.8452e-02, -1.2612e-02, ..., 1.2174e-02,\n", + " 1.0566e-02, -4.9561e-03],\n", + " [ 3.1722e-03, -2.9390e-03, 1.4502e-05, ..., -2.3297e-02,\n", + " -7.5403e-03, -1.3599e-02],\n", + " [ 1.4845e-02, -4.3150e-02, -1.0338e-02, ..., -1.1149e-02,\n", + " -3.3432e-02, 3.8337e-03]])),\n", + " ('model.layers.20.mixer.conv1d.weight',\n", + " tensor([[[-0.3842, 0.2397, 0.4873, -0.3091]],\n", + " \n", + " [[-0.1886, 0.0751, 0.2026, -0.2674]],\n", + " \n", + " [[-0.0594, 0.3119, -0.2404, 0.1652]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0028, 0.1315, 0.0515, 0.3189]],\n", + " \n", + " [[-0.1461, -0.0457, -0.0536, -0.2306]],\n", + " \n", + " [[-0.3025, -0.3339, 0.3007, -0.3007]]])),\n", + " ('model.layers.20.mixer.conv1d.bias',\n", + " tensor([-0.4901, -0.3784, -0.0173, ..., -0.3946, -0.0728, 0.2187])),\n", + " ('model.layers.20.mixer.out_proj.weight',\n", + " tensor([[ 0.0095, -0.0037, -0.0218, ..., 0.0080, 0.0062, 0.0246],\n", + " [-0.0197, 0.0037, 0.0076, ..., 0.0171, 0.0238, -0.0195],\n", + " [ 0.0364, -0.0165, 0.0224, ..., -0.0099, 0.0007, 0.0340],\n", + " ...,\n", + " [ 0.0235, -0.0072, -0.0319, ..., 0.0045, -0.0196, 0.0011],\n", + " [-0.0369, 0.0083, 0.0021, ..., -0.0357, -0.0039, -0.0150],\n", + " [-0.0174, -0.0211, 0.0111, ..., 0.0251, 0.0040, -0.0308]])),\n", + " ('model.layers.20.mlp.gate_proj.weight',\n", + " tensor([[ 0.0161, -0.0019, -0.0473, ..., 0.0019, 0.0075, -0.0038],\n", + " [-0.0321, -0.0020, -0.0100, ..., 0.0035, 0.0291, -0.0058],\n", + " [-0.0158, 0.0020, 0.0353, ..., 0.0125, 0.0228, -0.0392],\n", + " ...,\n", + " [ 0.0113, 0.0171, 0.0235, ..., 0.0043, 0.0378, 0.0391],\n", + " [ 0.0090, 0.0067, 0.0031, ..., 0.0291, -0.0052, -0.0216],\n", + " [ 0.0042, -0.0112, -0.0161, ..., -0.0063, -0.0156, 0.0211]])),\n", + " ('model.layers.20.mlp.up_proj.weight',\n", + " tensor([[ 0.0104, -0.0302, -0.0220, ..., -0.0072, -0.0083, -0.0066],\n", + " [ 0.0409, -0.0116, -0.0125, ..., 0.0182, 0.0267, 0.0099],\n", + " [-0.0055, 0.0104, 0.0027, ..., -0.0075, -0.0368, -0.0092],\n", + " ...,\n", + " [-0.0089, 0.0243, -0.0028, ..., -0.0136, -0.0176, -0.0054],\n", + " [ 0.0088, 0.0365, -0.0354, ..., 0.0035, 0.0280, 0.0155],\n", + " [-0.0472, 0.0088, 0.0102, ..., -0.0120, 0.0004, -0.0011]])),\n", + " ('model.layers.20.mlp.down_proj.weight',\n", + " tensor([[-0.0089, -0.0112, -0.0007, ..., 0.0360, -0.0077, 0.0261],\n", + " [ 0.0080, -0.0128, -0.0445, ..., 0.0095, -0.0298, 0.0176],\n", + " [ 0.0357, -0.0262, 0.0028, ..., 0.0162, 0.0089, 0.0050],\n", + " ...,\n", + " [-0.0129, 0.0216, 0.0125, ..., -0.0062, -0.0344, -0.0218],\n", + " [ 0.0006, -0.0143, -0.0099, ..., -0.0359, 0.0268, 0.0259],\n", + " [ 0.0222, -0.0154, 0.0013, ..., 0.0108, -0.0077, 0.0186]])),\n", + " ('model.layers.20.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.20.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.21.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.21.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.21.mixer.in_proj.weight',\n", + " tensor([[-0.0300, 0.0058, -0.0107, ..., -0.0318, 0.0350, 0.0350],\n", + " [ 0.0186, 0.0238, -0.0268, ..., 0.0142, -0.0277, -0.0095],\n", + " [-0.0061, 0.0083, 0.0072, ..., 0.0161, 0.0027, -0.0051],\n", + " ...,\n", + " [-0.0358, 0.0330, 0.0151, ..., -0.0376, 0.0057, 0.0174],\n", + " [-0.0021, 0.0068, 0.0151, ..., 0.0077, -0.0353, 0.0095],\n", + " [-0.0113, -0.0043, 0.0064, ..., -0.0063, -0.0232, -0.0058]])),\n", + " ('model.layers.21.mixer.conv1d.weight',\n", + " tensor([[[ 0.0354, 0.0496, -0.0106, 0.0084]],\n", + " \n", + " [[ 0.2553, 0.3217, -0.0078, -0.2333]],\n", + " \n", + " [[-0.1390, 0.0323, 0.4914, -0.2047]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.2243, 0.2984, 0.0188, 0.1830]],\n", + " \n", + " [[ 0.0756, 0.1443, -0.4898, -0.2082]],\n", + " \n", + " [[-0.3685, -0.1311, -0.4037, -0.3276]]])),\n", + " ('model.layers.21.mixer.conv1d.bias',\n", + " tensor([-0.2444, -0.1852, 0.2215, ..., 0.4515, 0.2532, -0.2388])),\n", + " ('model.layers.21.mixer.out_proj.weight',\n", + " tensor([[ 0.0232, 0.0328, 0.0026, ..., -0.0575, 0.0157, -0.0072],\n", + " [-0.0226, 0.0058, -0.0346, ..., 0.0092, 0.0078, 0.0108],\n", + " [ 0.0045, 0.0247, 0.0150, ..., -0.0085, 0.0268, 0.0253],\n", + " ...,\n", + " [ 0.0268, 0.0092, 0.0141, ..., 0.0062, 0.0177, -0.0405],\n", + " [ 0.0163, -0.0269, -0.0177, ..., 0.0029, -0.0080, -0.0036],\n", + " [ 0.0064, 0.0126, 0.0126, ..., -0.0400, -0.0015, -0.0088]])),\n", + " ('model.layers.21.mlp.gate_proj.weight',\n", + " tensor([[-3.7050e-02, 4.5834e-02, 1.9280e-02, ..., 1.6761e-02,\n", + " -5.8295e-03, -1.4284e-02],\n", + " [ 3.0156e-02, 3.2832e-02, 1.1083e-02, ..., -5.8261e-03,\n", + " -3.9076e-02, 5.3379e-03],\n", + " [ 1.3118e-03, 3.1510e-02, 1.5472e-02, ..., 1.8213e-02,\n", + " -2.5180e-02, 6.1512e-04],\n", + " ...,\n", + " [ 4.2010e-02, 1.0362e-02, 7.1759e-03, ..., 1.8667e-03,\n", + " -7.2165e-03, 1.6297e-02],\n", + " [ 1.8175e-02, 1.2840e-02, 3.2857e-03, ..., 1.8495e-02,\n", + " -7.7709e-03, 4.3964e-04],\n", + " [-9.2628e-05, 2.1701e-02, 2.1256e-02, ..., 2.5241e-02,\n", + " 5.0683e-02, -2.5481e-02]])),\n", + " ('model.layers.21.mlp.up_proj.weight',\n", + " tensor([[ 0.0228, 0.0082, -0.0083, ..., 0.0288, 0.0211, 0.0085],\n", + " [-0.0155, 0.0179, 0.0111, ..., -0.0218, -0.0162, -0.0052],\n", + " [ 0.0016, 0.0009, 0.0230, ..., -0.0017, 0.0131, 0.0255],\n", + " ...,\n", + " [-0.0098, -0.0098, -0.0188, ..., 0.0063, 0.0082, 0.0052],\n", + " [-0.0028, 0.0249, -0.0153, ..., -0.0208, 0.0130, -0.0093],\n", + " [ 0.0105, -0.0072, -0.0379, ..., 0.0035, 0.0182, 0.0307]])),\n", + " ('model.layers.21.mlp.down_proj.weight',\n", + " tensor([[-0.0445, -0.0116, 0.0058, ..., 0.0081, -0.0099, 0.0094],\n", + " [ 0.0106, -0.0387, 0.0051, ..., 0.0017, 0.0075, 0.0136],\n", + " [ 0.0022, 0.0058, -0.0268, ..., -0.0088, -0.0149, 0.0125],\n", + " ...,\n", + " [-0.0015, -0.0156, -0.0225, ..., 0.0100, -0.0118, -0.0019],\n", + " [-0.0161, -0.0225, -0.0060, ..., 0.0073, -0.0072, 0.0205],\n", + " [-0.0112, 0.0046, -0.0089, ..., -0.0014, -0.0221, 0.0124]])),\n", + " ('model.layers.21.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.21.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.22.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.22.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.22.mixer.in_proj.weight',\n", + " tensor([[-1.1591e-02, -6.0118e-03, -2.2227e-03, ..., -7.1433e-03,\n", + " -1.5757e-02, -1.5315e-03],\n", + " [-7.6057e-03, -4.2199e-02, 1.4478e-02, ..., 5.6496e-02,\n", + " 8.9105e-05, -3.8658e-03],\n", + " [-1.0330e-03, 2.3586e-02, 2.1835e-02, ..., -1.4911e-03,\n", + " -1.6604e-02, -4.5245e-03],\n", + " ...,\n", + " [-6.7261e-03, -6.9826e-03, -9.3003e-03, ..., -4.3939e-02,\n", + " 2.3792e-02, -5.5165e-03],\n", + " [-1.1798e-02, -3.4709e-02, -4.1277e-03, ..., -5.1867e-03,\n", + " 5.2496e-03, -6.0055e-03],\n", + " [ 7.3402e-04, -1.9525e-02, -5.8966e-03, ..., -1.5972e-02,\n", + " -1.5446e-02, -2.7164e-02]])),\n", + " ('model.layers.22.mixer.conv1d.weight',\n", + " tensor([[[-0.3791, 0.0616, 0.0369, 0.1365]],\n", + " \n", + " [[-0.4674, -0.4557, 0.3894, -0.4765]],\n", + " \n", + " [[ 0.3333, 0.2265, 0.1385, -0.1352]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.4363, -0.3526, -0.3982, -0.1049]],\n", + " \n", + " [[ 0.4798, -0.3912, 0.4059, -0.1379]],\n", + " \n", + " [[-0.4427, 0.4661, -0.1990, 0.1668]]])),\n", + " ('model.layers.22.mixer.conv1d.bias',\n", + " tensor([-0.1823, -0.4117, 0.4443, ..., -0.0024, 0.2144, -0.4922])),\n", + " ('model.layers.22.mixer.out_proj.weight',\n", + " tensor([[ 0.0138, -0.0169, -0.0349, ..., -0.0045, 0.0023, -0.0389],\n", + " [ 0.0250, 0.0040, -0.0259, ..., 0.0458, 0.0311, -0.0054],\n", + " [-0.0056, 0.0012, -0.0027, ..., 0.0095, -0.0089, -0.0106],\n", + " ...,\n", + " [ 0.0228, -0.0258, 0.0040, ..., 0.0276, -0.0121, -0.0239],\n", + " [ 0.0082, 0.0041, 0.0145, ..., 0.0079, -0.0076, 0.0177],\n", + " [ 0.0310, -0.0092, -0.0174, ..., 0.0179, 0.0231, -0.0035]])),\n", + " ('model.layers.22.mlp.gate_proj.weight',\n", + " tensor([[ 0.0090, -0.0178, -0.0120, ..., -0.0073, -0.0149, 0.0187],\n", + " [ 0.0263, -0.0093, -0.0074, ..., -0.0472, 0.0049, 0.0288],\n", + " [ 0.0159, -0.0083, 0.0291, ..., 0.0089, -0.0076, -0.0167],\n", + " ...,\n", + " [-0.0008, 0.0206, 0.0199, ..., -0.0134, -0.0366, -0.0202],\n", + " [-0.0069, -0.0275, 0.0054, ..., 0.0093, 0.0108, 0.0094],\n", + " [ 0.0198, 0.0033, -0.0118, ..., -0.0262, 0.0241, 0.0084]])),\n", + " ('model.layers.22.mlp.up_proj.weight',\n", + " tensor([[-0.0277, 0.0038, 0.0006, ..., -0.0222, -0.0313, -0.0133],\n", + " [ 0.0132, -0.0373, 0.0109, ..., 0.0359, -0.0116, 0.0099],\n", + " [ 0.0139, -0.0185, 0.0247, ..., 0.0178, 0.0192, 0.0049],\n", + " ...,\n", + " [ 0.0362, 0.0072, -0.0236, ..., -0.0238, 0.0319, -0.0210],\n", + " [ 0.0013, -0.0047, -0.0060, ..., 0.0106, -0.0074, -0.0185],\n", + " [-0.0228, 0.0176, -0.0047, ..., -0.0034, -0.0174, -0.0264]])),\n", + " ('model.layers.22.mlp.down_proj.weight',\n", + " tensor([[ 0.0149, 0.0122, -0.0037, ..., 0.0044, 0.0171, -0.0186],\n", + " [-0.0037, -0.0002, 0.0066, ..., 0.0263, -0.0025, -0.0012],\n", + " [-0.0075, 0.0209, 0.0045, ..., 0.0082, -0.0160, 0.0079],\n", + " ...,\n", + " [ 0.0001, 0.0507, -0.0078, ..., 0.0001, -0.0119, 0.0286],\n", + " [-0.0198, -0.0122, 0.0047, ..., -0.0052, 0.0130, -0.0007],\n", + " [ 0.0241, -0.0002, -0.0147, ..., 0.0219, -0.0020, -0.0071]])),\n", + " ('model.layers.22.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.22.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.23.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.23.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.23.mixer.in_proj.weight',\n", + " tensor([[-0.0017, 0.0027, -0.0150, ..., 0.0392, -0.0079, -0.0367],\n", + " [ 0.0183, 0.0261, -0.0262, ..., -0.0157, 0.0197, 0.0135],\n", + " [-0.0030, 0.0170, 0.0032, ..., 0.0059, 0.0299, 0.0158],\n", + " ...,\n", + " [-0.0149, 0.0218, 0.0072, ..., -0.0302, 0.0035, 0.0153],\n", + " [-0.0135, 0.0425, 0.0331, ..., -0.0119, -0.0364, 0.0365],\n", + " [-0.0215, -0.0242, 0.0271, ..., 0.0500, 0.0293, 0.0100]])),\n", + " ('model.layers.23.mixer.conv1d.weight',\n", + " tensor([[[ 0.2464, 0.3726, 0.2719, 0.3580]],\n", + " \n", + " [[-0.0520, 0.0010, 0.1396, -0.4634]],\n", + " \n", + " [[ 0.1383, 0.4039, -0.3622, 0.1499]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.4094, 0.0541, 0.2240, -0.1545]],\n", + " \n", + " [[-0.4393, 0.1323, 0.1705, -0.1722]],\n", + " \n", + " [[ 0.2166, -0.4335, -0.4088, -0.1159]]])),\n", + " ('model.layers.23.mixer.conv1d.bias',\n", + " tensor([ 0.3175, -0.0325, -0.4654, ..., 0.3869, -0.2534, 0.1588])),\n", + " ('model.layers.23.mixer.out_proj.weight',\n", + " tensor([[-0.0354, -0.0041, 0.0196, ..., -0.0218, -0.0222, 0.0126],\n", + " [-0.0155, -0.0067, -0.0007, ..., 0.0112, -0.0036, -0.0054],\n", + " [ 0.0141, 0.0040, -0.0218, ..., -0.0178, -0.0031, 0.0162],\n", + " ...,\n", + " [ 0.0264, 0.0063, 0.0088, ..., -0.0310, -0.0116, 0.0239],\n", + " [-0.0031, 0.0056, -0.0243, ..., -0.0350, 0.0004, 0.0004],\n", + " [ 0.0229, -0.0201, 0.0124, ..., 0.0313, -0.0412, -0.0033]])),\n", + " ('model.layers.23.mlp.gate_proj.weight',\n", + " tensor([[ 0.0026, -0.0155, 0.0595, ..., 0.0204, 0.0172, 0.0378],\n", + " [-0.0011, -0.0253, 0.0039, ..., 0.0330, -0.0487, -0.0195],\n", + " [ 0.0174, 0.0039, -0.0029, ..., -0.0026, 0.0104, 0.0108],\n", + " ...,\n", + " [-0.0159, 0.0008, 0.0173, ..., -0.0020, 0.0085, -0.0043],\n", + " [ 0.0101, 0.0221, -0.0034, ..., -0.0268, 0.0056, 0.0137],\n", + " [-0.0031, -0.0151, 0.0073, ..., -0.0083, -0.0064, 0.0109]])),\n", + " ('model.layers.23.mlp.up_proj.weight',\n", + " tensor([[ 0.0173, -0.0132, -0.0027, ..., 0.0391, 0.0268, -0.0185],\n", + " [ 0.0221, -0.0110, -0.0108, ..., -0.0302, 0.0170, 0.0139],\n", + " [-0.0047, -0.0373, 0.0056, ..., -0.0389, -0.0175, -0.0410],\n", + " ...,\n", + " [ 0.0003, 0.0153, 0.0160, ..., 0.0002, -0.0136, 0.0417],\n", + " [-0.0059, -0.0150, -0.0111, ..., 0.0163, 0.0171, 0.0267],\n", + " [-0.0123, -0.0032, 0.0193, ..., -0.0051, -0.0051, -0.0089]])),\n", + " ('model.layers.23.mlp.down_proj.weight',\n", + " tensor([[-0.0092, -0.0148, -0.0345, ..., -0.0240, 0.0425, -0.0099],\n", + " [ 0.0458, 0.0156, -0.0067, ..., -0.0283, 0.0401, 0.0074],\n", + " [ 0.0180, -0.0008, 0.0049, ..., -0.0085, -0.0157, 0.0044],\n", + " ...,\n", + " [-0.0207, 0.0074, -0.0176, ..., 0.0038, -0.0238, -0.0026],\n", + " [-0.0201, 0.0078, 0.0243, ..., -0.0031, 0.0080, -0.0176],\n", + " [-0.0034, 0.0191, 0.0391, ..., -0.0114, 0.0133, -0.0261]])),\n", + " ('model.layers.23.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.23.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.24.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.24.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.24.mixer.in_proj.weight',\n", + " tensor([[-0.0184, -0.0299, 0.0165, ..., 0.0035, 0.0417, -0.0170],\n", + " [-0.0346, -0.0226, 0.0064, ..., 0.0072, 0.0457, -0.0148],\n", + " [ 0.0032, -0.0245, -0.0474, ..., -0.0054, -0.0044, 0.0278],\n", + " ...,\n", + " [ 0.0139, 0.0133, -0.0185, ..., 0.0188, 0.0119, -0.0205],\n", + " [ 0.0235, 0.0161, -0.0095, ..., 0.0013, -0.0382, 0.0213],\n", + " [ 0.0031, -0.0394, 0.0275, ..., -0.0068, 0.0024, 0.0179]])),\n", + " ('model.layers.24.mixer.conv1d.weight',\n", + " tensor([[[-0.1857, -0.4692, 0.4791, 0.3706]],\n", + " \n", + " [[ 0.1749, 0.4182, -0.2338, 0.0838]],\n", + " \n", + " [[-0.1204, -0.2985, -0.0470, 0.4674]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.1485, 0.3118, -0.4916, -0.1610]],\n", + " \n", + " [[ 0.0684, -0.2980, 0.4517, -0.3662]],\n", + " \n", + " [[ 0.2353, -0.2156, -0.3332, -0.0665]]])),\n", + " ('model.layers.24.mixer.conv1d.bias',\n", + " tensor([-0.4464, -0.3485, -0.3916, ..., 0.2513, -0.0601, 0.1546])),\n", + " ('model.layers.24.mixer.out_proj.weight',\n", + " tensor([[-0.0023, 0.0087, -0.0280, ..., 0.0338, -0.0095, -0.0237],\n", + " [-0.0086, -0.0084, 0.0180, ..., 0.0350, 0.0463, -0.0270],\n", + " [-0.0093, -0.0009, 0.0236, ..., 0.0158, 0.0246, 0.0068],\n", + " ...,\n", + " [ 0.0526, 0.0009, 0.0039, ..., -0.0206, -0.0538, 0.0287],\n", + " [ 0.0054, -0.0053, -0.0108, ..., 0.0167, -0.0997, 0.0036],\n", + " [ 0.0009, -0.0297, -0.0424, ..., -0.0096, -0.0235, 0.0117]])),\n", + " ('model.layers.24.mlp.gate_proj.weight',\n", + " tensor([[-0.0265, 0.0259, 0.0224, ..., -0.0080, -0.0394, 0.0290],\n", + " [-0.0101, -0.0256, 0.0079, ..., -0.0017, -0.0287, -0.0163],\n", + " [ 0.0079, -0.0021, -0.0299, ..., 0.0076, 0.0063, 0.0082],\n", + " ...,\n", + " [ 0.0061, 0.0121, 0.0275, ..., -0.0162, 0.0025, -0.0075],\n", + " [-0.0039, -0.0217, -0.0428, ..., -0.0253, 0.0231, 0.0095],\n", + " [-0.0187, 0.0077, -0.0442, ..., 0.0358, -0.0084, -0.0132]])),\n", + " ('model.layers.24.mlp.up_proj.weight',\n", + " tensor([[-0.0201, -0.0119, 0.0505, ..., -0.0025, -0.0187, 0.0011],\n", + " [-0.0105, 0.0154, -0.0163, ..., 0.0248, 0.0028, 0.0178],\n", + " [-0.0163, -0.0271, -0.0100, ..., 0.0129, -0.0220, 0.0269],\n", + " ...,\n", + " [ 0.0138, 0.0329, -0.0091, ..., 0.0038, -0.0194, -0.0223],\n", + " [ 0.0469, 0.0291, -0.0027, ..., 0.0231, 0.0261, 0.0151],\n", + " [-0.0093, -0.0098, 0.0013, ..., 0.0078, -0.0145, 0.0268]])),\n", + " ('model.layers.24.mlp.down_proj.weight',\n", + " tensor([[-0.0195, -0.0003, -0.0046, ..., -0.0132, -0.0118, 0.0242],\n", + " [-0.0267, 0.0199, 0.0243, ..., -0.0063, 0.0134, -0.0163],\n", + " [-0.0044, -0.0303, -0.0215, ..., -0.0148, -0.0216, 0.0079],\n", + " ...,\n", + " [ 0.0159, 0.0180, 0.0098, ..., -0.0126, 0.0176, 0.0087],\n", + " [-0.0203, 0.0041, -0.0256, ..., -0.0047, -0.0236, -0.0256],\n", + " [-0.0017, 0.0133, 0.0490, ..., -0.0344, -0.0118, 0.0020]])),\n", + " ('model.layers.24.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.24.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.25.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.25.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.25.mixer.in_proj.weight',\n", + " tensor([[ 0.0064, 0.0039, 0.0014, ..., 0.0130, -0.0169, 0.0010],\n", + " [ 0.0371, 0.0241, 0.0203, ..., 0.0078, 0.0463, 0.0034],\n", + " [ 0.0184, -0.0431, -0.0026, ..., -0.0164, 0.0279, -0.0138],\n", + " ...,\n", + " [ 0.0146, -0.0138, -0.0418, ..., 0.0234, 0.0145, -0.0213],\n", + " [ 0.0124, -0.0298, -0.0164, ..., -0.0169, 0.0026, -0.0180],\n", + " [-0.0250, -0.0008, -0.0133, ..., -0.0131, -0.0064, 0.0071]])),\n", + " ('model.layers.25.mixer.conv1d.weight',\n", + " tensor([[[ 0.0171, -0.3423, -0.1701, 0.4869]],\n", + " \n", + " [[-0.4648, 0.4797, 0.3531, -0.3819]],\n", + " \n", + " [[-0.1660, -0.3489, -0.2488, 0.4428]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.3545, -0.1567, -0.2646, 0.3590]],\n", + " \n", + " [[-0.2175, 0.4394, 0.3840, 0.2620]],\n", + " \n", + " [[ 0.1335, -0.3655, 0.3256, -0.1752]]])),\n", + " ('model.layers.25.mixer.conv1d.bias',\n", + " tensor([-0.0935, 0.0170, 0.0779, ..., -0.2362, 0.2879, 0.2390])),\n", + " ('model.layers.25.mixer.out_proj.weight',\n", + " tensor([[ 2.0220e-02, 5.0645e-05, -1.7425e-02, ..., 8.6082e-03,\n", + " -1.8566e-02, 1.3872e-02],\n", + " [ 2.9139e-02, 1.1096e-02, 4.4168e-02, ..., 3.5600e-02,\n", + " 7.3446e-03, -1.6368e-02],\n", + " [-3.2418e-02, 6.9682e-03, 3.1648e-02, ..., 1.4050e-02,\n", + " -1.6554e-02, 7.2751e-03],\n", + " ...,\n", + " [-3.3057e-02, -7.0545e-04, 3.9661e-02, ..., 2.0690e-02,\n", + " -1.0262e-02, -4.9292e-03],\n", + " [ 1.9849e-02, 1.9666e-02, -1.9398e-02, ..., 1.9285e-02,\n", + " 2.2522e-02, -6.0243e-03],\n", + " [ 1.7683e-02, 2.4301e-02, 7.2223e-03, ..., 3.1373e-02,\n", + " -5.7889e-03, 1.1855e-02]])),\n", + " ('model.layers.25.mlp.gate_proj.weight',\n", + " tensor([[-1.6223e-02, 4.5519e-03, -1.9218e-02, ..., 6.3580e-03,\n", + " -1.2723e-02, -9.7756e-03],\n", + " [-7.4200e-03, 1.8729e-02, 2.6924e-03, ..., 8.2305e-03,\n", + " -1.5727e-02, -9.8748e-03],\n", + " [ 3.2143e-02, -6.1559e-02, 1.6362e-02, ..., -3.6189e-04,\n", + " 1.2017e-04, -1.5734e-02],\n", + " ...,\n", + " [-1.4649e-02, -4.7663e-03, -1.9292e-02, ..., -1.9359e-02,\n", + " 1.8795e-02, 1.0221e-02],\n", + " [-2.4459e-02, 1.1684e-02, -2.8023e-02, ..., 8.0104e-03,\n", + " 8.5950e-05, 1.0542e-02],\n", + " [-4.5679e-03, -1.1421e-02, -2.1099e-02, ..., 4.5089e-03,\n", + " -3.0686e-02, -9.6116e-03]])),\n", + " ('model.layers.25.mlp.up_proj.weight',\n", + " tensor([[-0.0204, -0.0013, -0.0264, ..., -0.0081, -0.0027, 0.0215],\n", + " [-0.0161, 0.0051, -0.0111, ..., -0.0244, 0.0043, -0.0043],\n", + " [-0.0511, 0.0006, -0.0249, ..., 0.0069, 0.0615, 0.0123],\n", + " ...,\n", + " [-0.0086, -0.0016, 0.0064, ..., -0.0347, 0.0097, -0.0134],\n", + " [-0.0003, 0.0015, -0.0053, ..., 0.0210, 0.0135, 0.0337],\n", + " [-0.0205, 0.0028, -0.0272, ..., -0.0168, -0.0072, 0.0019]])),\n", + " ('model.layers.25.mlp.down_proj.weight',\n", + " tensor([[ 0.0166, 0.0044, 0.0180, ..., -0.0127, 0.0070, -0.0066],\n", + " [-0.0056, 0.0140, 0.0151, ..., -0.0239, -0.0140, 0.0470],\n", + " [-0.0030, -0.0093, -0.0188, ..., -0.0090, -0.0092, -0.0088],\n", + " ...,\n", + " [ 0.0465, 0.0277, -0.0349, ..., 0.0424, 0.0015, 0.0206],\n", + " [-0.0096, 0.0174, 0.0250, ..., -0.0142, -0.0022, -0.0141],\n", + " [-0.0195, -0.0174, 0.0033, ..., 0.0027, -0.0061, -0.0108]])),\n", + " ('model.layers.25.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.25.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.26.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.26.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.26.mixer.in_proj.weight',\n", + " tensor([[ 0.0112, 0.0060, -0.0038, ..., -0.0164, 0.0111, 0.0105],\n", + " [ 0.0227, -0.0248, 0.0240, ..., 0.0103, -0.0373, -0.0051],\n", + " [-0.0073, 0.0227, -0.0190, ..., 0.0048, -0.0101, -0.0137],\n", + " ...,\n", + " [ 0.0086, -0.0084, 0.0177, ..., -0.0245, 0.0119, 0.0022],\n", + " [-0.0080, -0.0284, 0.0440, ..., 0.0340, -0.0093, 0.0130],\n", + " [-0.0107, 0.0234, -0.0279, ..., 0.0106, -0.0169, -0.0001]])),\n", + " ('model.layers.26.mixer.conv1d.weight',\n", + " tensor([[[ 0.0550, -0.3464, -0.2378, -0.1244]],\n", + " \n", + " [[-0.0925, -0.2497, 0.2629, -0.1821]],\n", + " \n", + " [[-0.4524, 0.3462, -0.4604, -0.2758]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.4555, -0.0839, 0.3936, -0.3707]],\n", + " \n", + " [[ 0.3409, -0.4109, 0.0890, -0.3629]],\n", + " \n", + " [[-0.2769, 0.4033, -0.1090, 0.3055]]])),\n", + " ('model.layers.26.mixer.conv1d.bias',\n", + " tensor([-0.2286, -0.2395, -0.2517, ..., 0.0537, 0.0906, 0.4936])),\n", + " ('model.layers.26.mixer.out_proj.weight',\n", + " tensor([[-0.0316, -0.0423, -0.0053, ..., 0.0024, 0.0084, -0.0270],\n", + " [ 0.0458, -0.0243, 0.0060, ..., -0.0007, -0.0161, -0.0232],\n", + " [ 0.0388, -0.0126, 0.0184, ..., -0.0059, 0.0061, 0.0090],\n", + " ...,\n", + " [ 0.0487, 0.0305, -0.0175, ..., -0.0250, -0.0158, -0.0035],\n", + " [-0.0148, -0.0224, 0.0095, ..., -0.0102, -0.0226, 0.0272],\n", + " [-0.0061, 0.0067, 0.0069, ..., 0.0038, -0.0277, -0.0168]])),\n", + " ('model.layers.26.mlp.gate_proj.weight',\n", + " tensor([[-1.9812e-02, 8.3232e-03, 3.0347e-03, ..., 2.1982e-02,\n", + " 1.3550e-02, -1.1203e-02],\n", + " [ 2.2460e-02, 4.9811e-03, -2.2167e-02, ..., 1.3932e-03,\n", + " 5.3891e-03, -2.8310e-02],\n", + " [ 1.1011e-02, -1.2903e-02, -2.8861e-02, ..., 2.6808e-02,\n", + " -2.8479e-03, -1.3105e-02],\n", + " ...,\n", + " [ 1.1078e-03, -1.1789e-02, -4.4165e-02, ..., 8.2950e-03,\n", + " -1.8015e-02, -1.2234e-02],\n", + " [-2.0721e-02, -4.7919e-04, -4.9474e-02, ..., 7.9999e-05,\n", + " 1.7886e-02, -4.4699e-02],\n", + " [ 8.1279e-03, 1.2636e-02, -2.0932e-02, ..., -3.0361e-03,\n", + " 3.3468e-03, 2.7677e-02]])),\n", + " ('model.layers.26.mlp.up_proj.weight',\n", + " tensor([[-0.0301, -0.0025, -0.0147, ..., -0.0186, 0.0058, -0.0057],\n", + " [ 0.0303, -0.0341, 0.0142, ..., -0.0252, -0.0247, 0.0280],\n", + " [ 0.0209, -0.0425, 0.0073, ..., 0.0063, -0.0040, -0.0076],\n", + " ...,\n", + " [-0.0172, -0.0199, 0.0125, ..., 0.0363, 0.0118, -0.0124],\n", + " [-0.0108, 0.0042, -0.0475, ..., 0.0091, -0.0185, 0.0144],\n", + " [-0.0275, -0.0049, 0.0183, ..., -0.0001, -0.0119, -0.0359]])),\n", + " ('model.layers.26.mlp.down_proj.weight',\n", + " tensor([[-0.0197, -0.0082, -0.0224, ..., -0.0469, -0.0076, -0.0375],\n", + " [-0.0070, -0.0071, 0.0190, ..., -0.0125, 0.0068, 0.0166],\n", + " [ 0.0062, -0.0072, 0.0189, ..., -0.0244, -0.0292, -0.0328],\n", + " ...,\n", + " [-0.0054, 0.0219, 0.0058, ..., 0.0118, 0.0136, -0.0221],\n", + " [-0.0133, 0.0299, -0.0182, ..., -0.0496, -0.0202, 0.0196],\n", + " [-0.0131, -0.0237, -0.0473, ..., 0.0066, 0.0119, 0.0100]])),\n", + " ('model.layers.26.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.26.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.27.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.27.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.27.mixer.in_proj.weight',\n", + " tensor([[ 0.0200, -0.0276, -0.0274, ..., 0.0282, 0.0025, 0.0215],\n", + " [ 0.0054, 0.0218, -0.0175, ..., -0.0054, 0.0211, -0.0073],\n", + " [ 0.0100, -0.0023, 0.0162, ..., 0.0008, -0.0193, -0.0050],\n", + " ...,\n", + " [-0.0241, -0.0197, -0.0142, ..., 0.0039, -0.0175, 0.0045],\n", + " [ 0.0214, 0.0137, -0.0155, ..., -0.0212, 0.0089, 0.0165],\n", + " [ 0.0086, 0.0181, 0.0069, ..., -0.0093, -0.0272, 0.0068]])),\n", + " ('model.layers.27.mixer.conv1d.weight',\n", + " tensor([[[ 0.0519, 0.2061, 0.2635, 0.4916]],\n", + " \n", + " [[ 0.3745, -0.0860, -0.2310, -0.4250]],\n", + " \n", + " [[ 0.0565, 0.3699, 0.2812, -0.4201]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.4073, 0.1852, -0.1687, -0.2643]],\n", + " \n", + " [[-0.0865, -0.0894, 0.2650, -0.4522]],\n", + " \n", + " [[-0.0987, 0.0925, -0.2098, 0.0325]]])),\n", + " ('model.layers.27.mixer.conv1d.bias',\n", + " tensor([-0.4788, -0.0231, -0.4210, ..., -0.3143, -0.2893, 0.0570])),\n", + " ('model.layers.27.mixer.out_proj.weight',\n", + " tensor([[-0.0294, -0.0038, -0.0213, ..., -0.0141, 0.0072, -0.0359],\n", + " [ 0.0131, 0.0173, 0.0159, ..., 0.0030, 0.0400, -0.0065],\n", + " [-0.0111, 0.0374, 0.0109, ..., -0.0338, 0.0312, 0.0073],\n", + " ...,\n", + " [-0.0004, 0.0282, 0.0148, ..., 0.0165, 0.0062, -0.0177],\n", + " [ 0.0265, -0.0331, -0.0056, ..., 0.0407, 0.0154, 0.0176],\n", + " [ 0.0209, -0.0293, 0.0009, ..., -0.0240, -0.0029, -0.0407]])),\n", + " ('model.layers.27.mlp.gate_proj.weight',\n", + " tensor([[-0.0118, 0.0202, -0.0012, ..., 0.0101, 0.0075, 0.0102],\n", + " [ 0.0102, -0.0062, 0.0330, ..., -0.0024, -0.0245, -0.0237],\n", + " [-0.0008, 0.0202, -0.0097, ..., 0.0022, -0.0152, -0.0128],\n", + " ...,\n", + " [-0.0461, 0.0178, 0.0253, ..., 0.0319, 0.0173, -0.0099],\n", + " [ 0.0014, -0.0256, 0.0224, ..., 0.0272, 0.0045, 0.0192],\n", + " [ 0.0146, -0.0357, -0.0089, ..., -0.0147, 0.0383, 0.0354]])),\n", + " ('model.layers.27.mlp.up_proj.weight',\n", + " tensor([[-3.1854e-02, -1.0290e-03, -3.4564e-03, ..., 3.3551e-03,\n", + " 3.2845e-02, 2.1107e-02],\n", + " [-4.8083e-04, -5.8388e-03, 1.7324e-03, ..., 2.0575e-02,\n", + " -1.1685e-02, 1.2504e-02],\n", + " [ 4.6267e-02, -1.8935e-02, -2.4184e-02, ..., -4.8211e-02,\n", + " -3.3912e-04, 3.0527e-02],\n", + " ...,\n", + " [-6.9427e-03, -4.8680e-03, 3.2021e-02, ..., 1.4236e-02,\n", + " 1.9532e-02, 1.3339e-02],\n", + " [ 1.2463e-02, -5.5923e-03, -1.5680e-02, ..., 8.7956e-03,\n", + " 2.8262e-02, -1.2526e-02],\n", + " [-4.8530e-03, -8.8749e-05, 3.3507e-02, ..., -2.8260e-02,\n", + " -2.0571e-03, -8.3943e-03]])),\n", + " ('model.layers.27.mlp.down_proj.weight',\n", + " tensor([[-0.0457, -0.0267, -0.0210, ..., -0.0093, -0.0016, -0.0008],\n", + " [-0.0053, 0.0284, -0.0003, ..., 0.0065, -0.0117, 0.0243],\n", + " [ 0.0120, 0.0023, -0.0180, ..., -0.0003, -0.0313, 0.0163],\n", + " ...,\n", + " [-0.0160, 0.0207, 0.0082, ..., 0.0153, 0.0131, 0.0034],\n", + " [-0.0073, 0.0424, 0.0274, ..., -0.0075, -0.0554, -0.0114],\n", + " [-0.0192, 0.0268, 0.0036, ..., 0.0094, 0.0045, 0.0030]])),\n", + " ('model.layers.27.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.27.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.norm.weight', tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('lm_head.weight',\n", + " tensor([[-0.0141, -0.0445, 0.0071, ..., -0.0143, -0.0239, -0.0512],\n", + " [ 0.0295, -0.0317, -0.0201, ..., -0.0082, 0.0231, -0.0030],\n", + " [-0.0255, -0.0139, 0.0020, ..., -0.0040, -0.0154, 0.0336],\n", + " ...,\n", + " [ 0.0095, 0.0361, 0.0135, ..., -0.0018, 0.0074, -0.0311],\n", + " [-0.0092, 0.0060, 0.0594, ..., -0.0046, 0.0117, 0.0364],\n", + " [ 0.0228, -0.0265, -0.0262, ..., 0.0038, 0.0097, -0.0257]]))])" ] }, - "execution_count": 9, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "apriel_ssm_config" + "apriel_ssm.state_dict()" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "N params SSM: 5.660780512\n" + "N params SSM: 5.305533088\n" ] } ], @@ -246,7 +2222,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -258,10 +2234,10 @@ " (layers): ModuleList(\n", " (0-27): 28 x AprielDecoderLayer(\n", " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=11304, bias=False)\n", - " (conv1d): Conv1d(7176, 7176, kernel_size=(4,), stride=(1,), padding=(3,), groups=7176)\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", " (act): Identity()\n", - " (out_proj): Linear(in_features=4104, out_features=4096, bias=False)\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", " )\n", " (mlp): AprielMLP(\n", " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", @@ -279,7 +2255,7 @@ ")" ] }, - "execution_count": 11, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -291,7 +2267,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -300,7 +2276,7 @@ "_IncompatibleKeys(missing_keys=['model.layers.0.mixer.z_bias', 'model.layers.0.mixer.D', 'model.layers.0.mixer.in_proj.weight', 'model.layers.0.mixer.conv1d.weight', 'model.layers.0.mixer.conv1d.bias', 'model.layers.0.mixer.out_proj.weight', 'model.layers.1.mixer.z_bias', 'model.layers.1.mixer.D', 'model.layers.1.mixer.in_proj.weight', 'model.layers.1.mixer.conv1d.weight', 'model.layers.1.mixer.conv1d.bias', 'model.layers.1.mixer.out_proj.weight', 'model.layers.2.mixer.z_bias', 'model.layers.2.mixer.D', 'model.layers.2.mixer.in_proj.weight', 'model.layers.2.mixer.conv1d.weight', 'model.layers.2.mixer.conv1d.bias', 'model.layers.2.mixer.out_proj.weight', 'model.layers.3.mixer.z_bias', 'model.layers.3.mixer.D', 'model.layers.3.mixer.in_proj.weight', 'model.layers.3.mixer.conv1d.weight', 'model.layers.3.mixer.conv1d.bias', 'model.layers.3.mixer.out_proj.weight', 'model.layers.4.mixer.z_bias', 'model.layers.4.mixer.D', 'model.layers.4.mixer.in_proj.weight', 'model.layers.4.mixer.conv1d.weight', 'model.layers.4.mixer.conv1d.bias', 'model.layers.4.mixer.out_proj.weight', 'model.layers.5.mixer.z_bias', 'model.layers.5.mixer.D', 'model.layers.5.mixer.in_proj.weight', 'model.layers.5.mixer.conv1d.weight', 'model.layers.5.mixer.conv1d.bias', 'model.layers.5.mixer.out_proj.weight', 'model.layers.6.mixer.z_bias', 'model.layers.6.mixer.D', 'model.layers.6.mixer.in_proj.weight', 'model.layers.6.mixer.conv1d.weight', 'model.layers.6.mixer.conv1d.bias', 'model.layers.6.mixer.out_proj.weight', 'model.layers.7.mixer.z_bias', 'model.layers.7.mixer.D', 'model.layers.7.mixer.in_proj.weight', 'model.layers.7.mixer.conv1d.weight', 'model.layers.7.mixer.conv1d.bias', 'model.layers.7.mixer.out_proj.weight', 'model.layers.8.mixer.z_bias', 'model.layers.8.mixer.D', 'model.layers.8.mixer.in_proj.weight', 'model.layers.8.mixer.conv1d.weight', 'model.layers.8.mixer.conv1d.bias', 'model.layers.8.mixer.out_proj.weight', 'model.layers.9.mixer.z_bias', 'model.layers.9.mixer.D', 'model.layers.9.mixer.in_proj.weight', 'model.layers.9.mixer.conv1d.weight', 'model.layers.9.mixer.conv1d.bias', 'model.layers.9.mixer.out_proj.weight', 'model.layers.10.mixer.z_bias', 'model.layers.10.mixer.D', 'model.layers.10.mixer.in_proj.weight', 'model.layers.10.mixer.conv1d.weight', 'model.layers.10.mixer.conv1d.bias', 'model.layers.10.mixer.out_proj.weight', 'model.layers.11.mixer.z_bias', 'model.layers.11.mixer.D', 'model.layers.11.mixer.in_proj.weight', 'model.layers.11.mixer.conv1d.weight', 'model.layers.11.mixer.conv1d.bias', 'model.layers.11.mixer.out_proj.weight', 'model.layers.12.mixer.z_bias', 'model.layers.12.mixer.D', 'model.layers.12.mixer.in_proj.weight', 'model.layers.12.mixer.conv1d.weight', 'model.layers.12.mixer.conv1d.bias', 'model.layers.12.mixer.out_proj.weight', 'model.layers.13.mixer.z_bias', 'model.layers.13.mixer.D', 'model.layers.13.mixer.in_proj.weight', 'model.layers.13.mixer.conv1d.weight', 'model.layers.13.mixer.conv1d.bias', 'model.layers.13.mixer.out_proj.weight', 'model.layers.14.mixer.z_bias', 'model.layers.14.mixer.D', 'model.layers.14.mixer.in_proj.weight', 'model.layers.14.mixer.conv1d.weight', 'model.layers.14.mixer.conv1d.bias', 'model.layers.14.mixer.out_proj.weight', 'model.layers.15.mixer.z_bias', 'model.layers.15.mixer.D', 'model.layers.15.mixer.in_proj.weight', 'model.layers.15.mixer.conv1d.weight', 'model.layers.15.mixer.conv1d.bias', 'model.layers.15.mixer.out_proj.weight', 'model.layers.16.mixer.z_bias', 'model.layers.16.mixer.D', 'model.layers.16.mixer.in_proj.weight', 'model.layers.16.mixer.conv1d.weight', 'model.layers.16.mixer.conv1d.bias', 'model.layers.16.mixer.out_proj.weight', 'model.layers.17.mixer.z_bias', 'model.layers.17.mixer.D', 'model.layers.17.mixer.in_proj.weight', 'model.layers.17.mixer.conv1d.weight', 'model.layers.17.mixer.conv1d.bias', 'model.layers.17.mixer.out_proj.weight', 'model.layers.18.mixer.z_bias', 'model.layers.18.mixer.D', 'model.layers.18.mixer.in_proj.weight', 'model.layers.18.mixer.conv1d.weight', 'model.layers.18.mixer.conv1d.bias', 'model.layers.18.mixer.out_proj.weight', 'model.layers.19.mixer.z_bias', 'model.layers.19.mixer.D', 'model.layers.19.mixer.in_proj.weight', 'model.layers.19.mixer.conv1d.weight', 'model.layers.19.mixer.conv1d.bias', 'model.layers.19.mixer.out_proj.weight', 'model.layers.20.mixer.z_bias', 'model.layers.20.mixer.D', 'model.layers.20.mixer.in_proj.weight', 'model.layers.20.mixer.conv1d.weight', 'model.layers.20.mixer.conv1d.bias', 'model.layers.20.mixer.out_proj.weight', 'model.layers.21.mixer.z_bias', 'model.layers.21.mixer.D', 'model.layers.21.mixer.in_proj.weight', 'model.layers.21.mixer.conv1d.weight', 'model.layers.21.mixer.conv1d.bias', 'model.layers.21.mixer.out_proj.weight', 'model.layers.22.mixer.z_bias', 'model.layers.22.mixer.D', 'model.layers.22.mixer.in_proj.weight', 'model.layers.22.mixer.conv1d.weight', 'model.layers.22.mixer.conv1d.bias', 'model.layers.22.mixer.out_proj.weight', 'model.layers.23.mixer.z_bias', 'model.layers.23.mixer.D', 'model.layers.23.mixer.in_proj.weight', 'model.layers.23.mixer.conv1d.weight', 'model.layers.23.mixer.conv1d.bias', 'model.layers.23.mixer.out_proj.weight', 'model.layers.24.mixer.z_bias', 'model.layers.24.mixer.D', 'model.layers.24.mixer.in_proj.weight', 'model.layers.24.mixer.conv1d.weight', 'model.layers.24.mixer.conv1d.bias', 'model.layers.24.mixer.out_proj.weight', 'model.layers.25.mixer.z_bias', 'model.layers.25.mixer.D', 'model.layers.25.mixer.in_proj.weight', 'model.layers.25.mixer.conv1d.weight', 'model.layers.25.mixer.conv1d.bias', 'model.layers.25.mixer.out_proj.weight', 'model.layers.26.mixer.z_bias', 'model.layers.26.mixer.D', 'model.layers.26.mixer.in_proj.weight', 'model.layers.26.mixer.conv1d.weight', 'model.layers.26.mixer.conv1d.bias', 'model.layers.26.mixer.out_proj.weight', 'model.layers.27.mixer.z_bias', 'model.layers.27.mixer.D', 'model.layers.27.mixer.in_proj.weight', 'model.layers.27.mixer.conv1d.weight', 'model.layers.27.mixer.conv1d.bias', 'model.layers.27.mixer.out_proj.weight'], unexpected_keys=['model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.18.self_attn.q_proj.weight', 'model.layers.18.self_attn.k_proj.weight', 'model.layers.18.self_attn.v_proj.weight', 'model.layers.18.self_attn.o_proj.weight', 'model.layers.19.self_attn.q_proj.weight', 'model.layers.19.self_attn.k_proj.weight', 'model.layers.19.self_attn.v_proj.weight', 'model.layers.19.self_attn.o_proj.weight', 'model.layers.20.self_attn.q_proj.weight', 'model.layers.20.self_attn.k_proj.weight', 'model.layers.20.self_attn.v_proj.weight', 'model.layers.20.self_attn.o_proj.weight', 'model.layers.21.self_attn.q_proj.weight', 'model.layers.21.self_attn.k_proj.weight', 'model.layers.21.self_attn.v_proj.weight', 'model.layers.21.self_attn.o_proj.weight', 'model.layers.22.self_attn.q_proj.weight', 'model.layers.22.self_attn.k_proj.weight', 'model.layers.22.self_attn.v_proj.weight', 'model.layers.22.self_attn.o_proj.weight', 'model.layers.23.self_attn.q_proj.weight', 'model.layers.23.self_attn.k_proj.weight', 'model.layers.23.self_attn.v_proj.weight', 'model.layers.23.self_attn.o_proj.weight', 'model.layers.24.self_attn.q_proj.weight', 'model.layers.24.self_attn.k_proj.weight', 'model.layers.24.self_attn.v_proj.weight', 'model.layers.24.self_attn.o_proj.weight', 'model.layers.25.self_attn.q_proj.weight', 'model.layers.25.self_attn.k_proj.weight', 'model.layers.25.self_attn.v_proj.weight', 'model.layers.25.self_attn.o_proj.weight', 'model.layers.26.self_attn.q_proj.weight', 'model.layers.26.self_attn.k_proj.weight', 'model.layers.26.self_attn.v_proj.weight', 'model.layers.26.self_attn.o_proj.weight', 'model.layers.27.self_attn.q_proj.weight', 'model.layers.27.self_attn.k_proj.weight', 'model.layers.27.self_attn.v_proj.weight', 'model.layers.27.self_attn.o_proj.weight'])" ] }, - "execution_count": 12, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -311,7 +2287,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -323,10 +2299,10 @@ " (layers): ModuleList(\n", " (0-27): 28 x AprielDecoderLayer(\n", " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=11304, bias=False)\n", - " (conv1d): Conv1d(7176, 7176, kernel_size=(4,), stride=(1,), padding=(3,), groups=7176)\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", " (act): Identity()\n", - " (out_proj): Linear(in_features=4104, out_features=4096, bias=False)\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", " )\n", " (mlp): AprielMLP(\n", " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", @@ -344,7 +2320,7 @@ ")" ] }, - "execution_count": 13, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -356,7 +2332,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -372,26 +2348,29 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 2, "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/toolkit/.local/lib/python3.12/site-packages/transformers/modeling_utils.py:2714: UserWarning: `save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead.\n", - " warnings.warn(\n" + "ename": "NameError", + "evalue": "name 'apriel_ssm' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[2], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mapriel_ssm\u001b[49m\u001b[38;5;241m.\u001b[39msave_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/mnt/checkpoints/ssm/apriel_ssm_instruct_base\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 2\u001b[0m save_config\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'apriel_ssm' is not defined" ] } ], "source": [ - "apriel_ssm.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm\",\n", + "apriel_ssm.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_instruct_base\",\n", " save_config=True)\n" ] }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -400,7 +2379,7 @@ "24" ] }, - "execution_count": 60, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -411,7 +2390,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -423,10 +2402,10 @@ " (layers): ModuleList(\n", " (0-27): 28 x AprielDecoderLayer(\n", " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=12320, bias=False)\n", - " (conv1d): Conv1d(8192, 8192, kernel_size=(4,), stride=(1,), padding=(3,), groups=8192)\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", " (act): Identity()\n", - " (out_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", " )\n", " (mlp): AprielMLP(\n", " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", @@ -439,13 +2418,12 @@ " )\n", " )\n", " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (rotary_emb): AprielRotaryEmbedding()\n", " )\n", " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", ")" ] }, - "execution_count": 10, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -463,7 +2441,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -482,30 +2460,30 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "CustomMambaCausalLMOutput(loss=None, logits=tensor([[[-5.4688, -1.6641, 0.4609, ..., -7.1562, -3.7812, -5.9062],\n", - " [-3.5000, 1.4297, 4.3125, ..., -5.3438, -4.9375, -2.9844],\n", - " [-3.1094, 0.7930, 2.2969, ..., -3.1250, -4.1875, -2.1250],\n", + "CustomMambaCausalLMOutput(loss=None, logits=tensor([[[-3.0781, 2.3594, 1.4609, ..., -2.3438, -1.9688, 0.6484],\n", + " [-5.8125, 4.9688, 0.4414, ..., -4.2500, -3.5156, -4.8125],\n", + " [-5.5000, 3.3594, 1.1484, ..., -3.4375, -2.3125, -4.4375],\n", " ...,\n", - " [-5.3438, -3.0938, -3.9062, ..., -4.9062, -3.0000, -3.9688],\n", - " [-3.0625, -3.2188, 5.6562, ..., -2.7812, -2.5938, -6.6562],\n", - " [-1.8438, -1.7500, 5.9062, ..., -3.7188, -2.1250, -0.8281]]],\n", - " device='cuda:0', grad_fn=), all_hidden_states=(), last_hidden_state=tensor([[[ 1.2266, 0.5547, -1.1953, ..., 0.1089, -2.5781, 0.6328],\n", - " [-0.4395, 0.5938, -0.1562, ..., -0.6719, -0.6367, -0.3086],\n", - " [ 0.0077, 0.6680, -1.0703, ..., -3.6875, 0.2207, 0.1299],\n", + " [-2.2812, 0.1465, 2.2344, ..., -7.6875, -3.0312, -6.2500],\n", + " [-6.8750, 1.7812, -1.3750, ..., -7.4688, -5.6875, -4.4062],\n", + " [-2.0156, 2.0938, 3.1094, ..., -3.0156, -2.1406, -2.2812]]],\n", + " device='cuda:0', grad_fn=), all_hidden_states=(), last_hidden_state=tensor([[[-1.3828, 0.0625, -2.7500, ..., -0.6523, -0.8906, 1.4609],\n", + " [ 2.1406, -0.0247, -3.0156, ..., -0.0074, 1.0234, 1.3828],\n", + " [ 1.6016, -0.7266, -1.2422, ..., -0.4004, -0.8242, -0.5586],\n", " ...,\n", - " [-0.0703, 0.4551, 0.1104, ..., 1.3438, 1.3984, 1.1641],\n", - " [-0.0613, 1.9141, -0.5430, ..., -1.0312, -0.6680, 0.0518],\n", - " [-0.6172, 0.2148, -0.5977, ..., -1.2734, -0.1914, 2.2344]]],\n", + " [ 1.5234, -0.0262, -1.5469, ..., -0.4922, -1.0078, 1.2344],\n", + " [-0.4629, -0.6055, -1.3906, ..., -0.9922, -0.3066, 1.1875],\n", + " [-0.7539, -0.0243, -2.4688, ..., -1.0625, -2.7188, 2.6875]]],\n", " device='cuda:0', dtype=torch.bfloat16, grad_fn=))" ] }, - "execution_count": 56, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -520,6 +2498,94 @@ "metadata": {}, "outputs": [], "source": [] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "import enum" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "class SSMBlockType(str, enum.Enum):\n", + " \"\"\"\n", + " An enum for the available mamba types for the MLP layer.\n", + " \"\"\"\n", + "\n", + " mamba = \"m\"\n", + " mamba2_discrete = \"m2d\"\n", + " mamba2 = \"m2\"\n", + " transformer = \"t\"" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_values([, , , ])" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'m' in SSMBlockType.__members__.values()" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "ename": "KeyError", + "evalue": "'m'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[21], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mm\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[43mSSMBlockType\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mm\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241m.\u001b[39mname\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/enum.py:808\u001b[0m, in \u001b[0;36mEnumType.__getitem__\u001b[0;34m(cls, name)\u001b[0m\n\u001b[1;32m 804\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mcls\u001b[39m, name):\n\u001b[1;32m 805\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 806\u001b[0m \u001b[38;5;124;03m Return the member matching `name`.\u001b[39;00m\n\u001b[1;32m 807\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 808\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_member_map_\u001b[49m\u001b[43m[\u001b[49m\u001b[43mname\u001b[49m\u001b[43m]\u001b[49m\n", + "\u001b[0;31mKeyError\u001b[0m: 'm'" + ] + } + ], + "source": [ + "\"m\" == SSMBlockType[\"m\"].name\n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'m2d'" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "SSMBlockType.mamba2_discrete.value" + ] } ], "metadata": { From 77ad39f7730f314d01c3c8f5da14f1ac8aabf5f4 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 30 Apr 2025 20:21:08 +0000 Subject: [PATCH 043/114] add token-prediction loss coefficients --- fast_llm/layers/language_model/config.py | 10 ++++++++++ fast_llm/layers/language_model/head.py | 5 ++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index c99ee4f6..c675361a 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -215,6 +215,12 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): doc="May be used to freeze the output weights by setting their scale to zero.", hint=FieldHint.feature, ) + prediction_loss_coefficient: list[float] | None = Field( + default=None, + desc="Loss coefficient for each prediction head.", + doc="If not provided, all heads are equally weighted.", + hint=FieldHint.feature, + ) def _validate(self) -> None: self.transformer.validate() @@ -231,3 +237,7 @@ def _validate(self) -> None: if self.distillation_model is not None: if self.prediction_heads > 1: raise NotImplementedError("Multi-token prediction not supported with distillation.") + if isinstance(self.prediction_loss_coefficient, list): + Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads) + for coeff in self.prediction_loss_coefficient: + Assert.geq(coeff, 0) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 1153fb2c..014a617c 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -57,6 +57,9 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + self._loss_coefficient = ( + config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 + ) self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) self.final_norm = config.transformer.normalization.get_layer(hidden_dim) self._logits_scale_factor = config.logits_scale_factor @@ -133,7 +136,7 @@ def forward( else: if self.training: # Backward hook to compute the gradient of the loss - shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, 1.0) + shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, self._loss_coefficient) # MTP: Return shared_hidden to be used by the next head. return shared_hidden From da9bf1a78efbf4a34af23b9ea55ee1b98343230c Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 1 May 2025 13:54:27 +0000 Subject: [PATCH 044/114] eval apriel ssm --- .../ssm/external/eval/apriel_eval_wrapper.py | 59 +++++++++++++++++++ .../models/ssm/external/eval/run_lm_eval.py | 6 ++ .../ssm/external/modeling_ssm_apriel.py | 55 ++++++++++------- 3 files changed, 98 insertions(+), 22 deletions(-) create mode 100644 fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py create mode 100644 fast_llm/models/ssm/external/eval/run_lm_eval.py diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py new file mode 100644 index 00000000..94537c33 --- /dev/null +++ b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py @@ -0,0 +1,59 @@ +from typing import Optional, Union + +import lm_eval.models.utils +import torch +from lm_eval.api.registry import register_model +from lm_eval.models.huggingface import HFLM + + +@register_model("apriel_ssm") +class AprielSSMWrapper(HFLM): + """Wrapper for Rene model for compatibility with lm-evaluation-harness.""" + + def __init__(self, pretrained, **kwargs) -> None: + if "backend" in kwargs: + # rene currently only supports causal models + assert kwargs["backend"] == "causal" + + super().__init__( + pretrained=pretrained, + backend=kwargs.pop("backend", "causal"), + tokenizer=kwargs.pop("tokenizer", "/mnt/checkpoints/upstream/Mistral-Nemo-Base-2407/"), + max_length=kwargs.pop("max_length", 4096), + **kwargs, + ) + + def _get_config(self, pretrained: str, **kwargs) -> None: + """Get the model configuration.""" + from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig + + self._config = AprielSSMConfig.from_pretrained(pretrained) + + def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: + """Create the model.""" + from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM + + self._model = AprielSSMForCausalLM.from_pretrained( + pretrained, + device=self._device, + dtype=torch.bfloat16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), + trust_remote_code=True, + ) + + def _model_generate(self, context, max_length, stop, **generation_kwargs): + """Generate text from the model.""" + for key in ("do_sample", "attention_mask"): + if key in generation_kwargs: + generation_kwargs.pop(key) + + # The custom GenerationMixin imported from mamba_ssm currently does not support + # passing stopping criteria. + # For the time being, we simply generate to max length, then truncate (equivalent result). + # This should be revisited to speed up generation + # stopping_criteria = stop_sequences_criteria(self.tokenizer, stop, 1, context.shape[0]) + + return self.model.generate( + input_ids=context, + max_length=max_length, + **generation_kwargs, + ) diff --git a/fast_llm/models/ssm/external/eval/run_lm_eval.py b/fast_llm/models/ssm/external/eval/run_lm_eval.py new file mode 100644 index 00000000..af07869a --- /dev/null +++ b/fast_llm/models/ssm/external/eval/run_lm_eval.py @@ -0,0 +1,6 @@ +from lm_eval.__main__ import cli_evaluate + +from fast_llm.models.ssm.external.eval.apriel_eval_wrapper import AprielSSMWrapper # noqa: F401 + +if __name__ == "__main__": + cli_evaluate() diff --git a/fast_llm/models/ssm/external/modeling_ssm_apriel.py b/fast_llm/models/ssm/external/modeling_ssm_apriel.py index d30d5b66..5a1b8db4 100644 --- a/fast_llm/models/ssm/external/modeling_ssm_apriel.py +++ b/fast_llm/models/ssm/external/modeling_ssm_apriel.py @@ -19,7 +19,7 @@ from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging from transformers.utils.generic import ModelOutput -from .configuration_ssm_apriel import AprielSSMConfig +from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig logger = logging.get_logger(__name__) @@ -35,12 +35,13 @@ class CustomMambaCausalLMOutput(ModelOutput): class AprielRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, hidden_size, eps=1e-6, device=None, dtype=None, **kwargs): """ AprielRMSNorm is equivalent to T5LayerNorm """ + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) + self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) self.variance_epsilon = eps def forward(self, hidden_states): @@ -58,14 +59,15 @@ def extra_repr(self): class AprielMLP(nn.Module): - def __init__(self, config): - super().__init__() + def __init__(self, config, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias, **factory_kwargs) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): @@ -437,19 +439,21 @@ def convolutional_step(self, xBC, conv_state): class AprielDecoderLayer(nn.Module): - def __init__(self, config: AprielSSMConfig, layer_idx: int): - super().__init__() + def __init__(self, config: AprielSSMConfig, layer_idx: int, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} self.hidden_size = config.hidden_size self.mixer = DiscreteMamba2( d_model=config.hidden_size, layer_idx=layer_idx, **config.ssm_cfg, + **factory_kwargs, ) - self.mlp = AprielMLP(config) - self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = AprielMLP(config, **factory_kwargs) + self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) + self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) def forward( self, hidden_states: torch.Tensor, inference_params=None, **kwargs @@ -598,16 +602,16 @@ class AprielSSMModel(AprielSSMPreTrainedModel): config: AprielSSMConfig """ - def __init__(self, config: AprielSSMConfig): - super().__init__(config) + def __init__(self, config: AprielSSMConfig, device=None, dtype=None, **kwargs): + super().__init__(config, device=device, dtype=dtype, **kwargs) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + factory_kwargs = {"device": device, "dtype": dtype} + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, **factory_kwargs) self.layers = nn.ModuleList( - [AprielDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [AprielDecoderLayer(config, layer_idx, **factory_kwargs) for layer_idx in range(config.num_hidden_layers)] ) - self.norm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -664,11 +668,12 @@ class AprielSSMForCausalLM(AprielSSMPreTrainedModel, GenerationMixin): _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - def __init__(self, config): - super().__init__(config) + def __init__(self, config, device=None, dtype=None, **kwargs): + super().__init__(config, device=device, dtype=dtype, **kwargs) self.model = AprielSSMModel(config) self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + factory_kwargs = {"device": device, "dtype": dtype} + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, **factory_kwargs) # Initialize weights and apply final processing self.post_init() @@ -722,6 +727,12 @@ def forward( last_hidden_state=outputs["last_hidden_state"], ) + def generate(self, *args, **kwargs): + """ + This is a wrapper to make sure we comply with the HF generation interface for eval harness + """ + return super().generate(*args, **kwargs) + __all__ = [ "AprielSSMForCausalLM", From ac4a5982d8ea8d04f8d8ebdb693e59f82f58e6b9 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 1 May 2025 12:26:50 -0400 Subject: [PATCH 045/114] fix --- fast_llm/models/gpt/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 9e28373b..80c9caa2 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -214,7 +214,6 @@ def preprocess_meta( reference_tokens, reference_kwargs_ = reference_preprocessed_meta[i] for key in ( TransformerKwargs.sequence_first, - TransformerKwargs.hidden_dims, TransformerKwargs.sequence_length, TransformerKwargs.sequence_q_dim, TransformerKwargs.sequence_k_dim, From 0c0e7d9cca6b3ca31b11d4764e9664263badc8cc Mon Sep 17 00:00:00 2001 From: Luke Nitish Kumar Date: Thu, 1 May 2025 13:23:47 -0400 Subject: [PATCH 046/114] adding check for missing `rope_type` (#246) --- fast_llm/models/gpt/conversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index bc8bea26..d4d581cd 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -407,7 +407,7 @@ def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: (export_value,) = export_values - if export_value is None or (rope_type := export_value[self._HUGGINGFACE_NAMES[0]]) == "default": + if export_value is None or export_value is MISSING or (rope_type := export_value[self._HUGGINGFACE_NAMES[0]]) == "default": return (RotaryEmbeddingType.default,) + (DEFAULT,) * 7 elif rope_type == RotaryEmbeddingType.llama3: return ("llama3", *[export_value.get(key, DEFAULT) for key in self._HUGGINGFACE_NAMES[1:]]) From 97ba9d44554908260dd24e1bb452d6bc00eee11a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 1 May 2025 14:04:02 -0400 Subject: [PATCH 047/114] Loss masking for distillation --- fast_llm/functional/cross_entropy.py | 40 ++++++++++++++------- fast_llm/functional/triton/cross_entropy.py | 7 ++++ fast_llm/layers/language_model/config.py | 1 + fast_llm/layers/language_model/head.py | 21 ++++++++--- fast_llm/models/gpt/model.py | 8 +++-- 5 files changed, 58 insertions(+), 19 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 1eb6c8c0..53b3e59b 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -8,9 +8,10 @@ from fast_llm.utils import Assert -def torch_cross_entropy_forward_backward( +def _torch_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, + loss_mask: torch.Tensor | None, grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, @@ -22,15 +23,25 @@ def torch_cross_entropy_forward_backward( TODO: loss masking only works for with labels format and if the masking index is set to -100. """ # Torch compile doesn't understand this. + if loss_mask is not None: + raise NotImplementedError(f"Torch cross-entropy from {target_format} doesn't support loss masking.") with torch.set_grad_enabled(grad_output is not None): logits_ = logits.float().detach().requires_grad_(grad_output is not None) if target_format == TargetFormat.logits: if logits_scale_factor != 1.0: target = target * logits_scale_factor target = torch.softmax(target, dim=-1) - loss = torch.nn.functional.cross_entropy( - logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target - ).mean() + if loss_mask is None: + loss = torch.nn.functional.cross_entropy( + logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target + ) + else: + loss = ( + torch.nn.functional.cross_entropy( + logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none" + ) + * loss_mask + ).mean() if grad_output is None: grad = None else: @@ -57,8 +68,8 @@ def _fused_softmax_base( return logits_norm, exp_logits, sum_exp_logits -# @torch.compile -def fused_softmax( +@torch.compile +def _fused_softmax( logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup = None, dim: int = -1 ) -> torch.Tensor: _, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group, dim) @@ -66,9 +77,10 @@ def fused_softmax( @torch.compile -def fused_cross_entropy_forward_backward( +def _fused_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, + loss_mask: torch.Tensor | None, grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, @@ -85,7 +97,7 @@ def fused_cross_entropy_forward_backward( logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) if target_format == TargetFormat.logits: - target = fused_softmax(target, logits_scale_factor, group) + target = _fused_softmax(target, logits_scale_factor, group) if target_format == TargetFormat.labels: target = target.unsqueeze(-1) @@ -101,8 +113,6 @@ def fused_cross_entropy_forward_backward( target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1)) target = (target - vocab_start_index) * target_mask else: - # TODO: Support masking - loss_mask = None # Target should be tensor-parallel already, no further manipulation needed. target_mask = None @@ -145,8 +155,8 @@ def fused_cross_entropy_forward_backward( _CROSS_ENTROPY_IMPLEMENTATIONS = { - CrossEntropyImpl.torch: torch_cross_entropy_forward_backward, - CrossEntropyImpl.fused: fused_cross_entropy_forward_backward, + CrossEntropyImpl.torch: _torch_cross_entropy_forward_backward, + CrossEntropyImpl.fused: _fused_cross_entropy_forward_backward, CrossEntropyImpl.triton: triton_cross_entropy_forward_backward, } @@ -154,6 +164,7 @@ def fused_cross_entropy_forward_backward( def cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, + loss_mask: torch.Tensor | None, grad_output: float | None, group: ProcessGroup | None = None, implementation: CrossEntropyImpl = CrossEntropyImpl.fused, @@ -169,12 +180,15 @@ def cross_entropy_forward_backward( if target_format == TargetFormat.labels: Assert.eq(target.shape, logits.shape[:-1]) Assert.eq(target.dtype, torch.int64) + assert loss_mask is None else: Assert.eq(target.shape, logits.shape) assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, target.shape) if group: Assert.eq(implementation, CrossEntropyImpl.fused) - return fused_cross_entropy_forward_backward( + return _fused_cross_entropy_forward_backward( logits, target, grad_output, logits_scale_factor, target_format, group ) else: diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index d825af03..02dc1ce7 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -57,12 +57,14 @@ def triton_cross_entropy_forward_backward_kernel( def triton_cross_entropy_from_distribution_forward_backward_kernel( logits_ptr, target_ptr, + loss_mask_ptr, grad_logits_ptr, losses_ptr, grad_losses, n_cols: tl_constexpr, logits_stride_0: tl_constexpr, target_stride_0: tl_constexpr, + loss_mask_stride_0: tl_constexpr, grad_logits_stride_0: tl_constexpr, logits_scale_factor: tl_constexpr, from_logits: tl_constexpr, @@ -87,6 +89,8 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( target = tl.load(target_ptr + block_idx * target_stride_0 + col_offsets, mask=mask, other=-float("inf")).to( tl.float32 ) + if loss_mask_ptr is not None: + loss_mask = tl.load(target_ptr + block_idx * target_stride_0 + col_offsets, mask=mask, other=0) if from_logits: if logits_scale_factor != 1.0: target *= logits_scale_factor @@ -110,6 +114,7 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( def triton_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, + loss_mask: torch.Tensor | None, grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, @@ -149,12 +154,14 @@ def triton_cross_entropy_forward_backward( triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( logits, target, + loss_mask, grad_logits, losses, None if grad_output is None else grad_output / n_rows, n_cols, logits.stride(0), target.stride(0), + None if loss_mask is None else loss_mask.stride(0), None if grad_output is None else grad_logits.stride(0), logits_scale_factor, block_size=block_size, diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 4fb471fb..0371eff4 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -35,6 +35,7 @@ class LanguageModelKwargs: # TODO: These are generic labels = "labels" phase = "phase" + loss_mask = "loss_mask" @config_class() diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 3b476f6a..9b1dd4d8 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -146,6 +146,8 @@ def _forward_backward( if self._config.distillation_model is None else f"{self._config.distillation_model}_logits" ) + # Loss mask for distillation. (Labels are already masked.) + loss_mask = None if target is not None: if self._config.distillation_model is None: # MTP: Shift the labels @@ -160,9 +162,12 @@ def _forward_backward( else: # Target is reference model logits. target = target.flatten(0, -2) + loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) if self._sequence_parallel_logits: target = split_op(target, self._tensor_space.distributed.tensor_group, 0) + if loss_mask is not None: + loss_mask = split_op(loss_mask, self._tensor_space.distributed.tensor_group, 0) do_grad = target is not None and self.training input_ = input_.detach().requires_grad_(do_grad) with torch.enable_grad(): @@ -174,7 +179,7 @@ def _forward_backward( output_weights = self._get_output_weights(kwargs) loss, ln_output_grad = self._logits_cross_entropy_forward_backward_split( - ln_output.detach(), target, output_weights, grad_output, kwargs, losses + ln_output.detach(), target, loss_mask, output_weights, grad_output, kwargs, losses ) if do_grad: @@ -194,6 +199,7 @@ def _logits_cross_entropy_forward_backward_split( self, input_: torch.Tensor, target: torch.Tensor | None, + loss_mask: torch.Tensor | None, weight: torch.Tensor, grad_output: float, kwargs: dict, @@ -201,7 +207,7 @@ def _logits_cross_entropy_forward_backward_split( ) -> tuple[torch.Tensor | None, torch.Tensor | None]: if self._cross_entropy_splits is None or target is None: loss, logit_input_grad = self._logits_cross_entropy_forward_backward( - input_, target, weight, grad_output, kwargs, losses + input_, target, loss_mask, weight, grad_output, kwargs, losses ) if target is None: # TODO: Make a proper way of returning the model output. @@ -214,12 +220,17 @@ def _logits_cross_entropy_forward_backward_split( grad_output /= self._cross_entropy_splits logit_input = input_.flatten(0, -2) logit_input_grad = torch.empty_like(logit_input) - for logit_input_, target_, logit_input_grad_ in zip( - logit_input.split(split_size), target.split(split_size), logit_input_grad.split(split_size) + for logit_input_, target_, loss_mask_, logit_input_grad_ in zip( + logit_input.split(split_size), + target.split(split_size), + [None] * self._cross_entropy_splits if loss_mask is None else loss_mask.split(split_size), + logit_input_grad.split(split_size), + strict=True, ): loss_, grad_ = self._logits_cross_entropy_forward_backward( logit_input_, target_, + loss_mask_, weight, grad_output, kwargs, @@ -240,6 +251,7 @@ def _logits_cross_entropy_forward_backward( self, input_: torch.Tensor, target: torch.Tensor | None, + loss_mask: torch.Tensor | None, weight: torch.Tensor, grad_output: float, kwargs: dict, @@ -298,6 +310,7 @@ def _logits_cross_entropy_forward_backward( loss, grad = cross_entropy_forward_backward( logits.flatten(0, -2), target, + loss_mask, group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, grad_output=grad_output, implementation=self._cross_entropy_impl, diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 80c9caa2..9084fb40 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -313,11 +313,15 @@ def preprocess( valid_spans[:, 0].clamp_(min=sequence_offset) valid_spans[:, 1].clamp_(max=sequence_k + prediction_heads - 1) valid_spans -= sequence_offset + loss_mask = torch.ones_like(labels, dtype=torch.bool) for start, end in valid_spans: if sequence_first: - labels[start : end + 1, i] = -100 + loss_mask[start : end + 1, i] = False else: - labels[i, start : end + 1] = -100 + loss_mask[i, start : end + 1] = False + if self._config.distillation_model is not None: + kwargs[LanguageModelKwargs.loss_mask] = loss_mask + labels = torch.where(loss_mask, labels, -100) kwargs[LanguageModelKwargs.labels] = labels kwargs.update(reference_logits[i]) From 231d5d82e535dbc805756604c420b366516584b5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 1 May 2025 14:33:53 -0400 Subject: [PATCH 048/114] test, misc --- fast_llm/functional/cross_entropy.py | 6 +-- tests/layers/test_lm_head.py | 71 ++++++++++++++++++++-------- 2 files changed, 52 insertions(+), 25 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 53b3e59b..34c69d79 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -23,8 +23,6 @@ def _torch_cross_entropy_forward_backward( TODO: loss masking only works for with labels format and if the masking index is set to -100. """ # Torch compile doesn't understand this. - if loss_mask is not None: - raise NotImplementedError(f"Torch cross-entropy from {target_format} doesn't support loss masking.") with torch.set_grad_enabled(grad_output is not None): logits_ = logits.float().detach().requires_grad_(grad_output is not None) if target_format == TargetFormat.logits: @@ -40,7 +38,7 @@ def _torch_cross_entropy_forward_backward( torch.nn.functional.cross_entropy( logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none" ) - * loss_mask + * loss_mask.unsqueeze(-1) ).mean() if grad_output is None: grad = None @@ -185,7 +183,7 @@ def cross_entropy_forward_backward( Assert.eq(target.shape, logits.shape) assert target.dtype.is_floating_point, target.dtype if loss_mask is not None: - Assert.eq(loss_mask.shape, target.shape) + Assert.eq(loss_mask.shape, logits.shape[:-1]) if group: Assert.eq(implementation, CrossEntropyImpl.fused) return _fused_cross_entropy_forward_backward( diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 79101f34..14edecff 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -25,6 +25,7 @@ def _lm_head( input_: torch.Tensor, target: torch.Tensor, + loss_mask: torch.Tensor | None, *, # config:LanguageModelBaseConfig, rms_weight: torch.Tensor, @@ -43,7 +44,13 @@ def _lm_head( if logit_scale_factor != 1.0: logits *= logit_scale_factor z_loss = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) if logit_z_loss > 0 else None - loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) + if target.ndim == logits.ndim: + loss = torch.nn.functional.cross_entropy(logits, target, reduction="none") + if loss_mask is not None: + loss = loss * loss_mask.unsqueeze(-1) + loss = loss.mean() + else: + loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) loss.backward(torch.full_like(loss, grad_output)) return loss, z_loss @@ -58,22 +65,26 @@ def _lm_head( @pytest.mark.slow @pytest.mark.parametrize("cross_entropy_impl", tuple(CrossEntropyImpl)) @pytest.mark.parametrize( - ("config_dict", "distributed_config_dict"), + ("config_dict", "distributed_config_dict", "loss_masking"), ( - ({}, {}), - ({}, {"training_dtype": DataType.bfloat16}), - ({"transformer": {"full_precision_residual": True}}, {"training_dtype": DataType.bfloat16}), - ({"sequence_first": True}, {}), - ({"logit_z_loss": 1e-3}, {}), - ({"logits_scale_factor": 5.0}, {}), - ({"tie_word_embeddings": False}, {}), - ({"prediction_heads": 2}, {}), + ({}, {}, False), + ({}, {"training_dtype": DataType.bfloat16}, False), + ({"transformer": {"full_precision_residual": True}}, {"training_dtype": DataType.bfloat16}, False), + ({"sequence_first": True}, {}, False), + ({"logit_z_loss": 1e-3}, {}, False), + ({"logits_scale_factor": 5.0}, {}, False), + ({"tie_word_embeddings": False}, {}, False), + ({"prediction_heads": 2}, {}, False), + ({}, {}, True), + ({"distillation_model": "distillation"}, {}, False), + ({"distillation_model": "distillation"}, {}, True), ), ) def test_lm_head( cross_entropy_impl: CrossEntropyImpl, config_dict: dict[str, typing.Any], distributed_config_dict: dict[str, typing.Any], + loss_masking: bool, ): config = GPTBaseModelConfig.from_dict( { @@ -99,17 +110,6 @@ def test_lm_head( sequence_first = config.sequence_first or ( config.cross_entropy_splits is not None and config.cross_entropy_splits > 1 ) - target = torch.randint( - 0, - VOCAB_SIZE, - ( - (SEQUENCE_LENGTH + config.prediction_heads - 1, BATCH_SIZE) - if sequence_first - else (BATCH_SIZE, SEQUENCE_LENGTH + config.prediction_heads - 1) - ), - dtype=torch.int64, - device=distributed.device, - ) input_ = torch.randn( (SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE) if sequence_first else (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), dtype=( @@ -120,6 +120,34 @@ def test_lm_head( device=distributed.device, requires_grad=True, ) + label_shape = ( + (SEQUENCE_LENGTH + config.prediction_heads - 1, BATCH_SIZE) + if sequence_first + else (BATCH_SIZE, SEQUENCE_LENGTH + config.prediction_heads - 1) + ) + if loss_masking: + loss_mask = torch.randint( + 0, + VOCAB_SIZE, + label_shape, + dtype=torch.bool, + device=distributed.device, + ) + else: + loss_mask = None + if config.distillation_model is None: + target = torch.randint( + 0, + VOCAB_SIZE, + label_shape, + dtype=torch.int64, + device=distributed.device, + ) + if loss_mask is not None: + target *= loss_mask + else: + assert config.prediction_heads == 1 + target = torch.randn_like(input_) kwargs = { TransformerKwargs.sequence_first: sequence_first, LanguageModelKwargs.labels: target, @@ -173,6 +201,7 @@ def test_lm_head( if sequence_first else target[:, prediction_distance : prediction_distance + SEQUENCE_LENGTH] ), + loss_mask, rms_weight=ref_rms_weight, logit_weight=ref_logit_weight, logit_scale_factor=config.logits_scale_factor, From 30a75b003360a37d907659668d4b4dcf9f61ac3a Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 1 May 2025 20:33:41 +0000 Subject: [PATCH 049/114] eval apriel ssm --- .../ssm/external/eval/apriel_eval_wrapper.py | 59 +++++++++++++++++++ .../models/ssm/external/eval/run_lm_eval.py | 6 ++ .../ssm/external/modeling_ssm_apriel.py | 55 ++++++++++------- 3 files changed, 98 insertions(+), 22 deletions(-) create mode 100644 fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py create mode 100644 fast_llm/models/ssm/external/eval/run_lm_eval.py diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py new file mode 100644 index 00000000..94537c33 --- /dev/null +++ b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py @@ -0,0 +1,59 @@ +from typing import Optional, Union + +import lm_eval.models.utils +import torch +from lm_eval.api.registry import register_model +from lm_eval.models.huggingface import HFLM + + +@register_model("apriel_ssm") +class AprielSSMWrapper(HFLM): + """Wrapper for Rene model for compatibility with lm-evaluation-harness.""" + + def __init__(self, pretrained, **kwargs) -> None: + if "backend" in kwargs: + # rene currently only supports causal models + assert kwargs["backend"] == "causal" + + super().__init__( + pretrained=pretrained, + backend=kwargs.pop("backend", "causal"), + tokenizer=kwargs.pop("tokenizer", "/mnt/checkpoints/upstream/Mistral-Nemo-Base-2407/"), + max_length=kwargs.pop("max_length", 4096), + **kwargs, + ) + + def _get_config(self, pretrained: str, **kwargs) -> None: + """Get the model configuration.""" + from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig + + self._config = AprielSSMConfig.from_pretrained(pretrained) + + def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: + """Create the model.""" + from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM + + self._model = AprielSSMForCausalLM.from_pretrained( + pretrained, + device=self._device, + dtype=torch.bfloat16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), + trust_remote_code=True, + ) + + def _model_generate(self, context, max_length, stop, **generation_kwargs): + """Generate text from the model.""" + for key in ("do_sample", "attention_mask"): + if key in generation_kwargs: + generation_kwargs.pop(key) + + # The custom GenerationMixin imported from mamba_ssm currently does not support + # passing stopping criteria. + # For the time being, we simply generate to max length, then truncate (equivalent result). + # This should be revisited to speed up generation + # stopping_criteria = stop_sequences_criteria(self.tokenizer, stop, 1, context.shape[0]) + + return self.model.generate( + input_ids=context, + max_length=max_length, + **generation_kwargs, + ) diff --git a/fast_llm/models/ssm/external/eval/run_lm_eval.py b/fast_llm/models/ssm/external/eval/run_lm_eval.py new file mode 100644 index 00000000..af07869a --- /dev/null +++ b/fast_llm/models/ssm/external/eval/run_lm_eval.py @@ -0,0 +1,6 @@ +from lm_eval.__main__ import cli_evaluate + +from fast_llm.models.ssm.external.eval.apriel_eval_wrapper import AprielSSMWrapper # noqa: F401 + +if __name__ == "__main__": + cli_evaluate() diff --git a/fast_llm/models/ssm/external/modeling_ssm_apriel.py b/fast_llm/models/ssm/external/modeling_ssm_apriel.py index d30d5b66..5a1b8db4 100644 --- a/fast_llm/models/ssm/external/modeling_ssm_apriel.py +++ b/fast_llm/models/ssm/external/modeling_ssm_apriel.py @@ -19,7 +19,7 @@ from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging from transformers.utils.generic import ModelOutput -from .configuration_ssm_apriel import AprielSSMConfig +from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig logger = logging.get_logger(__name__) @@ -35,12 +35,13 @@ class CustomMambaCausalLMOutput(ModelOutput): class AprielRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, hidden_size, eps=1e-6, device=None, dtype=None, **kwargs): """ AprielRMSNorm is equivalent to T5LayerNorm """ + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) + self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) self.variance_epsilon = eps def forward(self, hidden_states): @@ -58,14 +59,15 @@ def extra_repr(self): class AprielMLP(nn.Module): - def __init__(self, config): - super().__init__() + def __init__(self, config, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias, **factory_kwargs) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): @@ -437,19 +439,21 @@ def convolutional_step(self, xBC, conv_state): class AprielDecoderLayer(nn.Module): - def __init__(self, config: AprielSSMConfig, layer_idx: int): - super().__init__() + def __init__(self, config: AprielSSMConfig, layer_idx: int, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} self.hidden_size = config.hidden_size self.mixer = DiscreteMamba2( d_model=config.hidden_size, layer_idx=layer_idx, **config.ssm_cfg, + **factory_kwargs, ) - self.mlp = AprielMLP(config) - self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = AprielMLP(config, **factory_kwargs) + self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) + self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) def forward( self, hidden_states: torch.Tensor, inference_params=None, **kwargs @@ -598,16 +602,16 @@ class AprielSSMModel(AprielSSMPreTrainedModel): config: AprielSSMConfig """ - def __init__(self, config: AprielSSMConfig): - super().__init__(config) + def __init__(self, config: AprielSSMConfig, device=None, dtype=None, **kwargs): + super().__init__(config, device=device, dtype=dtype, **kwargs) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + factory_kwargs = {"device": device, "dtype": dtype} + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, **factory_kwargs) self.layers = nn.ModuleList( - [AprielDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [AprielDecoderLayer(config, layer_idx, **factory_kwargs) for layer_idx in range(config.num_hidden_layers)] ) - self.norm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -664,11 +668,12 @@ class AprielSSMForCausalLM(AprielSSMPreTrainedModel, GenerationMixin): _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - def __init__(self, config): - super().__init__(config) + def __init__(self, config, device=None, dtype=None, **kwargs): + super().__init__(config, device=device, dtype=dtype, **kwargs) self.model = AprielSSMModel(config) self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + factory_kwargs = {"device": device, "dtype": dtype} + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, **factory_kwargs) # Initialize weights and apply final processing self.post_init() @@ -722,6 +727,12 @@ def forward( last_hidden_state=outputs["last_hidden_state"], ) + def generate(self, *args, **kwargs): + """ + This is a wrapper to make sure we comply with the HF generation interface for eval harness + """ + return super().generate(*args, **kwargs) + __all__ = [ "AprielSSMForCausalLM", From a50bc2e1e414b89718f72086056bdbd1cbc01106 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 1 May 2025 20:38:37 +0000 Subject: [PATCH 050/114] cleanup --- fast_llm/layers/ssm/mamba2.py | 354 --- .../models/ssm/external/ariel_to_ssm.ipynb | 2612 ----------------- .../models/ssm/external/discrete_mamba2.py | 382 --- .../ssm/external/eval/apriel_eval_wrapper.py | 59 - .../models/ssm/external/eval/run_lm_eval.py | 6 - 5 files changed, 3413 deletions(-) delete mode 100644 fast_llm/layers/ssm/mamba2.py delete mode 100644 fast_llm/models/ssm/external/ariel_to_ssm.ipynb delete mode 100644 fast_llm/models/ssm/external/discrete_mamba2.py delete mode 100644 fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py delete mode 100644 fast_llm/models/ssm/external/eval/run_lm_eval.py diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py deleted file mode 100644 index 5763cb92..00000000 --- a/fast_llm/layers/ssm/mamba2.py +++ /dev/null @@ -1,354 +0,0 @@ -""" -This code is adapted from https://github.com/jxiw/MambaInLlama/blob/main/mamba2/hybrid_mamba_layer.py -""" - -import math - -import causal_conv1d -import einops -import mamba_ssm.ops.triton.ssd_combined -import torch -from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated - -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.common.linear import Linear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.tensor import kaiming_init_ - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class Mamba2(torch.nn.Module): - def __init__( - self, - config: SSMConfig, - layer_idx: int, - tensor_space: TensorSpace, - ): - # factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.config: SSMConfig = config - bias = config.add_bias_linear - self.layer_idx = layer_idx - - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) - tensor_space.get_tensor_dim(SSMDimNames.conv_dim) - tensor_space.get_tensor_dim(SSMDimNames.qk_heads) - tensor_space.get_tensor_dim(SSMDimNames.v_heads) - tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) - tensor_space.get_tensor_dim(SSMDimNames.inner_proj_mamba2) - - # self.d_model = d_model - # self.d_state = d_state - # self.d_conv = d_conv - # self.conv_init = conv_init - # self.expand = expand - # self.process_group = process_group - # self.sequence_parallel = sequence_parallel - # self.world_size = 1 if process_group is None else process_group.size() - # self.local_rank = 0 if process_group is None else process_group.rank() - # self.d_inner = d_inner if d_inner is not None else (self.expand * self.d_model) // self.world_size - # # assert self.d_inner * self.world_size == self.expand * self.d_model - # self.headdim = headdim - # self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size - # assert ngroups % self.world_size == 0 - # self.ngroups = ngroups // self.world_size - # assert self.d_ssm % self.headdim == 0 - # self.nheads = self.d_ssm // self.headdim - # self.D_has_hdim = D_has_hdim - # self.rmsnorm = rmsnorm - # self.norm_before_gate = norm_before_gate - # self.dt_limit = dt_limit - # self.activation = "silu" - # self.chunk_size = chunk_size - # self.use_mem_eff_path = use_mem_eff_path - # self.layer_idx = layer_idx - # self.d_xb = d_xb - # self.repeat_group = self.d_inner // self.d_xb - # self.repeat_kv_before_conv = repeat_kv_before_conv - - assert self.d_inner == self.ngroups * self.d_state - assert self.d_inner == self.d_ssm - - self.nheads = self.ngroups - self.headdim = self.d_state - - # Order: [z, x, B, C, dt] - # [hidden_dim, hidden_dim, d_state] - d_in_proj = self.d_inner + self.d_xb + self.d_xb + self.d_inner + self.nheads - # d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads - if self.process_group is None: - self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs) - else: - self.in_proj = ColumnParallelLinear( - self.d_model, - d_in_proj * self.world_size, - bias=bias, - process_group=self.process_group, - sequence_parallel=self.sequence_parallel, - **factory_kwargs, - ) - - # conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state - - if self.repeat_kv_before_conv: - conv_dim = self.d_inner + self.d_inner + self.d_inner - self.conv1d = nn.Conv1d( - in_channels=conv_dim, - out_channels=conv_dim, - bias=conv_bias, - kernel_size=d_conv, - groups=conv_dim, - padding=d_conv - 1, - **factory_kwargs, - ) - else: - conv_dim = self.d_inner + self.d_xb + self.d_xb - self.conv1d = nn.Conv1d( - in_channels=conv_dim, - out_channels=conv_dim, - bias=conv_bias, - kernel_size=d_conv, - groups=conv_dim, - padding=d_conv - 1, - **factory_kwargs, - ) - - if self.conv_init is not None: - nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init) - - self.act = nn.SiLU() - - # Initialize log dt bias - dt = torch.exp( - torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) - ) - dt = torch.clamp(dt, min=dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - self.dt_bias = nn.Parameter(inv_dt) - # Just to be explicit. Without this we already don't put wd on dt_bias because of the check - # name.endswith("bias") in param_grouping.py - self.dt_bias._no_weight_decay = True - - assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] - A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range) - A_log = torch.log(A).to(dtype=dtype) - self.A_log = nn.Parameter(A_log) - self.A_log._no_weight_decay = True - - # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device)) - self.D._no_weight_decay = True - - if self.rmsnorm: - assert RMSNormGated is not None - self.norm = RMSNormGated( - self.d_ssm, - eps=1e-5, - norm_before_gate=self.norm_before_gate, - group_size=self.d_ssm // ngroups, - **factory_kwargs, - ) - - # self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - self.out_proj = Linear( - td_inner, - td_model, - bias=bias, - weight_init_method=kaiming_init_(td_inner.size), - ) - - def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None): - """ - u: (batch, seqlen, hidden_dim) if seqlen=None. - If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we - split u during sequence parallel, we split the batch * seqlen dimension - (in case batch is small). - Returns: same shape as u - """ - seqlen_og = seqlen - if seqlen is None: - batch, seqlen, dim = u.shape - else: - batch_seqlen, dim = u.shape - batch = batch_seqlen // seqlen - - conv_state, ssm_state = None, None - if inference_params is not None: - inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch - conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch) - if inference_params.seqlen_offset > 0: - # The states are updated inplace - out, _, _ = self.step(u, conv_state, ssm_state) - return out - - zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj) - if seqlen_og is not None: - zxbcdt = einops.rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen) - # If the model is loaded in fp16, without the .float() here, A might be -inf - A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state) - dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit) - - # [z, x, B, C, dt] - d_mlp = (zxbcdt.shape[-1] - 2 * self.d_inner - 2 * self.d_xb - self.nheads) // 2 - z0, x0, z, xBC, dt = torch.split( - zxbcdt, [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.d_xb, self.nheads], dim=-1 - ) - - if self.repeat_kv_before_conv: - x, B, C = torch.split(xBC, [self.d_xb, self.d_xb, self.ngroups * self.d_state], dim=-1) - # minic the GQA - x = einops.rearrange(x, "b l (xb_group dstate) -> b xb_group l dstate", dstate=self.d_state) - x = repeat_kv(x, self.repeat_group) - # x shape: (bsz, n_group, l, dim) - B = einops.rearrange(B, "b l (xb_group dstate) -> b xb_group l dstate", dstate=self.d_state) - B = repeat_kv(B, self.repeat_group) - # combine x, B, C - x = einops.rearrange(x, "b g l p -> b l (g p)") - B = einops.rearrange(B, "b g l p -> b l (g p)") - xBC = torch.cat((x, B, C), dim=-1) - - if conv_state is not None: - if cu_seqlens is None: - # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - xBC_t = einops.rearrange(xBC, "b l d -> b d l") - conv_state.copy_( - torch.nn.functional.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0)) - ) # Update state (B D W) - else: - assert ( - causal_conv1d.causal_conv1d_varlen_states is not None - ), "varlen inference requires causal_conv1d package" - assert batch == 1, "varlen inference only supports batch dimension 1" - conv_varlen_states = causal_conv1d.causal_conv1d_varlen_states( - xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1] - ) - conv_state.copy_(conv_varlen_states) - assert self.activation in ["silu", "swish"] - - if causal_conv1d.causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: - assert seq_idx is None, "varlen conv1d requires the causal_conv1d package" - xBC = self.act( - self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.d_conv - 1) :] - ) # (B, L, self.d_ssm + 2 * ngroups * d_state) - else: - xBC = causal_conv1d.causal_conv1d_fn( - xBC.transpose(1, 2), - einops.rearrange(self.conv1d.weight, "d 1 w -> d w"), - bias=self.conv1d.bias, - activation=self.activation, - seq_idx=seq_idx, - ).transpose(1, 2) - - if self.repeat_kv_before_conv: - x, B, C = torch.split( - xBC, [self.ngroups * self.d_state, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1 - ) - - y = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( - einops.rearrange(x, "b l (h p) -> b l h p", p=self.headdim), - dt, - A, - einops.rearrange(B, "b l (g n) -> b l g n", g=self.ngroups), - einops.rearrange(C, "b l (g n) -> b l g n", g=self.ngroups), - chunk_size=self.chunk_size, - D=einops.rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D, - z=einops.rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None, - dt_bias=self.dt_bias, - dt_softplus=True, - seq_idx=seq_idx, - cu_seqlens=cu_seqlens, - **dt_limit_kwargs, - return_final_states=ssm_state is not None, - return_varlen_states=cu_seqlens is not None and inference_params is not None, - ) - - else: - # self.d_xb + self.d_xb + self.d_inner - x, B, C = torch.split(xBC, [self.d_xb, self.d_xb, self.ngroups * self.d_state], dim=-1) - - # minic the GQA - x = einops.rearrange(x, "b l (xb_group dstate) -> b xb_group l dstate", dstate=self.d_state) - x = repeat_kv(x, self.repeat_group) - # x shape: (bsz, n_group, l, dim) - - B = einops.rearrange(B, "b l (xb_group dstate) -> b xb_group l dstate", dstate=self.d_state) - B = repeat_kv(B, self.repeat_group) - - y = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( - # einops.rearrange(x, "b l (h p) -> b l h p", p=self.headdim), - einops.rearrange(x, "b g l p -> b l g p"), - dt, - A, - # einops.rearrange(B, "b l (g n) -> b l g n", g=self.ngroups), - einops.rearrange(B, "b g l n -> b l g n"), - einops.rearrange(C, "b l (g n) -> b l g n", g=self.ngroups), - chunk_size=self.chunk_size, - D=einops.rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D, - z=einops.rearrange(z, "b l (h p) -> b l h p", p=self.headdim) if not self.rmsnorm else None, - dt_bias=self.dt_bias, - dt_softplus=True, - seq_idx=seq_idx, - cu_seqlens=cu_seqlens, - **dt_limit_kwargs, - return_final_states=ssm_state is not None, - return_varlen_states=cu_seqlens is not None and inference_params is not None, - ) - - if ssm_state is not None: - y, last_state, *rest = y - if cu_seqlens is None: - ssm_state.copy_(last_state) - else: - varlen_states = rest[0] - ssm_state.copy_(varlen_states) - y = einops.rearrange(y, "b l h p -> b l (h p)") - if self.rmsnorm: - y = self.norm(y, z) - if d_mlp > 0: - y = torch.cat([torch.nn.functional.silu(z0) * x0, y], dim=-1) - if seqlen_og is not None: - y = einops.rearrange(y, "b l d -> (b l) d") - out = self.out_proj(y) - return out - - assert self.layer_idx is not None - if self.layer_idx not in inference_params.key_value_memory_dict: - (batch_size,) - conv_state = torch.zeros( - batch_size, - self.d_conv, - self.conv1d.weight.shape[0], - device=self.conv1d.weight.device, - dtype=self.conv1d.weight.dtype, - ).transpose(1, 2) - ssm_state = torch.zeros( - batch_size, - self.nheads, - self.headdim, - self.d_state, - device=self.in_proj.weight.device, - dtype=self.in_proj.weight.dtype, - ) - inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) - else: - conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] - # TODO: What if batch size changes between generation, and we reuse the same states? - if initialize_states: - conv_state.zero_() - ssm_state.zero_() - return conv_state, ssm_state diff --git a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb deleted file mode 100644 index a8390fa3..00000000 --- a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb +++ /dev/null @@ -1,2612 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "import torch\n", - "from mamba_ssm import MambaLMHeadModel\n", - "from mamba_ssm.models.config_mamba import MambaConfig\n", - "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", - "from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig\n", - "from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM\n", - "from transformers.cache_utils import StaticCache\n", - "from types import SimpleNamespace\n", - "\n", - "# make sure the code changes reflected without reload\n", - "%load_ext autoreload\n", - "%autoreload 2\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 8.90it/s]\n" - ] - }, - { - "data": { - "text/plain": [ - "AprielForCausalLM(\n", - " (model): AprielModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0-27): 28 x AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (rotary_emb): AprielRotaryEmbedding()\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", - ")" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", - "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", - "apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", - "apriel_state_dict = apriel_model.state_dict()\n", - "apriel_model.to(device).to(dtype=torch.bfloat16)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.bfloat16" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_model.config.torch_dtype" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "n_params = sum(p.numel() for p in apriel_model.parameters() if p.requires_grad)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "4.83207168" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "n_params/1e9" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n" - ] - } - ], - "source": [ - "config_apriel = AprielSSMConfig.from_pretrained(\"/mnt/checkpoints_fml/pretrained_models/ssm/apriel_ssm_instruct_base\", trust_remote_code=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n", - "You are using a model of type llamba to instantiate a model of type apriel_ssm. This is not supported for all configurations of models and can yield errors.\n" - ] - }, - { - "ename": "KeyError", - "evalue": "'n_qk_heads'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[12], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m stage2_checkpoint \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/mnt/checkpoints_fml/pretrained_models/ssm/mohawk_final\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 2\u001b[0m stage2_apriel_ssm \u001b[38;5;241m=\u001b[39m \u001b[43mAprielSSMForCausalLM\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstage2_checkpoint\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtorch_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbfloat16\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/modeling_utils.py:3571\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 3569\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(config, PretrainedConfig):\n\u001b[1;32m 3570\u001b[0m config_path \u001b[38;5;241m=\u001b[39m config \u001b[38;5;28;01mif\u001b[39;00m config \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m pretrained_model_name_or_path\n\u001b[0;32m-> 3571\u001b[0m config, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconfig_class\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3572\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3573\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3574\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_unused_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 3575\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3576\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3577\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3578\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3579\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3580\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3581\u001b[0m \u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msubfolder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3582\u001b[0m \u001b[43m \u001b[49m\u001b[43m_from_auto\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfrom_auto_class\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3583\u001b[0m \u001b[43m \u001b[49m\u001b[43m_from_pipeline\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfrom_pipeline\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3584\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3585\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3586\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3587\u001b[0m \u001b[38;5;66;03m# In case one passes a config to `from_pretrained` + \"attn_implementation\"\u001b[39;00m\n\u001b[1;32m 3588\u001b[0m \u001b[38;5;66;03m# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 3592\u001b[0m \u001b[38;5;66;03m# we pop attn_implementation from the kwargs but this handles the case where users\u001b[39;00m\n\u001b[1;32m 3593\u001b[0m \u001b[38;5;66;03m# passes manually the config to `from_pretrained`.\u001b[39;00m\n\u001b[1;32m 3594\u001b[0m config \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(config)\n", - "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/configuration_utils.py:569\u001b[0m, in \u001b[0;36mPretrainedConfig.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, **kwargs)\u001b[0m\n\u001b[1;32m 563\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_type:\n\u001b[1;32m 564\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarning(\n\u001b[1;32m 565\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou are using a model of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mconfig_dict[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m to instantiate a model of type \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 566\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. This is not supported for all configurations of models and can yield errors.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 567\u001b[0m )\n\u001b[0;32m--> 569\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/configuration_utils.py:740\u001b[0m, in \u001b[0;36mPretrainedConfig.from_dict\u001b[0;34m(cls, config_dict, **kwargs)\u001b[0m\n\u001b[1;32m 737\u001b[0m \u001b[38;5;66;03m# We remove it from kwargs so that it does not appear in `return_unused_kwargs`.\u001b[39;00m\n\u001b[1;32m 738\u001b[0m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattn_implementation\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattn_implementation\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m--> 740\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mconfig_dict\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 742\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(config, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpruned_heads\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 743\u001b[0m config\u001b[38;5;241m.\u001b[39mpruned_heads \u001b[38;5;241m=\u001b[39m {\u001b[38;5;28mint\u001b[39m(key): value \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m config\u001b[38;5;241m.\u001b[39mpruned_heads\u001b[38;5;241m.\u001b[39mitems()}\n", - "File \u001b[0;32m~/dev/Fast-LLM/fast_llm/models/ssm/external/configuration_ssm_apriel.py:99\u001b[0m, in \u001b[0;36mAprielSSMConfig.__init__\u001b[0;34m(self, vocab_size, hidden_size, intermediate_size, num_hidden_layers, hidden_act, initializer_range, use_cache, pad_token_id, bos_token_id, eos_token_id, tie_word_embeddings, mlp_bias, rms_norm_eps, ssm_cfg, head_dim, **kwargs)\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(\n\u001b[1;32m 82\u001b[0m pad_token_id\u001b[38;5;241m=\u001b[39mpad_token_id,\n\u001b[1;32m 83\u001b[0m bos_token_id\u001b[38;5;241m=\u001b[39mbos_token_id,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 87\u001b[0m )\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mssm_cfg \u001b[38;5;241m=\u001b[39m ssm_cfg \u001b[38;5;129;01mor\u001b[39;00m {\n\u001b[1;32m 90\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_state\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m64\u001b[39m,\n\u001b[1;32m 91\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mn_v_heads\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m24\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m24\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhead_dim, \u001b[38;5;66;03m# num_heads * head_dim\u001b[39;00m\n\u001b[1;32m 98\u001b[0m }\n\u001b[0;32m---> 99\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhead_dim \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mssm_cfg[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mssm_cfg\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mn_qk_heads\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\n", - "\u001b[0;31mKeyError\u001b[0m: 'n_qk_heads'" - ] - } - ], - "source": [ - "stage2_checkpoint = \"/mnt/checkpoints_fml/pretrained_models/ssm/mohawk_final\"\n", - "stage2_apriel_ssm = AprielSSMForCausalLM.from_pretrained(stage2_checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "apriel_ssm_config = AprielSSMConfig(vocab_size=config.vocab_size, \n", - " hidden_size=config.hidden_size,\n", - " intermediate_size=config.intermediate_size,\n", - " num_hidden_layers=config.num_hidden_layers,\n", - " hidden_act=config.hidden_act,\n", - " initializer_range=config.initializer_range,\n", - " use_cache=config.use_cache,\n", - " mlp_bias=config.mlp_bias,\n", - " tie_word_embeddings=config.tie_word_embeddings,\n", - " pad_token_id=config.pad_token_id,\n", - " bos_token_id=config.bos_token_id,\n", - " eos_token_id=config.eos_token_id,\n", - " head_dim=config.head_dim,\n", - " rms_norm_eps=config.rms_norm_eps)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "apriel_ssm = AprielSSMForCausalLM(apriel_ssm_config)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "OrderedDict([('model.embed_tokens.weight',\n", - " tensor([[ 0.0105, 0.0330, -0.0032, ..., 0.0076, -0.0051, 0.0112],\n", - " [-0.0111, -0.0101, 0.0064, ..., 0.0144, 0.0098, -0.0194],\n", - " [ 0.0301, 0.0228, 0.0105, ..., -0.0159, 0.0112, -0.0009],\n", - " ...,\n", - " [ 0.0266, 0.0224, -0.0150, ..., 0.0189, -0.0253, -0.0300],\n", - " [-0.0304, 0.0249, 0.0140, ..., -0.0235, 0.0315, -0.0188],\n", - " [-0.0215, -0.0034, 0.0035, ..., -0.0125, 0.0084, 0.0246]])),\n", - " ('model.layers.0.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.0.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.0.mixer.in_proj.weight',\n", - " tensor([[ 0.0104, 0.0055, -0.0148, ..., 0.0208, -0.0074, 0.0015],\n", - " [ 0.0102, 0.0148, 0.0148, ..., -0.0041, 0.0224, -0.0336],\n", - " [ 0.0129, -0.0179, -0.0120, ..., 0.0175, 0.0300, -0.0234],\n", - " ...,\n", - " [-0.0215, 0.0002, 0.0093, ..., -0.0424, 0.0016, -0.0162],\n", - " [-0.0178, -0.0093, 0.0226, ..., 0.0005, 0.0062, 0.0150],\n", - " [-0.0204, 0.0039, -0.0364, ..., -0.0128, 0.0002, 0.0134]])),\n", - " ('model.layers.0.mixer.conv1d.weight',\n", - " tensor([[[-0.1064, -0.3782, -0.3080, -0.3179]],\n", - " \n", - " [[-0.3493, 0.2230, 0.1062, 0.0614]],\n", - " \n", - " [[-0.4650, 0.0300, 0.3021, 0.1197]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.3686, 0.0679, 0.1440, 0.4445]],\n", - " \n", - " [[-0.1480, 0.3750, -0.0552, -0.0297]],\n", - " \n", - " [[ 0.0677, 0.0925, -0.0268, -0.0232]]])),\n", - " ('model.layers.0.mixer.conv1d.bias',\n", - " tensor([ 0.1379, 0.0862, -0.0723, ..., -0.2628, -0.1867, -0.1233])),\n", - " ('model.layers.0.mixer.out_proj.weight',\n", - " tensor([[ 0.0208, -0.0106, -0.0016, ..., 0.0117, 0.0140, -0.0040],\n", - " [-0.0147, 0.0419, 0.0327, ..., -0.0073, -0.0127, 0.0190],\n", - " [-0.0218, 0.0030, 0.0115, ..., -0.0062, 0.0214, 0.0105],\n", - " ...,\n", - " [ 0.0089, 0.0154, -0.0178, ..., -0.0206, -0.0378, 0.0102],\n", - " [ 0.0153, -0.0249, 0.0219, ..., 0.0119, 0.0019, 0.0383],\n", - " [-0.0126, 0.0284, -0.0035, ..., 0.0118, -0.0186, -0.0232]])),\n", - " ('model.layers.0.mlp.gate_proj.weight',\n", - " tensor([[-0.0032, -0.0405, 0.0180, ..., -0.0030, -0.0222, 0.0069],\n", - " [-0.0071, -0.0064, -0.0207, ..., 0.0037, -0.0077, 0.0261],\n", - " [ 0.0236, 0.0167, 0.0065, ..., 0.0064, 0.0035, -0.0092],\n", - " ...,\n", - " [-0.0357, 0.0192, 0.0099, ..., -0.0067, -0.0181, 0.0082],\n", - " [-0.0139, -0.0161, -0.0015, ..., -0.0052, -0.0337, 0.0514],\n", - " [ 0.0105, -0.0205, 0.0198, ..., 0.0090, 0.0315, 0.0066]])),\n", - " ('model.layers.0.mlp.up_proj.weight',\n", - " tensor([[ 0.0074, 0.0237, -0.0300, ..., 0.0343, 0.0016, 0.0395],\n", - " [ 0.0270, 0.0085, 0.0193, ..., 0.0199, -0.0139, 0.0094],\n", - " [ 0.0036, 0.0073, 0.0149, ..., 0.0094, 0.0346, -0.0111],\n", - " ...,\n", - " [ 0.0159, -0.0346, -0.0128, ..., 0.0377, -0.0531, -0.0305],\n", - " [ 0.0283, 0.0162, -0.0377, ..., -0.0254, 0.0110, -0.0167],\n", - " [-0.0277, 0.0130, 0.0161, ..., 0.0089, -0.0190, 0.0214]])),\n", - " ('model.layers.0.mlp.down_proj.weight',\n", - " tensor([[ 0.0157, 0.0105, 0.0036, ..., 0.0229, 0.0080, 0.0303],\n", - " [-0.0143, -0.0067, 0.0016, ..., 0.0494, -0.0043, 0.0072],\n", - " [-0.0148, 0.0113, 0.0025, ..., -0.0186, 0.0206, -0.0119],\n", - " ...,\n", - " [-0.0226, 0.0099, 0.0010, ..., 0.0123, -0.0170, 0.0024],\n", - " [-0.0120, -0.0015, -0.0355, ..., 0.0064, 0.0175, -0.0065],\n", - " [ 0.0364, 0.0364, 0.0265, ..., -0.0222, 0.0030, 0.0296]])),\n", - " ('model.layers.0.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.0.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.1.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.1.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.1.mixer.in_proj.weight',\n", - " tensor([[-0.0116, -0.0182, -0.0017, ..., -0.0216, -0.0136, -0.0203],\n", - " [-0.0142, -0.0106, -0.0334, ..., 0.0287, -0.0273, 0.0050],\n", - " [ 0.0131, -0.0106, -0.0012, ..., 0.0261, -0.0228, -0.0026],\n", - " ...,\n", - " [-0.0029, 0.0023, 0.0360, ..., -0.0195, 0.0018, -0.0227],\n", - " [ 0.0004, 0.0015, -0.0051, ..., -0.0095, 0.0269, 0.0179],\n", - " [ 0.0295, -0.0520, 0.0009, ..., 0.0019, 0.0255, 0.0478]])),\n", - " ('model.layers.1.mixer.conv1d.weight',\n", - " tensor([[[-0.4725, -0.2938, -0.3816, -0.1239]],\n", - " \n", - " [[-0.2002, 0.3790, 0.1908, -0.4679]],\n", - " \n", - " [[-0.3674, 0.3774, -0.2479, 0.4324]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.4181, 0.2263, -0.1937, 0.3585]],\n", - " \n", - " [[ 0.0704, 0.0913, 0.4217, 0.3004]],\n", - " \n", - " [[ 0.3175, -0.3239, -0.0614, -0.3978]]])),\n", - " ('model.layers.1.mixer.conv1d.bias',\n", - " tensor([ 0.4302, 0.0269, -0.3462, ..., 0.4887, 0.2848, 0.0745])),\n", - " ('model.layers.1.mixer.out_proj.weight',\n", - " tensor([[-0.0069, 0.0233, 0.0133, ..., -0.0064, -0.0085, 0.0166],\n", - " [-0.0302, 0.0129, -0.0042, ..., 0.0109, 0.0009, -0.0087],\n", - " [-0.0373, -0.0233, -0.0043, ..., -0.0017, 0.0384, -0.0114],\n", - " ...,\n", - " [-0.0219, 0.0330, -0.0341, ..., 0.0080, 0.0089, 0.0268],\n", - " [-0.0019, -0.0069, 0.0276, ..., 0.0182, -0.0240, 0.0163],\n", - " [ 0.0081, 0.0070, 0.0156, ..., -0.0135, 0.0469, -0.0221]])),\n", - " ('model.layers.1.mlp.gate_proj.weight',\n", - " tensor([[ 0.0175, -0.0074, -0.0028, ..., 0.0197, 0.0034, 0.0221],\n", - " [ 0.0063, 0.0339, -0.0047, ..., 0.0037, -0.0126, -0.0342],\n", - " [-0.0093, -0.0148, -0.0236, ..., 0.0190, -0.0451, -0.0173],\n", - " ...,\n", - " [ 0.0167, 0.0161, 0.0019, ..., -0.0083, -0.0133, 0.0141],\n", - " [-0.0163, 0.0383, -0.0203, ..., 0.0336, -0.0148, 0.0013],\n", - " [-0.0138, -0.0275, -0.0268, ..., -0.0243, -0.0031, -0.0227]])),\n", - " ('model.layers.1.mlp.up_proj.weight',\n", - " tensor([[ 0.0054, 0.0031, 0.0256, ..., 0.0002, 0.0020, -0.0050],\n", - " [ 0.0247, -0.0298, -0.0218, ..., -0.0161, 0.0253, 0.0128],\n", - " [-0.0231, -0.0012, 0.0130, ..., 0.0031, -0.0324, 0.0107],\n", - " ...,\n", - " [ 0.0359, -0.0202, 0.0386, ..., -0.0104, 0.0274, 0.0161],\n", - " [ 0.0062, -0.0111, 0.0338, ..., 0.0041, 0.0001, -0.0019],\n", - " [ 0.0105, -0.0258, 0.0184, ..., -0.0270, -0.0138, -0.0367]])),\n", - " ('model.layers.1.mlp.down_proj.weight',\n", - " tensor([[-0.0163, -0.0308, -0.0203, ..., 0.0002, -0.0227, 0.0019],\n", - " [ 0.0206, 0.0037, 0.0064, ..., -0.0261, -0.0206, 0.0063],\n", - " [ 0.0044, -0.0073, -0.0576, ..., -0.0015, -0.0082, 0.0022],\n", - " ...,\n", - " [-0.0034, 0.0142, -0.0547, ..., -0.0106, -0.0090, 0.0249],\n", - " [-0.0068, 0.0127, -0.0066, ..., -0.0255, 0.0004, 0.0106],\n", - " [-0.0293, 0.0146, -0.0142, ..., -0.0073, -0.0284, -0.0069]])),\n", - " ('model.layers.1.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.1.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.2.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.2.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.2.mixer.in_proj.weight',\n", - " tensor([[ 0.0337, -0.0055, -0.0538, ..., -0.0051, 0.0107, -0.0338],\n", - " [ 0.0227, -0.0008, 0.0003, ..., -0.0312, 0.0090, -0.0126],\n", - " [-0.0238, 0.0146, 0.0240, ..., -0.0114, -0.0180, 0.0025],\n", - " ...,\n", - " [-0.0208, -0.0261, 0.0227, ..., 0.0071, 0.0014, 0.0237],\n", - " [ 0.0356, 0.0372, 0.0186, ..., 0.0052, 0.0049, -0.0195],\n", - " [ 0.0023, -0.0159, -0.0238, ..., 0.0194, -0.0056, -0.0275]])),\n", - " ('model.layers.2.mixer.conv1d.weight',\n", - " tensor([[[ 0.1054, -0.4185, 0.4229, 0.3289]],\n", - " \n", - " [[-0.0081, 0.0321, 0.1334, -0.1055]],\n", - " \n", - " [[ 0.1587, -0.3806, -0.1336, -0.2662]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.2830, -0.3875, -0.2972, 0.0030]],\n", - " \n", - " [[ 0.4210, 0.2190, -0.4942, 0.0465]],\n", - " \n", - " [[-0.1830, -0.3686, 0.2928, -0.0313]]])),\n", - " ('model.layers.2.mixer.conv1d.bias',\n", - " tensor([-0.2931, -0.3513, -0.3013, ..., -0.1934, -0.3115, 0.3889])),\n", - " ('model.layers.2.mixer.out_proj.weight',\n", - " tensor([[-0.0038, -0.0160, -0.0042, ..., 0.0062, 0.0059, -0.0126],\n", - " [-0.0027, -0.0012, -0.0065, ..., -0.0032, 0.0129, -0.0298],\n", - " [ 0.0394, -0.0096, 0.0107, ..., -0.0290, 0.0248, 0.0308],\n", - " ...,\n", - " [ 0.0087, 0.0067, -0.0261, ..., -0.0038, -0.0168, 0.0485],\n", - " [ 0.0118, 0.0042, -0.0186, ..., 0.0104, 0.0281, 0.0028],\n", - " [ 0.0304, -0.0382, -0.0028, ..., -0.0264, -0.0050, 0.0050]])),\n", - " ('model.layers.2.mlp.gate_proj.weight',\n", - " tensor([[-0.0169, 0.0036, 0.0024, ..., 0.0429, 0.0313, 0.0167],\n", - " [-0.0100, 0.0011, -0.0024, ..., -0.0065, 0.0090, 0.0123],\n", - " [ 0.0102, 0.0282, 0.0166, ..., -0.0082, 0.0123, 0.0253],\n", - " ...,\n", - " [ 0.0168, -0.0056, -0.0096, ..., -0.0090, 0.0150, 0.0209],\n", - " [ 0.0258, 0.0113, -0.0093, ..., 0.0335, 0.0386, -0.0156],\n", - " [ 0.0129, 0.0338, -0.0006, ..., -0.0346, 0.0135, -0.0213]])),\n", - " ('model.layers.2.mlp.up_proj.weight',\n", - " tensor([[-0.0029, 0.0416, -0.0102, ..., -0.0413, 0.0019, 0.0063],\n", - " [ 0.0054, 0.0138, 0.0031, ..., -0.0077, -0.0070, -0.0016],\n", - " [ 0.0128, 0.0153, -0.0147, ..., -0.0131, -0.0244, 0.0097],\n", - " ...,\n", - " [-0.0190, -0.0025, 0.0322, ..., -0.0106, -0.0323, -0.0144],\n", - " [-0.0269, -0.0007, 0.0070, ..., 0.0191, -0.0025, 0.0033],\n", - " [-0.0311, 0.0217, -0.0021, ..., 0.0302, -0.0131, 0.0388]])),\n", - " ('model.layers.2.mlp.down_proj.weight',\n", - " tensor([[ 0.0150, -0.0127, 0.0372, ..., 0.0018, 0.0018, 0.0187],\n", - " [-0.0262, 0.0164, 0.0281, ..., 0.0120, -0.0187, -0.0177],\n", - " [ 0.0129, -0.0042, 0.0018, ..., -0.0136, 0.0278, 0.0284],\n", - " ...,\n", - " [ 0.0048, 0.0421, -0.0018, ..., 0.0002, -0.0064, 0.0085],\n", - " [ 0.0276, 0.0146, 0.0228, ..., 0.0055, -0.0288, -0.0081],\n", - " [-0.0133, 0.0102, 0.0318, ..., 0.0209, -0.0270, 0.0128]])),\n", - " ('model.layers.2.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.2.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.3.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.3.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.3.mixer.in_proj.weight',\n", - " tensor([[ 7.4766e-03, -9.8698e-03, -1.9172e-02, ..., 3.7842e-02,\n", - " -2.1648e-03, 2.8147e-03],\n", - " [ 2.4954e-02, -1.2659e-02, 8.0447e-04, ..., 3.1716e-02,\n", - " 4.9989e-03, 6.4200e-03],\n", - " [-3.3345e-02, -1.5256e-02, 2.7295e-02, ..., -1.1240e-02,\n", - " 9.7000e-03, 3.1136e-05],\n", - " ...,\n", - " [-2.0807e-04, -2.5132e-02, -1.9983e-02, ..., -2.9541e-02,\n", - " 4.6152e-04, 5.5341e-02],\n", - " [ 2.0498e-03, 2.2021e-02, -7.6882e-03, ..., 1.6469e-02,\n", - " -1.0645e-02, -1.8442e-03],\n", - " [ 2.0949e-03, -1.2398e-02, 1.2922e-02, ..., 1.1862e-02,\n", - " -4.7119e-03, 3.2352e-02]])),\n", - " ('model.layers.3.mixer.conv1d.weight',\n", - " tensor([[[ 0.2590, 0.1670, 0.3987, -0.1694]],\n", - " \n", - " [[-0.4425, 0.1468, 0.3060, -0.0764]],\n", - " \n", - " [[-0.3638, -0.0575, 0.2156, -0.2468]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.0111, -0.0182, -0.3816, 0.0382]],\n", - " \n", - " [[-0.4723, -0.3712, 0.1963, 0.2877]],\n", - " \n", - " [[-0.4890, 0.1197, 0.1361, 0.3282]]])),\n", - " ('model.layers.3.mixer.conv1d.bias',\n", - " tensor([-0.4712, -0.3272, 0.4587, ..., -0.3145, 0.4086, 0.4005])),\n", - " ('model.layers.3.mixer.out_proj.weight',\n", - " tensor([[-0.0362, 0.0137, -0.0296, ..., -0.0028, 0.0104, 0.0393],\n", - " [ 0.0130, 0.0246, -0.0132, ..., 0.0082, -0.0044, -0.0054],\n", - " [-0.0081, -0.0115, -0.0064, ..., 0.0250, -0.0076, -0.0021],\n", - " ...,\n", - " [ 0.0230, -0.0055, 0.0056, ..., 0.0076, 0.0016, -0.0068],\n", - " [ 0.0472, -0.0068, 0.0336, ..., 0.0079, 0.0211, 0.0031],\n", - " [-0.0450, -0.0005, 0.0219, ..., 0.0044, -0.0006, -0.0278]])),\n", - " ('model.layers.3.mlp.gate_proj.weight',\n", - " tensor([[ 0.0034, 0.0445, -0.0132, ..., 0.0290, 0.0019, 0.0048],\n", - " [ 0.0271, 0.0109, 0.0028, ..., -0.0304, -0.0237, -0.0017],\n", - " [ 0.0098, 0.0252, 0.0392, ..., 0.0486, 0.0326, -0.0171],\n", - " ...,\n", - " [-0.0015, 0.0080, 0.0005, ..., -0.0158, -0.0067, 0.0347],\n", - " [-0.0638, 0.0120, 0.0076, ..., 0.0007, 0.0052, -0.0109],\n", - " [-0.0303, -0.0168, -0.0537, ..., -0.0163, -0.0030, -0.0068]])),\n", - " ('model.layers.3.mlp.up_proj.weight',\n", - " tensor([[-0.0074, -0.0101, 0.0073, ..., -0.0012, -0.0208, -0.0239],\n", - " [ 0.0035, 0.0010, 0.0157, ..., -0.0228, -0.0224, 0.0194],\n", - " [ 0.0457, -0.0129, -0.0063, ..., -0.0312, 0.0261, -0.0018],\n", - " ...,\n", - " [ 0.0012, 0.0093, 0.0121, ..., -0.0035, -0.0367, -0.0454],\n", - " [ 0.0308, -0.0334, 0.0062, ..., 0.0043, -0.0031, -0.0406],\n", - " [-0.0175, -0.0089, -0.0137, ..., -0.0322, -0.0070, -0.0219]])),\n", - " ('model.layers.3.mlp.down_proj.weight',\n", - " tensor([[ 0.0226, 0.0074, -0.0170, ..., 0.0035, 0.0420, -0.0085],\n", - " [ 0.0116, 0.0173, -0.0009, ..., -0.0302, 0.0075, 0.0153],\n", - " [-0.0092, 0.0119, 0.0164, ..., 0.0233, -0.0177, -0.0397],\n", - " ...,\n", - " [-0.0006, -0.0275, 0.0127, ..., -0.0185, 0.0335, -0.0133],\n", - " [ 0.0064, -0.0200, 0.0296, ..., 0.0041, -0.0114, -0.0221],\n", - " [ 0.0317, 0.0392, 0.0553, ..., 0.0191, 0.0188, -0.0176]])),\n", - " ('model.layers.3.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.3.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.4.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.4.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.4.mixer.in_proj.weight',\n", - " tensor([[-0.0266, 0.0092, -0.0260, ..., -0.0121, -0.0286, 0.0267],\n", - " [ 0.0144, -0.0053, -0.0060, ..., -0.0065, 0.0201, -0.0025],\n", - " [-0.0092, -0.0465, -0.0032, ..., 0.0192, -0.0026, 0.0104],\n", - " ...,\n", - " [-0.0210, -0.0286, -0.0148, ..., 0.0593, 0.0130, 0.0118],\n", - " [ 0.0361, -0.0070, 0.0054, ..., -0.0073, 0.0004, 0.0287],\n", - " [ 0.0450, -0.0286, 0.0191, ..., -0.0180, 0.0039, -0.0033]])),\n", - " ('model.layers.4.mixer.conv1d.weight',\n", - " tensor([[[ 0.1450, 0.2065, -0.1750, -0.4560]],\n", - " \n", - " [[-0.2889, -0.4707, -0.0741, 0.1254]],\n", - " \n", - " [[-0.4665, 0.1876, -0.4049, 0.1143]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.0709, 0.2021, -0.0053, -0.1558]],\n", - " \n", - " [[-0.0195, -0.4046, -0.2437, -0.4405]],\n", - " \n", - " [[-0.3615, -0.4314, 0.1667, 0.3139]]])),\n", - " ('model.layers.4.mixer.conv1d.bias',\n", - " tensor([-0.3220, -0.4181, -0.0623, ..., 0.2788, 0.0518, 0.4607])),\n", - " ('model.layers.4.mixer.out_proj.weight',\n", - " tensor([[-0.0011, -0.0279, -0.0160, ..., -0.0222, 0.0262, 0.0234],\n", - " [ 0.0024, 0.0178, -0.0142, ..., 0.0048, -0.0145, 0.0332],\n", - " [-0.0084, -0.0037, 0.0054, ..., -0.0201, -0.0341, -0.0053],\n", - " ...,\n", - " [-0.0120, -0.0440, 0.0097, ..., -0.0070, -0.0129, 0.0170],\n", - " [ 0.0096, -0.0034, -0.0025, ..., 0.0242, 0.0047, 0.0093],\n", - " [ 0.0254, 0.0207, 0.0135, ..., 0.0204, -0.0185, -0.0026]])),\n", - " ('model.layers.4.mlp.gate_proj.weight',\n", - " tensor([[ 0.0049, 0.0087, 0.0081, ..., 0.0145, 0.0188, 0.0441],\n", - " [-0.0103, 0.0147, 0.0180, ..., -0.0190, 0.0182, 0.0160],\n", - " [-0.0041, 0.0289, 0.0106, ..., 0.0144, -0.0070, 0.0104],\n", - " ...,\n", - " [ 0.0086, 0.0079, 0.0155, ..., 0.0037, -0.0242, 0.0091],\n", - " [-0.0320, 0.0084, -0.0508, ..., 0.0003, -0.0120, 0.0129],\n", - " [ 0.0079, 0.0185, 0.0285, ..., -0.0324, 0.0444, -0.0147]])),\n", - " ('model.layers.4.mlp.up_proj.weight',\n", - " tensor([[ 3.4382e-03, 1.9171e-02, 4.1226e-03, ..., 1.3158e-02,\n", - " 3.6365e-02, -8.1017e-03],\n", - " [ 1.8713e-02, -2.7732e-03, 3.1982e-02, ..., -8.5724e-03,\n", - " -3.1505e-02, 2.1047e-03],\n", - " [ 1.2329e-02, 1.8352e-03, 9.2540e-03, ..., 2.9880e-02,\n", - " -2.7856e-04, -8.7440e-04],\n", - " ...,\n", - " [-2.2330e-02, -2.0716e-02, 9.0004e-05, ..., -1.6298e-02,\n", - " -1.9620e-02, 2.5112e-02],\n", - " [ 7.1659e-03, 1.2942e-02, 1.0291e-03, ..., -1.0113e-02,\n", - " -1.6838e-03, 2.0189e-02],\n", - " [ 7.2108e-03, 3.1229e-02, 2.2533e-03, ..., -2.0148e-02,\n", - " -1.3502e-02, -1.8923e-02]])),\n", - " ('model.layers.4.mlp.down_proj.weight',\n", - " tensor([[ 0.0140, -0.0129, 0.0005, ..., -0.0068, -0.0335, 0.0172],\n", - " [-0.0175, -0.0011, 0.0114, ..., -0.0087, -0.0048, -0.0231],\n", - " [-0.0053, -0.0079, -0.0172, ..., -0.0125, -0.0200, 0.0127],\n", - " ...,\n", - " [ 0.0321, -0.0039, 0.0142, ..., 0.0384, 0.0054, 0.0321],\n", - " [ 0.0041, -0.0150, 0.0141, ..., 0.0049, -0.0348, -0.0028],\n", - " [ 0.0176, 0.0132, 0.0090, ..., -0.0117, 0.0241, 0.0417]])),\n", - " ('model.layers.4.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.4.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.5.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.5.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.5.mixer.in_proj.weight',\n", - " tensor([[ 0.0270, 0.0124, 0.0098, ..., 0.0170, -0.0225, 0.0032],\n", - " [ 0.0245, -0.0008, 0.0226, ..., 0.0219, -0.0219, 0.0087],\n", - " [-0.0175, 0.0181, 0.0124, ..., 0.0038, -0.0094, 0.0079],\n", - " ...,\n", - " [-0.0080, -0.0011, 0.0316, ..., -0.0012, 0.0254, 0.0251],\n", - " [-0.0141, -0.0159, -0.0069, ..., 0.0147, -0.0161, -0.0093],\n", - " [ 0.0252, 0.0125, 0.0174, ..., -0.0065, 0.0110, 0.0272]])),\n", - " ('model.layers.5.mixer.conv1d.weight',\n", - " tensor([[[ 0.0684, -0.4353, 0.3899, 0.3199]],\n", - " \n", - " [[ 0.4136, 0.4306, -0.4871, 0.4781]],\n", - " \n", - " [[-0.2516, 0.2109, 0.3891, 0.1501]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.0781, -0.0675, -0.2995, -0.1805]],\n", - " \n", - " [[-0.3360, -0.4148, 0.1846, -0.1013]],\n", - " \n", - " [[ 0.1725, 0.1929, -0.0337, 0.1375]]])),\n", - " ('model.layers.5.mixer.conv1d.bias',\n", - " tensor([-0.4975, -0.0629, -0.2420, ..., -0.2253, 0.2512, 0.2788])),\n", - " ('model.layers.5.mixer.out_proj.weight',\n", - " tensor([[ 1.4306e-02, 1.3230e-02, -2.4141e-02, ..., 1.1763e-02,\n", - " 7.0706e-03, -4.7970e-03],\n", - " [ 2.7478e-02, 1.5179e-03, 1.9229e-02, ..., 1.0928e-02,\n", - " 2.2802e-02, -2.9729e-03],\n", - " [ 1.0169e-02, -1.0741e-02, 2.0628e-02, ..., -1.8109e-02,\n", - " -4.2582e-03, 2.4007e-02],\n", - " ...,\n", - " [-3.2843e-03, 3.7835e-03, -6.7958e-03, ..., -2.6205e-02,\n", - " -2.0391e-02, 5.3912e-03],\n", - " [ 1.2515e-02, -6.4975e-03, 9.9616e-05, ..., 1.0444e-02,\n", - " -2.0596e-02, -8.2915e-03],\n", - " [ 1.7899e-02, 2.0418e-02, -1.9891e-02, ..., -6.6709e-03,\n", - " -3.8566e-02, 2.7005e-02]])),\n", - " ('model.layers.5.mlp.gate_proj.weight',\n", - " tensor([[-2.3807e-03, 2.2714e-03, 2.2736e-05, ..., -2.3039e-03,\n", - " 3.6159e-02, -1.7253e-02],\n", - " [ 3.6929e-02, -6.2031e-03, 1.3606e-02, ..., 2.3592e-02,\n", - " 4.4487e-03, -9.6723e-03],\n", - " [ 4.7507e-02, 2.6413e-02, 1.6759e-02, ..., 1.1910e-02,\n", - " 1.2872e-02, -1.0443e-02],\n", - " ...,\n", - " [-2.0354e-02, -3.9074e-03, 9.7952e-03, ..., 1.0730e-02,\n", - " 2.8752e-02, -8.0048e-03],\n", - " [ 2.5331e-02, -9.9732e-03, 1.0772e-02, ..., 2.0420e-02,\n", - " -3.2179e-02, -1.6437e-02],\n", - " [-3.4425e-02, -1.4578e-02, 2.9686e-03, ..., 4.5907e-02,\n", - " 7.7639e-03, -2.2494e-03]])),\n", - " ('model.layers.5.mlp.up_proj.weight',\n", - " tensor([[ 1.5868e-02, -1.9222e-02, -1.2880e-03, ..., 8.3353e-03,\n", - " -1.8538e-02, 6.7395e-03],\n", - " [-1.8051e-02, -5.0142e-02, -2.2177e-03, ..., -9.3852e-03,\n", - " -3.0374e-02, 2.5795e-02],\n", - " [-1.1737e-02, 2.6278e-02, -2.3205e-02, ..., -1.8399e-03,\n", - " 1.4115e-02, -2.6438e-02],\n", - " ...,\n", - " [ 2.7706e-02, -2.5067e-03, -8.7058e-03, ..., 2.1662e-03,\n", - " -4.9858e-02, -1.1575e-02],\n", - " [-9.5670e-04, 2.1698e-02, -5.4794e-03, ..., -1.0661e-02,\n", - " 1.8568e-02, 5.2615e-03],\n", - " [ 1.0739e-03, 2.2945e-02, 3.0835e-02, ..., 4.1212e-03,\n", - " 1.2643e-02, -1.1568e-05]])),\n", - " ('model.layers.5.mlp.down_proj.weight',\n", - " tensor([[ 0.0052, -0.0343, 0.0072, ..., 0.0004, 0.0320, 0.0362],\n", - " [ 0.0171, -0.0238, -0.0316, ..., 0.0231, 0.0377, 0.0141],\n", - " [-0.0205, 0.0152, 0.0002, ..., -0.0061, -0.0353, -0.0138],\n", - " ...,\n", - " [-0.0039, -0.0039, 0.0326, ..., -0.0208, 0.0160, 0.0185],\n", - " [ 0.0176, -0.0300, -0.0024, ..., -0.0292, -0.0254, -0.0366],\n", - " [ 0.0361, 0.0243, -0.0253, ..., -0.0036, -0.0099, -0.0133]])),\n", - " ('model.layers.5.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.5.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.6.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.6.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.6.mixer.in_proj.weight',\n", - " tensor([[-0.0505, -0.0650, 0.0059, ..., 0.0060, 0.0347, 0.0149],\n", - " [-0.0216, 0.0057, -0.0281, ..., -0.0162, 0.0081, 0.0016],\n", - " [-0.0339, -0.0314, 0.0253, ..., 0.0030, 0.0139, -0.0039],\n", - " ...,\n", - " [ 0.0355, -0.0238, -0.0015, ..., 0.0063, 0.0284, -0.0089],\n", - " [ 0.0093, -0.0381, -0.0261, ..., -0.0170, -0.0170, -0.0288],\n", - " [-0.0228, -0.0110, 0.0107, ..., 0.0300, 0.0010, 0.0141]])),\n", - " ('model.layers.6.mixer.conv1d.weight',\n", - " tensor([[[ 0.4364, 0.2888, 0.2343, 0.3226]],\n", - " \n", - " [[ 0.2804, 0.3558, 0.4061, -0.0480]],\n", - " \n", - " [[ 0.4964, 0.0709, 0.0748, 0.0971]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.4291, 0.2445, -0.3121, 0.4013]],\n", - " \n", - " [[-0.1590, -0.1516, 0.0804, 0.2009]],\n", - " \n", - " [[ 0.1686, 0.0492, -0.2932, 0.1381]]])),\n", - " ('model.layers.6.mixer.conv1d.bias',\n", - " tensor([ 0.4241, -0.0500, 0.3393, ..., 0.1598, -0.4924, -0.3241])),\n", - " ('model.layers.6.mixer.out_proj.weight',\n", - " tensor([[ 0.0026, 0.0272, 0.0005, ..., 0.0434, -0.0293, -0.0105],\n", - " [ 0.0323, -0.0515, 0.0107, ..., -0.0406, 0.0252, -0.0038],\n", - " [-0.0156, -0.0078, 0.0173, ..., 0.0312, -0.0014, -0.0014],\n", - " ...,\n", - " [ 0.0014, -0.0522, -0.0154, ..., 0.0090, -0.0050, -0.0049],\n", - " [ 0.0350, 0.0099, -0.0014, ..., -0.0008, -0.0185, -0.0033],\n", - " [ 0.0134, 0.0002, 0.0325, ..., -0.0129, 0.0165, -0.0265]])),\n", - " ('model.layers.6.mlp.gate_proj.weight',\n", - " tensor([[-0.0011, 0.0202, 0.0236, ..., -0.0137, -0.0063, 0.0085],\n", - " [ 0.0163, 0.0261, 0.0120, ..., -0.0003, -0.0254, 0.0001],\n", - " [ 0.0318, -0.0121, 0.0103, ..., -0.0053, 0.0194, 0.0530],\n", - " ...,\n", - " [ 0.0039, 0.0228, -0.0147, ..., 0.0027, 0.0092, -0.0033],\n", - " [-0.0040, 0.0144, 0.0038, ..., -0.0106, -0.0022, 0.0094],\n", - " [ 0.0220, 0.0296, 0.0550, ..., 0.0079, -0.0135, -0.0092]])),\n", - " ('model.layers.6.mlp.up_proj.weight',\n", - " tensor([[ 0.0061, -0.0291, -0.0133, ..., 0.0054, -0.0049, -0.0028],\n", - " [-0.0032, -0.0201, 0.0218, ..., -0.0155, -0.0264, 0.0496],\n", - " [-0.0046, 0.0384, -0.0093, ..., 0.0356, -0.0245, 0.0175],\n", - " ...,\n", - " [-0.0111, -0.0092, -0.0143, ..., 0.0010, -0.0453, 0.0024],\n", - " [ 0.0078, -0.0025, 0.0227, ..., -0.0130, 0.0118, 0.0095],\n", - " [ 0.0234, -0.0114, -0.0102, ..., -0.0179, -0.0066, -0.0115]])),\n", - " ('model.layers.6.mlp.down_proj.weight',\n", - " tensor([[ 3.6976e-02, 1.7124e-02, -2.1290e-02, ..., -2.5206e-02,\n", - " 4.8023e-03, 9.8474e-03],\n", - " [-7.2866e-03, -5.4149e-03, -2.2242e-03, ..., -8.1606e-03,\n", - " -9.5275e-04, -1.8121e-02],\n", - " [-8.3493e-03, 1.2509e-02, 1.0773e-02, ..., 2.7061e-02,\n", - " 2.8131e-03, 5.8219e-03],\n", - " ...,\n", - " [ 8.7099e-03, 3.9196e-02, -3.5129e-03, ..., -2.3595e-02,\n", - " -8.3965e-03, 2.0074e-02],\n", - " [-2.7467e-02, -2.8721e-03, -2.2291e-02, ..., 9.7135e-03,\n", - " 3.4947e-02, -2.2158e-02],\n", - " [ 6.1744e-03, -4.7684e-03, 4.6690e-04, ..., -3.2948e-03,\n", - " 4.0735e-05, 3.3651e-02]])),\n", - " ('model.layers.6.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.6.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.7.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.7.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.7.mixer.in_proj.weight',\n", - " tensor([[-0.0045, -0.0288, 0.0362, ..., -0.0092, -0.0026, 0.0051],\n", - " [ 0.0160, 0.0139, 0.0057, ..., 0.0121, 0.0071, 0.0134],\n", - " [ 0.0062, 0.0181, 0.0161, ..., -0.0284, -0.0014, -0.0171],\n", - " ...,\n", - " [-0.0053, 0.0067, 0.0095, ..., -0.0175, 0.0235, 0.0125],\n", - " [-0.0048, 0.0041, 0.0038, ..., 0.0099, 0.0194, 0.0124],\n", - " [ 0.0131, 0.0073, -0.0284, ..., 0.0138, -0.0218, 0.0019]])),\n", - " ('model.layers.7.mixer.conv1d.weight',\n", - " tensor([[[ 0.2528, -0.0556, -0.3225, 0.1327]],\n", - " \n", - " [[-0.0437, 0.4941, -0.4075, 0.1062]],\n", - " \n", - " [[-0.3428, 0.2675, 0.1871, 0.0260]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.0409, -0.4458, 0.4488, 0.2841]],\n", - " \n", - " [[-0.2370, -0.3965, 0.0656, -0.1339]],\n", - " \n", - " [[ 0.4677, 0.0073, 0.3741, 0.1525]]])),\n", - " ('model.layers.7.mixer.conv1d.bias',\n", - " tensor([-0.1844, -0.1347, 0.0043, ..., -0.3839, -0.2167, -0.4637])),\n", - " ('model.layers.7.mixer.out_proj.weight',\n", - " tensor([[-2.8471e-02, 3.9783e-03, 6.0125e-03, ..., -1.6079e-02,\n", - " 1.4225e-02, 2.8166e-02],\n", - " [ 5.4680e-03, -5.1414e-03, 5.3077e-05, ..., 1.8734e-02,\n", - " 3.7454e-03, 1.7579e-02],\n", - " [-1.2955e-02, 1.4954e-02, 6.4922e-03, ..., -2.6830e-02,\n", - " 1.4766e-02, -1.8002e-02],\n", - " ...,\n", - " [ 1.7150e-02, 4.6781e-02, -1.1136e-02, ..., 4.7242e-03,\n", - " -1.3072e-02, -1.0412e-02],\n", - " [ 5.5498e-03, -3.0803e-02, -2.4880e-02, ..., -4.2644e-03,\n", - " -1.1047e-02, 1.5815e-02],\n", - " [ 1.7242e-02, 2.7994e-02, -4.8186e-04, ..., -2.2003e-02,\n", - " -2.1834e-02, -2.1826e-02]])),\n", - " ('model.layers.7.mlp.gate_proj.weight',\n", - " tensor([[-0.0302, -0.0160, -0.0341, ..., -0.0121, 0.0007, -0.0338],\n", - " [-0.0186, 0.0257, -0.0154, ..., 0.0153, -0.0029, 0.0163],\n", - " [ 0.0170, 0.0223, -0.0185, ..., -0.0020, 0.0061, 0.0174],\n", - " ...,\n", - " [-0.0044, 0.0044, 0.0077, ..., -0.0183, 0.0041, -0.0003],\n", - " [ 0.0168, 0.0149, -0.0221, ..., 0.0112, 0.0357, 0.0042],\n", - " [ 0.0310, -0.0217, 0.0070, ..., -0.0394, -0.0065, 0.0204]])),\n", - " ('model.layers.7.mlp.up_proj.weight',\n", - " tensor([[-0.0031, -0.0110, 0.0091, ..., 0.0152, -0.0013, 0.0096],\n", - " [ 0.0013, 0.0354, -0.0037, ..., 0.0130, 0.0204, 0.0262],\n", - " [-0.0075, -0.0044, 0.0207, ..., 0.0057, 0.0115, 0.0151],\n", - " ...,\n", - " [-0.0015, 0.0095, -0.0100, ..., -0.0150, 0.0105, -0.0350],\n", - " [-0.0300, -0.0092, -0.0176, ..., -0.0113, 0.0164, -0.0117],\n", - " [-0.0291, -0.0085, 0.0058, ..., 0.0386, -0.0174, -0.0092]])),\n", - " ('model.layers.7.mlp.down_proj.weight',\n", - " tensor([[-0.0276, 0.0017, -0.0217, ..., 0.0302, -0.0079, -0.0003],\n", - " [ 0.0379, 0.0052, 0.0052, ..., 0.0145, 0.0139, -0.0143],\n", - " [ 0.0176, -0.0028, 0.0172, ..., -0.0205, -0.0165, -0.0040],\n", - " ...,\n", - " [ 0.0095, -0.0139, 0.0077, ..., -0.0080, 0.0339, 0.0172],\n", - " [-0.0177, 0.0009, -0.0245, ..., 0.0040, 0.0258, 0.0202],\n", - " [-0.0064, -0.0270, 0.0041, ..., -0.0133, -0.0040, 0.0038]])),\n", - " ('model.layers.7.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.7.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.8.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.8.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.8.mixer.in_proj.weight',\n", - " tensor([[ 0.0050, 0.0270, -0.0196, ..., -0.0121, -0.0090, 0.0083],\n", - " [-0.0083, -0.0177, 0.0159, ..., 0.0298, -0.0202, -0.0265],\n", - " [ 0.0058, 0.0186, 0.0125, ..., -0.0067, -0.0255, 0.0298],\n", - " ...,\n", - " [-0.0164, 0.0012, 0.0023, ..., -0.0355, 0.0347, -0.0011],\n", - " [-0.0371, 0.0033, 0.0345, ..., -0.0097, 0.0019, 0.0185],\n", - " [-0.0322, -0.0160, 0.0072, ..., -0.0195, -0.0229, 0.0118]])),\n", - " ('model.layers.8.mixer.conv1d.weight',\n", - " tensor([[[-0.0520, 0.3004, -0.1990, 0.2512]],\n", - " \n", - " [[-0.4120, -0.0055, 0.1484, -0.3316]],\n", - " \n", - " [[ 0.3939, -0.0567, 0.1432, 0.1880]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.2849, 0.2494, -0.2141, -0.3375]],\n", - " \n", - " [[-0.2823, -0.2402, 0.2228, 0.2331]],\n", - " \n", - " [[ 0.1914, 0.4269, 0.1228, -0.3408]]])),\n", - " ('model.layers.8.mixer.conv1d.bias',\n", - " tensor([0.1304, 0.2065, 0.3084, ..., 0.3863, 0.4883, 0.4724])),\n", - " ('model.layers.8.mixer.out_proj.weight',\n", - " tensor([[ 0.0008, -0.0019, 0.0084, ..., -0.0003, 0.0045, 0.0024],\n", - " [ 0.0137, -0.0003, -0.0031, ..., 0.0013, 0.0131, 0.0090],\n", - " [ 0.0095, 0.0488, -0.0355, ..., 0.0344, -0.0229, -0.0150],\n", - " ...,\n", - " [ 0.0029, 0.0164, -0.0380, ..., -0.0005, -0.0031, 0.0127],\n", - " [-0.0039, 0.0283, 0.0295, ..., 0.0271, -0.0105, -0.0158],\n", - " [-0.0057, -0.0178, 0.0129, ..., 0.0323, -0.0091, 0.0178]])),\n", - " ('model.layers.8.mlp.gate_proj.weight',\n", - " tensor([[-0.0047, 0.0037, -0.0129, ..., 0.0255, -0.0118, 0.0084],\n", - " [ 0.0418, -0.0020, 0.0205, ..., 0.0161, 0.0306, 0.0250],\n", - " [ 0.0011, 0.0144, 0.0204, ..., -0.0007, 0.0298, -0.0067],\n", - " ...,\n", - " [-0.0536, -0.0083, -0.0049, ..., -0.0028, 0.0301, -0.0205],\n", - " [ 0.0031, 0.0139, 0.0070, ..., 0.0120, 0.0004, -0.0226],\n", - " [ 0.0114, -0.0173, 0.0212, ..., -0.0413, -0.0069, 0.0007]])),\n", - " ('model.layers.8.mlp.up_proj.weight',\n", - " tensor([[-0.0005, 0.0028, -0.0137, ..., 0.0078, 0.0348, 0.0006],\n", - " [-0.0020, 0.0300, -0.0056, ..., -0.0258, -0.0130, -0.0212],\n", - " [-0.0135, -0.0111, 0.0151, ..., 0.0043, -0.0426, -0.0109],\n", - " ...,\n", - " [ 0.0273, 0.0057, -0.0108, ..., -0.0205, 0.0005, -0.0239],\n", - " [ 0.0226, 0.0325, -0.0187, ..., 0.0069, -0.0132, -0.0002],\n", - " [ 0.0280, -0.0007, -0.0047, ..., 0.0159, -0.0054, -0.0172]])),\n", - " ('model.layers.8.mlp.down_proj.weight',\n", - " tensor([[-0.0091, 0.0072, 0.0030, ..., 0.0025, -0.0159, -0.0277],\n", - " [ 0.0159, -0.0260, -0.0076, ..., -0.0059, -0.0129, 0.0358],\n", - " [ 0.0026, -0.0357, -0.0138, ..., -0.0326, -0.0291, 0.0010],\n", - " ...,\n", - " [-0.0237, 0.0272, -0.0130, ..., -0.0280, 0.0097, -0.0563],\n", - " [ 0.0092, 0.0056, 0.0079, ..., -0.0224, 0.0039, -0.0054],\n", - " [-0.0109, -0.0241, -0.0223, ..., -0.0187, 0.0190, 0.0082]])),\n", - " ('model.layers.8.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.8.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.9.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.9.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.9.mixer.in_proj.weight',\n", - " tensor([[ 4.9824e-02, 5.7576e-03, -5.1022e-03, ..., -2.5615e-02,\n", - " 7.1750e-04, 1.5247e-02],\n", - " [-2.8065e-02, -1.2649e-02, -2.3566e-02, ..., 1.7742e-02,\n", - " -1.1202e-02, -2.1476e-02],\n", - " [ 2.0911e-02, 1.6496e-02, -1.9818e-02, ..., 4.0223e-02,\n", - " 1.8544e-02, -2.3633e-02],\n", - " ...,\n", - " [-4.3387e-02, -1.6504e-02, 2.2008e-02, ..., -2.5138e-03,\n", - " -5.6073e-03, -4.8212e-03],\n", - " [-1.9964e-05, -1.5835e-02, 1.2977e-02, ..., 4.1913e-03,\n", - " 4.5898e-02, -3.5822e-02],\n", - " [ 3.1376e-02, -5.4614e-03, -2.5093e-02, ..., -3.7903e-03,\n", - " 1.3560e-02, 3.3366e-02]])),\n", - " ('model.layers.9.mixer.conv1d.weight',\n", - " tensor([[[ 0.1986, -0.1666, -0.4140, -0.4607]],\n", - " \n", - " [[-0.3454, -0.3973, 0.2169, -0.2138]],\n", - " \n", - " [[ 0.2006, -0.3736, 0.3944, -0.0589]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.4604, 0.1224, -0.2571, -0.0286]],\n", - " \n", - " [[-0.2723, -0.1617, 0.3483, 0.2299]],\n", - " \n", - " [[ 0.4866, 0.2559, 0.3969, 0.0554]]])),\n", - " ('model.layers.9.mixer.conv1d.bias',\n", - " tensor([ 0.3388, 0.4633, -0.3762, ..., -0.3491, -0.2971, 0.0494])),\n", - " ('model.layers.9.mixer.out_proj.weight',\n", - " tensor([[ 0.0023, -0.0181, 0.0358, ..., 0.0243, 0.0070, -0.0183],\n", - " [ 0.0006, 0.0065, 0.0057, ..., -0.0351, -0.0107, 0.0132],\n", - " [ 0.0153, -0.0038, 0.0059, ..., -0.0285, -0.0247, -0.0104],\n", - " ...,\n", - " [ 0.0244, -0.0120, 0.0064, ..., -0.0133, 0.0263, 0.0016],\n", - " [ 0.0056, -0.0111, 0.0029, ..., -0.0017, -0.0172, -0.0071],\n", - " [-0.0056, -0.0192, -0.0238, ..., 0.0245, -0.0102, -0.0331]])),\n", - " ('model.layers.9.mlp.gate_proj.weight',\n", - " tensor([[-0.0132, 0.0014, -0.0413, ..., -0.0254, -0.0245, 0.0031],\n", - " [-0.0195, -0.0107, -0.0192, ..., 0.0012, -0.0026, 0.0148],\n", - " [-0.0074, -0.0070, -0.0078, ..., 0.0013, -0.0011, -0.0111],\n", - " ...,\n", - " [-0.0137, 0.0302, 0.0084, ..., -0.0063, -0.0065, 0.0240],\n", - " [ 0.0072, 0.0134, 0.0161, ..., 0.0122, 0.0182, 0.0137],\n", - " [ 0.0079, 0.0008, 0.0160, ..., 0.0281, 0.0226, 0.0058]])),\n", - " ('model.layers.9.mlp.up_proj.weight',\n", - " tensor([[ 0.0078, 0.0153, -0.0155, ..., 0.0153, -0.0164, -0.0140],\n", - " [-0.0072, -0.0050, 0.0030, ..., 0.0146, -0.0148, -0.0080],\n", - " [ 0.0165, -0.0078, 0.0005, ..., -0.0545, -0.0096, 0.0296],\n", - " ...,\n", - " [-0.0253, 0.0183, -0.0081, ..., -0.0061, 0.0270, -0.0003],\n", - " [-0.0015, -0.0320, 0.0361, ..., -0.0087, 0.0341, -0.0157],\n", - " [ 0.0041, 0.0102, -0.0195, ..., -0.0441, -0.0106, 0.0275]])),\n", - " ('model.layers.9.mlp.down_proj.weight',\n", - " tensor([[-6.3367e-02, -1.8214e-02, 5.7221e-03, ..., 2.1307e-02,\n", - " -3.0707e-02, -1.3281e-02],\n", - " [-7.7457e-05, -9.1894e-05, 6.8686e-03, ..., -4.7175e-03,\n", - " -1.1585e-03, -2.7604e-02],\n", - " [ 2.9301e-02, -5.9431e-03, -2.5356e-03, ..., -2.7858e-02,\n", - " 1.1647e-02, 1.1245e-02],\n", - " ...,\n", - " [-1.0442e-02, -9.6151e-03, -3.6635e-02, ..., -1.1052e-02,\n", - " -4.5122e-03, 4.0012e-03],\n", - " [ 3.2950e-02, -1.3836e-03, -7.8318e-03, ..., -1.2788e-03,\n", - " 2.3422e-02, -3.2098e-02],\n", - " [-9.2294e-03, 1.3838e-02, -2.0327e-02, ..., -3.8760e-02,\n", - " 2.2118e-02, 1.0696e-02]])),\n", - " ('model.layers.9.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.9.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.10.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.10.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.10.mixer.in_proj.weight',\n", - " tensor([[ 0.0096, -0.0159, 0.0141, ..., 0.0111, 0.0218, 0.0220],\n", - " [-0.0381, -0.0015, 0.0126, ..., -0.0066, -0.0034, -0.0119],\n", - " [ 0.0223, 0.0032, -0.0195, ..., -0.0107, -0.0018, 0.0059],\n", - " ...,\n", - " [-0.0256, -0.0170, -0.0362, ..., -0.0007, -0.0039, 0.0075],\n", - " [ 0.0136, -0.0045, 0.0128, ..., -0.0017, 0.0083, -0.0004],\n", - " [-0.0246, -0.0021, 0.0073, ..., 0.0020, 0.0071, 0.0090]])),\n", - " ('model.layers.10.mixer.conv1d.weight',\n", - " tensor([[[ 0.0463, -0.4497, -0.0679, -0.2209]],\n", - " \n", - " [[-0.3805, 0.4459, 0.1999, -0.4996]],\n", - " \n", - " [[ 0.1529, 0.1789, -0.1535, 0.1824]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.1087, -0.4478, -0.0420, 0.3437]],\n", - " \n", - " [[-0.2809, -0.4617, 0.3209, 0.4873]],\n", - " \n", - " [[ 0.1139, -0.0060, -0.0219, 0.0853]]])),\n", - " ('model.layers.10.mixer.conv1d.bias',\n", - " tensor([ 0.1364, -0.0475, 0.0849, ..., 0.1928, 0.2075, 0.1058])),\n", - " ('model.layers.10.mixer.out_proj.weight',\n", - " tensor([[-0.0164, -0.0188, 0.0174, ..., -0.0106, -0.0107, -0.0036],\n", - " [ 0.0048, -0.0016, -0.0444, ..., -0.0182, -0.0264, -0.0038],\n", - " [ 0.0089, -0.0225, -0.0002, ..., -0.0141, -0.0008, -0.0037],\n", - " ...,\n", - " [-0.0005, 0.0159, 0.0033, ..., 0.0187, -0.0064, 0.0233],\n", - " [-0.0050, 0.0296, 0.0147, ..., -0.0018, 0.0137, -0.0346],\n", - " [-0.0064, -0.0132, -0.0434, ..., -0.0173, -0.0113, -0.0175]])),\n", - " ('model.layers.10.mlp.gate_proj.weight',\n", - " tensor([[-0.0174, -0.0053, -0.0325, ..., -0.0072, -0.0280, 0.0033],\n", - " [ 0.0006, -0.0160, 0.0346, ..., 0.0019, 0.0059, 0.0198],\n", - " [ 0.0231, -0.0187, 0.0115, ..., 0.0085, 0.0080, 0.0061],\n", - " ...,\n", - " [ 0.0153, 0.0241, -0.0184, ..., 0.0089, -0.0242, 0.0010],\n", - " [-0.0019, -0.0322, 0.0011, ..., -0.0097, -0.0305, 0.0065],\n", - " [-0.0107, 0.0240, 0.0168, ..., 0.0226, -0.0238, 0.0117]])),\n", - " ('model.layers.10.mlp.up_proj.weight',\n", - " tensor([[-0.0072, 0.0352, 0.0282, ..., -0.0025, -0.0114, 0.0129],\n", - " [-0.0102, 0.0196, 0.0760, ..., 0.0461, -0.0058, -0.0112],\n", - " [-0.0271, 0.0323, -0.0069, ..., 0.0133, -0.0371, -0.0619],\n", - " ...,\n", - " [ 0.0100, 0.0011, 0.0262, ..., -0.0232, 0.0217, 0.0002],\n", - " [ 0.0151, -0.0266, -0.0074, ..., 0.0096, 0.0036, 0.0033],\n", - " [ 0.0004, 0.0103, 0.0363, ..., -0.0095, -0.0309, -0.0059]])),\n", - " ('model.layers.10.mlp.down_proj.weight',\n", - " tensor([[ 0.0124, -0.0225, -0.0294, ..., 0.0280, 0.0056, 0.0231],\n", - " [ 0.0124, -0.0030, 0.0014, ..., 0.0323, 0.0094, -0.0034],\n", - " [-0.0078, 0.0041, -0.0056, ..., 0.0241, -0.0278, -0.0152],\n", - " ...,\n", - " [-0.0044, 0.0025, -0.0161, ..., -0.0075, -0.0126, 0.0014],\n", - " [-0.0109, -0.0050, 0.0327, ..., -0.0300, -0.0048, 0.0284],\n", - " [ 0.0050, -0.0183, 0.0086, ..., -0.0072, 0.0139, -0.0010]])),\n", - " ('model.layers.10.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.10.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.11.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.11.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.11.mixer.in_proj.weight',\n", - " tensor([[-0.0133, 0.0225, 0.0486, ..., -0.0214, -0.0120, -0.0150],\n", - " [ 0.0183, 0.0020, 0.0079, ..., -0.0163, 0.0016, -0.0214],\n", - " [-0.0276, -0.0112, 0.0121, ..., -0.0057, -0.0143, -0.0462],\n", - " ...,\n", - " [-0.0142, -0.0080, -0.0194, ..., 0.0087, -0.0212, -0.0140],\n", - " [ 0.0060, -0.0005, -0.0171, ..., -0.0017, 0.0223, 0.0169],\n", - " [-0.0290, -0.0016, 0.0117, ..., 0.0037, 0.0047, 0.0152]])),\n", - " ('model.layers.11.mixer.conv1d.weight',\n", - " tensor([[[-0.2822, -0.4216, 0.4786, 0.0802]],\n", - " \n", - " [[-0.3671, 0.1761, -0.2686, 0.1631]],\n", - " \n", - " [[-0.3902, -0.2811, -0.0748, 0.4662]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.1623, 0.2871, -0.4585, 0.4755]],\n", - " \n", - " [[-0.0260, 0.4541, -0.2983, 0.2297]],\n", - " \n", - " [[-0.2991, -0.3590, -0.3256, -0.1434]]])),\n", - " ('model.layers.11.mixer.conv1d.bias',\n", - " tensor([ 0.1218, -0.0542, 0.3485, ..., 0.0528, 0.2711, -0.2811])),\n", - " ('model.layers.11.mixer.out_proj.weight',\n", - " tensor([[ 0.0032, 0.0028, -0.0122, ..., -0.0299, -0.0105, 0.0021],\n", - " [-0.0466, -0.0170, -0.0017, ..., 0.0156, -0.0287, 0.0066],\n", - " [ 0.0016, 0.0054, -0.0071, ..., -0.0240, 0.0215, -0.0046],\n", - " ...,\n", - " [-0.0210, 0.0034, -0.0267, ..., 0.0461, -0.0076, -0.0016],\n", - " [-0.0012, -0.0101, 0.0196, ..., 0.0121, -0.0043, -0.0143],\n", - " [-0.0067, 0.0086, 0.0134, ..., 0.0080, 0.0255, 0.0225]])),\n", - " ('model.layers.11.mlp.gate_proj.weight',\n", - " tensor([[ 0.0179, -0.0429, -0.0134, ..., 0.0110, 0.0368, -0.0259],\n", - " [ 0.0013, -0.0231, 0.0072, ..., -0.0056, -0.0012, -0.0037],\n", - " [-0.0172, -0.0162, 0.0088, ..., -0.0175, 0.0079, -0.0065],\n", - " ...,\n", - " [ 0.0287, -0.0289, 0.0045, ..., 0.0039, 0.0269, 0.0199],\n", - " [ 0.0043, -0.0202, -0.0261, ..., 0.0104, -0.0161, -0.0057],\n", - " [-0.0154, 0.0085, 0.0061, ..., 0.0208, 0.0001, 0.0166]])),\n", - " ('model.layers.11.mlp.up_proj.weight',\n", - " tensor([[-0.0107, 0.0328, 0.0065, ..., -0.0190, -0.0082, -0.0047],\n", - " [-0.0001, 0.0102, 0.0310, ..., -0.0396, -0.0278, -0.0095],\n", - " [-0.0288, 0.0052, 0.0137, ..., -0.0220, 0.0007, -0.0170],\n", - " ...,\n", - " [ 0.0213, -0.0074, -0.0033, ..., 0.0183, 0.0336, -0.0180],\n", - " [-0.0098, -0.0162, 0.0486, ..., 0.0191, 0.0064, 0.0269],\n", - " [-0.0251, 0.0081, 0.0053, ..., 0.0110, 0.0023, 0.0041]])),\n", - " ('model.layers.11.mlp.down_proj.weight',\n", - " tensor([[ 0.0166, -0.0410, 0.0066, ..., -0.0273, 0.0220, 0.0184],\n", - " [ 0.0092, 0.0087, -0.0136, ..., 0.0013, -0.0205, 0.0247],\n", - " [-0.0252, -0.0040, -0.0112, ..., -0.0331, 0.0201, -0.0038],\n", - " ...,\n", - " [ 0.0072, 0.0190, 0.0089, ..., 0.0098, -0.0235, -0.0141],\n", - " [-0.0045, -0.0381, -0.0134, ..., 0.0171, -0.0077, -0.0180],\n", - " [ 0.0109, 0.0060, 0.0048, ..., -0.0108, -0.0122, 0.0110]])),\n", - " ('model.layers.11.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.11.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.12.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.12.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.12.mixer.in_proj.weight',\n", - " tensor([[ 0.0043, 0.0138, 0.0138, ..., -0.0042, 0.0121, -0.0190],\n", - " [ 0.0002, -0.0199, 0.0315, ..., 0.0170, 0.0051, -0.0062],\n", - " [-0.0053, 0.0043, 0.0283, ..., -0.0087, 0.0069, -0.0160],\n", - " ...,\n", - " [-0.0313, 0.0200, 0.0036, ..., 0.0147, 0.0153, 0.0098],\n", - " [-0.0157, 0.0120, -0.0112, ..., 0.0166, -0.0005, 0.0066],\n", - " [-0.0271, 0.0037, 0.0163, ..., 0.0304, 0.0023, 0.0083]])),\n", - " ('model.layers.12.mixer.conv1d.weight',\n", - " tensor([[[-0.4295, -0.2474, -0.2324, -0.2138]],\n", - " \n", - " [[ 0.3607, -0.4824, 0.1667, 0.1348]],\n", - " \n", - " [[ 0.3596, 0.1167, 0.1089, -0.4010]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.3527, -0.3346, -0.3755, 0.1450]],\n", - " \n", - " [[-0.1921, -0.0632, -0.4885, -0.3986]],\n", - " \n", - " [[ 0.1950, 0.3037, -0.1630, 0.0353]]])),\n", - " ('model.layers.12.mixer.conv1d.bias',\n", - " tensor([0.3103, 0.0451, 0.4533, ..., 0.0235, 0.1819, 0.3933])),\n", - " ('model.layers.12.mixer.out_proj.weight',\n", - " tensor([[ 0.0167, -0.0197, -0.0054, ..., 0.0096, 0.0271, -0.0118],\n", - " [ 0.0167, -0.0455, 0.0001, ..., 0.0003, 0.0265, 0.0111],\n", - " [ 0.0231, -0.0113, 0.0195, ..., -0.0171, -0.0044, -0.0244],\n", - " ...,\n", - " [ 0.0042, 0.0048, 0.0357, ..., 0.0126, -0.0288, 0.0149],\n", - " [ 0.0192, 0.0078, 0.0126, ..., 0.0029, 0.0255, -0.0203],\n", - " [-0.0054, -0.0543, 0.0039, ..., -0.0240, 0.0282, 0.0082]])),\n", - " ('model.layers.12.mlp.gate_proj.weight',\n", - " tensor([[-0.0417, -0.0193, -0.0022, ..., 0.0031, 0.0337, 0.0175],\n", - " [ 0.0215, -0.0109, -0.0657, ..., -0.0145, -0.0475, -0.0091],\n", - " [-0.0225, -0.0012, -0.0020, ..., -0.0291, 0.0097, 0.0163],\n", - " ...,\n", - " [-0.0018, 0.0048, -0.0265, ..., -0.0056, 0.0446, 0.0045],\n", - " [ 0.0270, 0.0086, -0.0110, ..., -0.0038, 0.0176, 0.0138],\n", - " [-0.0134, 0.0046, -0.0186, ..., -0.0098, 0.0191, 0.0095]])),\n", - " ('model.layers.12.mlp.up_proj.weight',\n", - " tensor([[ 0.0180, 0.0075, 0.0147, ..., 0.0142, 0.0291, -0.0303],\n", - " [-0.0079, -0.0277, -0.0151, ..., -0.0069, -0.0045, -0.0223],\n", - " [ 0.0180, -0.0087, 0.0074, ..., 0.0215, 0.0274, -0.0199],\n", - " ...,\n", - " [-0.0215, -0.0115, 0.0140, ..., -0.0283, -0.0171, -0.0229],\n", - " [ 0.0231, -0.0179, -0.0386, ..., 0.0364, 0.0311, 0.0048],\n", - " [-0.0111, 0.0079, 0.0328, ..., 0.0285, 0.0423, 0.0039]])),\n", - " ('model.layers.12.mlp.down_proj.weight',\n", - " tensor([[-0.0361, 0.0192, -0.0005, ..., -0.0151, 0.0116, -0.0068],\n", - " [ 0.0203, -0.0064, 0.0061, ..., 0.0325, -0.0004, -0.0299],\n", - " [-0.0028, 0.0131, 0.0141, ..., -0.0108, -0.0070, -0.0090],\n", - " ...,\n", - " [ 0.0165, -0.0198, -0.0242, ..., 0.0162, 0.0099, 0.0025],\n", - " [ 0.0148, 0.0056, -0.0139, ..., 0.0108, -0.0477, 0.0225],\n", - " [ 0.0156, 0.0249, -0.0287, ..., -0.0200, -0.0496, 0.0169]])),\n", - " ('model.layers.12.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.12.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.13.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.13.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.13.mixer.in_proj.weight',\n", - " tensor([[-0.0064, -0.0200, 0.0384, ..., -0.0036, 0.0158, -0.0007],\n", - " [-0.0074, 0.0105, 0.0043, ..., 0.0097, 0.0259, -0.0012],\n", - " [ 0.0297, -0.0146, -0.0012, ..., 0.0273, 0.0309, 0.0087],\n", - " ...,\n", - " [ 0.0204, -0.0063, 0.0136, ..., -0.0092, 0.0196, 0.0057],\n", - " [ 0.0195, 0.0059, 0.0228, ..., 0.0093, -0.0183, -0.0003],\n", - " [-0.0131, -0.0447, -0.0262, ..., -0.0125, 0.0237, -0.0404]])),\n", - " ('model.layers.13.mixer.conv1d.weight',\n", - " tensor([[[ 7.7458e-03, 4.9829e-01, 2.1690e-01, -2.3587e-01]],\n", - " \n", - " [[ 3.7281e-01, -4.0991e-03, 2.4588e-01, -1.1600e-01]],\n", - " \n", - " [[-4.8238e-01, -2.8961e-01, -4.4331e-02, 1.0011e-01]],\n", - " \n", - " ...,\n", - " \n", - " [[-3.6304e-01, -1.4106e-01, -3.5434e-01, 1.4923e-01]],\n", - " \n", - " [[-2.3703e-01, 3.9285e-04, -2.1456e-02, -2.5568e-01]],\n", - " \n", - " [[ 1.5303e-02, -8.3474e-03, -3.2668e-01, -4.8096e-01]]])),\n", - " ('model.layers.13.mixer.conv1d.bias',\n", - " tensor([-0.2462, 0.1532, -0.2298, ..., -0.3016, 0.1210, -0.3777])),\n", - " ('model.layers.13.mixer.out_proj.weight',\n", - " tensor([[-0.0019, 0.0103, 0.0098, ..., -0.0050, 0.0180, -0.0117],\n", - " [-0.0153, 0.0134, -0.0102, ..., 0.0327, -0.0387, 0.0025],\n", - " [ 0.0102, -0.0038, 0.0224, ..., -0.0118, 0.0234, 0.0014],\n", - " ...,\n", - " [-0.0201, 0.0233, 0.0189, ..., 0.0010, 0.0313, 0.0130],\n", - " [ 0.0193, 0.0035, -0.0253, ..., 0.0084, -0.0208, 0.0372],\n", - " [ 0.0367, -0.0029, -0.0205, ..., -0.0055, -0.0209, 0.0082]])),\n", - " ('model.layers.13.mlp.gate_proj.weight',\n", - " tensor([[ 0.0148, -0.0052, 0.0371, ..., -0.0118, 0.0397, -0.0234],\n", - " [ 0.0237, -0.0323, 0.0219, ..., 0.0098, -0.0304, 0.0165],\n", - " [ 0.0168, -0.0289, 0.0038, ..., 0.0022, 0.0174, 0.0043],\n", - " ...,\n", - " [-0.0135, 0.0258, -0.0172, ..., 0.0251, -0.0071, -0.0384],\n", - " [ 0.0005, -0.0123, 0.0116, ..., 0.0041, -0.0108, -0.0068],\n", - " [ 0.0116, 0.0069, 0.0063, ..., 0.0045, -0.0145, 0.0185]])),\n", - " ('model.layers.13.mlp.up_proj.weight',\n", - " tensor([[-0.0002, -0.0120, 0.0069, ..., 0.0005, -0.0108, -0.0284],\n", - " [ 0.0215, 0.0045, 0.0167, ..., 0.0177, -0.0030, 0.0051],\n", - " [ 0.0265, 0.0169, 0.0047, ..., 0.0069, -0.0299, 0.0196],\n", - " ...,\n", - " [ 0.0127, -0.0063, 0.0242, ..., -0.0061, -0.0263, 0.0041],\n", - " [ 0.0142, -0.0515, -0.0221, ..., -0.0369, -0.0399, -0.0210],\n", - " [ 0.0123, 0.0133, -0.0269, ..., 0.0092, -0.0177, 0.0226]])),\n", - " ('model.layers.13.mlp.down_proj.weight',\n", - " tensor([[ 0.0048, 0.0360, -0.0037, ..., 0.0169, 0.0304, -0.0162],\n", - " [ 0.0271, -0.0121, 0.0108, ..., -0.0424, 0.0293, -0.0137],\n", - " [ 0.0225, -0.0061, -0.0096, ..., 0.0075, -0.0168, 0.0142],\n", - " ...,\n", - " [ 0.0039, -0.0152, -0.0156, ..., 0.0181, 0.0105, 0.0070],\n", - " [ 0.0311, 0.0205, 0.0259, ..., -0.0025, 0.0060, -0.0125],\n", - " [ 0.0004, -0.0114, 0.0022, ..., -0.0159, -0.0290, 0.0036]])),\n", - " ('model.layers.13.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.13.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.14.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.14.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.14.mixer.in_proj.weight',\n", - " tensor([[-0.0123, 0.0054, 0.0059, ..., 0.0285, -0.0292, -0.0184],\n", - " [-0.0146, -0.0175, 0.0155, ..., -0.0206, -0.0190, -0.0172],\n", - " [ 0.0050, -0.0235, -0.0159, ..., -0.0013, -0.0102, 0.0082],\n", - " ...,\n", - " [-0.0243, -0.0013, 0.0312, ..., -0.0141, -0.0156, 0.0279],\n", - " [ 0.0018, 0.0181, -0.0188, ..., 0.0593, -0.0155, 0.0156],\n", - " [ 0.0036, 0.0182, -0.0308, ..., 0.0306, -0.0035, 0.0037]])),\n", - " ('model.layers.14.mixer.conv1d.weight',\n", - " tensor([[[-0.4608, 0.4926, -0.2625, 0.3060]],\n", - " \n", - " [[-0.0932, 0.0153, 0.2298, -0.1735]],\n", - " \n", - " [[-0.1927, 0.1979, -0.1773, 0.3277]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.0538, -0.2180, -0.4857, -0.1428]],\n", - " \n", - " [[-0.1736, 0.2405, 0.3148, -0.4481]],\n", - " \n", - " [[-0.4971, -0.1558, 0.2762, -0.1849]]])),\n", - " ('model.layers.14.mixer.conv1d.bias',\n", - " tensor([-0.2181, -0.2375, 0.0896, ..., 0.0744, 0.0857, 0.4347])),\n", - " ('model.layers.14.mixer.out_proj.weight',\n", - " tensor([[-3.8364e-04, 2.4458e-02, 5.8783e-03, ..., -1.3479e-02,\n", - " -2.4306e-02, 5.7698e-03],\n", - " [ 4.5843e-02, -3.9217e-03, -6.9897e-03, ..., 5.5401e-03,\n", - " -1.4523e-02, 1.2266e-02],\n", - " [-7.1069e-03, 5.5550e-03, 1.1359e-02, ..., 3.5839e-02,\n", - " 1.0787e-02, 8.4053e-03],\n", - " ...,\n", - " [ 3.3029e-03, 5.4333e-03, -9.3382e-03, ..., -1.7376e-02,\n", - " 1.5601e-02, -6.3227e-03],\n", - " [-6.9199e-03, -1.6950e-02, 1.5155e-03, ..., 1.2324e-02,\n", - " 1.2259e-02, 5.5500e-02],\n", - " [-1.6177e-02, -6.5257e-05, -9.3656e-03, ..., 1.0653e-02,\n", - " 1.8864e-02, -1.2508e-02]])),\n", - " ('model.layers.14.mlp.gate_proj.weight',\n", - " tensor([[ 0.0279, 0.0025, 0.0214, ..., -0.0137, -0.0042, 0.0172],\n", - " [-0.0240, -0.0150, 0.0170, ..., 0.0090, 0.0002, 0.0172],\n", - " [-0.0181, 0.0052, -0.0418, ..., 0.0106, 0.0052, -0.0264],\n", - " ...,\n", - " [-0.0295, 0.0323, 0.0387, ..., -0.0116, -0.0140, -0.0053],\n", - " [ 0.0411, 0.0189, 0.0236, ..., 0.0094, -0.0176, -0.0066],\n", - " [ 0.0004, 0.0291, 0.0402, ..., 0.0127, -0.0009, 0.0010]])),\n", - " ('model.layers.14.mlp.up_proj.weight',\n", - " tensor([[ 0.0198, -0.0115, -0.0045, ..., 0.0273, 0.0012, -0.0082],\n", - " [-0.0217, 0.0075, 0.0006, ..., 0.0047, -0.0416, -0.0011],\n", - " [ 0.0012, -0.0214, -0.0211, ..., 0.0030, -0.0176, -0.0215],\n", - " ...,\n", - " [ 0.0062, -0.0305, 0.0310, ..., 0.0044, -0.0379, 0.0155],\n", - " [-0.0062, 0.0451, 0.0167, ..., 0.0062, -0.0033, 0.0012],\n", - " [ 0.0293, -0.0186, 0.0295, ..., 0.0092, 0.0100, 0.0038]])),\n", - " ('model.layers.14.mlp.down_proj.weight',\n", - " tensor([[ 0.0019, 0.0114, -0.0202, ..., 0.0227, -0.0227, -0.0005],\n", - " [-0.0437, -0.0045, -0.0385, ..., -0.0083, -0.0135, 0.0172],\n", - " [-0.0032, -0.0024, 0.0137, ..., 0.0071, 0.0034, 0.0104],\n", - " ...,\n", - " [ 0.0210, -0.0237, -0.0166, ..., -0.0105, 0.0490, 0.0155],\n", - " [-0.0109, 0.0112, 0.0082, ..., -0.0342, -0.0133, -0.0086],\n", - " [ 0.0282, -0.0210, -0.0127, ..., -0.0047, -0.0126, 0.0103]])),\n", - " ('model.layers.14.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.14.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.15.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.15.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.15.mixer.in_proj.weight',\n", - " tensor([[-0.0098, -0.0201, -0.0033, ..., -0.0289, 0.0275, 0.0186],\n", - " [ 0.0048, 0.0075, -0.0033, ..., 0.0011, 0.0042, 0.0040],\n", - " [-0.0079, -0.0025, 0.0018, ..., -0.0051, -0.0231, -0.0022],\n", - " ...,\n", - " [ 0.0186, -0.0104, -0.0062, ..., 0.0086, -0.0007, -0.0653],\n", - " [-0.0212, 0.0034, 0.0019, ..., 0.0167, 0.0050, 0.0120],\n", - " [ 0.0066, 0.0381, -0.0225, ..., -0.0043, 0.0229, -0.0004]])),\n", - " ('model.layers.15.mixer.conv1d.weight',\n", - " tensor([[[ 0.2306, 0.2721, 0.3406, 0.4513]],\n", - " \n", - " [[ 0.0991, 0.4973, 0.0010, -0.1445]],\n", - " \n", - " [[ 0.2975, 0.4813, 0.2817, -0.0468]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.0104, -0.1473, 0.1685, -0.4390]],\n", - " \n", - " [[ 0.3669, 0.3461, 0.0845, 0.3576]],\n", - " \n", - " [[-0.1177, 0.0524, 0.4329, 0.0687]]])),\n", - " ('model.layers.15.mixer.conv1d.bias',\n", - " tensor([-0.0356, 0.4173, 0.3287, ..., -0.0141, 0.1365, 0.2086])),\n", - " ('model.layers.15.mixer.out_proj.weight',\n", - " tensor([[-0.0137, -0.0239, -0.0133, ..., -0.0177, -0.0125, -0.0015],\n", - " [ 0.0168, 0.0120, 0.0034, ..., 0.0098, 0.0098, 0.0110],\n", - " [-0.0315, 0.0447, 0.0189, ..., 0.0305, 0.0131, -0.0230],\n", - " ...,\n", - " [-0.0480, 0.0170, 0.0025, ..., 0.0317, -0.0378, -0.0236],\n", - " [-0.0319, -0.0290, 0.0023, ..., -0.0093, 0.0354, 0.0126],\n", - " [-0.0107, 0.0100, -0.0101, ..., 0.0046, 0.0205, -0.0203]])),\n", - " ('model.layers.15.mlp.gate_proj.weight',\n", - " tensor([[ 0.0160, 0.0432, 0.0073, ..., -0.0003, -0.0170, 0.0236],\n", - " [ 0.0055, 0.0066, -0.0311, ..., 0.0049, -0.0130, 0.0040],\n", - " [-0.0147, -0.0184, 0.0281, ..., 0.0016, 0.0077, -0.0072],\n", - " ...,\n", - " [-0.0049, -0.0434, -0.0118, ..., 0.0137, -0.0225, -0.0058],\n", - " [ 0.0221, -0.0077, 0.0029, ..., 0.0087, -0.0361, -0.0100],\n", - " [ 0.0263, 0.0228, 0.0050, ..., -0.0557, 0.0037, 0.0196]])),\n", - " ('model.layers.15.mlp.up_proj.weight',\n", - " tensor([[ 0.0093, -0.0189, 0.0173, ..., 0.0276, 0.0075, -0.0215],\n", - " [-0.0147, 0.0241, 0.0109, ..., 0.0120, 0.0032, 0.0327],\n", - " [ 0.0036, 0.0127, 0.0116, ..., 0.0100, -0.0003, 0.0233],\n", - " ...,\n", - " [-0.0063, 0.0160, 0.0138, ..., -0.0078, -0.0098, 0.0150],\n", - " [ 0.0138, -0.0236, 0.0109, ..., -0.0156, -0.0143, 0.0273],\n", - " [ 0.0345, 0.0201, -0.0119, ..., -0.0182, 0.0053, 0.0105]])),\n", - " ('model.layers.15.mlp.down_proj.weight',\n", - " tensor([[-0.0114, 0.0138, -0.0110, ..., 0.0084, -0.0144, 0.0100],\n", - " [ 0.0016, -0.0069, 0.0172, ..., -0.0394, 0.0368, 0.0468],\n", - " [-0.0184, -0.0094, -0.0273, ..., -0.0195, 0.0148, 0.0142],\n", - " ...,\n", - " [ 0.0311, 0.0093, -0.0130, ..., -0.0023, 0.0395, -0.0375],\n", - " [ 0.0056, 0.0027, 0.0061, ..., 0.0058, 0.0225, -0.0153],\n", - " [-0.0031, -0.0107, 0.0020, ..., -0.0173, -0.0050, 0.0423]])),\n", - " ('model.layers.15.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.15.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.16.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.16.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.16.mixer.in_proj.weight',\n", - " tensor([[-0.0063, 0.0006, 0.0130, ..., 0.0186, 0.0408, 0.0126],\n", - " [-0.0015, -0.0029, 0.0268, ..., -0.0042, -0.0209, -0.0046],\n", - " [-0.0034, -0.0286, 0.0185, ..., -0.0125, 0.0050, 0.0033],\n", - " ...,\n", - " [ 0.0045, 0.0133, 0.0220, ..., 0.0165, 0.0287, 0.0371],\n", - " [ 0.0100, -0.0232, 0.0103, ..., -0.0083, -0.0105, -0.0187],\n", - " [-0.0412, -0.0035, 0.0028, ..., 0.0286, 0.0349, -0.0037]])),\n", - " ('model.layers.16.mixer.conv1d.weight',\n", - " tensor([[[-0.1874, 0.2517, 0.0537, 0.1258]],\n", - " \n", - " [[ 0.1465, 0.2013, 0.3547, 0.2689]],\n", - " \n", - " [[ 0.4834, 0.4906, 0.0844, -0.0541]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.3004, 0.3313, 0.1688, 0.4381]],\n", - " \n", - " [[-0.0606, 0.3455, -0.0910, 0.1148]],\n", - " \n", - " [[-0.1421, -0.1254, -0.2353, -0.1675]]])),\n", - " ('model.layers.16.mixer.conv1d.bias',\n", - " tensor([ 0.2835, 0.2361, 0.1225, ..., -0.2119, -0.1929, 0.3877])),\n", - " ('model.layers.16.mixer.out_proj.weight',\n", - " tensor([[-0.0121, 0.0194, 0.0060, ..., -0.0029, -0.0147, -0.0085],\n", - " [-0.0216, -0.0012, 0.0287, ..., 0.0102, -0.0133, -0.0153],\n", - " [ 0.0136, -0.0296, 0.0417, ..., -0.0118, -0.0283, 0.0359],\n", - " ...,\n", - " [-0.0263, -0.0003, 0.0022, ..., 0.0135, -0.0519, -0.0254],\n", - " [ 0.0121, -0.0144, -0.0026, ..., 0.0096, 0.0130, 0.0095],\n", - " [-0.0147, -0.0217, 0.0099, ..., 0.0267, -0.0072, -0.0213]])),\n", - " ('model.layers.16.mlp.gate_proj.weight',\n", - " tensor([[ 0.0103, -0.0396, -0.0127, ..., 0.0020, -0.0055, 0.0291],\n", - " [ 0.0194, 0.0357, -0.0020, ..., -0.0112, 0.0448, -0.0224],\n", - " [-0.0390, 0.0142, -0.0224, ..., -0.0030, 0.0102, 0.0078],\n", - " ...,\n", - " [ 0.0165, -0.0251, 0.0196, ..., 0.0213, 0.0040, -0.0228],\n", - " [-0.0145, 0.0218, -0.0032, ..., -0.0240, -0.0079, 0.0256],\n", - " [ 0.0539, -0.0027, -0.0227, ..., -0.0184, -0.0109, 0.0236]])),\n", - " ('model.layers.16.mlp.up_proj.weight',\n", - " tensor([[ 7.1125e-03, -3.2583e-04, -2.6297e-02, ..., -4.9575e-03,\n", - " -1.2243e-02, -1.3005e-02],\n", - " [ 2.5637e-02, -1.1874e-02, 1.1376e-02, ..., -1.4700e-02,\n", - " -1.5193e-02, 2.6111e-03],\n", - " [-4.8919e-02, -4.9716e-04, 5.8527e-03, ..., 8.6775e-05,\n", - " 1.0694e-02, 3.7682e-03],\n", - " ...,\n", - " [ 8.8393e-03, -4.3317e-02, 2.8372e-02, ..., 2.2709e-02,\n", - " -4.8128e-03, 1.6899e-02],\n", - " [ 1.3257e-02, 2.1000e-02, 1.5035e-03, ..., 1.5603e-02,\n", - " -5.5857e-03, 4.0449e-03],\n", - " [-2.6754e-02, -1.6263e-02, 1.9013e-02, ..., -9.0918e-03,\n", - " -8.0242e-03, -1.0925e-02]])),\n", - " ('model.layers.16.mlp.down_proj.weight',\n", - " tensor([[ 0.0207, -0.0038, -0.0234, ..., 0.0299, -0.0329, -0.0117],\n", - " [-0.0316, 0.0032, 0.0131, ..., 0.0020, -0.0320, 0.0381],\n", - " [-0.0192, -0.0031, -0.0030, ..., -0.0224, 0.0037, 0.0085],\n", - " ...,\n", - " [ 0.0044, 0.0281, -0.0208, ..., 0.0179, -0.0085, -0.0010],\n", - " [-0.0076, -0.0008, 0.0483, ..., 0.0082, -0.0177, -0.0039],\n", - " [ 0.0224, 0.0019, 0.0181, ..., 0.0143, -0.0252, 0.0022]])),\n", - " ('model.layers.16.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.16.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.17.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.17.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.17.mixer.in_proj.weight',\n", - " tensor([[-0.0115, 0.0061, -0.0062, ..., -0.0132, -0.0047, 0.0274],\n", - " [ 0.0076, 0.0278, -0.0147, ..., 0.0439, -0.0093, -0.0154],\n", - " [-0.0383, -0.0264, -0.0053, ..., -0.0206, 0.0275, 0.0188],\n", - " ...,\n", - " [ 0.0096, 0.0228, 0.0351, ..., 0.0227, 0.0138, -0.0164],\n", - " [ 0.0321, -0.0293, -0.0054, ..., 0.0109, -0.0113, -0.0130],\n", - " [-0.0120, -0.0132, 0.0092, ..., -0.0338, 0.0308, -0.0135]])),\n", - " ('model.layers.17.mixer.conv1d.weight',\n", - " tensor([[[-0.4933, 0.4156, 0.2523, -0.0026]],\n", - " \n", - " [[-0.2572, 0.4916, 0.3642, -0.2145]],\n", - " \n", - " [[ 0.0261, 0.4852, -0.1448, 0.2288]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.3698, -0.4122, -0.2264, -0.1378]],\n", - " \n", - " [[ 0.1447, 0.4556, -0.0466, 0.0389]],\n", - " \n", - " [[-0.3891, 0.4149, 0.1454, -0.4282]]])),\n", - " ('model.layers.17.mixer.conv1d.bias',\n", - " tensor([-0.3919, -0.4015, 0.2591, ..., -0.3368, 0.2285, 0.1701])),\n", - " ('model.layers.17.mixer.out_proj.weight',\n", - " tensor([[-0.0127, -0.0155, 0.0193, ..., 0.0204, 0.0025, 0.0159],\n", - " [ 0.0192, 0.0194, -0.0169, ..., -0.0062, 0.0262, 0.0070],\n", - " [ 0.0397, 0.0009, 0.0189, ..., -0.0082, 0.0352, -0.0150],\n", - " ...,\n", - " [-0.0339, -0.0142, -0.0151, ..., 0.0229, 0.0032, 0.0038],\n", - " [ 0.0235, 0.0319, -0.0137, ..., -0.0121, 0.0112, 0.0162],\n", - " [ 0.0060, 0.0102, -0.0016, ..., 0.0118, 0.0158, -0.0140]])),\n", - " ('model.layers.17.mlp.gate_proj.weight',\n", - " tensor([[ 0.0285, -0.0090, -0.0095, ..., 0.0315, -0.0065, 0.0189],\n", - " [ 0.0040, -0.0358, -0.0039, ..., -0.0074, -0.0285, -0.0223],\n", - " [ 0.0202, 0.0021, -0.0104, ..., -0.0083, 0.0300, -0.0267],\n", - " ...,\n", - " [ 0.0093, -0.0008, -0.0372, ..., 0.0422, 0.0309, 0.0095],\n", - " [ 0.0027, 0.0252, 0.0378, ..., -0.0238, 0.0234, -0.0062],\n", - " [-0.0061, -0.0022, -0.0033, ..., 0.0157, -0.0296, 0.0034]])),\n", - " ('model.layers.17.mlp.up_proj.weight',\n", - " tensor([[ 0.0061, -0.0135, 0.0029, ..., 0.0328, 0.0008, -0.0072],\n", - " [ 0.0145, -0.0226, -0.0095, ..., 0.0114, 0.0224, -0.0160],\n", - " [ 0.0097, -0.0024, -0.0179, ..., 0.0073, -0.0061, -0.0195],\n", - " ...,\n", - " [ 0.0308, -0.0014, 0.0104, ..., 0.0047, 0.0026, 0.0243],\n", - " [-0.0364, 0.0350, 0.0031, ..., -0.0072, 0.0267, 0.0017],\n", - " [ 0.0227, -0.0146, 0.0146, ..., -0.0434, -0.0159, 0.0230]])),\n", - " ('model.layers.17.mlp.down_proj.weight',\n", - " tensor([[-0.0216, 0.0211, 0.0136, ..., -0.0004, 0.0051, 0.0415],\n", - " [-0.0061, -0.0123, 0.0156, ..., -0.0005, -0.0183, -0.0137],\n", - " [-0.0146, -0.0274, -0.0439, ..., -0.0033, -0.0030, -0.0074],\n", - " ...,\n", - " [-0.0108, -0.0005, -0.0094, ..., -0.0243, 0.0065, -0.0005],\n", - " [-0.0126, 0.0124, -0.0006, ..., -0.0282, -0.0110, 0.0128],\n", - " [-0.0162, -0.0102, 0.0025, ..., -0.0084, 0.0066, -0.0074]])),\n", - " ('model.layers.17.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.17.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.18.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.18.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.18.mixer.in_proj.weight',\n", - " tensor([[-9.4961e-03, -1.2349e-04, -7.1455e-03, ..., 1.9508e-02,\n", - " -6.8715e-03, -1.3565e-02],\n", - " [-2.9701e-03, 3.1580e-03, 1.8849e-02, ..., 7.6566e-03,\n", - " -1.0968e-02, -8.0445e-03],\n", - " [-1.5402e-02, -6.7267e-03, 9.6119e-03, ..., 1.9799e-02,\n", - " 2.0198e-03, -1.7366e-03],\n", - " ...,\n", - " [ 8.2379e-03, 5.1668e-03, 3.8116e-02, ..., -3.8710e-03,\n", - " 1.4452e-02, -2.5152e-02],\n", - " [ 1.1949e-02, -1.2245e-03, 1.0568e-02, ..., -3.1690e-02,\n", - " 3.8135e-05, 1.7263e-02],\n", - " [ 1.6173e-04, 5.6721e-04, 2.1043e-02, ..., -3.6167e-02,\n", - " -1.1129e-02, -9.6768e-03]])),\n", - " ('model.layers.18.mixer.conv1d.weight',\n", - " tensor([[[ 0.2776, 0.2169, -0.2840, 0.1736]],\n", - " \n", - " [[-0.0598, -0.2654, 0.2423, -0.0874]],\n", - " \n", - " [[-0.3612, -0.3049, -0.3197, -0.2763]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.1389, 0.2034, -0.1739, 0.1634]],\n", - " \n", - " [[-0.2836, -0.0471, 0.1284, -0.0099]],\n", - " \n", - " [[ 0.2952, -0.2676, -0.3961, 0.2656]]])),\n", - " ('model.layers.18.mixer.conv1d.bias',\n", - " tensor([ 0.1804, 0.0336, 0.4006, ..., 0.2943, -0.1079, 0.0963])),\n", - " ('model.layers.18.mixer.out_proj.weight',\n", - " tensor([[ 0.0109, -0.0181, 0.0148, ..., -0.0105, -0.0011, -0.0052],\n", - " [ 0.0507, 0.0100, -0.0273, ..., -0.0069, 0.0054, 0.0129],\n", - " [ 0.0014, 0.0423, -0.0193, ..., -0.0023, -0.0293, 0.0004],\n", - " ...,\n", - " [ 0.0420, -0.0401, 0.0205, ..., 0.0135, -0.0089, -0.0023],\n", - " [ 0.0242, 0.0273, 0.0139, ..., -0.0402, 0.0061, 0.0119],\n", - " [-0.0145, 0.0102, 0.0245, ..., 0.0205, -0.0251, 0.0006]])),\n", - " ('model.layers.18.mlp.gate_proj.weight',\n", - " tensor([[ 0.0241, -0.0086, 0.0136, ..., -0.0219, -0.0064, -0.0142],\n", - " [-0.0067, 0.0252, 0.0246, ..., -0.0205, -0.0273, 0.0137],\n", - " [-0.0030, 0.0055, -0.0063, ..., 0.0107, 0.0083, -0.0037],\n", - " ...,\n", - " [-0.0154, 0.0101, 0.0221, ..., 0.0025, -0.0109, 0.0133],\n", - " [-0.0175, 0.0105, -0.0246, ..., 0.0244, 0.0023, 0.0080],\n", - " [-0.0060, 0.0183, 0.0297, ..., 0.0420, -0.0006, -0.0119]])),\n", - " ('model.layers.18.mlp.up_proj.weight',\n", - " tensor([[ 0.0066, -0.0009, -0.0070, ..., -0.0064, 0.0002, 0.0196],\n", - " [-0.0173, -0.0362, -0.0011, ..., 0.0158, -0.0198, -0.0046],\n", - " [ 0.0133, -0.0090, -0.0092, ..., 0.0039, -0.0052, -0.0101],\n", - " ...,\n", - " [ 0.0077, -0.0063, 0.0010, ..., 0.0091, 0.0218, 0.0132],\n", - " [ 0.0005, -0.0046, 0.0207, ..., 0.0112, 0.0183, -0.0020],\n", - " [ 0.0238, -0.0022, 0.0364, ..., -0.0042, 0.0237, 0.0183]])),\n", - " ('model.layers.18.mlp.down_proj.weight',\n", - " tensor([[ 0.0305, 0.0178, -0.0264, ..., -0.0158, 0.0135, 0.0132],\n", - " [ 0.0248, -0.0061, 0.0144, ..., -0.0165, 0.0098, 0.0410],\n", - " [-0.0156, -0.0039, 0.0112, ..., -0.0431, -0.0084, -0.0197],\n", - " ...,\n", - " [ 0.0071, 0.0236, -0.0038, ..., 0.0035, -0.0236, 0.0106],\n", - " [-0.0369, -0.0029, -0.0182, ..., -0.0008, -0.0417, 0.0064],\n", - " [-0.0273, 0.0207, 0.0130, ..., 0.0372, 0.0163, 0.0273]])),\n", - " ('model.layers.18.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.18.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.19.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.19.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.19.mixer.in_proj.weight',\n", - " tensor([[-0.0079, 0.0147, -0.0337, ..., -0.0201, -0.0254, 0.0035],\n", - " [ 0.0139, 0.0054, -0.0093, ..., -0.0208, -0.0289, -0.0087],\n", - " [ 0.0004, -0.0034, 0.0090, ..., -0.0109, -0.0093, 0.0102],\n", - " ...,\n", - " [ 0.0128, 0.0015, -0.0101, ..., -0.0482, -0.0217, 0.0144],\n", - " [-0.0100, -0.0079, 0.0286, ..., -0.0025, -0.0210, 0.0164],\n", - " [-0.0264, 0.0015, 0.0031, ..., 0.0027, 0.0131, -0.0384]])),\n", - " ('model.layers.19.mixer.conv1d.weight',\n", - " tensor([[[ 0.4729, 0.3708, -0.4394, -0.3549]],\n", - " \n", - " [[ 0.2230, -0.3271, 0.3017, -0.2552]],\n", - " \n", - " [[-0.0417, 0.1893, 0.4552, -0.0644]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.2565, 0.0407, 0.3521, 0.4116]],\n", - " \n", - " [[ 0.0795, -0.0374, 0.1034, 0.4254]],\n", - " \n", - " [[ 0.3333, 0.2431, 0.3459, -0.2676]]])),\n", - " ('model.layers.19.mixer.conv1d.bias',\n", - " tensor([-0.2287, -0.4446, -0.2300, ..., -0.2317, -0.3395, 0.4310])),\n", - " ('model.layers.19.mixer.out_proj.weight',\n", - " tensor([[-0.0456, -0.0167, -0.0117, ..., -0.0068, -0.0150, 0.0125],\n", - " [ 0.0194, 0.0172, -0.0232, ..., -0.0202, -0.0066, 0.0083],\n", - " [ 0.0320, -0.0065, 0.0274, ..., 0.0200, 0.0090, 0.0105],\n", - " ...,\n", - " [ 0.0315, 0.0415, 0.0128, ..., -0.0143, -0.0338, -0.0231],\n", - " [ 0.0227, -0.0177, -0.0034, ..., 0.0174, 0.0006, 0.0212],\n", - " [ 0.0358, 0.0084, 0.0075, ..., 0.0091, 0.0062, 0.0114]])),\n", - " ('model.layers.19.mlp.gate_proj.weight',\n", - " tensor([[-0.0010, 0.0156, 0.0042, ..., -0.0181, 0.0113, 0.0089],\n", - " [-0.0182, 0.0068, -0.0043, ..., -0.0323, -0.0019, -0.0045],\n", - " [ 0.0168, -0.0093, -0.0162, ..., -0.0074, 0.0166, -0.0334],\n", - " ...,\n", - " [ 0.0038, -0.0211, -0.0054, ..., -0.0229, 0.0193, -0.0210],\n", - " [ 0.0153, -0.0372, 0.0119, ..., 0.0043, -0.0097, -0.0025],\n", - " [ 0.0037, 0.0208, -0.0135, ..., 0.0052, -0.0125, -0.0282]])),\n", - " ('model.layers.19.mlp.up_proj.weight',\n", - " tensor([[-0.0026, 0.0360, 0.0161, ..., 0.0199, -0.0283, -0.0026],\n", - " [ 0.0185, 0.0122, -0.0299, ..., 0.0125, 0.0063, 0.0387],\n", - " [-0.0085, -0.0010, -0.0054, ..., -0.0088, -0.0034, -0.0179],\n", - " ...,\n", - " [-0.0179, 0.0211, -0.0003, ..., -0.0071, -0.0145, 0.0235],\n", - " [-0.0002, 0.0060, -0.0172, ..., -0.0086, 0.0175, -0.0232],\n", - " [-0.0081, -0.0280, -0.0152, ..., -0.0221, 0.0047, -0.0077]])),\n", - " ('model.layers.19.mlp.down_proj.weight',\n", - " tensor([[ 0.0038, -0.0027, -0.0122, ..., 0.0090, 0.0044, 0.0128],\n", - " [ 0.0054, 0.0075, 0.0116, ..., 0.0232, 0.0130, 0.0298],\n", - " [-0.0498, -0.0208, -0.0127, ..., 0.0166, -0.0221, 0.0038],\n", - " ...,\n", - " [ 0.0101, 0.0051, 0.0209, ..., 0.0137, -0.0225, 0.0142],\n", - " [-0.0433, -0.0217, -0.0167, ..., -0.0179, -0.0191, -0.0021],\n", - " [-0.0020, 0.0084, -0.0114, ..., 0.0324, 0.0216, -0.0062]])),\n", - " ('model.layers.19.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.19.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.20.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.20.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.20.mixer.in_proj.weight',\n", - " tensor([[ 3.3776e-02, 3.6619e-02, 6.8532e-03, ..., 5.7664e-02,\n", - " -2.3083e-02, -6.2962e-02],\n", - " [-2.9787e-03, -2.5050e-03, -3.4841e-03, ..., 5.4946e-03,\n", - " 9.0683e-03, 2.1583e-04],\n", - " [ 7.4430e-03, -1.0495e-02, 3.5169e-02, ..., -5.1808e-02,\n", - " 3.2650e-03, -3.1967e-02],\n", - " ...,\n", - " [-5.8685e-02, 4.8452e-02, -1.2612e-02, ..., 1.2174e-02,\n", - " 1.0566e-02, -4.9561e-03],\n", - " [ 3.1722e-03, -2.9390e-03, 1.4502e-05, ..., -2.3297e-02,\n", - " -7.5403e-03, -1.3599e-02],\n", - " [ 1.4845e-02, -4.3150e-02, -1.0338e-02, ..., -1.1149e-02,\n", - " -3.3432e-02, 3.8337e-03]])),\n", - " ('model.layers.20.mixer.conv1d.weight',\n", - " tensor([[[-0.3842, 0.2397, 0.4873, -0.3091]],\n", - " \n", - " [[-0.1886, 0.0751, 0.2026, -0.2674]],\n", - " \n", - " [[-0.0594, 0.3119, -0.2404, 0.1652]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.0028, 0.1315, 0.0515, 0.3189]],\n", - " \n", - " [[-0.1461, -0.0457, -0.0536, -0.2306]],\n", - " \n", - " [[-0.3025, -0.3339, 0.3007, -0.3007]]])),\n", - " ('model.layers.20.mixer.conv1d.bias',\n", - " tensor([-0.4901, -0.3784, -0.0173, ..., -0.3946, -0.0728, 0.2187])),\n", - " ('model.layers.20.mixer.out_proj.weight',\n", - " tensor([[ 0.0095, -0.0037, -0.0218, ..., 0.0080, 0.0062, 0.0246],\n", - " [-0.0197, 0.0037, 0.0076, ..., 0.0171, 0.0238, -0.0195],\n", - " [ 0.0364, -0.0165, 0.0224, ..., -0.0099, 0.0007, 0.0340],\n", - " ...,\n", - " [ 0.0235, -0.0072, -0.0319, ..., 0.0045, -0.0196, 0.0011],\n", - " [-0.0369, 0.0083, 0.0021, ..., -0.0357, -0.0039, -0.0150],\n", - " [-0.0174, -0.0211, 0.0111, ..., 0.0251, 0.0040, -0.0308]])),\n", - " ('model.layers.20.mlp.gate_proj.weight',\n", - " tensor([[ 0.0161, -0.0019, -0.0473, ..., 0.0019, 0.0075, -0.0038],\n", - " [-0.0321, -0.0020, -0.0100, ..., 0.0035, 0.0291, -0.0058],\n", - " [-0.0158, 0.0020, 0.0353, ..., 0.0125, 0.0228, -0.0392],\n", - " ...,\n", - " [ 0.0113, 0.0171, 0.0235, ..., 0.0043, 0.0378, 0.0391],\n", - " [ 0.0090, 0.0067, 0.0031, ..., 0.0291, -0.0052, -0.0216],\n", - " [ 0.0042, -0.0112, -0.0161, ..., -0.0063, -0.0156, 0.0211]])),\n", - " ('model.layers.20.mlp.up_proj.weight',\n", - " tensor([[ 0.0104, -0.0302, -0.0220, ..., -0.0072, -0.0083, -0.0066],\n", - " [ 0.0409, -0.0116, -0.0125, ..., 0.0182, 0.0267, 0.0099],\n", - " [-0.0055, 0.0104, 0.0027, ..., -0.0075, -0.0368, -0.0092],\n", - " ...,\n", - " [-0.0089, 0.0243, -0.0028, ..., -0.0136, -0.0176, -0.0054],\n", - " [ 0.0088, 0.0365, -0.0354, ..., 0.0035, 0.0280, 0.0155],\n", - " [-0.0472, 0.0088, 0.0102, ..., -0.0120, 0.0004, -0.0011]])),\n", - " ('model.layers.20.mlp.down_proj.weight',\n", - " tensor([[-0.0089, -0.0112, -0.0007, ..., 0.0360, -0.0077, 0.0261],\n", - " [ 0.0080, -0.0128, -0.0445, ..., 0.0095, -0.0298, 0.0176],\n", - " [ 0.0357, -0.0262, 0.0028, ..., 0.0162, 0.0089, 0.0050],\n", - " ...,\n", - " [-0.0129, 0.0216, 0.0125, ..., -0.0062, -0.0344, -0.0218],\n", - " [ 0.0006, -0.0143, -0.0099, ..., -0.0359, 0.0268, 0.0259],\n", - " [ 0.0222, -0.0154, 0.0013, ..., 0.0108, -0.0077, 0.0186]])),\n", - " ('model.layers.20.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.20.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.21.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.21.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.21.mixer.in_proj.weight',\n", - " tensor([[-0.0300, 0.0058, -0.0107, ..., -0.0318, 0.0350, 0.0350],\n", - " [ 0.0186, 0.0238, -0.0268, ..., 0.0142, -0.0277, -0.0095],\n", - " [-0.0061, 0.0083, 0.0072, ..., 0.0161, 0.0027, -0.0051],\n", - " ...,\n", - " [-0.0358, 0.0330, 0.0151, ..., -0.0376, 0.0057, 0.0174],\n", - " [-0.0021, 0.0068, 0.0151, ..., 0.0077, -0.0353, 0.0095],\n", - " [-0.0113, -0.0043, 0.0064, ..., -0.0063, -0.0232, -0.0058]])),\n", - " ('model.layers.21.mixer.conv1d.weight',\n", - " tensor([[[ 0.0354, 0.0496, -0.0106, 0.0084]],\n", - " \n", - " [[ 0.2553, 0.3217, -0.0078, -0.2333]],\n", - " \n", - " [[-0.1390, 0.0323, 0.4914, -0.2047]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.2243, 0.2984, 0.0188, 0.1830]],\n", - " \n", - " [[ 0.0756, 0.1443, -0.4898, -0.2082]],\n", - " \n", - " [[-0.3685, -0.1311, -0.4037, -0.3276]]])),\n", - " ('model.layers.21.mixer.conv1d.bias',\n", - " tensor([-0.2444, -0.1852, 0.2215, ..., 0.4515, 0.2532, -0.2388])),\n", - " ('model.layers.21.mixer.out_proj.weight',\n", - " tensor([[ 0.0232, 0.0328, 0.0026, ..., -0.0575, 0.0157, -0.0072],\n", - " [-0.0226, 0.0058, -0.0346, ..., 0.0092, 0.0078, 0.0108],\n", - " [ 0.0045, 0.0247, 0.0150, ..., -0.0085, 0.0268, 0.0253],\n", - " ...,\n", - " [ 0.0268, 0.0092, 0.0141, ..., 0.0062, 0.0177, -0.0405],\n", - " [ 0.0163, -0.0269, -0.0177, ..., 0.0029, -0.0080, -0.0036],\n", - " [ 0.0064, 0.0126, 0.0126, ..., -0.0400, -0.0015, -0.0088]])),\n", - " ('model.layers.21.mlp.gate_proj.weight',\n", - " tensor([[-3.7050e-02, 4.5834e-02, 1.9280e-02, ..., 1.6761e-02,\n", - " -5.8295e-03, -1.4284e-02],\n", - " [ 3.0156e-02, 3.2832e-02, 1.1083e-02, ..., -5.8261e-03,\n", - " -3.9076e-02, 5.3379e-03],\n", - " [ 1.3118e-03, 3.1510e-02, 1.5472e-02, ..., 1.8213e-02,\n", - " -2.5180e-02, 6.1512e-04],\n", - " ...,\n", - " [ 4.2010e-02, 1.0362e-02, 7.1759e-03, ..., 1.8667e-03,\n", - " -7.2165e-03, 1.6297e-02],\n", - " [ 1.8175e-02, 1.2840e-02, 3.2857e-03, ..., 1.8495e-02,\n", - " -7.7709e-03, 4.3964e-04],\n", - " [-9.2628e-05, 2.1701e-02, 2.1256e-02, ..., 2.5241e-02,\n", - " 5.0683e-02, -2.5481e-02]])),\n", - " ('model.layers.21.mlp.up_proj.weight',\n", - " tensor([[ 0.0228, 0.0082, -0.0083, ..., 0.0288, 0.0211, 0.0085],\n", - " [-0.0155, 0.0179, 0.0111, ..., -0.0218, -0.0162, -0.0052],\n", - " [ 0.0016, 0.0009, 0.0230, ..., -0.0017, 0.0131, 0.0255],\n", - " ...,\n", - " [-0.0098, -0.0098, -0.0188, ..., 0.0063, 0.0082, 0.0052],\n", - " [-0.0028, 0.0249, -0.0153, ..., -0.0208, 0.0130, -0.0093],\n", - " [ 0.0105, -0.0072, -0.0379, ..., 0.0035, 0.0182, 0.0307]])),\n", - " ('model.layers.21.mlp.down_proj.weight',\n", - " tensor([[-0.0445, -0.0116, 0.0058, ..., 0.0081, -0.0099, 0.0094],\n", - " [ 0.0106, -0.0387, 0.0051, ..., 0.0017, 0.0075, 0.0136],\n", - " [ 0.0022, 0.0058, -0.0268, ..., -0.0088, -0.0149, 0.0125],\n", - " ...,\n", - " [-0.0015, -0.0156, -0.0225, ..., 0.0100, -0.0118, -0.0019],\n", - " [-0.0161, -0.0225, -0.0060, ..., 0.0073, -0.0072, 0.0205],\n", - " [-0.0112, 0.0046, -0.0089, ..., -0.0014, -0.0221, 0.0124]])),\n", - " ('model.layers.21.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.21.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.22.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.22.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.22.mixer.in_proj.weight',\n", - " tensor([[-1.1591e-02, -6.0118e-03, -2.2227e-03, ..., -7.1433e-03,\n", - " -1.5757e-02, -1.5315e-03],\n", - " [-7.6057e-03, -4.2199e-02, 1.4478e-02, ..., 5.6496e-02,\n", - " 8.9105e-05, -3.8658e-03],\n", - " [-1.0330e-03, 2.3586e-02, 2.1835e-02, ..., -1.4911e-03,\n", - " -1.6604e-02, -4.5245e-03],\n", - " ...,\n", - " [-6.7261e-03, -6.9826e-03, -9.3003e-03, ..., -4.3939e-02,\n", - " 2.3792e-02, -5.5165e-03],\n", - " [-1.1798e-02, -3.4709e-02, -4.1277e-03, ..., -5.1867e-03,\n", - " 5.2496e-03, -6.0055e-03],\n", - " [ 7.3402e-04, -1.9525e-02, -5.8966e-03, ..., -1.5972e-02,\n", - " -1.5446e-02, -2.7164e-02]])),\n", - " ('model.layers.22.mixer.conv1d.weight',\n", - " tensor([[[-0.3791, 0.0616, 0.0369, 0.1365]],\n", - " \n", - " [[-0.4674, -0.4557, 0.3894, -0.4765]],\n", - " \n", - " [[ 0.3333, 0.2265, 0.1385, -0.1352]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.4363, -0.3526, -0.3982, -0.1049]],\n", - " \n", - " [[ 0.4798, -0.3912, 0.4059, -0.1379]],\n", - " \n", - " [[-0.4427, 0.4661, -0.1990, 0.1668]]])),\n", - " ('model.layers.22.mixer.conv1d.bias',\n", - " tensor([-0.1823, -0.4117, 0.4443, ..., -0.0024, 0.2144, -0.4922])),\n", - " ('model.layers.22.mixer.out_proj.weight',\n", - " tensor([[ 0.0138, -0.0169, -0.0349, ..., -0.0045, 0.0023, -0.0389],\n", - " [ 0.0250, 0.0040, -0.0259, ..., 0.0458, 0.0311, -0.0054],\n", - " [-0.0056, 0.0012, -0.0027, ..., 0.0095, -0.0089, -0.0106],\n", - " ...,\n", - " [ 0.0228, -0.0258, 0.0040, ..., 0.0276, -0.0121, -0.0239],\n", - " [ 0.0082, 0.0041, 0.0145, ..., 0.0079, -0.0076, 0.0177],\n", - " [ 0.0310, -0.0092, -0.0174, ..., 0.0179, 0.0231, -0.0035]])),\n", - " ('model.layers.22.mlp.gate_proj.weight',\n", - " tensor([[ 0.0090, -0.0178, -0.0120, ..., -0.0073, -0.0149, 0.0187],\n", - " [ 0.0263, -0.0093, -0.0074, ..., -0.0472, 0.0049, 0.0288],\n", - " [ 0.0159, -0.0083, 0.0291, ..., 0.0089, -0.0076, -0.0167],\n", - " ...,\n", - " [-0.0008, 0.0206, 0.0199, ..., -0.0134, -0.0366, -0.0202],\n", - " [-0.0069, -0.0275, 0.0054, ..., 0.0093, 0.0108, 0.0094],\n", - " [ 0.0198, 0.0033, -0.0118, ..., -0.0262, 0.0241, 0.0084]])),\n", - " ('model.layers.22.mlp.up_proj.weight',\n", - " tensor([[-0.0277, 0.0038, 0.0006, ..., -0.0222, -0.0313, -0.0133],\n", - " [ 0.0132, -0.0373, 0.0109, ..., 0.0359, -0.0116, 0.0099],\n", - " [ 0.0139, -0.0185, 0.0247, ..., 0.0178, 0.0192, 0.0049],\n", - " ...,\n", - " [ 0.0362, 0.0072, -0.0236, ..., -0.0238, 0.0319, -0.0210],\n", - " [ 0.0013, -0.0047, -0.0060, ..., 0.0106, -0.0074, -0.0185],\n", - " [-0.0228, 0.0176, -0.0047, ..., -0.0034, -0.0174, -0.0264]])),\n", - " ('model.layers.22.mlp.down_proj.weight',\n", - " tensor([[ 0.0149, 0.0122, -0.0037, ..., 0.0044, 0.0171, -0.0186],\n", - " [-0.0037, -0.0002, 0.0066, ..., 0.0263, -0.0025, -0.0012],\n", - " [-0.0075, 0.0209, 0.0045, ..., 0.0082, -0.0160, 0.0079],\n", - " ...,\n", - " [ 0.0001, 0.0507, -0.0078, ..., 0.0001, -0.0119, 0.0286],\n", - " [-0.0198, -0.0122, 0.0047, ..., -0.0052, 0.0130, -0.0007],\n", - " [ 0.0241, -0.0002, -0.0147, ..., 0.0219, -0.0020, -0.0071]])),\n", - " ('model.layers.22.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.22.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.23.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.23.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.23.mixer.in_proj.weight',\n", - " tensor([[-0.0017, 0.0027, -0.0150, ..., 0.0392, -0.0079, -0.0367],\n", - " [ 0.0183, 0.0261, -0.0262, ..., -0.0157, 0.0197, 0.0135],\n", - " [-0.0030, 0.0170, 0.0032, ..., 0.0059, 0.0299, 0.0158],\n", - " ...,\n", - " [-0.0149, 0.0218, 0.0072, ..., -0.0302, 0.0035, 0.0153],\n", - " [-0.0135, 0.0425, 0.0331, ..., -0.0119, -0.0364, 0.0365],\n", - " [-0.0215, -0.0242, 0.0271, ..., 0.0500, 0.0293, 0.0100]])),\n", - " ('model.layers.23.mixer.conv1d.weight',\n", - " tensor([[[ 0.2464, 0.3726, 0.2719, 0.3580]],\n", - " \n", - " [[-0.0520, 0.0010, 0.1396, -0.4634]],\n", - " \n", - " [[ 0.1383, 0.4039, -0.3622, 0.1499]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.4094, 0.0541, 0.2240, -0.1545]],\n", - " \n", - " [[-0.4393, 0.1323, 0.1705, -0.1722]],\n", - " \n", - " [[ 0.2166, -0.4335, -0.4088, -0.1159]]])),\n", - " ('model.layers.23.mixer.conv1d.bias',\n", - " tensor([ 0.3175, -0.0325, -0.4654, ..., 0.3869, -0.2534, 0.1588])),\n", - " ('model.layers.23.mixer.out_proj.weight',\n", - " tensor([[-0.0354, -0.0041, 0.0196, ..., -0.0218, -0.0222, 0.0126],\n", - " [-0.0155, -0.0067, -0.0007, ..., 0.0112, -0.0036, -0.0054],\n", - " [ 0.0141, 0.0040, -0.0218, ..., -0.0178, -0.0031, 0.0162],\n", - " ...,\n", - " [ 0.0264, 0.0063, 0.0088, ..., -0.0310, -0.0116, 0.0239],\n", - " [-0.0031, 0.0056, -0.0243, ..., -0.0350, 0.0004, 0.0004],\n", - " [ 0.0229, -0.0201, 0.0124, ..., 0.0313, -0.0412, -0.0033]])),\n", - " ('model.layers.23.mlp.gate_proj.weight',\n", - " tensor([[ 0.0026, -0.0155, 0.0595, ..., 0.0204, 0.0172, 0.0378],\n", - " [-0.0011, -0.0253, 0.0039, ..., 0.0330, -0.0487, -0.0195],\n", - " [ 0.0174, 0.0039, -0.0029, ..., -0.0026, 0.0104, 0.0108],\n", - " ...,\n", - " [-0.0159, 0.0008, 0.0173, ..., -0.0020, 0.0085, -0.0043],\n", - " [ 0.0101, 0.0221, -0.0034, ..., -0.0268, 0.0056, 0.0137],\n", - " [-0.0031, -0.0151, 0.0073, ..., -0.0083, -0.0064, 0.0109]])),\n", - " ('model.layers.23.mlp.up_proj.weight',\n", - " tensor([[ 0.0173, -0.0132, -0.0027, ..., 0.0391, 0.0268, -0.0185],\n", - " [ 0.0221, -0.0110, -0.0108, ..., -0.0302, 0.0170, 0.0139],\n", - " [-0.0047, -0.0373, 0.0056, ..., -0.0389, -0.0175, -0.0410],\n", - " ...,\n", - " [ 0.0003, 0.0153, 0.0160, ..., 0.0002, -0.0136, 0.0417],\n", - " [-0.0059, -0.0150, -0.0111, ..., 0.0163, 0.0171, 0.0267],\n", - " [-0.0123, -0.0032, 0.0193, ..., -0.0051, -0.0051, -0.0089]])),\n", - " ('model.layers.23.mlp.down_proj.weight',\n", - " tensor([[-0.0092, -0.0148, -0.0345, ..., -0.0240, 0.0425, -0.0099],\n", - " [ 0.0458, 0.0156, -0.0067, ..., -0.0283, 0.0401, 0.0074],\n", - " [ 0.0180, -0.0008, 0.0049, ..., -0.0085, -0.0157, 0.0044],\n", - " ...,\n", - " [-0.0207, 0.0074, -0.0176, ..., 0.0038, -0.0238, -0.0026],\n", - " [-0.0201, 0.0078, 0.0243, ..., -0.0031, 0.0080, -0.0176],\n", - " [-0.0034, 0.0191, 0.0391, ..., -0.0114, 0.0133, -0.0261]])),\n", - " ('model.layers.23.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.23.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.24.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.24.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.24.mixer.in_proj.weight',\n", - " tensor([[-0.0184, -0.0299, 0.0165, ..., 0.0035, 0.0417, -0.0170],\n", - " [-0.0346, -0.0226, 0.0064, ..., 0.0072, 0.0457, -0.0148],\n", - " [ 0.0032, -0.0245, -0.0474, ..., -0.0054, -0.0044, 0.0278],\n", - " ...,\n", - " [ 0.0139, 0.0133, -0.0185, ..., 0.0188, 0.0119, -0.0205],\n", - " [ 0.0235, 0.0161, -0.0095, ..., 0.0013, -0.0382, 0.0213],\n", - " [ 0.0031, -0.0394, 0.0275, ..., -0.0068, 0.0024, 0.0179]])),\n", - " ('model.layers.24.mixer.conv1d.weight',\n", - " tensor([[[-0.1857, -0.4692, 0.4791, 0.3706]],\n", - " \n", - " [[ 0.1749, 0.4182, -0.2338, 0.0838]],\n", - " \n", - " [[-0.1204, -0.2985, -0.0470, 0.4674]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.1485, 0.3118, -0.4916, -0.1610]],\n", - " \n", - " [[ 0.0684, -0.2980, 0.4517, -0.3662]],\n", - " \n", - " [[ 0.2353, -0.2156, -0.3332, -0.0665]]])),\n", - " ('model.layers.24.mixer.conv1d.bias',\n", - " tensor([-0.4464, -0.3485, -0.3916, ..., 0.2513, -0.0601, 0.1546])),\n", - " ('model.layers.24.mixer.out_proj.weight',\n", - " tensor([[-0.0023, 0.0087, -0.0280, ..., 0.0338, -0.0095, -0.0237],\n", - " [-0.0086, -0.0084, 0.0180, ..., 0.0350, 0.0463, -0.0270],\n", - " [-0.0093, -0.0009, 0.0236, ..., 0.0158, 0.0246, 0.0068],\n", - " ...,\n", - " [ 0.0526, 0.0009, 0.0039, ..., -0.0206, -0.0538, 0.0287],\n", - " [ 0.0054, -0.0053, -0.0108, ..., 0.0167, -0.0997, 0.0036],\n", - " [ 0.0009, -0.0297, -0.0424, ..., -0.0096, -0.0235, 0.0117]])),\n", - " ('model.layers.24.mlp.gate_proj.weight',\n", - " tensor([[-0.0265, 0.0259, 0.0224, ..., -0.0080, -0.0394, 0.0290],\n", - " [-0.0101, -0.0256, 0.0079, ..., -0.0017, -0.0287, -0.0163],\n", - " [ 0.0079, -0.0021, -0.0299, ..., 0.0076, 0.0063, 0.0082],\n", - " ...,\n", - " [ 0.0061, 0.0121, 0.0275, ..., -0.0162, 0.0025, -0.0075],\n", - " [-0.0039, -0.0217, -0.0428, ..., -0.0253, 0.0231, 0.0095],\n", - " [-0.0187, 0.0077, -0.0442, ..., 0.0358, -0.0084, -0.0132]])),\n", - " ('model.layers.24.mlp.up_proj.weight',\n", - " tensor([[-0.0201, -0.0119, 0.0505, ..., -0.0025, -0.0187, 0.0011],\n", - " [-0.0105, 0.0154, -0.0163, ..., 0.0248, 0.0028, 0.0178],\n", - " [-0.0163, -0.0271, -0.0100, ..., 0.0129, -0.0220, 0.0269],\n", - " ...,\n", - " [ 0.0138, 0.0329, -0.0091, ..., 0.0038, -0.0194, -0.0223],\n", - " [ 0.0469, 0.0291, -0.0027, ..., 0.0231, 0.0261, 0.0151],\n", - " [-0.0093, -0.0098, 0.0013, ..., 0.0078, -0.0145, 0.0268]])),\n", - " ('model.layers.24.mlp.down_proj.weight',\n", - " tensor([[-0.0195, -0.0003, -0.0046, ..., -0.0132, -0.0118, 0.0242],\n", - " [-0.0267, 0.0199, 0.0243, ..., -0.0063, 0.0134, -0.0163],\n", - " [-0.0044, -0.0303, -0.0215, ..., -0.0148, -0.0216, 0.0079],\n", - " ...,\n", - " [ 0.0159, 0.0180, 0.0098, ..., -0.0126, 0.0176, 0.0087],\n", - " [-0.0203, 0.0041, -0.0256, ..., -0.0047, -0.0236, -0.0256],\n", - " [-0.0017, 0.0133, 0.0490, ..., -0.0344, -0.0118, 0.0020]])),\n", - " ('model.layers.24.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.24.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.25.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.25.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.25.mixer.in_proj.weight',\n", - " tensor([[ 0.0064, 0.0039, 0.0014, ..., 0.0130, -0.0169, 0.0010],\n", - " [ 0.0371, 0.0241, 0.0203, ..., 0.0078, 0.0463, 0.0034],\n", - " [ 0.0184, -0.0431, -0.0026, ..., -0.0164, 0.0279, -0.0138],\n", - " ...,\n", - " [ 0.0146, -0.0138, -0.0418, ..., 0.0234, 0.0145, -0.0213],\n", - " [ 0.0124, -0.0298, -0.0164, ..., -0.0169, 0.0026, -0.0180],\n", - " [-0.0250, -0.0008, -0.0133, ..., -0.0131, -0.0064, 0.0071]])),\n", - " ('model.layers.25.mixer.conv1d.weight',\n", - " tensor([[[ 0.0171, -0.3423, -0.1701, 0.4869]],\n", - " \n", - " [[-0.4648, 0.4797, 0.3531, -0.3819]],\n", - " \n", - " [[-0.1660, -0.3489, -0.2488, 0.4428]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.3545, -0.1567, -0.2646, 0.3590]],\n", - " \n", - " [[-0.2175, 0.4394, 0.3840, 0.2620]],\n", - " \n", - " [[ 0.1335, -0.3655, 0.3256, -0.1752]]])),\n", - " ('model.layers.25.mixer.conv1d.bias',\n", - " tensor([-0.0935, 0.0170, 0.0779, ..., -0.2362, 0.2879, 0.2390])),\n", - " ('model.layers.25.mixer.out_proj.weight',\n", - " tensor([[ 2.0220e-02, 5.0645e-05, -1.7425e-02, ..., 8.6082e-03,\n", - " -1.8566e-02, 1.3872e-02],\n", - " [ 2.9139e-02, 1.1096e-02, 4.4168e-02, ..., 3.5600e-02,\n", - " 7.3446e-03, -1.6368e-02],\n", - " [-3.2418e-02, 6.9682e-03, 3.1648e-02, ..., 1.4050e-02,\n", - " -1.6554e-02, 7.2751e-03],\n", - " ...,\n", - " [-3.3057e-02, -7.0545e-04, 3.9661e-02, ..., 2.0690e-02,\n", - " -1.0262e-02, -4.9292e-03],\n", - " [ 1.9849e-02, 1.9666e-02, -1.9398e-02, ..., 1.9285e-02,\n", - " 2.2522e-02, -6.0243e-03],\n", - " [ 1.7683e-02, 2.4301e-02, 7.2223e-03, ..., 3.1373e-02,\n", - " -5.7889e-03, 1.1855e-02]])),\n", - " ('model.layers.25.mlp.gate_proj.weight',\n", - " tensor([[-1.6223e-02, 4.5519e-03, -1.9218e-02, ..., 6.3580e-03,\n", - " -1.2723e-02, -9.7756e-03],\n", - " [-7.4200e-03, 1.8729e-02, 2.6924e-03, ..., 8.2305e-03,\n", - " -1.5727e-02, -9.8748e-03],\n", - " [ 3.2143e-02, -6.1559e-02, 1.6362e-02, ..., -3.6189e-04,\n", - " 1.2017e-04, -1.5734e-02],\n", - " ...,\n", - " [-1.4649e-02, -4.7663e-03, -1.9292e-02, ..., -1.9359e-02,\n", - " 1.8795e-02, 1.0221e-02],\n", - " [-2.4459e-02, 1.1684e-02, -2.8023e-02, ..., 8.0104e-03,\n", - " 8.5950e-05, 1.0542e-02],\n", - " [-4.5679e-03, -1.1421e-02, -2.1099e-02, ..., 4.5089e-03,\n", - " -3.0686e-02, -9.6116e-03]])),\n", - " ('model.layers.25.mlp.up_proj.weight',\n", - " tensor([[-0.0204, -0.0013, -0.0264, ..., -0.0081, -0.0027, 0.0215],\n", - " [-0.0161, 0.0051, -0.0111, ..., -0.0244, 0.0043, -0.0043],\n", - " [-0.0511, 0.0006, -0.0249, ..., 0.0069, 0.0615, 0.0123],\n", - " ...,\n", - " [-0.0086, -0.0016, 0.0064, ..., -0.0347, 0.0097, -0.0134],\n", - " [-0.0003, 0.0015, -0.0053, ..., 0.0210, 0.0135, 0.0337],\n", - " [-0.0205, 0.0028, -0.0272, ..., -0.0168, -0.0072, 0.0019]])),\n", - " ('model.layers.25.mlp.down_proj.weight',\n", - " tensor([[ 0.0166, 0.0044, 0.0180, ..., -0.0127, 0.0070, -0.0066],\n", - " [-0.0056, 0.0140, 0.0151, ..., -0.0239, -0.0140, 0.0470],\n", - " [-0.0030, -0.0093, -0.0188, ..., -0.0090, -0.0092, -0.0088],\n", - " ...,\n", - " [ 0.0465, 0.0277, -0.0349, ..., 0.0424, 0.0015, 0.0206],\n", - " [-0.0096, 0.0174, 0.0250, ..., -0.0142, -0.0022, -0.0141],\n", - " [-0.0195, -0.0174, 0.0033, ..., 0.0027, -0.0061, -0.0108]])),\n", - " ('model.layers.25.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.25.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.26.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.26.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.26.mixer.in_proj.weight',\n", - " tensor([[ 0.0112, 0.0060, -0.0038, ..., -0.0164, 0.0111, 0.0105],\n", - " [ 0.0227, -0.0248, 0.0240, ..., 0.0103, -0.0373, -0.0051],\n", - " [-0.0073, 0.0227, -0.0190, ..., 0.0048, -0.0101, -0.0137],\n", - " ...,\n", - " [ 0.0086, -0.0084, 0.0177, ..., -0.0245, 0.0119, 0.0022],\n", - " [-0.0080, -0.0284, 0.0440, ..., 0.0340, -0.0093, 0.0130],\n", - " [-0.0107, 0.0234, -0.0279, ..., 0.0106, -0.0169, -0.0001]])),\n", - " ('model.layers.26.mixer.conv1d.weight',\n", - " tensor([[[ 0.0550, -0.3464, -0.2378, -0.1244]],\n", - " \n", - " [[-0.0925, -0.2497, 0.2629, -0.1821]],\n", - " \n", - " [[-0.4524, 0.3462, -0.4604, -0.2758]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.4555, -0.0839, 0.3936, -0.3707]],\n", - " \n", - " [[ 0.3409, -0.4109, 0.0890, -0.3629]],\n", - " \n", - " [[-0.2769, 0.4033, -0.1090, 0.3055]]])),\n", - " ('model.layers.26.mixer.conv1d.bias',\n", - " tensor([-0.2286, -0.2395, -0.2517, ..., 0.0537, 0.0906, 0.4936])),\n", - " ('model.layers.26.mixer.out_proj.weight',\n", - " tensor([[-0.0316, -0.0423, -0.0053, ..., 0.0024, 0.0084, -0.0270],\n", - " [ 0.0458, -0.0243, 0.0060, ..., -0.0007, -0.0161, -0.0232],\n", - " [ 0.0388, -0.0126, 0.0184, ..., -0.0059, 0.0061, 0.0090],\n", - " ...,\n", - " [ 0.0487, 0.0305, -0.0175, ..., -0.0250, -0.0158, -0.0035],\n", - " [-0.0148, -0.0224, 0.0095, ..., -0.0102, -0.0226, 0.0272],\n", - " [-0.0061, 0.0067, 0.0069, ..., 0.0038, -0.0277, -0.0168]])),\n", - " ('model.layers.26.mlp.gate_proj.weight',\n", - " tensor([[-1.9812e-02, 8.3232e-03, 3.0347e-03, ..., 2.1982e-02,\n", - " 1.3550e-02, -1.1203e-02],\n", - " [ 2.2460e-02, 4.9811e-03, -2.2167e-02, ..., 1.3932e-03,\n", - " 5.3891e-03, -2.8310e-02],\n", - " [ 1.1011e-02, -1.2903e-02, -2.8861e-02, ..., 2.6808e-02,\n", - " -2.8479e-03, -1.3105e-02],\n", - " ...,\n", - " [ 1.1078e-03, -1.1789e-02, -4.4165e-02, ..., 8.2950e-03,\n", - " -1.8015e-02, -1.2234e-02],\n", - " [-2.0721e-02, -4.7919e-04, -4.9474e-02, ..., 7.9999e-05,\n", - " 1.7886e-02, -4.4699e-02],\n", - " [ 8.1279e-03, 1.2636e-02, -2.0932e-02, ..., -3.0361e-03,\n", - " 3.3468e-03, 2.7677e-02]])),\n", - " ('model.layers.26.mlp.up_proj.weight',\n", - " tensor([[-0.0301, -0.0025, -0.0147, ..., -0.0186, 0.0058, -0.0057],\n", - " [ 0.0303, -0.0341, 0.0142, ..., -0.0252, -0.0247, 0.0280],\n", - " [ 0.0209, -0.0425, 0.0073, ..., 0.0063, -0.0040, -0.0076],\n", - " ...,\n", - " [-0.0172, -0.0199, 0.0125, ..., 0.0363, 0.0118, -0.0124],\n", - " [-0.0108, 0.0042, -0.0475, ..., 0.0091, -0.0185, 0.0144],\n", - " [-0.0275, -0.0049, 0.0183, ..., -0.0001, -0.0119, -0.0359]])),\n", - " ('model.layers.26.mlp.down_proj.weight',\n", - " tensor([[-0.0197, -0.0082, -0.0224, ..., -0.0469, -0.0076, -0.0375],\n", - " [-0.0070, -0.0071, 0.0190, ..., -0.0125, 0.0068, 0.0166],\n", - " [ 0.0062, -0.0072, 0.0189, ..., -0.0244, -0.0292, -0.0328],\n", - " ...,\n", - " [-0.0054, 0.0219, 0.0058, ..., 0.0118, 0.0136, -0.0221],\n", - " [-0.0133, 0.0299, -0.0182, ..., -0.0496, -0.0202, 0.0196],\n", - " [-0.0131, -0.0237, -0.0473, ..., 0.0066, 0.0119, 0.0100]])),\n", - " ('model.layers.26.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.26.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.27.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.27.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.27.mixer.in_proj.weight',\n", - " tensor([[ 0.0200, -0.0276, -0.0274, ..., 0.0282, 0.0025, 0.0215],\n", - " [ 0.0054, 0.0218, -0.0175, ..., -0.0054, 0.0211, -0.0073],\n", - " [ 0.0100, -0.0023, 0.0162, ..., 0.0008, -0.0193, -0.0050],\n", - " ...,\n", - " [-0.0241, -0.0197, -0.0142, ..., 0.0039, -0.0175, 0.0045],\n", - " [ 0.0214, 0.0137, -0.0155, ..., -0.0212, 0.0089, 0.0165],\n", - " [ 0.0086, 0.0181, 0.0069, ..., -0.0093, -0.0272, 0.0068]])),\n", - " ('model.layers.27.mixer.conv1d.weight',\n", - " tensor([[[ 0.0519, 0.2061, 0.2635, 0.4916]],\n", - " \n", - " [[ 0.3745, -0.0860, -0.2310, -0.4250]],\n", - " \n", - " [[ 0.0565, 0.3699, 0.2812, -0.4201]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.4073, 0.1852, -0.1687, -0.2643]],\n", - " \n", - " [[-0.0865, -0.0894, 0.2650, -0.4522]],\n", - " \n", - " [[-0.0987, 0.0925, -0.2098, 0.0325]]])),\n", - " ('model.layers.27.mixer.conv1d.bias',\n", - " tensor([-0.4788, -0.0231, -0.4210, ..., -0.3143, -0.2893, 0.0570])),\n", - " ('model.layers.27.mixer.out_proj.weight',\n", - " tensor([[-0.0294, -0.0038, -0.0213, ..., -0.0141, 0.0072, -0.0359],\n", - " [ 0.0131, 0.0173, 0.0159, ..., 0.0030, 0.0400, -0.0065],\n", - " [-0.0111, 0.0374, 0.0109, ..., -0.0338, 0.0312, 0.0073],\n", - " ...,\n", - " [-0.0004, 0.0282, 0.0148, ..., 0.0165, 0.0062, -0.0177],\n", - " [ 0.0265, -0.0331, -0.0056, ..., 0.0407, 0.0154, 0.0176],\n", - " [ 0.0209, -0.0293, 0.0009, ..., -0.0240, -0.0029, -0.0407]])),\n", - " ('model.layers.27.mlp.gate_proj.weight',\n", - " tensor([[-0.0118, 0.0202, -0.0012, ..., 0.0101, 0.0075, 0.0102],\n", - " [ 0.0102, -0.0062, 0.0330, ..., -0.0024, -0.0245, -0.0237],\n", - " [-0.0008, 0.0202, -0.0097, ..., 0.0022, -0.0152, -0.0128],\n", - " ...,\n", - " [-0.0461, 0.0178, 0.0253, ..., 0.0319, 0.0173, -0.0099],\n", - " [ 0.0014, -0.0256, 0.0224, ..., 0.0272, 0.0045, 0.0192],\n", - " [ 0.0146, -0.0357, -0.0089, ..., -0.0147, 0.0383, 0.0354]])),\n", - " ('model.layers.27.mlp.up_proj.weight',\n", - " tensor([[-3.1854e-02, -1.0290e-03, -3.4564e-03, ..., 3.3551e-03,\n", - " 3.2845e-02, 2.1107e-02],\n", - " [-4.8083e-04, -5.8388e-03, 1.7324e-03, ..., 2.0575e-02,\n", - " -1.1685e-02, 1.2504e-02],\n", - " [ 4.6267e-02, -1.8935e-02, -2.4184e-02, ..., -4.8211e-02,\n", - " -3.3912e-04, 3.0527e-02],\n", - " ...,\n", - " [-6.9427e-03, -4.8680e-03, 3.2021e-02, ..., 1.4236e-02,\n", - " 1.9532e-02, 1.3339e-02],\n", - " [ 1.2463e-02, -5.5923e-03, -1.5680e-02, ..., 8.7956e-03,\n", - " 2.8262e-02, -1.2526e-02],\n", - " [-4.8530e-03, -8.8749e-05, 3.3507e-02, ..., -2.8260e-02,\n", - " -2.0571e-03, -8.3943e-03]])),\n", - " ('model.layers.27.mlp.down_proj.weight',\n", - " tensor([[-0.0457, -0.0267, -0.0210, ..., -0.0093, -0.0016, -0.0008],\n", - " [-0.0053, 0.0284, -0.0003, ..., 0.0065, -0.0117, 0.0243],\n", - " [ 0.0120, 0.0023, -0.0180, ..., -0.0003, -0.0313, 0.0163],\n", - " ...,\n", - " [-0.0160, 0.0207, 0.0082, ..., 0.0153, 0.0131, 0.0034],\n", - " [-0.0073, 0.0424, 0.0274, ..., -0.0075, -0.0554, -0.0114],\n", - " [-0.0192, 0.0268, 0.0036, ..., 0.0094, 0.0045, 0.0030]])),\n", - " ('model.layers.27.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.27.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.norm.weight', tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('lm_head.weight',\n", - " tensor([[-0.0141, -0.0445, 0.0071, ..., -0.0143, -0.0239, -0.0512],\n", - " [ 0.0295, -0.0317, -0.0201, ..., -0.0082, 0.0231, -0.0030],\n", - " [-0.0255, -0.0139, 0.0020, ..., -0.0040, -0.0154, 0.0336],\n", - " ...,\n", - " [ 0.0095, 0.0361, 0.0135, ..., -0.0018, 0.0074, -0.0311],\n", - " [-0.0092, 0.0060, 0.0594, ..., -0.0046, 0.0117, 0.0364],\n", - " [ 0.0228, -0.0265, -0.0262, ..., 0.0038, 0.0097, -0.0257]]))])" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_ssm.state_dict()" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "N params SSM: 5.305533088\n" - ] - } - ], - "source": [ - "print(\"N params SSM:\", sum(p.numel() for p in apriel_ssm.parameters() if p.requires_grad)/1e9)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Load State dict into SSM" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AprielSSMForCausalLM(\n", - " (model): AprielSSMModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0-27): 28 x AprielDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", - ")" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "apriel_ssm.to(device).to(dtype=torch.bfloat16)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "_IncompatibleKeys(missing_keys=['model.layers.0.mixer.z_bias', 'model.layers.0.mixer.D', 'model.layers.0.mixer.in_proj.weight', 'model.layers.0.mixer.conv1d.weight', 'model.layers.0.mixer.conv1d.bias', 'model.layers.0.mixer.out_proj.weight', 'model.layers.1.mixer.z_bias', 'model.layers.1.mixer.D', 'model.layers.1.mixer.in_proj.weight', 'model.layers.1.mixer.conv1d.weight', 'model.layers.1.mixer.conv1d.bias', 'model.layers.1.mixer.out_proj.weight', 'model.layers.2.mixer.z_bias', 'model.layers.2.mixer.D', 'model.layers.2.mixer.in_proj.weight', 'model.layers.2.mixer.conv1d.weight', 'model.layers.2.mixer.conv1d.bias', 'model.layers.2.mixer.out_proj.weight', 'model.layers.3.mixer.z_bias', 'model.layers.3.mixer.D', 'model.layers.3.mixer.in_proj.weight', 'model.layers.3.mixer.conv1d.weight', 'model.layers.3.mixer.conv1d.bias', 'model.layers.3.mixer.out_proj.weight', 'model.layers.4.mixer.z_bias', 'model.layers.4.mixer.D', 'model.layers.4.mixer.in_proj.weight', 'model.layers.4.mixer.conv1d.weight', 'model.layers.4.mixer.conv1d.bias', 'model.layers.4.mixer.out_proj.weight', 'model.layers.5.mixer.z_bias', 'model.layers.5.mixer.D', 'model.layers.5.mixer.in_proj.weight', 'model.layers.5.mixer.conv1d.weight', 'model.layers.5.mixer.conv1d.bias', 'model.layers.5.mixer.out_proj.weight', 'model.layers.6.mixer.z_bias', 'model.layers.6.mixer.D', 'model.layers.6.mixer.in_proj.weight', 'model.layers.6.mixer.conv1d.weight', 'model.layers.6.mixer.conv1d.bias', 'model.layers.6.mixer.out_proj.weight', 'model.layers.7.mixer.z_bias', 'model.layers.7.mixer.D', 'model.layers.7.mixer.in_proj.weight', 'model.layers.7.mixer.conv1d.weight', 'model.layers.7.mixer.conv1d.bias', 'model.layers.7.mixer.out_proj.weight', 'model.layers.8.mixer.z_bias', 'model.layers.8.mixer.D', 'model.layers.8.mixer.in_proj.weight', 'model.layers.8.mixer.conv1d.weight', 'model.layers.8.mixer.conv1d.bias', 'model.layers.8.mixer.out_proj.weight', 'model.layers.9.mixer.z_bias', 'model.layers.9.mixer.D', 'model.layers.9.mixer.in_proj.weight', 'model.layers.9.mixer.conv1d.weight', 'model.layers.9.mixer.conv1d.bias', 'model.layers.9.mixer.out_proj.weight', 'model.layers.10.mixer.z_bias', 'model.layers.10.mixer.D', 'model.layers.10.mixer.in_proj.weight', 'model.layers.10.mixer.conv1d.weight', 'model.layers.10.mixer.conv1d.bias', 'model.layers.10.mixer.out_proj.weight', 'model.layers.11.mixer.z_bias', 'model.layers.11.mixer.D', 'model.layers.11.mixer.in_proj.weight', 'model.layers.11.mixer.conv1d.weight', 'model.layers.11.mixer.conv1d.bias', 'model.layers.11.mixer.out_proj.weight', 'model.layers.12.mixer.z_bias', 'model.layers.12.mixer.D', 'model.layers.12.mixer.in_proj.weight', 'model.layers.12.mixer.conv1d.weight', 'model.layers.12.mixer.conv1d.bias', 'model.layers.12.mixer.out_proj.weight', 'model.layers.13.mixer.z_bias', 'model.layers.13.mixer.D', 'model.layers.13.mixer.in_proj.weight', 'model.layers.13.mixer.conv1d.weight', 'model.layers.13.mixer.conv1d.bias', 'model.layers.13.mixer.out_proj.weight', 'model.layers.14.mixer.z_bias', 'model.layers.14.mixer.D', 'model.layers.14.mixer.in_proj.weight', 'model.layers.14.mixer.conv1d.weight', 'model.layers.14.mixer.conv1d.bias', 'model.layers.14.mixer.out_proj.weight', 'model.layers.15.mixer.z_bias', 'model.layers.15.mixer.D', 'model.layers.15.mixer.in_proj.weight', 'model.layers.15.mixer.conv1d.weight', 'model.layers.15.mixer.conv1d.bias', 'model.layers.15.mixer.out_proj.weight', 'model.layers.16.mixer.z_bias', 'model.layers.16.mixer.D', 'model.layers.16.mixer.in_proj.weight', 'model.layers.16.mixer.conv1d.weight', 'model.layers.16.mixer.conv1d.bias', 'model.layers.16.mixer.out_proj.weight', 'model.layers.17.mixer.z_bias', 'model.layers.17.mixer.D', 'model.layers.17.mixer.in_proj.weight', 'model.layers.17.mixer.conv1d.weight', 'model.layers.17.mixer.conv1d.bias', 'model.layers.17.mixer.out_proj.weight', 'model.layers.18.mixer.z_bias', 'model.layers.18.mixer.D', 'model.layers.18.mixer.in_proj.weight', 'model.layers.18.mixer.conv1d.weight', 'model.layers.18.mixer.conv1d.bias', 'model.layers.18.mixer.out_proj.weight', 'model.layers.19.mixer.z_bias', 'model.layers.19.mixer.D', 'model.layers.19.mixer.in_proj.weight', 'model.layers.19.mixer.conv1d.weight', 'model.layers.19.mixer.conv1d.bias', 'model.layers.19.mixer.out_proj.weight', 'model.layers.20.mixer.z_bias', 'model.layers.20.mixer.D', 'model.layers.20.mixer.in_proj.weight', 'model.layers.20.mixer.conv1d.weight', 'model.layers.20.mixer.conv1d.bias', 'model.layers.20.mixer.out_proj.weight', 'model.layers.21.mixer.z_bias', 'model.layers.21.mixer.D', 'model.layers.21.mixer.in_proj.weight', 'model.layers.21.mixer.conv1d.weight', 'model.layers.21.mixer.conv1d.bias', 'model.layers.21.mixer.out_proj.weight', 'model.layers.22.mixer.z_bias', 'model.layers.22.mixer.D', 'model.layers.22.mixer.in_proj.weight', 'model.layers.22.mixer.conv1d.weight', 'model.layers.22.mixer.conv1d.bias', 'model.layers.22.mixer.out_proj.weight', 'model.layers.23.mixer.z_bias', 'model.layers.23.mixer.D', 'model.layers.23.mixer.in_proj.weight', 'model.layers.23.mixer.conv1d.weight', 'model.layers.23.mixer.conv1d.bias', 'model.layers.23.mixer.out_proj.weight', 'model.layers.24.mixer.z_bias', 'model.layers.24.mixer.D', 'model.layers.24.mixer.in_proj.weight', 'model.layers.24.mixer.conv1d.weight', 'model.layers.24.mixer.conv1d.bias', 'model.layers.24.mixer.out_proj.weight', 'model.layers.25.mixer.z_bias', 'model.layers.25.mixer.D', 'model.layers.25.mixer.in_proj.weight', 'model.layers.25.mixer.conv1d.weight', 'model.layers.25.mixer.conv1d.bias', 'model.layers.25.mixer.out_proj.weight', 'model.layers.26.mixer.z_bias', 'model.layers.26.mixer.D', 'model.layers.26.mixer.in_proj.weight', 'model.layers.26.mixer.conv1d.weight', 'model.layers.26.mixer.conv1d.bias', 'model.layers.26.mixer.out_proj.weight', 'model.layers.27.mixer.z_bias', 'model.layers.27.mixer.D', 'model.layers.27.mixer.in_proj.weight', 'model.layers.27.mixer.conv1d.weight', 'model.layers.27.mixer.conv1d.bias', 'model.layers.27.mixer.out_proj.weight'], unexpected_keys=['model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.18.self_attn.q_proj.weight', 'model.layers.18.self_attn.k_proj.weight', 'model.layers.18.self_attn.v_proj.weight', 'model.layers.18.self_attn.o_proj.weight', 'model.layers.19.self_attn.q_proj.weight', 'model.layers.19.self_attn.k_proj.weight', 'model.layers.19.self_attn.v_proj.weight', 'model.layers.19.self_attn.o_proj.weight', 'model.layers.20.self_attn.q_proj.weight', 'model.layers.20.self_attn.k_proj.weight', 'model.layers.20.self_attn.v_proj.weight', 'model.layers.20.self_attn.o_proj.weight', 'model.layers.21.self_attn.q_proj.weight', 'model.layers.21.self_attn.k_proj.weight', 'model.layers.21.self_attn.v_proj.weight', 'model.layers.21.self_attn.o_proj.weight', 'model.layers.22.self_attn.q_proj.weight', 'model.layers.22.self_attn.k_proj.weight', 'model.layers.22.self_attn.v_proj.weight', 'model.layers.22.self_attn.o_proj.weight', 'model.layers.23.self_attn.q_proj.weight', 'model.layers.23.self_attn.k_proj.weight', 'model.layers.23.self_attn.v_proj.weight', 'model.layers.23.self_attn.o_proj.weight', 'model.layers.24.self_attn.q_proj.weight', 'model.layers.24.self_attn.k_proj.weight', 'model.layers.24.self_attn.v_proj.weight', 'model.layers.24.self_attn.o_proj.weight', 'model.layers.25.self_attn.q_proj.weight', 'model.layers.25.self_attn.k_proj.weight', 'model.layers.25.self_attn.v_proj.weight', 'model.layers.25.self_attn.o_proj.weight', 'model.layers.26.self_attn.q_proj.weight', 'model.layers.26.self_attn.k_proj.weight', 'model.layers.26.self_attn.v_proj.weight', 'model.layers.26.self_attn.o_proj.weight', 'model.layers.27.self_attn.q_proj.weight', 'model.layers.27.self_attn.k_proj.weight', 'model.layers.27.self_attn.v_proj.weight', 'model.layers.27.self_attn.o_proj.weight'])" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_ssm.load_state_dict(apriel_state_dict, strict=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AprielSSMForCausalLM(\n", - " (model): AprielSSMModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0-27): 28 x AprielDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", - ")" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "apriel_ssm.to(device).to(dtype=torch.bfloat16)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# apriel_ssm.state_dict()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Save checkpoint" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'apriel_ssm' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[2], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mapriel_ssm\u001b[49m\u001b[38;5;241m.\u001b[39msave_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/mnt/checkpoints/ssm/apriel_ssm_instruct_base\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 2\u001b[0m save_config\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", - "\u001b[0;31mNameError\u001b[0m: name 'apriel_ssm' is not defined" - ] - } - ], - "source": [ - "apriel_ssm.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_instruct_base\",\n", - " save_config=True)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "24" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_ssm.model.layers[0].mixer.n_v_heads" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AprielSSMForCausalLM(\n", - " (model): AprielSSMModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0-27): 28 x AprielDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", - ")" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_ssm" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Try a forward pass" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "input_ids = torch.randint(0, 32000, (1, 128), dtype=torch.long, device=device)\n", - "batch_size = 1\n", - "max_length = 128\n", - "state = SimpleNamespace()\n", - "state.key_value_memory_dict = apriel_ssm.allocate_inference_cache(batch_size, max_length, dtype=torch.bfloat16)\n", - "state.batch_size = batch_size\n", - "state.seqlen_offset = 0\n", - "static_inputs = {\"inference_params\": state,\n", - " \"input_ids\": input_ids,\n", - " \"use_cache\": True,\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "CustomMambaCausalLMOutput(loss=None, logits=tensor([[[-3.0781, 2.3594, 1.4609, ..., -2.3438, -1.9688, 0.6484],\n", - " [-5.8125, 4.9688, 0.4414, ..., -4.2500, -3.5156, -4.8125],\n", - " [-5.5000, 3.3594, 1.1484, ..., -3.4375, -2.3125, -4.4375],\n", - " ...,\n", - " [-2.2812, 0.1465, 2.2344, ..., -7.6875, -3.0312, -6.2500],\n", - " [-6.8750, 1.7812, -1.3750, ..., -7.4688, -5.6875, -4.4062],\n", - " [-2.0156, 2.0938, 3.1094, ..., -3.0156, -2.1406, -2.2812]]],\n", - " device='cuda:0', grad_fn=), all_hidden_states=(), last_hidden_state=tensor([[[-1.3828, 0.0625, -2.7500, ..., -0.6523, -0.8906, 1.4609],\n", - " [ 2.1406, -0.0247, -3.0156, ..., -0.0074, 1.0234, 1.3828],\n", - " [ 1.6016, -0.7266, -1.2422, ..., -0.4004, -0.8242, -0.5586],\n", - " ...,\n", - " [ 1.5234, -0.0262, -1.5469, ..., -0.4922, -1.0078, 1.2344],\n", - " [-0.4629, -0.6055, -1.3906, ..., -0.9922, -0.3066, 1.1875],\n", - " [-0.7539, -0.0243, -2.4688, ..., -1.0625, -2.7188, 2.6875]]],\n", - " device='cuda:0', dtype=torch.bfloat16, grad_fn=))" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_ssm.forward(**static_inputs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "import enum" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "class SSMBlockType(str, enum.Enum):\n", - " \"\"\"\n", - " An enum for the available mamba types for the MLP layer.\n", - " \"\"\"\n", - "\n", - " mamba = \"m\"\n", - " mamba2_discrete = \"m2d\"\n", - " mamba2 = \"m2\"\n", - " transformer = \"t\"" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "dict_values([, , , ])" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "'m' in SSMBlockType.__members__.values()" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "ename": "KeyError", - "evalue": "'m'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[21], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mm\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[43mSSMBlockType\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mm\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241m.\u001b[39mname\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/enum.py:808\u001b[0m, in \u001b[0;36mEnumType.__getitem__\u001b[0;34m(cls, name)\u001b[0m\n\u001b[1;32m 804\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mcls\u001b[39m, name):\n\u001b[1;32m 805\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 806\u001b[0m \u001b[38;5;124;03m Return the member matching `name`.\u001b[39;00m\n\u001b[1;32m 807\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 808\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_member_map_\u001b[49m\u001b[43m[\u001b[49m\u001b[43mname\u001b[49m\u001b[43m]\u001b[49m\n", - "\u001b[0;31mKeyError\u001b[0m: 'm'" - ] - } - ], - "source": [ - "\"m\" == SSMBlockType[\"m\"].name\n" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'m2d'" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "SSMBlockType.mamba2_discrete.value" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "hymba2", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/fast_llm/models/ssm/external/discrete_mamba2.py b/fast_llm/models/ssm/external/discrete_mamba2.py deleted file mode 100644 index bb8afaa7..00000000 --- a/fast_llm/models/ssm/external/discrete_mamba2.py +++ /dev/null @@ -1,382 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange, repeat -from mamba_ssm.ops.triton.selective_state_update import selective_state_update -from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined - -from .configuration_mtp_llamba import StateUpdateKernel - -try: - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -except ImportError: - causal_conv1d_fn, causal_conv1d_update = None, None - - -class DiscreteMamba2(nn.Module): - """DiscreteMamba2 (taken github.com/goombalab/phi-mamba.git).""" - - def __init__( - self, - d_model, - d_state=64, - n_qk_heads=32, - n_v_heads=32, - d_conv=4, - expand=1, - activation="identity", - bias=False, - conv_bias=True, - chunk_size=128, - layer_idx=None, - device=None, - dtype=None, - verification_mode: StateUpdateKernel = StateUpdateKernel.cs, - **kwargs, - ): - """ - See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. - Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr". - - Other options are all experimental and should not need to be configured. - """ - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.d_model = d_model - self.d_state = d_state - self.d_conv = d_conv - self.expand = expand - self.d_inner = self.expand * self.d_model - self.n_qk_heads = n_qk_heads - self.n_v_heads = n_v_heads - self.headdim = self.d_inner // self.n_v_heads - assert self.n_v_heads == self.d_inner // self.headdim - assert self.d_inner % self.headdim == 0 - assert self.n_v_heads % self.n_qk_heads == 0 - self.activation = activation - self.chunk_size = chunk_size - self.layer_idx = layer_idx - self.bias = bias - self.kwargs = kwargs - self.inference_mode = verification_mode - assert verification_mode in [ - StateUpdateKernel.cs, - StateUpdateKernel.standard, - ], "Only chunk scan and standard selective scan are supported for now" - - # Projections - self.in_proj = nn.Linear( - self.d_model, - 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, - bias=bias, - **factory_kwargs, - ) - self.z_bias = ( - nn.Parameter(torch.zeros(self.d_inner, **factory_kwargs)) if not bias else 0 - ) # make sure z_bias always exists - - # Convolutional layer - conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state - self.conv_bias = conv_bias - self.conv1d = nn.Conv1d( - in_channels=conv_dim, - out_channels=conv_dim, - bias=conv_bias, - kernel_size=d_conv, - groups=conv_dim, - padding=d_conv - 1, - **factory_kwargs, - ) - - # Activation after conv - if self.activation == "identity": - self.act = nn.Identity() - elif self.activation in ["silu", "swish"]: - self.act = nn.SiLU() - else: - raise ValueError(f"Unknown activation {self.activation}") - - # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.n_v_heads, **factory_kwargs)) - - # out_proj - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) - - @property - def d_output(self): - """Returns the output dimension of the model.""" - return self.d_model - - @property - def state_to_tensor(self): - """Returns the state of the model as a tensor.""" - return self.layer.state_to_tensor - - def forward(self, u, inference_params=None, **kwargs): - """ - Args: - u: (B, L, D), - inference_params: dict.. Here we assume it contains a mask tensor of shape (B, L) with 1s for valid tokens and 0s for no-op tokens. - - Returns: - outputs: dict. - outputs["hidden_states"]: (B, L, D). - outputs["state"]: inference cache. - """ - outputs = {} - # assert state is None - batch, seqlen, dim = u.shape - - state = None - if inference_params is not None: - state = self._get_states_from_cache(inference_params, batch) - - if ( - state is not None - and inference_params.seqlen_offset > 0 # meaning we are in the middle of the sequence - and seqlen == 1 - and self.inference_mode != StateUpdateKernel.cs - ): - # we go in here for standard 1 token per time-step inference. - # seqlen_offset > 0 means we are in the middle of a sequence - # States are updated inplace - u = u.squeeze(1) if len(u.shape) == 3 else u - out, _ = self.step(u, state) - out = out.unsqueeze(1) if len(u.shape) == 2 else out - return {"hidden_states": out} - - # Hacky way to initialize state during inference - chunk_size = self.chunk_size if state is None else seqlen - - # Pad input to nearest multiple of chunklen - padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size - u = F.pad(u, (0, 0, 0, padded_len - seqlen)) - - # Project input - xBCzA_log = self.in_proj(u) - - xBC, z, A_log = torch.split( - xBCzA_log, - [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, - ], - dim=-1, - ) - - if state is not None: - # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") - state["conv"].copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) - - # Convolutional layer - xBC = self.convolutional_forward( - xBC, padded_len, mask=inference_params.mask if inference_params is not None else None - ) - - x, B, C = torch.split( - xBC, - [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, - ], - dim=-1, - ) - - x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) - B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) - C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) - - # SSM forward - # TODO: this kernel needs to be aupdated to use the mask! If used solely for throughout benchmarking, it is enough to call it as is. - result = mamba_chunk_scan_combined( - x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), - dt=A_log, - dt_softplus=True, - A=-torch.ones(self.n_v_heads, device=A_log.device), - B=B, - C=C, - chunk_size=chunk_size, - # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation - return_final_states=(state is not None), - ) - - if state is not None: - y, ssm_state = result - state["ssm"].copy_(ssm_state) - else: - y = result - - Du = torch.einsum("h,blhp->blhp", self.D, x) - y = rearrange(y + Du, "b l h p -> b l (h p)") - - # Norm and gate - out = self.out_proj(y * F.silu(z + self.z_bias)) - outputs["hidden_states"] = out[:, :seqlen, :] - - return outputs - - def step(self, u, state, **kwargs): - """ - Args: - u: (B, D), - state: dict. - - Returns: - out: (B, D), - state: dict. - - """ - # Project input - xBCzA_log = self.in_proj(u) - xBC, z, A_log = torch.split( - xBCzA_log, - [ - self.d_inner + 2 * self.n_qk_heads * self.d_state, - self.d_inner, - self.n_v_heads, - ], - dim=-1, - ) - - xBC, conv_state = self.convolutional_step(xBC, state["conv"]) - state["conv"].copy_(conv_state) # update state in place - - x, B, C = torch.split( - xBC, - [ - self.d_inner, - self.n_qk_heads * self.d_state, - self.n_qk_heads * self.d_state, - ], - dim=-1, - ) - - x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) - B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) - C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) - - state["ssm"] = state["ssm"].to(x.dtype) - zeros = torch.zeros((self.n_v_heads, self.headdim), device=A_log.device).to(dtype=x.dtype) - ones = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=A_log.device).to(dtype=x.dtype) - y = selective_state_update( - x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), - dt=repeat(A_log, "b h -> b h p", p=self.headdim), - dt_softplus=True, - A=-ones, - B=B, - C=C, - state=state["ssm"], # will be updated in place - dt_bias=zeros, - D=zeros, - ) - - y = y + self.D[:, None] * x - y = rearrange(y, "b h p -> b (h p)") - - # Norm and gate - out = self.out_proj(y * F.silu(z + self.z_bias)) - - return out, state - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - """Allocate memory for inference cache.""" - device = self.in_proj.weight.device - # conv_state: - conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype - conv_state = torch.zeros( - batch_size, - self.d_conv, - self.conv1d.weight.shape[0], - device=device, - dtype=conv_dtype, - ).transpose(1, 2) - # ssm_state: - ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype - ssm_state = torch.zeros( - batch_size, - self.n_v_heads, - self.headdim, - self.d_state, - device=device, - dtype=ssm_dtype, - ) - return {"conv": conv_state, "ssm": ssm_state} - - def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): - """ - Get states from cache. - - conv_state: (batch, d_conv, conv1d.weight.shape[0]) - ssm_state: (batch, n_qk_heads, headdim, d_state) - """ - assert self.layer_idx is not None - # Allocate memory if not exists - if self.layer_idx not in inference_params.key_value_memory_dict: - inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( - batch_size, inference_params.max_seqlen, dtype=torch.float32 - ) - # Get states - states = inference_params.key_value_memory_dict[self.layer_idx] - if initialize_states: - states["conv"].zero_() - states["ssm"].zero_() - return states - - def convolutional_forward(self, xBC, padded_len, mask=None): - """Convolutional layer forward pass for the full sequence.""" - seqlen = xBC.shape[1] - mask_seql = -1 if mask is None else mask.shape[1] - # If seqlen != mask_seql, this likely means we preallocated mask for static generation, - # but here we are in the prefill phase. - # Note, mask is needed to prevent state upodate for no-op tokens as described in https://proceedings.mlr.press/v262/wu24a.html - # Note, if we want to use joint attanimnet and advancement in selective-scan mode, we would need to implement masking into the kernel of causal_conv1d_fn and mamba_chunk_scan_combined - if causal_conv1d_fn is None or self.activation not in [ - "silu", - "swish", - "identity", - ]: - if mask_seql == seqlen: - xBC = xBC * mask.unsqueeze(-1) - - xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) - if mask_seql == seqlen: - xBC = xBC * mask.unsqueeze(-1) - else: - # TODO: note, this only works for chunked inference, for autoregressive mode we need to update the kernel to make sure conv state is not poluted - if mask_seql == seqlen: - xBC = xBC * mask.unsqueeze(-1) - xBC = causal_conv1d_fn( - xBC.transpose(1, 2), - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - activation=None if self.activation == "identity" else self.activation, - ).transpose(1, 2) - - if mask_seql == seqlen: - xBC = xBC * mask.unsqueeze(-1) - return xBC - - def convolutional_step(self, xBC, conv_state): - """Convolutional layer forward pass for a single step.""" - conv_state = conv_state.to(xBC.dtype) - if causal_conv1d_update: - xBC = causal_conv1d_update( - xBC, - conv_state, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation if self.activation != "identity" else None, - ) - else: - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = xBC - xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) - if self.conv_bias: - xBC = xBC + self.conv1d.bias - xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype - - return xBC, conv_state diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py deleted file mode 100644 index 94537c33..00000000 --- a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py +++ /dev/null @@ -1,59 +0,0 @@ -from typing import Optional, Union - -import lm_eval.models.utils -import torch -from lm_eval.api.registry import register_model -from lm_eval.models.huggingface import HFLM - - -@register_model("apriel_ssm") -class AprielSSMWrapper(HFLM): - """Wrapper for Rene model for compatibility with lm-evaluation-harness.""" - - def __init__(self, pretrained, **kwargs) -> None: - if "backend" in kwargs: - # rene currently only supports causal models - assert kwargs["backend"] == "causal" - - super().__init__( - pretrained=pretrained, - backend=kwargs.pop("backend", "causal"), - tokenizer=kwargs.pop("tokenizer", "/mnt/checkpoints/upstream/Mistral-Nemo-Base-2407/"), - max_length=kwargs.pop("max_length", 4096), - **kwargs, - ) - - def _get_config(self, pretrained: str, **kwargs) -> None: - """Get the model configuration.""" - from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig - - self._config = AprielSSMConfig.from_pretrained(pretrained) - - def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: - """Create the model.""" - from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM - - self._model = AprielSSMForCausalLM.from_pretrained( - pretrained, - device=self._device, - dtype=torch.bfloat16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), - trust_remote_code=True, - ) - - def _model_generate(self, context, max_length, stop, **generation_kwargs): - """Generate text from the model.""" - for key in ("do_sample", "attention_mask"): - if key in generation_kwargs: - generation_kwargs.pop(key) - - # The custom GenerationMixin imported from mamba_ssm currently does not support - # passing stopping criteria. - # For the time being, we simply generate to max length, then truncate (equivalent result). - # This should be revisited to speed up generation - # stopping_criteria = stop_sequences_criteria(self.tokenizer, stop, 1, context.shape[0]) - - return self.model.generate( - input_ids=context, - max_length=max_length, - **generation_kwargs, - ) diff --git a/fast_llm/models/ssm/external/eval/run_lm_eval.py b/fast_llm/models/ssm/external/eval/run_lm_eval.py deleted file mode 100644 index af07869a..00000000 --- a/fast_llm/models/ssm/external/eval/run_lm_eval.py +++ /dev/null @@ -1,6 +0,0 @@ -from lm_eval.__main__ import cli_evaluate - -from fast_llm.models.ssm.external.eval.apriel_eval_wrapper import AprielSSMWrapper # noqa: F401 - -if __name__ == "__main__": - cli_evaluate() From 6532c5f94c3aea38018b2a06e413a59ad98fb4aa Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 2 May 2025 12:23:57 +0000 Subject: [PATCH 051/114] hybrid config --- .../models/ssm/external/ariel_to_ssm.ipynb | 3526 +++++++++++++++++ .../configuration_ssm_hybrid_apriel.py | 446 +++ .../external/modeling_ssm_hybrid_apriel.py | 1203 ++++++ 3 files changed, 5175 insertions(+) create mode 100644 fast_llm/models/ssm/external/ariel_to_ssm.ipynb create mode 100644 fast_llm/models/ssm/external/configuration_ssm_hybrid_apriel.py create mode 100644 fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py diff --git a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb new file mode 100644 index 00000000..496338cb --- /dev/null +++ b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb @@ -0,0 +1,3526 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import torch\n", + "from mamba_ssm import MambaLMHeadModel\n", + "from mamba_ssm.models.config_mamba import MambaConfig\n", + "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", + "from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig\n", + "from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM\n", + "from transformers.cache_utils import StaticCache\n", + "from types import SimpleNamespace\n", + "\n", + "# make sure the code changes reflected without reload\n", + "%load_ext autoreload\n", + "%autoreload 2\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 8.90it/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "AprielForCausalLM(\n", + " (model): AprielModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (rotary_emb): AprielRotaryEmbedding()\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", + "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", + "apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", + "apriel_state_dict = apriel_model.state_dict()\n", + "apriel_model.to(device).to(dtype=torch.bfloat16)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.bfloat16" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_model.config.torch_dtype" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "n_params = sum(p.numel() for p in apriel_model.parameters() if p.requires_grad)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4.83207168" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "n_params/1e9" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n" + ] + } + ], + "source": [ + "config_apriel = AprielSSMConfig.from_pretrained(\"/mnt/checkpoints_fml/pretrained_models/ssm/apriel_ssm_instruct_base\", trust_remote_code=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n", + "You are using a model of type llamba to instantiate a model of type apriel_ssm. This is not supported for all configurations of models and can yield errors.\n" + ] + }, + { + "ename": "KeyError", + "evalue": "'n_qk_heads'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[12], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m stage2_checkpoint \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/mnt/checkpoints_fml/pretrained_models/ssm/mohawk_final\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 2\u001b[0m stage2_apriel_ssm \u001b[38;5;241m=\u001b[39m \u001b[43mAprielSSMForCausalLM\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstage2_checkpoint\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtorch_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbfloat16\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/modeling_utils.py:3571\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 3569\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(config, PretrainedConfig):\n\u001b[1;32m 3570\u001b[0m config_path \u001b[38;5;241m=\u001b[39m config \u001b[38;5;28;01mif\u001b[39;00m config \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m pretrained_model_name_or_path\n\u001b[0;32m-> 3571\u001b[0m config, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconfig_class\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3572\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3573\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3574\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_unused_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 3575\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3576\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3577\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3578\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3579\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3580\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3581\u001b[0m \u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msubfolder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3582\u001b[0m \u001b[43m \u001b[49m\u001b[43m_from_auto\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfrom_auto_class\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3583\u001b[0m \u001b[43m \u001b[49m\u001b[43m_from_pipeline\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfrom_pipeline\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3584\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3585\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3586\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3587\u001b[0m \u001b[38;5;66;03m# In case one passes a config to `from_pretrained` + \"attn_implementation\"\u001b[39;00m\n\u001b[1;32m 3588\u001b[0m \u001b[38;5;66;03m# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 3592\u001b[0m \u001b[38;5;66;03m# we pop attn_implementation from the kwargs but this handles the case where users\u001b[39;00m\n\u001b[1;32m 3593\u001b[0m \u001b[38;5;66;03m# passes manually the config to `from_pretrained`.\u001b[39;00m\n\u001b[1;32m 3594\u001b[0m config \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(config)\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/configuration_utils.py:569\u001b[0m, in \u001b[0;36mPretrainedConfig.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, **kwargs)\u001b[0m\n\u001b[1;32m 563\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_type:\n\u001b[1;32m 564\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarning(\n\u001b[1;32m 565\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou are using a model of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mconfig_dict[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m to instantiate a model of type \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 566\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. This is not supported for all configurations of models and can yield errors.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 567\u001b[0m )\n\u001b[0;32m--> 569\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/configuration_utils.py:740\u001b[0m, in \u001b[0;36mPretrainedConfig.from_dict\u001b[0;34m(cls, config_dict, **kwargs)\u001b[0m\n\u001b[1;32m 737\u001b[0m \u001b[38;5;66;03m# We remove it from kwargs so that it does not appear in `return_unused_kwargs`.\u001b[39;00m\n\u001b[1;32m 738\u001b[0m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattn_implementation\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattn_implementation\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m--> 740\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mconfig_dict\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 742\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(config, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpruned_heads\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 743\u001b[0m config\u001b[38;5;241m.\u001b[39mpruned_heads \u001b[38;5;241m=\u001b[39m {\u001b[38;5;28mint\u001b[39m(key): value \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m config\u001b[38;5;241m.\u001b[39mpruned_heads\u001b[38;5;241m.\u001b[39mitems()}\n", + "File \u001b[0;32m~/dev/Fast-LLM/fast_llm/models/ssm/external/configuration_ssm_apriel.py:99\u001b[0m, in \u001b[0;36mAprielSSMConfig.__init__\u001b[0;34m(self, vocab_size, hidden_size, intermediate_size, num_hidden_layers, hidden_act, initializer_range, use_cache, pad_token_id, bos_token_id, eos_token_id, tie_word_embeddings, mlp_bias, rms_norm_eps, ssm_cfg, head_dim, **kwargs)\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(\n\u001b[1;32m 82\u001b[0m pad_token_id\u001b[38;5;241m=\u001b[39mpad_token_id,\n\u001b[1;32m 83\u001b[0m bos_token_id\u001b[38;5;241m=\u001b[39mbos_token_id,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 87\u001b[0m )\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mssm_cfg \u001b[38;5;241m=\u001b[39m ssm_cfg \u001b[38;5;129;01mor\u001b[39;00m {\n\u001b[1;32m 90\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_state\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m64\u001b[39m,\n\u001b[1;32m 91\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mn_v_heads\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m24\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m24\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhead_dim, \u001b[38;5;66;03m# num_heads * head_dim\u001b[39;00m\n\u001b[1;32m 98\u001b[0m }\n\u001b[0;32m---> 99\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhead_dim \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mssm_cfg[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mssm_cfg\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mn_qk_heads\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\n", + "\u001b[0;31mKeyError\u001b[0m: 'n_qk_heads'" + ] + } + ], + "source": [ + "stage2_checkpoint = \"/mnt/checkpoints_fml/pretrained_models/ssm/mohawk_final\"\n", + "stage2_apriel_ssm = AprielSSMForCausalLM.from_pretrained(stage2_checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "apriel_ssm_config = AprielSSMConfig(vocab_size=config.vocab_size, \n", + " hidden_size=config.hidden_size,\n", + " intermediate_size=config.intermediate_size,\n", + " num_hidden_layers=config.num_hidden_layers,\n", + " hidden_act=config.hidden_act,\n", + " initializer_range=config.initializer_range,\n", + " use_cache=config.use_cache,\n", + " mlp_bias=config.mlp_bias,\n", + " tie_word_embeddings=config.tie_word_embeddings,\n", + " pad_token_id=config.pad_token_id,\n", + " bos_token_id=config.bos_token_id,\n", + " eos_token_id=config.eos_token_id,\n", + " head_dim=config.head_dim,\n", + " rms_norm_eps=config.rms_norm_eps)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "apriel_ssm = AprielSSMForCausalLM(apriel_ssm_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "OrderedDict([('model.embed_tokens.weight',\n", + " tensor([[ 0.0105, 0.0330, -0.0032, ..., 0.0076, -0.0051, 0.0112],\n", + " [-0.0111, -0.0101, 0.0064, ..., 0.0144, 0.0098, -0.0194],\n", + " [ 0.0301, 0.0228, 0.0105, ..., -0.0159, 0.0112, -0.0009],\n", + " ...,\n", + " [ 0.0266, 0.0224, -0.0150, ..., 0.0189, -0.0253, -0.0300],\n", + " [-0.0304, 0.0249, 0.0140, ..., -0.0235, 0.0315, -0.0188],\n", + " [-0.0215, -0.0034, 0.0035, ..., -0.0125, 0.0084, 0.0246]])),\n", + " ('model.layers.0.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.0.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.0.mixer.in_proj.weight',\n", + " tensor([[ 0.0104, 0.0055, -0.0148, ..., 0.0208, -0.0074, 0.0015],\n", + " [ 0.0102, 0.0148, 0.0148, ..., -0.0041, 0.0224, -0.0336],\n", + " [ 0.0129, -0.0179, -0.0120, ..., 0.0175, 0.0300, -0.0234],\n", + " ...,\n", + " [-0.0215, 0.0002, 0.0093, ..., -0.0424, 0.0016, -0.0162],\n", + " [-0.0178, -0.0093, 0.0226, ..., 0.0005, 0.0062, 0.0150],\n", + " [-0.0204, 0.0039, -0.0364, ..., -0.0128, 0.0002, 0.0134]])),\n", + " ('model.layers.0.mixer.conv1d.weight',\n", + " tensor([[[-0.1064, -0.3782, -0.3080, -0.3179]],\n", + " \n", + " [[-0.3493, 0.2230, 0.1062, 0.0614]],\n", + " \n", + " [[-0.4650, 0.0300, 0.3021, 0.1197]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.3686, 0.0679, 0.1440, 0.4445]],\n", + " \n", + " [[-0.1480, 0.3750, -0.0552, -0.0297]],\n", + " \n", + " [[ 0.0677, 0.0925, -0.0268, -0.0232]]])),\n", + " ('model.layers.0.mixer.conv1d.bias',\n", + " tensor([ 0.1379, 0.0862, -0.0723, ..., -0.2628, -0.1867, -0.1233])),\n", + " ('model.layers.0.mixer.out_proj.weight',\n", + " tensor([[ 0.0208, -0.0106, -0.0016, ..., 0.0117, 0.0140, -0.0040],\n", + " [-0.0147, 0.0419, 0.0327, ..., -0.0073, -0.0127, 0.0190],\n", + " [-0.0218, 0.0030, 0.0115, ..., -0.0062, 0.0214, 0.0105],\n", + " ...,\n", + " [ 0.0089, 0.0154, -0.0178, ..., -0.0206, -0.0378, 0.0102],\n", + " [ 0.0153, -0.0249, 0.0219, ..., 0.0119, 0.0019, 0.0383],\n", + " [-0.0126, 0.0284, -0.0035, ..., 0.0118, -0.0186, -0.0232]])),\n", + " ('model.layers.0.mlp.gate_proj.weight',\n", + " tensor([[-0.0032, -0.0405, 0.0180, ..., -0.0030, -0.0222, 0.0069],\n", + " [-0.0071, -0.0064, -0.0207, ..., 0.0037, -0.0077, 0.0261],\n", + " [ 0.0236, 0.0167, 0.0065, ..., 0.0064, 0.0035, -0.0092],\n", + " ...,\n", + " [-0.0357, 0.0192, 0.0099, ..., -0.0067, -0.0181, 0.0082],\n", + " [-0.0139, -0.0161, -0.0015, ..., -0.0052, -0.0337, 0.0514],\n", + " [ 0.0105, -0.0205, 0.0198, ..., 0.0090, 0.0315, 0.0066]])),\n", + " ('model.layers.0.mlp.up_proj.weight',\n", + " tensor([[ 0.0074, 0.0237, -0.0300, ..., 0.0343, 0.0016, 0.0395],\n", + " [ 0.0270, 0.0085, 0.0193, ..., 0.0199, -0.0139, 0.0094],\n", + " [ 0.0036, 0.0073, 0.0149, ..., 0.0094, 0.0346, -0.0111],\n", + " ...,\n", + " [ 0.0159, -0.0346, -0.0128, ..., 0.0377, -0.0531, -0.0305],\n", + " [ 0.0283, 0.0162, -0.0377, ..., -0.0254, 0.0110, -0.0167],\n", + " [-0.0277, 0.0130, 0.0161, ..., 0.0089, -0.0190, 0.0214]])),\n", + " ('model.layers.0.mlp.down_proj.weight',\n", + " tensor([[ 0.0157, 0.0105, 0.0036, ..., 0.0229, 0.0080, 0.0303],\n", + " [-0.0143, -0.0067, 0.0016, ..., 0.0494, -0.0043, 0.0072],\n", + " [-0.0148, 0.0113, 0.0025, ..., -0.0186, 0.0206, -0.0119],\n", + " ...,\n", + " [-0.0226, 0.0099, 0.0010, ..., 0.0123, -0.0170, 0.0024],\n", + " [-0.0120, -0.0015, -0.0355, ..., 0.0064, 0.0175, -0.0065],\n", + " [ 0.0364, 0.0364, 0.0265, ..., -0.0222, 0.0030, 0.0296]])),\n", + " ('model.layers.0.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.0.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.1.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.1.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.1.mixer.in_proj.weight',\n", + " tensor([[-0.0116, -0.0182, -0.0017, ..., -0.0216, -0.0136, -0.0203],\n", + " [-0.0142, -0.0106, -0.0334, ..., 0.0287, -0.0273, 0.0050],\n", + " [ 0.0131, -0.0106, -0.0012, ..., 0.0261, -0.0228, -0.0026],\n", + " ...,\n", + " [-0.0029, 0.0023, 0.0360, ..., -0.0195, 0.0018, -0.0227],\n", + " [ 0.0004, 0.0015, -0.0051, ..., -0.0095, 0.0269, 0.0179],\n", + " [ 0.0295, -0.0520, 0.0009, ..., 0.0019, 0.0255, 0.0478]])),\n", + " ('model.layers.1.mixer.conv1d.weight',\n", + " tensor([[[-0.4725, -0.2938, -0.3816, -0.1239]],\n", + " \n", + " [[-0.2002, 0.3790, 0.1908, -0.4679]],\n", + " \n", + " [[-0.3674, 0.3774, -0.2479, 0.4324]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.4181, 0.2263, -0.1937, 0.3585]],\n", + " \n", + " [[ 0.0704, 0.0913, 0.4217, 0.3004]],\n", + " \n", + " [[ 0.3175, -0.3239, -0.0614, -0.3978]]])),\n", + " ('model.layers.1.mixer.conv1d.bias',\n", + " tensor([ 0.4302, 0.0269, -0.3462, ..., 0.4887, 0.2848, 0.0745])),\n", + " ('model.layers.1.mixer.out_proj.weight',\n", + " tensor([[-0.0069, 0.0233, 0.0133, ..., -0.0064, -0.0085, 0.0166],\n", + " [-0.0302, 0.0129, -0.0042, ..., 0.0109, 0.0009, -0.0087],\n", + " [-0.0373, -0.0233, -0.0043, ..., -0.0017, 0.0384, -0.0114],\n", + " ...,\n", + " [-0.0219, 0.0330, -0.0341, ..., 0.0080, 0.0089, 0.0268],\n", + " [-0.0019, -0.0069, 0.0276, ..., 0.0182, -0.0240, 0.0163],\n", + " [ 0.0081, 0.0070, 0.0156, ..., -0.0135, 0.0469, -0.0221]])),\n", + " ('model.layers.1.mlp.gate_proj.weight',\n", + " tensor([[ 0.0175, -0.0074, -0.0028, ..., 0.0197, 0.0034, 0.0221],\n", + " [ 0.0063, 0.0339, -0.0047, ..., 0.0037, -0.0126, -0.0342],\n", + " [-0.0093, -0.0148, -0.0236, ..., 0.0190, -0.0451, -0.0173],\n", + " ...,\n", + " [ 0.0167, 0.0161, 0.0019, ..., -0.0083, -0.0133, 0.0141],\n", + " [-0.0163, 0.0383, -0.0203, ..., 0.0336, -0.0148, 0.0013],\n", + " [-0.0138, -0.0275, -0.0268, ..., -0.0243, -0.0031, -0.0227]])),\n", + " ('model.layers.1.mlp.up_proj.weight',\n", + " tensor([[ 0.0054, 0.0031, 0.0256, ..., 0.0002, 0.0020, -0.0050],\n", + " [ 0.0247, -0.0298, -0.0218, ..., -0.0161, 0.0253, 0.0128],\n", + " [-0.0231, -0.0012, 0.0130, ..., 0.0031, -0.0324, 0.0107],\n", + " ...,\n", + " [ 0.0359, -0.0202, 0.0386, ..., -0.0104, 0.0274, 0.0161],\n", + " [ 0.0062, -0.0111, 0.0338, ..., 0.0041, 0.0001, -0.0019],\n", + " [ 0.0105, -0.0258, 0.0184, ..., -0.0270, -0.0138, -0.0367]])),\n", + " ('model.layers.1.mlp.down_proj.weight',\n", + " tensor([[-0.0163, -0.0308, -0.0203, ..., 0.0002, -0.0227, 0.0019],\n", + " [ 0.0206, 0.0037, 0.0064, ..., -0.0261, -0.0206, 0.0063],\n", + " [ 0.0044, -0.0073, -0.0576, ..., -0.0015, -0.0082, 0.0022],\n", + " ...,\n", + " [-0.0034, 0.0142, -0.0547, ..., -0.0106, -0.0090, 0.0249],\n", + " [-0.0068, 0.0127, -0.0066, ..., -0.0255, 0.0004, 0.0106],\n", + " [-0.0293, 0.0146, -0.0142, ..., -0.0073, -0.0284, -0.0069]])),\n", + " ('model.layers.1.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.1.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.2.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.2.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.2.mixer.in_proj.weight',\n", + " tensor([[ 0.0337, -0.0055, -0.0538, ..., -0.0051, 0.0107, -0.0338],\n", + " [ 0.0227, -0.0008, 0.0003, ..., -0.0312, 0.0090, -0.0126],\n", + " [-0.0238, 0.0146, 0.0240, ..., -0.0114, -0.0180, 0.0025],\n", + " ...,\n", + " [-0.0208, -0.0261, 0.0227, ..., 0.0071, 0.0014, 0.0237],\n", + " [ 0.0356, 0.0372, 0.0186, ..., 0.0052, 0.0049, -0.0195],\n", + " [ 0.0023, -0.0159, -0.0238, ..., 0.0194, -0.0056, -0.0275]])),\n", + " ('model.layers.2.mixer.conv1d.weight',\n", + " tensor([[[ 0.1054, -0.4185, 0.4229, 0.3289]],\n", + " \n", + " [[-0.0081, 0.0321, 0.1334, -0.1055]],\n", + " \n", + " [[ 0.1587, -0.3806, -0.1336, -0.2662]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.2830, -0.3875, -0.2972, 0.0030]],\n", + " \n", + " [[ 0.4210, 0.2190, -0.4942, 0.0465]],\n", + " \n", + " [[-0.1830, -0.3686, 0.2928, -0.0313]]])),\n", + " ('model.layers.2.mixer.conv1d.bias',\n", + " tensor([-0.2931, -0.3513, -0.3013, ..., -0.1934, -0.3115, 0.3889])),\n", + " ('model.layers.2.mixer.out_proj.weight',\n", + " tensor([[-0.0038, -0.0160, -0.0042, ..., 0.0062, 0.0059, -0.0126],\n", + " [-0.0027, -0.0012, -0.0065, ..., -0.0032, 0.0129, -0.0298],\n", + " [ 0.0394, -0.0096, 0.0107, ..., -0.0290, 0.0248, 0.0308],\n", + " ...,\n", + " [ 0.0087, 0.0067, -0.0261, ..., -0.0038, -0.0168, 0.0485],\n", + " [ 0.0118, 0.0042, -0.0186, ..., 0.0104, 0.0281, 0.0028],\n", + " [ 0.0304, -0.0382, -0.0028, ..., -0.0264, -0.0050, 0.0050]])),\n", + " ('model.layers.2.mlp.gate_proj.weight',\n", + " tensor([[-0.0169, 0.0036, 0.0024, ..., 0.0429, 0.0313, 0.0167],\n", + " [-0.0100, 0.0011, -0.0024, ..., -0.0065, 0.0090, 0.0123],\n", + " [ 0.0102, 0.0282, 0.0166, ..., -0.0082, 0.0123, 0.0253],\n", + " ...,\n", + " [ 0.0168, -0.0056, -0.0096, ..., -0.0090, 0.0150, 0.0209],\n", + " [ 0.0258, 0.0113, -0.0093, ..., 0.0335, 0.0386, -0.0156],\n", + " [ 0.0129, 0.0338, -0.0006, ..., -0.0346, 0.0135, -0.0213]])),\n", + " ('model.layers.2.mlp.up_proj.weight',\n", + " tensor([[-0.0029, 0.0416, -0.0102, ..., -0.0413, 0.0019, 0.0063],\n", + " [ 0.0054, 0.0138, 0.0031, ..., -0.0077, -0.0070, -0.0016],\n", + " [ 0.0128, 0.0153, -0.0147, ..., -0.0131, -0.0244, 0.0097],\n", + " ...,\n", + " [-0.0190, -0.0025, 0.0322, ..., -0.0106, -0.0323, -0.0144],\n", + " [-0.0269, -0.0007, 0.0070, ..., 0.0191, -0.0025, 0.0033],\n", + " [-0.0311, 0.0217, -0.0021, ..., 0.0302, -0.0131, 0.0388]])),\n", + " ('model.layers.2.mlp.down_proj.weight',\n", + " tensor([[ 0.0150, -0.0127, 0.0372, ..., 0.0018, 0.0018, 0.0187],\n", + " [-0.0262, 0.0164, 0.0281, ..., 0.0120, -0.0187, -0.0177],\n", + " [ 0.0129, -0.0042, 0.0018, ..., -0.0136, 0.0278, 0.0284],\n", + " ...,\n", + " [ 0.0048, 0.0421, -0.0018, ..., 0.0002, -0.0064, 0.0085],\n", + " [ 0.0276, 0.0146, 0.0228, ..., 0.0055, -0.0288, -0.0081],\n", + " [-0.0133, 0.0102, 0.0318, ..., 0.0209, -0.0270, 0.0128]])),\n", + " ('model.layers.2.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.2.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.3.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.3.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.3.mixer.in_proj.weight',\n", + " tensor([[ 7.4766e-03, -9.8698e-03, -1.9172e-02, ..., 3.7842e-02,\n", + " -2.1648e-03, 2.8147e-03],\n", + " [ 2.4954e-02, -1.2659e-02, 8.0447e-04, ..., 3.1716e-02,\n", + " 4.9989e-03, 6.4200e-03],\n", + " [-3.3345e-02, -1.5256e-02, 2.7295e-02, ..., -1.1240e-02,\n", + " 9.7000e-03, 3.1136e-05],\n", + " ...,\n", + " [-2.0807e-04, -2.5132e-02, -1.9983e-02, ..., -2.9541e-02,\n", + " 4.6152e-04, 5.5341e-02],\n", + " [ 2.0498e-03, 2.2021e-02, -7.6882e-03, ..., 1.6469e-02,\n", + " -1.0645e-02, -1.8442e-03],\n", + " [ 2.0949e-03, -1.2398e-02, 1.2922e-02, ..., 1.1862e-02,\n", + " -4.7119e-03, 3.2352e-02]])),\n", + " ('model.layers.3.mixer.conv1d.weight',\n", + " tensor([[[ 0.2590, 0.1670, 0.3987, -0.1694]],\n", + " \n", + " [[-0.4425, 0.1468, 0.3060, -0.0764]],\n", + " \n", + " [[-0.3638, -0.0575, 0.2156, -0.2468]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0111, -0.0182, -0.3816, 0.0382]],\n", + " \n", + " [[-0.4723, -0.3712, 0.1963, 0.2877]],\n", + " \n", + " [[-0.4890, 0.1197, 0.1361, 0.3282]]])),\n", + " ('model.layers.3.mixer.conv1d.bias',\n", + " tensor([-0.4712, -0.3272, 0.4587, ..., -0.3145, 0.4086, 0.4005])),\n", + " ('model.layers.3.mixer.out_proj.weight',\n", + " tensor([[-0.0362, 0.0137, -0.0296, ..., -0.0028, 0.0104, 0.0393],\n", + " [ 0.0130, 0.0246, -0.0132, ..., 0.0082, -0.0044, -0.0054],\n", + " [-0.0081, -0.0115, -0.0064, ..., 0.0250, -0.0076, -0.0021],\n", + " ...,\n", + " [ 0.0230, -0.0055, 0.0056, ..., 0.0076, 0.0016, -0.0068],\n", + " [ 0.0472, -0.0068, 0.0336, ..., 0.0079, 0.0211, 0.0031],\n", + " [-0.0450, -0.0005, 0.0219, ..., 0.0044, -0.0006, -0.0278]])),\n", + " ('model.layers.3.mlp.gate_proj.weight',\n", + " tensor([[ 0.0034, 0.0445, -0.0132, ..., 0.0290, 0.0019, 0.0048],\n", + " [ 0.0271, 0.0109, 0.0028, ..., -0.0304, -0.0237, -0.0017],\n", + " [ 0.0098, 0.0252, 0.0392, ..., 0.0486, 0.0326, -0.0171],\n", + " ...,\n", + " [-0.0015, 0.0080, 0.0005, ..., -0.0158, -0.0067, 0.0347],\n", + " [-0.0638, 0.0120, 0.0076, ..., 0.0007, 0.0052, -0.0109],\n", + " [-0.0303, -0.0168, -0.0537, ..., -0.0163, -0.0030, -0.0068]])),\n", + " ('model.layers.3.mlp.up_proj.weight',\n", + " tensor([[-0.0074, -0.0101, 0.0073, ..., -0.0012, -0.0208, -0.0239],\n", + " [ 0.0035, 0.0010, 0.0157, ..., -0.0228, -0.0224, 0.0194],\n", + " [ 0.0457, -0.0129, -0.0063, ..., -0.0312, 0.0261, -0.0018],\n", + " ...,\n", + " [ 0.0012, 0.0093, 0.0121, ..., -0.0035, -0.0367, -0.0454],\n", + " [ 0.0308, -0.0334, 0.0062, ..., 0.0043, -0.0031, -0.0406],\n", + " [-0.0175, -0.0089, -0.0137, ..., -0.0322, -0.0070, -0.0219]])),\n", + " ('model.layers.3.mlp.down_proj.weight',\n", + " tensor([[ 0.0226, 0.0074, -0.0170, ..., 0.0035, 0.0420, -0.0085],\n", + " [ 0.0116, 0.0173, -0.0009, ..., -0.0302, 0.0075, 0.0153],\n", + " [-0.0092, 0.0119, 0.0164, ..., 0.0233, -0.0177, -0.0397],\n", + " ...,\n", + " [-0.0006, -0.0275, 0.0127, ..., -0.0185, 0.0335, -0.0133],\n", + " [ 0.0064, -0.0200, 0.0296, ..., 0.0041, -0.0114, -0.0221],\n", + " [ 0.0317, 0.0392, 0.0553, ..., 0.0191, 0.0188, -0.0176]])),\n", + " ('model.layers.3.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.3.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.4.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.4.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.4.mixer.in_proj.weight',\n", + " tensor([[-0.0266, 0.0092, -0.0260, ..., -0.0121, -0.0286, 0.0267],\n", + " [ 0.0144, -0.0053, -0.0060, ..., -0.0065, 0.0201, -0.0025],\n", + " [-0.0092, -0.0465, -0.0032, ..., 0.0192, -0.0026, 0.0104],\n", + " ...,\n", + " [-0.0210, -0.0286, -0.0148, ..., 0.0593, 0.0130, 0.0118],\n", + " [ 0.0361, -0.0070, 0.0054, ..., -0.0073, 0.0004, 0.0287],\n", + " [ 0.0450, -0.0286, 0.0191, ..., -0.0180, 0.0039, -0.0033]])),\n", + " ('model.layers.4.mixer.conv1d.weight',\n", + " tensor([[[ 0.1450, 0.2065, -0.1750, -0.4560]],\n", + " \n", + " [[-0.2889, -0.4707, -0.0741, 0.1254]],\n", + " \n", + " [[-0.4665, 0.1876, -0.4049, 0.1143]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0709, 0.2021, -0.0053, -0.1558]],\n", + " \n", + " [[-0.0195, -0.4046, -0.2437, -0.4405]],\n", + " \n", + " [[-0.3615, -0.4314, 0.1667, 0.3139]]])),\n", + " ('model.layers.4.mixer.conv1d.bias',\n", + " tensor([-0.3220, -0.4181, -0.0623, ..., 0.2788, 0.0518, 0.4607])),\n", + " ('model.layers.4.mixer.out_proj.weight',\n", + " tensor([[-0.0011, -0.0279, -0.0160, ..., -0.0222, 0.0262, 0.0234],\n", + " [ 0.0024, 0.0178, -0.0142, ..., 0.0048, -0.0145, 0.0332],\n", + " [-0.0084, -0.0037, 0.0054, ..., -0.0201, -0.0341, -0.0053],\n", + " ...,\n", + " [-0.0120, -0.0440, 0.0097, ..., -0.0070, -0.0129, 0.0170],\n", + " [ 0.0096, -0.0034, -0.0025, ..., 0.0242, 0.0047, 0.0093],\n", + " [ 0.0254, 0.0207, 0.0135, ..., 0.0204, -0.0185, -0.0026]])),\n", + " ('model.layers.4.mlp.gate_proj.weight',\n", + " tensor([[ 0.0049, 0.0087, 0.0081, ..., 0.0145, 0.0188, 0.0441],\n", + " [-0.0103, 0.0147, 0.0180, ..., -0.0190, 0.0182, 0.0160],\n", + " [-0.0041, 0.0289, 0.0106, ..., 0.0144, -0.0070, 0.0104],\n", + " ...,\n", + " [ 0.0086, 0.0079, 0.0155, ..., 0.0037, -0.0242, 0.0091],\n", + " [-0.0320, 0.0084, -0.0508, ..., 0.0003, -0.0120, 0.0129],\n", + " [ 0.0079, 0.0185, 0.0285, ..., -0.0324, 0.0444, -0.0147]])),\n", + " ('model.layers.4.mlp.up_proj.weight',\n", + " tensor([[ 3.4382e-03, 1.9171e-02, 4.1226e-03, ..., 1.3158e-02,\n", + " 3.6365e-02, -8.1017e-03],\n", + " [ 1.8713e-02, -2.7732e-03, 3.1982e-02, ..., -8.5724e-03,\n", + " -3.1505e-02, 2.1047e-03],\n", + " [ 1.2329e-02, 1.8352e-03, 9.2540e-03, ..., 2.9880e-02,\n", + " -2.7856e-04, -8.7440e-04],\n", + " ...,\n", + " [-2.2330e-02, -2.0716e-02, 9.0004e-05, ..., -1.6298e-02,\n", + " -1.9620e-02, 2.5112e-02],\n", + " [ 7.1659e-03, 1.2942e-02, 1.0291e-03, ..., -1.0113e-02,\n", + " -1.6838e-03, 2.0189e-02],\n", + " [ 7.2108e-03, 3.1229e-02, 2.2533e-03, ..., -2.0148e-02,\n", + " -1.3502e-02, -1.8923e-02]])),\n", + " ('model.layers.4.mlp.down_proj.weight',\n", + " tensor([[ 0.0140, -0.0129, 0.0005, ..., -0.0068, -0.0335, 0.0172],\n", + " [-0.0175, -0.0011, 0.0114, ..., -0.0087, -0.0048, -0.0231],\n", + " [-0.0053, -0.0079, -0.0172, ..., -0.0125, -0.0200, 0.0127],\n", + " ...,\n", + " [ 0.0321, -0.0039, 0.0142, ..., 0.0384, 0.0054, 0.0321],\n", + " [ 0.0041, -0.0150, 0.0141, ..., 0.0049, -0.0348, -0.0028],\n", + " [ 0.0176, 0.0132, 0.0090, ..., -0.0117, 0.0241, 0.0417]])),\n", + " ('model.layers.4.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.4.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.5.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.5.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.5.mixer.in_proj.weight',\n", + " tensor([[ 0.0270, 0.0124, 0.0098, ..., 0.0170, -0.0225, 0.0032],\n", + " [ 0.0245, -0.0008, 0.0226, ..., 0.0219, -0.0219, 0.0087],\n", + " [-0.0175, 0.0181, 0.0124, ..., 0.0038, -0.0094, 0.0079],\n", + " ...,\n", + " [-0.0080, -0.0011, 0.0316, ..., -0.0012, 0.0254, 0.0251],\n", + " [-0.0141, -0.0159, -0.0069, ..., 0.0147, -0.0161, -0.0093],\n", + " [ 0.0252, 0.0125, 0.0174, ..., -0.0065, 0.0110, 0.0272]])),\n", + " ('model.layers.5.mixer.conv1d.weight',\n", + " tensor([[[ 0.0684, -0.4353, 0.3899, 0.3199]],\n", + " \n", + " [[ 0.4136, 0.4306, -0.4871, 0.4781]],\n", + " \n", + " [[-0.2516, 0.2109, 0.3891, 0.1501]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0781, -0.0675, -0.2995, -0.1805]],\n", + " \n", + " [[-0.3360, -0.4148, 0.1846, -0.1013]],\n", + " \n", + " [[ 0.1725, 0.1929, -0.0337, 0.1375]]])),\n", + " ('model.layers.5.mixer.conv1d.bias',\n", + " tensor([-0.4975, -0.0629, -0.2420, ..., -0.2253, 0.2512, 0.2788])),\n", + " ('model.layers.5.mixer.out_proj.weight',\n", + " tensor([[ 1.4306e-02, 1.3230e-02, -2.4141e-02, ..., 1.1763e-02,\n", + " 7.0706e-03, -4.7970e-03],\n", + " [ 2.7478e-02, 1.5179e-03, 1.9229e-02, ..., 1.0928e-02,\n", + " 2.2802e-02, -2.9729e-03],\n", + " [ 1.0169e-02, -1.0741e-02, 2.0628e-02, ..., -1.8109e-02,\n", + " -4.2582e-03, 2.4007e-02],\n", + " ...,\n", + " [-3.2843e-03, 3.7835e-03, -6.7958e-03, ..., -2.6205e-02,\n", + " -2.0391e-02, 5.3912e-03],\n", + " [ 1.2515e-02, -6.4975e-03, 9.9616e-05, ..., 1.0444e-02,\n", + " -2.0596e-02, -8.2915e-03],\n", + " [ 1.7899e-02, 2.0418e-02, -1.9891e-02, ..., -6.6709e-03,\n", + " -3.8566e-02, 2.7005e-02]])),\n", + " ('model.layers.5.mlp.gate_proj.weight',\n", + " tensor([[-2.3807e-03, 2.2714e-03, 2.2736e-05, ..., -2.3039e-03,\n", + " 3.6159e-02, -1.7253e-02],\n", + " [ 3.6929e-02, -6.2031e-03, 1.3606e-02, ..., 2.3592e-02,\n", + " 4.4487e-03, -9.6723e-03],\n", + " [ 4.7507e-02, 2.6413e-02, 1.6759e-02, ..., 1.1910e-02,\n", + " 1.2872e-02, -1.0443e-02],\n", + " ...,\n", + " [-2.0354e-02, -3.9074e-03, 9.7952e-03, ..., 1.0730e-02,\n", + " 2.8752e-02, -8.0048e-03],\n", + " [ 2.5331e-02, -9.9732e-03, 1.0772e-02, ..., 2.0420e-02,\n", + " -3.2179e-02, -1.6437e-02],\n", + " [-3.4425e-02, -1.4578e-02, 2.9686e-03, ..., 4.5907e-02,\n", + " 7.7639e-03, -2.2494e-03]])),\n", + " ('model.layers.5.mlp.up_proj.weight',\n", + " tensor([[ 1.5868e-02, -1.9222e-02, -1.2880e-03, ..., 8.3353e-03,\n", + " -1.8538e-02, 6.7395e-03],\n", + " [-1.8051e-02, -5.0142e-02, -2.2177e-03, ..., -9.3852e-03,\n", + " -3.0374e-02, 2.5795e-02],\n", + " [-1.1737e-02, 2.6278e-02, -2.3205e-02, ..., -1.8399e-03,\n", + " 1.4115e-02, -2.6438e-02],\n", + " ...,\n", + " [ 2.7706e-02, -2.5067e-03, -8.7058e-03, ..., 2.1662e-03,\n", + " -4.9858e-02, -1.1575e-02],\n", + " [-9.5670e-04, 2.1698e-02, -5.4794e-03, ..., -1.0661e-02,\n", + " 1.8568e-02, 5.2615e-03],\n", + " [ 1.0739e-03, 2.2945e-02, 3.0835e-02, ..., 4.1212e-03,\n", + " 1.2643e-02, -1.1568e-05]])),\n", + " ('model.layers.5.mlp.down_proj.weight',\n", + " tensor([[ 0.0052, -0.0343, 0.0072, ..., 0.0004, 0.0320, 0.0362],\n", + " [ 0.0171, -0.0238, -0.0316, ..., 0.0231, 0.0377, 0.0141],\n", + " [-0.0205, 0.0152, 0.0002, ..., -0.0061, -0.0353, -0.0138],\n", + " ...,\n", + " [-0.0039, -0.0039, 0.0326, ..., -0.0208, 0.0160, 0.0185],\n", + " [ 0.0176, -0.0300, -0.0024, ..., -0.0292, -0.0254, -0.0366],\n", + " [ 0.0361, 0.0243, -0.0253, ..., -0.0036, -0.0099, -0.0133]])),\n", + " ('model.layers.5.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.5.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.6.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.6.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.6.mixer.in_proj.weight',\n", + " tensor([[-0.0505, -0.0650, 0.0059, ..., 0.0060, 0.0347, 0.0149],\n", + " [-0.0216, 0.0057, -0.0281, ..., -0.0162, 0.0081, 0.0016],\n", + " [-0.0339, -0.0314, 0.0253, ..., 0.0030, 0.0139, -0.0039],\n", + " ...,\n", + " [ 0.0355, -0.0238, -0.0015, ..., 0.0063, 0.0284, -0.0089],\n", + " [ 0.0093, -0.0381, -0.0261, ..., -0.0170, -0.0170, -0.0288],\n", + " [-0.0228, -0.0110, 0.0107, ..., 0.0300, 0.0010, 0.0141]])),\n", + " ('model.layers.6.mixer.conv1d.weight',\n", + " tensor([[[ 0.4364, 0.2888, 0.2343, 0.3226]],\n", + " \n", + " [[ 0.2804, 0.3558, 0.4061, -0.0480]],\n", + " \n", + " [[ 0.4964, 0.0709, 0.0748, 0.0971]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.4291, 0.2445, -0.3121, 0.4013]],\n", + " \n", + " [[-0.1590, -0.1516, 0.0804, 0.2009]],\n", + " \n", + " [[ 0.1686, 0.0492, -0.2932, 0.1381]]])),\n", + " ('model.layers.6.mixer.conv1d.bias',\n", + " tensor([ 0.4241, -0.0500, 0.3393, ..., 0.1598, -0.4924, -0.3241])),\n", + " ('model.layers.6.mixer.out_proj.weight',\n", + " tensor([[ 0.0026, 0.0272, 0.0005, ..., 0.0434, -0.0293, -0.0105],\n", + " [ 0.0323, -0.0515, 0.0107, ..., -0.0406, 0.0252, -0.0038],\n", + " [-0.0156, -0.0078, 0.0173, ..., 0.0312, -0.0014, -0.0014],\n", + " ...,\n", + " [ 0.0014, -0.0522, -0.0154, ..., 0.0090, -0.0050, -0.0049],\n", + " [ 0.0350, 0.0099, -0.0014, ..., -0.0008, -0.0185, -0.0033],\n", + " [ 0.0134, 0.0002, 0.0325, ..., -0.0129, 0.0165, -0.0265]])),\n", + " ('model.layers.6.mlp.gate_proj.weight',\n", + " tensor([[-0.0011, 0.0202, 0.0236, ..., -0.0137, -0.0063, 0.0085],\n", + " [ 0.0163, 0.0261, 0.0120, ..., -0.0003, -0.0254, 0.0001],\n", + " [ 0.0318, -0.0121, 0.0103, ..., -0.0053, 0.0194, 0.0530],\n", + " ...,\n", + " [ 0.0039, 0.0228, -0.0147, ..., 0.0027, 0.0092, -0.0033],\n", + " [-0.0040, 0.0144, 0.0038, ..., -0.0106, -0.0022, 0.0094],\n", + " [ 0.0220, 0.0296, 0.0550, ..., 0.0079, -0.0135, -0.0092]])),\n", + " ('model.layers.6.mlp.up_proj.weight',\n", + " tensor([[ 0.0061, -0.0291, -0.0133, ..., 0.0054, -0.0049, -0.0028],\n", + " [-0.0032, -0.0201, 0.0218, ..., -0.0155, -0.0264, 0.0496],\n", + " [-0.0046, 0.0384, -0.0093, ..., 0.0356, -0.0245, 0.0175],\n", + " ...,\n", + " [-0.0111, -0.0092, -0.0143, ..., 0.0010, -0.0453, 0.0024],\n", + " [ 0.0078, -0.0025, 0.0227, ..., -0.0130, 0.0118, 0.0095],\n", + " [ 0.0234, -0.0114, -0.0102, ..., -0.0179, -0.0066, -0.0115]])),\n", + " ('model.layers.6.mlp.down_proj.weight',\n", + " tensor([[ 3.6976e-02, 1.7124e-02, -2.1290e-02, ..., -2.5206e-02,\n", + " 4.8023e-03, 9.8474e-03],\n", + " [-7.2866e-03, -5.4149e-03, -2.2242e-03, ..., -8.1606e-03,\n", + " -9.5275e-04, -1.8121e-02],\n", + " [-8.3493e-03, 1.2509e-02, 1.0773e-02, ..., 2.7061e-02,\n", + " 2.8131e-03, 5.8219e-03],\n", + " ...,\n", + " [ 8.7099e-03, 3.9196e-02, -3.5129e-03, ..., -2.3595e-02,\n", + " -8.3965e-03, 2.0074e-02],\n", + " [-2.7467e-02, -2.8721e-03, -2.2291e-02, ..., 9.7135e-03,\n", + " 3.4947e-02, -2.2158e-02],\n", + " [ 6.1744e-03, -4.7684e-03, 4.6690e-04, ..., -3.2948e-03,\n", + " 4.0735e-05, 3.3651e-02]])),\n", + " ('model.layers.6.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.6.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.7.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.7.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.7.mixer.in_proj.weight',\n", + " tensor([[-0.0045, -0.0288, 0.0362, ..., -0.0092, -0.0026, 0.0051],\n", + " [ 0.0160, 0.0139, 0.0057, ..., 0.0121, 0.0071, 0.0134],\n", + " [ 0.0062, 0.0181, 0.0161, ..., -0.0284, -0.0014, -0.0171],\n", + " ...,\n", + " [-0.0053, 0.0067, 0.0095, ..., -0.0175, 0.0235, 0.0125],\n", + " [-0.0048, 0.0041, 0.0038, ..., 0.0099, 0.0194, 0.0124],\n", + " [ 0.0131, 0.0073, -0.0284, ..., 0.0138, -0.0218, 0.0019]])),\n", + " ('model.layers.7.mixer.conv1d.weight',\n", + " tensor([[[ 0.2528, -0.0556, -0.3225, 0.1327]],\n", + " \n", + " [[-0.0437, 0.4941, -0.4075, 0.1062]],\n", + " \n", + " [[-0.3428, 0.2675, 0.1871, 0.0260]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0409, -0.4458, 0.4488, 0.2841]],\n", + " \n", + " [[-0.2370, -0.3965, 0.0656, -0.1339]],\n", + " \n", + " [[ 0.4677, 0.0073, 0.3741, 0.1525]]])),\n", + " ('model.layers.7.mixer.conv1d.bias',\n", + " tensor([-0.1844, -0.1347, 0.0043, ..., -0.3839, -0.2167, -0.4637])),\n", + " ('model.layers.7.mixer.out_proj.weight',\n", + " tensor([[-2.8471e-02, 3.9783e-03, 6.0125e-03, ..., -1.6079e-02,\n", + " 1.4225e-02, 2.8166e-02],\n", + " [ 5.4680e-03, -5.1414e-03, 5.3077e-05, ..., 1.8734e-02,\n", + " 3.7454e-03, 1.7579e-02],\n", + " [-1.2955e-02, 1.4954e-02, 6.4922e-03, ..., -2.6830e-02,\n", + " 1.4766e-02, -1.8002e-02],\n", + " ...,\n", + " [ 1.7150e-02, 4.6781e-02, -1.1136e-02, ..., 4.7242e-03,\n", + " -1.3072e-02, -1.0412e-02],\n", + " [ 5.5498e-03, -3.0803e-02, -2.4880e-02, ..., -4.2644e-03,\n", + " -1.1047e-02, 1.5815e-02],\n", + " [ 1.7242e-02, 2.7994e-02, -4.8186e-04, ..., -2.2003e-02,\n", + " -2.1834e-02, -2.1826e-02]])),\n", + " ('model.layers.7.mlp.gate_proj.weight',\n", + " tensor([[-0.0302, -0.0160, -0.0341, ..., -0.0121, 0.0007, -0.0338],\n", + " [-0.0186, 0.0257, -0.0154, ..., 0.0153, -0.0029, 0.0163],\n", + " [ 0.0170, 0.0223, -0.0185, ..., -0.0020, 0.0061, 0.0174],\n", + " ...,\n", + " [-0.0044, 0.0044, 0.0077, ..., -0.0183, 0.0041, -0.0003],\n", + " [ 0.0168, 0.0149, -0.0221, ..., 0.0112, 0.0357, 0.0042],\n", + " [ 0.0310, -0.0217, 0.0070, ..., -0.0394, -0.0065, 0.0204]])),\n", + " ('model.layers.7.mlp.up_proj.weight',\n", + " tensor([[-0.0031, -0.0110, 0.0091, ..., 0.0152, -0.0013, 0.0096],\n", + " [ 0.0013, 0.0354, -0.0037, ..., 0.0130, 0.0204, 0.0262],\n", + " [-0.0075, -0.0044, 0.0207, ..., 0.0057, 0.0115, 0.0151],\n", + " ...,\n", + " [-0.0015, 0.0095, -0.0100, ..., -0.0150, 0.0105, -0.0350],\n", + " [-0.0300, -0.0092, -0.0176, ..., -0.0113, 0.0164, -0.0117],\n", + " [-0.0291, -0.0085, 0.0058, ..., 0.0386, -0.0174, -0.0092]])),\n", + " ('model.layers.7.mlp.down_proj.weight',\n", + " tensor([[-0.0276, 0.0017, -0.0217, ..., 0.0302, -0.0079, -0.0003],\n", + " [ 0.0379, 0.0052, 0.0052, ..., 0.0145, 0.0139, -0.0143],\n", + " [ 0.0176, -0.0028, 0.0172, ..., -0.0205, -0.0165, -0.0040],\n", + " ...,\n", + " [ 0.0095, -0.0139, 0.0077, ..., -0.0080, 0.0339, 0.0172],\n", + " [-0.0177, 0.0009, -0.0245, ..., 0.0040, 0.0258, 0.0202],\n", + " [-0.0064, -0.0270, 0.0041, ..., -0.0133, -0.0040, 0.0038]])),\n", + " ('model.layers.7.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.7.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.8.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.8.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.8.mixer.in_proj.weight',\n", + " tensor([[ 0.0050, 0.0270, -0.0196, ..., -0.0121, -0.0090, 0.0083],\n", + " [-0.0083, -0.0177, 0.0159, ..., 0.0298, -0.0202, -0.0265],\n", + " [ 0.0058, 0.0186, 0.0125, ..., -0.0067, -0.0255, 0.0298],\n", + " ...,\n", + " [-0.0164, 0.0012, 0.0023, ..., -0.0355, 0.0347, -0.0011],\n", + " [-0.0371, 0.0033, 0.0345, ..., -0.0097, 0.0019, 0.0185],\n", + " [-0.0322, -0.0160, 0.0072, ..., -0.0195, -0.0229, 0.0118]])),\n", + " ('model.layers.8.mixer.conv1d.weight',\n", + " tensor([[[-0.0520, 0.3004, -0.1990, 0.2512]],\n", + " \n", + " [[-0.4120, -0.0055, 0.1484, -0.3316]],\n", + " \n", + " [[ 0.3939, -0.0567, 0.1432, 0.1880]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.2849, 0.2494, -0.2141, -0.3375]],\n", + " \n", + " [[-0.2823, -0.2402, 0.2228, 0.2331]],\n", + " \n", + " [[ 0.1914, 0.4269, 0.1228, -0.3408]]])),\n", + " ('model.layers.8.mixer.conv1d.bias',\n", + " tensor([0.1304, 0.2065, 0.3084, ..., 0.3863, 0.4883, 0.4724])),\n", + " ('model.layers.8.mixer.out_proj.weight',\n", + " tensor([[ 0.0008, -0.0019, 0.0084, ..., -0.0003, 0.0045, 0.0024],\n", + " [ 0.0137, -0.0003, -0.0031, ..., 0.0013, 0.0131, 0.0090],\n", + " [ 0.0095, 0.0488, -0.0355, ..., 0.0344, -0.0229, -0.0150],\n", + " ...,\n", + " [ 0.0029, 0.0164, -0.0380, ..., -0.0005, -0.0031, 0.0127],\n", + " [-0.0039, 0.0283, 0.0295, ..., 0.0271, -0.0105, -0.0158],\n", + " [-0.0057, -0.0178, 0.0129, ..., 0.0323, -0.0091, 0.0178]])),\n", + " ('model.layers.8.mlp.gate_proj.weight',\n", + " tensor([[-0.0047, 0.0037, -0.0129, ..., 0.0255, -0.0118, 0.0084],\n", + " [ 0.0418, -0.0020, 0.0205, ..., 0.0161, 0.0306, 0.0250],\n", + " [ 0.0011, 0.0144, 0.0204, ..., -0.0007, 0.0298, -0.0067],\n", + " ...,\n", + " [-0.0536, -0.0083, -0.0049, ..., -0.0028, 0.0301, -0.0205],\n", + " [ 0.0031, 0.0139, 0.0070, ..., 0.0120, 0.0004, -0.0226],\n", + " [ 0.0114, -0.0173, 0.0212, ..., -0.0413, -0.0069, 0.0007]])),\n", + " ('model.layers.8.mlp.up_proj.weight',\n", + " tensor([[-0.0005, 0.0028, -0.0137, ..., 0.0078, 0.0348, 0.0006],\n", + " [-0.0020, 0.0300, -0.0056, ..., -0.0258, -0.0130, -0.0212],\n", + " [-0.0135, -0.0111, 0.0151, ..., 0.0043, -0.0426, -0.0109],\n", + " ...,\n", + " [ 0.0273, 0.0057, -0.0108, ..., -0.0205, 0.0005, -0.0239],\n", + " [ 0.0226, 0.0325, -0.0187, ..., 0.0069, -0.0132, -0.0002],\n", + " [ 0.0280, -0.0007, -0.0047, ..., 0.0159, -0.0054, -0.0172]])),\n", + " ('model.layers.8.mlp.down_proj.weight',\n", + " tensor([[-0.0091, 0.0072, 0.0030, ..., 0.0025, -0.0159, -0.0277],\n", + " [ 0.0159, -0.0260, -0.0076, ..., -0.0059, -0.0129, 0.0358],\n", + " [ 0.0026, -0.0357, -0.0138, ..., -0.0326, -0.0291, 0.0010],\n", + " ...,\n", + " [-0.0237, 0.0272, -0.0130, ..., -0.0280, 0.0097, -0.0563],\n", + " [ 0.0092, 0.0056, 0.0079, ..., -0.0224, 0.0039, -0.0054],\n", + " [-0.0109, -0.0241, -0.0223, ..., -0.0187, 0.0190, 0.0082]])),\n", + " ('model.layers.8.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.8.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.9.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.9.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.9.mixer.in_proj.weight',\n", + " tensor([[ 4.9824e-02, 5.7576e-03, -5.1022e-03, ..., -2.5615e-02,\n", + " 7.1750e-04, 1.5247e-02],\n", + " [-2.8065e-02, -1.2649e-02, -2.3566e-02, ..., 1.7742e-02,\n", + " -1.1202e-02, -2.1476e-02],\n", + " [ 2.0911e-02, 1.6496e-02, -1.9818e-02, ..., 4.0223e-02,\n", + " 1.8544e-02, -2.3633e-02],\n", + " ...,\n", + " [-4.3387e-02, -1.6504e-02, 2.2008e-02, ..., -2.5138e-03,\n", + " -5.6073e-03, -4.8212e-03],\n", + " [-1.9964e-05, -1.5835e-02, 1.2977e-02, ..., 4.1913e-03,\n", + " 4.5898e-02, -3.5822e-02],\n", + " [ 3.1376e-02, -5.4614e-03, -2.5093e-02, ..., -3.7903e-03,\n", + " 1.3560e-02, 3.3366e-02]])),\n", + " ('model.layers.9.mixer.conv1d.weight',\n", + " tensor([[[ 0.1986, -0.1666, -0.4140, -0.4607]],\n", + " \n", + " [[-0.3454, -0.3973, 0.2169, -0.2138]],\n", + " \n", + " [[ 0.2006, -0.3736, 0.3944, -0.0589]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.4604, 0.1224, -0.2571, -0.0286]],\n", + " \n", + " [[-0.2723, -0.1617, 0.3483, 0.2299]],\n", + " \n", + " [[ 0.4866, 0.2559, 0.3969, 0.0554]]])),\n", + " ('model.layers.9.mixer.conv1d.bias',\n", + " tensor([ 0.3388, 0.4633, -0.3762, ..., -0.3491, -0.2971, 0.0494])),\n", + " ('model.layers.9.mixer.out_proj.weight',\n", + " tensor([[ 0.0023, -0.0181, 0.0358, ..., 0.0243, 0.0070, -0.0183],\n", + " [ 0.0006, 0.0065, 0.0057, ..., -0.0351, -0.0107, 0.0132],\n", + " [ 0.0153, -0.0038, 0.0059, ..., -0.0285, -0.0247, -0.0104],\n", + " ...,\n", + " [ 0.0244, -0.0120, 0.0064, ..., -0.0133, 0.0263, 0.0016],\n", + " [ 0.0056, -0.0111, 0.0029, ..., -0.0017, -0.0172, -0.0071],\n", + " [-0.0056, -0.0192, -0.0238, ..., 0.0245, -0.0102, -0.0331]])),\n", + " ('model.layers.9.mlp.gate_proj.weight',\n", + " tensor([[-0.0132, 0.0014, -0.0413, ..., -0.0254, -0.0245, 0.0031],\n", + " [-0.0195, -0.0107, -0.0192, ..., 0.0012, -0.0026, 0.0148],\n", + " [-0.0074, -0.0070, -0.0078, ..., 0.0013, -0.0011, -0.0111],\n", + " ...,\n", + " [-0.0137, 0.0302, 0.0084, ..., -0.0063, -0.0065, 0.0240],\n", + " [ 0.0072, 0.0134, 0.0161, ..., 0.0122, 0.0182, 0.0137],\n", + " [ 0.0079, 0.0008, 0.0160, ..., 0.0281, 0.0226, 0.0058]])),\n", + " ('model.layers.9.mlp.up_proj.weight',\n", + " tensor([[ 0.0078, 0.0153, -0.0155, ..., 0.0153, -0.0164, -0.0140],\n", + " [-0.0072, -0.0050, 0.0030, ..., 0.0146, -0.0148, -0.0080],\n", + " [ 0.0165, -0.0078, 0.0005, ..., -0.0545, -0.0096, 0.0296],\n", + " ...,\n", + " [-0.0253, 0.0183, -0.0081, ..., -0.0061, 0.0270, -0.0003],\n", + " [-0.0015, -0.0320, 0.0361, ..., -0.0087, 0.0341, -0.0157],\n", + " [ 0.0041, 0.0102, -0.0195, ..., -0.0441, -0.0106, 0.0275]])),\n", + " ('model.layers.9.mlp.down_proj.weight',\n", + " tensor([[-6.3367e-02, -1.8214e-02, 5.7221e-03, ..., 2.1307e-02,\n", + " -3.0707e-02, -1.3281e-02],\n", + " [-7.7457e-05, -9.1894e-05, 6.8686e-03, ..., -4.7175e-03,\n", + " -1.1585e-03, -2.7604e-02],\n", + " [ 2.9301e-02, -5.9431e-03, -2.5356e-03, ..., -2.7858e-02,\n", + " 1.1647e-02, 1.1245e-02],\n", + " ...,\n", + " [-1.0442e-02, -9.6151e-03, -3.6635e-02, ..., -1.1052e-02,\n", + " -4.5122e-03, 4.0012e-03],\n", + " [ 3.2950e-02, -1.3836e-03, -7.8318e-03, ..., -1.2788e-03,\n", + " 2.3422e-02, -3.2098e-02],\n", + " [-9.2294e-03, 1.3838e-02, -2.0327e-02, ..., -3.8760e-02,\n", + " 2.2118e-02, 1.0696e-02]])),\n", + " ('model.layers.9.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.9.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.10.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.10.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.10.mixer.in_proj.weight',\n", + " tensor([[ 0.0096, -0.0159, 0.0141, ..., 0.0111, 0.0218, 0.0220],\n", + " [-0.0381, -0.0015, 0.0126, ..., -0.0066, -0.0034, -0.0119],\n", + " [ 0.0223, 0.0032, -0.0195, ..., -0.0107, -0.0018, 0.0059],\n", + " ...,\n", + " [-0.0256, -0.0170, -0.0362, ..., -0.0007, -0.0039, 0.0075],\n", + " [ 0.0136, -0.0045, 0.0128, ..., -0.0017, 0.0083, -0.0004],\n", + " [-0.0246, -0.0021, 0.0073, ..., 0.0020, 0.0071, 0.0090]])),\n", + " ('model.layers.10.mixer.conv1d.weight',\n", + " tensor([[[ 0.0463, -0.4497, -0.0679, -0.2209]],\n", + " \n", + " [[-0.3805, 0.4459, 0.1999, -0.4996]],\n", + " \n", + " [[ 0.1529, 0.1789, -0.1535, 0.1824]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.1087, -0.4478, -0.0420, 0.3437]],\n", + " \n", + " [[-0.2809, -0.4617, 0.3209, 0.4873]],\n", + " \n", + " [[ 0.1139, -0.0060, -0.0219, 0.0853]]])),\n", + " ('model.layers.10.mixer.conv1d.bias',\n", + " tensor([ 0.1364, -0.0475, 0.0849, ..., 0.1928, 0.2075, 0.1058])),\n", + " ('model.layers.10.mixer.out_proj.weight',\n", + " tensor([[-0.0164, -0.0188, 0.0174, ..., -0.0106, -0.0107, -0.0036],\n", + " [ 0.0048, -0.0016, -0.0444, ..., -0.0182, -0.0264, -0.0038],\n", + " [ 0.0089, -0.0225, -0.0002, ..., -0.0141, -0.0008, -0.0037],\n", + " ...,\n", + " [-0.0005, 0.0159, 0.0033, ..., 0.0187, -0.0064, 0.0233],\n", + " [-0.0050, 0.0296, 0.0147, ..., -0.0018, 0.0137, -0.0346],\n", + " [-0.0064, -0.0132, -0.0434, ..., -0.0173, -0.0113, -0.0175]])),\n", + " ('model.layers.10.mlp.gate_proj.weight',\n", + " tensor([[-0.0174, -0.0053, -0.0325, ..., -0.0072, -0.0280, 0.0033],\n", + " [ 0.0006, -0.0160, 0.0346, ..., 0.0019, 0.0059, 0.0198],\n", + " [ 0.0231, -0.0187, 0.0115, ..., 0.0085, 0.0080, 0.0061],\n", + " ...,\n", + " [ 0.0153, 0.0241, -0.0184, ..., 0.0089, -0.0242, 0.0010],\n", + " [-0.0019, -0.0322, 0.0011, ..., -0.0097, -0.0305, 0.0065],\n", + " [-0.0107, 0.0240, 0.0168, ..., 0.0226, -0.0238, 0.0117]])),\n", + " ('model.layers.10.mlp.up_proj.weight',\n", + " tensor([[-0.0072, 0.0352, 0.0282, ..., -0.0025, -0.0114, 0.0129],\n", + " [-0.0102, 0.0196, 0.0760, ..., 0.0461, -0.0058, -0.0112],\n", + " [-0.0271, 0.0323, -0.0069, ..., 0.0133, -0.0371, -0.0619],\n", + " ...,\n", + " [ 0.0100, 0.0011, 0.0262, ..., -0.0232, 0.0217, 0.0002],\n", + " [ 0.0151, -0.0266, -0.0074, ..., 0.0096, 0.0036, 0.0033],\n", + " [ 0.0004, 0.0103, 0.0363, ..., -0.0095, -0.0309, -0.0059]])),\n", + " ('model.layers.10.mlp.down_proj.weight',\n", + " tensor([[ 0.0124, -0.0225, -0.0294, ..., 0.0280, 0.0056, 0.0231],\n", + " [ 0.0124, -0.0030, 0.0014, ..., 0.0323, 0.0094, -0.0034],\n", + " [-0.0078, 0.0041, -0.0056, ..., 0.0241, -0.0278, -0.0152],\n", + " ...,\n", + " [-0.0044, 0.0025, -0.0161, ..., -0.0075, -0.0126, 0.0014],\n", + " [-0.0109, -0.0050, 0.0327, ..., -0.0300, -0.0048, 0.0284],\n", + " [ 0.0050, -0.0183, 0.0086, ..., -0.0072, 0.0139, -0.0010]])),\n", + " ('model.layers.10.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.10.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.11.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.11.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.11.mixer.in_proj.weight',\n", + " tensor([[-0.0133, 0.0225, 0.0486, ..., -0.0214, -0.0120, -0.0150],\n", + " [ 0.0183, 0.0020, 0.0079, ..., -0.0163, 0.0016, -0.0214],\n", + " [-0.0276, -0.0112, 0.0121, ..., -0.0057, -0.0143, -0.0462],\n", + " ...,\n", + " [-0.0142, -0.0080, -0.0194, ..., 0.0087, -0.0212, -0.0140],\n", + " [ 0.0060, -0.0005, -0.0171, ..., -0.0017, 0.0223, 0.0169],\n", + " [-0.0290, -0.0016, 0.0117, ..., 0.0037, 0.0047, 0.0152]])),\n", + " ('model.layers.11.mixer.conv1d.weight',\n", + " tensor([[[-0.2822, -0.4216, 0.4786, 0.0802]],\n", + " \n", + " [[-0.3671, 0.1761, -0.2686, 0.1631]],\n", + " \n", + " [[-0.3902, -0.2811, -0.0748, 0.4662]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.1623, 0.2871, -0.4585, 0.4755]],\n", + " \n", + " [[-0.0260, 0.4541, -0.2983, 0.2297]],\n", + " \n", + " [[-0.2991, -0.3590, -0.3256, -0.1434]]])),\n", + " ('model.layers.11.mixer.conv1d.bias',\n", + " tensor([ 0.1218, -0.0542, 0.3485, ..., 0.0528, 0.2711, -0.2811])),\n", + " ('model.layers.11.mixer.out_proj.weight',\n", + " tensor([[ 0.0032, 0.0028, -0.0122, ..., -0.0299, -0.0105, 0.0021],\n", + " [-0.0466, -0.0170, -0.0017, ..., 0.0156, -0.0287, 0.0066],\n", + " [ 0.0016, 0.0054, -0.0071, ..., -0.0240, 0.0215, -0.0046],\n", + " ...,\n", + " [-0.0210, 0.0034, -0.0267, ..., 0.0461, -0.0076, -0.0016],\n", + " [-0.0012, -0.0101, 0.0196, ..., 0.0121, -0.0043, -0.0143],\n", + " [-0.0067, 0.0086, 0.0134, ..., 0.0080, 0.0255, 0.0225]])),\n", + " ('model.layers.11.mlp.gate_proj.weight',\n", + " tensor([[ 0.0179, -0.0429, -0.0134, ..., 0.0110, 0.0368, -0.0259],\n", + " [ 0.0013, -0.0231, 0.0072, ..., -0.0056, -0.0012, -0.0037],\n", + " [-0.0172, -0.0162, 0.0088, ..., -0.0175, 0.0079, -0.0065],\n", + " ...,\n", + " [ 0.0287, -0.0289, 0.0045, ..., 0.0039, 0.0269, 0.0199],\n", + " [ 0.0043, -0.0202, -0.0261, ..., 0.0104, -0.0161, -0.0057],\n", + " [-0.0154, 0.0085, 0.0061, ..., 0.0208, 0.0001, 0.0166]])),\n", + " ('model.layers.11.mlp.up_proj.weight',\n", + " tensor([[-0.0107, 0.0328, 0.0065, ..., -0.0190, -0.0082, -0.0047],\n", + " [-0.0001, 0.0102, 0.0310, ..., -0.0396, -0.0278, -0.0095],\n", + " [-0.0288, 0.0052, 0.0137, ..., -0.0220, 0.0007, -0.0170],\n", + " ...,\n", + " [ 0.0213, -0.0074, -0.0033, ..., 0.0183, 0.0336, -0.0180],\n", + " [-0.0098, -0.0162, 0.0486, ..., 0.0191, 0.0064, 0.0269],\n", + " [-0.0251, 0.0081, 0.0053, ..., 0.0110, 0.0023, 0.0041]])),\n", + " ('model.layers.11.mlp.down_proj.weight',\n", + " tensor([[ 0.0166, -0.0410, 0.0066, ..., -0.0273, 0.0220, 0.0184],\n", + " [ 0.0092, 0.0087, -0.0136, ..., 0.0013, -0.0205, 0.0247],\n", + " [-0.0252, -0.0040, -0.0112, ..., -0.0331, 0.0201, -0.0038],\n", + " ...,\n", + " [ 0.0072, 0.0190, 0.0089, ..., 0.0098, -0.0235, -0.0141],\n", + " [-0.0045, -0.0381, -0.0134, ..., 0.0171, -0.0077, -0.0180],\n", + " [ 0.0109, 0.0060, 0.0048, ..., -0.0108, -0.0122, 0.0110]])),\n", + " ('model.layers.11.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.11.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.12.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.12.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.12.mixer.in_proj.weight',\n", + " tensor([[ 0.0043, 0.0138, 0.0138, ..., -0.0042, 0.0121, -0.0190],\n", + " [ 0.0002, -0.0199, 0.0315, ..., 0.0170, 0.0051, -0.0062],\n", + " [-0.0053, 0.0043, 0.0283, ..., -0.0087, 0.0069, -0.0160],\n", + " ...,\n", + " [-0.0313, 0.0200, 0.0036, ..., 0.0147, 0.0153, 0.0098],\n", + " [-0.0157, 0.0120, -0.0112, ..., 0.0166, -0.0005, 0.0066],\n", + " [-0.0271, 0.0037, 0.0163, ..., 0.0304, 0.0023, 0.0083]])),\n", + " ('model.layers.12.mixer.conv1d.weight',\n", + " tensor([[[-0.4295, -0.2474, -0.2324, -0.2138]],\n", + " \n", + " [[ 0.3607, -0.4824, 0.1667, 0.1348]],\n", + " \n", + " [[ 0.3596, 0.1167, 0.1089, -0.4010]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.3527, -0.3346, -0.3755, 0.1450]],\n", + " \n", + " [[-0.1921, -0.0632, -0.4885, -0.3986]],\n", + " \n", + " [[ 0.1950, 0.3037, -0.1630, 0.0353]]])),\n", + " ('model.layers.12.mixer.conv1d.bias',\n", + " tensor([0.3103, 0.0451, 0.4533, ..., 0.0235, 0.1819, 0.3933])),\n", + " ('model.layers.12.mixer.out_proj.weight',\n", + " tensor([[ 0.0167, -0.0197, -0.0054, ..., 0.0096, 0.0271, -0.0118],\n", + " [ 0.0167, -0.0455, 0.0001, ..., 0.0003, 0.0265, 0.0111],\n", + " [ 0.0231, -0.0113, 0.0195, ..., -0.0171, -0.0044, -0.0244],\n", + " ...,\n", + " [ 0.0042, 0.0048, 0.0357, ..., 0.0126, -0.0288, 0.0149],\n", + " [ 0.0192, 0.0078, 0.0126, ..., 0.0029, 0.0255, -0.0203],\n", + " [-0.0054, -0.0543, 0.0039, ..., -0.0240, 0.0282, 0.0082]])),\n", + " ('model.layers.12.mlp.gate_proj.weight',\n", + " tensor([[-0.0417, -0.0193, -0.0022, ..., 0.0031, 0.0337, 0.0175],\n", + " [ 0.0215, -0.0109, -0.0657, ..., -0.0145, -0.0475, -0.0091],\n", + " [-0.0225, -0.0012, -0.0020, ..., -0.0291, 0.0097, 0.0163],\n", + " ...,\n", + " [-0.0018, 0.0048, -0.0265, ..., -0.0056, 0.0446, 0.0045],\n", + " [ 0.0270, 0.0086, -0.0110, ..., -0.0038, 0.0176, 0.0138],\n", + " [-0.0134, 0.0046, -0.0186, ..., -0.0098, 0.0191, 0.0095]])),\n", + " ('model.layers.12.mlp.up_proj.weight',\n", + " tensor([[ 0.0180, 0.0075, 0.0147, ..., 0.0142, 0.0291, -0.0303],\n", + " [-0.0079, -0.0277, -0.0151, ..., -0.0069, -0.0045, -0.0223],\n", + " [ 0.0180, -0.0087, 0.0074, ..., 0.0215, 0.0274, -0.0199],\n", + " ...,\n", + " [-0.0215, -0.0115, 0.0140, ..., -0.0283, -0.0171, -0.0229],\n", + " [ 0.0231, -0.0179, -0.0386, ..., 0.0364, 0.0311, 0.0048],\n", + " [-0.0111, 0.0079, 0.0328, ..., 0.0285, 0.0423, 0.0039]])),\n", + " ('model.layers.12.mlp.down_proj.weight',\n", + " tensor([[-0.0361, 0.0192, -0.0005, ..., -0.0151, 0.0116, -0.0068],\n", + " [ 0.0203, -0.0064, 0.0061, ..., 0.0325, -0.0004, -0.0299],\n", + " [-0.0028, 0.0131, 0.0141, ..., -0.0108, -0.0070, -0.0090],\n", + " ...,\n", + " [ 0.0165, -0.0198, -0.0242, ..., 0.0162, 0.0099, 0.0025],\n", + " [ 0.0148, 0.0056, -0.0139, ..., 0.0108, -0.0477, 0.0225],\n", + " [ 0.0156, 0.0249, -0.0287, ..., -0.0200, -0.0496, 0.0169]])),\n", + " ('model.layers.12.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.12.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.13.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.13.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.13.mixer.in_proj.weight',\n", + " tensor([[-0.0064, -0.0200, 0.0384, ..., -0.0036, 0.0158, -0.0007],\n", + " [-0.0074, 0.0105, 0.0043, ..., 0.0097, 0.0259, -0.0012],\n", + " [ 0.0297, -0.0146, -0.0012, ..., 0.0273, 0.0309, 0.0087],\n", + " ...,\n", + " [ 0.0204, -0.0063, 0.0136, ..., -0.0092, 0.0196, 0.0057],\n", + " [ 0.0195, 0.0059, 0.0228, ..., 0.0093, -0.0183, -0.0003],\n", + " [-0.0131, -0.0447, -0.0262, ..., -0.0125, 0.0237, -0.0404]])),\n", + " ('model.layers.13.mixer.conv1d.weight',\n", + " tensor([[[ 7.7458e-03, 4.9829e-01, 2.1690e-01, -2.3587e-01]],\n", + " \n", + " [[ 3.7281e-01, -4.0991e-03, 2.4588e-01, -1.1600e-01]],\n", + " \n", + " [[-4.8238e-01, -2.8961e-01, -4.4331e-02, 1.0011e-01]],\n", + " \n", + " ...,\n", + " \n", + " [[-3.6304e-01, -1.4106e-01, -3.5434e-01, 1.4923e-01]],\n", + " \n", + " [[-2.3703e-01, 3.9285e-04, -2.1456e-02, -2.5568e-01]],\n", + " \n", + " [[ 1.5303e-02, -8.3474e-03, -3.2668e-01, -4.8096e-01]]])),\n", + " ('model.layers.13.mixer.conv1d.bias',\n", + " tensor([-0.2462, 0.1532, -0.2298, ..., -0.3016, 0.1210, -0.3777])),\n", + " ('model.layers.13.mixer.out_proj.weight',\n", + " tensor([[-0.0019, 0.0103, 0.0098, ..., -0.0050, 0.0180, -0.0117],\n", + " [-0.0153, 0.0134, -0.0102, ..., 0.0327, -0.0387, 0.0025],\n", + " [ 0.0102, -0.0038, 0.0224, ..., -0.0118, 0.0234, 0.0014],\n", + " ...,\n", + " [-0.0201, 0.0233, 0.0189, ..., 0.0010, 0.0313, 0.0130],\n", + " [ 0.0193, 0.0035, -0.0253, ..., 0.0084, -0.0208, 0.0372],\n", + " [ 0.0367, -0.0029, -0.0205, ..., -0.0055, -0.0209, 0.0082]])),\n", + " ('model.layers.13.mlp.gate_proj.weight',\n", + " tensor([[ 0.0148, -0.0052, 0.0371, ..., -0.0118, 0.0397, -0.0234],\n", + " [ 0.0237, -0.0323, 0.0219, ..., 0.0098, -0.0304, 0.0165],\n", + " [ 0.0168, -0.0289, 0.0038, ..., 0.0022, 0.0174, 0.0043],\n", + " ...,\n", + " [-0.0135, 0.0258, -0.0172, ..., 0.0251, -0.0071, -0.0384],\n", + " [ 0.0005, -0.0123, 0.0116, ..., 0.0041, -0.0108, -0.0068],\n", + " [ 0.0116, 0.0069, 0.0063, ..., 0.0045, -0.0145, 0.0185]])),\n", + " ('model.layers.13.mlp.up_proj.weight',\n", + " tensor([[-0.0002, -0.0120, 0.0069, ..., 0.0005, -0.0108, -0.0284],\n", + " [ 0.0215, 0.0045, 0.0167, ..., 0.0177, -0.0030, 0.0051],\n", + " [ 0.0265, 0.0169, 0.0047, ..., 0.0069, -0.0299, 0.0196],\n", + " ...,\n", + " [ 0.0127, -0.0063, 0.0242, ..., -0.0061, -0.0263, 0.0041],\n", + " [ 0.0142, -0.0515, -0.0221, ..., -0.0369, -0.0399, -0.0210],\n", + " [ 0.0123, 0.0133, -0.0269, ..., 0.0092, -0.0177, 0.0226]])),\n", + " ('model.layers.13.mlp.down_proj.weight',\n", + " tensor([[ 0.0048, 0.0360, -0.0037, ..., 0.0169, 0.0304, -0.0162],\n", + " [ 0.0271, -0.0121, 0.0108, ..., -0.0424, 0.0293, -0.0137],\n", + " [ 0.0225, -0.0061, -0.0096, ..., 0.0075, -0.0168, 0.0142],\n", + " ...,\n", + " [ 0.0039, -0.0152, -0.0156, ..., 0.0181, 0.0105, 0.0070],\n", + " [ 0.0311, 0.0205, 0.0259, ..., -0.0025, 0.0060, -0.0125],\n", + " [ 0.0004, -0.0114, 0.0022, ..., -0.0159, -0.0290, 0.0036]])),\n", + " ('model.layers.13.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.13.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.14.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.14.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.14.mixer.in_proj.weight',\n", + " tensor([[-0.0123, 0.0054, 0.0059, ..., 0.0285, -0.0292, -0.0184],\n", + " [-0.0146, -0.0175, 0.0155, ..., -0.0206, -0.0190, -0.0172],\n", + " [ 0.0050, -0.0235, -0.0159, ..., -0.0013, -0.0102, 0.0082],\n", + " ...,\n", + " [-0.0243, -0.0013, 0.0312, ..., -0.0141, -0.0156, 0.0279],\n", + " [ 0.0018, 0.0181, -0.0188, ..., 0.0593, -0.0155, 0.0156],\n", + " [ 0.0036, 0.0182, -0.0308, ..., 0.0306, -0.0035, 0.0037]])),\n", + " ('model.layers.14.mixer.conv1d.weight',\n", + " tensor([[[-0.4608, 0.4926, -0.2625, 0.3060]],\n", + " \n", + " [[-0.0932, 0.0153, 0.2298, -0.1735]],\n", + " \n", + " [[-0.1927, 0.1979, -0.1773, 0.3277]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.0538, -0.2180, -0.4857, -0.1428]],\n", + " \n", + " [[-0.1736, 0.2405, 0.3148, -0.4481]],\n", + " \n", + " [[-0.4971, -0.1558, 0.2762, -0.1849]]])),\n", + " ('model.layers.14.mixer.conv1d.bias',\n", + " tensor([-0.2181, -0.2375, 0.0896, ..., 0.0744, 0.0857, 0.4347])),\n", + " ('model.layers.14.mixer.out_proj.weight',\n", + " tensor([[-3.8364e-04, 2.4458e-02, 5.8783e-03, ..., -1.3479e-02,\n", + " -2.4306e-02, 5.7698e-03],\n", + " [ 4.5843e-02, -3.9217e-03, -6.9897e-03, ..., 5.5401e-03,\n", + " -1.4523e-02, 1.2266e-02],\n", + " [-7.1069e-03, 5.5550e-03, 1.1359e-02, ..., 3.5839e-02,\n", + " 1.0787e-02, 8.4053e-03],\n", + " ...,\n", + " [ 3.3029e-03, 5.4333e-03, -9.3382e-03, ..., -1.7376e-02,\n", + " 1.5601e-02, -6.3227e-03],\n", + " [-6.9199e-03, -1.6950e-02, 1.5155e-03, ..., 1.2324e-02,\n", + " 1.2259e-02, 5.5500e-02],\n", + " [-1.6177e-02, -6.5257e-05, -9.3656e-03, ..., 1.0653e-02,\n", + " 1.8864e-02, -1.2508e-02]])),\n", + " ('model.layers.14.mlp.gate_proj.weight',\n", + " tensor([[ 0.0279, 0.0025, 0.0214, ..., -0.0137, -0.0042, 0.0172],\n", + " [-0.0240, -0.0150, 0.0170, ..., 0.0090, 0.0002, 0.0172],\n", + " [-0.0181, 0.0052, -0.0418, ..., 0.0106, 0.0052, -0.0264],\n", + " ...,\n", + " [-0.0295, 0.0323, 0.0387, ..., -0.0116, -0.0140, -0.0053],\n", + " [ 0.0411, 0.0189, 0.0236, ..., 0.0094, -0.0176, -0.0066],\n", + " [ 0.0004, 0.0291, 0.0402, ..., 0.0127, -0.0009, 0.0010]])),\n", + " ('model.layers.14.mlp.up_proj.weight',\n", + " tensor([[ 0.0198, -0.0115, -0.0045, ..., 0.0273, 0.0012, -0.0082],\n", + " [-0.0217, 0.0075, 0.0006, ..., 0.0047, -0.0416, -0.0011],\n", + " [ 0.0012, -0.0214, -0.0211, ..., 0.0030, -0.0176, -0.0215],\n", + " ...,\n", + " [ 0.0062, -0.0305, 0.0310, ..., 0.0044, -0.0379, 0.0155],\n", + " [-0.0062, 0.0451, 0.0167, ..., 0.0062, -0.0033, 0.0012],\n", + " [ 0.0293, -0.0186, 0.0295, ..., 0.0092, 0.0100, 0.0038]])),\n", + " ('model.layers.14.mlp.down_proj.weight',\n", + " tensor([[ 0.0019, 0.0114, -0.0202, ..., 0.0227, -0.0227, -0.0005],\n", + " [-0.0437, -0.0045, -0.0385, ..., -0.0083, -0.0135, 0.0172],\n", + " [-0.0032, -0.0024, 0.0137, ..., 0.0071, 0.0034, 0.0104],\n", + " ...,\n", + " [ 0.0210, -0.0237, -0.0166, ..., -0.0105, 0.0490, 0.0155],\n", + " [-0.0109, 0.0112, 0.0082, ..., -0.0342, -0.0133, -0.0086],\n", + " [ 0.0282, -0.0210, -0.0127, ..., -0.0047, -0.0126, 0.0103]])),\n", + " ('model.layers.14.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.14.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.15.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.15.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.15.mixer.in_proj.weight',\n", + " tensor([[-0.0098, -0.0201, -0.0033, ..., -0.0289, 0.0275, 0.0186],\n", + " [ 0.0048, 0.0075, -0.0033, ..., 0.0011, 0.0042, 0.0040],\n", + " [-0.0079, -0.0025, 0.0018, ..., -0.0051, -0.0231, -0.0022],\n", + " ...,\n", + " [ 0.0186, -0.0104, -0.0062, ..., 0.0086, -0.0007, -0.0653],\n", + " [-0.0212, 0.0034, 0.0019, ..., 0.0167, 0.0050, 0.0120],\n", + " [ 0.0066, 0.0381, -0.0225, ..., -0.0043, 0.0229, -0.0004]])),\n", + " ('model.layers.15.mixer.conv1d.weight',\n", + " tensor([[[ 0.2306, 0.2721, 0.3406, 0.4513]],\n", + " \n", + " [[ 0.0991, 0.4973, 0.0010, -0.1445]],\n", + " \n", + " [[ 0.2975, 0.4813, 0.2817, -0.0468]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0104, -0.1473, 0.1685, -0.4390]],\n", + " \n", + " [[ 0.3669, 0.3461, 0.0845, 0.3576]],\n", + " \n", + " [[-0.1177, 0.0524, 0.4329, 0.0687]]])),\n", + " ('model.layers.15.mixer.conv1d.bias',\n", + " tensor([-0.0356, 0.4173, 0.3287, ..., -0.0141, 0.1365, 0.2086])),\n", + " ('model.layers.15.mixer.out_proj.weight',\n", + " tensor([[-0.0137, -0.0239, -0.0133, ..., -0.0177, -0.0125, -0.0015],\n", + " [ 0.0168, 0.0120, 0.0034, ..., 0.0098, 0.0098, 0.0110],\n", + " [-0.0315, 0.0447, 0.0189, ..., 0.0305, 0.0131, -0.0230],\n", + " ...,\n", + " [-0.0480, 0.0170, 0.0025, ..., 0.0317, -0.0378, -0.0236],\n", + " [-0.0319, -0.0290, 0.0023, ..., -0.0093, 0.0354, 0.0126],\n", + " [-0.0107, 0.0100, -0.0101, ..., 0.0046, 0.0205, -0.0203]])),\n", + " ('model.layers.15.mlp.gate_proj.weight',\n", + " tensor([[ 0.0160, 0.0432, 0.0073, ..., -0.0003, -0.0170, 0.0236],\n", + " [ 0.0055, 0.0066, -0.0311, ..., 0.0049, -0.0130, 0.0040],\n", + " [-0.0147, -0.0184, 0.0281, ..., 0.0016, 0.0077, -0.0072],\n", + " ...,\n", + " [-0.0049, -0.0434, -0.0118, ..., 0.0137, -0.0225, -0.0058],\n", + " [ 0.0221, -0.0077, 0.0029, ..., 0.0087, -0.0361, -0.0100],\n", + " [ 0.0263, 0.0228, 0.0050, ..., -0.0557, 0.0037, 0.0196]])),\n", + " ('model.layers.15.mlp.up_proj.weight',\n", + " tensor([[ 0.0093, -0.0189, 0.0173, ..., 0.0276, 0.0075, -0.0215],\n", + " [-0.0147, 0.0241, 0.0109, ..., 0.0120, 0.0032, 0.0327],\n", + " [ 0.0036, 0.0127, 0.0116, ..., 0.0100, -0.0003, 0.0233],\n", + " ...,\n", + " [-0.0063, 0.0160, 0.0138, ..., -0.0078, -0.0098, 0.0150],\n", + " [ 0.0138, -0.0236, 0.0109, ..., -0.0156, -0.0143, 0.0273],\n", + " [ 0.0345, 0.0201, -0.0119, ..., -0.0182, 0.0053, 0.0105]])),\n", + " ('model.layers.15.mlp.down_proj.weight',\n", + " tensor([[-0.0114, 0.0138, -0.0110, ..., 0.0084, -0.0144, 0.0100],\n", + " [ 0.0016, -0.0069, 0.0172, ..., -0.0394, 0.0368, 0.0468],\n", + " [-0.0184, -0.0094, -0.0273, ..., -0.0195, 0.0148, 0.0142],\n", + " ...,\n", + " [ 0.0311, 0.0093, -0.0130, ..., -0.0023, 0.0395, -0.0375],\n", + " [ 0.0056, 0.0027, 0.0061, ..., 0.0058, 0.0225, -0.0153],\n", + " [-0.0031, -0.0107, 0.0020, ..., -0.0173, -0.0050, 0.0423]])),\n", + " ('model.layers.15.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.15.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.16.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.16.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.16.mixer.in_proj.weight',\n", + " tensor([[-0.0063, 0.0006, 0.0130, ..., 0.0186, 0.0408, 0.0126],\n", + " [-0.0015, -0.0029, 0.0268, ..., -0.0042, -0.0209, -0.0046],\n", + " [-0.0034, -0.0286, 0.0185, ..., -0.0125, 0.0050, 0.0033],\n", + " ...,\n", + " [ 0.0045, 0.0133, 0.0220, ..., 0.0165, 0.0287, 0.0371],\n", + " [ 0.0100, -0.0232, 0.0103, ..., -0.0083, -0.0105, -0.0187],\n", + " [-0.0412, -0.0035, 0.0028, ..., 0.0286, 0.0349, -0.0037]])),\n", + " ('model.layers.16.mixer.conv1d.weight',\n", + " tensor([[[-0.1874, 0.2517, 0.0537, 0.1258]],\n", + " \n", + " [[ 0.1465, 0.2013, 0.3547, 0.2689]],\n", + " \n", + " [[ 0.4834, 0.4906, 0.0844, -0.0541]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.3004, 0.3313, 0.1688, 0.4381]],\n", + " \n", + " [[-0.0606, 0.3455, -0.0910, 0.1148]],\n", + " \n", + " [[-0.1421, -0.1254, -0.2353, -0.1675]]])),\n", + " ('model.layers.16.mixer.conv1d.bias',\n", + " tensor([ 0.2835, 0.2361, 0.1225, ..., -0.2119, -0.1929, 0.3877])),\n", + " ('model.layers.16.mixer.out_proj.weight',\n", + " tensor([[-0.0121, 0.0194, 0.0060, ..., -0.0029, -0.0147, -0.0085],\n", + " [-0.0216, -0.0012, 0.0287, ..., 0.0102, -0.0133, -0.0153],\n", + " [ 0.0136, -0.0296, 0.0417, ..., -0.0118, -0.0283, 0.0359],\n", + " ...,\n", + " [-0.0263, -0.0003, 0.0022, ..., 0.0135, -0.0519, -0.0254],\n", + " [ 0.0121, -0.0144, -0.0026, ..., 0.0096, 0.0130, 0.0095],\n", + " [-0.0147, -0.0217, 0.0099, ..., 0.0267, -0.0072, -0.0213]])),\n", + " ('model.layers.16.mlp.gate_proj.weight',\n", + " tensor([[ 0.0103, -0.0396, -0.0127, ..., 0.0020, -0.0055, 0.0291],\n", + " [ 0.0194, 0.0357, -0.0020, ..., -0.0112, 0.0448, -0.0224],\n", + " [-0.0390, 0.0142, -0.0224, ..., -0.0030, 0.0102, 0.0078],\n", + " ...,\n", + " [ 0.0165, -0.0251, 0.0196, ..., 0.0213, 0.0040, -0.0228],\n", + " [-0.0145, 0.0218, -0.0032, ..., -0.0240, -0.0079, 0.0256],\n", + " [ 0.0539, -0.0027, -0.0227, ..., -0.0184, -0.0109, 0.0236]])),\n", + " ('model.layers.16.mlp.up_proj.weight',\n", + " tensor([[ 7.1125e-03, -3.2583e-04, -2.6297e-02, ..., -4.9575e-03,\n", + " -1.2243e-02, -1.3005e-02],\n", + " [ 2.5637e-02, -1.1874e-02, 1.1376e-02, ..., -1.4700e-02,\n", + " -1.5193e-02, 2.6111e-03],\n", + " [-4.8919e-02, -4.9716e-04, 5.8527e-03, ..., 8.6775e-05,\n", + " 1.0694e-02, 3.7682e-03],\n", + " ...,\n", + " [ 8.8393e-03, -4.3317e-02, 2.8372e-02, ..., 2.2709e-02,\n", + " -4.8128e-03, 1.6899e-02],\n", + " [ 1.3257e-02, 2.1000e-02, 1.5035e-03, ..., 1.5603e-02,\n", + " -5.5857e-03, 4.0449e-03],\n", + " [-2.6754e-02, -1.6263e-02, 1.9013e-02, ..., -9.0918e-03,\n", + " -8.0242e-03, -1.0925e-02]])),\n", + " ('model.layers.16.mlp.down_proj.weight',\n", + " tensor([[ 0.0207, -0.0038, -0.0234, ..., 0.0299, -0.0329, -0.0117],\n", + " [-0.0316, 0.0032, 0.0131, ..., 0.0020, -0.0320, 0.0381],\n", + " [-0.0192, -0.0031, -0.0030, ..., -0.0224, 0.0037, 0.0085],\n", + " ...,\n", + " [ 0.0044, 0.0281, -0.0208, ..., 0.0179, -0.0085, -0.0010],\n", + " [-0.0076, -0.0008, 0.0483, ..., 0.0082, -0.0177, -0.0039],\n", + " [ 0.0224, 0.0019, 0.0181, ..., 0.0143, -0.0252, 0.0022]])),\n", + " ('model.layers.16.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.16.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.17.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.17.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.17.mixer.in_proj.weight',\n", + " tensor([[-0.0115, 0.0061, -0.0062, ..., -0.0132, -0.0047, 0.0274],\n", + " [ 0.0076, 0.0278, -0.0147, ..., 0.0439, -0.0093, -0.0154],\n", + " [-0.0383, -0.0264, -0.0053, ..., -0.0206, 0.0275, 0.0188],\n", + " ...,\n", + " [ 0.0096, 0.0228, 0.0351, ..., 0.0227, 0.0138, -0.0164],\n", + " [ 0.0321, -0.0293, -0.0054, ..., 0.0109, -0.0113, -0.0130],\n", + " [-0.0120, -0.0132, 0.0092, ..., -0.0338, 0.0308, -0.0135]])),\n", + " ('model.layers.17.mixer.conv1d.weight',\n", + " tensor([[[-0.4933, 0.4156, 0.2523, -0.0026]],\n", + " \n", + " [[-0.2572, 0.4916, 0.3642, -0.2145]],\n", + " \n", + " [[ 0.0261, 0.4852, -0.1448, 0.2288]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.3698, -0.4122, -0.2264, -0.1378]],\n", + " \n", + " [[ 0.1447, 0.4556, -0.0466, 0.0389]],\n", + " \n", + " [[-0.3891, 0.4149, 0.1454, -0.4282]]])),\n", + " ('model.layers.17.mixer.conv1d.bias',\n", + " tensor([-0.3919, -0.4015, 0.2591, ..., -0.3368, 0.2285, 0.1701])),\n", + " ('model.layers.17.mixer.out_proj.weight',\n", + " tensor([[-0.0127, -0.0155, 0.0193, ..., 0.0204, 0.0025, 0.0159],\n", + " [ 0.0192, 0.0194, -0.0169, ..., -0.0062, 0.0262, 0.0070],\n", + " [ 0.0397, 0.0009, 0.0189, ..., -0.0082, 0.0352, -0.0150],\n", + " ...,\n", + " [-0.0339, -0.0142, -0.0151, ..., 0.0229, 0.0032, 0.0038],\n", + " [ 0.0235, 0.0319, -0.0137, ..., -0.0121, 0.0112, 0.0162],\n", + " [ 0.0060, 0.0102, -0.0016, ..., 0.0118, 0.0158, -0.0140]])),\n", + " ('model.layers.17.mlp.gate_proj.weight',\n", + " tensor([[ 0.0285, -0.0090, -0.0095, ..., 0.0315, -0.0065, 0.0189],\n", + " [ 0.0040, -0.0358, -0.0039, ..., -0.0074, -0.0285, -0.0223],\n", + " [ 0.0202, 0.0021, -0.0104, ..., -0.0083, 0.0300, -0.0267],\n", + " ...,\n", + " [ 0.0093, -0.0008, -0.0372, ..., 0.0422, 0.0309, 0.0095],\n", + " [ 0.0027, 0.0252, 0.0378, ..., -0.0238, 0.0234, -0.0062],\n", + " [-0.0061, -0.0022, -0.0033, ..., 0.0157, -0.0296, 0.0034]])),\n", + " ('model.layers.17.mlp.up_proj.weight',\n", + " tensor([[ 0.0061, -0.0135, 0.0029, ..., 0.0328, 0.0008, -0.0072],\n", + " [ 0.0145, -0.0226, -0.0095, ..., 0.0114, 0.0224, -0.0160],\n", + " [ 0.0097, -0.0024, -0.0179, ..., 0.0073, -0.0061, -0.0195],\n", + " ...,\n", + " [ 0.0308, -0.0014, 0.0104, ..., 0.0047, 0.0026, 0.0243],\n", + " [-0.0364, 0.0350, 0.0031, ..., -0.0072, 0.0267, 0.0017],\n", + " [ 0.0227, -0.0146, 0.0146, ..., -0.0434, -0.0159, 0.0230]])),\n", + " ('model.layers.17.mlp.down_proj.weight',\n", + " tensor([[-0.0216, 0.0211, 0.0136, ..., -0.0004, 0.0051, 0.0415],\n", + " [-0.0061, -0.0123, 0.0156, ..., -0.0005, -0.0183, -0.0137],\n", + " [-0.0146, -0.0274, -0.0439, ..., -0.0033, -0.0030, -0.0074],\n", + " ...,\n", + " [-0.0108, -0.0005, -0.0094, ..., -0.0243, 0.0065, -0.0005],\n", + " [-0.0126, 0.0124, -0.0006, ..., -0.0282, -0.0110, 0.0128],\n", + " [-0.0162, -0.0102, 0.0025, ..., -0.0084, 0.0066, -0.0074]])),\n", + " ('model.layers.17.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.17.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.18.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.18.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.18.mixer.in_proj.weight',\n", + " tensor([[-9.4961e-03, -1.2349e-04, -7.1455e-03, ..., 1.9508e-02,\n", + " -6.8715e-03, -1.3565e-02],\n", + " [-2.9701e-03, 3.1580e-03, 1.8849e-02, ..., 7.6566e-03,\n", + " -1.0968e-02, -8.0445e-03],\n", + " [-1.5402e-02, -6.7267e-03, 9.6119e-03, ..., 1.9799e-02,\n", + " 2.0198e-03, -1.7366e-03],\n", + " ...,\n", + " [ 8.2379e-03, 5.1668e-03, 3.8116e-02, ..., -3.8710e-03,\n", + " 1.4452e-02, -2.5152e-02],\n", + " [ 1.1949e-02, -1.2245e-03, 1.0568e-02, ..., -3.1690e-02,\n", + " 3.8135e-05, 1.7263e-02],\n", + " [ 1.6173e-04, 5.6721e-04, 2.1043e-02, ..., -3.6167e-02,\n", + " -1.1129e-02, -9.6768e-03]])),\n", + " ('model.layers.18.mixer.conv1d.weight',\n", + " tensor([[[ 0.2776, 0.2169, -0.2840, 0.1736]],\n", + " \n", + " [[-0.0598, -0.2654, 0.2423, -0.0874]],\n", + " \n", + " [[-0.3612, -0.3049, -0.3197, -0.2763]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.1389, 0.2034, -0.1739, 0.1634]],\n", + " \n", + " [[-0.2836, -0.0471, 0.1284, -0.0099]],\n", + " \n", + " [[ 0.2952, -0.2676, -0.3961, 0.2656]]])),\n", + " ('model.layers.18.mixer.conv1d.bias',\n", + " tensor([ 0.1804, 0.0336, 0.4006, ..., 0.2943, -0.1079, 0.0963])),\n", + " ('model.layers.18.mixer.out_proj.weight',\n", + " tensor([[ 0.0109, -0.0181, 0.0148, ..., -0.0105, -0.0011, -0.0052],\n", + " [ 0.0507, 0.0100, -0.0273, ..., -0.0069, 0.0054, 0.0129],\n", + " [ 0.0014, 0.0423, -0.0193, ..., -0.0023, -0.0293, 0.0004],\n", + " ...,\n", + " [ 0.0420, -0.0401, 0.0205, ..., 0.0135, -0.0089, -0.0023],\n", + " [ 0.0242, 0.0273, 0.0139, ..., -0.0402, 0.0061, 0.0119],\n", + " [-0.0145, 0.0102, 0.0245, ..., 0.0205, -0.0251, 0.0006]])),\n", + " ('model.layers.18.mlp.gate_proj.weight',\n", + " tensor([[ 0.0241, -0.0086, 0.0136, ..., -0.0219, -0.0064, -0.0142],\n", + " [-0.0067, 0.0252, 0.0246, ..., -0.0205, -0.0273, 0.0137],\n", + " [-0.0030, 0.0055, -0.0063, ..., 0.0107, 0.0083, -0.0037],\n", + " ...,\n", + " [-0.0154, 0.0101, 0.0221, ..., 0.0025, -0.0109, 0.0133],\n", + " [-0.0175, 0.0105, -0.0246, ..., 0.0244, 0.0023, 0.0080],\n", + " [-0.0060, 0.0183, 0.0297, ..., 0.0420, -0.0006, -0.0119]])),\n", + " ('model.layers.18.mlp.up_proj.weight',\n", + " tensor([[ 0.0066, -0.0009, -0.0070, ..., -0.0064, 0.0002, 0.0196],\n", + " [-0.0173, -0.0362, -0.0011, ..., 0.0158, -0.0198, -0.0046],\n", + " [ 0.0133, -0.0090, -0.0092, ..., 0.0039, -0.0052, -0.0101],\n", + " ...,\n", + " [ 0.0077, -0.0063, 0.0010, ..., 0.0091, 0.0218, 0.0132],\n", + " [ 0.0005, -0.0046, 0.0207, ..., 0.0112, 0.0183, -0.0020],\n", + " [ 0.0238, -0.0022, 0.0364, ..., -0.0042, 0.0237, 0.0183]])),\n", + " ('model.layers.18.mlp.down_proj.weight',\n", + " tensor([[ 0.0305, 0.0178, -0.0264, ..., -0.0158, 0.0135, 0.0132],\n", + " [ 0.0248, -0.0061, 0.0144, ..., -0.0165, 0.0098, 0.0410],\n", + " [-0.0156, -0.0039, 0.0112, ..., -0.0431, -0.0084, -0.0197],\n", + " ...,\n", + " [ 0.0071, 0.0236, -0.0038, ..., 0.0035, -0.0236, 0.0106],\n", + " [-0.0369, -0.0029, -0.0182, ..., -0.0008, -0.0417, 0.0064],\n", + " [-0.0273, 0.0207, 0.0130, ..., 0.0372, 0.0163, 0.0273]])),\n", + " ('model.layers.18.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.18.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.19.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.19.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.19.mixer.in_proj.weight',\n", + " tensor([[-0.0079, 0.0147, -0.0337, ..., -0.0201, -0.0254, 0.0035],\n", + " [ 0.0139, 0.0054, -0.0093, ..., -0.0208, -0.0289, -0.0087],\n", + " [ 0.0004, -0.0034, 0.0090, ..., -0.0109, -0.0093, 0.0102],\n", + " ...,\n", + " [ 0.0128, 0.0015, -0.0101, ..., -0.0482, -0.0217, 0.0144],\n", + " [-0.0100, -0.0079, 0.0286, ..., -0.0025, -0.0210, 0.0164],\n", + " [-0.0264, 0.0015, 0.0031, ..., 0.0027, 0.0131, -0.0384]])),\n", + " ('model.layers.19.mixer.conv1d.weight',\n", + " tensor([[[ 0.4729, 0.3708, -0.4394, -0.3549]],\n", + " \n", + " [[ 0.2230, -0.3271, 0.3017, -0.2552]],\n", + " \n", + " [[-0.0417, 0.1893, 0.4552, -0.0644]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.2565, 0.0407, 0.3521, 0.4116]],\n", + " \n", + " [[ 0.0795, -0.0374, 0.1034, 0.4254]],\n", + " \n", + " [[ 0.3333, 0.2431, 0.3459, -0.2676]]])),\n", + " ('model.layers.19.mixer.conv1d.bias',\n", + " tensor([-0.2287, -0.4446, -0.2300, ..., -0.2317, -0.3395, 0.4310])),\n", + " ('model.layers.19.mixer.out_proj.weight',\n", + " tensor([[-0.0456, -0.0167, -0.0117, ..., -0.0068, -0.0150, 0.0125],\n", + " [ 0.0194, 0.0172, -0.0232, ..., -0.0202, -0.0066, 0.0083],\n", + " [ 0.0320, -0.0065, 0.0274, ..., 0.0200, 0.0090, 0.0105],\n", + " ...,\n", + " [ 0.0315, 0.0415, 0.0128, ..., -0.0143, -0.0338, -0.0231],\n", + " [ 0.0227, -0.0177, -0.0034, ..., 0.0174, 0.0006, 0.0212],\n", + " [ 0.0358, 0.0084, 0.0075, ..., 0.0091, 0.0062, 0.0114]])),\n", + " ('model.layers.19.mlp.gate_proj.weight',\n", + " tensor([[-0.0010, 0.0156, 0.0042, ..., -0.0181, 0.0113, 0.0089],\n", + " [-0.0182, 0.0068, -0.0043, ..., -0.0323, -0.0019, -0.0045],\n", + " [ 0.0168, -0.0093, -0.0162, ..., -0.0074, 0.0166, -0.0334],\n", + " ...,\n", + " [ 0.0038, -0.0211, -0.0054, ..., -0.0229, 0.0193, -0.0210],\n", + " [ 0.0153, -0.0372, 0.0119, ..., 0.0043, -0.0097, -0.0025],\n", + " [ 0.0037, 0.0208, -0.0135, ..., 0.0052, -0.0125, -0.0282]])),\n", + " ('model.layers.19.mlp.up_proj.weight',\n", + " tensor([[-0.0026, 0.0360, 0.0161, ..., 0.0199, -0.0283, -0.0026],\n", + " [ 0.0185, 0.0122, -0.0299, ..., 0.0125, 0.0063, 0.0387],\n", + " [-0.0085, -0.0010, -0.0054, ..., -0.0088, -0.0034, -0.0179],\n", + " ...,\n", + " [-0.0179, 0.0211, -0.0003, ..., -0.0071, -0.0145, 0.0235],\n", + " [-0.0002, 0.0060, -0.0172, ..., -0.0086, 0.0175, -0.0232],\n", + " [-0.0081, -0.0280, -0.0152, ..., -0.0221, 0.0047, -0.0077]])),\n", + " ('model.layers.19.mlp.down_proj.weight',\n", + " tensor([[ 0.0038, -0.0027, -0.0122, ..., 0.0090, 0.0044, 0.0128],\n", + " [ 0.0054, 0.0075, 0.0116, ..., 0.0232, 0.0130, 0.0298],\n", + " [-0.0498, -0.0208, -0.0127, ..., 0.0166, -0.0221, 0.0038],\n", + " ...,\n", + " [ 0.0101, 0.0051, 0.0209, ..., 0.0137, -0.0225, 0.0142],\n", + " [-0.0433, -0.0217, -0.0167, ..., -0.0179, -0.0191, -0.0021],\n", + " [-0.0020, 0.0084, -0.0114, ..., 0.0324, 0.0216, -0.0062]])),\n", + " ('model.layers.19.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.19.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.20.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.20.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.20.mixer.in_proj.weight',\n", + " tensor([[ 3.3776e-02, 3.6619e-02, 6.8532e-03, ..., 5.7664e-02,\n", + " -2.3083e-02, -6.2962e-02],\n", + " [-2.9787e-03, -2.5050e-03, -3.4841e-03, ..., 5.4946e-03,\n", + " 9.0683e-03, 2.1583e-04],\n", + " [ 7.4430e-03, -1.0495e-02, 3.5169e-02, ..., -5.1808e-02,\n", + " 3.2650e-03, -3.1967e-02],\n", + " ...,\n", + " [-5.8685e-02, 4.8452e-02, -1.2612e-02, ..., 1.2174e-02,\n", + " 1.0566e-02, -4.9561e-03],\n", + " [ 3.1722e-03, -2.9390e-03, 1.4502e-05, ..., -2.3297e-02,\n", + " -7.5403e-03, -1.3599e-02],\n", + " [ 1.4845e-02, -4.3150e-02, -1.0338e-02, ..., -1.1149e-02,\n", + " -3.3432e-02, 3.8337e-03]])),\n", + " ('model.layers.20.mixer.conv1d.weight',\n", + " tensor([[[-0.3842, 0.2397, 0.4873, -0.3091]],\n", + " \n", + " [[-0.1886, 0.0751, 0.2026, -0.2674]],\n", + " \n", + " [[-0.0594, 0.3119, -0.2404, 0.1652]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.0028, 0.1315, 0.0515, 0.3189]],\n", + " \n", + " [[-0.1461, -0.0457, -0.0536, -0.2306]],\n", + " \n", + " [[-0.3025, -0.3339, 0.3007, -0.3007]]])),\n", + " ('model.layers.20.mixer.conv1d.bias',\n", + " tensor([-0.4901, -0.3784, -0.0173, ..., -0.3946, -0.0728, 0.2187])),\n", + " ('model.layers.20.mixer.out_proj.weight',\n", + " tensor([[ 0.0095, -0.0037, -0.0218, ..., 0.0080, 0.0062, 0.0246],\n", + " [-0.0197, 0.0037, 0.0076, ..., 0.0171, 0.0238, -0.0195],\n", + " [ 0.0364, -0.0165, 0.0224, ..., -0.0099, 0.0007, 0.0340],\n", + " ...,\n", + " [ 0.0235, -0.0072, -0.0319, ..., 0.0045, -0.0196, 0.0011],\n", + " [-0.0369, 0.0083, 0.0021, ..., -0.0357, -0.0039, -0.0150],\n", + " [-0.0174, -0.0211, 0.0111, ..., 0.0251, 0.0040, -0.0308]])),\n", + " ('model.layers.20.mlp.gate_proj.weight',\n", + " tensor([[ 0.0161, -0.0019, -0.0473, ..., 0.0019, 0.0075, -0.0038],\n", + " [-0.0321, -0.0020, -0.0100, ..., 0.0035, 0.0291, -0.0058],\n", + " [-0.0158, 0.0020, 0.0353, ..., 0.0125, 0.0228, -0.0392],\n", + " ...,\n", + " [ 0.0113, 0.0171, 0.0235, ..., 0.0043, 0.0378, 0.0391],\n", + " [ 0.0090, 0.0067, 0.0031, ..., 0.0291, -0.0052, -0.0216],\n", + " [ 0.0042, -0.0112, -0.0161, ..., -0.0063, -0.0156, 0.0211]])),\n", + " ('model.layers.20.mlp.up_proj.weight',\n", + " tensor([[ 0.0104, -0.0302, -0.0220, ..., -0.0072, -0.0083, -0.0066],\n", + " [ 0.0409, -0.0116, -0.0125, ..., 0.0182, 0.0267, 0.0099],\n", + " [-0.0055, 0.0104, 0.0027, ..., -0.0075, -0.0368, -0.0092],\n", + " ...,\n", + " [-0.0089, 0.0243, -0.0028, ..., -0.0136, -0.0176, -0.0054],\n", + " [ 0.0088, 0.0365, -0.0354, ..., 0.0035, 0.0280, 0.0155],\n", + " [-0.0472, 0.0088, 0.0102, ..., -0.0120, 0.0004, -0.0011]])),\n", + " ('model.layers.20.mlp.down_proj.weight',\n", + " tensor([[-0.0089, -0.0112, -0.0007, ..., 0.0360, -0.0077, 0.0261],\n", + " [ 0.0080, -0.0128, -0.0445, ..., 0.0095, -0.0298, 0.0176],\n", + " [ 0.0357, -0.0262, 0.0028, ..., 0.0162, 0.0089, 0.0050],\n", + " ...,\n", + " [-0.0129, 0.0216, 0.0125, ..., -0.0062, -0.0344, -0.0218],\n", + " [ 0.0006, -0.0143, -0.0099, ..., -0.0359, 0.0268, 0.0259],\n", + " [ 0.0222, -0.0154, 0.0013, ..., 0.0108, -0.0077, 0.0186]])),\n", + " ('model.layers.20.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.20.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.21.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.21.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.21.mixer.in_proj.weight',\n", + " tensor([[-0.0300, 0.0058, -0.0107, ..., -0.0318, 0.0350, 0.0350],\n", + " [ 0.0186, 0.0238, -0.0268, ..., 0.0142, -0.0277, -0.0095],\n", + " [-0.0061, 0.0083, 0.0072, ..., 0.0161, 0.0027, -0.0051],\n", + " ...,\n", + " [-0.0358, 0.0330, 0.0151, ..., -0.0376, 0.0057, 0.0174],\n", + " [-0.0021, 0.0068, 0.0151, ..., 0.0077, -0.0353, 0.0095],\n", + " [-0.0113, -0.0043, 0.0064, ..., -0.0063, -0.0232, -0.0058]])),\n", + " ('model.layers.21.mixer.conv1d.weight',\n", + " tensor([[[ 0.0354, 0.0496, -0.0106, 0.0084]],\n", + " \n", + " [[ 0.2553, 0.3217, -0.0078, -0.2333]],\n", + " \n", + " [[-0.1390, 0.0323, 0.4914, -0.2047]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.2243, 0.2984, 0.0188, 0.1830]],\n", + " \n", + " [[ 0.0756, 0.1443, -0.4898, -0.2082]],\n", + " \n", + " [[-0.3685, -0.1311, -0.4037, -0.3276]]])),\n", + " ('model.layers.21.mixer.conv1d.bias',\n", + " tensor([-0.2444, -0.1852, 0.2215, ..., 0.4515, 0.2532, -0.2388])),\n", + " ('model.layers.21.mixer.out_proj.weight',\n", + " tensor([[ 0.0232, 0.0328, 0.0026, ..., -0.0575, 0.0157, -0.0072],\n", + " [-0.0226, 0.0058, -0.0346, ..., 0.0092, 0.0078, 0.0108],\n", + " [ 0.0045, 0.0247, 0.0150, ..., -0.0085, 0.0268, 0.0253],\n", + " ...,\n", + " [ 0.0268, 0.0092, 0.0141, ..., 0.0062, 0.0177, -0.0405],\n", + " [ 0.0163, -0.0269, -0.0177, ..., 0.0029, -0.0080, -0.0036],\n", + " [ 0.0064, 0.0126, 0.0126, ..., -0.0400, -0.0015, -0.0088]])),\n", + " ('model.layers.21.mlp.gate_proj.weight',\n", + " tensor([[-3.7050e-02, 4.5834e-02, 1.9280e-02, ..., 1.6761e-02,\n", + " -5.8295e-03, -1.4284e-02],\n", + " [ 3.0156e-02, 3.2832e-02, 1.1083e-02, ..., -5.8261e-03,\n", + " -3.9076e-02, 5.3379e-03],\n", + " [ 1.3118e-03, 3.1510e-02, 1.5472e-02, ..., 1.8213e-02,\n", + " -2.5180e-02, 6.1512e-04],\n", + " ...,\n", + " [ 4.2010e-02, 1.0362e-02, 7.1759e-03, ..., 1.8667e-03,\n", + " -7.2165e-03, 1.6297e-02],\n", + " [ 1.8175e-02, 1.2840e-02, 3.2857e-03, ..., 1.8495e-02,\n", + " -7.7709e-03, 4.3964e-04],\n", + " [-9.2628e-05, 2.1701e-02, 2.1256e-02, ..., 2.5241e-02,\n", + " 5.0683e-02, -2.5481e-02]])),\n", + " ('model.layers.21.mlp.up_proj.weight',\n", + " tensor([[ 0.0228, 0.0082, -0.0083, ..., 0.0288, 0.0211, 0.0085],\n", + " [-0.0155, 0.0179, 0.0111, ..., -0.0218, -0.0162, -0.0052],\n", + " [ 0.0016, 0.0009, 0.0230, ..., -0.0017, 0.0131, 0.0255],\n", + " ...,\n", + " [-0.0098, -0.0098, -0.0188, ..., 0.0063, 0.0082, 0.0052],\n", + " [-0.0028, 0.0249, -0.0153, ..., -0.0208, 0.0130, -0.0093],\n", + " [ 0.0105, -0.0072, -0.0379, ..., 0.0035, 0.0182, 0.0307]])),\n", + " ('model.layers.21.mlp.down_proj.weight',\n", + " tensor([[-0.0445, -0.0116, 0.0058, ..., 0.0081, -0.0099, 0.0094],\n", + " [ 0.0106, -0.0387, 0.0051, ..., 0.0017, 0.0075, 0.0136],\n", + " [ 0.0022, 0.0058, -0.0268, ..., -0.0088, -0.0149, 0.0125],\n", + " ...,\n", + " [-0.0015, -0.0156, -0.0225, ..., 0.0100, -0.0118, -0.0019],\n", + " [-0.0161, -0.0225, -0.0060, ..., 0.0073, -0.0072, 0.0205],\n", + " [-0.0112, 0.0046, -0.0089, ..., -0.0014, -0.0221, 0.0124]])),\n", + " ('model.layers.21.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.21.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.22.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.22.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.22.mixer.in_proj.weight',\n", + " tensor([[-1.1591e-02, -6.0118e-03, -2.2227e-03, ..., -7.1433e-03,\n", + " -1.5757e-02, -1.5315e-03],\n", + " [-7.6057e-03, -4.2199e-02, 1.4478e-02, ..., 5.6496e-02,\n", + " 8.9105e-05, -3.8658e-03],\n", + " [-1.0330e-03, 2.3586e-02, 2.1835e-02, ..., -1.4911e-03,\n", + " -1.6604e-02, -4.5245e-03],\n", + " ...,\n", + " [-6.7261e-03, -6.9826e-03, -9.3003e-03, ..., -4.3939e-02,\n", + " 2.3792e-02, -5.5165e-03],\n", + " [-1.1798e-02, -3.4709e-02, -4.1277e-03, ..., -5.1867e-03,\n", + " 5.2496e-03, -6.0055e-03],\n", + " [ 7.3402e-04, -1.9525e-02, -5.8966e-03, ..., -1.5972e-02,\n", + " -1.5446e-02, -2.7164e-02]])),\n", + " ('model.layers.22.mixer.conv1d.weight',\n", + " tensor([[[-0.3791, 0.0616, 0.0369, 0.1365]],\n", + " \n", + " [[-0.4674, -0.4557, 0.3894, -0.4765]],\n", + " \n", + " [[ 0.3333, 0.2265, 0.1385, -0.1352]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.4363, -0.3526, -0.3982, -0.1049]],\n", + " \n", + " [[ 0.4798, -0.3912, 0.4059, -0.1379]],\n", + " \n", + " [[-0.4427, 0.4661, -0.1990, 0.1668]]])),\n", + " ('model.layers.22.mixer.conv1d.bias',\n", + " tensor([-0.1823, -0.4117, 0.4443, ..., -0.0024, 0.2144, -0.4922])),\n", + " ('model.layers.22.mixer.out_proj.weight',\n", + " tensor([[ 0.0138, -0.0169, -0.0349, ..., -0.0045, 0.0023, -0.0389],\n", + " [ 0.0250, 0.0040, -0.0259, ..., 0.0458, 0.0311, -0.0054],\n", + " [-0.0056, 0.0012, -0.0027, ..., 0.0095, -0.0089, -0.0106],\n", + " ...,\n", + " [ 0.0228, -0.0258, 0.0040, ..., 0.0276, -0.0121, -0.0239],\n", + " [ 0.0082, 0.0041, 0.0145, ..., 0.0079, -0.0076, 0.0177],\n", + " [ 0.0310, -0.0092, -0.0174, ..., 0.0179, 0.0231, -0.0035]])),\n", + " ('model.layers.22.mlp.gate_proj.weight',\n", + " tensor([[ 0.0090, -0.0178, -0.0120, ..., -0.0073, -0.0149, 0.0187],\n", + " [ 0.0263, -0.0093, -0.0074, ..., -0.0472, 0.0049, 0.0288],\n", + " [ 0.0159, -0.0083, 0.0291, ..., 0.0089, -0.0076, -0.0167],\n", + " ...,\n", + " [-0.0008, 0.0206, 0.0199, ..., -0.0134, -0.0366, -0.0202],\n", + " [-0.0069, -0.0275, 0.0054, ..., 0.0093, 0.0108, 0.0094],\n", + " [ 0.0198, 0.0033, -0.0118, ..., -0.0262, 0.0241, 0.0084]])),\n", + " ('model.layers.22.mlp.up_proj.weight',\n", + " tensor([[-0.0277, 0.0038, 0.0006, ..., -0.0222, -0.0313, -0.0133],\n", + " [ 0.0132, -0.0373, 0.0109, ..., 0.0359, -0.0116, 0.0099],\n", + " [ 0.0139, -0.0185, 0.0247, ..., 0.0178, 0.0192, 0.0049],\n", + " ...,\n", + " [ 0.0362, 0.0072, -0.0236, ..., -0.0238, 0.0319, -0.0210],\n", + " [ 0.0013, -0.0047, -0.0060, ..., 0.0106, -0.0074, -0.0185],\n", + " [-0.0228, 0.0176, -0.0047, ..., -0.0034, -0.0174, -0.0264]])),\n", + " ('model.layers.22.mlp.down_proj.weight',\n", + " tensor([[ 0.0149, 0.0122, -0.0037, ..., 0.0044, 0.0171, -0.0186],\n", + " [-0.0037, -0.0002, 0.0066, ..., 0.0263, -0.0025, -0.0012],\n", + " [-0.0075, 0.0209, 0.0045, ..., 0.0082, -0.0160, 0.0079],\n", + " ...,\n", + " [ 0.0001, 0.0507, -0.0078, ..., 0.0001, -0.0119, 0.0286],\n", + " [-0.0198, -0.0122, 0.0047, ..., -0.0052, 0.0130, -0.0007],\n", + " [ 0.0241, -0.0002, -0.0147, ..., 0.0219, -0.0020, -0.0071]])),\n", + " ('model.layers.22.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.22.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.23.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.23.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.23.mixer.in_proj.weight',\n", + " tensor([[-0.0017, 0.0027, -0.0150, ..., 0.0392, -0.0079, -0.0367],\n", + " [ 0.0183, 0.0261, -0.0262, ..., -0.0157, 0.0197, 0.0135],\n", + " [-0.0030, 0.0170, 0.0032, ..., 0.0059, 0.0299, 0.0158],\n", + " ...,\n", + " [-0.0149, 0.0218, 0.0072, ..., -0.0302, 0.0035, 0.0153],\n", + " [-0.0135, 0.0425, 0.0331, ..., -0.0119, -0.0364, 0.0365],\n", + " [-0.0215, -0.0242, 0.0271, ..., 0.0500, 0.0293, 0.0100]])),\n", + " ('model.layers.23.mixer.conv1d.weight',\n", + " tensor([[[ 0.2464, 0.3726, 0.2719, 0.3580]],\n", + " \n", + " [[-0.0520, 0.0010, 0.1396, -0.4634]],\n", + " \n", + " [[ 0.1383, 0.4039, -0.3622, 0.1499]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.4094, 0.0541, 0.2240, -0.1545]],\n", + " \n", + " [[-0.4393, 0.1323, 0.1705, -0.1722]],\n", + " \n", + " [[ 0.2166, -0.4335, -0.4088, -0.1159]]])),\n", + " ('model.layers.23.mixer.conv1d.bias',\n", + " tensor([ 0.3175, -0.0325, -0.4654, ..., 0.3869, -0.2534, 0.1588])),\n", + " ('model.layers.23.mixer.out_proj.weight',\n", + " tensor([[-0.0354, -0.0041, 0.0196, ..., -0.0218, -0.0222, 0.0126],\n", + " [-0.0155, -0.0067, -0.0007, ..., 0.0112, -0.0036, -0.0054],\n", + " [ 0.0141, 0.0040, -0.0218, ..., -0.0178, -0.0031, 0.0162],\n", + " ...,\n", + " [ 0.0264, 0.0063, 0.0088, ..., -0.0310, -0.0116, 0.0239],\n", + " [-0.0031, 0.0056, -0.0243, ..., -0.0350, 0.0004, 0.0004],\n", + " [ 0.0229, -0.0201, 0.0124, ..., 0.0313, -0.0412, -0.0033]])),\n", + " ('model.layers.23.mlp.gate_proj.weight',\n", + " tensor([[ 0.0026, -0.0155, 0.0595, ..., 0.0204, 0.0172, 0.0378],\n", + " [-0.0011, -0.0253, 0.0039, ..., 0.0330, -0.0487, -0.0195],\n", + " [ 0.0174, 0.0039, -0.0029, ..., -0.0026, 0.0104, 0.0108],\n", + " ...,\n", + " [-0.0159, 0.0008, 0.0173, ..., -0.0020, 0.0085, -0.0043],\n", + " [ 0.0101, 0.0221, -0.0034, ..., -0.0268, 0.0056, 0.0137],\n", + " [-0.0031, -0.0151, 0.0073, ..., -0.0083, -0.0064, 0.0109]])),\n", + " ('model.layers.23.mlp.up_proj.weight',\n", + " tensor([[ 0.0173, -0.0132, -0.0027, ..., 0.0391, 0.0268, -0.0185],\n", + " [ 0.0221, -0.0110, -0.0108, ..., -0.0302, 0.0170, 0.0139],\n", + " [-0.0047, -0.0373, 0.0056, ..., -0.0389, -0.0175, -0.0410],\n", + " ...,\n", + " [ 0.0003, 0.0153, 0.0160, ..., 0.0002, -0.0136, 0.0417],\n", + " [-0.0059, -0.0150, -0.0111, ..., 0.0163, 0.0171, 0.0267],\n", + " [-0.0123, -0.0032, 0.0193, ..., -0.0051, -0.0051, -0.0089]])),\n", + " ('model.layers.23.mlp.down_proj.weight',\n", + " tensor([[-0.0092, -0.0148, -0.0345, ..., -0.0240, 0.0425, -0.0099],\n", + " [ 0.0458, 0.0156, -0.0067, ..., -0.0283, 0.0401, 0.0074],\n", + " [ 0.0180, -0.0008, 0.0049, ..., -0.0085, -0.0157, 0.0044],\n", + " ...,\n", + " [-0.0207, 0.0074, -0.0176, ..., 0.0038, -0.0238, -0.0026],\n", + " [-0.0201, 0.0078, 0.0243, ..., -0.0031, 0.0080, -0.0176],\n", + " [-0.0034, 0.0191, 0.0391, ..., -0.0114, 0.0133, -0.0261]])),\n", + " ('model.layers.23.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.23.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.24.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.24.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.24.mixer.in_proj.weight',\n", + " tensor([[-0.0184, -0.0299, 0.0165, ..., 0.0035, 0.0417, -0.0170],\n", + " [-0.0346, -0.0226, 0.0064, ..., 0.0072, 0.0457, -0.0148],\n", + " [ 0.0032, -0.0245, -0.0474, ..., -0.0054, -0.0044, 0.0278],\n", + " ...,\n", + " [ 0.0139, 0.0133, -0.0185, ..., 0.0188, 0.0119, -0.0205],\n", + " [ 0.0235, 0.0161, -0.0095, ..., 0.0013, -0.0382, 0.0213],\n", + " [ 0.0031, -0.0394, 0.0275, ..., -0.0068, 0.0024, 0.0179]])),\n", + " ('model.layers.24.mixer.conv1d.weight',\n", + " tensor([[[-0.1857, -0.4692, 0.4791, 0.3706]],\n", + " \n", + " [[ 0.1749, 0.4182, -0.2338, 0.0838]],\n", + " \n", + " [[-0.1204, -0.2985, -0.0470, 0.4674]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.1485, 0.3118, -0.4916, -0.1610]],\n", + " \n", + " [[ 0.0684, -0.2980, 0.4517, -0.3662]],\n", + " \n", + " [[ 0.2353, -0.2156, -0.3332, -0.0665]]])),\n", + " ('model.layers.24.mixer.conv1d.bias',\n", + " tensor([-0.4464, -0.3485, -0.3916, ..., 0.2513, -0.0601, 0.1546])),\n", + " ('model.layers.24.mixer.out_proj.weight',\n", + " tensor([[-0.0023, 0.0087, -0.0280, ..., 0.0338, -0.0095, -0.0237],\n", + " [-0.0086, -0.0084, 0.0180, ..., 0.0350, 0.0463, -0.0270],\n", + " [-0.0093, -0.0009, 0.0236, ..., 0.0158, 0.0246, 0.0068],\n", + " ...,\n", + " [ 0.0526, 0.0009, 0.0039, ..., -0.0206, -0.0538, 0.0287],\n", + " [ 0.0054, -0.0053, -0.0108, ..., 0.0167, -0.0997, 0.0036],\n", + " [ 0.0009, -0.0297, -0.0424, ..., -0.0096, -0.0235, 0.0117]])),\n", + " ('model.layers.24.mlp.gate_proj.weight',\n", + " tensor([[-0.0265, 0.0259, 0.0224, ..., -0.0080, -0.0394, 0.0290],\n", + " [-0.0101, -0.0256, 0.0079, ..., -0.0017, -0.0287, -0.0163],\n", + " [ 0.0079, -0.0021, -0.0299, ..., 0.0076, 0.0063, 0.0082],\n", + " ...,\n", + " [ 0.0061, 0.0121, 0.0275, ..., -0.0162, 0.0025, -0.0075],\n", + " [-0.0039, -0.0217, -0.0428, ..., -0.0253, 0.0231, 0.0095],\n", + " [-0.0187, 0.0077, -0.0442, ..., 0.0358, -0.0084, -0.0132]])),\n", + " ('model.layers.24.mlp.up_proj.weight',\n", + " tensor([[-0.0201, -0.0119, 0.0505, ..., -0.0025, -0.0187, 0.0011],\n", + " [-0.0105, 0.0154, -0.0163, ..., 0.0248, 0.0028, 0.0178],\n", + " [-0.0163, -0.0271, -0.0100, ..., 0.0129, -0.0220, 0.0269],\n", + " ...,\n", + " [ 0.0138, 0.0329, -0.0091, ..., 0.0038, -0.0194, -0.0223],\n", + " [ 0.0469, 0.0291, -0.0027, ..., 0.0231, 0.0261, 0.0151],\n", + " [-0.0093, -0.0098, 0.0013, ..., 0.0078, -0.0145, 0.0268]])),\n", + " ('model.layers.24.mlp.down_proj.weight',\n", + " tensor([[-0.0195, -0.0003, -0.0046, ..., -0.0132, -0.0118, 0.0242],\n", + " [-0.0267, 0.0199, 0.0243, ..., -0.0063, 0.0134, -0.0163],\n", + " [-0.0044, -0.0303, -0.0215, ..., -0.0148, -0.0216, 0.0079],\n", + " ...,\n", + " [ 0.0159, 0.0180, 0.0098, ..., -0.0126, 0.0176, 0.0087],\n", + " [-0.0203, 0.0041, -0.0256, ..., -0.0047, -0.0236, -0.0256],\n", + " [-0.0017, 0.0133, 0.0490, ..., -0.0344, -0.0118, 0.0020]])),\n", + " ('model.layers.24.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.24.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.25.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.25.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.25.mixer.in_proj.weight',\n", + " tensor([[ 0.0064, 0.0039, 0.0014, ..., 0.0130, -0.0169, 0.0010],\n", + " [ 0.0371, 0.0241, 0.0203, ..., 0.0078, 0.0463, 0.0034],\n", + " [ 0.0184, -0.0431, -0.0026, ..., -0.0164, 0.0279, -0.0138],\n", + " ...,\n", + " [ 0.0146, -0.0138, -0.0418, ..., 0.0234, 0.0145, -0.0213],\n", + " [ 0.0124, -0.0298, -0.0164, ..., -0.0169, 0.0026, -0.0180],\n", + " [-0.0250, -0.0008, -0.0133, ..., -0.0131, -0.0064, 0.0071]])),\n", + " ('model.layers.25.mixer.conv1d.weight',\n", + " tensor([[[ 0.0171, -0.3423, -0.1701, 0.4869]],\n", + " \n", + " [[-0.4648, 0.4797, 0.3531, -0.3819]],\n", + " \n", + " [[-0.1660, -0.3489, -0.2488, 0.4428]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.3545, -0.1567, -0.2646, 0.3590]],\n", + " \n", + " [[-0.2175, 0.4394, 0.3840, 0.2620]],\n", + " \n", + " [[ 0.1335, -0.3655, 0.3256, -0.1752]]])),\n", + " ('model.layers.25.mixer.conv1d.bias',\n", + " tensor([-0.0935, 0.0170, 0.0779, ..., -0.2362, 0.2879, 0.2390])),\n", + " ('model.layers.25.mixer.out_proj.weight',\n", + " tensor([[ 2.0220e-02, 5.0645e-05, -1.7425e-02, ..., 8.6082e-03,\n", + " -1.8566e-02, 1.3872e-02],\n", + " [ 2.9139e-02, 1.1096e-02, 4.4168e-02, ..., 3.5600e-02,\n", + " 7.3446e-03, -1.6368e-02],\n", + " [-3.2418e-02, 6.9682e-03, 3.1648e-02, ..., 1.4050e-02,\n", + " -1.6554e-02, 7.2751e-03],\n", + " ...,\n", + " [-3.3057e-02, -7.0545e-04, 3.9661e-02, ..., 2.0690e-02,\n", + " -1.0262e-02, -4.9292e-03],\n", + " [ 1.9849e-02, 1.9666e-02, -1.9398e-02, ..., 1.9285e-02,\n", + " 2.2522e-02, -6.0243e-03],\n", + " [ 1.7683e-02, 2.4301e-02, 7.2223e-03, ..., 3.1373e-02,\n", + " -5.7889e-03, 1.1855e-02]])),\n", + " ('model.layers.25.mlp.gate_proj.weight',\n", + " tensor([[-1.6223e-02, 4.5519e-03, -1.9218e-02, ..., 6.3580e-03,\n", + " -1.2723e-02, -9.7756e-03],\n", + " [-7.4200e-03, 1.8729e-02, 2.6924e-03, ..., 8.2305e-03,\n", + " -1.5727e-02, -9.8748e-03],\n", + " [ 3.2143e-02, -6.1559e-02, 1.6362e-02, ..., -3.6189e-04,\n", + " 1.2017e-04, -1.5734e-02],\n", + " ...,\n", + " [-1.4649e-02, -4.7663e-03, -1.9292e-02, ..., -1.9359e-02,\n", + " 1.8795e-02, 1.0221e-02],\n", + " [-2.4459e-02, 1.1684e-02, -2.8023e-02, ..., 8.0104e-03,\n", + " 8.5950e-05, 1.0542e-02],\n", + " [-4.5679e-03, -1.1421e-02, -2.1099e-02, ..., 4.5089e-03,\n", + " -3.0686e-02, -9.6116e-03]])),\n", + " ('model.layers.25.mlp.up_proj.weight',\n", + " tensor([[-0.0204, -0.0013, -0.0264, ..., -0.0081, -0.0027, 0.0215],\n", + " [-0.0161, 0.0051, -0.0111, ..., -0.0244, 0.0043, -0.0043],\n", + " [-0.0511, 0.0006, -0.0249, ..., 0.0069, 0.0615, 0.0123],\n", + " ...,\n", + " [-0.0086, -0.0016, 0.0064, ..., -0.0347, 0.0097, -0.0134],\n", + " [-0.0003, 0.0015, -0.0053, ..., 0.0210, 0.0135, 0.0337],\n", + " [-0.0205, 0.0028, -0.0272, ..., -0.0168, -0.0072, 0.0019]])),\n", + " ('model.layers.25.mlp.down_proj.weight',\n", + " tensor([[ 0.0166, 0.0044, 0.0180, ..., -0.0127, 0.0070, -0.0066],\n", + " [-0.0056, 0.0140, 0.0151, ..., -0.0239, -0.0140, 0.0470],\n", + " [-0.0030, -0.0093, -0.0188, ..., -0.0090, -0.0092, -0.0088],\n", + " ...,\n", + " [ 0.0465, 0.0277, -0.0349, ..., 0.0424, 0.0015, 0.0206],\n", + " [-0.0096, 0.0174, 0.0250, ..., -0.0142, -0.0022, -0.0141],\n", + " [-0.0195, -0.0174, 0.0033, ..., 0.0027, -0.0061, -0.0108]])),\n", + " ('model.layers.25.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.25.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.26.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.26.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.26.mixer.in_proj.weight',\n", + " tensor([[ 0.0112, 0.0060, -0.0038, ..., -0.0164, 0.0111, 0.0105],\n", + " [ 0.0227, -0.0248, 0.0240, ..., 0.0103, -0.0373, -0.0051],\n", + " [-0.0073, 0.0227, -0.0190, ..., 0.0048, -0.0101, -0.0137],\n", + " ...,\n", + " [ 0.0086, -0.0084, 0.0177, ..., -0.0245, 0.0119, 0.0022],\n", + " [-0.0080, -0.0284, 0.0440, ..., 0.0340, -0.0093, 0.0130],\n", + " [-0.0107, 0.0234, -0.0279, ..., 0.0106, -0.0169, -0.0001]])),\n", + " ('model.layers.26.mixer.conv1d.weight',\n", + " tensor([[[ 0.0550, -0.3464, -0.2378, -0.1244]],\n", + " \n", + " [[-0.0925, -0.2497, 0.2629, -0.1821]],\n", + " \n", + " [[-0.4524, 0.3462, -0.4604, -0.2758]],\n", + " \n", + " ...,\n", + " \n", + " [[-0.4555, -0.0839, 0.3936, -0.3707]],\n", + " \n", + " [[ 0.3409, -0.4109, 0.0890, -0.3629]],\n", + " \n", + " [[-0.2769, 0.4033, -0.1090, 0.3055]]])),\n", + " ('model.layers.26.mixer.conv1d.bias',\n", + " tensor([-0.2286, -0.2395, -0.2517, ..., 0.0537, 0.0906, 0.4936])),\n", + " ('model.layers.26.mixer.out_proj.weight',\n", + " tensor([[-0.0316, -0.0423, -0.0053, ..., 0.0024, 0.0084, -0.0270],\n", + " [ 0.0458, -0.0243, 0.0060, ..., -0.0007, -0.0161, -0.0232],\n", + " [ 0.0388, -0.0126, 0.0184, ..., -0.0059, 0.0061, 0.0090],\n", + " ...,\n", + " [ 0.0487, 0.0305, -0.0175, ..., -0.0250, -0.0158, -0.0035],\n", + " [-0.0148, -0.0224, 0.0095, ..., -0.0102, -0.0226, 0.0272],\n", + " [-0.0061, 0.0067, 0.0069, ..., 0.0038, -0.0277, -0.0168]])),\n", + " ('model.layers.26.mlp.gate_proj.weight',\n", + " tensor([[-1.9812e-02, 8.3232e-03, 3.0347e-03, ..., 2.1982e-02,\n", + " 1.3550e-02, -1.1203e-02],\n", + " [ 2.2460e-02, 4.9811e-03, -2.2167e-02, ..., 1.3932e-03,\n", + " 5.3891e-03, -2.8310e-02],\n", + " [ 1.1011e-02, -1.2903e-02, -2.8861e-02, ..., 2.6808e-02,\n", + " -2.8479e-03, -1.3105e-02],\n", + " ...,\n", + " [ 1.1078e-03, -1.1789e-02, -4.4165e-02, ..., 8.2950e-03,\n", + " -1.8015e-02, -1.2234e-02],\n", + " [-2.0721e-02, -4.7919e-04, -4.9474e-02, ..., 7.9999e-05,\n", + " 1.7886e-02, -4.4699e-02],\n", + " [ 8.1279e-03, 1.2636e-02, -2.0932e-02, ..., -3.0361e-03,\n", + " 3.3468e-03, 2.7677e-02]])),\n", + " ('model.layers.26.mlp.up_proj.weight',\n", + " tensor([[-0.0301, -0.0025, -0.0147, ..., -0.0186, 0.0058, -0.0057],\n", + " [ 0.0303, -0.0341, 0.0142, ..., -0.0252, -0.0247, 0.0280],\n", + " [ 0.0209, -0.0425, 0.0073, ..., 0.0063, -0.0040, -0.0076],\n", + " ...,\n", + " [-0.0172, -0.0199, 0.0125, ..., 0.0363, 0.0118, -0.0124],\n", + " [-0.0108, 0.0042, -0.0475, ..., 0.0091, -0.0185, 0.0144],\n", + " [-0.0275, -0.0049, 0.0183, ..., -0.0001, -0.0119, -0.0359]])),\n", + " ('model.layers.26.mlp.down_proj.weight',\n", + " tensor([[-0.0197, -0.0082, -0.0224, ..., -0.0469, -0.0076, -0.0375],\n", + " [-0.0070, -0.0071, 0.0190, ..., -0.0125, 0.0068, 0.0166],\n", + " [ 0.0062, -0.0072, 0.0189, ..., -0.0244, -0.0292, -0.0328],\n", + " ...,\n", + " [-0.0054, 0.0219, 0.0058, ..., 0.0118, 0.0136, -0.0221],\n", + " [-0.0133, 0.0299, -0.0182, ..., -0.0496, -0.0202, 0.0196],\n", + " [-0.0131, -0.0237, -0.0473, ..., 0.0066, 0.0119, 0.0100]])),\n", + " ('model.layers.26.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.26.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.27.mixer.z_bias',\n", + " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", + " ('model.layers.27.mixer.D',\n", + " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1.])),\n", + " ('model.layers.27.mixer.in_proj.weight',\n", + " tensor([[ 0.0200, -0.0276, -0.0274, ..., 0.0282, 0.0025, 0.0215],\n", + " [ 0.0054, 0.0218, -0.0175, ..., -0.0054, 0.0211, -0.0073],\n", + " [ 0.0100, -0.0023, 0.0162, ..., 0.0008, -0.0193, -0.0050],\n", + " ...,\n", + " [-0.0241, -0.0197, -0.0142, ..., 0.0039, -0.0175, 0.0045],\n", + " [ 0.0214, 0.0137, -0.0155, ..., -0.0212, 0.0089, 0.0165],\n", + " [ 0.0086, 0.0181, 0.0069, ..., -0.0093, -0.0272, 0.0068]])),\n", + " ('model.layers.27.mixer.conv1d.weight',\n", + " tensor([[[ 0.0519, 0.2061, 0.2635, 0.4916]],\n", + " \n", + " [[ 0.3745, -0.0860, -0.2310, -0.4250]],\n", + " \n", + " [[ 0.0565, 0.3699, 0.2812, -0.4201]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.4073, 0.1852, -0.1687, -0.2643]],\n", + " \n", + " [[-0.0865, -0.0894, 0.2650, -0.4522]],\n", + " \n", + " [[-0.0987, 0.0925, -0.2098, 0.0325]]])),\n", + " ('model.layers.27.mixer.conv1d.bias',\n", + " tensor([-0.4788, -0.0231, -0.4210, ..., -0.3143, -0.2893, 0.0570])),\n", + " ('model.layers.27.mixer.out_proj.weight',\n", + " tensor([[-0.0294, -0.0038, -0.0213, ..., -0.0141, 0.0072, -0.0359],\n", + " [ 0.0131, 0.0173, 0.0159, ..., 0.0030, 0.0400, -0.0065],\n", + " [-0.0111, 0.0374, 0.0109, ..., -0.0338, 0.0312, 0.0073],\n", + " ...,\n", + " [-0.0004, 0.0282, 0.0148, ..., 0.0165, 0.0062, -0.0177],\n", + " [ 0.0265, -0.0331, -0.0056, ..., 0.0407, 0.0154, 0.0176],\n", + " [ 0.0209, -0.0293, 0.0009, ..., -0.0240, -0.0029, -0.0407]])),\n", + " ('model.layers.27.mlp.gate_proj.weight',\n", + " tensor([[-0.0118, 0.0202, -0.0012, ..., 0.0101, 0.0075, 0.0102],\n", + " [ 0.0102, -0.0062, 0.0330, ..., -0.0024, -0.0245, -0.0237],\n", + " [-0.0008, 0.0202, -0.0097, ..., 0.0022, -0.0152, -0.0128],\n", + " ...,\n", + " [-0.0461, 0.0178, 0.0253, ..., 0.0319, 0.0173, -0.0099],\n", + " [ 0.0014, -0.0256, 0.0224, ..., 0.0272, 0.0045, 0.0192],\n", + " [ 0.0146, -0.0357, -0.0089, ..., -0.0147, 0.0383, 0.0354]])),\n", + " ('model.layers.27.mlp.up_proj.weight',\n", + " tensor([[-3.1854e-02, -1.0290e-03, -3.4564e-03, ..., 3.3551e-03,\n", + " 3.2845e-02, 2.1107e-02],\n", + " [-4.8083e-04, -5.8388e-03, 1.7324e-03, ..., 2.0575e-02,\n", + " -1.1685e-02, 1.2504e-02],\n", + " [ 4.6267e-02, -1.8935e-02, -2.4184e-02, ..., -4.8211e-02,\n", + " -3.3912e-04, 3.0527e-02],\n", + " ...,\n", + " [-6.9427e-03, -4.8680e-03, 3.2021e-02, ..., 1.4236e-02,\n", + " 1.9532e-02, 1.3339e-02],\n", + " [ 1.2463e-02, -5.5923e-03, -1.5680e-02, ..., 8.7956e-03,\n", + " 2.8262e-02, -1.2526e-02],\n", + " [-4.8530e-03, -8.8749e-05, 3.3507e-02, ..., -2.8260e-02,\n", + " -2.0571e-03, -8.3943e-03]])),\n", + " ('model.layers.27.mlp.down_proj.weight',\n", + " tensor([[-0.0457, -0.0267, -0.0210, ..., -0.0093, -0.0016, -0.0008],\n", + " [-0.0053, 0.0284, -0.0003, ..., 0.0065, -0.0117, 0.0243],\n", + " [ 0.0120, 0.0023, -0.0180, ..., -0.0003, -0.0313, 0.0163],\n", + " ...,\n", + " [-0.0160, 0.0207, 0.0082, ..., 0.0153, 0.0131, 0.0034],\n", + " [-0.0073, 0.0424, 0.0274, ..., -0.0075, -0.0554, -0.0114],\n", + " [-0.0192, 0.0268, 0.0036, ..., 0.0094, 0.0045, 0.0030]])),\n", + " ('model.layers.27.input_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.layers.27.post_attention_layernorm.weight',\n", + " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('model.norm.weight', tensor([1., 1., 1., ..., 1., 1., 1.])),\n", + " ('lm_head.weight',\n", + " tensor([[-0.0141, -0.0445, 0.0071, ..., -0.0143, -0.0239, -0.0512],\n", + " [ 0.0295, -0.0317, -0.0201, ..., -0.0082, 0.0231, -0.0030],\n", + " [-0.0255, -0.0139, 0.0020, ..., -0.0040, -0.0154, 0.0336],\n", + " ...,\n", + " [ 0.0095, 0.0361, 0.0135, ..., -0.0018, 0.0074, -0.0311],\n", + " [-0.0092, 0.0060, 0.0594, ..., -0.0046, 0.0117, 0.0364],\n", + " [ 0.0228, -0.0265, -0.0262, ..., 0.0038, 0.0097, -0.0257]]))])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_ssm.state_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "N params SSM: 5.305533088\n" + ] + } + ], + "source": [ + "print(\"N params SSM:\", sum(p.numel() for p in apriel_ssm.parameters() if p.requires_grad)/1e9)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load State dict into SSM" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMForCausalLM(\n", + " (model): AprielSSMModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "apriel_ssm.to(device).to(dtype=torch.bfloat16)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "_IncompatibleKeys(missing_keys=['model.layers.0.mixer.z_bias', 'model.layers.0.mixer.D', 'model.layers.0.mixer.in_proj.weight', 'model.layers.0.mixer.conv1d.weight', 'model.layers.0.mixer.conv1d.bias', 'model.layers.0.mixer.out_proj.weight', 'model.layers.1.mixer.z_bias', 'model.layers.1.mixer.D', 'model.layers.1.mixer.in_proj.weight', 'model.layers.1.mixer.conv1d.weight', 'model.layers.1.mixer.conv1d.bias', 'model.layers.1.mixer.out_proj.weight', 'model.layers.2.mixer.z_bias', 'model.layers.2.mixer.D', 'model.layers.2.mixer.in_proj.weight', 'model.layers.2.mixer.conv1d.weight', 'model.layers.2.mixer.conv1d.bias', 'model.layers.2.mixer.out_proj.weight', 'model.layers.3.mixer.z_bias', 'model.layers.3.mixer.D', 'model.layers.3.mixer.in_proj.weight', 'model.layers.3.mixer.conv1d.weight', 'model.layers.3.mixer.conv1d.bias', 'model.layers.3.mixer.out_proj.weight', 'model.layers.4.mixer.z_bias', 'model.layers.4.mixer.D', 'model.layers.4.mixer.in_proj.weight', 'model.layers.4.mixer.conv1d.weight', 'model.layers.4.mixer.conv1d.bias', 'model.layers.4.mixer.out_proj.weight', 'model.layers.5.mixer.z_bias', 'model.layers.5.mixer.D', 'model.layers.5.mixer.in_proj.weight', 'model.layers.5.mixer.conv1d.weight', 'model.layers.5.mixer.conv1d.bias', 'model.layers.5.mixer.out_proj.weight', 'model.layers.6.mixer.z_bias', 'model.layers.6.mixer.D', 'model.layers.6.mixer.in_proj.weight', 'model.layers.6.mixer.conv1d.weight', 'model.layers.6.mixer.conv1d.bias', 'model.layers.6.mixer.out_proj.weight', 'model.layers.7.mixer.z_bias', 'model.layers.7.mixer.D', 'model.layers.7.mixer.in_proj.weight', 'model.layers.7.mixer.conv1d.weight', 'model.layers.7.mixer.conv1d.bias', 'model.layers.7.mixer.out_proj.weight', 'model.layers.8.mixer.z_bias', 'model.layers.8.mixer.D', 'model.layers.8.mixer.in_proj.weight', 'model.layers.8.mixer.conv1d.weight', 'model.layers.8.mixer.conv1d.bias', 'model.layers.8.mixer.out_proj.weight', 'model.layers.9.mixer.z_bias', 'model.layers.9.mixer.D', 'model.layers.9.mixer.in_proj.weight', 'model.layers.9.mixer.conv1d.weight', 'model.layers.9.mixer.conv1d.bias', 'model.layers.9.mixer.out_proj.weight', 'model.layers.10.mixer.z_bias', 'model.layers.10.mixer.D', 'model.layers.10.mixer.in_proj.weight', 'model.layers.10.mixer.conv1d.weight', 'model.layers.10.mixer.conv1d.bias', 'model.layers.10.mixer.out_proj.weight', 'model.layers.11.mixer.z_bias', 'model.layers.11.mixer.D', 'model.layers.11.mixer.in_proj.weight', 'model.layers.11.mixer.conv1d.weight', 'model.layers.11.mixer.conv1d.bias', 'model.layers.11.mixer.out_proj.weight', 'model.layers.12.mixer.z_bias', 'model.layers.12.mixer.D', 'model.layers.12.mixer.in_proj.weight', 'model.layers.12.mixer.conv1d.weight', 'model.layers.12.mixer.conv1d.bias', 'model.layers.12.mixer.out_proj.weight', 'model.layers.13.mixer.z_bias', 'model.layers.13.mixer.D', 'model.layers.13.mixer.in_proj.weight', 'model.layers.13.mixer.conv1d.weight', 'model.layers.13.mixer.conv1d.bias', 'model.layers.13.mixer.out_proj.weight', 'model.layers.14.mixer.z_bias', 'model.layers.14.mixer.D', 'model.layers.14.mixer.in_proj.weight', 'model.layers.14.mixer.conv1d.weight', 'model.layers.14.mixer.conv1d.bias', 'model.layers.14.mixer.out_proj.weight', 'model.layers.15.mixer.z_bias', 'model.layers.15.mixer.D', 'model.layers.15.mixer.in_proj.weight', 'model.layers.15.mixer.conv1d.weight', 'model.layers.15.mixer.conv1d.bias', 'model.layers.15.mixer.out_proj.weight', 'model.layers.16.mixer.z_bias', 'model.layers.16.mixer.D', 'model.layers.16.mixer.in_proj.weight', 'model.layers.16.mixer.conv1d.weight', 'model.layers.16.mixer.conv1d.bias', 'model.layers.16.mixer.out_proj.weight', 'model.layers.17.mixer.z_bias', 'model.layers.17.mixer.D', 'model.layers.17.mixer.in_proj.weight', 'model.layers.17.mixer.conv1d.weight', 'model.layers.17.mixer.conv1d.bias', 'model.layers.17.mixer.out_proj.weight', 'model.layers.18.mixer.z_bias', 'model.layers.18.mixer.D', 'model.layers.18.mixer.in_proj.weight', 'model.layers.18.mixer.conv1d.weight', 'model.layers.18.mixer.conv1d.bias', 'model.layers.18.mixer.out_proj.weight', 'model.layers.19.mixer.z_bias', 'model.layers.19.mixer.D', 'model.layers.19.mixer.in_proj.weight', 'model.layers.19.mixer.conv1d.weight', 'model.layers.19.mixer.conv1d.bias', 'model.layers.19.mixer.out_proj.weight', 'model.layers.20.mixer.z_bias', 'model.layers.20.mixer.D', 'model.layers.20.mixer.in_proj.weight', 'model.layers.20.mixer.conv1d.weight', 'model.layers.20.mixer.conv1d.bias', 'model.layers.20.mixer.out_proj.weight', 'model.layers.21.mixer.z_bias', 'model.layers.21.mixer.D', 'model.layers.21.mixer.in_proj.weight', 'model.layers.21.mixer.conv1d.weight', 'model.layers.21.mixer.conv1d.bias', 'model.layers.21.mixer.out_proj.weight', 'model.layers.22.mixer.z_bias', 'model.layers.22.mixer.D', 'model.layers.22.mixer.in_proj.weight', 'model.layers.22.mixer.conv1d.weight', 'model.layers.22.mixer.conv1d.bias', 'model.layers.22.mixer.out_proj.weight', 'model.layers.23.mixer.z_bias', 'model.layers.23.mixer.D', 'model.layers.23.mixer.in_proj.weight', 'model.layers.23.mixer.conv1d.weight', 'model.layers.23.mixer.conv1d.bias', 'model.layers.23.mixer.out_proj.weight', 'model.layers.24.mixer.z_bias', 'model.layers.24.mixer.D', 'model.layers.24.mixer.in_proj.weight', 'model.layers.24.mixer.conv1d.weight', 'model.layers.24.mixer.conv1d.bias', 'model.layers.24.mixer.out_proj.weight', 'model.layers.25.mixer.z_bias', 'model.layers.25.mixer.D', 'model.layers.25.mixer.in_proj.weight', 'model.layers.25.mixer.conv1d.weight', 'model.layers.25.mixer.conv1d.bias', 'model.layers.25.mixer.out_proj.weight', 'model.layers.26.mixer.z_bias', 'model.layers.26.mixer.D', 'model.layers.26.mixer.in_proj.weight', 'model.layers.26.mixer.conv1d.weight', 'model.layers.26.mixer.conv1d.bias', 'model.layers.26.mixer.out_proj.weight', 'model.layers.27.mixer.z_bias', 'model.layers.27.mixer.D', 'model.layers.27.mixer.in_proj.weight', 'model.layers.27.mixer.conv1d.weight', 'model.layers.27.mixer.conv1d.bias', 'model.layers.27.mixer.out_proj.weight'], unexpected_keys=['model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.18.self_attn.q_proj.weight', 'model.layers.18.self_attn.k_proj.weight', 'model.layers.18.self_attn.v_proj.weight', 'model.layers.18.self_attn.o_proj.weight', 'model.layers.19.self_attn.q_proj.weight', 'model.layers.19.self_attn.k_proj.weight', 'model.layers.19.self_attn.v_proj.weight', 'model.layers.19.self_attn.o_proj.weight', 'model.layers.20.self_attn.q_proj.weight', 'model.layers.20.self_attn.k_proj.weight', 'model.layers.20.self_attn.v_proj.weight', 'model.layers.20.self_attn.o_proj.weight', 'model.layers.21.self_attn.q_proj.weight', 'model.layers.21.self_attn.k_proj.weight', 'model.layers.21.self_attn.v_proj.weight', 'model.layers.21.self_attn.o_proj.weight', 'model.layers.22.self_attn.q_proj.weight', 'model.layers.22.self_attn.k_proj.weight', 'model.layers.22.self_attn.v_proj.weight', 'model.layers.22.self_attn.o_proj.weight', 'model.layers.23.self_attn.q_proj.weight', 'model.layers.23.self_attn.k_proj.weight', 'model.layers.23.self_attn.v_proj.weight', 'model.layers.23.self_attn.o_proj.weight', 'model.layers.24.self_attn.q_proj.weight', 'model.layers.24.self_attn.k_proj.weight', 'model.layers.24.self_attn.v_proj.weight', 'model.layers.24.self_attn.o_proj.weight', 'model.layers.25.self_attn.q_proj.weight', 'model.layers.25.self_attn.k_proj.weight', 'model.layers.25.self_attn.v_proj.weight', 'model.layers.25.self_attn.o_proj.weight', 'model.layers.26.self_attn.q_proj.weight', 'model.layers.26.self_attn.k_proj.weight', 'model.layers.26.self_attn.v_proj.weight', 'model.layers.26.self_attn.o_proj.weight', 'model.layers.27.self_attn.q_proj.weight', 'model.layers.27.self_attn.k_proj.weight', 'model.layers.27.self_attn.v_proj.weight', 'model.layers.27.self_attn.o_proj.weight'])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_ssm.load_state_dict(apriel_state_dict, strict=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMForCausalLM(\n", + " (model): AprielSSMModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "apriel_ssm.to(device).to(dtype=torch.bfloat16)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# apriel_ssm.state_dict()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Save checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'apriel_ssm' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[2], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mapriel_ssm\u001b[49m\u001b[38;5;241m.\u001b[39msave_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/mnt/checkpoints/ssm/apriel_ssm_instruct_base\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 2\u001b[0m save_config\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", + "\u001b[0;31mNameError\u001b[0m: name 'apriel_ssm' is not defined" + ] + } + ], + "source": [ + "apriel_ssm.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_instruct_base\",\n", + " save_config=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "24" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_ssm.model.layers[0].mixer.n_v_heads" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMForCausalLM(\n", + " (model): AprielSSMModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_ssm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Try a forward pass" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "input_ids = torch.randint(0, 32000, (1, 128), dtype=torch.long, device=device)\n", + "batch_size = 1\n", + "max_length = 128\n", + "state = SimpleNamespace()\n", + "state.key_value_memory_dict = apriel_ssm.allocate_inference_cache(batch_size, max_length, dtype=torch.bfloat16)\n", + "state.batch_size = batch_size\n", + "state.seqlen_offset = 0\n", + "static_inputs = {\"inference_params\": state,\n", + " \"input_ids\": input_ids,\n", + " \"use_cache\": True,\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "CustomMambaCausalLMOutput(loss=None, logits=tensor([[[-3.0781, 2.3594, 1.4609, ..., -2.3438, -1.9688, 0.6484],\n", + " [-5.8125, 4.9688, 0.4414, ..., -4.2500, -3.5156, -4.8125],\n", + " [-5.5000, 3.3594, 1.1484, ..., -3.4375, -2.3125, -4.4375],\n", + " ...,\n", + " [-2.2812, 0.1465, 2.2344, ..., -7.6875, -3.0312, -6.2500],\n", + " [-6.8750, 1.7812, -1.3750, ..., -7.4688, -5.6875, -4.4062],\n", + " [-2.0156, 2.0938, 3.1094, ..., -3.0156, -2.1406, -2.2812]]],\n", + " device='cuda:0', grad_fn=), all_hidden_states=(), last_hidden_state=tensor([[[-1.3828, 0.0625, -2.7500, ..., -0.6523, -0.8906, 1.4609],\n", + " [ 2.1406, -0.0247, -3.0156, ..., -0.0074, 1.0234, 1.3828],\n", + " [ 1.6016, -0.7266, -1.2422, ..., -0.4004, -0.8242, -0.5586],\n", + " ...,\n", + " [ 1.5234, -0.0262, -1.5469, ..., -0.4922, -1.0078, 1.2344],\n", + " [-0.4629, -0.6055, -1.3906, ..., -0.9922, -0.3066, 1.1875],\n", + " [-0.7539, -0.0243, -2.4688, ..., -1.0625, -2.7188, 2.6875]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=))" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_ssm.forward(**static_inputs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load mdoel" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import torch\n", + "from mamba_ssm import MambaLMHeadModel\n", + "from mamba_ssm.models.config_mamba import MambaConfig\n", + "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", + "from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig\n", + "from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM\n", + "from transformers.cache_utils import StaticCache\n", + "from types import SimpleNamespace\n", + "import os\n", + "import shutil\n", + "# make sure the code changes reflected without reload\n", + "%load_ext autoreload\n", + "%autoreload 2\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "model_path = \"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/apriel_ssminstr-distil-randinit-bs768-lr0.0003-sl4096_ti5000_luke_mix1/export/apriel_ssm/5000\"\n", + "modeling_path = \"/home/toolkit/dev/Fast-LLM/fast_llm/models/ssm/external\"\n", + "# # copy the config.json to the model path\n", + "shutil.copy(os.path.join(modeling_path, \"modeling_ssm_apriel.py\"), os.path.join(model_path, \"modeling_ssm_apriel.py\"))\n", + "shutil.copy(os.path.join(modeling_path, \"configuration_ssm_apriel.py\"), os.path.join(model_path, \"configuration_ssm_apriel.py\"))\n", + "\n", + "tokenizer_path = \"/mnt/checkpoints/upstream/Mistral-Nemo-Base-2407/\"\n", + "# # cp tokenizer*\n", + "# shutil.copy(os.path.join(tokenizer_path, \"tokenizer.json\"), os.path.join(model_path, \"tokenizer.json\"))\n", + "# shutil.copy(os.path.join(tokenizer_path, \"tokenizer_config.json\"), os.path.join(model_path, \"tokenizer_config.json\"))\n", + "# shutil.copy(os.path.join(tokenizer_path, \"special_tokens_map.json\"), os.path.join(model_path, \"special_tokens_map.json\"))\n", + "# shutil.copy(os.path.join(tokenizer_path, \"vocab.json\"), os.path.join(model_path, \"vocab.json\"))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n", + "Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00, 1.08s/it]\n" + ] + } + ], + "source": [ + "\n", + "apriel_ssm = AprielSSMForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, device=\"cuda\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMForCausalLM(\n", + " (model): AprielSSMModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "apriel_ssm" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "config = apriel_ssm.config" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Mamba in Llama" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], + "source": [ + "\n", + "from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig\n", + "import torch\n", + "from mamba_ssm import MambaLMHeadModel\n", + "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", + "from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig\n", + "from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM\n", + "from transformers.cache_utils import StaticCache\n", + "from types import SimpleNamespace\n", + "from fast_llm.models.ssm.external.modeling_ssm_hybrid_apriel import AprielSSMHybridConfig\n", + "from fast_llm.models.ssm.external.modeling_ssm_hybrid_apriel import AprielSSMHybridModel\n", + "# from fast_llm.models.ssm.external.__hybrid_wrapper import MambaTransformerHybridModelWrapper\n", + "# make sure the code changes reflected without reload\n", + "%load_ext autoreload\n", + "%autoreload 2\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMHybridConfig {\n", + " \"_name_or_path\": \"ServiceNow-AI/Apriel-5B-Instruct\",\n", + " \"architectures\": [\n", + " \"AprielForCausalLM\"\n", + " ],\n", + " \"attention_bias\": false,\n", + " \"attention_dropout\": 0.0,\n", + " \"auto_map\": {\n", + " \"AutoConfig\": \"ServiceNow-AI/Apriel-5B-Instruct--configuration_apriel.AprielConfig\",\n", + " \"AutoModelForCausalLM\": \"ServiceNow-AI/Apriel-5B-Instruct--modeling_apriel.AprielForCausalLM\"\n", + " },\n", + " \"bos_token_id\": 1,\n", + " \"eos_token_id\": 2,\n", + " \"head_dim\": 128,\n", + " \"hidden_act\": \"silu\",\n", + " \"hidden_size\": 4096,\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 8192,\n", + " \"max_position_embeddings\": 16384,\n", + " \"mlp_bias\": false,\n", + " \"model_type\": \"apriel\",\n", + " \"num_attention_heads\": 24,\n", + " \"num_hidden_layers\": 28,\n", + " \"num_key_value_heads\": 8,\n", + " \"pretraining_tp\": 1,\n", + " \"rms_norm_eps\": 1e-05,\n", + " \"rope_scaling\": {\n", + " \"attention_factor\": null,\n", + " \"beta_fast\": 32.0,\n", + " \"beta_slow\": 1.0,\n", + " \"factor\": 32.0,\n", + " \"original_max_position_embeddings\": 4096,\n", + " \"rope_type\": \"yarn\"\n", + " },\n", + " \"rope_theta\": 1000000.0,\n", + " \"ssm_block_pattern\": [\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\",\n", + " \"m2d\",\n", + " \"t\"\n", + " ],\n", + " \"ssm_cfg\": {\n", + " \"activation\": \"identity\",\n", + " \"bias\": false,\n", + " \"chunk_size\": 128,\n", + " \"d_inner\": 3072,\n", + " \"d_state\": 64,\n", + " \"expand\": 1,\n", + " \"n_qk_heads\": 24,\n", + " \"n_v_heads\": 24\n", + " },\n", + " \"tie_word_embeddings\": false,\n", + " \"torch_dtype\": \"bfloat16\",\n", + " \"transformers_version\": \"4.48.1\",\n", + " \"use_cache\": true,\n", + " \"vocab_size\": 131072\n", + "}" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", + "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", + "hybrdif_apriel_config = AprielSSMHybridConfig(**config.to_dict(),\n", + " ssm_block_pattern=[\"m2d\", \"t\"] * 14,\n", + " ssm_cfg=None)\n", + "hybrdif_apriel_config" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "28" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "config.num_hidden_layers" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "hybrid_apriel_model = AprielSSMHybridModel(hybrdif_apriel_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "input_ids = torch.randint(0, 32000, (1, 128), dtype=torch.long, device=device)\n", + "batch_size = 1\n", + "max_length = 128\n", + "state = SimpleNamespace()\n", + "state.key_value_memory_dict = hybrid_apriel_model.allocate_inference_cache(batch_size, max_length, dtype=torch.bfloat16)\n", + "state.batch_size = batch_size\n", + "state.seqlen_offset = 0\n", + "static_inputs = {\"inference_params\": state,\n", + " \"input_ids\": input_ids,\n", + " \"use_cache\": True,\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMHybridModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (1): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (2): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (3): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (4): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (5): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (6): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (7): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (8): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (9): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (10): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (11): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (12): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (13): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (14): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (15): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (16): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (17): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (18): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (19): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (20): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (21): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (22): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (23): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (24): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (25): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (26): AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " (27): AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (rotary_emb): AprielRotaryEmbedding()\n", + ")" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hybrid_apriel_model.to(device).to(dtype=torch.bfloat16)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "BaseModelOutputWithPast(last_hidden_state=tensor([[[ 2.2031, 0.1777, 0.4258, ..., -2.0312, 0.2246, 0.5664],\n", + " [ 0.0562, -1.1016, 0.4590, ..., -2.1719, -0.1455, -0.6992],\n", + " [-1.5078, -1.3516, 0.8789, ..., -1.9141, 1.3672, -1.0391],\n", + " ...,\n", + " [-1.4453, 0.1260, 0.6992, ..., 0.4746, -0.1729, -0.5938],\n", + " [-0.4961, -0.4160, -0.4551, ..., -0.1328, 0.7461, -0.0376],\n", + " [ 0.3184, 0.4355, -0.7578, ..., 1.5547, 0.8555, -0.8711]]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=), past_key_values=DynamicCache(), hidden_states=None, attentions=None)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "hybrid_apriel_model.forward(**static_inputs)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 9.73it/s]\n" + ] + }, + { + "ename": "RuntimeError", + "evalue": "CUDA error: out of memory\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[3], line 6\u001b[0m\n\u001b[1;32m 4\u001b[0m apriel_model \u001b[38;5;241m=\u001b[39m AutoModelForCausalLM\u001b[38;5;241m.\u001b[39mfrom_pretrained(checkpoint, torch_dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mbfloat16, trust_remote_code\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 5\u001b[0m apriel_state_dict \u001b[38;5;241m=\u001b[39m apriel_model\u001b[38;5;241m.\u001b[39mstate_dict()\n\u001b[0;32m----> 6\u001b[0m \u001b[43mapriel_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mto(dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mbfloat16)\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/modeling_utils.py:3110\u001b[0m, in \u001b[0;36mPreTrainedModel.to\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 3105\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype_present_in_args:\n\u001b[1;32m 3106\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 3107\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3108\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m `dtype` by passing the correct `torch_dtype` argument.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3109\u001b[0m )\n\u001b[0;32m-> 3110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1174\u001b[0m, in \u001b[0;36mModule.to\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1171\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1172\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n\u001b[0;32m-> 1174\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconvert\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:780\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 778\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[1;32m 779\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchildren():\n\u001b[0;32m--> 780\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 782\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[1;32m 783\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[1;32m 784\u001b[0m \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[1;32m 785\u001b[0m \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[1;32m 791\u001b[0m \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:780\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 778\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[1;32m 779\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchildren():\n\u001b[0;32m--> 780\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 782\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[1;32m 783\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[1;32m 784\u001b[0m \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[1;32m 785\u001b[0m \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[1;32m 791\u001b[0m \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:805\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 801\u001b[0m \u001b[38;5;66;03m# Tensors stored in modules are graph leaves, and we don't want to\u001b[39;00m\n\u001b[1;32m 802\u001b[0m \u001b[38;5;66;03m# track autograd history of `param_applied`, so we have to use\u001b[39;00m\n\u001b[1;32m 803\u001b[0m \u001b[38;5;66;03m# `with torch.no_grad():`\u001b[39;00m\n\u001b[1;32m 804\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m--> 805\u001b[0m param_applied \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparam\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 806\u001b[0m p_should_use_set_data \u001b[38;5;241m=\u001b[39m compute_should_use_set_data(param, param_applied)\n\u001b[1;32m 808\u001b[0m \u001b[38;5;66;03m# subclasses may have multiple child tensors so we need to use swap_tensors\u001b[39;00m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1160\u001b[0m, in \u001b[0;36mModule.to..convert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 1153\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m convert_to_format \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m t\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;241m4\u001b[39m, \u001b[38;5;241m5\u001b[39m):\n\u001b[1;32m 1154\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m t\u001b[38;5;241m.\u001b[39mto(\n\u001b[1;32m 1155\u001b[0m device,\n\u001b[1;32m 1156\u001b[0m dtype \u001b[38;5;28;01mif\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_floating_point() \u001b[38;5;129;01mor\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_complex() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1157\u001b[0m non_blocking,\n\u001b[1;32m 1158\u001b[0m memory_format\u001b[38;5;241m=\u001b[39mconvert_to_format,\n\u001b[1;32m 1159\u001b[0m )\n\u001b[0;32m-> 1160\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1161\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1162\u001b[0m \u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_floating_point\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_complex\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1163\u001b[0m \u001b[43m \u001b[49m\u001b[43mnon_blocking\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1164\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1165\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 1166\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mstr\u001b[39m(e) \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot copy out of meta tensor; no data!\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", + "\u001b[0;31mRuntimeError\u001b[0m: CUDA error: out of memory\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n" + ] + } + ], + "source": [ + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", + "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", + "apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", + "apriel_state_dict = apriel_model.state_dict()\n", + "apriel_model.to(device).to(dtype=torch.bfloat16)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielConfig {\n", + " \"_name_or_path\": \"ServiceNow-AI/Apriel-5B-Instruct\",\n", + " \"architectures\": [\n", + " \"AprielForCausalLM\"\n", + " ],\n", + " \"attention_bias\": false,\n", + " \"attention_dropout\": 0.0,\n", + " \"auto_map\": {\n", + " \"AutoConfig\": \"ServiceNow-AI/Apriel-5B-Instruct--configuration_apriel.AprielConfig\",\n", + " \"AutoModelForCausalLM\": \"ServiceNow-AI/Apriel-5B-Instruct--modeling_apriel.AprielForCausalLM\"\n", + " },\n", + " \"bos_token_id\": 1,\n", + " \"eos_token_id\": 2,\n", + " \"head_dim\": 128,\n", + " \"hidden_act\": \"silu\",\n", + " \"hidden_size\": 4096,\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 8192,\n", + " \"max_position_embeddings\": 16384,\n", + " \"mlp_bias\": false,\n", + " \"model_type\": \"apriel\",\n", + " \"num_attention_heads\": 24,\n", + " \"num_hidden_layers\": 28,\n", + " \"num_key_value_heads\": 8,\n", + " \"pretraining_tp\": 1,\n", + " \"rms_norm_eps\": 1e-05,\n", + " \"rope_scaling\": {\n", + " \"attention_factor\": null,\n", + " \"beta_fast\": 32.0,\n", + " \"beta_slow\": 1.0,\n", + " \"factor\": 32.0,\n", + " \"original_max_position_embeddings\": 4096,\n", + " \"rope_type\": \"yarn\"\n", + " },\n", + " \"rope_theta\": 1000000.0,\n", + " \"tie_word_embeddings\": false,\n", + " \"torch_dtype\": \"bfloat16\",\n", + " \"transformers_version\": \"4.48.1\",\n", + " \"use_cache\": true,\n", + " \"vocab_size\": 131072\n", + "}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "config" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "d_xb = config.num_key_value_heads * config.head_dim\n", + "ssm_layers = [2,4,8]\n", + "attn_layers = [i for i in range(config.num_hidden_layers) if i not in ssm_layers]\n", + "model_name = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", + "ngroups = config.num_attention_heads # n heads\n", + "d_inner = config.head_dim * config.num_attention_heads\n", + "headdim = 128 # d_state\n", + "d_state = config.head_dim\n", + "d_model = config.hidden_size \n", + "assert d_inner == ngroups * d_state\n", + "\n", + "mamba_config = AprielSSMConfig(\n", + " ssm_cfg={\n", + " \"d_state\": 64,\n", + " \"n_v_heads\": 24,\n", + " \"n_qk_heads\": 24,\n", + " \"expand\": 1,\n", + " \"chunk_size\": 128,\n", + " \"activation\": \"identity\",\n", + " \"bias\": False,\n", + " \"d_inner\": 24 * headdim, # num_heads * head_dim\n", + " },\n", + " vocab_size=config.vocab_size, \n", + " hidden_size=config.hidden_size,\n", + " intermediate_size=config.intermediate_size,\n", + " num_hidden_layers=config.num_hidden_layers,\n", + " hidden_act=config.hidden_act,\n", + " initializer_range=config.initializer_range,\n", + " use_cache=config.use_cache,\n", + " mlp_bias=config.mlp_bias,\n", + " tie_word_embeddings=config.tie_word_embeddings,\n", + " pad_token_id=config.pad_token_id,\n", + " bos_token_id=config.bos_token_id,\n", + " eos_token_id=config.eos_token_id,\n", + " head_dim=config.head_dim,\n", + " rms_norm_eps=config.rms_norm_eps\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "student_model = MambaTransformerHybridModelWrapper.init_distillation(None, model_name, \n", + " mamba_config, \n", + " attn_layers=attn_layers, \n", + " init_with_kqvo=True, \n", + " attn_implementation=\"flash_attention_2\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "hymba2", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/fast_llm/models/ssm/external/configuration_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/configuration_ssm_hybrid_apriel.py new file mode 100644 index 00000000..58891802 --- /dev/null +++ b/fast_llm/models/ssm/external/configuration_ssm_hybrid_apriel.py @@ -0,0 +1,446 @@ +import math +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import is_torch_available, logging + +logger = logging.get_logger(__name__) + +if is_torch_available(): + import torch + + +def _compute_default_rope_parameters( + config: Optional[PretrainedConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + **rope_kwargs, +) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + if config is not None and len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " + f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" + ) + if len(rope_kwargs) > 0: + base = rope_kwargs["base"] + dim = rope_kwargs["dim"] + elif config is not None: + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) + return inv_freq, attention_factor + + +def _compute_yarn_parameters( + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs +) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with NTK scaling. Please refer to the + [original paper](https://arxiv.org/abs/2309.00071) + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # No need to keep BC with yarn, unreleased when this new pattern was created. + if len(rope_kwargs) > 0: + raise ValueError( + f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}" + ) + + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + + # Apriel: Use original max_position_embeddings instead of max_position_embeddings + max_position_embeddings = config.rope_scaling.get( + "original_max_position_embeddings", config.max_position_embeddings + ) + factor = config.rope_scaling["factor"] + + # Sets the attention factor as suggested in the paper + attention_factor = config.rope_scaling.get("attention_factor") + if attention_factor is None: + attention_factor = 0.1 * math.log(factor) + 1.0 + + # Optional config options + # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) + beta_fast = config.rope_scaling.get("beta_fast") or 32 + beta_slow = config.rope_scaling.get("beta_slow") or 1 + + # Compute the inverse frequencies + def find_correction_dim(num_rotations, dim, base, max_position_embeddings): + """Inverse dimension formula to find the dimension based on the number of rotations""" + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings): + """Find dimension range bounds based on rotations""" + low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs + # to expand the possible context length. In other words, interpolation = apply scaling factor. + pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor * pos_freqs) + + low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings) + + # Get n-dimensional rotational scaling corrected for extrapolation + inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device) + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor + ) + + return inv_freq, attention_factor + + +def _check_received_keys( + rope_type: str, + received_keys: set, + required_keys: set, + optional_keys: Optional[set] = None, + ignore_keys: Optional[set] = None, +): + """Compare the received keys in `config.rope_scaling` against the expected and optional keys""" + # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present + if "type" in received_keys: + received_keys -= {"type"} + required_keys.add("rope_type") + + # Some models need to store model-specific keys, and we don't want to throw warning at them + if ignore_keys is not None: + received_keys -= ignore_keys + + missing_keys = required_keys - received_keys + if missing_keys: + raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}") + + if optional_keys is not None: + unused_keys = received_keys - required_keys - optional_keys + else: + unused_keys = received_keys - required_keys + if unused_keys: + logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") + + +def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) + + +def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor", "original_max_position_embeddings"} + optional_keys = {"attention_factor", "beta_fast", "beta_slow"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + attention_factor = rope_scaling.get("attention_factor") + if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0): + logger.warning( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + beta_fast = rope_scaling.get("beta_fast") + if beta_fast is not None and not isinstance(beta_fast, float): + logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + beta_slow = rope_scaling.get("beta_slow") + if beta_slow is not None and not isinstance(beta_slow, float): + logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + + if (beta_fast or 32) < (beta_slow or 1): + logger.warning( + f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " + f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" + ) + + +# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters +# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE +# parameterizations, as long as the callable has the same signature. +ROPE_INIT_FUNCTIONS = { + "default": _compute_default_rope_parameters, + "yarn": _compute_yarn_parameters, +} + +# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types. +ROPE_VALIDATION_FUNCTIONS = { + "default": _validate_default_rope_parameters, + "yarn": _validate_yarn_parameters, +} + + +def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None): + """ + Validate the RoPE config arguments, given a `PretrainedConfig` object + """ + rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig` + if rope_scaling is None: + return + + # BC: "rope_type" was originally "type" + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) + validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) + if validation_fn is not None: + validation_fn(config, ignore_keys=ignore_keys) + else: + logger.warning( + f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" + ) + + +class AprielSSMHybridConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`AprielModel`]. It is used to instantiate an Apriel + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Apriel-5B-Base. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Apriel model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`AprielModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Apriel-5B-Base supports up to 16384 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to + understand more about it. This value is necessary to ensure exact reproducibility of the pretraining + results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'yarn'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'yarn', 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + head_dim (`int`, *optional*): + The attention head dimension. If None, it will default to hidden_size // num_attention_heads + ```python + >>> from transformers import AprielModel, AprielConfig + >>> # Initializing an Apriel Apriel-5B-Base style configuration + >>> configuration = AprielConfig() + >>> # Initializing a model from the Apriel-5B-Base style configuration + >>> model = AprielModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "apriel" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `AprielModel` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + head_dim=None, + ssm_block_pattern=["m2d"], + ssm_cfg=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + self.ssm_block_pattern = ssm_block_pattern + if len(ssm_block_pattern) == 1: + self.ssm_block_pattern = [ssm_block_pattern[0]] * self.num_hidden_layers + assert len(self.ssm_block_pattern) == self.num_hidden_layers + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + self.ssm_cfg = ssm_cfg or { + "d_state": 64, + "n_v_heads": 24, + "n_qk_heads": 24, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + "d_inner": 24 * self.head_dim, # num_heads * head_dim + } + + +__all__ = ["AprielConfig"] diff --git a/fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py new file mode 100644 index 00000000..49b00986 --- /dev/null +++ b/fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py @@ -0,0 +1,1203 @@ +from dataclasses import dataclass +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from einops import rearrange, repeat +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined +from mamba_ssm.utils.generation import GenerationMixin +from torch import nn +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from transformers.utils.generic import ModelOutput + +from fast_llm.models.ssm.external.configuration_ssm_hybrid_apriel import ROPE_INIT_FUNCTIONS, AprielSSMHybridConfig + +logger = logging.get_logger(__name__) + + +@dataclass +class CustomMambaCausalLMOutput(ModelOutput): + """Custom output class for MambaLMHeadModel.""" + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + + +class AprielRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6, device=None, dtype=None, **kwargs): + """ + AprielRMSNorm is equivalent to T5LayerNorm + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(AprielRMSNorm) + + +class AprielMLP(nn.Module): + def __init__(self, config, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias, **factory_kwargs) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class AprielRotaryEmbedding(nn.Module): + def __init__(self, config: AprielSSMHybridConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class AprielAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: AprielSSMHybridConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +def segsum(x): + """More stable segment sum calculation.""" + # [1, 2, 3] + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + # [[1, 1, 1], [2, 2, 2], [3, 3, 3]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) + x = x.masked_fill(~mask, 0) + # [[0, 0, 0], [2, 0, 0], [3, 3, 0]] + x_segsum = torch.cumsum(x, dim=-2) + # [[0, 0, 0], [2, 0, 0], [5, 3, 0]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def materialize_mixer(A_log, B, C, D): + """ + Since the transfer matrix will be equated to the attention matrix, + we need to support the form: torch.matmul(attn_weights, value_states). + Thus, y = torch.matmul(T, X) + Arguments: + A_log: (batch, length, n_heads) + B: (batch, length, n_heads, d_state) + C: (batch, length, n_heads, d_state) + Return: + T: (batch, n_heads, length, length) + """ + batch_size, length, n_heads, d_state = B.shape + assert A_log.shape == (batch_size, length, n_heads) + assert B.shape == C.shape == (batch_size, length, n_heads, d_state) + + # Compute: + A_log = rearrange(-F.softplus(A_log), "b l h -> b h l") + powers = torch.exp(segsum(A_log)) + T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers) + + # Add D: + if D is not None: + T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1) + + T = rearrange(T, "b h z l -> b h l z") + return T + + +class DiscreteMamba2(nn.Module): + def __init__( + self, + d_model, + d_state=64, + n_qk_heads=32, + n_v_heads=32, + d_conv=4, + expand=1, + activation="identity", + bias=False, + conv_bias=True, + chunk_size=128, + layer_idx=None, + device=None, + dtype=None, + d_inner=None, + **kwargs, # Absorb kwarg for general module + ): + """ + See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. + Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" + + Other options are all experimental and should not need to be configured + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = self.expand * self.d_model if d_inner is None else d_inner + self.n_qk_heads = n_qk_heads + self.n_v_heads = n_v_heads + self.headdim = self.d_inner // self.n_v_heads + assert self.n_v_heads == self.d_inner // self.headdim + assert self.d_inner % self.headdim == 0 + assert self.n_v_heads % self.n_qk_heads == 0 + self.activation = activation + self.chunk_size = chunk_size + self.layer_idx = layer_idx + self.bias = bias + self.kwargs = kwargs + + # Projections + self.in_proj = nn.Linear( + self.d_model, + 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, + bias=bias, + **factory_kwargs, + ) + self.z_bias = ( + nn.Parameter(torch.zeros(self.d_inner, device=device)) if not bias else 0 + ) # make sure z_bias always exists + + # Convolutional layer + conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state + self.conv_bias = conv_bias + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + **factory_kwargs, + ) + + # Activation after conv + if self.activation == "identity": + self.act = nn.Identity() + elif self.activation in ["silu", "swish"]: + self.act = nn.SiLU() + else: + raise ValueError(f"Unknown activation {self.activation}") + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.n_v_heads, device=device)) + self.D._optim = {"weight_decay": 0.0} + + # out_proj + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + + @property + def d_output(self): + return self.d_model + + @property + def state_to_tensor(self): + return self.layer.state_to_tensor + + def forward(self, u, return_mixer_matrix=False, inference_params=None, **kwargs): + """ + u: (B, L, D) + Returns: same shape as u + """ + outputs = {} + # assert state is None + batch, seqlen, dim = u.shape + + state = None + if inference_params is not None: + state = self._get_states_from_cache(inference_params, batch) + if inference_params.seqlen_offset > 0: + # States are updated inplace + out, _ = self.step(u, state) + return {"hidden_states": out} + + # Hacky way to initialize state during inference + chunk_size = self.chunk_size if state is None else seqlen + + # Pad input to nearest multiple of chunklen + padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size + u = F.pad(u, (0, 0, 0, padded_len - seqlen)) + + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + if state is not None: + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") + state["conv"].copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + + # Convolutional layer + xBC = self.convolutional_forward(xBC, padded_len) + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) + B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) + C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + + # SSM forward + result = mamba_chunk_scan_combined( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=A_log, + dt_softplus=True, + A=-torch.ones(self.n_v_heads, device=A_log.device), + B=B, + C=C, + chunk_size=chunk_size, + # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation + return_final_states=(state is not None), + ) + + if state is not None: + y, ssm_state = result + state["ssm"].copy_(ssm_state) + else: + y = result + + Du = torch.einsum("h,blhp->blhp", self.D, x) + y = rearrange(y + Du, "b l h p -> b l (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + outputs["hidden_states"] = out[:, :seqlen, :] + + if return_mixer_matrix: + outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] + return outputs + + def step(self, u, state, **kwargs): + """ + u: (B D) + state: dict of states + Returns: same shape as u + """ + + # Project input + xBCzA_log = self.in_proj(u.squeeze(1)) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + xBC, conv_state = self.convolutional_step(xBC, state["conv"]) + state["conv"].copy_(conv_state) # update state in place + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) + B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) + C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) + + state["ssm"] = state["ssm"].to(x.dtype) + zeros = torch.zeros((self.n_v_heads, self.headdim), device=A_log.device).to(dtype=x.dtype) + ones = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=A_log.device).to(dtype=x.dtype) + y = selective_state_update( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=repeat(A_log, "b h -> b h p", p=self.headdim), + dt_softplus=True, + A=-ones, + B=B, + C=C, + state=state["ssm"], # will be updated in place + dt_bias=zeros, + D=zeros, + ) + + y = y + self.D[:, None] * x + y = rearrange(y, "b h p -> b (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + + return out, state + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + device = self.in_proj.weight.device + # conv_state: + conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + conv_state = torch.zeros( + batch_size, + self.d_conv, + self.conv1d.weight.shape[0], + device=device, + dtype=conv_dtype, + ).transpose(1, 2) + # ssm_state: + ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype + ssm_state = torch.zeros( + batch_size, + self.n_v_heads, + self.headdim, + self.d_state, + device=device, + dtype=ssm_dtype, + ) + return {"conv": conv_state, "ssm": ssm_state} + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + """ + conv_state: (batch, d_conv, conv1d.weight.shape[0]) + ssm_state: (batch, n_qk_heads, headdim, d_state) + """ + assert self.layer_idx is not None + # Allocate memory if not exists + if self.layer_idx not in inference_params.key_value_memory_dict: + inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( + batch_size, inference_params.max_seqlen, dtype=torch.float32 + ) + # Get states + states = inference_params.key_value_memory_dict[self.layer_idx] + if initialize_states: + states["conv"].zero_() + states["ssm"].zero_() + return states + + def convolutional_forward(self, xBC, padded_len): + if causal_conv1d_fn is None or self.activation not in [ + "silu", + "swish", + "identity", + ]: + xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) + else: + xBC = causal_conv1d_fn( + xBC.transpose(1, 2), + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + activation=None if self.activation == "identity" else self.activation, + ).transpose(1, 2) + return xBC + + def convolutional_step(self, xBC, conv_state): + # Convolutional layer + conv_state = conv_state.to(xBC.dtype) + if causal_conv1d_update: + xBC = causal_conv1d_update( + xBC, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation if self.activation != "identity" else None, + ) + else: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = xBC + xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv_bias: + xBC = xBC + self.conv1d.bias + xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype + + return xBC, conv_state + + +class AprielDecoderLayer(nn.Module): + def __init__(self, config: AprielSSMHybridConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = AprielAttention(config=config, layer_idx=layer_idx) + + self.mlp = AprielMLP(config) + self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + inference_params=None, # just to be compatible with SSM block + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class AprielSSMDecoderLayer(nn.Module): + def __init__(self, config: AprielSSMHybridConfig, layer_idx: int, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} + self.hidden_size = config.hidden_size + + self.mixer = DiscreteMamba2( + d_model=config.hidden_size, + layer_idx=layer_idx, + **config.ssm_cfg, + **factory_kwargs, + ) + + self.mlp = AprielMLP(config, **factory_kwargs) + self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) + self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) + + def forward( + self, hidden_states: torch.Tensor, inference_params=None, **kwargs + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + + outputs = {} + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + mixer_outputs = self.mixer( + hidden_states, + inference_params=inference_params, + ) + + hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + # outputs["hidden_states"] = hidden_states + outputs = (hidden_states,) + + return outputs + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + """Allocate inference cache for the model.""" + if getattr(self.mixer, "allocate_inference_cache", None) is None: + return + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + +APRIEL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`AprielSSMHybridConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Apriel Model outputting raw hidden-states without any specific head on top.", + APRIEL_START_DOCSTRING, +) +class AprielSSMPreTrainedModel(PreTrainedModel): + config_class = AprielSSMHybridConfig + base_model_prefix = "model" + _no_split_modules = ["AprielDecoderLayer", "AprielSSMDecoderLayer"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def allocate_inference_cache(self, *args, **kwargs): + """Allocate inference cache for the model.""" + return getattr(self, self.base_model_prefix).allocate_inference_cache(*args, **kwargs) + + +APRIEL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Apriel Model outputting raw hidden-states without any specific head on top.", + APRIEL_START_DOCSTRING, +) +class AprielSSMHybridModel(AprielSSMPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`, `AprielSSMDecoderLayer`] + Args: + config: AprielSSMHybridConfig + """ + + def __init__(self, config: AprielSSMHybridConfig, device=None, dtype=None, **kwargs): + super().__init__(config, device=device, dtype=dtype, **kwargs) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + factory_kwargs = {"device": device, "dtype": dtype} + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, **factory_kwargs) + blocks = [] + for layer_idx, type in enumerate(config.ssm_block_pattern): + if type == "m2d": + blocks.append(AprielSSMDecoderLayer(config, layer_idx, **factory_kwargs)) + elif type == "t": + blocks.append(AprielDecoderLayer(config, layer_idx)) + else: + raise ValueError(f"Invalid block type: {type}") + self.layers = nn.ModuleList(blocks) + self.norm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) + self.gradient_checkpointing = False + self.rotary_emb = AprielRotaryEmbedding(config=config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def allocate_inference_cache(self, *args, **kwargs): + """Allocate inference cache for the model.""" + cache = {} + for i, layer in enumerate(self.layers): + if isinstance(layer, AprielSSMDecoderLayer): + cache[i] = layer.allocate_inference_cache(*args, **kwargs) + return cache + + @add_start_docstrings_to_model_forward(APRIEL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + inference_params=None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + inference_params=inference_params, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions and isinstance(decoder_layer, AprielDecoderLayer): + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class AprielSSMHybridForCausalLM(AprielSSMPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config, device=None, dtype=None, **kwargs): + super().__init__(config, device=device, dtype=dtype, **kwargs) + self.model = AprielSSMHybridModel(config) + self.vocab_size = config.vocab_size + factory_kwargs = {"device": device, "dtype": dtype} + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, **factory_kwargs) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids=None, + return_hidden_states=False, + return_logits=True, + inference_params=None, + num_last_tokens=0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[tuple, CausalLMOutputWithPast]: + + outputs = self.model( + input_ids, + return_hidden_states=return_hidden_states, + inference_params=inference_params, + position_ids=position_ids, + ) + + if outputs["last_hidden_state"] is not None and return_logits: + logits = self.lm_head(outputs["last_hidden_state"]).float() + outputs["logits"] = logits if num_last_tokens == 0 else logits[:, -num_last_tokens:] + else: + outputs["logits"] = None + + return CustomMambaCausalLMOutput( + loss=None, + logits=outputs["logits"], + all_hidden_states=outputs["all_hidden_states"], + last_hidden_state=outputs["last_hidden_state"], + ) + + def generate(self, *args, **kwargs): + """ + This is a wrapper to make sure we comply with the HF generation interface for eval harness + """ + return super().generate(*args, **kwargs) + + +__all__ = [ + "AprielSSMForCausalLM", + "AprielModel", + "AprielSSMPreTrainedModel", +] From 9a678df83ff6fe1a5a1f5b447b616836c0e6b5c3 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 2 May 2025 12:59:54 +0000 Subject: [PATCH 052/114] sft distill --- fast_llm/models/gpt/config.py | 6 +++--- fast_llm/models/ssm/config.py | 10 ++++------ 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index b82dd3e8..1ddb7ed2 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -193,9 +193,9 @@ def _validate(self) -> None: Assert.eq(self.reference_models.keys(), {name}) if self.model.base_model.use_absolute_position_embeddings: Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) - if self.model.base_model.distillation_model is not None: - # TODO: Support loss masking for distillation? - assert not self.batch.use_loss_masking_spans + # if self.model.base_model.distillation_model is not None: + # # TODO: Support loss masking for distillation? + # assert not self.batch.use_loss_masking_spans for reference_model in self.reference_models.values(): Assert.none(reference_model.model.base_model.distillation_model) # TODO: Support more LM head features. diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index d77d206b..1d8ac007 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -172,9 +172,7 @@ class PretrainedHybridSSMModelConfig(PretrainedFastLLMModelConfig): class HybridTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) batch: GPTBatchConfig = FieldUpdate(default_factory=GPTBatchConfig) - reference_models: dict[str, PretrainedGPTModelConfig] = ( - FieldUpdate() - ) # TODO: make sure any reference mdoel can be suported + reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() @classmethod def get_trainer_class(cls) -> type["SSMTrainer"]: @@ -190,9 +188,9 @@ def _validate(self) -> None: Assert.eq(self.reference_models.keys(), {name}) if self.model.base_model.use_absolute_position_embeddings: Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) - if self.model.base_model.distillation_model is not None: - # TODO: Support loss masking for distillation? - assert not self.batch.use_loss_masking_spans + # if self.model.base_model.distillation_model is not None: + # # TODO: Support loss masking for distillation? + # assert not self.batch.use_loss_masking_spans for reference_model in self.reference_models.values(): Assert.none(reference_model.model.base_model.distillation_model) # TODO: Support more LM head features. From a7abe53383286f92fd26ad5ed93f11010e52e4c9 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 2 May 2025 14:21:32 +0000 Subject: [PATCH 053/114] conversion --- fast_llm/models/ssm/conversion.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index c2e54ca0..3d3aa728 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -52,11 +52,11 @@ def _import_config(cls, config, architecture_only: bool = False): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: if cls.block_pattern is not None: - block_converter = MappedConfigParamConverter( - fast_llm_names=(("hybrid_block_layout",),), - export_names=(("hybrid_block_layout",),), - fast_llm_value=cls.block_pattern, - export_value=cls.block_pattern, + block_converter = ( + RenameParamConverter( + fast_llm_names=(("hybrid_block_layout",),), + export_names=(("hybrid_block_layout",),), + ), ) else: block_converter = ConstantImportParamConverter( From a68c0b7318ec07abebea8acc0c035f9f8877017e Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 2 May 2025 14:30:10 +0000 Subject: [PATCH 054/114] conversion --- fast_llm/models/ssm/conversion.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 3d3aa728..675c709f 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -52,11 +52,9 @@ def _import_config(cls, config, architecture_only: bool = False): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: if cls.block_pattern is not None: - block_converter = ( - RenameParamConverter( - fast_llm_names=(("hybrid_block_layout",),), - export_names=(("hybrid_block_layout",),), - ), + block_converter = RenameParamConverter( + fast_llm_names=(("hybrid_block_layout",),), + export_names=(("hybrid_block_layout",),), ) else: block_converter = ConstantImportParamConverter( From 9cfef449bb232e6c3895f7406f04e67b7a62fea9 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 2 May 2025 15:30:33 +0000 Subject: [PATCH 055/114] lr stage definition as string --- fast_llm/engine/optimizer/learning_rate.py | 26 +++++++++++----------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/fast_llm/engine/optimizer/learning_rate.py b/fast_llm/engine/optimizer/learning_rate.py index bf11038a..c6912e4f 100644 --- a/fast_llm/engine/optimizer/learning_rate.py +++ b/fast_llm/engine/optimizer/learning_rate.py @@ -120,19 +120,19 @@ def create_schedule_from_config(config: LearningRateScheduleConfig) -> LearningR begin_step = 0 for stage_arg_str in config.schedule.split(";"): try: - for stage_type, num_steps, lr, *stage_args in stage_arg_str.split(","): - assert begin_step is not None - num_steps = int(num_steps) - end_step = None if num_steps < 0 else begin_step + num_steps - kwargs = {"begin_step": begin_step, "end_step": end_step, "lr": float(lr)} - if len(stage_args) > 0: - kwargs["end_lr"] = float(stage_args[0]) - if len(stage_args) > 1: - kwargs["power"] = float(stage_args[1]) - if len(stage_args) > 2: - raise ValueError(stage_args[2:]) - stages.append(_STAGE_TYPE_MAP[stage_type](**kwargs)) - begin_step = end_step + stage_type, num_steps, lr, *stage_args = stage_arg_str.split(",") + assert begin_step is not None + num_steps = int(num_steps) + end_step = None if num_steps < 0 else begin_step + num_steps + kwargs = {"begin_step": begin_step, "end_step": end_step, "lr": float(lr)} + if len(stage_args) > 0: + kwargs["end_lr"] = float(stage_args[0]) + if len(stage_args) > 1: + kwargs["power"] = float(stage_args[1]) + if len(stage_args) > 2: + raise ValueError(stage_args[2:]) + stages.append(_STAGE_TYPE_MAP[stage_type](**kwargs)) + begin_step = end_step except Exception: raise ValueError(f'Cannot parse optimizer stage definition "{stage_arg_str}"') return LearningRateSchedule(stages) From 005e623e08936f7addad87a083fde3aab176ceef Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 2 May 2025 11:59:15 -0400 Subject: [PATCH 056/114] fixes --- fast_llm/functional/triton/mlp.py | 12 +++++++----- fast_llm/layers/ssm/llamba_block.py | 2 +- fast_llm/layers/transformer/transformer.py | 3 +-- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index 5b220b1a..ee3ba304 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -50,19 +50,21 @@ def triton_mlp_activation_forward_kernel( input_ = tl.load(input_ptr, mask=mask).to(tl.float32) - if activation_type == _TritonActivationType.gelu.value: + if activation_type == _TritonActivationType.gelu: tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_) tanh = 1 - 2 / (1 + tl.exp(2 * tanh_input)) out = input_ * 0.5 * (1.0 + tanh) - elif activation_type == _TritonActivationType.silu.value: + elif activation_type == _TritonActivationType.silu: out = input_ / (1 + tl.exp(-input_)) - elif activation_type == _TritonActivationType.relu.value: + elif activation_type == _TritonActivationType.relu: out = tl.where(input_ > 0, input_, 0) elif activation_type == _TritonActivationType.squared_relu: relu_out = tl.where(input_ > 0, input_, 0) out = relu_out * relu_out + elif activation_type == _TritonActivationType.identity: + out = input_ else: - raise NotImplementedError() + tl.static_assert(False, activation_type) if gated: other = tl.load(input_ptr + n_cols, mask=mask) @@ -124,7 +126,7 @@ def triton_mlp_activation_backward_kernel( if gated or recompute: out = input_ else: - raise NotImplementedError() + tl.static_assert(False, activation_type) if gated: other = tl.load(input_ptr + n_cols, mask=mask) diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py index 22135638..ee222d6d 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/llamba_block.py @@ -13,7 +13,7 @@ class LlambaBlock(BaseBlock): A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 """ - name = "Llamba block" + _name = "Llamba block" _mixer_module_name = "mixer" def __init__( diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 92df1893..40dd2e00 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -23,7 +23,6 @@ class BaseBlock(Layer, abc.ABC): A transformer-like decoder base block block with abstract mixer. """ - name = "Transformer layer" _mixer_module_name = "self_attn" def __init__( @@ -137,7 +136,7 @@ def forward( class TransformerLayer(BaseBlock): - name = "Transformer layer" + _name = "Transformer layer" _mixer_module_name = "self_attn" def __init__( From cad951aafd0f2ac8048929c8fea5b5b67c0f6a42 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 2 May 2025 12:31:07 -0400 Subject: [PATCH 057/114] fix --- fast_llm/layers/language_model/config.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 4fb471fb..b4b4e187 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -215,8 +215,6 @@ def _validate(self) -> None: super()._validate() if self.init_method_max_embed is not None and self.init_method_min_embed is not None: Assert.leq(self.init_method_min_embed, self.init_method_max_embed) - if self.prediction_heads > 1: - Assert.gt(self.transformer.num_layers, 1) if self.distillation_model is not None: if self.prediction_heads > 1: raise NotImplementedError("Multi-token prediction not supported with distillation.") From bce916d01566541ee364bcd2fd395db8f4010ff6 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 2 May 2025 17:07:43 +0000 Subject: [PATCH 058/114] loss maks --- fast_llm/functional/cross_entropy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 34c69d79..401cfe07 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -187,9 +187,9 @@ def cross_entropy_forward_backward( if group: Assert.eq(implementation, CrossEntropyImpl.fused) return _fused_cross_entropy_forward_backward( - logits, target, grad_output, logits_scale_factor, target_format, group + logits, target, loss_mask, grad_output, logits_scale_factor, target_format, group ) else: return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( - logits, target, grad_output, logits_scale_factor, target_format + logits, target, loss_mask, grad_output, logits_scale_factor, target_format ) From 9d9506418cc290c549af6aa6c9b8040f0ea7e1e5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 2 May 2025 13:14:02 -0400 Subject: [PATCH 059/114] fix --- fast_llm/functional/cross_entropy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 34c69d79..401cfe07 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -187,9 +187,9 @@ def cross_entropy_forward_backward( if group: Assert.eq(implementation, CrossEntropyImpl.fused) return _fused_cross_entropy_forward_backward( - logits, target, grad_output, logits_scale_factor, target_format, group + logits, target, loss_mask, grad_output, logits_scale_factor, target_format, group ) else: return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( - logits, target, grad_output, logits_scale_factor, target_format + logits, target, loss_mask, grad_output, logits_scale_factor, target_format ) From 935c470b4ade91462a92cfe26aae9be1e32e2154 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 2 May 2025 13:30:28 -0400 Subject: [PATCH 060/114] fix --- fast_llm/data/dataset/gpt/sampled.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index f3633a76..065eb94d 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -187,7 +187,7 @@ def _sample(self) -> None: if self._yaml_path is not None and self._yaml_path.is_file(): loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) - self._load_yaml_data(yaml_data) + self._load_yaml_data(loaded_yaml_data) if not self._truncate_documents: del loaded_yaml_data["unshuffled_tokens"] From 9aff3b70cb13d57920b98d4d1b65926fdea5fa5f Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 2 May 2025 17:39:10 +0000 Subject: [PATCH 061/114] fix shuffled tokens --- fast_llm/data/dataset/gpt/sampled.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index f3633a76..065eb94d 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -187,7 +187,7 @@ def _sample(self) -> None: if self._yaml_path is not None and self._yaml_path.is_file(): loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) - self._load_yaml_data(yaml_data) + self._load_yaml_data(loaded_yaml_data) if not self._truncate_documents: del loaded_yaml_data["unshuffled_tokens"] From ae4d111a26105362ec4f49aeea8da77e522f98c0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 2 May 2025 15:08:40 -0400 Subject: [PATCH 062/114] fixes --- tests/layers/test_lm_head.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 14edecff..b32292bd 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -126,15 +126,13 @@ def test_lm_head( else (BATCH_SIZE, SEQUENCE_LENGTH + config.prediction_heads - 1) ) if loss_masking: - loss_mask = torch.randint( - 0, - VOCAB_SIZE, - label_shape, - dtype=torch.bool, - device=distributed.device, - ) + loss_mask = torch.randint(0, 2, label_shape, dtype=torch.bool, device=distributed.device) else: loss_mask = None + kwargs = { + TransformerKwargs.sequence_first: sequence_first, + TransformerKwargs.grad_output: 1.0, + } if config.distillation_model is None: target = torch.randint( 0, @@ -145,14 +143,17 @@ def test_lm_head( ) if loss_mask is not None: target *= loss_mask + + kwargs[LanguageModelKwargs.labels] = target else: assert config.prediction_heads == 1 - target = torch.randn_like(input_) - kwargs = { - TransformerKwargs.sequence_first: sequence_first, - LanguageModelKwargs.labels: target, - TransformerKwargs.grad_output: 1.0, - } + target = torch.randn( + input_.shape[:-1] + (VOCAB_SIZE,), + dtype=input_.dtype, + device=distributed.device, + ) + kwargs[f"{config.distillation_model}_logits"] = target + if config.tie_word_embeddings or config.prediction_heads > 1: logit_weight = ( torch.empty( From deb7ce66a6bb15e60a493845a71b4f6ee9366ea2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 2 May 2025 18:00:13 -0400 Subject: [PATCH 063/114] fixes --- fast_llm/functional/cross_entropy.py | 12 +++++++----- fast_llm/functional/triton/cross_entropy.py | 16 ++++++++++++---- fast_llm/layers/language_model/head.py | 2 ++ tests/layers/test_lm_head.py | 10 +++++++--- 4 files changed, 28 insertions(+), 12 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 401cfe07..513510ec 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -38,7 +38,7 @@ def _torch_cross_entropy_forward_backward( torch.nn.functional.cross_entropy( logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none" ) - * loss_mask.unsqueeze(-1) + * loss_mask ).mean() if grad_output is None: grad = None @@ -48,7 +48,7 @@ def _torch_cross_entropy_forward_backward( return loss.detach_(), grad -# @torch.compile +@torch.compile def _fused_softmax_base( logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -74,7 +74,7 @@ def _fused_softmax( return exp_logits / sum_exp_logits -@torch.compile +# @torch.compile def _fused_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, @@ -113,6 +113,8 @@ def _fused_cross_entropy_forward_backward( else: # Target should be tensor-parallel already, no further manipulation needed. target_mask = None + if loss_mask is not None: + loss_mask = loss_mask.unsqueeze(-1) if grad_output is None: grad = None @@ -128,9 +130,9 @@ def _fused_cross_entropy_forward_backward( grad = grad_base.mul((grad_output / logits.size(0)) / sum_exp_logits) if logits_scale_factor != 1.0: grad *= logits_scale_factor - grad = grad.to(logits.dtype) if loss_mask is not None: - grad = torch.where(loss_mask, grad.to(logits.dtype), 0) + grad *= loss_mask + grad = grad.to(logits.dtype) # loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) if target_format == TargetFormat.labels: diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 02dc1ce7..8cb59c85 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -64,7 +64,6 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( n_cols: tl_constexpr, logits_stride_0: tl_constexpr, target_stride_0: tl_constexpr, - loss_mask_stride_0: tl_constexpr, grad_logits_stride_0: tl_constexpr, logits_scale_factor: tl_constexpr, from_logits: tl_constexpr, @@ -75,6 +74,14 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( col_offsets = tl.arange(0, block_size) mask = col_offsets < n_cols + if loss_mask_ptr is not None: + loss_mask = tl.load(loss_mask_ptr + block_idx) + if loss_mask == 0: + tl.store(losses_ptr + block_idx, 0) + if grad_losses is not None: + tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, 0, mask=mask) + return + logits = tl.load(logits_ptr + block_idx * logits_stride_0 + col_offsets, mask=mask, other=-float("inf")).to( tl.float32 ) @@ -89,8 +96,6 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( target = tl.load(target_ptr + block_idx * target_stride_0 + col_offsets, mask=mask, other=-float("inf")).to( tl.float32 ) - if loss_mask_ptr is not None: - loss_mask = tl.load(target_ptr + block_idx * target_stride_0 + col_offsets, mask=mask, other=0) if from_logits: if logits_scale_factor != 1.0: target *= logits_scale_factor @@ -108,6 +113,8 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( grad_logits = grad_losses * (exp_logits / sum_exp_logits - target) if logits_scale_factor != 1.0: grad_logits *= logits_scale_factor + if loss_mask_ptr is not None: + grad_logits = grad_logits tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) @@ -151,6 +158,8 @@ def triton_cross_entropy_forward_backward( num_warps=num_warps, ) else: + if loss_mask is not None: + assert loss_mask.is_contiguous() triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( logits, target, @@ -161,7 +170,6 @@ def triton_cross_entropy_forward_backward( n_cols, logits.stride(0), target.stride(0), - None if loss_mask is None else loss_mask.stride(0), None if grad_output is None else grad_logits.stride(0), logits_scale_factor, block_size=block_size, diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 9b1dd4d8..813dcc07 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -163,6 +163,8 @@ def _forward_backward( # Target is reference model logits. target = target.flatten(0, -2) loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) + if loss_mask is not None: + loss_mask = loss_mask.flatten() if self._sequence_parallel_logits: target = split_op(target, self._tensor_space.distributed.tensor_group, 0) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index b32292bd..7578a5f0 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -40,14 +40,16 @@ def _lm_head( rms_weight, 1e-5, ) - logits = torch.nn.functional.linear(hidden, logit_weight) + logits = torch.nn.functional.linear(hidden, logit_weight).float() if logit_scale_factor != 1.0: logits *= logit_scale_factor z_loss = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) if logit_z_loss > 0 else None if target.ndim == logits.ndim: - loss = torch.nn.functional.cross_entropy(logits, target, reduction="none") + loss = torch.nn.functional.cross_entropy( + logits.flatten(0, -2), target.float().softmax(-1).flatten(0, -2), reduction="none" + ) if loss_mask is not None: - loss = loss * loss_mask.unsqueeze(-1) + loss = loss * loss_mask.flatten() loss = loss.mean() else: loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) @@ -153,6 +155,8 @@ def test_lm_head( device=distributed.device, ) kwargs[f"{config.distillation_model}_logits"] = target + if loss_mask is not None: + kwargs[LanguageModelKwargs.loss_mask] = loss_mask if config.tie_word_embeddings or config.prediction_heads > 1: logit_weight = ( From eaba34f66730d06e880b209f97f0560eead0e510 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 2 May 2025 22:01:04 +0000 Subject: [PATCH 064/114] innit like in mamba in llama --- .../models/ssm/external/ariel_to_ssm.ipynb | 963 ++++-------------- 1 file changed, 213 insertions(+), 750 deletions(-) diff --git a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb index 496338cb..664d927f 100644 --- a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb +++ b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb @@ -29,6 +29,13 @@ "%autoreload 2\n" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Apriel SSM for distillation" + ] + }, { "cell_type": "code", "execution_count": 3, @@ -115,35 +122,6 @@ "apriel_model.config.torch_dtype" ] }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "n_params = sum(p.numel() for p in apriel_model.parameters() if p.requires_grad)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "4.83207168" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "n_params/1e9" - ] - }, { "cell_type": "code", "execution_count": 8, @@ -161,62 +139,6 @@ "config_apriel = AprielSSMConfig.from_pretrained(\"/mnt/checkpoints_fml/pretrained_models/ssm/apriel_ssm_instruct_base\", trust_remote_code=True)" ] }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n", - "You are using a model of type llamba to instantiate a model of type apriel_ssm. This is not supported for all configurations of models and can yield errors.\n" - ] - }, - { - "ename": "KeyError", - "evalue": "'n_qk_heads'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[12], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m stage2_checkpoint \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/mnt/checkpoints_fml/pretrained_models/ssm/mohawk_final\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 2\u001b[0m stage2_apriel_ssm \u001b[38;5;241m=\u001b[39m \u001b[43mAprielSSMForCausalLM\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstage2_checkpoint\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtorch_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbfloat16\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/modeling_utils.py:3571\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 3569\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(config, PretrainedConfig):\n\u001b[1;32m 3570\u001b[0m config_path \u001b[38;5;241m=\u001b[39m config \u001b[38;5;28;01mif\u001b[39;00m config \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m pretrained_model_name_or_path\n\u001b[0;32m-> 3571\u001b[0m config, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconfig_class\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3572\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3573\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3574\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_unused_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 3575\u001b[0m \u001b[43m \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3576\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3577\u001b[0m \u001b[43m \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3578\u001b[0m \u001b[43m \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3579\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3580\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3581\u001b[0m \u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msubfolder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3582\u001b[0m \u001b[43m \u001b[49m\u001b[43m_from_auto\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfrom_auto_class\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3583\u001b[0m \u001b[43m \u001b[49m\u001b[43m_from_pipeline\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfrom_pipeline\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3584\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3585\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3586\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3587\u001b[0m \u001b[38;5;66;03m# In case one passes a config to `from_pretrained` + \"attn_implementation\"\u001b[39;00m\n\u001b[1;32m 3588\u001b[0m \u001b[38;5;66;03m# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 3592\u001b[0m \u001b[38;5;66;03m# we pop attn_implementation from the kwargs but this handles the case where users\u001b[39;00m\n\u001b[1;32m 3593\u001b[0m \u001b[38;5;66;03m# passes manually the config to `from_pretrained`.\u001b[39;00m\n\u001b[1;32m 3594\u001b[0m config \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(config)\n", - "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/configuration_utils.py:569\u001b[0m, in \u001b[0;36mPretrainedConfig.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, **kwargs)\u001b[0m\n\u001b[1;32m 563\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_type:\n\u001b[1;32m 564\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarning(\n\u001b[1;32m 565\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou are using a model of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mconfig_dict[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodel_type\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m to instantiate a model of type \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 566\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. This is not supported for all configurations of models and can yield errors.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 567\u001b[0m )\n\u001b[0;32m--> 569\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/configuration_utils.py:740\u001b[0m, in \u001b[0;36mPretrainedConfig.from_dict\u001b[0;34m(cls, config_dict, **kwargs)\u001b[0m\n\u001b[1;32m 737\u001b[0m \u001b[38;5;66;03m# We remove it from kwargs so that it does not appear in `return_unused_kwargs`.\u001b[39;00m\n\u001b[1;32m 738\u001b[0m config_dict[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattn_implementation\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattn_implementation\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m--> 740\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mconfig_dict\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 742\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(config, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpruned_heads\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 743\u001b[0m config\u001b[38;5;241m.\u001b[39mpruned_heads \u001b[38;5;241m=\u001b[39m {\u001b[38;5;28mint\u001b[39m(key): value \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m config\u001b[38;5;241m.\u001b[39mpruned_heads\u001b[38;5;241m.\u001b[39mitems()}\n", - "File \u001b[0;32m~/dev/Fast-LLM/fast_llm/models/ssm/external/configuration_ssm_apriel.py:99\u001b[0m, in \u001b[0;36mAprielSSMConfig.__init__\u001b[0;34m(self, vocab_size, hidden_size, intermediate_size, num_hidden_layers, hidden_act, initializer_range, use_cache, pad_token_id, bos_token_id, eos_token_id, tie_word_embeddings, mlp_bias, rms_norm_eps, ssm_cfg, head_dim, **kwargs)\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(\n\u001b[1;32m 82\u001b[0m pad_token_id\u001b[38;5;241m=\u001b[39mpad_token_id,\n\u001b[1;32m 83\u001b[0m bos_token_id\u001b[38;5;241m=\u001b[39mbos_token_id,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 87\u001b[0m )\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mssm_cfg \u001b[38;5;241m=\u001b[39m ssm_cfg \u001b[38;5;129;01mor\u001b[39;00m {\n\u001b[1;32m 90\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_state\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m64\u001b[39m,\n\u001b[1;32m 91\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mn_v_heads\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m24\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m24\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhead_dim, \u001b[38;5;66;03m# num_heads * head_dim\u001b[39;00m\n\u001b[1;32m 98\u001b[0m }\n\u001b[0;32m---> 99\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhead_dim \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mssm_cfg[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124md_inner\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mssm_cfg\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mn_qk_heads\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\n", - "\u001b[0;31mKeyError\u001b[0m: 'n_qk_heads'" - ] - } - ], - "source": [ - "stage2_checkpoint = \"/mnt/checkpoints_fml/pretrained_models/ssm/mohawk_final\"\n", - "stage2_apriel_ssm = AprielSSMForCausalLM.from_pretrained(stage2_checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "apriel_ssm_config = AprielSSMConfig(vocab_size=config.vocab_size, \n", - " hidden_size=config.hidden_size,\n", - " intermediate_size=config.intermediate_size,\n", - " num_hidden_layers=config.num_hidden_layers,\n", - " hidden_act=config.hidden_act,\n", - " initializer_range=config.initializer_range,\n", - " use_cache=config.use_cache,\n", - " mlp_bias=config.mlp_bias,\n", - " tie_word_embeddings=config.tie_word_embeddings,\n", - " pad_token_id=config.pad_token_id,\n", - " bos_token_id=config.bos_token_id,\n", - " eos_token_id=config.eos_token_id,\n", - " head_dim=config.head_dim,\n", - " rms_norm_eps=config.rms_norm_eps)" - ] - }, { "cell_type": "code", "execution_count": 10, @@ -2330,15 +2252,6 @@ "apriel_ssm.to(device).to(dtype=torch.bfloat16)" ] }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# apriel_ssm.state_dict()" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -2503,20 +2416,20 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Load mdoel" + "## Load Apriel SSM into HF class" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 130, "metadata": {}, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" ] } ], @@ -2632,12 +2545,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Mamba in Llama" + "# Mamba in Llama: SSM hybrid " ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 90, "metadata": {}, "outputs": [ { @@ -2660,7 +2573,7 @@ "from transformers.cache_utils import StaticCache\n", "from types import SimpleNamespace\n", "from fast_llm.models.ssm.external.modeling_ssm_hybrid_apriel import AprielSSMHybridConfig\n", - "from fast_llm.models.ssm.external.modeling_ssm_hybrid_apriel import AprielSSMHybridModel\n", + "from fast_llm.models.ssm.external.modeling_ssm_hybrid_apriel import AprielSSMHybridModel, AprielSSMDecoderLayer\n", "# from fast_llm.models.ssm.external.__hybrid_wrapper import MambaTransformerHybridModelWrapper\n", "# make sure the code changes reflected without reload\n", "%load_ext autoreload\n", @@ -2669,146 +2582,104 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 81, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AprielSSMHybridConfig {\n", - " \"_name_or_path\": \"ServiceNow-AI/Apriel-5B-Instruct\",\n", - " \"architectures\": [\n", - " \"AprielForCausalLM\"\n", - " ],\n", - " \"attention_bias\": false,\n", - " \"attention_dropout\": 0.0,\n", - " \"auto_map\": {\n", - " \"AutoConfig\": \"ServiceNow-AI/Apriel-5B-Instruct--configuration_apriel.AprielConfig\",\n", - " \"AutoModelForCausalLM\": \"ServiceNow-AI/Apriel-5B-Instruct--modeling_apriel.AprielForCausalLM\"\n", - " },\n", - " \"bos_token_id\": 1,\n", - " \"eos_token_id\": 2,\n", - " \"head_dim\": 128,\n", - " \"hidden_act\": \"silu\",\n", - " \"hidden_size\": 4096,\n", - " \"initializer_range\": 0.02,\n", - " \"intermediate_size\": 8192,\n", - " \"max_position_embeddings\": 16384,\n", - " \"mlp_bias\": false,\n", - " \"model_type\": \"apriel\",\n", - " \"num_attention_heads\": 24,\n", - " \"num_hidden_layers\": 28,\n", - " \"num_key_value_heads\": 8,\n", - " \"pretraining_tp\": 1,\n", - " \"rms_norm_eps\": 1e-05,\n", - " \"rope_scaling\": {\n", - " \"attention_factor\": null,\n", - " \"beta_fast\": 32.0,\n", - " \"beta_slow\": 1.0,\n", - " \"factor\": 32.0,\n", - " \"original_max_position_embeddings\": 4096,\n", - " \"rope_type\": \"yarn\"\n", - " },\n", - " \"rope_theta\": 1000000.0,\n", - " \"ssm_block_pattern\": [\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\",\n", - " \"m2d\",\n", - " \"t\"\n", - " ],\n", - " \"ssm_cfg\": {\n", - " \"activation\": \"identity\",\n", - " \"bias\": false,\n", - " \"chunk_size\": 128,\n", - " \"d_inner\": 3072,\n", - " \"d_state\": 64,\n", - " \"expand\": 1,\n", - " \"n_qk_heads\": 24,\n", - " \"n_v_heads\": 24\n", - " },\n", - " \"tie_word_embeddings\": false,\n", - " \"torch_dtype\": \"bfloat16\",\n", - " \"transformers_version\": \"4.48.1\",\n", - " \"use_cache\": true,\n", - " \"vocab_size\": 131072\n", - "}" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "\n", "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", + "\n", + "# d_xb = config.num_key_value_heads * config.head_dim\n", + "d_inner = config.num_attention_heads * config.head_dim\n", + "d_state = config.head_dim\n", "hybrdif_apriel_config = AprielSSMHybridConfig(**config.to_dict(),\n", " ssm_block_pattern=[\"m2d\", \"t\"] * 14,\n", - " ssm_cfg=None)\n", - "hybrdif_apriel_config" + " ssm_cfg={\n", + " \"d_state\": 64,\n", + " \"n_v_heads\": 24,\n", + " \"n_qk_heads\": 24,\n", + " # \"d_xb\": d_xb,\n", + " \"expand\": 1,\n", + " \"chunk_size\": 128,\n", + " \"activation\": \"identity\",\n", + " \"bias\": False,\n", + " \"d_inner\": 24 * 128, # num_heads * head_dim\n", + " })\n", + "# hybrdif_apriel_config" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 87, + "metadata": {}, + "outputs": [], + "source": [ + "hybrid_apriel_model = AprielSSMHybridModel(hybrdif_apriel_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 88, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "28" + "AprielSSMDecoderLayer(\n", + " (mixer): DiscreteMamba2(\n", + " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", + " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", + " (act): Identity()\n", + " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + ")" ] }, - "execution_count": 15, + "execution_count": 88, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "config.num_hidden_layers" + "hybrid_apriel_model.layers[0]" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 91, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 91, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "hybrid_apriel_model = AprielSSMHybridModel(hybrdif_apriel_config)" + "isinstance(hybrid_apriel_model.layers[0], AprielSSMDecoderLayer)" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 84, "metadata": {}, "outputs": [], "source": [ - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "device = \"cpu\" #if torch.cuda.is_available() else \"cpu\"\n", "input_ids = torch.randint(0, 32000, (1, 128), dtype=torch.long, device=device)\n", "batch_size = 1\n", "max_length = 128\n", @@ -2824,472 +2695,24 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 73, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "AprielSSMHybridModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (1): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (2): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (3): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (4): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (5): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (6): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (7): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (8): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (9): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (10): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (11): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (12): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (13): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (14): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (15): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (16): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (17): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (18): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (19): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (20): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (21): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (22): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (23): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (24): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (25): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (26): AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (27): AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (rotary_emb): AprielRotaryEmbedding()\n", - ")" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" + "ename": "OutOfMemoryError", + "evalue": "CUDA out of memory. Tried to allocate 2.00 GiB. GPU 0 has a total capacity of 79.10 GiB of which 1.72 GiB is free. Process 191417 has 19.83 GiB memory in use. Process 1524280 has 57.54 GiB memory in use. Of the allocated memory 18.11 GiB is allocated by PyTorch, and 1.05 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[73], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mhybrid_apriel_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mto(dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mbfloat16)\n", + "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/modeling_utils.py:3110\u001b[0m, in \u001b[0;36mPreTrainedModel.to\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 3105\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype_present_in_args:\n\u001b[1;32m 3106\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 3107\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3108\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m `dtype` by passing the correct `torch_dtype` argument.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3109\u001b[0m )\n\u001b[0;32m-> 3110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1174\u001b[0m, in \u001b[0;36mModule.to\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1171\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1172\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n\u001b[0;32m-> 1174\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconvert\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:780\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 778\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[1;32m 779\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchildren():\n\u001b[0;32m--> 780\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 782\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[1;32m 783\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[1;32m 784\u001b[0m \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[1;32m 785\u001b[0m \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[1;32m 791\u001b[0m \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:805\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 801\u001b[0m \u001b[38;5;66;03m# Tensors stored in modules are graph leaves, and we don't want to\u001b[39;00m\n\u001b[1;32m 802\u001b[0m \u001b[38;5;66;03m# track autograd history of `param_applied`, so we have to use\u001b[39;00m\n\u001b[1;32m 803\u001b[0m \u001b[38;5;66;03m# `with torch.no_grad():`\u001b[39;00m\n\u001b[1;32m 804\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m--> 805\u001b[0m param_applied \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparam\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 806\u001b[0m p_should_use_set_data \u001b[38;5;241m=\u001b[39m compute_should_use_set_data(param, param_applied)\n\u001b[1;32m 808\u001b[0m \u001b[38;5;66;03m# subclasses may have multiple child tensors so we need to use swap_tensors\u001b[39;00m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1160\u001b[0m, in \u001b[0;36mModule.to..convert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 1153\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m convert_to_format \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m t\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;241m4\u001b[39m, \u001b[38;5;241m5\u001b[39m):\n\u001b[1;32m 1154\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m t\u001b[38;5;241m.\u001b[39mto(\n\u001b[1;32m 1155\u001b[0m device,\n\u001b[1;32m 1156\u001b[0m dtype \u001b[38;5;28;01mif\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_floating_point() \u001b[38;5;129;01mor\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_complex() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1157\u001b[0m non_blocking,\n\u001b[1;32m 1158\u001b[0m memory_format\u001b[38;5;241m=\u001b[39mconvert_to_format,\n\u001b[1;32m 1159\u001b[0m )\n\u001b[0;32m-> 1160\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1161\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1162\u001b[0m \u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_floating_point\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_complex\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1163\u001b[0m \u001b[43m \u001b[49m\u001b[43mnon_blocking\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1164\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1165\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 1166\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mstr\u001b[39m(e) \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot copy out of meta tensor; no data!\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", + "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 2.00 GiB. GPU 0 has a total capacity of 79.10 GiB of which 1.72 GiB is free. Process 191417 has 19.83 GiB memory in use. Process 1524280 has 57.54 GiB memory in use. Of the allocated memory 18.11 GiB is allocated by PyTorch, and 1.05 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)" + ] } ], "source": [ @@ -3298,25 +2721,28 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 79, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "BaseModelOutputWithPast(last_hidden_state=tensor([[[ 2.2031, 0.1777, 0.4258, ..., -2.0312, 0.2246, 0.5664],\n", - " [ 0.0562, -1.1016, 0.4590, ..., -2.1719, -0.1455, -0.6992],\n", - " [-1.5078, -1.3516, 0.8789, ..., -1.9141, 1.3672, -1.0391],\n", - " ...,\n", - " [-1.4453, 0.1260, 0.6992, ..., 0.4746, -0.1729, -0.5938],\n", - " [-0.4961, -0.4160, -0.4551, ..., -0.1328, 0.7461, -0.0376],\n", - " [ 0.3184, 0.4355, -0.7578, ..., 1.5547, 0.8555, -0.8711]]],\n", - " device='cuda:0', dtype=torch.bfloat16, grad_fn=), past_key_values=DynamicCache(), hidden_states=None, attentions=None)" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" + "ename": "RuntimeError", + "evalue": "split_with_sizes expects split_sizes to sum exactly to 8216 (input tensor's size at dimension -1), but got split_sizes=[6144, 3072, 24]", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[79], line 2\u001b[0m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mhybrid_apriel_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mstatic_inputs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/dev/Fast-LLM/fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py:1043\u001b[0m, in \u001b[0;36mAprielSSMHybridModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, inference_params, **flash_attn_kwargs)\u001b[0m\n\u001b[1;32m 1041\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_hidden_states:\n\u001b[1;32m 1042\u001b[0m all_hidden_states \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m (hidden_states,)\n\u001b[0;32m-> 1043\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1044\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1045\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1046\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1047\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1048\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1049\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1050\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1051\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1052\u001b[0m \u001b[43m \u001b[49m\u001b[43minference_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minference_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1053\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mflash_attn_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1054\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1056\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1058\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_attentions \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(decoder_layer, AprielDecoderLayer):\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m~/dev/Fast-LLM/fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py:805\u001b[0m, in \u001b[0;36mAprielSSMDecoderLayer.forward\u001b[0;34m(self, hidden_states, inference_params, **kwargs)\u001b[0m\n\u001b[1;32m 801\u001b[0m residual \u001b[38;5;241m=\u001b[39m hidden_states\n\u001b[1;32m 803\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_layernorm(hidden_states)\n\u001b[0;32m--> 805\u001b[0m mixer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmixer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 806\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 807\u001b[0m \u001b[43m \u001b[49m\u001b[43minference_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minference_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 808\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 810\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m mixer_outputs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhidden_states\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mto(residual\u001b[38;5;241m.\u001b[39mdtype) \u001b[38;5;241m+\u001b[39m residual\n\u001b[1;32m 812\u001b[0m \u001b[38;5;66;03m# Fully Connected\u001b[39;00m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m~/dev/Fast-LLM/fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py:460\u001b[0m, in \u001b[0;36mDiscreteMamba2.forward\u001b[0;34m(self, u, return_mixer_matrix, inference_params, **kwargs)\u001b[0m\n\u001b[1;32m 458\u001b[0m \u001b[38;5;66;03m# Project input\u001b[39;00m\n\u001b[1;32m 459\u001b[0m xBCzA_log \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_proj(u)\n\u001b[0;32m--> 460\u001b[0m xBC, z, A_log \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msplit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 461\u001b[0m \u001b[43m \u001b[49m\u001b[43mxBCzA_log\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 462\u001b[0m \u001b[43m \u001b[49m\u001b[43m[\u001b[49m\n\u001b[1;32m 463\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43md_inner\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_qk_heads\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43md_state\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 464\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43md_inner\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 465\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_v_heads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 466\u001b[0m \u001b[43m \u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 467\u001b[0m \u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 468\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 470\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m state \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 471\u001b[0m \u001b[38;5;66;03m# If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv\u001b[39;00m\n\u001b[1;32m 472\u001b[0m \u001b[38;5;66;03m# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.\u001b[39;00m\n\u001b[1;32m 473\u001b[0m xBC_t \u001b[38;5;241m=\u001b[39m rearrange(xBC[:, :seqlen, :], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb l d -> b d l\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/functional.py:196\u001b[0m, in \u001b[0;36msplit\u001b[0;34m(tensor, split_size_or_sections, dim)\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 191\u001b[0m split, (tensor,), tensor, split_size_or_sections, dim\u001b[38;5;241m=\u001b[39mdim)\n\u001b[1;32m 192\u001b[0m \u001b[38;5;66;03m# Overwriting reason:\u001b[39;00m\n\u001b[1;32m 193\u001b[0m \u001b[38;5;66;03m# This dispatches to two ATen functions depending on the type of\u001b[39;00m\n\u001b[1;32m 194\u001b[0m \u001b[38;5;66;03m# split_size_or_sections. The branching code is in _tensor.py, which we\u001b[39;00m\n\u001b[1;32m 195\u001b[0m \u001b[38;5;66;03m# call here.\u001b[39;00m\n\u001b[0;32m--> 196\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtensor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msplit\u001b[49m\u001b[43m(\u001b[49m\u001b[43msplit_size_or_sections\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/_tensor.py:917\u001b[0m, in \u001b[0;36mTensor.split\u001b[0;34m(self, split_size, dim)\u001b[0m\n\u001b[1;32m 915\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_VF\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;28mself\u001b[39m, split_size, dim) \u001b[38;5;66;03m# type: ignore[attr-defined]\u001b[39;00m\n\u001b[1;32m 916\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 917\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_VF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msplit_with_sizes\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msplit_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mRuntimeError\u001b[0m: split_with_sizes expects split_sizes to sum exactly to 8216 (input tensor's size at dimension -1), but got split_sizes=[6144, 3072, 24]" + ] } ], "source": [ @@ -3326,102 +2752,139 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 9.73it/s]\n" + "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 2.44it/s]\n" ] }, { - "ename": "RuntimeError", - "evalue": "CUDA error: out of memory\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[3], line 6\u001b[0m\n\u001b[1;32m 4\u001b[0m apriel_model \u001b[38;5;241m=\u001b[39m AutoModelForCausalLM\u001b[38;5;241m.\u001b[39mfrom_pretrained(checkpoint, torch_dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mbfloat16, trust_remote_code\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 5\u001b[0m apriel_state_dict \u001b[38;5;241m=\u001b[39m apriel_model\u001b[38;5;241m.\u001b[39mstate_dict()\n\u001b[0;32m----> 6\u001b[0m \u001b[43mapriel_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mto(dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mbfloat16)\n", - "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/modeling_utils.py:3110\u001b[0m, in \u001b[0;36mPreTrainedModel.to\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 3105\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype_present_in_args:\n\u001b[1;32m 3106\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 3107\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3108\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m `dtype` by passing the correct `torch_dtype` argument.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3109\u001b[0m )\n\u001b[0;32m-> 3110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1174\u001b[0m, in \u001b[0;36mModule.to\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1171\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1172\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n\u001b[0;32m-> 1174\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconvert\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:780\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 778\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[1;32m 779\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchildren():\n\u001b[0;32m--> 780\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 782\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[1;32m 783\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[1;32m 784\u001b[0m \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[1;32m 785\u001b[0m \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[1;32m 791\u001b[0m \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:780\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 778\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[1;32m 779\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchildren():\n\u001b[0;32m--> 780\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 782\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[1;32m 783\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[1;32m 784\u001b[0m \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[1;32m 785\u001b[0m \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[1;32m 791\u001b[0m \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:805\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 801\u001b[0m \u001b[38;5;66;03m# Tensors stored in modules are graph leaves, and we don't want to\u001b[39;00m\n\u001b[1;32m 802\u001b[0m \u001b[38;5;66;03m# track autograd history of `param_applied`, so we have to use\u001b[39;00m\n\u001b[1;32m 803\u001b[0m \u001b[38;5;66;03m# `with torch.no_grad():`\u001b[39;00m\n\u001b[1;32m 804\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m--> 805\u001b[0m param_applied \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparam\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 806\u001b[0m p_should_use_set_data \u001b[38;5;241m=\u001b[39m compute_should_use_set_data(param, param_applied)\n\u001b[1;32m 808\u001b[0m \u001b[38;5;66;03m# subclasses may have multiple child tensors so we need to use swap_tensors\u001b[39;00m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1160\u001b[0m, in \u001b[0;36mModule.to..convert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 1153\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m convert_to_format \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m t\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;241m4\u001b[39m, \u001b[38;5;241m5\u001b[39m):\n\u001b[1;32m 1154\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m t\u001b[38;5;241m.\u001b[39mto(\n\u001b[1;32m 1155\u001b[0m device,\n\u001b[1;32m 1156\u001b[0m dtype \u001b[38;5;28;01mif\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_floating_point() \u001b[38;5;129;01mor\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_complex() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1157\u001b[0m non_blocking,\n\u001b[1;32m 1158\u001b[0m memory_format\u001b[38;5;241m=\u001b[39mconvert_to_format,\n\u001b[1;32m 1159\u001b[0m )\n\u001b[0;32m-> 1160\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1161\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1162\u001b[0m \u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_floating_point\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_complex\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1163\u001b[0m \u001b[43m \u001b[49m\u001b[43mnon_blocking\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1164\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1165\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 1166\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mstr\u001b[39m(e) \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot copy out of meta tensor; no data!\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", - "\u001b[0;31mRuntimeError\u001b[0m: CUDA error: out of memory\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n" - ] + "data": { + "text/plain": [ + "AprielForCausalLM(\n", + " (model): AprielModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (rotary_emb): AprielRotaryEmbedding()\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", - "apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", + "apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", "apriel_state_dict = apriel_model.state_dict()\n", - "apriel_model.to(device).to(dtype=torch.bfloat16)\n" + "apriel_model.to(device).to(dtype=torch.bfloat16)" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 129, + "metadata": {}, + "outputs": [], + "source": [ + "# Innitialization using k, q, v from Apriel transformer\n", + "def expand_k_q(k):\n", + " Hq = config.num_attention_heads\n", + " Hk = config.num_key_value_heads\n", + " d_head = config.head_dim\n", + " d = k.shape[-1]\n", + " \n", + " # Expand k\n", + " repeat_factor = Hq // Hk\n", + " k_expanded = k.view(Hk, d_head, d)\n", + " k_expanded = k_expanded.repeat_interleave(repeat_factor, dim=0)\n", + " k_expanded = k_expanded.view(d_head * Hq, d)\n", + " return k_expanded\n", + "\n", + "for block_h, block_t in zip(hybrid_apriel_model.layers, apriel_model.model.layers):\n", + " # print(isinstance(block_h, AprielSSMDecoderLayer))\n", + " if isinstance(block_h, AprielSSMDecoderLayer):\n", + " # print(block_h.mixer.n_v_heads)\n", + " # print(block_t.self_attn.v_proj.weight.shape)\n", + " # print(block_h.mixer.in_proj.weight.shape)\n", + "\n", + " # print(block_h.mixer.in_proj.weight.shape)\n", + " # print(block_t.self_attn.v_proj.weight.shape)\n", + " block_h.mlp.load_state_dict(block_t.mlp.state_dict())\n", + " block_h.input_layernorm.load_state_dict(block_t.input_layernorm.state_dict())\n", + " block_h.post_attention_layernorm.load_state_dict(block_t.post_attention_layernorm.state_dict())\n", + " block_h.mixer.out_proj.load_state_dict(block_t.self_attn.o_proj.state_dict())\n", + " # [x B C z A_log]\n", + " # print(block_h.mixer.d_inner)\n", + " # init x, but interleave to address GQA\n", + " v_expended = expand_k_q(block_t.self_attn.v_proj.weight.data)\n", + " block_h.mixer.in_proj.weight.data[:block_h.mixer.d_inner, : ].copy_(v_expended)\n", + " # init k, but interleave to address GQA\n", + " k_expended = expand_k_q(block_t.self_attn.k_proj.weight.data)\n", + " block_h.mixer.in_proj.weight.data[block_h.mixer.d_inner: 2*block_h.mixer.d_inner, : ].copy_(k_expended)\n", + " # init C ewith Q\n", + " block_h.mixer.in_proj.weight.data[2*block_h.mixer.d_inner: 3*block_h.mixer.d_inner, : ].copy_(block_t.self_attn.q_proj.weight.data)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 124, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "AprielConfig {\n", - " \"_name_or_path\": \"ServiceNow-AI/Apriel-5B-Instruct\",\n", - " \"architectures\": [\n", - " \"AprielForCausalLM\"\n", - " ],\n", - " \"attention_bias\": false,\n", - " \"attention_dropout\": 0.0,\n", - " \"auto_map\": {\n", - " \"AutoConfig\": \"ServiceNow-AI/Apriel-5B-Instruct--configuration_apriel.AprielConfig\",\n", - " \"AutoModelForCausalLM\": \"ServiceNow-AI/Apriel-5B-Instruct--modeling_apriel.AprielForCausalLM\"\n", - " },\n", - " \"bos_token_id\": 1,\n", - " \"eos_token_id\": 2,\n", - " \"head_dim\": 128,\n", - " \"hidden_act\": \"silu\",\n", - " \"hidden_size\": 4096,\n", - " \"initializer_range\": 0.02,\n", - " \"intermediate_size\": 8192,\n", - " \"max_position_embeddings\": 16384,\n", - " \"mlp_bias\": false,\n", - " \"model_type\": \"apriel\",\n", - " \"num_attention_heads\": 24,\n", - " \"num_hidden_layers\": 28,\n", - " \"num_key_value_heads\": 8,\n", - " \"pretraining_tp\": 1,\n", - " \"rms_norm_eps\": 1e-05,\n", - " \"rope_scaling\": {\n", - " \"attention_factor\": null,\n", - " \"beta_fast\": 32.0,\n", - " \"beta_slow\": 1.0,\n", - " \"factor\": 32.0,\n", - " \"original_max_position_embeddings\": 4096,\n", - " \"rope_type\": \"yarn\"\n", - " },\n", - " \"rope_theta\": 1000000.0,\n", - " \"tie_word_embeddings\": false,\n", - " \"torch_dtype\": \"bfloat16\",\n", - " \"transformers_version\": \"4.48.1\",\n", - " \"use_cache\": true,\n", - " \"vocab_size\": 131072\n", - "}" + "torch.Size([1024, 4096])" ] }, - "execution_count": 4, + "execution_count": 124, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "config" + "block_t.self_attn.v_proj.weight.data.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#" ] }, { From f8ca1222a1f938a77338272d6c7f32bb9769d1a9 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 5 May 2025 01:22:06 +0000 Subject: [PATCH 065/114] embeddings_lr_scale --- fast_llm/layers/language_model/config.py | 7 +++++++ fast_llm/layers/language_model/embedding.py | 2 ++ 2 files changed, 9 insertions(+) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 0371eff4..1b7e2d94 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -203,6 +203,13 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + embeddings_lr_scale: float | None = Field( + default=None, + desc="Learning rate scale for the word embeddings.", + doc="May be used to freeze some layers by setting their scale to zero.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) def _validate(self) -> None: self.transformer.validate() diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 1d9406ed..e0386d8d 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -62,6 +62,7 @@ def __init__( min_val=config.init_method_min_embed, max_val=config.init_method_max_embed, ), + lr_scale=config.embeddings_lr_scale, ) if self._use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( @@ -72,6 +73,7 @@ def __init__( max_val=config.init_method_max_embed, ), allow_sequence_tensor_parallel=not config.parallel_embeddings, + lr_scale=config.embeddings_lr_scale, ) # PEFT. From 2db740bcc1d23530a8db15b794b2f6778d8a61ca Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 5 May 2025 13:15:02 -0400 Subject: [PATCH 066/114] fix --- fast_llm/models/gpt/config.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index e5afac16..418f948e 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -179,9 +179,6 @@ def _validate(self) -> None: Assert.eq(self.reference_models.keys(), {name}) if self.model.base_model.use_absolute_position_embeddings: Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) - if self.model.base_model.distillation_model is not None: - # TODO: Support loss masking for distillation? - assert not self.batch.use_loss_masking_spans for reference_model in self.reference_models.values(): Assert.none(reference_model.model.base_model.distillation_model) # TODO: Support more LM head features. From 41d4da3491faa7ac7c1a6bd599be4fa41b97feeb Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Mon, 5 May 2025 21:50:05 +0000 Subject: [PATCH 067/114] disable freezing --- fast_llm/tensor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 84930756..611eb9f4 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -234,7 +234,9 @@ def __init__( self.allow_no_grad = allow_no_grad self.lr_scale = lr_scale if isinstance(lr_scale, tuple) else (lr_scale,) - self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) + # TODO: re-enable when fixed? + # self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) + self.requires_grad = requires_grad # Ensure the parameter is split in chunks of equal size. Assert.multiple(self.dims[0].size, len(self.lr_scale)) From 4160b1f3a189174502f27942d9f9ff995f34eb4d Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 6 May 2025 13:24:29 +0000 Subject: [PATCH 068/114] hybrid model loading and exporting --- fast_llm/models/ssm/config.py | 22 +- fast_llm/models/ssm/conversion.py | 247 +- .../configuration_ssm_apriel.py | 3 +- .../{ => aperiel_ssm}/modeling_ssm_apriel.py | 8 +- .../configuration_ssm_hybrid_apriel.py | 12 +- .../modeling_ssm_hybrid_apriel.py | 14 +- .../models/ssm/external/ariel_to_ssm.ipynb | 2989 ----------------- .../ssm/external/eval/apriel_eval_wrapper.py | 57 +- .../{ => llamba}/configuration_mtp_llamba.py | 0 .../{ => llamba}/modeling_mtp_llamba.py | 0 tests/test_ssms.py | 119 +- 11 files changed, 365 insertions(+), 3106 deletions(-) rename fast_llm/models/ssm/external/{ => aperiel_ssm}/configuration_ssm_apriel.py (95%) rename fast_llm/models/ssm/external/{ => aperiel_ssm}/modeling_ssm_apriel.py (98%) rename fast_llm/models/ssm/external/{ => apriel_hybrid}/configuration_ssm_hybrid_apriel.py (98%) rename fast_llm/models/ssm/external/{ => apriel_hybrid}/modeling_ssm_hybrid_apriel.py (99%) delete mode 100644 fast_llm/models/ssm/external/ariel_to_ssm.ipynb rename fast_llm/models/ssm/external/{ => llamba}/configuration_mtp_llamba.py (100%) rename fast_llm/models/ssm/external/{ => llamba}/modeling_mtp_llamba.py (100%) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 1d8ac007..44093d1f 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -23,12 +23,18 @@ class HybridSSMArchitectureConfig(LanguageModelArchitectureConfig): _abstract = False - hybrid_block_layout: list[str] = Field( - default_factory=lambda: [SSMBlockType.mamba2_discrete.value], + hybrid_block_layout: list[str] | None = Field( + default=None, desc=f"Pattern of blocks to use in the model. Availabel types: {SSMBlockType.__members__.values()}", hint=FieldHint.core, ) + def _validate(self): + if self.hybrid_block_layout is None: + with self._set_implicit_default(): + self.hybrid_block_layout = [SSMBlockType.mamba2_discrete.value] + super()._validate() + @config_class() class HybridSSMBaseModelConfig(LanguageModelBaseConfig, HybridSSMArchitectureConfig): @@ -133,6 +139,17 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return AprielSSMHuggingfaceCheckpointHandler +class AprielSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): + support_optimizer: typing.ClassVar[bool] = False + name: typing.ClassVar[str] = "apriel_ssm_hybrid" + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.ssm.conversion import AprielSSMHHybridHuggingfaceCheckpointHandler + + return AprielSSMHHybridHuggingfaceCheckpointHandler + + @config_class() class HybridSSMModelConfig(FastLLMModelConfig): _abstract = False @@ -141,6 +158,7 @@ class HybridSSMModelConfig(FastLLMModelConfig): checkpoint_formats = FastLLMModelConfig.checkpoint_formats + ( LLambaHuggingfaceCheckpointFormat, AprielSSMHuggingfaceCheckpointFormat, + AprielSSMHHybridHuggingfaceCheckpointFormat, ) @classmethod diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 675c709f..2e84cd10 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -19,8 +19,9 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import NormalizationType from fast_llm.layers.ssm.config import SSMBlockType -from fast_llm.models.gpt.conversion import MLPLayer2Converter +from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter from fast_llm.models.ssm.config import ( + AprielSSMHHybridHuggingfaceCheckpointFormat, AprielSSMHuggingfaceCheckpointFormat, HybridSSMModelConfig, LLambaHuggingfaceCheckpointFormat, @@ -72,10 +73,6 @@ class CommonSSMHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandle @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ - RenameParamConverter( - fast_llm_names=(("vocab_size",),), - export_names=(("vocab_size",),), - ), RenameParamConverter( fast_llm_names=(("ssm", "state_size"),), export_names=( @@ -143,12 +140,79 @@ def _create_config_converters(cls) -> list[ParamConverter]: ), ] + def _create_weight_converters(self) -> list[WeightConverter]: + converters = super()._create_weight_converters() + + num_layers = self._model.config.base_model.transformer.num_layers + ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear + + for i in range(num_layers): + # SSM + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.in_proj", f"model.layers.{i}.mixer.in_proj", ssm_bias + ) + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.out_proj", f"model.layers.{i}.mixer.out_proj", ssm_bias + ) + converters.append( + WeightConverter(f"layers.{i+1}.mixer.D", f"model.layers.{i}.mixer.D", self._model.config.base_model) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.z_bias", f"model.layers.{i}.mixer.z_bias", self._model.config.base_model + ) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.conv1d_weight", + f"model.layers.{i}.mixer.conv1d.weight", + self._model.config.base_model, + ) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.conv1d_bias", + f"model.layers.{i}.mixer.conv1d.bias", + self._model.config.base_model, + ) + ) + + return converters + + def _get_weight_and_bias_converters( + self, + fast_llm_prefix: str | tuple[str, ...], + hf_prefix: str | tuple[str, ...], + use_bias: bool, + cls=WeightConverter, + ) -> list[WeightConverter]: + if isinstance(fast_llm_prefix, str): + fast_llm_prefix = (fast_llm_prefix,) + if isinstance(hf_prefix, str): + hf_prefix = (hf_prefix,) + converters = [ + cls( + tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), + tuple(f"{prefix}.weight" for prefix in hf_prefix), + self._model.config.base_model, + ) + ] + if use_bias: + converters.append( + cls( + tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), + tuple(f"{prefix}.bias" for prefix in hf_prefix), + self._model.config.base_model, + ) + ) + return converters + -class LLambaHuggingfaceCheckpointHandler(HybridModelCheckpointHandler, CommonSSMHuggingfaceCheckpointHandler): +class LLambaHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandler): _model: HybridSSMModel _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig format: typing.ClassVar[type[CheckpointFormat]] = LLambaHuggingfaceCheckpointFormat - _default_block_type: str = SSMBlockType.mamba2_discrete.value + _hf_prefix: str = "backbone" @classmethod def _create_config_converters(cls) -> list[ParamConverter]: @@ -156,6 +220,10 @@ def _create_config_converters(cls) -> list[ParamConverter]: Create config converters for the model, see args under https://huggingface.co/cartesia-ai/Llamba-8B/blob/main/config.json """ return super()._create_config_converters() + [ + RenameParamConverter( + fast_llm_names=(("vocab_size",),), + export_names=(("vocab_size",),), + ), RenameParamConverter( fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) ), @@ -208,6 +276,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _create_weight_converters(self) -> list[WeightConverter]: + # not using super() because LLamba model is called backbone in the checkpoints converters = [] num_layers = self._model.config.base_model.transformer.num_layers norm_bias: bool = False @@ -215,58 +284,68 @@ def _create_weight_converters(self) -> list[WeightConverter]: # Embedding and output if self._model.config.base_model.tie_word_embeddings: - converters.append(WeightConverter("layers.0.word_embeddings_weight", "backbone.embedding.weight")) - converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) + converters.append( + WeightConverter("layers.0.word_embeddings_weight", f"{self._hf_prefix}.embedding.weight") + ) + converters.append(IgnoreImportWeightConverter((), f"{self._hf_prefix}.lm_head.weight")) else: - converters.append(WeightConverter("layers.0.word_embeddings_weight", "backbone.embedding.weight")) - converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) + converters.append( + WeightConverter("layers.0.word_embeddings_weight", f"{self._hf_prefix}.embedding.weight") + ) + converters.append( + WeightConverter(f"layers.{num_layers + 1}.output_weights", f"{self._hf_prefix}.lm_head.weight") + ) # Final norm converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", "backbone.final_layernorm", norm_bias + f"layers.{num_layers + 1}.final_norm", f"{self._hf_prefix}.final_layernorm", norm_bias ) for i in range(num_layers): # SSM converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.in_proj", f"backbone.layers.{i}.mixer.in_proj", ssm_bias + f"layers.{i+1}.mixer.in_proj", f"{self._hf_prefix}.layers.{i}.mixer.in_proj", ssm_bias ) converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.out_proj", f"backbone.layers.{i}.mixer.out_proj", ssm_bias + f"layers.{i+1}.mixer.out_proj", f"{self._hf_prefix}.layers.{i}.mixer.out_proj", ssm_bias ) converters.append( - WeightConverter(f"layers.{i+1}.mixer.D", f"backbone.layers.{i}.mixer.D", self._model.config.base_model) + WeightConverter( + f"layers.{i+1}.mixer.D", f"{self._hf_prefix}.layers.{i}.mixer.D", self._model.config.base_model + ) ) converters.append( WeightConverter( - f"layers.{i+1}.mixer.z_bias", f"backbone.layers.{i}.mixer.z_bias", self._model.config.base_model + f"layers.{i+1}.mixer.z_bias", + f"{self._hf_prefix}.layers.{i}.mixer.z_bias", + self._model.config.base_model, ) ) converters.append( WeightConverter( f"layers.{i+1}.mixer.conv1d_weight", - f"backbone.layers.{i}.mixer.conv1d.weight", + f"{self._hf_prefix}.layers.{i}.mixer.conv1d.weight", self._model.config.base_model, ) ) converters.append( WeightConverter( f"layers.{i+1}.mixer.conv1d_bias", - f"backbone.layers.{i}.mixer.conv1d.bias", + f"{self._hf_prefix}.layers.{i}.mixer.conv1d.bias", self._model.config.base_model, ) ) # Norm converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.norm_1", f"backbone.layers.{i}.input_layernorm", norm_bias + f"layers.{i+1}.norm_1", f"{self._hf_prefix}.layers.{i}.input_layernorm", norm_bias ) converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.norm_2", f"backbone.layers.{i}.post_attention_layernorm", norm_bias + f"layers.{i+1}.norm_2", f"{self._hf_prefix}.layers.{i}.post_attention_layernorm", norm_bias ) # MLP - converters += self._get_mlp_converters(f"layers.{i+1}", f"backbone.layers.{i}") + converters += self._get_mlp_converters(f"layers.{i+1}", f"{self._hf_prefix}.layers.{i}") return converters @@ -330,14 +409,22 @@ def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.An json.dump(config, f) -class AprielSSMHuggingfaceCheckpointHandler(HybridModelCheckpointHandler, CommonSSMHuggingfaceCheckpointHandler): +class AprielSSMHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandler): + """ + Lamba-like configs, pure SSM models. + """ + + _model: HybridSSMModel _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHuggingfaceCheckpointFormat - _default_block_type: str = SSMBlockType.mamba2_discrete.value @classmethod def _create_config_converters(cls) -> list[ParamConverter]: return super()._create_config_converters() + [ + RenameParamConverter( + fast_llm_names=(("vocab_size",),), + export_names=(("vocab_size",),), + ), RenameParamConverter( fast_llm_names=(("ssm", "d_inner"),), export_names=(("ssm_cfg", "d_inner"),), @@ -377,10 +464,9 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _create_weight_converters(self) -> list[WeightConverter]: - converters = [] + converters = super()._create_weight_converters() num_layers = self._model.config.base_model.transformer.num_layers norm_bias: bool = False - ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear # Embedding and output if self._model.config.base_model.tie_word_embeddings: @@ -396,36 +482,6 @@ def _create_weight_converters(self) -> list[WeightConverter]: ) for i in range(num_layers): - # SSM - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.in_proj", f"model.layers.{i}.mixer.in_proj", ssm_bias - ) - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.out_proj", f"model.layers.{i}.mixer.out_proj", ssm_bias - ) - converters.append( - WeightConverter(f"layers.{i+1}.mixer.D", f"model.layers.{i}.mixer.D", self._model.config.base_model) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.z_bias", f"model.layers.{i}.mixer.z_bias", self._model.config.base_model - ) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.conv1d_weight", - f"model.layers.{i}.mixer.conv1d.weight", - self._model.config.base_model, - ) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.conv1d_bias", - f"model.layers.{i}.mixer.conv1d.bias", - self._model.config.base_model, - ) - ) - # Norm converters += self._get_weight_and_bias_converters( f"layers.{i+1}.norm_1", f"model.layers.{i}.input_layernorm", norm_bias @@ -456,33 +512,62 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig ), ] - def _get_weight_and_bias_converters( - self, - fast_llm_prefix: str | tuple[str, ...], - hf_prefix: str | tuple[str, ...], - use_bias: bool, - cls=WeightConverter, - ) -> list[WeightConverter]: - if isinstance(fast_llm_prefix, str): - fast_llm_prefix = (fast_llm_prefix,) - if isinstance(hf_prefix, str): - hf_prefix = (hf_prefix,) - converters = [ - cls( - tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), - tuple(f"{prefix}.weight" for prefix in hf_prefix), - self._model.config.base_model, - ) + @classmethod + def _load_config(cls, directory: pathlib.Path | str) -> dict: + if not os.path.exists(directory / "config.json"): + raise FileNotFoundError(f"config.json not found in {directory}") + with open(directory / "config.json") as f: + config = json.load(f) + Assert.eq(config["model_type"], cls.get_huggingface_model_type()) + return config + + @classmethod + def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: + with open(directory / "config.json", "w") as f: + json.dump(config, f) + + +class AprielSSMHHybridHuggingfaceCheckpointHandler( + HybridModelCheckpointHandler, # handles the block structure parameter + CommonSSMHuggingfaceCheckpointHandler, # handles the SSM layers + CommonLlamaHuggingfaceCheckpointHandler, # handles the LLama layers +): + """ + Lamba-like configs, models that interleave LLama like layers with LLamba-like SSM layers. + """ + + _model: HybridSSMModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHHybridHuggingfaceCheckpointFormat + _default_block_type: str = SSMBlockType.mamba2_discrete.value + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + RenameParamConverter( + fast_llm_names=(("ssm", "d_inner"),), + export_names=(("ssm_cfg", "d_inner"),), + ), + ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False), + ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False), + ] + + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases + return [ + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), + linear_bias, + SplitWeightConverter, + ), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + linear_bias, + MLPLayer2Converter, + ), ] - if use_bias: - converters.append( - cls( - tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), - tuple(f"{prefix}.bias" for prefix in hf_prefix), - self._model.config.base_model, - ) - ) - return converters @classmethod def _load_config(cls, directory: pathlib.Path | str) -> dict: diff --git a/fast_llm/models/ssm/external/configuration_ssm_apriel.py b/fast_llm/models/ssm/external/aperiel_ssm/configuration_ssm_apriel.py similarity index 95% rename from fast_llm/models/ssm/external/configuration_ssm_apriel.py rename to fast_llm/models/ssm/external/aperiel_ssm/configuration_ssm_apriel.py index 2e5d5810..c3f7ef38 100644 --- a/fast_llm/models/ssm/external/configuration_ssm_apriel.py +++ b/fast_llm/models/ssm/external/aperiel_ssm/configuration_ssm_apriel.py @@ -96,7 +96,8 @@ def __init__( "bias": False, "d_inner": 24 * self.head_dim, # num_heads * head_dim } - assert self.head_dim == self.ssm_cfg["d_inner"] // self.ssm_cfg["n_qk_heads"] + if self.head_dim == self.ssm_cfg["d_inner"] // self.ssm_cfg["n_qk_heads"]: + logger.warning("Head dim is equal to d_inner // n_qk_heads.") __all__ = ["AprielConfig"] diff --git a/fast_llm/models/ssm/external/modeling_ssm_apriel.py b/fast_llm/models/ssm/external/aperiel_ssm/modeling_ssm_apriel.py similarity index 98% rename from fast_llm/models/ssm/external/modeling_ssm_apriel.py rename to fast_llm/models/ssm/external/aperiel_ssm/modeling_ssm_apriel.py index 5a1b8db4..dd228024 100644 --- a/fast_llm/models/ssm/external/modeling_ssm_apriel.py +++ b/fast_llm/models/ssm/external/aperiel_ssm/modeling_ssm_apriel.py @@ -19,7 +19,7 @@ from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging from transformers.utils.generic import ModelOutput -from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig +from fast_llm.models.ssm.external.aperiel_ssm.configuration_ssm_apriel import AprielSSMConfig logger = logging.get_logger(__name__) @@ -172,7 +172,7 @@ def __init__( **factory_kwargs, ) self.z_bias = ( - nn.Parameter(torch.zeros(self.d_inner, device=device)) if not bias else 0 + nn.Parameter(torch.zeros(self.d_inner, **factory_kwargs)) if not bias else 0 ) # make sure z_bias always exists # Convolutional layer @@ -197,7 +197,7 @@ def __init__( raise ValueError(f"Unknown activation {self.activation}") # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.n_v_heads, device=device)) + self.D = nn.Parameter(torch.ones(self.n_v_heads, **factory_kwargs)) self.D._optim = {"weight_decay": 0.0} # out_proj @@ -670,7 +670,7 @@ class AprielSSMForCausalLM(AprielSSMPreTrainedModel, GenerationMixin): def __init__(self, config, device=None, dtype=None, **kwargs): super().__init__(config, device=device, dtype=dtype, **kwargs) - self.model = AprielSSMModel(config) + self.model = AprielSSMModel(config, device=device, dtype=dtype) self.vocab_size = config.vocab_size factory_kwargs = {"device": device, "dtype": dtype} self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, **factory_kwargs) diff --git a/fast_llm/models/ssm/external/configuration_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py similarity index 98% rename from fast_llm/models/ssm/external/configuration_ssm_hybrid_apriel.py rename to fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py index 58891802..b030150c 100644 --- a/fast_llm/models/ssm/external/configuration_ssm_hybrid_apriel.py +++ b/fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py @@ -344,7 +344,7 @@ class AprielSSMHybridConfig(PretrainedConfig): >>> configuration = model.config ```""" - model_type = "apriel" + model_type = "apriel_ssm_hybrid" keys_to_ignore_at_inference = ["past_key_values"] # Default tensor parallel plan for base model `AprielModel` base_model_tp_plan = { @@ -386,7 +386,7 @@ def __init__( attention_dropout=0.0, mlp_bias=False, head_dim=None, - ssm_block_pattern=["m2d"], + hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs, ): @@ -413,10 +413,10 @@ def __init__( self.attention_dropout = attention_dropout self.mlp_bias = mlp_bias self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads - self.ssm_block_pattern = ssm_block_pattern - if len(ssm_block_pattern) == 1: - self.ssm_block_pattern = [ssm_block_pattern[0]] * self.num_hidden_layers - assert len(self.ssm_block_pattern) == self.num_hidden_layers + self.hybrid_block_layout = hybrid_block_layout + if len(hybrid_block_layout) == 1: + self.hybrid_block_layout = [hybrid_block_layout[0]] * self.num_hidden_layers + assert len(self.hybrid_block_layout) == self.num_hidden_layers # Validate the correctness of rotary position embeddings parameters # BC: if there is a 'type' field, copy it it to 'rope_type'. if self.rope_scaling is not None and "type" in self.rope_scaling: diff --git a/fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py similarity index 99% rename from fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py rename to fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py index 49b00986..950327df 100644 --- a/fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py +++ b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py @@ -21,7 +21,10 @@ from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging from transformers.utils.generic import ModelOutput -from fast_llm.models.ssm.external.configuration_ssm_hybrid_apriel import ROPE_INIT_FUNCTIONS, AprielSSMHybridConfig +from fast_llm.models.ssm.external.apriel_hybrid.configuration_ssm_hybrid_apriel import ( + ROPE_INIT_FUNCTIONS, + AprielSSMHybridConfig, +) logger = logging.get_logger(__name__) @@ -875,7 +878,7 @@ def __init__(self, config: AprielSSMHybridConfig, device=None, dtype=None, **kwa factory_kwargs = {"device": device, "dtype": dtype} self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, **factory_kwargs) blocks = [] - for layer_idx, type in enumerate(config.ssm_block_pattern): + for layer_idx, type in enumerate(config.hybrid_block_layout): if type == "m2d": blocks.append(AprielSSMDecoderLayer(config, layer_idx, **factory_kwargs)) elif type == "t": @@ -1169,11 +1172,12 @@ def forward( **kwargs: Unpack[KwargsForCausalLM], ) -> Union[tuple, CausalLMOutputWithPast]: - outputs = self.model( + outputs: BaseModelOutputWithPast = self.model( input_ids, return_hidden_states=return_hidden_states, inference_params=inference_params, position_ids=position_ids, + return_dict=True, ) if outputs["last_hidden_state"] is not None and return_logits: @@ -1185,8 +1189,8 @@ def forward( return CustomMambaCausalLMOutput( loss=None, logits=outputs["logits"], - all_hidden_states=outputs["all_hidden_states"], - last_hidden_state=outputs["last_hidden_state"], + all_hidden_states=outputs.hidden_states, + last_hidden_state=outputs.last_hidden_state, ) def generate(self, *args, **kwargs): diff --git a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb b/fast_llm/models/ssm/external/ariel_to_ssm.ipynb deleted file mode 100644 index 664d927f..00000000 --- a/fast_llm/models/ssm/external/ariel_to_ssm.ipynb +++ /dev/null @@ -1,2989 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "import torch\n", - "from mamba_ssm import MambaLMHeadModel\n", - "from mamba_ssm.models.config_mamba import MambaConfig\n", - "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", - "from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig\n", - "from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM\n", - "from transformers.cache_utils import StaticCache\n", - "from types import SimpleNamespace\n", - "\n", - "# make sure the code changes reflected without reload\n", - "%load_ext autoreload\n", - "%autoreload 2\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Apriel SSM for distillation" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 8.90it/s]\n" - ] - }, - { - "data": { - "text/plain": [ - "AprielForCausalLM(\n", - " (model): AprielModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0-27): 28 x AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (rotary_emb): AprielRotaryEmbedding()\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", - ")" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", - "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", - "apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", - "apriel_state_dict = apriel_model.state_dict()\n", - "apriel_model.to(device).to(dtype=torch.bfloat16)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.bfloat16" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_model.config.torch_dtype" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n" - ] - } - ], - "source": [ - "config_apriel = AprielSSMConfig.from_pretrained(\"/mnt/checkpoints_fml/pretrained_models/ssm/apriel_ssm_instruct_base\", trust_remote_code=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "apriel_ssm = AprielSSMForCausalLM(apriel_ssm_config)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "OrderedDict([('model.embed_tokens.weight',\n", - " tensor([[ 0.0105, 0.0330, -0.0032, ..., 0.0076, -0.0051, 0.0112],\n", - " [-0.0111, -0.0101, 0.0064, ..., 0.0144, 0.0098, -0.0194],\n", - " [ 0.0301, 0.0228, 0.0105, ..., -0.0159, 0.0112, -0.0009],\n", - " ...,\n", - " [ 0.0266, 0.0224, -0.0150, ..., 0.0189, -0.0253, -0.0300],\n", - " [-0.0304, 0.0249, 0.0140, ..., -0.0235, 0.0315, -0.0188],\n", - " [-0.0215, -0.0034, 0.0035, ..., -0.0125, 0.0084, 0.0246]])),\n", - " ('model.layers.0.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.0.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.0.mixer.in_proj.weight',\n", - " tensor([[ 0.0104, 0.0055, -0.0148, ..., 0.0208, -0.0074, 0.0015],\n", - " [ 0.0102, 0.0148, 0.0148, ..., -0.0041, 0.0224, -0.0336],\n", - " [ 0.0129, -0.0179, -0.0120, ..., 0.0175, 0.0300, -0.0234],\n", - " ...,\n", - " [-0.0215, 0.0002, 0.0093, ..., -0.0424, 0.0016, -0.0162],\n", - " [-0.0178, -0.0093, 0.0226, ..., 0.0005, 0.0062, 0.0150],\n", - " [-0.0204, 0.0039, -0.0364, ..., -0.0128, 0.0002, 0.0134]])),\n", - " ('model.layers.0.mixer.conv1d.weight',\n", - " tensor([[[-0.1064, -0.3782, -0.3080, -0.3179]],\n", - " \n", - " [[-0.3493, 0.2230, 0.1062, 0.0614]],\n", - " \n", - " [[-0.4650, 0.0300, 0.3021, 0.1197]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.3686, 0.0679, 0.1440, 0.4445]],\n", - " \n", - " [[-0.1480, 0.3750, -0.0552, -0.0297]],\n", - " \n", - " [[ 0.0677, 0.0925, -0.0268, -0.0232]]])),\n", - " ('model.layers.0.mixer.conv1d.bias',\n", - " tensor([ 0.1379, 0.0862, -0.0723, ..., -0.2628, -0.1867, -0.1233])),\n", - " ('model.layers.0.mixer.out_proj.weight',\n", - " tensor([[ 0.0208, -0.0106, -0.0016, ..., 0.0117, 0.0140, -0.0040],\n", - " [-0.0147, 0.0419, 0.0327, ..., -0.0073, -0.0127, 0.0190],\n", - " [-0.0218, 0.0030, 0.0115, ..., -0.0062, 0.0214, 0.0105],\n", - " ...,\n", - " [ 0.0089, 0.0154, -0.0178, ..., -0.0206, -0.0378, 0.0102],\n", - " [ 0.0153, -0.0249, 0.0219, ..., 0.0119, 0.0019, 0.0383],\n", - " [-0.0126, 0.0284, -0.0035, ..., 0.0118, -0.0186, -0.0232]])),\n", - " ('model.layers.0.mlp.gate_proj.weight',\n", - " tensor([[-0.0032, -0.0405, 0.0180, ..., -0.0030, -0.0222, 0.0069],\n", - " [-0.0071, -0.0064, -0.0207, ..., 0.0037, -0.0077, 0.0261],\n", - " [ 0.0236, 0.0167, 0.0065, ..., 0.0064, 0.0035, -0.0092],\n", - " ...,\n", - " [-0.0357, 0.0192, 0.0099, ..., -0.0067, -0.0181, 0.0082],\n", - " [-0.0139, -0.0161, -0.0015, ..., -0.0052, -0.0337, 0.0514],\n", - " [ 0.0105, -0.0205, 0.0198, ..., 0.0090, 0.0315, 0.0066]])),\n", - " ('model.layers.0.mlp.up_proj.weight',\n", - " tensor([[ 0.0074, 0.0237, -0.0300, ..., 0.0343, 0.0016, 0.0395],\n", - " [ 0.0270, 0.0085, 0.0193, ..., 0.0199, -0.0139, 0.0094],\n", - " [ 0.0036, 0.0073, 0.0149, ..., 0.0094, 0.0346, -0.0111],\n", - " ...,\n", - " [ 0.0159, -0.0346, -0.0128, ..., 0.0377, -0.0531, -0.0305],\n", - " [ 0.0283, 0.0162, -0.0377, ..., -0.0254, 0.0110, -0.0167],\n", - " [-0.0277, 0.0130, 0.0161, ..., 0.0089, -0.0190, 0.0214]])),\n", - " ('model.layers.0.mlp.down_proj.weight',\n", - " tensor([[ 0.0157, 0.0105, 0.0036, ..., 0.0229, 0.0080, 0.0303],\n", - " [-0.0143, -0.0067, 0.0016, ..., 0.0494, -0.0043, 0.0072],\n", - " [-0.0148, 0.0113, 0.0025, ..., -0.0186, 0.0206, -0.0119],\n", - " ...,\n", - " [-0.0226, 0.0099, 0.0010, ..., 0.0123, -0.0170, 0.0024],\n", - " [-0.0120, -0.0015, -0.0355, ..., 0.0064, 0.0175, -0.0065],\n", - " [ 0.0364, 0.0364, 0.0265, ..., -0.0222, 0.0030, 0.0296]])),\n", - " ('model.layers.0.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.0.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.1.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.1.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.1.mixer.in_proj.weight',\n", - " tensor([[-0.0116, -0.0182, -0.0017, ..., -0.0216, -0.0136, -0.0203],\n", - " [-0.0142, -0.0106, -0.0334, ..., 0.0287, -0.0273, 0.0050],\n", - " [ 0.0131, -0.0106, -0.0012, ..., 0.0261, -0.0228, -0.0026],\n", - " ...,\n", - " [-0.0029, 0.0023, 0.0360, ..., -0.0195, 0.0018, -0.0227],\n", - " [ 0.0004, 0.0015, -0.0051, ..., -0.0095, 0.0269, 0.0179],\n", - " [ 0.0295, -0.0520, 0.0009, ..., 0.0019, 0.0255, 0.0478]])),\n", - " ('model.layers.1.mixer.conv1d.weight',\n", - " tensor([[[-0.4725, -0.2938, -0.3816, -0.1239]],\n", - " \n", - " [[-0.2002, 0.3790, 0.1908, -0.4679]],\n", - " \n", - " [[-0.3674, 0.3774, -0.2479, 0.4324]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.4181, 0.2263, -0.1937, 0.3585]],\n", - " \n", - " [[ 0.0704, 0.0913, 0.4217, 0.3004]],\n", - " \n", - " [[ 0.3175, -0.3239, -0.0614, -0.3978]]])),\n", - " ('model.layers.1.mixer.conv1d.bias',\n", - " tensor([ 0.4302, 0.0269, -0.3462, ..., 0.4887, 0.2848, 0.0745])),\n", - " ('model.layers.1.mixer.out_proj.weight',\n", - " tensor([[-0.0069, 0.0233, 0.0133, ..., -0.0064, -0.0085, 0.0166],\n", - " [-0.0302, 0.0129, -0.0042, ..., 0.0109, 0.0009, -0.0087],\n", - " [-0.0373, -0.0233, -0.0043, ..., -0.0017, 0.0384, -0.0114],\n", - " ...,\n", - " [-0.0219, 0.0330, -0.0341, ..., 0.0080, 0.0089, 0.0268],\n", - " [-0.0019, -0.0069, 0.0276, ..., 0.0182, -0.0240, 0.0163],\n", - " [ 0.0081, 0.0070, 0.0156, ..., -0.0135, 0.0469, -0.0221]])),\n", - " ('model.layers.1.mlp.gate_proj.weight',\n", - " tensor([[ 0.0175, -0.0074, -0.0028, ..., 0.0197, 0.0034, 0.0221],\n", - " [ 0.0063, 0.0339, -0.0047, ..., 0.0037, -0.0126, -0.0342],\n", - " [-0.0093, -0.0148, -0.0236, ..., 0.0190, -0.0451, -0.0173],\n", - " ...,\n", - " [ 0.0167, 0.0161, 0.0019, ..., -0.0083, -0.0133, 0.0141],\n", - " [-0.0163, 0.0383, -0.0203, ..., 0.0336, -0.0148, 0.0013],\n", - " [-0.0138, -0.0275, -0.0268, ..., -0.0243, -0.0031, -0.0227]])),\n", - " ('model.layers.1.mlp.up_proj.weight',\n", - " tensor([[ 0.0054, 0.0031, 0.0256, ..., 0.0002, 0.0020, -0.0050],\n", - " [ 0.0247, -0.0298, -0.0218, ..., -0.0161, 0.0253, 0.0128],\n", - " [-0.0231, -0.0012, 0.0130, ..., 0.0031, -0.0324, 0.0107],\n", - " ...,\n", - " [ 0.0359, -0.0202, 0.0386, ..., -0.0104, 0.0274, 0.0161],\n", - " [ 0.0062, -0.0111, 0.0338, ..., 0.0041, 0.0001, -0.0019],\n", - " [ 0.0105, -0.0258, 0.0184, ..., -0.0270, -0.0138, -0.0367]])),\n", - " ('model.layers.1.mlp.down_proj.weight',\n", - " tensor([[-0.0163, -0.0308, -0.0203, ..., 0.0002, -0.0227, 0.0019],\n", - " [ 0.0206, 0.0037, 0.0064, ..., -0.0261, -0.0206, 0.0063],\n", - " [ 0.0044, -0.0073, -0.0576, ..., -0.0015, -0.0082, 0.0022],\n", - " ...,\n", - " [-0.0034, 0.0142, -0.0547, ..., -0.0106, -0.0090, 0.0249],\n", - " [-0.0068, 0.0127, -0.0066, ..., -0.0255, 0.0004, 0.0106],\n", - " [-0.0293, 0.0146, -0.0142, ..., -0.0073, -0.0284, -0.0069]])),\n", - " ('model.layers.1.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.1.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.2.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.2.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.2.mixer.in_proj.weight',\n", - " tensor([[ 0.0337, -0.0055, -0.0538, ..., -0.0051, 0.0107, -0.0338],\n", - " [ 0.0227, -0.0008, 0.0003, ..., -0.0312, 0.0090, -0.0126],\n", - " [-0.0238, 0.0146, 0.0240, ..., -0.0114, -0.0180, 0.0025],\n", - " ...,\n", - " [-0.0208, -0.0261, 0.0227, ..., 0.0071, 0.0014, 0.0237],\n", - " [ 0.0356, 0.0372, 0.0186, ..., 0.0052, 0.0049, -0.0195],\n", - " [ 0.0023, -0.0159, -0.0238, ..., 0.0194, -0.0056, -0.0275]])),\n", - " ('model.layers.2.mixer.conv1d.weight',\n", - " tensor([[[ 0.1054, -0.4185, 0.4229, 0.3289]],\n", - " \n", - " [[-0.0081, 0.0321, 0.1334, -0.1055]],\n", - " \n", - " [[ 0.1587, -0.3806, -0.1336, -0.2662]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.2830, -0.3875, -0.2972, 0.0030]],\n", - " \n", - " [[ 0.4210, 0.2190, -0.4942, 0.0465]],\n", - " \n", - " [[-0.1830, -0.3686, 0.2928, -0.0313]]])),\n", - " ('model.layers.2.mixer.conv1d.bias',\n", - " tensor([-0.2931, -0.3513, -0.3013, ..., -0.1934, -0.3115, 0.3889])),\n", - " ('model.layers.2.mixer.out_proj.weight',\n", - " tensor([[-0.0038, -0.0160, -0.0042, ..., 0.0062, 0.0059, -0.0126],\n", - " [-0.0027, -0.0012, -0.0065, ..., -0.0032, 0.0129, -0.0298],\n", - " [ 0.0394, -0.0096, 0.0107, ..., -0.0290, 0.0248, 0.0308],\n", - " ...,\n", - " [ 0.0087, 0.0067, -0.0261, ..., -0.0038, -0.0168, 0.0485],\n", - " [ 0.0118, 0.0042, -0.0186, ..., 0.0104, 0.0281, 0.0028],\n", - " [ 0.0304, -0.0382, -0.0028, ..., -0.0264, -0.0050, 0.0050]])),\n", - " ('model.layers.2.mlp.gate_proj.weight',\n", - " tensor([[-0.0169, 0.0036, 0.0024, ..., 0.0429, 0.0313, 0.0167],\n", - " [-0.0100, 0.0011, -0.0024, ..., -0.0065, 0.0090, 0.0123],\n", - " [ 0.0102, 0.0282, 0.0166, ..., -0.0082, 0.0123, 0.0253],\n", - " ...,\n", - " [ 0.0168, -0.0056, -0.0096, ..., -0.0090, 0.0150, 0.0209],\n", - " [ 0.0258, 0.0113, -0.0093, ..., 0.0335, 0.0386, -0.0156],\n", - " [ 0.0129, 0.0338, -0.0006, ..., -0.0346, 0.0135, -0.0213]])),\n", - " ('model.layers.2.mlp.up_proj.weight',\n", - " tensor([[-0.0029, 0.0416, -0.0102, ..., -0.0413, 0.0019, 0.0063],\n", - " [ 0.0054, 0.0138, 0.0031, ..., -0.0077, -0.0070, -0.0016],\n", - " [ 0.0128, 0.0153, -0.0147, ..., -0.0131, -0.0244, 0.0097],\n", - " ...,\n", - " [-0.0190, -0.0025, 0.0322, ..., -0.0106, -0.0323, -0.0144],\n", - " [-0.0269, -0.0007, 0.0070, ..., 0.0191, -0.0025, 0.0033],\n", - " [-0.0311, 0.0217, -0.0021, ..., 0.0302, -0.0131, 0.0388]])),\n", - " ('model.layers.2.mlp.down_proj.weight',\n", - " tensor([[ 0.0150, -0.0127, 0.0372, ..., 0.0018, 0.0018, 0.0187],\n", - " [-0.0262, 0.0164, 0.0281, ..., 0.0120, -0.0187, -0.0177],\n", - " [ 0.0129, -0.0042, 0.0018, ..., -0.0136, 0.0278, 0.0284],\n", - " ...,\n", - " [ 0.0048, 0.0421, -0.0018, ..., 0.0002, -0.0064, 0.0085],\n", - " [ 0.0276, 0.0146, 0.0228, ..., 0.0055, -0.0288, -0.0081],\n", - " [-0.0133, 0.0102, 0.0318, ..., 0.0209, -0.0270, 0.0128]])),\n", - " ('model.layers.2.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.2.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.3.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.3.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.3.mixer.in_proj.weight',\n", - " tensor([[ 7.4766e-03, -9.8698e-03, -1.9172e-02, ..., 3.7842e-02,\n", - " -2.1648e-03, 2.8147e-03],\n", - " [ 2.4954e-02, -1.2659e-02, 8.0447e-04, ..., 3.1716e-02,\n", - " 4.9989e-03, 6.4200e-03],\n", - " [-3.3345e-02, -1.5256e-02, 2.7295e-02, ..., -1.1240e-02,\n", - " 9.7000e-03, 3.1136e-05],\n", - " ...,\n", - " [-2.0807e-04, -2.5132e-02, -1.9983e-02, ..., -2.9541e-02,\n", - " 4.6152e-04, 5.5341e-02],\n", - " [ 2.0498e-03, 2.2021e-02, -7.6882e-03, ..., 1.6469e-02,\n", - " -1.0645e-02, -1.8442e-03],\n", - " [ 2.0949e-03, -1.2398e-02, 1.2922e-02, ..., 1.1862e-02,\n", - " -4.7119e-03, 3.2352e-02]])),\n", - " ('model.layers.3.mixer.conv1d.weight',\n", - " tensor([[[ 0.2590, 0.1670, 0.3987, -0.1694]],\n", - " \n", - " [[-0.4425, 0.1468, 0.3060, -0.0764]],\n", - " \n", - " [[-0.3638, -0.0575, 0.2156, -0.2468]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.0111, -0.0182, -0.3816, 0.0382]],\n", - " \n", - " [[-0.4723, -0.3712, 0.1963, 0.2877]],\n", - " \n", - " [[-0.4890, 0.1197, 0.1361, 0.3282]]])),\n", - " ('model.layers.3.mixer.conv1d.bias',\n", - " tensor([-0.4712, -0.3272, 0.4587, ..., -0.3145, 0.4086, 0.4005])),\n", - " ('model.layers.3.mixer.out_proj.weight',\n", - " tensor([[-0.0362, 0.0137, -0.0296, ..., -0.0028, 0.0104, 0.0393],\n", - " [ 0.0130, 0.0246, -0.0132, ..., 0.0082, -0.0044, -0.0054],\n", - " [-0.0081, -0.0115, -0.0064, ..., 0.0250, -0.0076, -0.0021],\n", - " ...,\n", - " [ 0.0230, -0.0055, 0.0056, ..., 0.0076, 0.0016, -0.0068],\n", - " [ 0.0472, -0.0068, 0.0336, ..., 0.0079, 0.0211, 0.0031],\n", - " [-0.0450, -0.0005, 0.0219, ..., 0.0044, -0.0006, -0.0278]])),\n", - " ('model.layers.3.mlp.gate_proj.weight',\n", - " tensor([[ 0.0034, 0.0445, -0.0132, ..., 0.0290, 0.0019, 0.0048],\n", - " [ 0.0271, 0.0109, 0.0028, ..., -0.0304, -0.0237, -0.0017],\n", - " [ 0.0098, 0.0252, 0.0392, ..., 0.0486, 0.0326, -0.0171],\n", - " ...,\n", - " [-0.0015, 0.0080, 0.0005, ..., -0.0158, -0.0067, 0.0347],\n", - " [-0.0638, 0.0120, 0.0076, ..., 0.0007, 0.0052, -0.0109],\n", - " [-0.0303, -0.0168, -0.0537, ..., -0.0163, -0.0030, -0.0068]])),\n", - " ('model.layers.3.mlp.up_proj.weight',\n", - " tensor([[-0.0074, -0.0101, 0.0073, ..., -0.0012, -0.0208, -0.0239],\n", - " [ 0.0035, 0.0010, 0.0157, ..., -0.0228, -0.0224, 0.0194],\n", - " [ 0.0457, -0.0129, -0.0063, ..., -0.0312, 0.0261, -0.0018],\n", - " ...,\n", - " [ 0.0012, 0.0093, 0.0121, ..., -0.0035, -0.0367, -0.0454],\n", - " [ 0.0308, -0.0334, 0.0062, ..., 0.0043, -0.0031, -0.0406],\n", - " [-0.0175, -0.0089, -0.0137, ..., -0.0322, -0.0070, -0.0219]])),\n", - " ('model.layers.3.mlp.down_proj.weight',\n", - " tensor([[ 0.0226, 0.0074, -0.0170, ..., 0.0035, 0.0420, -0.0085],\n", - " [ 0.0116, 0.0173, -0.0009, ..., -0.0302, 0.0075, 0.0153],\n", - " [-0.0092, 0.0119, 0.0164, ..., 0.0233, -0.0177, -0.0397],\n", - " ...,\n", - " [-0.0006, -0.0275, 0.0127, ..., -0.0185, 0.0335, -0.0133],\n", - " [ 0.0064, -0.0200, 0.0296, ..., 0.0041, -0.0114, -0.0221],\n", - " [ 0.0317, 0.0392, 0.0553, ..., 0.0191, 0.0188, -0.0176]])),\n", - " ('model.layers.3.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.3.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.4.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.4.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.4.mixer.in_proj.weight',\n", - " tensor([[-0.0266, 0.0092, -0.0260, ..., -0.0121, -0.0286, 0.0267],\n", - " [ 0.0144, -0.0053, -0.0060, ..., -0.0065, 0.0201, -0.0025],\n", - " [-0.0092, -0.0465, -0.0032, ..., 0.0192, -0.0026, 0.0104],\n", - " ...,\n", - " [-0.0210, -0.0286, -0.0148, ..., 0.0593, 0.0130, 0.0118],\n", - " [ 0.0361, -0.0070, 0.0054, ..., -0.0073, 0.0004, 0.0287],\n", - " [ 0.0450, -0.0286, 0.0191, ..., -0.0180, 0.0039, -0.0033]])),\n", - " ('model.layers.4.mixer.conv1d.weight',\n", - " tensor([[[ 0.1450, 0.2065, -0.1750, -0.4560]],\n", - " \n", - " [[-0.2889, -0.4707, -0.0741, 0.1254]],\n", - " \n", - " [[-0.4665, 0.1876, -0.4049, 0.1143]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.0709, 0.2021, -0.0053, -0.1558]],\n", - " \n", - " [[-0.0195, -0.4046, -0.2437, -0.4405]],\n", - " \n", - " [[-0.3615, -0.4314, 0.1667, 0.3139]]])),\n", - " ('model.layers.4.mixer.conv1d.bias',\n", - " tensor([-0.3220, -0.4181, -0.0623, ..., 0.2788, 0.0518, 0.4607])),\n", - " ('model.layers.4.mixer.out_proj.weight',\n", - " tensor([[-0.0011, -0.0279, -0.0160, ..., -0.0222, 0.0262, 0.0234],\n", - " [ 0.0024, 0.0178, -0.0142, ..., 0.0048, -0.0145, 0.0332],\n", - " [-0.0084, -0.0037, 0.0054, ..., -0.0201, -0.0341, -0.0053],\n", - " ...,\n", - " [-0.0120, -0.0440, 0.0097, ..., -0.0070, -0.0129, 0.0170],\n", - " [ 0.0096, -0.0034, -0.0025, ..., 0.0242, 0.0047, 0.0093],\n", - " [ 0.0254, 0.0207, 0.0135, ..., 0.0204, -0.0185, -0.0026]])),\n", - " ('model.layers.4.mlp.gate_proj.weight',\n", - " tensor([[ 0.0049, 0.0087, 0.0081, ..., 0.0145, 0.0188, 0.0441],\n", - " [-0.0103, 0.0147, 0.0180, ..., -0.0190, 0.0182, 0.0160],\n", - " [-0.0041, 0.0289, 0.0106, ..., 0.0144, -0.0070, 0.0104],\n", - " ...,\n", - " [ 0.0086, 0.0079, 0.0155, ..., 0.0037, -0.0242, 0.0091],\n", - " [-0.0320, 0.0084, -0.0508, ..., 0.0003, -0.0120, 0.0129],\n", - " [ 0.0079, 0.0185, 0.0285, ..., -0.0324, 0.0444, -0.0147]])),\n", - " ('model.layers.4.mlp.up_proj.weight',\n", - " tensor([[ 3.4382e-03, 1.9171e-02, 4.1226e-03, ..., 1.3158e-02,\n", - " 3.6365e-02, -8.1017e-03],\n", - " [ 1.8713e-02, -2.7732e-03, 3.1982e-02, ..., -8.5724e-03,\n", - " -3.1505e-02, 2.1047e-03],\n", - " [ 1.2329e-02, 1.8352e-03, 9.2540e-03, ..., 2.9880e-02,\n", - " -2.7856e-04, -8.7440e-04],\n", - " ...,\n", - " [-2.2330e-02, -2.0716e-02, 9.0004e-05, ..., -1.6298e-02,\n", - " -1.9620e-02, 2.5112e-02],\n", - " [ 7.1659e-03, 1.2942e-02, 1.0291e-03, ..., -1.0113e-02,\n", - " -1.6838e-03, 2.0189e-02],\n", - " [ 7.2108e-03, 3.1229e-02, 2.2533e-03, ..., -2.0148e-02,\n", - " -1.3502e-02, -1.8923e-02]])),\n", - " ('model.layers.4.mlp.down_proj.weight',\n", - " tensor([[ 0.0140, -0.0129, 0.0005, ..., -0.0068, -0.0335, 0.0172],\n", - " [-0.0175, -0.0011, 0.0114, ..., -0.0087, -0.0048, -0.0231],\n", - " [-0.0053, -0.0079, -0.0172, ..., -0.0125, -0.0200, 0.0127],\n", - " ...,\n", - " [ 0.0321, -0.0039, 0.0142, ..., 0.0384, 0.0054, 0.0321],\n", - " [ 0.0041, -0.0150, 0.0141, ..., 0.0049, -0.0348, -0.0028],\n", - " [ 0.0176, 0.0132, 0.0090, ..., -0.0117, 0.0241, 0.0417]])),\n", - " ('model.layers.4.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.4.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.5.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.5.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.5.mixer.in_proj.weight',\n", - " tensor([[ 0.0270, 0.0124, 0.0098, ..., 0.0170, -0.0225, 0.0032],\n", - " [ 0.0245, -0.0008, 0.0226, ..., 0.0219, -0.0219, 0.0087],\n", - " [-0.0175, 0.0181, 0.0124, ..., 0.0038, -0.0094, 0.0079],\n", - " ...,\n", - " [-0.0080, -0.0011, 0.0316, ..., -0.0012, 0.0254, 0.0251],\n", - " [-0.0141, -0.0159, -0.0069, ..., 0.0147, -0.0161, -0.0093],\n", - " [ 0.0252, 0.0125, 0.0174, ..., -0.0065, 0.0110, 0.0272]])),\n", - " ('model.layers.5.mixer.conv1d.weight',\n", - " tensor([[[ 0.0684, -0.4353, 0.3899, 0.3199]],\n", - " \n", - " [[ 0.4136, 0.4306, -0.4871, 0.4781]],\n", - " \n", - " [[-0.2516, 0.2109, 0.3891, 0.1501]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.0781, -0.0675, -0.2995, -0.1805]],\n", - " \n", - " [[-0.3360, -0.4148, 0.1846, -0.1013]],\n", - " \n", - " [[ 0.1725, 0.1929, -0.0337, 0.1375]]])),\n", - " ('model.layers.5.mixer.conv1d.bias',\n", - " tensor([-0.4975, -0.0629, -0.2420, ..., -0.2253, 0.2512, 0.2788])),\n", - " ('model.layers.5.mixer.out_proj.weight',\n", - " tensor([[ 1.4306e-02, 1.3230e-02, -2.4141e-02, ..., 1.1763e-02,\n", - " 7.0706e-03, -4.7970e-03],\n", - " [ 2.7478e-02, 1.5179e-03, 1.9229e-02, ..., 1.0928e-02,\n", - " 2.2802e-02, -2.9729e-03],\n", - " [ 1.0169e-02, -1.0741e-02, 2.0628e-02, ..., -1.8109e-02,\n", - " -4.2582e-03, 2.4007e-02],\n", - " ...,\n", - " [-3.2843e-03, 3.7835e-03, -6.7958e-03, ..., -2.6205e-02,\n", - " -2.0391e-02, 5.3912e-03],\n", - " [ 1.2515e-02, -6.4975e-03, 9.9616e-05, ..., 1.0444e-02,\n", - " -2.0596e-02, -8.2915e-03],\n", - " [ 1.7899e-02, 2.0418e-02, -1.9891e-02, ..., -6.6709e-03,\n", - " -3.8566e-02, 2.7005e-02]])),\n", - " ('model.layers.5.mlp.gate_proj.weight',\n", - " tensor([[-2.3807e-03, 2.2714e-03, 2.2736e-05, ..., -2.3039e-03,\n", - " 3.6159e-02, -1.7253e-02],\n", - " [ 3.6929e-02, -6.2031e-03, 1.3606e-02, ..., 2.3592e-02,\n", - " 4.4487e-03, -9.6723e-03],\n", - " [ 4.7507e-02, 2.6413e-02, 1.6759e-02, ..., 1.1910e-02,\n", - " 1.2872e-02, -1.0443e-02],\n", - " ...,\n", - " [-2.0354e-02, -3.9074e-03, 9.7952e-03, ..., 1.0730e-02,\n", - " 2.8752e-02, -8.0048e-03],\n", - " [ 2.5331e-02, -9.9732e-03, 1.0772e-02, ..., 2.0420e-02,\n", - " -3.2179e-02, -1.6437e-02],\n", - " [-3.4425e-02, -1.4578e-02, 2.9686e-03, ..., 4.5907e-02,\n", - " 7.7639e-03, -2.2494e-03]])),\n", - " ('model.layers.5.mlp.up_proj.weight',\n", - " tensor([[ 1.5868e-02, -1.9222e-02, -1.2880e-03, ..., 8.3353e-03,\n", - " -1.8538e-02, 6.7395e-03],\n", - " [-1.8051e-02, -5.0142e-02, -2.2177e-03, ..., -9.3852e-03,\n", - " -3.0374e-02, 2.5795e-02],\n", - " [-1.1737e-02, 2.6278e-02, -2.3205e-02, ..., -1.8399e-03,\n", - " 1.4115e-02, -2.6438e-02],\n", - " ...,\n", - " [ 2.7706e-02, -2.5067e-03, -8.7058e-03, ..., 2.1662e-03,\n", - " -4.9858e-02, -1.1575e-02],\n", - " [-9.5670e-04, 2.1698e-02, -5.4794e-03, ..., -1.0661e-02,\n", - " 1.8568e-02, 5.2615e-03],\n", - " [ 1.0739e-03, 2.2945e-02, 3.0835e-02, ..., 4.1212e-03,\n", - " 1.2643e-02, -1.1568e-05]])),\n", - " ('model.layers.5.mlp.down_proj.weight',\n", - " tensor([[ 0.0052, -0.0343, 0.0072, ..., 0.0004, 0.0320, 0.0362],\n", - " [ 0.0171, -0.0238, -0.0316, ..., 0.0231, 0.0377, 0.0141],\n", - " [-0.0205, 0.0152, 0.0002, ..., -0.0061, -0.0353, -0.0138],\n", - " ...,\n", - " [-0.0039, -0.0039, 0.0326, ..., -0.0208, 0.0160, 0.0185],\n", - " [ 0.0176, -0.0300, -0.0024, ..., -0.0292, -0.0254, -0.0366],\n", - " [ 0.0361, 0.0243, -0.0253, ..., -0.0036, -0.0099, -0.0133]])),\n", - " ('model.layers.5.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.5.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.6.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.6.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.6.mixer.in_proj.weight',\n", - " tensor([[-0.0505, -0.0650, 0.0059, ..., 0.0060, 0.0347, 0.0149],\n", - " [-0.0216, 0.0057, -0.0281, ..., -0.0162, 0.0081, 0.0016],\n", - " [-0.0339, -0.0314, 0.0253, ..., 0.0030, 0.0139, -0.0039],\n", - " ...,\n", - " [ 0.0355, -0.0238, -0.0015, ..., 0.0063, 0.0284, -0.0089],\n", - " [ 0.0093, -0.0381, -0.0261, ..., -0.0170, -0.0170, -0.0288],\n", - " [-0.0228, -0.0110, 0.0107, ..., 0.0300, 0.0010, 0.0141]])),\n", - " ('model.layers.6.mixer.conv1d.weight',\n", - " tensor([[[ 0.4364, 0.2888, 0.2343, 0.3226]],\n", - " \n", - " [[ 0.2804, 0.3558, 0.4061, -0.0480]],\n", - " \n", - " [[ 0.4964, 0.0709, 0.0748, 0.0971]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.4291, 0.2445, -0.3121, 0.4013]],\n", - " \n", - " [[-0.1590, -0.1516, 0.0804, 0.2009]],\n", - " \n", - " [[ 0.1686, 0.0492, -0.2932, 0.1381]]])),\n", - " ('model.layers.6.mixer.conv1d.bias',\n", - " tensor([ 0.4241, -0.0500, 0.3393, ..., 0.1598, -0.4924, -0.3241])),\n", - " ('model.layers.6.mixer.out_proj.weight',\n", - " tensor([[ 0.0026, 0.0272, 0.0005, ..., 0.0434, -0.0293, -0.0105],\n", - " [ 0.0323, -0.0515, 0.0107, ..., -0.0406, 0.0252, -0.0038],\n", - " [-0.0156, -0.0078, 0.0173, ..., 0.0312, -0.0014, -0.0014],\n", - " ...,\n", - " [ 0.0014, -0.0522, -0.0154, ..., 0.0090, -0.0050, -0.0049],\n", - " [ 0.0350, 0.0099, -0.0014, ..., -0.0008, -0.0185, -0.0033],\n", - " [ 0.0134, 0.0002, 0.0325, ..., -0.0129, 0.0165, -0.0265]])),\n", - " ('model.layers.6.mlp.gate_proj.weight',\n", - " tensor([[-0.0011, 0.0202, 0.0236, ..., -0.0137, -0.0063, 0.0085],\n", - " [ 0.0163, 0.0261, 0.0120, ..., -0.0003, -0.0254, 0.0001],\n", - " [ 0.0318, -0.0121, 0.0103, ..., -0.0053, 0.0194, 0.0530],\n", - " ...,\n", - " [ 0.0039, 0.0228, -0.0147, ..., 0.0027, 0.0092, -0.0033],\n", - " [-0.0040, 0.0144, 0.0038, ..., -0.0106, -0.0022, 0.0094],\n", - " [ 0.0220, 0.0296, 0.0550, ..., 0.0079, -0.0135, -0.0092]])),\n", - " ('model.layers.6.mlp.up_proj.weight',\n", - " tensor([[ 0.0061, -0.0291, -0.0133, ..., 0.0054, -0.0049, -0.0028],\n", - " [-0.0032, -0.0201, 0.0218, ..., -0.0155, -0.0264, 0.0496],\n", - " [-0.0046, 0.0384, -0.0093, ..., 0.0356, -0.0245, 0.0175],\n", - " ...,\n", - " [-0.0111, -0.0092, -0.0143, ..., 0.0010, -0.0453, 0.0024],\n", - " [ 0.0078, -0.0025, 0.0227, ..., -0.0130, 0.0118, 0.0095],\n", - " [ 0.0234, -0.0114, -0.0102, ..., -0.0179, -0.0066, -0.0115]])),\n", - " ('model.layers.6.mlp.down_proj.weight',\n", - " tensor([[ 3.6976e-02, 1.7124e-02, -2.1290e-02, ..., -2.5206e-02,\n", - " 4.8023e-03, 9.8474e-03],\n", - " [-7.2866e-03, -5.4149e-03, -2.2242e-03, ..., -8.1606e-03,\n", - " -9.5275e-04, -1.8121e-02],\n", - " [-8.3493e-03, 1.2509e-02, 1.0773e-02, ..., 2.7061e-02,\n", - " 2.8131e-03, 5.8219e-03],\n", - " ...,\n", - " [ 8.7099e-03, 3.9196e-02, -3.5129e-03, ..., -2.3595e-02,\n", - " -8.3965e-03, 2.0074e-02],\n", - " [-2.7467e-02, -2.8721e-03, -2.2291e-02, ..., 9.7135e-03,\n", - " 3.4947e-02, -2.2158e-02],\n", - " [ 6.1744e-03, -4.7684e-03, 4.6690e-04, ..., -3.2948e-03,\n", - " 4.0735e-05, 3.3651e-02]])),\n", - " ('model.layers.6.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.6.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.7.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.7.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.7.mixer.in_proj.weight',\n", - " tensor([[-0.0045, -0.0288, 0.0362, ..., -0.0092, -0.0026, 0.0051],\n", - " [ 0.0160, 0.0139, 0.0057, ..., 0.0121, 0.0071, 0.0134],\n", - " [ 0.0062, 0.0181, 0.0161, ..., -0.0284, -0.0014, -0.0171],\n", - " ...,\n", - " [-0.0053, 0.0067, 0.0095, ..., -0.0175, 0.0235, 0.0125],\n", - " [-0.0048, 0.0041, 0.0038, ..., 0.0099, 0.0194, 0.0124],\n", - " [ 0.0131, 0.0073, -0.0284, ..., 0.0138, -0.0218, 0.0019]])),\n", - " ('model.layers.7.mixer.conv1d.weight',\n", - " tensor([[[ 0.2528, -0.0556, -0.3225, 0.1327]],\n", - " \n", - " [[-0.0437, 0.4941, -0.4075, 0.1062]],\n", - " \n", - " [[-0.3428, 0.2675, 0.1871, 0.0260]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.0409, -0.4458, 0.4488, 0.2841]],\n", - " \n", - " [[-0.2370, -0.3965, 0.0656, -0.1339]],\n", - " \n", - " [[ 0.4677, 0.0073, 0.3741, 0.1525]]])),\n", - " ('model.layers.7.mixer.conv1d.bias',\n", - " tensor([-0.1844, -0.1347, 0.0043, ..., -0.3839, -0.2167, -0.4637])),\n", - " ('model.layers.7.mixer.out_proj.weight',\n", - " tensor([[-2.8471e-02, 3.9783e-03, 6.0125e-03, ..., -1.6079e-02,\n", - " 1.4225e-02, 2.8166e-02],\n", - " [ 5.4680e-03, -5.1414e-03, 5.3077e-05, ..., 1.8734e-02,\n", - " 3.7454e-03, 1.7579e-02],\n", - " [-1.2955e-02, 1.4954e-02, 6.4922e-03, ..., -2.6830e-02,\n", - " 1.4766e-02, -1.8002e-02],\n", - " ...,\n", - " [ 1.7150e-02, 4.6781e-02, -1.1136e-02, ..., 4.7242e-03,\n", - " -1.3072e-02, -1.0412e-02],\n", - " [ 5.5498e-03, -3.0803e-02, -2.4880e-02, ..., -4.2644e-03,\n", - " -1.1047e-02, 1.5815e-02],\n", - " [ 1.7242e-02, 2.7994e-02, -4.8186e-04, ..., -2.2003e-02,\n", - " -2.1834e-02, -2.1826e-02]])),\n", - " ('model.layers.7.mlp.gate_proj.weight',\n", - " tensor([[-0.0302, -0.0160, -0.0341, ..., -0.0121, 0.0007, -0.0338],\n", - " [-0.0186, 0.0257, -0.0154, ..., 0.0153, -0.0029, 0.0163],\n", - " [ 0.0170, 0.0223, -0.0185, ..., -0.0020, 0.0061, 0.0174],\n", - " ...,\n", - " [-0.0044, 0.0044, 0.0077, ..., -0.0183, 0.0041, -0.0003],\n", - " [ 0.0168, 0.0149, -0.0221, ..., 0.0112, 0.0357, 0.0042],\n", - " [ 0.0310, -0.0217, 0.0070, ..., -0.0394, -0.0065, 0.0204]])),\n", - " ('model.layers.7.mlp.up_proj.weight',\n", - " tensor([[-0.0031, -0.0110, 0.0091, ..., 0.0152, -0.0013, 0.0096],\n", - " [ 0.0013, 0.0354, -0.0037, ..., 0.0130, 0.0204, 0.0262],\n", - " [-0.0075, -0.0044, 0.0207, ..., 0.0057, 0.0115, 0.0151],\n", - " ...,\n", - " [-0.0015, 0.0095, -0.0100, ..., -0.0150, 0.0105, -0.0350],\n", - " [-0.0300, -0.0092, -0.0176, ..., -0.0113, 0.0164, -0.0117],\n", - " [-0.0291, -0.0085, 0.0058, ..., 0.0386, -0.0174, -0.0092]])),\n", - " ('model.layers.7.mlp.down_proj.weight',\n", - " tensor([[-0.0276, 0.0017, -0.0217, ..., 0.0302, -0.0079, -0.0003],\n", - " [ 0.0379, 0.0052, 0.0052, ..., 0.0145, 0.0139, -0.0143],\n", - " [ 0.0176, -0.0028, 0.0172, ..., -0.0205, -0.0165, -0.0040],\n", - " ...,\n", - " [ 0.0095, -0.0139, 0.0077, ..., -0.0080, 0.0339, 0.0172],\n", - " [-0.0177, 0.0009, -0.0245, ..., 0.0040, 0.0258, 0.0202],\n", - " [-0.0064, -0.0270, 0.0041, ..., -0.0133, -0.0040, 0.0038]])),\n", - " ('model.layers.7.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.7.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.8.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.8.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.8.mixer.in_proj.weight',\n", - " tensor([[ 0.0050, 0.0270, -0.0196, ..., -0.0121, -0.0090, 0.0083],\n", - " [-0.0083, -0.0177, 0.0159, ..., 0.0298, -0.0202, -0.0265],\n", - " [ 0.0058, 0.0186, 0.0125, ..., -0.0067, -0.0255, 0.0298],\n", - " ...,\n", - " [-0.0164, 0.0012, 0.0023, ..., -0.0355, 0.0347, -0.0011],\n", - " [-0.0371, 0.0033, 0.0345, ..., -0.0097, 0.0019, 0.0185],\n", - " [-0.0322, -0.0160, 0.0072, ..., -0.0195, -0.0229, 0.0118]])),\n", - " ('model.layers.8.mixer.conv1d.weight',\n", - " tensor([[[-0.0520, 0.3004, -0.1990, 0.2512]],\n", - " \n", - " [[-0.4120, -0.0055, 0.1484, -0.3316]],\n", - " \n", - " [[ 0.3939, -0.0567, 0.1432, 0.1880]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.2849, 0.2494, -0.2141, -0.3375]],\n", - " \n", - " [[-0.2823, -0.2402, 0.2228, 0.2331]],\n", - " \n", - " [[ 0.1914, 0.4269, 0.1228, -0.3408]]])),\n", - " ('model.layers.8.mixer.conv1d.bias',\n", - " tensor([0.1304, 0.2065, 0.3084, ..., 0.3863, 0.4883, 0.4724])),\n", - " ('model.layers.8.mixer.out_proj.weight',\n", - " tensor([[ 0.0008, -0.0019, 0.0084, ..., -0.0003, 0.0045, 0.0024],\n", - " [ 0.0137, -0.0003, -0.0031, ..., 0.0013, 0.0131, 0.0090],\n", - " [ 0.0095, 0.0488, -0.0355, ..., 0.0344, -0.0229, -0.0150],\n", - " ...,\n", - " [ 0.0029, 0.0164, -0.0380, ..., -0.0005, -0.0031, 0.0127],\n", - " [-0.0039, 0.0283, 0.0295, ..., 0.0271, -0.0105, -0.0158],\n", - " [-0.0057, -0.0178, 0.0129, ..., 0.0323, -0.0091, 0.0178]])),\n", - " ('model.layers.8.mlp.gate_proj.weight',\n", - " tensor([[-0.0047, 0.0037, -0.0129, ..., 0.0255, -0.0118, 0.0084],\n", - " [ 0.0418, -0.0020, 0.0205, ..., 0.0161, 0.0306, 0.0250],\n", - " [ 0.0011, 0.0144, 0.0204, ..., -0.0007, 0.0298, -0.0067],\n", - " ...,\n", - " [-0.0536, -0.0083, -0.0049, ..., -0.0028, 0.0301, -0.0205],\n", - " [ 0.0031, 0.0139, 0.0070, ..., 0.0120, 0.0004, -0.0226],\n", - " [ 0.0114, -0.0173, 0.0212, ..., -0.0413, -0.0069, 0.0007]])),\n", - " ('model.layers.8.mlp.up_proj.weight',\n", - " tensor([[-0.0005, 0.0028, -0.0137, ..., 0.0078, 0.0348, 0.0006],\n", - " [-0.0020, 0.0300, -0.0056, ..., -0.0258, -0.0130, -0.0212],\n", - " [-0.0135, -0.0111, 0.0151, ..., 0.0043, -0.0426, -0.0109],\n", - " ...,\n", - " [ 0.0273, 0.0057, -0.0108, ..., -0.0205, 0.0005, -0.0239],\n", - " [ 0.0226, 0.0325, -0.0187, ..., 0.0069, -0.0132, -0.0002],\n", - " [ 0.0280, -0.0007, -0.0047, ..., 0.0159, -0.0054, -0.0172]])),\n", - " ('model.layers.8.mlp.down_proj.weight',\n", - " tensor([[-0.0091, 0.0072, 0.0030, ..., 0.0025, -0.0159, -0.0277],\n", - " [ 0.0159, -0.0260, -0.0076, ..., -0.0059, -0.0129, 0.0358],\n", - " [ 0.0026, -0.0357, -0.0138, ..., -0.0326, -0.0291, 0.0010],\n", - " ...,\n", - " [-0.0237, 0.0272, -0.0130, ..., -0.0280, 0.0097, -0.0563],\n", - " [ 0.0092, 0.0056, 0.0079, ..., -0.0224, 0.0039, -0.0054],\n", - " [-0.0109, -0.0241, -0.0223, ..., -0.0187, 0.0190, 0.0082]])),\n", - " ('model.layers.8.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.8.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.9.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.9.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.9.mixer.in_proj.weight',\n", - " tensor([[ 4.9824e-02, 5.7576e-03, -5.1022e-03, ..., -2.5615e-02,\n", - " 7.1750e-04, 1.5247e-02],\n", - " [-2.8065e-02, -1.2649e-02, -2.3566e-02, ..., 1.7742e-02,\n", - " -1.1202e-02, -2.1476e-02],\n", - " [ 2.0911e-02, 1.6496e-02, -1.9818e-02, ..., 4.0223e-02,\n", - " 1.8544e-02, -2.3633e-02],\n", - " ...,\n", - " [-4.3387e-02, -1.6504e-02, 2.2008e-02, ..., -2.5138e-03,\n", - " -5.6073e-03, -4.8212e-03],\n", - " [-1.9964e-05, -1.5835e-02, 1.2977e-02, ..., 4.1913e-03,\n", - " 4.5898e-02, -3.5822e-02],\n", - " [ 3.1376e-02, -5.4614e-03, -2.5093e-02, ..., -3.7903e-03,\n", - " 1.3560e-02, 3.3366e-02]])),\n", - " ('model.layers.9.mixer.conv1d.weight',\n", - " tensor([[[ 0.1986, -0.1666, -0.4140, -0.4607]],\n", - " \n", - " [[-0.3454, -0.3973, 0.2169, -0.2138]],\n", - " \n", - " [[ 0.2006, -0.3736, 0.3944, -0.0589]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.4604, 0.1224, -0.2571, -0.0286]],\n", - " \n", - " [[-0.2723, -0.1617, 0.3483, 0.2299]],\n", - " \n", - " [[ 0.4866, 0.2559, 0.3969, 0.0554]]])),\n", - " ('model.layers.9.mixer.conv1d.bias',\n", - " tensor([ 0.3388, 0.4633, -0.3762, ..., -0.3491, -0.2971, 0.0494])),\n", - " ('model.layers.9.mixer.out_proj.weight',\n", - " tensor([[ 0.0023, -0.0181, 0.0358, ..., 0.0243, 0.0070, -0.0183],\n", - " [ 0.0006, 0.0065, 0.0057, ..., -0.0351, -0.0107, 0.0132],\n", - " [ 0.0153, -0.0038, 0.0059, ..., -0.0285, -0.0247, -0.0104],\n", - " ...,\n", - " [ 0.0244, -0.0120, 0.0064, ..., -0.0133, 0.0263, 0.0016],\n", - " [ 0.0056, -0.0111, 0.0029, ..., -0.0017, -0.0172, -0.0071],\n", - " [-0.0056, -0.0192, -0.0238, ..., 0.0245, -0.0102, -0.0331]])),\n", - " ('model.layers.9.mlp.gate_proj.weight',\n", - " tensor([[-0.0132, 0.0014, -0.0413, ..., -0.0254, -0.0245, 0.0031],\n", - " [-0.0195, -0.0107, -0.0192, ..., 0.0012, -0.0026, 0.0148],\n", - " [-0.0074, -0.0070, -0.0078, ..., 0.0013, -0.0011, -0.0111],\n", - " ...,\n", - " [-0.0137, 0.0302, 0.0084, ..., -0.0063, -0.0065, 0.0240],\n", - " [ 0.0072, 0.0134, 0.0161, ..., 0.0122, 0.0182, 0.0137],\n", - " [ 0.0079, 0.0008, 0.0160, ..., 0.0281, 0.0226, 0.0058]])),\n", - " ('model.layers.9.mlp.up_proj.weight',\n", - " tensor([[ 0.0078, 0.0153, -0.0155, ..., 0.0153, -0.0164, -0.0140],\n", - " [-0.0072, -0.0050, 0.0030, ..., 0.0146, -0.0148, -0.0080],\n", - " [ 0.0165, -0.0078, 0.0005, ..., -0.0545, -0.0096, 0.0296],\n", - " ...,\n", - " [-0.0253, 0.0183, -0.0081, ..., -0.0061, 0.0270, -0.0003],\n", - " [-0.0015, -0.0320, 0.0361, ..., -0.0087, 0.0341, -0.0157],\n", - " [ 0.0041, 0.0102, -0.0195, ..., -0.0441, -0.0106, 0.0275]])),\n", - " ('model.layers.9.mlp.down_proj.weight',\n", - " tensor([[-6.3367e-02, -1.8214e-02, 5.7221e-03, ..., 2.1307e-02,\n", - " -3.0707e-02, -1.3281e-02],\n", - " [-7.7457e-05, -9.1894e-05, 6.8686e-03, ..., -4.7175e-03,\n", - " -1.1585e-03, -2.7604e-02],\n", - " [ 2.9301e-02, -5.9431e-03, -2.5356e-03, ..., -2.7858e-02,\n", - " 1.1647e-02, 1.1245e-02],\n", - " ...,\n", - " [-1.0442e-02, -9.6151e-03, -3.6635e-02, ..., -1.1052e-02,\n", - " -4.5122e-03, 4.0012e-03],\n", - " [ 3.2950e-02, -1.3836e-03, -7.8318e-03, ..., -1.2788e-03,\n", - " 2.3422e-02, -3.2098e-02],\n", - " [-9.2294e-03, 1.3838e-02, -2.0327e-02, ..., -3.8760e-02,\n", - " 2.2118e-02, 1.0696e-02]])),\n", - " ('model.layers.9.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.9.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.10.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.10.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.10.mixer.in_proj.weight',\n", - " tensor([[ 0.0096, -0.0159, 0.0141, ..., 0.0111, 0.0218, 0.0220],\n", - " [-0.0381, -0.0015, 0.0126, ..., -0.0066, -0.0034, -0.0119],\n", - " [ 0.0223, 0.0032, -0.0195, ..., -0.0107, -0.0018, 0.0059],\n", - " ...,\n", - " [-0.0256, -0.0170, -0.0362, ..., -0.0007, -0.0039, 0.0075],\n", - " [ 0.0136, -0.0045, 0.0128, ..., -0.0017, 0.0083, -0.0004],\n", - " [-0.0246, -0.0021, 0.0073, ..., 0.0020, 0.0071, 0.0090]])),\n", - " ('model.layers.10.mixer.conv1d.weight',\n", - " tensor([[[ 0.0463, -0.4497, -0.0679, -0.2209]],\n", - " \n", - " [[-0.3805, 0.4459, 0.1999, -0.4996]],\n", - " \n", - " [[ 0.1529, 0.1789, -0.1535, 0.1824]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.1087, -0.4478, -0.0420, 0.3437]],\n", - " \n", - " [[-0.2809, -0.4617, 0.3209, 0.4873]],\n", - " \n", - " [[ 0.1139, -0.0060, -0.0219, 0.0853]]])),\n", - " ('model.layers.10.mixer.conv1d.bias',\n", - " tensor([ 0.1364, -0.0475, 0.0849, ..., 0.1928, 0.2075, 0.1058])),\n", - " ('model.layers.10.mixer.out_proj.weight',\n", - " tensor([[-0.0164, -0.0188, 0.0174, ..., -0.0106, -0.0107, -0.0036],\n", - " [ 0.0048, -0.0016, -0.0444, ..., -0.0182, -0.0264, -0.0038],\n", - " [ 0.0089, -0.0225, -0.0002, ..., -0.0141, -0.0008, -0.0037],\n", - " ...,\n", - " [-0.0005, 0.0159, 0.0033, ..., 0.0187, -0.0064, 0.0233],\n", - " [-0.0050, 0.0296, 0.0147, ..., -0.0018, 0.0137, -0.0346],\n", - " [-0.0064, -0.0132, -0.0434, ..., -0.0173, -0.0113, -0.0175]])),\n", - " ('model.layers.10.mlp.gate_proj.weight',\n", - " tensor([[-0.0174, -0.0053, -0.0325, ..., -0.0072, -0.0280, 0.0033],\n", - " [ 0.0006, -0.0160, 0.0346, ..., 0.0019, 0.0059, 0.0198],\n", - " [ 0.0231, -0.0187, 0.0115, ..., 0.0085, 0.0080, 0.0061],\n", - " ...,\n", - " [ 0.0153, 0.0241, -0.0184, ..., 0.0089, -0.0242, 0.0010],\n", - " [-0.0019, -0.0322, 0.0011, ..., -0.0097, -0.0305, 0.0065],\n", - " [-0.0107, 0.0240, 0.0168, ..., 0.0226, -0.0238, 0.0117]])),\n", - " ('model.layers.10.mlp.up_proj.weight',\n", - " tensor([[-0.0072, 0.0352, 0.0282, ..., -0.0025, -0.0114, 0.0129],\n", - " [-0.0102, 0.0196, 0.0760, ..., 0.0461, -0.0058, -0.0112],\n", - " [-0.0271, 0.0323, -0.0069, ..., 0.0133, -0.0371, -0.0619],\n", - " ...,\n", - " [ 0.0100, 0.0011, 0.0262, ..., -0.0232, 0.0217, 0.0002],\n", - " [ 0.0151, -0.0266, -0.0074, ..., 0.0096, 0.0036, 0.0033],\n", - " [ 0.0004, 0.0103, 0.0363, ..., -0.0095, -0.0309, -0.0059]])),\n", - " ('model.layers.10.mlp.down_proj.weight',\n", - " tensor([[ 0.0124, -0.0225, -0.0294, ..., 0.0280, 0.0056, 0.0231],\n", - " [ 0.0124, -0.0030, 0.0014, ..., 0.0323, 0.0094, -0.0034],\n", - " [-0.0078, 0.0041, -0.0056, ..., 0.0241, -0.0278, -0.0152],\n", - " ...,\n", - " [-0.0044, 0.0025, -0.0161, ..., -0.0075, -0.0126, 0.0014],\n", - " [-0.0109, -0.0050, 0.0327, ..., -0.0300, -0.0048, 0.0284],\n", - " [ 0.0050, -0.0183, 0.0086, ..., -0.0072, 0.0139, -0.0010]])),\n", - " ('model.layers.10.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.10.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.11.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.11.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.11.mixer.in_proj.weight',\n", - " tensor([[-0.0133, 0.0225, 0.0486, ..., -0.0214, -0.0120, -0.0150],\n", - " [ 0.0183, 0.0020, 0.0079, ..., -0.0163, 0.0016, -0.0214],\n", - " [-0.0276, -0.0112, 0.0121, ..., -0.0057, -0.0143, -0.0462],\n", - " ...,\n", - " [-0.0142, -0.0080, -0.0194, ..., 0.0087, -0.0212, -0.0140],\n", - " [ 0.0060, -0.0005, -0.0171, ..., -0.0017, 0.0223, 0.0169],\n", - " [-0.0290, -0.0016, 0.0117, ..., 0.0037, 0.0047, 0.0152]])),\n", - " ('model.layers.11.mixer.conv1d.weight',\n", - " tensor([[[-0.2822, -0.4216, 0.4786, 0.0802]],\n", - " \n", - " [[-0.3671, 0.1761, -0.2686, 0.1631]],\n", - " \n", - " [[-0.3902, -0.2811, -0.0748, 0.4662]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.1623, 0.2871, -0.4585, 0.4755]],\n", - " \n", - " [[-0.0260, 0.4541, -0.2983, 0.2297]],\n", - " \n", - " [[-0.2991, -0.3590, -0.3256, -0.1434]]])),\n", - " ('model.layers.11.mixer.conv1d.bias',\n", - " tensor([ 0.1218, -0.0542, 0.3485, ..., 0.0528, 0.2711, -0.2811])),\n", - " ('model.layers.11.mixer.out_proj.weight',\n", - " tensor([[ 0.0032, 0.0028, -0.0122, ..., -0.0299, -0.0105, 0.0021],\n", - " [-0.0466, -0.0170, -0.0017, ..., 0.0156, -0.0287, 0.0066],\n", - " [ 0.0016, 0.0054, -0.0071, ..., -0.0240, 0.0215, -0.0046],\n", - " ...,\n", - " [-0.0210, 0.0034, -0.0267, ..., 0.0461, -0.0076, -0.0016],\n", - " [-0.0012, -0.0101, 0.0196, ..., 0.0121, -0.0043, -0.0143],\n", - " [-0.0067, 0.0086, 0.0134, ..., 0.0080, 0.0255, 0.0225]])),\n", - " ('model.layers.11.mlp.gate_proj.weight',\n", - " tensor([[ 0.0179, -0.0429, -0.0134, ..., 0.0110, 0.0368, -0.0259],\n", - " [ 0.0013, -0.0231, 0.0072, ..., -0.0056, -0.0012, -0.0037],\n", - " [-0.0172, -0.0162, 0.0088, ..., -0.0175, 0.0079, -0.0065],\n", - " ...,\n", - " [ 0.0287, -0.0289, 0.0045, ..., 0.0039, 0.0269, 0.0199],\n", - " [ 0.0043, -0.0202, -0.0261, ..., 0.0104, -0.0161, -0.0057],\n", - " [-0.0154, 0.0085, 0.0061, ..., 0.0208, 0.0001, 0.0166]])),\n", - " ('model.layers.11.mlp.up_proj.weight',\n", - " tensor([[-0.0107, 0.0328, 0.0065, ..., -0.0190, -0.0082, -0.0047],\n", - " [-0.0001, 0.0102, 0.0310, ..., -0.0396, -0.0278, -0.0095],\n", - " [-0.0288, 0.0052, 0.0137, ..., -0.0220, 0.0007, -0.0170],\n", - " ...,\n", - " [ 0.0213, -0.0074, -0.0033, ..., 0.0183, 0.0336, -0.0180],\n", - " [-0.0098, -0.0162, 0.0486, ..., 0.0191, 0.0064, 0.0269],\n", - " [-0.0251, 0.0081, 0.0053, ..., 0.0110, 0.0023, 0.0041]])),\n", - " ('model.layers.11.mlp.down_proj.weight',\n", - " tensor([[ 0.0166, -0.0410, 0.0066, ..., -0.0273, 0.0220, 0.0184],\n", - " [ 0.0092, 0.0087, -0.0136, ..., 0.0013, -0.0205, 0.0247],\n", - " [-0.0252, -0.0040, -0.0112, ..., -0.0331, 0.0201, -0.0038],\n", - " ...,\n", - " [ 0.0072, 0.0190, 0.0089, ..., 0.0098, -0.0235, -0.0141],\n", - " [-0.0045, -0.0381, -0.0134, ..., 0.0171, -0.0077, -0.0180],\n", - " [ 0.0109, 0.0060, 0.0048, ..., -0.0108, -0.0122, 0.0110]])),\n", - " ('model.layers.11.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.11.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.12.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.12.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.12.mixer.in_proj.weight',\n", - " tensor([[ 0.0043, 0.0138, 0.0138, ..., -0.0042, 0.0121, -0.0190],\n", - " [ 0.0002, -0.0199, 0.0315, ..., 0.0170, 0.0051, -0.0062],\n", - " [-0.0053, 0.0043, 0.0283, ..., -0.0087, 0.0069, -0.0160],\n", - " ...,\n", - " [-0.0313, 0.0200, 0.0036, ..., 0.0147, 0.0153, 0.0098],\n", - " [-0.0157, 0.0120, -0.0112, ..., 0.0166, -0.0005, 0.0066],\n", - " [-0.0271, 0.0037, 0.0163, ..., 0.0304, 0.0023, 0.0083]])),\n", - " ('model.layers.12.mixer.conv1d.weight',\n", - " tensor([[[-0.4295, -0.2474, -0.2324, -0.2138]],\n", - " \n", - " [[ 0.3607, -0.4824, 0.1667, 0.1348]],\n", - " \n", - " [[ 0.3596, 0.1167, 0.1089, -0.4010]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.3527, -0.3346, -0.3755, 0.1450]],\n", - " \n", - " [[-0.1921, -0.0632, -0.4885, -0.3986]],\n", - " \n", - " [[ 0.1950, 0.3037, -0.1630, 0.0353]]])),\n", - " ('model.layers.12.mixer.conv1d.bias',\n", - " tensor([0.3103, 0.0451, 0.4533, ..., 0.0235, 0.1819, 0.3933])),\n", - " ('model.layers.12.mixer.out_proj.weight',\n", - " tensor([[ 0.0167, -0.0197, -0.0054, ..., 0.0096, 0.0271, -0.0118],\n", - " [ 0.0167, -0.0455, 0.0001, ..., 0.0003, 0.0265, 0.0111],\n", - " [ 0.0231, -0.0113, 0.0195, ..., -0.0171, -0.0044, -0.0244],\n", - " ...,\n", - " [ 0.0042, 0.0048, 0.0357, ..., 0.0126, -0.0288, 0.0149],\n", - " [ 0.0192, 0.0078, 0.0126, ..., 0.0029, 0.0255, -0.0203],\n", - " [-0.0054, -0.0543, 0.0039, ..., -0.0240, 0.0282, 0.0082]])),\n", - " ('model.layers.12.mlp.gate_proj.weight',\n", - " tensor([[-0.0417, -0.0193, -0.0022, ..., 0.0031, 0.0337, 0.0175],\n", - " [ 0.0215, -0.0109, -0.0657, ..., -0.0145, -0.0475, -0.0091],\n", - " [-0.0225, -0.0012, -0.0020, ..., -0.0291, 0.0097, 0.0163],\n", - " ...,\n", - " [-0.0018, 0.0048, -0.0265, ..., -0.0056, 0.0446, 0.0045],\n", - " [ 0.0270, 0.0086, -0.0110, ..., -0.0038, 0.0176, 0.0138],\n", - " [-0.0134, 0.0046, -0.0186, ..., -0.0098, 0.0191, 0.0095]])),\n", - " ('model.layers.12.mlp.up_proj.weight',\n", - " tensor([[ 0.0180, 0.0075, 0.0147, ..., 0.0142, 0.0291, -0.0303],\n", - " [-0.0079, -0.0277, -0.0151, ..., -0.0069, -0.0045, -0.0223],\n", - " [ 0.0180, -0.0087, 0.0074, ..., 0.0215, 0.0274, -0.0199],\n", - " ...,\n", - " [-0.0215, -0.0115, 0.0140, ..., -0.0283, -0.0171, -0.0229],\n", - " [ 0.0231, -0.0179, -0.0386, ..., 0.0364, 0.0311, 0.0048],\n", - " [-0.0111, 0.0079, 0.0328, ..., 0.0285, 0.0423, 0.0039]])),\n", - " ('model.layers.12.mlp.down_proj.weight',\n", - " tensor([[-0.0361, 0.0192, -0.0005, ..., -0.0151, 0.0116, -0.0068],\n", - " [ 0.0203, -0.0064, 0.0061, ..., 0.0325, -0.0004, -0.0299],\n", - " [-0.0028, 0.0131, 0.0141, ..., -0.0108, -0.0070, -0.0090],\n", - " ...,\n", - " [ 0.0165, -0.0198, -0.0242, ..., 0.0162, 0.0099, 0.0025],\n", - " [ 0.0148, 0.0056, -0.0139, ..., 0.0108, -0.0477, 0.0225],\n", - " [ 0.0156, 0.0249, -0.0287, ..., -0.0200, -0.0496, 0.0169]])),\n", - " ('model.layers.12.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.12.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.13.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.13.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.13.mixer.in_proj.weight',\n", - " tensor([[-0.0064, -0.0200, 0.0384, ..., -0.0036, 0.0158, -0.0007],\n", - " [-0.0074, 0.0105, 0.0043, ..., 0.0097, 0.0259, -0.0012],\n", - " [ 0.0297, -0.0146, -0.0012, ..., 0.0273, 0.0309, 0.0087],\n", - " ...,\n", - " [ 0.0204, -0.0063, 0.0136, ..., -0.0092, 0.0196, 0.0057],\n", - " [ 0.0195, 0.0059, 0.0228, ..., 0.0093, -0.0183, -0.0003],\n", - " [-0.0131, -0.0447, -0.0262, ..., -0.0125, 0.0237, -0.0404]])),\n", - " ('model.layers.13.mixer.conv1d.weight',\n", - " tensor([[[ 7.7458e-03, 4.9829e-01, 2.1690e-01, -2.3587e-01]],\n", - " \n", - " [[ 3.7281e-01, -4.0991e-03, 2.4588e-01, -1.1600e-01]],\n", - " \n", - " [[-4.8238e-01, -2.8961e-01, -4.4331e-02, 1.0011e-01]],\n", - " \n", - " ...,\n", - " \n", - " [[-3.6304e-01, -1.4106e-01, -3.5434e-01, 1.4923e-01]],\n", - " \n", - " [[-2.3703e-01, 3.9285e-04, -2.1456e-02, -2.5568e-01]],\n", - " \n", - " [[ 1.5303e-02, -8.3474e-03, -3.2668e-01, -4.8096e-01]]])),\n", - " ('model.layers.13.mixer.conv1d.bias',\n", - " tensor([-0.2462, 0.1532, -0.2298, ..., -0.3016, 0.1210, -0.3777])),\n", - " ('model.layers.13.mixer.out_proj.weight',\n", - " tensor([[-0.0019, 0.0103, 0.0098, ..., -0.0050, 0.0180, -0.0117],\n", - " [-0.0153, 0.0134, -0.0102, ..., 0.0327, -0.0387, 0.0025],\n", - " [ 0.0102, -0.0038, 0.0224, ..., -0.0118, 0.0234, 0.0014],\n", - " ...,\n", - " [-0.0201, 0.0233, 0.0189, ..., 0.0010, 0.0313, 0.0130],\n", - " [ 0.0193, 0.0035, -0.0253, ..., 0.0084, -0.0208, 0.0372],\n", - " [ 0.0367, -0.0029, -0.0205, ..., -0.0055, -0.0209, 0.0082]])),\n", - " ('model.layers.13.mlp.gate_proj.weight',\n", - " tensor([[ 0.0148, -0.0052, 0.0371, ..., -0.0118, 0.0397, -0.0234],\n", - " [ 0.0237, -0.0323, 0.0219, ..., 0.0098, -0.0304, 0.0165],\n", - " [ 0.0168, -0.0289, 0.0038, ..., 0.0022, 0.0174, 0.0043],\n", - " ...,\n", - " [-0.0135, 0.0258, -0.0172, ..., 0.0251, -0.0071, -0.0384],\n", - " [ 0.0005, -0.0123, 0.0116, ..., 0.0041, -0.0108, -0.0068],\n", - " [ 0.0116, 0.0069, 0.0063, ..., 0.0045, -0.0145, 0.0185]])),\n", - " ('model.layers.13.mlp.up_proj.weight',\n", - " tensor([[-0.0002, -0.0120, 0.0069, ..., 0.0005, -0.0108, -0.0284],\n", - " [ 0.0215, 0.0045, 0.0167, ..., 0.0177, -0.0030, 0.0051],\n", - " [ 0.0265, 0.0169, 0.0047, ..., 0.0069, -0.0299, 0.0196],\n", - " ...,\n", - " [ 0.0127, -0.0063, 0.0242, ..., -0.0061, -0.0263, 0.0041],\n", - " [ 0.0142, -0.0515, -0.0221, ..., -0.0369, -0.0399, -0.0210],\n", - " [ 0.0123, 0.0133, -0.0269, ..., 0.0092, -0.0177, 0.0226]])),\n", - " ('model.layers.13.mlp.down_proj.weight',\n", - " tensor([[ 0.0048, 0.0360, -0.0037, ..., 0.0169, 0.0304, -0.0162],\n", - " [ 0.0271, -0.0121, 0.0108, ..., -0.0424, 0.0293, -0.0137],\n", - " [ 0.0225, -0.0061, -0.0096, ..., 0.0075, -0.0168, 0.0142],\n", - " ...,\n", - " [ 0.0039, -0.0152, -0.0156, ..., 0.0181, 0.0105, 0.0070],\n", - " [ 0.0311, 0.0205, 0.0259, ..., -0.0025, 0.0060, -0.0125],\n", - " [ 0.0004, -0.0114, 0.0022, ..., -0.0159, -0.0290, 0.0036]])),\n", - " ('model.layers.13.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.13.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.14.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.14.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.14.mixer.in_proj.weight',\n", - " tensor([[-0.0123, 0.0054, 0.0059, ..., 0.0285, -0.0292, -0.0184],\n", - " [-0.0146, -0.0175, 0.0155, ..., -0.0206, -0.0190, -0.0172],\n", - " [ 0.0050, -0.0235, -0.0159, ..., -0.0013, -0.0102, 0.0082],\n", - " ...,\n", - " [-0.0243, -0.0013, 0.0312, ..., -0.0141, -0.0156, 0.0279],\n", - " [ 0.0018, 0.0181, -0.0188, ..., 0.0593, -0.0155, 0.0156],\n", - " [ 0.0036, 0.0182, -0.0308, ..., 0.0306, -0.0035, 0.0037]])),\n", - " ('model.layers.14.mixer.conv1d.weight',\n", - " tensor([[[-0.4608, 0.4926, -0.2625, 0.3060]],\n", - " \n", - " [[-0.0932, 0.0153, 0.2298, -0.1735]],\n", - " \n", - " [[-0.1927, 0.1979, -0.1773, 0.3277]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.0538, -0.2180, -0.4857, -0.1428]],\n", - " \n", - " [[-0.1736, 0.2405, 0.3148, -0.4481]],\n", - " \n", - " [[-0.4971, -0.1558, 0.2762, -0.1849]]])),\n", - " ('model.layers.14.mixer.conv1d.bias',\n", - " tensor([-0.2181, -0.2375, 0.0896, ..., 0.0744, 0.0857, 0.4347])),\n", - " ('model.layers.14.mixer.out_proj.weight',\n", - " tensor([[-3.8364e-04, 2.4458e-02, 5.8783e-03, ..., -1.3479e-02,\n", - " -2.4306e-02, 5.7698e-03],\n", - " [ 4.5843e-02, -3.9217e-03, -6.9897e-03, ..., 5.5401e-03,\n", - " -1.4523e-02, 1.2266e-02],\n", - " [-7.1069e-03, 5.5550e-03, 1.1359e-02, ..., 3.5839e-02,\n", - " 1.0787e-02, 8.4053e-03],\n", - " ...,\n", - " [ 3.3029e-03, 5.4333e-03, -9.3382e-03, ..., -1.7376e-02,\n", - " 1.5601e-02, -6.3227e-03],\n", - " [-6.9199e-03, -1.6950e-02, 1.5155e-03, ..., 1.2324e-02,\n", - " 1.2259e-02, 5.5500e-02],\n", - " [-1.6177e-02, -6.5257e-05, -9.3656e-03, ..., 1.0653e-02,\n", - " 1.8864e-02, -1.2508e-02]])),\n", - " ('model.layers.14.mlp.gate_proj.weight',\n", - " tensor([[ 0.0279, 0.0025, 0.0214, ..., -0.0137, -0.0042, 0.0172],\n", - " [-0.0240, -0.0150, 0.0170, ..., 0.0090, 0.0002, 0.0172],\n", - " [-0.0181, 0.0052, -0.0418, ..., 0.0106, 0.0052, -0.0264],\n", - " ...,\n", - " [-0.0295, 0.0323, 0.0387, ..., -0.0116, -0.0140, -0.0053],\n", - " [ 0.0411, 0.0189, 0.0236, ..., 0.0094, -0.0176, -0.0066],\n", - " [ 0.0004, 0.0291, 0.0402, ..., 0.0127, -0.0009, 0.0010]])),\n", - " ('model.layers.14.mlp.up_proj.weight',\n", - " tensor([[ 0.0198, -0.0115, -0.0045, ..., 0.0273, 0.0012, -0.0082],\n", - " [-0.0217, 0.0075, 0.0006, ..., 0.0047, -0.0416, -0.0011],\n", - " [ 0.0012, -0.0214, -0.0211, ..., 0.0030, -0.0176, -0.0215],\n", - " ...,\n", - " [ 0.0062, -0.0305, 0.0310, ..., 0.0044, -0.0379, 0.0155],\n", - " [-0.0062, 0.0451, 0.0167, ..., 0.0062, -0.0033, 0.0012],\n", - " [ 0.0293, -0.0186, 0.0295, ..., 0.0092, 0.0100, 0.0038]])),\n", - " ('model.layers.14.mlp.down_proj.weight',\n", - " tensor([[ 0.0019, 0.0114, -0.0202, ..., 0.0227, -0.0227, -0.0005],\n", - " [-0.0437, -0.0045, -0.0385, ..., -0.0083, -0.0135, 0.0172],\n", - " [-0.0032, -0.0024, 0.0137, ..., 0.0071, 0.0034, 0.0104],\n", - " ...,\n", - " [ 0.0210, -0.0237, -0.0166, ..., -0.0105, 0.0490, 0.0155],\n", - " [-0.0109, 0.0112, 0.0082, ..., -0.0342, -0.0133, -0.0086],\n", - " [ 0.0282, -0.0210, -0.0127, ..., -0.0047, -0.0126, 0.0103]])),\n", - " ('model.layers.14.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.14.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.15.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.15.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.15.mixer.in_proj.weight',\n", - " tensor([[-0.0098, -0.0201, -0.0033, ..., -0.0289, 0.0275, 0.0186],\n", - " [ 0.0048, 0.0075, -0.0033, ..., 0.0011, 0.0042, 0.0040],\n", - " [-0.0079, -0.0025, 0.0018, ..., -0.0051, -0.0231, -0.0022],\n", - " ...,\n", - " [ 0.0186, -0.0104, -0.0062, ..., 0.0086, -0.0007, -0.0653],\n", - " [-0.0212, 0.0034, 0.0019, ..., 0.0167, 0.0050, 0.0120],\n", - " [ 0.0066, 0.0381, -0.0225, ..., -0.0043, 0.0229, -0.0004]])),\n", - " ('model.layers.15.mixer.conv1d.weight',\n", - " tensor([[[ 0.2306, 0.2721, 0.3406, 0.4513]],\n", - " \n", - " [[ 0.0991, 0.4973, 0.0010, -0.1445]],\n", - " \n", - " [[ 0.2975, 0.4813, 0.2817, -0.0468]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.0104, -0.1473, 0.1685, -0.4390]],\n", - " \n", - " [[ 0.3669, 0.3461, 0.0845, 0.3576]],\n", - " \n", - " [[-0.1177, 0.0524, 0.4329, 0.0687]]])),\n", - " ('model.layers.15.mixer.conv1d.bias',\n", - " tensor([-0.0356, 0.4173, 0.3287, ..., -0.0141, 0.1365, 0.2086])),\n", - " ('model.layers.15.mixer.out_proj.weight',\n", - " tensor([[-0.0137, -0.0239, -0.0133, ..., -0.0177, -0.0125, -0.0015],\n", - " [ 0.0168, 0.0120, 0.0034, ..., 0.0098, 0.0098, 0.0110],\n", - " [-0.0315, 0.0447, 0.0189, ..., 0.0305, 0.0131, -0.0230],\n", - " ...,\n", - " [-0.0480, 0.0170, 0.0025, ..., 0.0317, -0.0378, -0.0236],\n", - " [-0.0319, -0.0290, 0.0023, ..., -0.0093, 0.0354, 0.0126],\n", - " [-0.0107, 0.0100, -0.0101, ..., 0.0046, 0.0205, -0.0203]])),\n", - " ('model.layers.15.mlp.gate_proj.weight',\n", - " tensor([[ 0.0160, 0.0432, 0.0073, ..., -0.0003, -0.0170, 0.0236],\n", - " [ 0.0055, 0.0066, -0.0311, ..., 0.0049, -0.0130, 0.0040],\n", - " [-0.0147, -0.0184, 0.0281, ..., 0.0016, 0.0077, -0.0072],\n", - " ...,\n", - " [-0.0049, -0.0434, -0.0118, ..., 0.0137, -0.0225, -0.0058],\n", - " [ 0.0221, -0.0077, 0.0029, ..., 0.0087, -0.0361, -0.0100],\n", - " [ 0.0263, 0.0228, 0.0050, ..., -0.0557, 0.0037, 0.0196]])),\n", - " ('model.layers.15.mlp.up_proj.weight',\n", - " tensor([[ 0.0093, -0.0189, 0.0173, ..., 0.0276, 0.0075, -0.0215],\n", - " [-0.0147, 0.0241, 0.0109, ..., 0.0120, 0.0032, 0.0327],\n", - " [ 0.0036, 0.0127, 0.0116, ..., 0.0100, -0.0003, 0.0233],\n", - " ...,\n", - " [-0.0063, 0.0160, 0.0138, ..., -0.0078, -0.0098, 0.0150],\n", - " [ 0.0138, -0.0236, 0.0109, ..., -0.0156, -0.0143, 0.0273],\n", - " [ 0.0345, 0.0201, -0.0119, ..., -0.0182, 0.0053, 0.0105]])),\n", - " ('model.layers.15.mlp.down_proj.weight',\n", - " tensor([[-0.0114, 0.0138, -0.0110, ..., 0.0084, -0.0144, 0.0100],\n", - " [ 0.0016, -0.0069, 0.0172, ..., -0.0394, 0.0368, 0.0468],\n", - " [-0.0184, -0.0094, -0.0273, ..., -0.0195, 0.0148, 0.0142],\n", - " ...,\n", - " [ 0.0311, 0.0093, -0.0130, ..., -0.0023, 0.0395, -0.0375],\n", - " [ 0.0056, 0.0027, 0.0061, ..., 0.0058, 0.0225, -0.0153],\n", - " [-0.0031, -0.0107, 0.0020, ..., -0.0173, -0.0050, 0.0423]])),\n", - " ('model.layers.15.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.15.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.16.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.16.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.16.mixer.in_proj.weight',\n", - " tensor([[-0.0063, 0.0006, 0.0130, ..., 0.0186, 0.0408, 0.0126],\n", - " [-0.0015, -0.0029, 0.0268, ..., -0.0042, -0.0209, -0.0046],\n", - " [-0.0034, -0.0286, 0.0185, ..., -0.0125, 0.0050, 0.0033],\n", - " ...,\n", - " [ 0.0045, 0.0133, 0.0220, ..., 0.0165, 0.0287, 0.0371],\n", - " [ 0.0100, -0.0232, 0.0103, ..., -0.0083, -0.0105, -0.0187],\n", - " [-0.0412, -0.0035, 0.0028, ..., 0.0286, 0.0349, -0.0037]])),\n", - " ('model.layers.16.mixer.conv1d.weight',\n", - " tensor([[[-0.1874, 0.2517, 0.0537, 0.1258]],\n", - " \n", - " [[ 0.1465, 0.2013, 0.3547, 0.2689]],\n", - " \n", - " [[ 0.4834, 0.4906, 0.0844, -0.0541]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.3004, 0.3313, 0.1688, 0.4381]],\n", - " \n", - " [[-0.0606, 0.3455, -0.0910, 0.1148]],\n", - " \n", - " [[-0.1421, -0.1254, -0.2353, -0.1675]]])),\n", - " ('model.layers.16.mixer.conv1d.bias',\n", - " tensor([ 0.2835, 0.2361, 0.1225, ..., -0.2119, -0.1929, 0.3877])),\n", - " ('model.layers.16.mixer.out_proj.weight',\n", - " tensor([[-0.0121, 0.0194, 0.0060, ..., -0.0029, -0.0147, -0.0085],\n", - " [-0.0216, -0.0012, 0.0287, ..., 0.0102, -0.0133, -0.0153],\n", - " [ 0.0136, -0.0296, 0.0417, ..., -0.0118, -0.0283, 0.0359],\n", - " ...,\n", - " [-0.0263, -0.0003, 0.0022, ..., 0.0135, -0.0519, -0.0254],\n", - " [ 0.0121, -0.0144, -0.0026, ..., 0.0096, 0.0130, 0.0095],\n", - " [-0.0147, -0.0217, 0.0099, ..., 0.0267, -0.0072, -0.0213]])),\n", - " ('model.layers.16.mlp.gate_proj.weight',\n", - " tensor([[ 0.0103, -0.0396, -0.0127, ..., 0.0020, -0.0055, 0.0291],\n", - " [ 0.0194, 0.0357, -0.0020, ..., -0.0112, 0.0448, -0.0224],\n", - " [-0.0390, 0.0142, -0.0224, ..., -0.0030, 0.0102, 0.0078],\n", - " ...,\n", - " [ 0.0165, -0.0251, 0.0196, ..., 0.0213, 0.0040, -0.0228],\n", - " [-0.0145, 0.0218, -0.0032, ..., -0.0240, -0.0079, 0.0256],\n", - " [ 0.0539, -0.0027, -0.0227, ..., -0.0184, -0.0109, 0.0236]])),\n", - " ('model.layers.16.mlp.up_proj.weight',\n", - " tensor([[ 7.1125e-03, -3.2583e-04, -2.6297e-02, ..., -4.9575e-03,\n", - " -1.2243e-02, -1.3005e-02],\n", - " [ 2.5637e-02, -1.1874e-02, 1.1376e-02, ..., -1.4700e-02,\n", - " -1.5193e-02, 2.6111e-03],\n", - " [-4.8919e-02, -4.9716e-04, 5.8527e-03, ..., 8.6775e-05,\n", - " 1.0694e-02, 3.7682e-03],\n", - " ...,\n", - " [ 8.8393e-03, -4.3317e-02, 2.8372e-02, ..., 2.2709e-02,\n", - " -4.8128e-03, 1.6899e-02],\n", - " [ 1.3257e-02, 2.1000e-02, 1.5035e-03, ..., 1.5603e-02,\n", - " -5.5857e-03, 4.0449e-03],\n", - " [-2.6754e-02, -1.6263e-02, 1.9013e-02, ..., -9.0918e-03,\n", - " -8.0242e-03, -1.0925e-02]])),\n", - " ('model.layers.16.mlp.down_proj.weight',\n", - " tensor([[ 0.0207, -0.0038, -0.0234, ..., 0.0299, -0.0329, -0.0117],\n", - " [-0.0316, 0.0032, 0.0131, ..., 0.0020, -0.0320, 0.0381],\n", - " [-0.0192, -0.0031, -0.0030, ..., -0.0224, 0.0037, 0.0085],\n", - " ...,\n", - " [ 0.0044, 0.0281, -0.0208, ..., 0.0179, -0.0085, -0.0010],\n", - " [-0.0076, -0.0008, 0.0483, ..., 0.0082, -0.0177, -0.0039],\n", - " [ 0.0224, 0.0019, 0.0181, ..., 0.0143, -0.0252, 0.0022]])),\n", - " ('model.layers.16.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.16.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.17.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.17.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.17.mixer.in_proj.weight',\n", - " tensor([[-0.0115, 0.0061, -0.0062, ..., -0.0132, -0.0047, 0.0274],\n", - " [ 0.0076, 0.0278, -0.0147, ..., 0.0439, -0.0093, -0.0154],\n", - " [-0.0383, -0.0264, -0.0053, ..., -0.0206, 0.0275, 0.0188],\n", - " ...,\n", - " [ 0.0096, 0.0228, 0.0351, ..., 0.0227, 0.0138, -0.0164],\n", - " [ 0.0321, -0.0293, -0.0054, ..., 0.0109, -0.0113, -0.0130],\n", - " [-0.0120, -0.0132, 0.0092, ..., -0.0338, 0.0308, -0.0135]])),\n", - " ('model.layers.17.mixer.conv1d.weight',\n", - " tensor([[[-0.4933, 0.4156, 0.2523, -0.0026]],\n", - " \n", - " [[-0.2572, 0.4916, 0.3642, -0.2145]],\n", - " \n", - " [[ 0.0261, 0.4852, -0.1448, 0.2288]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.3698, -0.4122, -0.2264, -0.1378]],\n", - " \n", - " [[ 0.1447, 0.4556, -0.0466, 0.0389]],\n", - " \n", - " [[-0.3891, 0.4149, 0.1454, -0.4282]]])),\n", - " ('model.layers.17.mixer.conv1d.bias',\n", - " tensor([-0.3919, -0.4015, 0.2591, ..., -0.3368, 0.2285, 0.1701])),\n", - " ('model.layers.17.mixer.out_proj.weight',\n", - " tensor([[-0.0127, -0.0155, 0.0193, ..., 0.0204, 0.0025, 0.0159],\n", - " [ 0.0192, 0.0194, -0.0169, ..., -0.0062, 0.0262, 0.0070],\n", - " [ 0.0397, 0.0009, 0.0189, ..., -0.0082, 0.0352, -0.0150],\n", - " ...,\n", - " [-0.0339, -0.0142, -0.0151, ..., 0.0229, 0.0032, 0.0038],\n", - " [ 0.0235, 0.0319, -0.0137, ..., -0.0121, 0.0112, 0.0162],\n", - " [ 0.0060, 0.0102, -0.0016, ..., 0.0118, 0.0158, -0.0140]])),\n", - " ('model.layers.17.mlp.gate_proj.weight',\n", - " tensor([[ 0.0285, -0.0090, -0.0095, ..., 0.0315, -0.0065, 0.0189],\n", - " [ 0.0040, -0.0358, -0.0039, ..., -0.0074, -0.0285, -0.0223],\n", - " [ 0.0202, 0.0021, -0.0104, ..., -0.0083, 0.0300, -0.0267],\n", - " ...,\n", - " [ 0.0093, -0.0008, -0.0372, ..., 0.0422, 0.0309, 0.0095],\n", - " [ 0.0027, 0.0252, 0.0378, ..., -0.0238, 0.0234, -0.0062],\n", - " [-0.0061, -0.0022, -0.0033, ..., 0.0157, -0.0296, 0.0034]])),\n", - " ('model.layers.17.mlp.up_proj.weight',\n", - " tensor([[ 0.0061, -0.0135, 0.0029, ..., 0.0328, 0.0008, -0.0072],\n", - " [ 0.0145, -0.0226, -0.0095, ..., 0.0114, 0.0224, -0.0160],\n", - " [ 0.0097, -0.0024, -0.0179, ..., 0.0073, -0.0061, -0.0195],\n", - " ...,\n", - " [ 0.0308, -0.0014, 0.0104, ..., 0.0047, 0.0026, 0.0243],\n", - " [-0.0364, 0.0350, 0.0031, ..., -0.0072, 0.0267, 0.0017],\n", - " [ 0.0227, -0.0146, 0.0146, ..., -0.0434, -0.0159, 0.0230]])),\n", - " ('model.layers.17.mlp.down_proj.weight',\n", - " tensor([[-0.0216, 0.0211, 0.0136, ..., -0.0004, 0.0051, 0.0415],\n", - " [-0.0061, -0.0123, 0.0156, ..., -0.0005, -0.0183, -0.0137],\n", - " [-0.0146, -0.0274, -0.0439, ..., -0.0033, -0.0030, -0.0074],\n", - " ...,\n", - " [-0.0108, -0.0005, -0.0094, ..., -0.0243, 0.0065, -0.0005],\n", - " [-0.0126, 0.0124, -0.0006, ..., -0.0282, -0.0110, 0.0128],\n", - " [-0.0162, -0.0102, 0.0025, ..., -0.0084, 0.0066, -0.0074]])),\n", - " ('model.layers.17.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.17.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.18.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.18.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.18.mixer.in_proj.weight',\n", - " tensor([[-9.4961e-03, -1.2349e-04, -7.1455e-03, ..., 1.9508e-02,\n", - " -6.8715e-03, -1.3565e-02],\n", - " [-2.9701e-03, 3.1580e-03, 1.8849e-02, ..., 7.6566e-03,\n", - " -1.0968e-02, -8.0445e-03],\n", - " [-1.5402e-02, -6.7267e-03, 9.6119e-03, ..., 1.9799e-02,\n", - " 2.0198e-03, -1.7366e-03],\n", - " ...,\n", - " [ 8.2379e-03, 5.1668e-03, 3.8116e-02, ..., -3.8710e-03,\n", - " 1.4452e-02, -2.5152e-02],\n", - " [ 1.1949e-02, -1.2245e-03, 1.0568e-02, ..., -3.1690e-02,\n", - " 3.8135e-05, 1.7263e-02],\n", - " [ 1.6173e-04, 5.6721e-04, 2.1043e-02, ..., -3.6167e-02,\n", - " -1.1129e-02, -9.6768e-03]])),\n", - " ('model.layers.18.mixer.conv1d.weight',\n", - " tensor([[[ 0.2776, 0.2169, -0.2840, 0.1736]],\n", - " \n", - " [[-0.0598, -0.2654, 0.2423, -0.0874]],\n", - " \n", - " [[-0.3612, -0.3049, -0.3197, -0.2763]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.1389, 0.2034, -0.1739, 0.1634]],\n", - " \n", - " [[-0.2836, -0.0471, 0.1284, -0.0099]],\n", - " \n", - " [[ 0.2952, -0.2676, -0.3961, 0.2656]]])),\n", - " ('model.layers.18.mixer.conv1d.bias',\n", - " tensor([ 0.1804, 0.0336, 0.4006, ..., 0.2943, -0.1079, 0.0963])),\n", - " ('model.layers.18.mixer.out_proj.weight',\n", - " tensor([[ 0.0109, -0.0181, 0.0148, ..., -0.0105, -0.0011, -0.0052],\n", - " [ 0.0507, 0.0100, -0.0273, ..., -0.0069, 0.0054, 0.0129],\n", - " [ 0.0014, 0.0423, -0.0193, ..., -0.0023, -0.0293, 0.0004],\n", - " ...,\n", - " [ 0.0420, -0.0401, 0.0205, ..., 0.0135, -0.0089, -0.0023],\n", - " [ 0.0242, 0.0273, 0.0139, ..., -0.0402, 0.0061, 0.0119],\n", - " [-0.0145, 0.0102, 0.0245, ..., 0.0205, -0.0251, 0.0006]])),\n", - " ('model.layers.18.mlp.gate_proj.weight',\n", - " tensor([[ 0.0241, -0.0086, 0.0136, ..., -0.0219, -0.0064, -0.0142],\n", - " [-0.0067, 0.0252, 0.0246, ..., -0.0205, -0.0273, 0.0137],\n", - " [-0.0030, 0.0055, -0.0063, ..., 0.0107, 0.0083, -0.0037],\n", - " ...,\n", - " [-0.0154, 0.0101, 0.0221, ..., 0.0025, -0.0109, 0.0133],\n", - " [-0.0175, 0.0105, -0.0246, ..., 0.0244, 0.0023, 0.0080],\n", - " [-0.0060, 0.0183, 0.0297, ..., 0.0420, -0.0006, -0.0119]])),\n", - " ('model.layers.18.mlp.up_proj.weight',\n", - " tensor([[ 0.0066, -0.0009, -0.0070, ..., -0.0064, 0.0002, 0.0196],\n", - " [-0.0173, -0.0362, -0.0011, ..., 0.0158, -0.0198, -0.0046],\n", - " [ 0.0133, -0.0090, -0.0092, ..., 0.0039, -0.0052, -0.0101],\n", - " ...,\n", - " [ 0.0077, -0.0063, 0.0010, ..., 0.0091, 0.0218, 0.0132],\n", - " [ 0.0005, -0.0046, 0.0207, ..., 0.0112, 0.0183, -0.0020],\n", - " [ 0.0238, -0.0022, 0.0364, ..., -0.0042, 0.0237, 0.0183]])),\n", - " ('model.layers.18.mlp.down_proj.weight',\n", - " tensor([[ 0.0305, 0.0178, -0.0264, ..., -0.0158, 0.0135, 0.0132],\n", - " [ 0.0248, -0.0061, 0.0144, ..., -0.0165, 0.0098, 0.0410],\n", - " [-0.0156, -0.0039, 0.0112, ..., -0.0431, -0.0084, -0.0197],\n", - " ...,\n", - " [ 0.0071, 0.0236, -0.0038, ..., 0.0035, -0.0236, 0.0106],\n", - " [-0.0369, -0.0029, -0.0182, ..., -0.0008, -0.0417, 0.0064],\n", - " [-0.0273, 0.0207, 0.0130, ..., 0.0372, 0.0163, 0.0273]])),\n", - " ('model.layers.18.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.18.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.19.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.19.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.19.mixer.in_proj.weight',\n", - " tensor([[-0.0079, 0.0147, -0.0337, ..., -0.0201, -0.0254, 0.0035],\n", - " [ 0.0139, 0.0054, -0.0093, ..., -0.0208, -0.0289, -0.0087],\n", - " [ 0.0004, -0.0034, 0.0090, ..., -0.0109, -0.0093, 0.0102],\n", - " ...,\n", - " [ 0.0128, 0.0015, -0.0101, ..., -0.0482, -0.0217, 0.0144],\n", - " [-0.0100, -0.0079, 0.0286, ..., -0.0025, -0.0210, 0.0164],\n", - " [-0.0264, 0.0015, 0.0031, ..., 0.0027, 0.0131, -0.0384]])),\n", - " ('model.layers.19.mixer.conv1d.weight',\n", - " tensor([[[ 0.4729, 0.3708, -0.4394, -0.3549]],\n", - " \n", - " [[ 0.2230, -0.3271, 0.3017, -0.2552]],\n", - " \n", - " [[-0.0417, 0.1893, 0.4552, -0.0644]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.2565, 0.0407, 0.3521, 0.4116]],\n", - " \n", - " [[ 0.0795, -0.0374, 0.1034, 0.4254]],\n", - " \n", - " [[ 0.3333, 0.2431, 0.3459, -0.2676]]])),\n", - " ('model.layers.19.mixer.conv1d.bias',\n", - " tensor([-0.2287, -0.4446, -0.2300, ..., -0.2317, -0.3395, 0.4310])),\n", - " ('model.layers.19.mixer.out_proj.weight',\n", - " tensor([[-0.0456, -0.0167, -0.0117, ..., -0.0068, -0.0150, 0.0125],\n", - " [ 0.0194, 0.0172, -0.0232, ..., -0.0202, -0.0066, 0.0083],\n", - " [ 0.0320, -0.0065, 0.0274, ..., 0.0200, 0.0090, 0.0105],\n", - " ...,\n", - " [ 0.0315, 0.0415, 0.0128, ..., -0.0143, -0.0338, -0.0231],\n", - " [ 0.0227, -0.0177, -0.0034, ..., 0.0174, 0.0006, 0.0212],\n", - " [ 0.0358, 0.0084, 0.0075, ..., 0.0091, 0.0062, 0.0114]])),\n", - " ('model.layers.19.mlp.gate_proj.weight',\n", - " tensor([[-0.0010, 0.0156, 0.0042, ..., -0.0181, 0.0113, 0.0089],\n", - " [-0.0182, 0.0068, -0.0043, ..., -0.0323, -0.0019, -0.0045],\n", - " [ 0.0168, -0.0093, -0.0162, ..., -0.0074, 0.0166, -0.0334],\n", - " ...,\n", - " [ 0.0038, -0.0211, -0.0054, ..., -0.0229, 0.0193, -0.0210],\n", - " [ 0.0153, -0.0372, 0.0119, ..., 0.0043, -0.0097, -0.0025],\n", - " [ 0.0037, 0.0208, -0.0135, ..., 0.0052, -0.0125, -0.0282]])),\n", - " ('model.layers.19.mlp.up_proj.weight',\n", - " tensor([[-0.0026, 0.0360, 0.0161, ..., 0.0199, -0.0283, -0.0026],\n", - " [ 0.0185, 0.0122, -0.0299, ..., 0.0125, 0.0063, 0.0387],\n", - " [-0.0085, -0.0010, -0.0054, ..., -0.0088, -0.0034, -0.0179],\n", - " ...,\n", - " [-0.0179, 0.0211, -0.0003, ..., -0.0071, -0.0145, 0.0235],\n", - " [-0.0002, 0.0060, -0.0172, ..., -0.0086, 0.0175, -0.0232],\n", - " [-0.0081, -0.0280, -0.0152, ..., -0.0221, 0.0047, -0.0077]])),\n", - " ('model.layers.19.mlp.down_proj.weight',\n", - " tensor([[ 0.0038, -0.0027, -0.0122, ..., 0.0090, 0.0044, 0.0128],\n", - " [ 0.0054, 0.0075, 0.0116, ..., 0.0232, 0.0130, 0.0298],\n", - " [-0.0498, -0.0208, -0.0127, ..., 0.0166, -0.0221, 0.0038],\n", - " ...,\n", - " [ 0.0101, 0.0051, 0.0209, ..., 0.0137, -0.0225, 0.0142],\n", - " [-0.0433, -0.0217, -0.0167, ..., -0.0179, -0.0191, -0.0021],\n", - " [-0.0020, 0.0084, -0.0114, ..., 0.0324, 0.0216, -0.0062]])),\n", - " ('model.layers.19.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.19.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.20.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.20.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.20.mixer.in_proj.weight',\n", - " tensor([[ 3.3776e-02, 3.6619e-02, 6.8532e-03, ..., 5.7664e-02,\n", - " -2.3083e-02, -6.2962e-02],\n", - " [-2.9787e-03, -2.5050e-03, -3.4841e-03, ..., 5.4946e-03,\n", - " 9.0683e-03, 2.1583e-04],\n", - " [ 7.4430e-03, -1.0495e-02, 3.5169e-02, ..., -5.1808e-02,\n", - " 3.2650e-03, -3.1967e-02],\n", - " ...,\n", - " [-5.8685e-02, 4.8452e-02, -1.2612e-02, ..., 1.2174e-02,\n", - " 1.0566e-02, -4.9561e-03],\n", - " [ 3.1722e-03, -2.9390e-03, 1.4502e-05, ..., -2.3297e-02,\n", - " -7.5403e-03, -1.3599e-02],\n", - " [ 1.4845e-02, -4.3150e-02, -1.0338e-02, ..., -1.1149e-02,\n", - " -3.3432e-02, 3.8337e-03]])),\n", - " ('model.layers.20.mixer.conv1d.weight',\n", - " tensor([[[-0.3842, 0.2397, 0.4873, -0.3091]],\n", - " \n", - " [[-0.1886, 0.0751, 0.2026, -0.2674]],\n", - " \n", - " [[-0.0594, 0.3119, -0.2404, 0.1652]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.0028, 0.1315, 0.0515, 0.3189]],\n", - " \n", - " [[-0.1461, -0.0457, -0.0536, -0.2306]],\n", - " \n", - " [[-0.3025, -0.3339, 0.3007, -0.3007]]])),\n", - " ('model.layers.20.mixer.conv1d.bias',\n", - " tensor([-0.4901, -0.3784, -0.0173, ..., -0.3946, -0.0728, 0.2187])),\n", - " ('model.layers.20.mixer.out_proj.weight',\n", - " tensor([[ 0.0095, -0.0037, -0.0218, ..., 0.0080, 0.0062, 0.0246],\n", - " [-0.0197, 0.0037, 0.0076, ..., 0.0171, 0.0238, -0.0195],\n", - " [ 0.0364, -0.0165, 0.0224, ..., -0.0099, 0.0007, 0.0340],\n", - " ...,\n", - " [ 0.0235, -0.0072, -0.0319, ..., 0.0045, -0.0196, 0.0011],\n", - " [-0.0369, 0.0083, 0.0021, ..., -0.0357, -0.0039, -0.0150],\n", - " [-0.0174, -0.0211, 0.0111, ..., 0.0251, 0.0040, -0.0308]])),\n", - " ('model.layers.20.mlp.gate_proj.weight',\n", - " tensor([[ 0.0161, -0.0019, -0.0473, ..., 0.0019, 0.0075, -0.0038],\n", - " [-0.0321, -0.0020, -0.0100, ..., 0.0035, 0.0291, -0.0058],\n", - " [-0.0158, 0.0020, 0.0353, ..., 0.0125, 0.0228, -0.0392],\n", - " ...,\n", - " [ 0.0113, 0.0171, 0.0235, ..., 0.0043, 0.0378, 0.0391],\n", - " [ 0.0090, 0.0067, 0.0031, ..., 0.0291, -0.0052, -0.0216],\n", - " [ 0.0042, -0.0112, -0.0161, ..., -0.0063, -0.0156, 0.0211]])),\n", - " ('model.layers.20.mlp.up_proj.weight',\n", - " tensor([[ 0.0104, -0.0302, -0.0220, ..., -0.0072, -0.0083, -0.0066],\n", - " [ 0.0409, -0.0116, -0.0125, ..., 0.0182, 0.0267, 0.0099],\n", - " [-0.0055, 0.0104, 0.0027, ..., -0.0075, -0.0368, -0.0092],\n", - " ...,\n", - " [-0.0089, 0.0243, -0.0028, ..., -0.0136, -0.0176, -0.0054],\n", - " [ 0.0088, 0.0365, -0.0354, ..., 0.0035, 0.0280, 0.0155],\n", - " [-0.0472, 0.0088, 0.0102, ..., -0.0120, 0.0004, -0.0011]])),\n", - " ('model.layers.20.mlp.down_proj.weight',\n", - " tensor([[-0.0089, -0.0112, -0.0007, ..., 0.0360, -0.0077, 0.0261],\n", - " [ 0.0080, -0.0128, -0.0445, ..., 0.0095, -0.0298, 0.0176],\n", - " [ 0.0357, -0.0262, 0.0028, ..., 0.0162, 0.0089, 0.0050],\n", - " ...,\n", - " [-0.0129, 0.0216, 0.0125, ..., -0.0062, -0.0344, -0.0218],\n", - " [ 0.0006, -0.0143, -0.0099, ..., -0.0359, 0.0268, 0.0259],\n", - " [ 0.0222, -0.0154, 0.0013, ..., 0.0108, -0.0077, 0.0186]])),\n", - " ('model.layers.20.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.20.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.21.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.21.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.21.mixer.in_proj.weight',\n", - " tensor([[-0.0300, 0.0058, -0.0107, ..., -0.0318, 0.0350, 0.0350],\n", - " [ 0.0186, 0.0238, -0.0268, ..., 0.0142, -0.0277, -0.0095],\n", - " [-0.0061, 0.0083, 0.0072, ..., 0.0161, 0.0027, -0.0051],\n", - " ...,\n", - " [-0.0358, 0.0330, 0.0151, ..., -0.0376, 0.0057, 0.0174],\n", - " [-0.0021, 0.0068, 0.0151, ..., 0.0077, -0.0353, 0.0095],\n", - " [-0.0113, -0.0043, 0.0064, ..., -0.0063, -0.0232, -0.0058]])),\n", - " ('model.layers.21.mixer.conv1d.weight',\n", - " tensor([[[ 0.0354, 0.0496, -0.0106, 0.0084]],\n", - " \n", - " [[ 0.2553, 0.3217, -0.0078, -0.2333]],\n", - " \n", - " [[-0.1390, 0.0323, 0.4914, -0.2047]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.2243, 0.2984, 0.0188, 0.1830]],\n", - " \n", - " [[ 0.0756, 0.1443, -0.4898, -0.2082]],\n", - " \n", - " [[-0.3685, -0.1311, -0.4037, -0.3276]]])),\n", - " ('model.layers.21.mixer.conv1d.bias',\n", - " tensor([-0.2444, -0.1852, 0.2215, ..., 0.4515, 0.2532, -0.2388])),\n", - " ('model.layers.21.mixer.out_proj.weight',\n", - " tensor([[ 0.0232, 0.0328, 0.0026, ..., -0.0575, 0.0157, -0.0072],\n", - " [-0.0226, 0.0058, -0.0346, ..., 0.0092, 0.0078, 0.0108],\n", - " [ 0.0045, 0.0247, 0.0150, ..., -0.0085, 0.0268, 0.0253],\n", - " ...,\n", - " [ 0.0268, 0.0092, 0.0141, ..., 0.0062, 0.0177, -0.0405],\n", - " [ 0.0163, -0.0269, -0.0177, ..., 0.0029, -0.0080, -0.0036],\n", - " [ 0.0064, 0.0126, 0.0126, ..., -0.0400, -0.0015, -0.0088]])),\n", - " ('model.layers.21.mlp.gate_proj.weight',\n", - " tensor([[-3.7050e-02, 4.5834e-02, 1.9280e-02, ..., 1.6761e-02,\n", - " -5.8295e-03, -1.4284e-02],\n", - " [ 3.0156e-02, 3.2832e-02, 1.1083e-02, ..., -5.8261e-03,\n", - " -3.9076e-02, 5.3379e-03],\n", - " [ 1.3118e-03, 3.1510e-02, 1.5472e-02, ..., 1.8213e-02,\n", - " -2.5180e-02, 6.1512e-04],\n", - " ...,\n", - " [ 4.2010e-02, 1.0362e-02, 7.1759e-03, ..., 1.8667e-03,\n", - " -7.2165e-03, 1.6297e-02],\n", - " [ 1.8175e-02, 1.2840e-02, 3.2857e-03, ..., 1.8495e-02,\n", - " -7.7709e-03, 4.3964e-04],\n", - " [-9.2628e-05, 2.1701e-02, 2.1256e-02, ..., 2.5241e-02,\n", - " 5.0683e-02, -2.5481e-02]])),\n", - " ('model.layers.21.mlp.up_proj.weight',\n", - " tensor([[ 0.0228, 0.0082, -0.0083, ..., 0.0288, 0.0211, 0.0085],\n", - " [-0.0155, 0.0179, 0.0111, ..., -0.0218, -0.0162, -0.0052],\n", - " [ 0.0016, 0.0009, 0.0230, ..., -0.0017, 0.0131, 0.0255],\n", - " ...,\n", - " [-0.0098, -0.0098, -0.0188, ..., 0.0063, 0.0082, 0.0052],\n", - " [-0.0028, 0.0249, -0.0153, ..., -0.0208, 0.0130, -0.0093],\n", - " [ 0.0105, -0.0072, -0.0379, ..., 0.0035, 0.0182, 0.0307]])),\n", - " ('model.layers.21.mlp.down_proj.weight',\n", - " tensor([[-0.0445, -0.0116, 0.0058, ..., 0.0081, -0.0099, 0.0094],\n", - " [ 0.0106, -0.0387, 0.0051, ..., 0.0017, 0.0075, 0.0136],\n", - " [ 0.0022, 0.0058, -0.0268, ..., -0.0088, -0.0149, 0.0125],\n", - " ...,\n", - " [-0.0015, -0.0156, -0.0225, ..., 0.0100, -0.0118, -0.0019],\n", - " [-0.0161, -0.0225, -0.0060, ..., 0.0073, -0.0072, 0.0205],\n", - " [-0.0112, 0.0046, -0.0089, ..., -0.0014, -0.0221, 0.0124]])),\n", - " ('model.layers.21.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.21.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.22.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.22.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.22.mixer.in_proj.weight',\n", - " tensor([[-1.1591e-02, -6.0118e-03, -2.2227e-03, ..., -7.1433e-03,\n", - " -1.5757e-02, -1.5315e-03],\n", - " [-7.6057e-03, -4.2199e-02, 1.4478e-02, ..., 5.6496e-02,\n", - " 8.9105e-05, -3.8658e-03],\n", - " [-1.0330e-03, 2.3586e-02, 2.1835e-02, ..., -1.4911e-03,\n", - " -1.6604e-02, -4.5245e-03],\n", - " ...,\n", - " [-6.7261e-03, -6.9826e-03, -9.3003e-03, ..., -4.3939e-02,\n", - " 2.3792e-02, -5.5165e-03],\n", - " [-1.1798e-02, -3.4709e-02, -4.1277e-03, ..., -5.1867e-03,\n", - " 5.2496e-03, -6.0055e-03],\n", - " [ 7.3402e-04, -1.9525e-02, -5.8966e-03, ..., -1.5972e-02,\n", - " -1.5446e-02, -2.7164e-02]])),\n", - " ('model.layers.22.mixer.conv1d.weight',\n", - " tensor([[[-0.3791, 0.0616, 0.0369, 0.1365]],\n", - " \n", - " [[-0.4674, -0.4557, 0.3894, -0.4765]],\n", - " \n", - " [[ 0.3333, 0.2265, 0.1385, -0.1352]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.4363, -0.3526, -0.3982, -0.1049]],\n", - " \n", - " [[ 0.4798, -0.3912, 0.4059, -0.1379]],\n", - " \n", - " [[-0.4427, 0.4661, -0.1990, 0.1668]]])),\n", - " ('model.layers.22.mixer.conv1d.bias',\n", - " tensor([-0.1823, -0.4117, 0.4443, ..., -0.0024, 0.2144, -0.4922])),\n", - " ('model.layers.22.mixer.out_proj.weight',\n", - " tensor([[ 0.0138, -0.0169, -0.0349, ..., -0.0045, 0.0023, -0.0389],\n", - " [ 0.0250, 0.0040, -0.0259, ..., 0.0458, 0.0311, -0.0054],\n", - " [-0.0056, 0.0012, -0.0027, ..., 0.0095, -0.0089, -0.0106],\n", - " ...,\n", - " [ 0.0228, -0.0258, 0.0040, ..., 0.0276, -0.0121, -0.0239],\n", - " [ 0.0082, 0.0041, 0.0145, ..., 0.0079, -0.0076, 0.0177],\n", - " [ 0.0310, -0.0092, -0.0174, ..., 0.0179, 0.0231, -0.0035]])),\n", - " ('model.layers.22.mlp.gate_proj.weight',\n", - " tensor([[ 0.0090, -0.0178, -0.0120, ..., -0.0073, -0.0149, 0.0187],\n", - " [ 0.0263, -0.0093, -0.0074, ..., -0.0472, 0.0049, 0.0288],\n", - " [ 0.0159, -0.0083, 0.0291, ..., 0.0089, -0.0076, -0.0167],\n", - " ...,\n", - " [-0.0008, 0.0206, 0.0199, ..., -0.0134, -0.0366, -0.0202],\n", - " [-0.0069, -0.0275, 0.0054, ..., 0.0093, 0.0108, 0.0094],\n", - " [ 0.0198, 0.0033, -0.0118, ..., -0.0262, 0.0241, 0.0084]])),\n", - " ('model.layers.22.mlp.up_proj.weight',\n", - " tensor([[-0.0277, 0.0038, 0.0006, ..., -0.0222, -0.0313, -0.0133],\n", - " [ 0.0132, -0.0373, 0.0109, ..., 0.0359, -0.0116, 0.0099],\n", - " [ 0.0139, -0.0185, 0.0247, ..., 0.0178, 0.0192, 0.0049],\n", - " ...,\n", - " [ 0.0362, 0.0072, -0.0236, ..., -0.0238, 0.0319, -0.0210],\n", - " [ 0.0013, -0.0047, -0.0060, ..., 0.0106, -0.0074, -0.0185],\n", - " [-0.0228, 0.0176, -0.0047, ..., -0.0034, -0.0174, -0.0264]])),\n", - " ('model.layers.22.mlp.down_proj.weight',\n", - " tensor([[ 0.0149, 0.0122, -0.0037, ..., 0.0044, 0.0171, -0.0186],\n", - " [-0.0037, -0.0002, 0.0066, ..., 0.0263, -0.0025, -0.0012],\n", - " [-0.0075, 0.0209, 0.0045, ..., 0.0082, -0.0160, 0.0079],\n", - " ...,\n", - " [ 0.0001, 0.0507, -0.0078, ..., 0.0001, -0.0119, 0.0286],\n", - " [-0.0198, -0.0122, 0.0047, ..., -0.0052, 0.0130, -0.0007],\n", - " [ 0.0241, -0.0002, -0.0147, ..., 0.0219, -0.0020, -0.0071]])),\n", - " ('model.layers.22.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.22.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.23.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.23.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.23.mixer.in_proj.weight',\n", - " tensor([[-0.0017, 0.0027, -0.0150, ..., 0.0392, -0.0079, -0.0367],\n", - " [ 0.0183, 0.0261, -0.0262, ..., -0.0157, 0.0197, 0.0135],\n", - " [-0.0030, 0.0170, 0.0032, ..., 0.0059, 0.0299, 0.0158],\n", - " ...,\n", - " [-0.0149, 0.0218, 0.0072, ..., -0.0302, 0.0035, 0.0153],\n", - " [-0.0135, 0.0425, 0.0331, ..., -0.0119, -0.0364, 0.0365],\n", - " [-0.0215, -0.0242, 0.0271, ..., 0.0500, 0.0293, 0.0100]])),\n", - " ('model.layers.23.mixer.conv1d.weight',\n", - " tensor([[[ 0.2464, 0.3726, 0.2719, 0.3580]],\n", - " \n", - " [[-0.0520, 0.0010, 0.1396, -0.4634]],\n", - " \n", - " [[ 0.1383, 0.4039, -0.3622, 0.1499]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.4094, 0.0541, 0.2240, -0.1545]],\n", - " \n", - " [[-0.4393, 0.1323, 0.1705, -0.1722]],\n", - " \n", - " [[ 0.2166, -0.4335, -0.4088, -0.1159]]])),\n", - " ('model.layers.23.mixer.conv1d.bias',\n", - " tensor([ 0.3175, -0.0325, -0.4654, ..., 0.3869, -0.2534, 0.1588])),\n", - " ('model.layers.23.mixer.out_proj.weight',\n", - " tensor([[-0.0354, -0.0041, 0.0196, ..., -0.0218, -0.0222, 0.0126],\n", - " [-0.0155, -0.0067, -0.0007, ..., 0.0112, -0.0036, -0.0054],\n", - " [ 0.0141, 0.0040, -0.0218, ..., -0.0178, -0.0031, 0.0162],\n", - " ...,\n", - " [ 0.0264, 0.0063, 0.0088, ..., -0.0310, -0.0116, 0.0239],\n", - " [-0.0031, 0.0056, -0.0243, ..., -0.0350, 0.0004, 0.0004],\n", - " [ 0.0229, -0.0201, 0.0124, ..., 0.0313, -0.0412, -0.0033]])),\n", - " ('model.layers.23.mlp.gate_proj.weight',\n", - " tensor([[ 0.0026, -0.0155, 0.0595, ..., 0.0204, 0.0172, 0.0378],\n", - " [-0.0011, -0.0253, 0.0039, ..., 0.0330, -0.0487, -0.0195],\n", - " [ 0.0174, 0.0039, -0.0029, ..., -0.0026, 0.0104, 0.0108],\n", - " ...,\n", - " [-0.0159, 0.0008, 0.0173, ..., -0.0020, 0.0085, -0.0043],\n", - " [ 0.0101, 0.0221, -0.0034, ..., -0.0268, 0.0056, 0.0137],\n", - " [-0.0031, -0.0151, 0.0073, ..., -0.0083, -0.0064, 0.0109]])),\n", - " ('model.layers.23.mlp.up_proj.weight',\n", - " tensor([[ 0.0173, -0.0132, -0.0027, ..., 0.0391, 0.0268, -0.0185],\n", - " [ 0.0221, -0.0110, -0.0108, ..., -0.0302, 0.0170, 0.0139],\n", - " [-0.0047, -0.0373, 0.0056, ..., -0.0389, -0.0175, -0.0410],\n", - " ...,\n", - " [ 0.0003, 0.0153, 0.0160, ..., 0.0002, -0.0136, 0.0417],\n", - " [-0.0059, -0.0150, -0.0111, ..., 0.0163, 0.0171, 0.0267],\n", - " [-0.0123, -0.0032, 0.0193, ..., -0.0051, -0.0051, -0.0089]])),\n", - " ('model.layers.23.mlp.down_proj.weight',\n", - " tensor([[-0.0092, -0.0148, -0.0345, ..., -0.0240, 0.0425, -0.0099],\n", - " [ 0.0458, 0.0156, -0.0067, ..., -0.0283, 0.0401, 0.0074],\n", - " [ 0.0180, -0.0008, 0.0049, ..., -0.0085, -0.0157, 0.0044],\n", - " ...,\n", - " [-0.0207, 0.0074, -0.0176, ..., 0.0038, -0.0238, -0.0026],\n", - " [-0.0201, 0.0078, 0.0243, ..., -0.0031, 0.0080, -0.0176],\n", - " [-0.0034, 0.0191, 0.0391, ..., -0.0114, 0.0133, -0.0261]])),\n", - " ('model.layers.23.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.23.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.24.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.24.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.24.mixer.in_proj.weight',\n", - " tensor([[-0.0184, -0.0299, 0.0165, ..., 0.0035, 0.0417, -0.0170],\n", - " [-0.0346, -0.0226, 0.0064, ..., 0.0072, 0.0457, -0.0148],\n", - " [ 0.0032, -0.0245, -0.0474, ..., -0.0054, -0.0044, 0.0278],\n", - " ...,\n", - " [ 0.0139, 0.0133, -0.0185, ..., 0.0188, 0.0119, -0.0205],\n", - " [ 0.0235, 0.0161, -0.0095, ..., 0.0013, -0.0382, 0.0213],\n", - " [ 0.0031, -0.0394, 0.0275, ..., -0.0068, 0.0024, 0.0179]])),\n", - " ('model.layers.24.mixer.conv1d.weight',\n", - " tensor([[[-0.1857, -0.4692, 0.4791, 0.3706]],\n", - " \n", - " [[ 0.1749, 0.4182, -0.2338, 0.0838]],\n", - " \n", - " [[-0.1204, -0.2985, -0.0470, 0.4674]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.1485, 0.3118, -0.4916, -0.1610]],\n", - " \n", - " [[ 0.0684, -0.2980, 0.4517, -0.3662]],\n", - " \n", - " [[ 0.2353, -0.2156, -0.3332, -0.0665]]])),\n", - " ('model.layers.24.mixer.conv1d.bias',\n", - " tensor([-0.4464, -0.3485, -0.3916, ..., 0.2513, -0.0601, 0.1546])),\n", - " ('model.layers.24.mixer.out_proj.weight',\n", - " tensor([[-0.0023, 0.0087, -0.0280, ..., 0.0338, -0.0095, -0.0237],\n", - " [-0.0086, -0.0084, 0.0180, ..., 0.0350, 0.0463, -0.0270],\n", - " [-0.0093, -0.0009, 0.0236, ..., 0.0158, 0.0246, 0.0068],\n", - " ...,\n", - " [ 0.0526, 0.0009, 0.0039, ..., -0.0206, -0.0538, 0.0287],\n", - " [ 0.0054, -0.0053, -0.0108, ..., 0.0167, -0.0997, 0.0036],\n", - " [ 0.0009, -0.0297, -0.0424, ..., -0.0096, -0.0235, 0.0117]])),\n", - " ('model.layers.24.mlp.gate_proj.weight',\n", - " tensor([[-0.0265, 0.0259, 0.0224, ..., -0.0080, -0.0394, 0.0290],\n", - " [-0.0101, -0.0256, 0.0079, ..., -0.0017, -0.0287, -0.0163],\n", - " [ 0.0079, -0.0021, -0.0299, ..., 0.0076, 0.0063, 0.0082],\n", - " ...,\n", - " [ 0.0061, 0.0121, 0.0275, ..., -0.0162, 0.0025, -0.0075],\n", - " [-0.0039, -0.0217, -0.0428, ..., -0.0253, 0.0231, 0.0095],\n", - " [-0.0187, 0.0077, -0.0442, ..., 0.0358, -0.0084, -0.0132]])),\n", - " ('model.layers.24.mlp.up_proj.weight',\n", - " tensor([[-0.0201, -0.0119, 0.0505, ..., -0.0025, -0.0187, 0.0011],\n", - " [-0.0105, 0.0154, -0.0163, ..., 0.0248, 0.0028, 0.0178],\n", - " [-0.0163, -0.0271, -0.0100, ..., 0.0129, -0.0220, 0.0269],\n", - " ...,\n", - " [ 0.0138, 0.0329, -0.0091, ..., 0.0038, -0.0194, -0.0223],\n", - " [ 0.0469, 0.0291, -0.0027, ..., 0.0231, 0.0261, 0.0151],\n", - " [-0.0093, -0.0098, 0.0013, ..., 0.0078, -0.0145, 0.0268]])),\n", - " ('model.layers.24.mlp.down_proj.weight',\n", - " tensor([[-0.0195, -0.0003, -0.0046, ..., -0.0132, -0.0118, 0.0242],\n", - " [-0.0267, 0.0199, 0.0243, ..., -0.0063, 0.0134, -0.0163],\n", - " [-0.0044, -0.0303, -0.0215, ..., -0.0148, -0.0216, 0.0079],\n", - " ...,\n", - " [ 0.0159, 0.0180, 0.0098, ..., -0.0126, 0.0176, 0.0087],\n", - " [-0.0203, 0.0041, -0.0256, ..., -0.0047, -0.0236, -0.0256],\n", - " [-0.0017, 0.0133, 0.0490, ..., -0.0344, -0.0118, 0.0020]])),\n", - " ('model.layers.24.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.24.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.25.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.25.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.25.mixer.in_proj.weight',\n", - " tensor([[ 0.0064, 0.0039, 0.0014, ..., 0.0130, -0.0169, 0.0010],\n", - " [ 0.0371, 0.0241, 0.0203, ..., 0.0078, 0.0463, 0.0034],\n", - " [ 0.0184, -0.0431, -0.0026, ..., -0.0164, 0.0279, -0.0138],\n", - " ...,\n", - " [ 0.0146, -0.0138, -0.0418, ..., 0.0234, 0.0145, -0.0213],\n", - " [ 0.0124, -0.0298, -0.0164, ..., -0.0169, 0.0026, -0.0180],\n", - " [-0.0250, -0.0008, -0.0133, ..., -0.0131, -0.0064, 0.0071]])),\n", - " ('model.layers.25.mixer.conv1d.weight',\n", - " tensor([[[ 0.0171, -0.3423, -0.1701, 0.4869]],\n", - " \n", - " [[-0.4648, 0.4797, 0.3531, -0.3819]],\n", - " \n", - " [[-0.1660, -0.3489, -0.2488, 0.4428]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.3545, -0.1567, -0.2646, 0.3590]],\n", - " \n", - " [[-0.2175, 0.4394, 0.3840, 0.2620]],\n", - " \n", - " [[ 0.1335, -0.3655, 0.3256, -0.1752]]])),\n", - " ('model.layers.25.mixer.conv1d.bias',\n", - " tensor([-0.0935, 0.0170, 0.0779, ..., -0.2362, 0.2879, 0.2390])),\n", - " ('model.layers.25.mixer.out_proj.weight',\n", - " tensor([[ 2.0220e-02, 5.0645e-05, -1.7425e-02, ..., 8.6082e-03,\n", - " -1.8566e-02, 1.3872e-02],\n", - " [ 2.9139e-02, 1.1096e-02, 4.4168e-02, ..., 3.5600e-02,\n", - " 7.3446e-03, -1.6368e-02],\n", - " [-3.2418e-02, 6.9682e-03, 3.1648e-02, ..., 1.4050e-02,\n", - " -1.6554e-02, 7.2751e-03],\n", - " ...,\n", - " [-3.3057e-02, -7.0545e-04, 3.9661e-02, ..., 2.0690e-02,\n", - " -1.0262e-02, -4.9292e-03],\n", - " [ 1.9849e-02, 1.9666e-02, -1.9398e-02, ..., 1.9285e-02,\n", - " 2.2522e-02, -6.0243e-03],\n", - " [ 1.7683e-02, 2.4301e-02, 7.2223e-03, ..., 3.1373e-02,\n", - " -5.7889e-03, 1.1855e-02]])),\n", - " ('model.layers.25.mlp.gate_proj.weight',\n", - " tensor([[-1.6223e-02, 4.5519e-03, -1.9218e-02, ..., 6.3580e-03,\n", - " -1.2723e-02, -9.7756e-03],\n", - " [-7.4200e-03, 1.8729e-02, 2.6924e-03, ..., 8.2305e-03,\n", - " -1.5727e-02, -9.8748e-03],\n", - " [ 3.2143e-02, -6.1559e-02, 1.6362e-02, ..., -3.6189e-04,\n", - " 1.2017e-04, -1.5734e-02],\n", - " ...,\n", - " [-1.4649e-02, -4.7663e-03, -1.9292e-02, ..., -1.9359e-02,\n", - " 1.8795e-02, 1.0221e-02],\n", - " [-2.4459e-02, 1.1684e-02, -2.8023e-02, ..., 8.0104e-03,\n", - " 8.5950e-05, 1.0542e-02],\n", - " [-4.5679e-03, -1.1421e-02, -2.1099e-02, ..., 4.5089e-03,\n", - " -3.0686e-02, -9.6116e-03]])),\n", - " ('model.layers.25.mlp.up_proj.weight',\n", - " tensor([[-0.0204, -0.0013, -0.0264, ..., -0.0081, -0.0027, 0.0215],\n", - " [-0.0161, 0.0051, -0.0111, ..., -0.0244, 0.0043, -0.0043],\n", - " [-0.0511, 0.0006, -0.0249, ..., 0.0069, 0.0615, 0.0123],\n", - " ...,\n", - " [-0.0086, -0.0016, 0.0064, ..., -0.0347, 0.0097, -0.0134],\n", - " [-0.0003, 0.0015, -0.0053, ..., 0.0210, 0.0135, 0.0337],\n", - " [-0.0205, 0.0028, -0.0272, ..., -0.0168, -0.0072, 0.0019]])),\n", - " ('model.layers.25.mlp.down_proj.weight',\n", - " tensor([[ 0.0166, 0.0044, 0.0180, ..., -0.0127, 0.0070, -0.0066],\n", - " [-0.0056, 0.0140, 0.0151, ..., -0.0239, -0.0140, 0.0470],\n", - " [-0.0030, -0.0093, -0.0188, ..., -0.0090, -0.0092, -0.0088],\n", - " ...,\n", - " [ 0.0465, 0.0277, -0.0349, ..., 0.0424, 0.0015, 0.0206],\n", - " [-0.0096, 0.0174, 0.0250, ..., -0.0142, -0.0022, -0.0141],\n", - " [-0.0195, -0.0174, 0.0033, ..., 0.0027, -0.0061, -0.0108]])),\n", - " ('model.layers.25.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.25.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.26.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.26.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.26.mixer.in_proj.weight',\n", - " tensor([[ 0.0112, 0.0060, -0.0038, ..., -0.0164, 0.0111, 0.0105],\n", - " [ 0.0227, -0.0248, 0.0240, ..., 0.0103, -0.0373, -0.0051],\n", - " [-0.0073, 0.0227, -0.0190, ..., 0.0048, -0.0101, -0.0137],\n", - " ...,\n", - " [ 0.0086, -0.0084, 0.0177, ..., -0.0245, 0.0119, 0.0022],\n", - " [-0.0080, -0.0284, 0.0440, ..., 0.0340, -0.0093, 0.0130],\n", - " [-0.0107, 0.0234, -0.0279, ..., 0.0106, -0.0169, -0.0001]])),\n", - " ('model.layers.26.mixer.conv1d.weight',\n", - " tensor([[[ 0.0550, -0.3464, -0.2378, -0.1244]],\n", - " \n", - " [[-0.0925, -0.2497, 0.2629, -0.1821]],\n", - " \n", - " [[-0.4524, 0.3462, -0.4604, -0.2758]],\n", - " \n", - " ...,\n", - " \n", - " [[-0.4555, -0.0839, 0.3936, -0.3707]],\n", - " \n", - " [[ 0.3409, -0.4109, 0.0890, -0.3629]],\n", - " \n", - " [[-0.2769, 0.4033, -0.1090, 0.3055]]])),\n", - " ('model.layers.26.mixer.conv1d.bias',\n", - " tensor([-0.2286, -0.2395, -0.2517, ..., 0.0537, 0.0906, 0.4936])),\n", - " ('model.layers.26.mixer.out_proj.weight',\n", - " tensor([[-0.0316, -0.0423, -0.0053, ..., 0.0024, 0.0084, -0.0270],\n", - " [ 0.0458, -0.0243, 0.0060, ..., -0.0007, -0.0161, -0.0232],\n", - " [ 0.0388, -0.0126, 0.0184, ..., -0.0059, 0.0061, 0.0090],\n", - " ...,\n", - " [ 0.0487, 0.0305, -0.0175, ..., -0.0250, -0.0158, -0.0035],\n", - " [-0.0148, -0.0224, 0.0095, ..., -0.0102, -0.0226, 0.0272],\n", - " [-0.0061, 0.0067, 0.0069, ..., 0.0038, -0.0277, -0.0168]])),\n", - " ('model.layers.26.mlp.gate_proj.weight',\n", - " tensor([[-1.9812e-02, 8.3232e-03, 3.0347e-03, ..., 2.1982e-02,\n", - " 1.3550e-02, -1.1203e-02],\n", - " [ 2.2460e-02, 4.9811e-03, -2.2167e-02, ..., 1.3932e-03,\n", - " 5.3891e-03, -2.8310e-02],\n", - " [ 1.1011e-02, -1.2903e-02, -2.8861e-02, ..., 2.6808e-02,\n", - " -2.8479e-03, -1.3105e-02],\n", - " ...,\n", - " [ 1.1078e-03, -1.1789e-02, -4.4165e-02, ..., 8.2950e-03,\n", - " -1.8015e-02, -1.2234e-02],\n", - " [-2.0721e-02, -4.7919e-04, -4.9474e-02, ..., 7.9999e-05,\n", - " 1.7886e-02, -4.4699e-02],\n", - " [ 8.1279e-03, 1.2636e-02, -2.0932e-02, ..., -3.0361e-03,\n", - " 3.3468e-03, 2.7677e-02]])),\n", - " ('model.layers.26.mlp.up_proj.weight',\n", - " tensor([[-0.0301, -0.0025, -0.0147, ..., -0.0186, 0.0058, -0.0057],\n", - " [ 0.0303, -0.0341, 0.0142, ..., -0.0252, -0.0247, 0.0280],\n", - " [ 0.0209, -0.0425, 0.0073, ..., 0.0063, -0.0040, -0.0076],\n", - " ...,\n", - " [-0.0172, -0.0199, 0.0125, ..., 0.0363, 0.0118, -0.0124],\n", - " [-0.0108, 0.0042, -0.0475, ..., 0.0091, -0.0185, 0.0144],\n", - " [-0.0275, -0.0049, 0.0183, ..., -0.0001, -0.0119, -0.0359]])),\n", - " ('model.layers.26.mlp.down_proj.weight',\n", - " tensor([[-0.0197, -0.0082, -0.0224, ..., -0.0469, -0.0076, -0.0375],\n", - " [-0.0070, -0.0071, 0.0190, ..., -0.0125, 0.0068, 0.0166],\n", - " [ 0.0062, -0.0072, 0.0189, ..., -0.0244, -0.0292, -0.0328],\n", - " ...,\n", - " [-0.0054, 0.0219, 0.0058, ..., 0.0118, 0.0136, -0.0221],\n", - " [-0.0133, 0.0299, -0.0182, ..., -0.0496, -0.0202, 0.0196],\n", - " [-0.0131, -0.0237, -0.0473, ..., 0.0066, 0.0119, 0.0100]])),\n", - " ('model.layers.26.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.26.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.27.mixer.z_bias',\n", - " tensor([0., 0., 0., ..., 0., 0., 0.])),\n", - " ('model.layers.27.mixer.D',\n", - " tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1.])),\n", - " ('model.layers.27.mixer.in_proj.weight',\n", - " tensor([[ 0.0200, -0.0276, -0.0274, ..., 0.0282, 0.0025, 0.0215],\n", - " [ 0.0054, 0.0218, -0.0175, ..., -0.0054, 0.0211, -0.0073],\n", - " [ 0.0100, -0.0023, 0.0162, ..., 0.0008, -0.0193, -0.0050],\n", - " ...,\n", - " [-0.0241, -0.0197, -0.0142, ..., 0.0039, -0.0175, 0.0045],\n", - " [ 0.0214, 0.0137, -0.0155, ..., -0.0212, 0.0089, 0.0165],\n", - " [ 0.0086, 0.0181, 0.0069, ..., -0.0093, -0.0272, 0.0068]])),\n", - " ('model.layers.27.mixer.conv1d.weight',\n", - " tensor([[[ 0.0519, 0.2061, 0.2635, 0.4916]],\n", - " \n", - " [[ 0.3745, -0.0860, -0.2310, -0.4250]],\n", - " \n", - " [[ 0.0565, 0.3699, 0.2812, -0.4201]],\n", - " \n", - " ...,\n", - " \n", - " [[ 0.4073, 0.1852, -0.1687, -0.2643]],\n", - " \n", - " [[-0.0865, -0.0894, 0.2650, -0.4522]],\n", - " \n", - " [[-0.0987, 0.0925, -0.2098, 0.0325]]])),\n", - " ('model.layers.27.mixer.conv1d.bias',\n", - " tensor([-0.4788, -0.0231, -0.4210, ..., -0.3143, -0.2893, 0.0570])),\n", - " ('model.layers.27.mixer.out_proj.weight',\n", - " tensor([[-0.0294, -0.0038, -0.0213, ..., -0.0141, 0.0072, -0.0359],\n", - " [ 0.0131, 0.0173, 0.0159, ..., 0.0030, 0.0400, -0.0065],\n", - " [-0.0111, 0.0374, 0.0109, ..., -0.0338, 0.0312, 0.0073],\n", - " ...,\n", - " [-0.0004, 0.0282, 0.0148, ..., 0.0165, 0.0062, -0.0177],\n", - " [ 0.0265, -0.0331, -0.0056, ..., 0.0407, 0.0154, 0.0176],\n", - " [ 0.0209, -0.0293, 0.0009, ..., -0.0240, -0.0029, -0.0407]])),\n", - " ('model.layers.27.mlp.gate_proj.weight',\n", - " tensor([[-0.0118, 0.0202, -0.0012, ..., 0.0101, 0.0075, 0.0102],\n", - " [ 0.0102, -0.0062, 0.0330, ..., -0.0024, -0.0245, -0.0237],\n", - " [-0.0008, 0.0202, -0.0097, ..., 0.0022, -0.0152, -0.0128],\n", - " ...,\n", - " [-0.0461, 0.0178, 0.0253, ..., 0.0319, 0.0173, -0.0099],\n", - " [ 0.0014, -0.0256, 0.0224, ..., 0.0272, 0.0045, 0.0192],\n", - " [ 0.0146, -0.0357, -0.0089, ..., -0.0147, 0.0383, 0.0354]])),\n", - " ('model.layers.27.mlp.up_proj.weight',\n", - " tensor([[-3.1854e-02, -1.0290e-03, -3.4564e-03, ..., 3.3551e-03,\n", - " 3.2845e-02, 2.1107e-02],\n", - " [-4.8083e-04, -5.8388e-03, 1.7324e-03, ..., 2.0575e-02,\n", - " -1.1685e-02, 1.2504e-02],\n", - " [ 4.6267e-02, -1.8935e-02, -2.4184e-02, ..., -4.8211e-02,\n", - " -3.3912e-04, 3.0527e-02],\n", - " ...,\n", - " [-6.9427e-03, -4.8680e-03, 3.2021e-02, ..., 1.4236e-02,\n", - " 1.9532e-02, 1.3339e-02],\n", - " [ 1.2463e-02, -5.5923e-03, -1.5680e-02, ..., 8.7956e-03,\n", - " 2.8262e-02, -1.2526e-02],\n", - " [-4.8530e-03, -8.8749e-05, 3.3507e-02, ..., -2.8260e-02,\n", - " -2.0571e-03, -8.3943e-03]])),\n", - " ('model.layers.27.mlp.down_proj.weight',\n", - " tensor([[-0.0457, -0.0267, -0.0210, ..., -0.0093, -0.0016, -0.0008],\n", - " [-0.0053, 0.0284, -0.0003, ..., 0.0065, -0.0117, 0.0243],\n", - " [ 0.0120, 0.0023, -0.0180, ..., -0.0003, -0.0313, 0.0163],\n", - " ...,\n", - " [-0.0160, 0.0207, 0.0082, ..., 0.0153, 0.0131, 0.0034],\n", - " [-0.0073, 0.0424, 0.0274, ..., -0.0075, -0.0554, -0.0114],\n", - " [-0.0192, 0.0268, 0.0036, ..., 0.0094, 0.0045, 0.0030]])),\n", - " ('model.layers.27.input_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.layers.27.post_attention_layernorm.weight',\n", - " tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('model.norm.weight', tensor([1., 1., 1., ..., 1., 1., 1.])),\n", - " ('lm_head.weight',\n", - " tensor([[-0.0141, -0.0445, 0.0071, ..., -0.0143, -0.0239, -0.0512],\n", - " [ 0.0295, -0.0317, -0.0201, ..., -0.0082, 0.0231, -0.0030],\n", - " [-0.0255, -0.0139, 0.0020, ..., -0.0040, -0.0154, 0.0336],\n", - " ...,\n", - " [ 0.0095, 0.0361, 0.0135, ..., -0.0018, 0.0074, -0.0311],\n", - " [-0.0092, 0.0060, 0.0594, ..., -0.0046, 0.0117, 0.0364],\n", - " [ 0.0228, -0.0265, -0.0262, ..., 0.0038, 0.0097, -0.0257]]))])" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_ssm.state_dict()" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "N params SSM: 5.305533088\n" - ] - } - ], - "source": [ - "print(\"N params SSM:\", sum(p.numel() for p in apriel_ssm.parameters() if p.requires_grad)/1e9)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Load State dict into SSM" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AprielSSMForCausalLM(\n", - " (model): AprielSSMModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0-27): 28 x AprielDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", - ")" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "apriel_ssm.to(device).to(dtype=torch.bfloat16)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "_IncompatibleKeys(missing_keys=['model.layers.0.mixer.z_bias', 'model.layers.0.mixer.D', 'model.layers.0.mixer.in_proj.weight', 'model.layers.0.mixer.conv1d.weight', 'model.layers.0.mixer.conv1d.bias', 'model.layers.0.mixer.out_proj.weight', 'model.layers.1.mixer.z_bias', 'model.layers.1.mixer.D', 'model.layers.1.mixer.in_proj.weight', 'model.layers.1.mixer.conv1d.weight', 'model.layers.1.mixer.conv1d.bias', 'model.layers.1.mixer.out_proj.weight', 'model.layers.2.mixer.z_bias', 'model.layers.2.mixer.D', 'model.layers.2.mixer.in_proj.weight', 'model.layers.2.mixer.conv1d.weight', 'model.layers.2.mixer.conv1d.bias', 'model.layers.2.mixer.out_proj.weight', 'model.layers.3.mixer.z_bias', 'model.layers.3.mixer.D', 'model.layers.3.mixer.in_proj.weight', 'model.layers.3.mixer.conv1d.weight', 'model.layers.3.mixer.conv1d.bias', 'model.layers.3.mixer.out_proj.weight', 'model.layers.4.mixer.z_bias', 'model.layers.4.mixer.D', 'model.layers.4.mixer.in_proj.weight', 'model.layers.4.mixer.conv1d.weight', 'model.layers.4.mixer.conv1d.bias', 'model.layers.4.mixer.out_proj.weight', 'model.layers.5.mixer.z_bias', 'model.layers.5.mixer.D', 'model.layers.5.mixer.in_proj.weight', 'model.layers.5.mixer.conv1d.weight', 'model.layers.5.mixer.conv1d.bias', 'model.layers.5.mixer.out_proj.weight', 'model.layers.6.mixer.z_bias', 'model.layers.6.mixer.D', 'model.layers.6.mixer.in_proj.weight', 'model.layers.6.mixer.conv1d.weight', 'model.layers.6.mixer.conv1d.bias', 'model.layers.6.mixer.out_proj.weight', 'model.layers.7.mixer.z_bias', 'model.layers.7.mixer.D', 'model.layers.7.mixer.in_proj.weight', 'model.layers.7.mixer.conv1d.weight', 'model.layers.7.mixer.conv1d.bias', 'model.layers.7.mixer.out_proj.weight', 'model.layers.8.mixer.z_bias', 'model.layers.8.mixer.D', 'model.layers.8.mixer.in_proj.weight', 'model.layers.8.mixer.conv1d.weight', 'model.layers.8.mixer.conv1d.bias', 'model.layers.8.mixer.out_proj.weight', 'model.layers.9.mixer.z_bias', 'model.layers.9.mixer.D', 'model.layers.9.mixer.in_proj.weight', 'model.layers.9.mixer.conv1d.weight', 'model.layers.9.mixer.conv1d.bias', 'model.layers.9.mixer.out_proj.weight', 'model.layers.10.mixer.z_bias', 'model.layers.10.mixer.D', 'model.layers.10.mixer.in_proj.weight', 'model.layers.10.mixer.conv1d.weight', 'model.layers.10.mixer.conv1d.bias', 'model.layers.10.mixer.out_proj.weight', 'model.layers.11.mixer.z_bias', 'model.layers.11.mixer.D', 'model.layers.11.mixer.in_proj.weight', 'model.layers.11.mixer.conv1d.weight', 'model.layers.11.mixer.conv1d.bias', 'model.layers.11.mixer.out_proj.weight', 'model.layers.12.mixer.z_bias', 'model.layers.12.mixer.D', 'model.layers.12.mixer.in_proj.weight', 'model.layers.12.mixer.conv1d.weight', 'model.layers.12.mixer.conv1d.bias', 'model.layers.12.mixer.out_proj.weight', 'model.layers.13.mixer.z_bias', 'model.layers.13.mixer.D', 'model.layers.13.mixer.in_proj.weight', 'model.layers.13.mixer.conv1d.weight', 'model.layers.13.mixer.conv1d.bias', 'model.layers.13.mixer.out_proj.weight', 'model.layers.14.mixer.z_bias', 'model.layers.14.mixer.D', 'model.layers.14.mixer.in_proj.weight', 'model.layers.14.mixer.conv1d.weight', 'model.layers.14.mixer.conv1d.bias', 'model.layers.14.mixer.out_proj.weight', 'model.layers.15.mixer.z_bias', 'model.layers.15.mixer.D', 'model.layers.15.mixer.in_proj.weight', 'model.layers.15.mixer.conv1d.weight', 'model.layers.15.mixer.conv1d.bias', 'model.layers.15.mixer.out_proj.weight', 'model.layers.16.mixer.z_bias', 'model.layers.16.mixer.D', 'model.layers.16.mixer.in_proj.weight', 'model.layers.16.mixer.conv1d.weight', 'model.layers.16.mixer.conv1d.bias', 'model.layers.16.mixer.out_proj.weight', 'model.layers.17.mixer.z_bias', 'model.layers.17.mixer.D', 'model.layers.17.mixer.in_proj.weight', 'model.layers.17.mixer.conv1d.weight', 'model.layers.17.mixer.conv1d.bias', 'model.layers.17.mixer.out_proj.weight', 'model.layers.18.mixer.z_bias', 'model.layers.18.mixer.D', 'model.layers.18.mixer.in_proj.weight', 'model.layers.18.mixer.conv1d.weight', 'model.layers.18.mixer.conv1d.bias', 'model.layers.18.mixer.out_proj.weight', 'model.layers.19.mixer.z_bias', 'model.layers.19.mixer.D', 'model.layers.19.mixer.in_proj.weight', 'model.layers.19.mixer.conv1d.weight', 'model.layers.19.mixer.conv1d.bias', 'model.layers.19.mixer.out_proj.weight', 'model.layers.20.mixer.z_bias', 'model.layers.20.mixer.D', 'model.layers.20.mixer.in_proj.weight', 'model.layers.20.mixer.conv1d.weight', 'model.layers.20.mixer.conv1d.bias', 'model.layers.20.mixer.out_proj.weight', 'model.layers.21.mixer.z_bias', 'model.layers.21.mixer.D', 'model.layers.21.mixer.in_proj.weight', 'model.layers.21.mixer.conv1d.weight', 'model.layers.21.mixer.conv1d.bias', 'model.layers.21.mixer.out_proj.weight', 'model.layers.22.mixer.z_bias', 'model.layers.22.mixer.D', 'model.layers.22.mixer.in_proj.weight', 'model.layers.22.mixer.conv1d.weight', 'model.layers.22.mixer.conv1d.bias', 'model.layers.22.mixer.out_proj.weight', 'model.layers.23.mixer.z_bias', 'model.layers.23.mixer.D', 'model.layers.23.mixer.in_proj.weight', 'model.layers.23.mixer.conv1d.weight', 'model.layers.23.mixer.conv1d.bias', 'model.layers.23.mixer.out_proj.weight', 'model.layers.24.mixer.z_bias', 'model.layers.24.mixer.D', 'model.layers.24.mixer.in_proj.weight', 'model.layers.24.mixer.conv1d.weight', 'model.layers.24.mixer.conv1d.bias', 'model.layers.24.mixer.out_proj.weight', 'model.layers.25.mixer.z_bias', 'model.layers.25.mixer.D', 'model.layers.25.mixer.in_proj.weight', 'model.layers.25.mixer.conv1d.weight', 'model.layers.25.mixer.conv1d.bias', 'model.layers.25.mixer.out_proj.weight', 'model.layers.26.mixer.z_bias', 'model.layers.26.mixer.D', 'model.layers.26.mixer.in_proj.weight', 'model.layers.26.mixer.conv1d.weight', 'model.layers.26.mixer.conv1d.bias', 'model.layers.26.mixer.out_proj.weight', 'model.layers.27.mixer.z_bias', 'model.layers.27.mixer.D', 'model.layers.27.mixer.in_proj.weight', 'model.layers.27.mixer.conv1d.weight', 'model.layers.27.mixer.conv1d.bias', 'model.layers.27.mixer.out_proj.weight'], unexpected_keys=['model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.18.self_attn.q_proj.weight', 'model.layers.18.self_attn.k_proj.weight', 'model.layers.18.self_attn.v_proj.weight', 'model.layers.18.self_attn.o_proj.weight', 'model.layers.19.self_attn.q_proj.weight', 'model.layers.19.self_attn.k_proj.weight', 'model.layers.19.self_attn.v_proj.weight', 'model.layers.19.self_attn.o_proj.weight', 'model.layers.20.self_attn.q_proj.weight', 'model.layers.20.self_attn.k_proj.weight', 'model.layers.20.self_attn.v_proj.weight', 'model.layers.20.self_attn.o_proj.weight', 'model.layers.21.self_attn.q_proj.weight', 'model.layers.21.self_attn.k_proj.weight', 'model.layers.21.self_attn.v_proj.weight', 'model.layers.21.self_attn.o_proj.weight', 'model.layers.22.self_attn.q_proj.weight', 'model.layers.22.self_attn.k_proj.weight', 'model.layers.22.self_attn.v_proj.weight', 'model.layers.22.self_attn.o_proj.weight', 'model.layers.23.self_attn.q_proj.weight', 'model.layers.23.self_attn.k_proj.weight', 'model.layers.23.self_attn.v_proj.weight', 'model.layers.23.self_attn.o_proj.weight', 'model.layers.24.self_attn.q_proj.weight', 'model.layers.24.self_attn.k_proj.weight', 'model.layers.24.self_attn.v_proj.weight', 'model.layers.24.self_attn.o_proj.weight', 'model.layers.25.self_attn.q_proj.weight', 'model.layers.25.self_attn.k_proj.weight', 'model.layers.25.self_attn.v_proj.weight', 'model.layers.25.self_attn.o_proj.weight', 'model.layers.26.self_attn.q_proj.weight', 'model.layers.26.self_attn.k_proj.weight', 'model.layers.26.self_attn.v_proj.weight', 'model.layers.26.self_attn.o_proj.weight', 'model.layers.27.self_attn.q_proj.weight', 'model.layers.27.self_attn.k_proj.weight', 'model.layers.27.self_attn.v_proj.weight', 'model.layers.27.self_attn.o_proj.weight'])" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_ssm.load_state_dict(apriel_state_dict, strict=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AprielSSMForCausalLM(\n", - " (model): AprielSSMModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0-27): 28 x AprielDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", - ")" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "apriel_ssm.to(device).to(dtype=torch.bfloat16)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Save checkpoint" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'apriel_ssm' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[2], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mapriel_ssm\u001b[49m\u001b[38;5;241m.\u001b[39msave_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/mnt/checkpoints/ssm/apriel_ssm_instruct_base\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 2\u001b[0m save_config\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", - "\u001b[0;31mNameError\u001b[0m: name 'apriel_ssm' is not defined" - ] - } - ], - "source": [ - "apriel_ssm.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_instruct_base\",\n", - " save_config=True)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "24" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_ssm.model.layers[0].mixer.n_v_heads" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AprielSSMForCausalLM(\n", - " (model): AprielSSMModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0-27): 28 x AprielDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", - ")" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_ssm" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Try a forward pass" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "input_ids = torch.randint(0, 32000, (1, 128), dtype=torch.long, device=device)\n", - "batch_size = 1\n", - "max_length = 128\n", - "state = SimpleNamespace()\n", - "state.key_value_memory_dict = apriel_ssm.allocate_inference_cache(batch_size, max_length, dtype=torch.bfloat16)\n", - "state.batch_size = batch_size\n", - "state.seqlen_offset = 0\n", - "static_inputs = {\"inference_params\": state,\n", - " \"input_ids\": input_ids,\n", - " \"use_cache\": True,\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "CustomMambaCausalLMOutput(loss=None, logits=tensor([[[-3.0781, 2.3594, 1.4609, ..., -2.3438, -1.9688, 0.6484],\n", - " [-5.8125, 4.9688, 0.4414, ..., -4.2500, -3.5156, -4.8125],\n", - " [-5.5000, 3.3594, 1.1484, ..., -3.4375, -2.3125, -4.4375],\n", - " ...,\n", - " [-2.2812, 0.1465, 2.2344, ..., -7.6875, -3.0312, -6.2500],\n", - " [-6.8750, 1.7812, -1.3750, ..., -7.4688, -5.6875, -4.4062],\n", - " [-2.0156, 2.0938, 3.1094, ..., -3.0156, -2.1406, -2.2812]]],\n", - " device='cuda:0', grad_fn=), all_hidden_states=(), last_hidden_state=tensor([[[-1.3828, 0.0625, -2.7500, ..., -0.6523, -0.8906, 1.4609],\n", - " [ 2.1406, -0.0247, -3.0156, ..., -0.0074, 1.0234, 1.3828],\n", - " [ 1.6016, -0.7266, -1.2422, ..., -0.4004, -0.8242, -0.5586],\n", - " ...,\n", - " [ 1.5234, -0.0262, -1.5469, ..., -0.4922, -1.0078, 1.2344],\n", - " [-0.4629, -0.6055, -1.3906, ..., -0.9922, -0.3066, 1.1875],\n", - " [-0.7539, -0.0243, -2.4688, ..., -1.0625, -2.7188, 2.6875]]],\n", - " device='cuda:0', dtype=torch.bfloat16, grad_fn=))" - ] - }, - "execution_count": 24, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_ssm.forward(**static_inputs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load Apriel SSM into HF class" - ] - }, - { - "cell_type": "code", - "execution_count": 130, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], - "source": [ - "import torch\n", - "from mamba_ssm import MambaLMHeadModel\n", - "from mamba_ssm.models.config_mamba import MambaConfig\n", - "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", - "from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig\n", - "from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM\n", - "from transformers.cache_utils import StaticCache\n", - "from types import SimpleNamespace\n", - "import os\n", - "import shutil\n", - "# make sure the code changes reflected without reload\n", - "%load_ext autoreload\n", - "%autoreload 2\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "model_path = \"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/apriel_ssminstr-distil-randinit-bs768-lr0.0003-sl4096_ti5000_luke_mix1/export/apriel_ssm/5000\"\n", - "modeling_path = \"/home/toolkit/dev/Fast-LLM/fast_llm/models/ssm/external\"\n", - "# # copy the config.json to the model path\n", - "shutil.copy(os.path.join(modeling_path, \"modeling_ssm_apriel.py\"), os.path.join(model_path, \"modeling_ssm_apriel.py\"))\n", - "shutil.copy(os.path.join(modeling_path, \"configuration_ssm_apriel.py\"), os.path.join(model_path, \"configuration_ssm_apriel.py\"))\n", - "\n", - "tokenizer_path = \"/mnt/checkpoints/upstream/Mistral-Nemo-Base-2407/\"\n", - "# # cp tokenizer*\n", - "# shutil.copy(os.path.join(tokenizer_path, \"tokenizer.json\"), os.path.join(model_path, \"tokenizer.json\"))\n", - "# shutil.copy(os.path.join(tokenizer_path, \"tokenizer_config.json\"), os.path.join(model_path, \"tokenizer_config.json\"))\n", - "# shutil.copy(os.path.join(tokenizer_path, \"special_tokens_map.json\"), os.path.join(model_path, \"special_tokens_map.json\"))\n", - "# shutil.copy(os.path.join(tokenizer_path, \"vocab.json\"), os.path.join(model_path, \"vocab.json\"))\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n", - "Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00, 1.08s/it]\n" - ] - } - ], - "source": [ - "\n", - "apriel_ssm = AprielSSMForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, device=\"cuda\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AprielSSMForCausalLM(\n", - " (model): AprielSSMModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0-27): 28 x AprielDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", - ")" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "apriel_ssm" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "config = apriel_ssm.config" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Mamba in Llama: SSM hybrid " - ] - }, - { - "cell_type": "code", - "execution_count": 90, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], - "source": [ - "\n", - "from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig\n", - "import torch\n", - "from mamba_ssm import MambaLMHeadModel\n", - "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", - "from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig\n", - "from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM\n", - "from transformers.cache_utils import StaticCache\n", - "from types import SimpleNamespace\n", - "from fast_llm.models.ssm.external.modeling_ssm_hybrid_apriel import AprielSSMHybridConfig\n", - "from fast_llm.models.ssm.external.modeling_ssm_hybrid_apriel import AprielSSMHybridModel, AprielSSMDecoderLayer\n", - "# from fast_llm.models.ssm.external.__hybrid_wrapper import MambaTransformerHybridModelWrapper\n", - "# make sure the code changes reflected without reload\n", - "%load_ext autoreload\n", - "%autoreload 2\n" - ] - }, - { - "cell_type": "code", - "execution_count": 81, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", - "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", - "\n", - "# d_xb = config.num_key_value_heads * config.head_dim\n", - "d_inner = config.num_attention_heads * config.head_dim\n", - "d_state = config.head_dim\n", - "hybrdif_apriel_config = AprielSSMHybridConfig(**config.to_dict(),\n", - " ssm_block_pattern=[\"m2d\", \"t\"] * 14,\n", - " ssm_cfg={\n", - " \"d_state\": 64,\n", - " \"n_v_heads\": 24,\n", - " \"n_qk_heads\": 24,\n", - " # \"d_xb\": d_xb,\n", - " \"expand\": 1,\n", - " \"chunk_size\": 128,\n", - " \"activation\": \"identity\",\n", - " \"bias\": False,\n", - " \"d_inner\": 24 * 128, # num_heads * head_dim\n", - " })\n", - "# hybrdif_apriel_config" - ] - }, - { - "cell_type": "code", - "execution_count": 87, - "metadata": {}, - "outputs": [], - "source": [ - "hybrid_apriel_model = AprielSSMHybridModel(hybrdif_apriel_config)" - ] - }, - { - "cell_type": "code", - "execution_count": 88, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AprielSSMDecoderLayer(\n", - " (mixer): DiscreteMamba2(\n", - " (in_proj): Linear(in_features=4096, out_features=9240, bias=False)\n", - " (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)\n", - " (act): Identity()\n", - " (out_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - ")" - ] - }, - "execution_count": 88, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "hybrid_apriel_model.layers[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 91, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 91, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "isinstance(hybrid_apriel_model.layers[0], AprielSSMDecoderLayer)" - ] - }, - { - "cell_type": "code", - "execution_count": 84, - "metadata": {}, - "outputs": [], - "source": [ - "device = \"cpu\" #if torch.cuda.is_available() else \"cpu\"\n", - "input_ids = torch.randint(0, 32000, (1, 128), dtype=torch.long, device=device)\n", - "batch_size = 1\n", - "max_length = 128\n", - "state = SimpleNamespace()\n", - "state.key_value_memory_dict = hybrid_apriel_model.allocate_inference_cache(batch_size, max_length, dtype=torch.bfloat16)\n", - "state.batch_size = batch_size\n", - "state.seqlen_offset = 0\n", - "static_inputs = {\"inference_params\": state,\n", - " \"input_ids\": input_ids,\n", - " \"use_cache\": True,\n", - "}\n" - ] - }, - { - "cell_type": "code", - "execution_count": 73, - "metadata": {}, - "outputs": [ - { - "ename": "OutOfMemoryError", - "evalue": "CUDA out of memory. Tried to allocate 2.00 GiB. GPU 0 has a total capacity of 79.10 GiB of which 1.72 GiB is free. Process 191417 has 19.83 GiB memory in use. Process 1524280 has 57.54 GiB memory in use. Of the allocated memory 18.11 GiB is allocated by PyTorch, and 1.05 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[73], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mhybrid_apriel_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mto(dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mbfloat16)\n", - "File \u001b[0;32m~/.local/lib/python3.12/site-packages/transformers/modeling_utils.py:3110\u001b[0m, in \u001b[0;36mPreTrainedModel.to\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 3105\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype_present_in_args:\n\u001b[1;32m 3106\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 3107\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3108\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m `dtype` by passing the correct `torch_dtype` argument.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3109\u001b[0m )\n\u001b[0;32m-> 3110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1174\u001b[0m, in \u001b[0;36mModule.to\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1171\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1172\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n\u001b[0;32m-> 1174\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconvert\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:780\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 778\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m recurse:\n\u001b[1;32m 779\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchildren():\n\u001b[0;32m--> 780\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 782\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mcompute_should_use_set_data\u001b[39m(tensor, tensor_applied):\n\u001b[1;32m 783\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_has_compatible_shallow_copy_type(tensor, tensor_applied):\n\u001b[1;32m 784\u001b[0m \u001b[38;5;66;03m# If the new tensor has compatible tensor type as the existing tensor,\u001b[39;00m\n\u001b[1;32m 785\u001b[0m \u001b[38;5;66;03m# the current behavior is to change the tensor in-place using `.data =`,\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 790\u001b[0m \u001b[38;5;66;03m# global flag to let the user control whether they want the future\u001b[39;00m\n\u001b[1;32m 791\u001b[0m \u001b[38;5;66;03m# behavior of overwriting the existing tensor or not.\u001b[39;00m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:805\u001b[0m, in \u001b[0;36mModule._apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 801\u001b[0m \u001b[38;5;66;03m# Tensors stored in modules are graph leaves, and we don't want to\u001b[39;00m\n\u001b[1;32m 802\u001b[0m \u001b[38;5;66;03m# track autograd history of `param_applied`, so we have to use\u001b[39;00m\n\u001b[1;32m 803\u001b[0m \u001b[38;5;66;03m# `with torch.no_grad():`\u001b[39;00m\n\u001b[1;32m 804\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m--> 805\u001b[0m param_applied \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparam\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 806\u001b[0m p_should_use_set_data \u001b[38;5;241m=\u001b[39m compute_should_use_set_data(param, param_applied)\n\u001b[1;32m 808\u001b[0m \u001b[38;5;66;03m# subclasses may have multiple child tensors so we need to use swap_tensors\u001b[39;00m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1160\u001b[0m, in \u001b[0;36mModule.to..convert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 1153\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m convert_to_format \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m t\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;241m4\u001b[39m, \u001b[38;5;241m5\u001b[39m):\n\u001b[1;32m 1154\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m t\u001b[38;5;241m.\u001b[39mto(\n\u001b[1;32m 1155\u001b[0m device,\n\u001b[1;32m 1156\u001b[0m dtype \u001b[38;5;28;01mif\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_floating_point() \u001b[38;5;129;01mor\u001b[39;00m t\u001b[38;5;241m.\u001b[39mis_complex() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1157\u001b[0m non_blocking,\n\u001b[1;32m 1158\u001b[0m memory_format\u001b[38;5;241m=\u001b[39mconvert_to_format,\n\u001b[1;32m 1159\u001b[0m )\n\u001b[0;32m-> 1160\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1161\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1162\u001b[0m \u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_floating_point\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mis_complex\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1163\u001b[0m \u001b[43m \u001b[49m\u001b[43mnon_blocking\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1164\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1165\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 1166\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mstr\u001b[39m(e) \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot copy out of meta tensor; no data!\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", - "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 2.00 GiB. GPU 0 has a total capacity of 79.10 GiB of which 1.72 GiB is free. Process 191417 has 19.83 GiB memory in use. Process 1524280 has 57.54 GiB memory in use. Of the allocated memory 18.11 GiB is allocated by PyTorch, and 1.05 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)" - ] - } - ], - "source": [ - "hybrid_apriel_model.to(device).to(dtype=torch.bfloat16)" - ] - }, - { - "cell_type": "code", - "execution_count": 79, - "metadata": {}, - "outputs": [ - { - "ename": "RuntimeError", - "evalue": "split_with_sizes expects split_sizes to sum exactly to 8216 (input tensor's size at dimension -1), but got split_sizes=[6144, 3072, 24]", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[79], line 2\u001b[0m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mhybrid_apriel_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mstatic_inputs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/dev/Fast-LLM/fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py:1043\u001b[0m, in \u001b[0;36mAprielSSMHybridModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, inference_params, **flash_attn_kwargs)\u001b[0m\n\u001b[1;32m 1041\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_hidden_states:\n\u001b[1;32m 1042\u001b[0m all_hidden_states \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m (hidden_states,)\n\u001b[0;32m-> 1043\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1044\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1045\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1046\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1047\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1048\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1049\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1050\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1051\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1052\u001b[0m \u001b[43m \u001b[49m\u001b[43minference_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minference_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1053\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mflash_attn_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1054\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1056\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1058\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_attentions \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(decoder_layer, AprielDecoderLayer):\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m~/dev/Fast-LLM/fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py:805\u001b[0m, in \u001b[0;36mAprielSSMDecoderLayer.forward\u001b[0;34m(self, hidden_states, inference_params, **kwargs)\u001b[0m\n\u001b[1;32m 801\u001b[0m residual \u001b[38;5;241m=\u001b[39m hidden_states\n\u001b[1;32m 803\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_layernorm(hidden_states)\n\u001b[0;32m--> 805\u001b[0m mixer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmixer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 806\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 807\u001b[0m \u001b[43m \u001b[49m\u001b[43minference_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minference_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 808\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 810\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m mixer_outputs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhidden_states\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mto(residual\u001b[38;5;241m.\u001b[39mdtype) \u001b[38;5;241m+\u001b[39m residual\n\u001b[1;32m 812\u001b[0m \u001b[38;5;66;03m# Fully Connected\u001b[39;00m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m~/dev/Fast-LLM/fast_llm/models/ssm/external/modeling_ssm_hybrid_apriel.py:460\u001b[0m, in \u001b[0;36mDiscreteMamba2.forward\u001b[0;34m(self, u, return_mixer_matrix, inference_params, **kwargs)\u001b[0m\n\u001b[1;32m 458\u001b[0m \u001b[38;5;66;03m# Project input\u001b[39;00m\n\u001b[1;32m 459\u001b[0m xBCzA_log \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_proj(u)\n\u001b[0;32m--> 460\u001b[0m xBC, z, A_log \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msplit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 461\u001b[0m \u001b[43m \u001b[49m\u001b[43mxBCzA_log\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 462\u001b[0m \u001b[43m \u001b[49m\u001b[43m[\u001b[49m\n\u001b[1;32m 463\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43md_inner\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_qk_heads\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43md_state\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 464\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43md_inner\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 465\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_v_heads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 466\u001b[0m \u001b[43m \u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 467\u001b[0m \u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 468\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 470\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m state \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 471\u001b[0m \u001b[38;5;66;03m# If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv\u001b[39;00m\n\u001b[1;32m 472\u001b[0m \u001b[38;5;66;03m# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.\u001b[39;00m\n\u001b[1;32m 473\u001b[0m xBC_t \u001b[38;5;241m=\u001b[39m rearrange(xBC[:, :seqlen, :], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mb l d -> b d l\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/functional.py:196\u001b[0m, in \u001b[0;36msplit\u001b[0;34m(tensor, split_size_or_sections, dim)\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 191\u001b[0m split, (tensor,), tensor, split_size_or_sections, dim\u001b[38;5;241m=\u001b[39mdim)\n\u001b[1;32m 192\u001b[0m \u001b[38;5;66;03m# Overwriting reason:\u001b[39;00m\n\u001b[1;32m 193\u001b[0m \u001b[38;5;66;03m# This dispatches to two ATen functions depending on the type of\u001b[39;00m\n\u001b[1;32m 194\u001b[0m \u001b[38;5;66;03m# split_size_or_sections. The branching code is in _tensor.py, which we\u001b[39;00m\n\u001b[1;32m 195\u001b[0m \u001b[38;5;66;03m# call here.\u001b[39;00m\n\u001b[0;32m--> 196\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtensor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msplit\u001b[49m\u001b[43m(\u001b[49m\u001b[43msplit_size_or_sections\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/_tensor.py:917\u001b[0m, in \u001b[0;36mTensor.split\u001b[0;34m(self, split_size, dim)\u001b[0m\n\u001b[1;32m 915\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_VF\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;28mself\u001b[39m, split_size, dim) \u001b[38;5;66;03m# type: ignore[attr-defined]\u001b[39;00m\n\u001b[1;32m 916\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 917\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_VF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msplit_with_sizes\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msplit_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[0;31mRuntimeError\u001b[0m: split_with_sizes expects split_sizes to sum exactly to 8216 (input tensor's size at dimension -1), but got split_sizes=[6144, 3072, 24]" - ] - } - ], - "source": [ - "\n", - "hybrid_apriel_model.forward(**static_inputs)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 2.44it/s]\n" - ] - }, - { - "data": { - "text/plain": [ - "AprielForCausalLM(\n", - " (model): AprielModel(\n", - " (embed_tokens): Embedding(131072, 4096)\n", - " (layers): ModuleList(\n", - " (0-27): 28 x AprielDecoderLayer(\n", - " (self_attn): AprielAttention(\n", - " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", - " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", - " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", - " )\n", - " (mlp): AprielMLP(\n", - " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", - " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", - " (act_fn): SiLU()\n", - " )\n", - " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", - " )\n", - " )\n", - " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", - " (rotary_emb): AprielRotaryEmbedding()\n", - " )\n", - " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", - ")" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", - "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", - "apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", - "apriel_state_dict = apriel_model.state_dict()\n", - "apriel_model.to(device).to(dtype=torch.bfloat16)" - ] - }, - { - "cell_type": "code", - "execution_count": 129, - "metadata": {}, - "outputs": [], - "source": [ - "# Innitialization using k, q, v from Apriel transformer\n", - "def expand_k_q(k):\n", - " Hq = config.num_attention_heads\n", - " Hk = config.num_key_value_heads\n", - " d_head = config.head_dim\n", - " d = k.shape[-1]\n", - " \n", - " # Expand k\n", - " repeat_factor = Hq // Hk\n", - " k_expanded = k.view(Hk, d_head, d)\n", - " k_expanded = k_expanded.repeat_interleave(repeat_factor, dim=0)\n", - " k_expanded = k_expanded.view(d_head * Hq, d)\n", - " return k_expanded\n", - "\n", - "for block_h, block_t in zip(hybrid_apriel_model.layers, apriel_model.model.layers):\n", - " # print(isinstance(block_h, AprielSSMDecoderLayer))\n", - " if isinstance(block_h, AprielSSMDecoderLayer):\n", - " # print(block_h.mixer.n_v_heads)\n", - " # print(block_t.self_attn.v_proj.weight.shape)\n", - " # print(block_h.mixer.in_proj.weight.shape)\n", - "\n", - " # print(block_h.mixer.in_proj.weight.shape)\n", - " # print(block_t.self_attn.v_proj.weight.shape)\n", - " block_h.mlp.load_state_dict(block_t.mlp.state_dict())\n", - " block_h.input_layernorm.load_state_dict(block_t.input_layernorm.state_dict())\n", - " block_h.post_attention_layernorm.load_state_dict(block_t.post_attention_layernorm.state_dict())\n", - " block_h.mixer.out_proj.load_state_dict(block_t.self_attn.o_proj.state_dict())\n", - " # [x B C z A_log]\n", - " # print(block_h.mixer.d_inner)\n", - " # init x, but interleave to address GQA\n", - " v_expended = expand_k_q(block_t.self_attn.v_proj.weight.data)\n", - " block_h.mixer.in_proj.weight.data[:block_h.mixer.d_inner, : ].copy_(v_expended)\n", - " # init k, but interleave to address GQA\n", - " k_expended = expand_k_q(block_t.self_attn.k_proj.weight.data)\n", - " block_h.mixer.in_proj.weight.data[block_h.mixer.d_inner: 2*block_h.mixer.d_inner, : ].copy_(k_expended)\n", - " # init C ewith Q\n", - " block_h.mixer.in_proj.weight.data[2*block_h.mixer.d_inner: 3*block_h.mixer.d_inner, : ].copy_(block_t.self_attn.q_proj.weight.data)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 124, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1024, 4096])" - ] - }, - "execution_count": 124, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "block_t.self_attn.v_proj.weight.data.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "d_xb = config.num_key_value_heads * config.head_dim\n", - "ssm_layers = [2,4,8]\n", - "attn_layers = [i for i in range(config.num_hidden_layers) if i not in ssm_layers]\n", - "model_name = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", - "ngroups = config.num_attention_heads # n heads\n", - "d_inner = config.head_dim * config.num_attention_heads\n", - "headdim = 128 # d_state\n", - "d_state = config.head_dim\n", - "d_model = config.hidden_size \n", - "assert d_inner == ngroups * d_state\n", - "\n", - "mamba_config = AprielSSMConfig(\n", - " ssm_cfg={\n", - " \"d_state\": 64,\n", - " \"n_v_heads\": 24,\n", - " \"n_qk_heads\": 24,\n", - " \"expand\": 1,\n", - " \"chunk_size\": 128,\n", - " \"activation\": \"identity\",\n", - " \"bias\": False,\n", - " \"d_inner\": 24 * headdim, # num_heads * head_dim\n", - " },\n", - " vocab_size=config.vocab_size, \n", - " hidden_size=config.hidden_size,\n", - " intermediate_size=config.intermediate_size,\n", - " num_hidden_layers=config.num_hidden_layers,\n", - " hidden_act=config.hidden_act,\n", - " initializer_range=config.initializer_range,\n", - " use_cache=config.use_cache,\n", - " mlp_bias=config.mlp_bias,\n", - " tie_word_embeddings=config.tie_word_embeddings,\n", - " pad_token_id=config.pad_token_id,\n", - " bos_token_id=config.bos_token_id,\n", - " eos_token_id=config.eos_token_id,\n", - " head_dim=config.head_dim,\n", - " rms_norm_eps=config.rms_norm_eps\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "student_model = MambaTransformerHybridModelWrapper.init_distillation(None, model_name, \n", - " mamba_config, \n", - " attn_layers=attn_layers, \n", - " init_with_kqvo=True, \n", - " attn_implementation=\"flash_attention_2\")\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "hymba2", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py index 94537c33..38ad5edf 100644 --- a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py +++ b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py @@ -25,13 +25,13 @@ def __init__(self, pretrained, **kwargs) -> None: def _get_config(self, pretrained: str, **kwargs) -> None: """Get the model configuration.""" - from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig + from fast_llm.models.ssm.external.aperiel_ssm.configuration_ssm_apriel import AprielSSMConfig self._config = AprielSSMConfig.from_pretrained(pretrained) def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: """Create the model.""" - from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM + from fast_llm.models.ssm.external.aperiel_ssm.modeling_ssm_apriel import AprielSSMForCausalLM self._model = AprielSSMForCausalLM.from_pretrained( pretrained, @@ -57,3 +57,56 @@ def _model_generate(self, context, max_length, stop, **generation_kwargs): max_length=max_length, **generation_kwargs, ) + + +@register_model("apriel_hybrid_ssm") +class AprielHybridSSMWrapper(HFLM): + """Wrapper for Rene model for compatibility with lm-evaluation-harness.""" + + def __init__(self, pretrained, **kwargs) -> None: + if "backend" in kwargs: + # rene currently only supports causal models + assert kwargs["backend"] == "causal" + + super().__init__( + pretrained=pretrained, + backend=kwargs.pop("backend", "causal"), + tokenizer=kwargs.pop("tokenizer", "/mnt/checkpoints/upstream/Mistral-Nemo-Base-2407/"), + max_length=kwargs.pop("max_length", 4096), + **kwargs, + ) + + def _get_config(self, pretrained: str, **kwargs) -> None: + """Get the model configuration.""" + from fast_llm.models.ssm.external.apriel_hybrid.configuration_ssm_hybrid_apriel import AprielSSMHybridConfig + + self._config = AprielSSMHybridConfig.from_pretrained(pretrained) + + def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: + """Create the model.""" + from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridForCausalLM + + self._model = AprielSSMHybridForCausalLM.from_pretrained( + pretrained, + device=self._device, + dtype=torch.bfloat16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), + trust_remote_code=True, + ) + + def _model_generate(self, context, max_length, stop, **generation_kwargs): + """Generate text from the model.""" + for key in ("do_sample", "attention_mask"): + if key in generation_kwargs: + generation_kwargs.pop(key) + + # The custom GenerationMixin imported from mamba_ssm currently does not support + # passing stopping criteria. + # For the time being, we simply generate to max length, then truncate (equivalent result). + # This should be revisited to speed up generation + # stopping_criteria = stop_sequences_criteria(self.tokenizer, stop, 1, context.shape[0]) + + return self.model.generate( + input_ids=context, + max_length=max_length, + **generation_kwargs, + ) diff --git a/fast_llm/models/ssm/external/configuration_mtp_llamba.py b/fast_llm/models/ssm/external/llamba/configuration_mtp_llamba.py similarity index 100% rename from fast_llm/models/ssm/external/configuration_mtp_llamba.py rename to fast_llm/models/ssm/external/llamba/configuration_mtp_llamba.py diff --git a/fast_llm/models/ssm/external/modeling_mtp_llamba.py b/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py similarity index 100% rename from fast_llm/models/ssm/external/modeling_mtp_llamba.py rename to fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py diff --git a/tests/test_ssms.py b/tests/test_ssms.py index 5863f903..bb6d5457 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -15,7 +15,7 @@ from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat -from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat +from fast_llm.models.ssm.config import AprielSSMHHybridHuggingfaceCheckpointFormat, LLambaHuggingfaceCheckpointFormat try: from fast_llm.layers.ssm.config import SSMConfig @@ -112,24 +112,102 @@ def get_hf_llamba_out(input_ids, path, format): return output, parameter_sum +# @pytest.mark.slow +# @pytest.mark.skipif( +# not run_test or LMHeadModel is None, +# reason=f"Skipping because one of the following: cartesia_pytorch.Llamba not installed or no CUDA available or Mamba not installed", +# ) +# def test_load_from_llamba_checkpoint(distributed_config): +# """ +# Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. +# """ +# vocab_size = 128256 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json +# batch_size = 2 +# seq_length = 32 + +# path = pathlib.Path("/mnt/checkpoints_fml/pretrained_models/Llamba-1B") +# format = LLambaHuggingfaceCheckpointFormat + +# x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") +# hf_logits, parameter_sum_hf = get_hf_llamba_out(x, path, format) +# hf_logits = hf_logits["logits"].cpu() + +# # Create checkpoint load config +# checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) +# # Initialize model +# model = HybridSSMModel.from_pretrained(checkpoint_config) +# param_sum = 0 +# for stage in model.stages: +# for fsdp in stage.fsdps: +# if hasattr(fsdp, "_weight_shard"): +# param_sum += torch.sum(fsdp._weight_shard).item() +# assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 + +# # model = GPTModel.from_pretrained(checkpoint_config) +# assert model.config.base_model.vocab_size == vocab_size +# schedule_config = ScheduleConfig() +# with NoAutoValidate(): +# batch_config = GPTBatchConfig(micro_batch_size=batch_size, sequence_length=seq_length) +# batch_config.setup(distributed_config) +# batch_config.validate() +# schedule_runner = ScheduleRunner( +# config=schedule_config, +# multi_stage=model, +# distributed_config=model.distributed.config, +# ) +# schedule = Schedule( +# multi_stage=model, +# batch_config=batch_config, +# schedule_config=schedule_config, +# distributed_config=model.distributed.config, +# phase=PhaseType.inference, +# ) +# schedule_runner.setup(model.distributed, optimizer=None) + +# common_kwargs = { +# TransformerKwargs.sequence_first: True, +# TransformerKwargs.grad_output: False, +# } +# input_data = [(x, common_kwargs)] + +# losses, success, metrics = schedule_runner.run_step( +# iter([input_data]), schedule, iteration=0, return_metrics=True, preprocessed=True +# ) + +# logits = input_data[0][1]["logits"].cpu() +# assert torch.allclose(logits, hf_logits, atol=1e-2) + + +def get_hf_apriel_hybrid_out(input_ids, path, format): + from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridForCausalLM + + model = AprielSSMHybridForCausalLM.from_pretrained(path, strict=True).to("cuda") + parameter_sum = sum(p.detach().cpu().numpy().sum() for p in model.parameters()) + print(f"Parameter sum: {parameter_sum}") + output = model(input_ids) + del model + torch.cuda.empty_cache() + return output, parameter_sum + + @pytest.mark.slow @pytest.mark.skipif( - not run_test or LMHeadModel is None, - reason=f"Skipping because one of the following: cartesia_pytorch.Llamba not installed or no CUDA available or Mamba not installed", + not run_test, + reason=f"Skipping because no CUDA available or Mamba not installed", ) -def test_load_from_llamba_checkpoint(distributed_config): +def test_load_from_hybridssm_checkpoint(distributed_config): """ Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. """ - vocab_size = 128256 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json + vocab_size = 131072 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json batch_size = 2 seq_length = 32 - path = pathlib.Path("/mnt/checkpoints_fml/pretrained_models/Llamba-1B") - format = LLambaHuggingfaceCheckpointFormat + path = pathlib.Path("/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_ssm2nd_init_mambainlama_debug") + format = AprielSSMHHybridHuggingfaceCheckpointFormat x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") - hf_logits, parameter_sum_hf = get_hf_llamba_out(x, path, format) + hf_logits, parameter_sum_hf = get_hf_apriel_hybrid_out(x, path, format) hf_logits = hf_logits["logits"].cpu() # Create checkpoint load config @@ -163,15 +241,21 @@ def test_load_from_llamba_checkpoint(distributed_config): phase=PhaseType.inference, ) schedule_runner.setup(model.distributed, optimizer=None) + from fast_llm.layers.transformer.config import RotaryConfig, RotaryEmbeddingType + from fast_llm.layers.transformer.preprocessing import get_rotary_frequencies + + rotary_config = RotaryConfig(type=RotaryEmbeddingType.default, theta=10000.0) # or whatever type your model uses + frequencies = get_rotary_frequencies(rotary_config, seq_length, 4096, device="cuda") - common_kwargs = { - TransformerKwargs.sequence_first: True, - TransformerKwargs.grad_output: False, - } - input_data = [(x, common_kwargs)] + from types import SimpleNamespace + batch = SimpleNamespace( + token_ids=x, + sequence_lengths=[[seq_length, seq_length]], + ) + input_data = [(batch)] losses, success, metrics = schedule_runner.run_step( - iter([input_data]), schedule, iteration=0, return_metrics=True, preprocessed=True + iter(input_data), schedule, iteration=0, return_metrics=True, preprocessed=False ) logits = input_data[0][1]["logits"].cpu() @@ -183,7 +267,7 @@ def test_load_from_llamba_checkpoint(distributed_config): "hybrid_block_layout,LAYER_CLS", [ (["m", "t"], MambaLayer), - (["m2", "t"], DiscreteMamba2), + (["m2d", "t"], DiscreteMamba2), ], ids=["mamba", "descrete_mamba2"], ) @@ -251,7 +335,7 @@ def test_mamba_block(distributed_config, distributed): "hybrid_block_layout", [ (["m", "t"]), - (["m2", "t"]), + (["m2d", "t"]), ], ids=["mamba", "descrete_mamba2"], ) @@ -338,3 +422,6 @@ def test_hybrid_model_train_with_fast_mode(distributed_config, hybrid_block_layo # }, # losses=losses, # ) + +if __name__ == "__main__": + pytest.main(["-s", __file__]) From 30ad8b8f890e43df61bc733fb6c04da8f9d59889 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 7 May 2025 12:32:06 +0000 Subject: [PATCH 069/114] wip --- .../ssm/external/aperiel_ssm/configuration_ssm_apriel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/models/ssm/external/aperiel_ssm/configuration_ssm_apriel.py b/fast_llm/models/ssm/external/aperiel_ssm/configuration_ssm_apriel.py index c3f7ef38..6943a312 100644 --- a/fast_llm/models/ssm/external/aperiel_ssm/configuration_ssm_apriel.py +++ b/fast_llm/models/ssm/external/aperiel_ssm/configuration_ssm_apriel.py @@ -96,8 +96,8 @@ def __init__( "bias": False, "d_inner": 24 * self.head_dim, # num_heads * head_dim } - if self.head_dim == self.ssm_cfg["d_inner"] // self.ssm_cfg["n_qk_heads"]: - logger.warning("Head dim is equal to d_inner // n_qk_heads.") + if self.head_dim != self.ssm_cfg["d_inner"] // self.ssm_cfg["n_qk_heads"]: + logger.warning("Head dim is not equal to d_inner // n_qk_heads.") __all__ = ["AprielConfig"] From 9c4f38f92c26c4a3a44ab67795f9dd3b58840245 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 7 May 2025 15:39:51 +0000 Subject: [PATCH 070/114] layer-lr scale for mlp as well --- fast_llm/layers/transformer/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 9e1e0bcf..98238172 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -46,7 +46,7 @@ def __init__( self._create_mixer() self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp" + self._config, self._tensor_space, f"{self.name} mlp", layer_index=layer_index ) # PEFT. From 1784dcaf64fd312734042176dea827382214a16c Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 7 May 2025 21:55:30 +0000 Subject: [PATCH 071/114] wip --- fast_llm/models/ssm/config.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 7e0a69fd..15a75ac2 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -100,19 +100,19 @@ def _validate(self): if self.hybrid_block_layout is None: with self._set_implicit_default(): self.hybrid_block_layout = [SSMBlockType.mamba2_discrete.value] - len_block_layout = len(self.hybrid_block_layout) - if len_block_layout != self.transformer.num_layers: - if self.transformer.num_layers % len_block_layout != 0: + + if len(self.hybrid_block_layout) != self.transformer.num_layers: + if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: raise ValueError( - f"hybrid_block_layout length {len_block_layout} does not match num_layers {self.transformer.num_layers}" + f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" ) - num_repeats = int(self.transformer.num_layers // len_block_layout) + num_repeats = int(self.transformer.num_layers // len(self.hybrid_block_layout)) logger.warning( - f"hybrid_block_layout length {len_block_layout} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" + f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" ) self.hybrid_block_layout = self.hybrid_block_layout * num_repeats - Assert.eq(len_block_layout, self.transformer.num_layers) + Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) Assert.custom( lambda _: all(block_type in SSMBlockType.__members__.values() for block_type in self.hybrid_block_layout), f"Invalid block type: {self.hybrid_block_layout}. Must be one of {SSMBlockType.__members__.values()}", From 1e3cc2847887070c90d366ccf8497babd3feb661 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 7 May 2025 22:01:11 +0000 Subject: [PATCH 072/114] nvm --- fast_llm/models/ssm/conversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 39aaa6e9..357a26c0 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -141,7 +141,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _create_weight_converters(self) -> list[WeightConverter]: - converters = super()._create_weight_converters() + converters = super()._create_weight_converters() or [] num_layers = self._model.config.base_model.transformer.num_layers ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear From 2dc945b9402dbbe4379458f8238e5ab2fb2b4cff Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 9 May 2025 12:48:08 +0000 Subject: [PATCH 073/114] hybrid modeling --- .../configuration_ssm_hybrid_apriel.py | 10 +- .../modeling_ssm_hybrid_apriel.py | 288 ++++++++++++++---- .../ssm/external/eval/apriel_eval_wrapper.py | 30 +- 3 files changed, 250 insertions(+), 78 deletions(-) diff --git a/fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py index b030150c..1d230bb6 100644 --- a/fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py +++ b/fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py @@ -431,7 +431,7 @@ def __init__( **kwargs, ) - self.ssm_cfg = ssm_cfg or { + ssm_defaults = { "d_state": 64, "n_v_heads": 24, "n_qk_heads": 24, @@ -439,8 +439,10 @@ def __init__( "chunk_size": 128, "activation": "identity", "bias": False, + "d_conv": 4, "d_inner": 24 * self.head_dim, # num_heads * head_dim } - - -__all__ = ["AprielConfig"] + self.ssm_cfg = ssm_cfg or ssm_defaults + for k, v in ssm_defaults.items(): + if k not in self.ssm_cfg: + self.ssm_cfg[k] = v diff --git a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py index 950327df..d6fd3518 100644 --- a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py +++ b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Callable, Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.nn.functional as F @@ -8,7 +8,6 @@ from einops import rearrange, repeat from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined -from mamba_ssm.utils.generation import GenerationMixin from torch import nn from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache @@ -29,14 +28,133 @@ logger = logging.get_logger(__name__) +# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py +class HybridMambaAttentionDynamicCache(DynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config: AprielSSMHybridConfig, batch_size, dtype=torch.float16, device=None): + super().__init__() + self.dtype = dtype + self.hybrid_override_pattern = config.hybrid_block_layout + self.has_previous_state = False # only used by mamba + intermediate_size = config.ssm_cfg["d_inner"] + ssm_state_size = config.ssm_cfg["d_state"] + conv_kernel_size = config.ssm_cfg["d_conv"] + self.n_qk_heads = config.ssm_cfg["n_qk_heads"] + assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" + self.head_d = intermediate_size // self.n_qk_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + if self.hybrid_override_pattern[i] == "m2d": + # Mamba layer + self.conv_states += [ + torch.zeros( + batch_size, + conv_kernel_size, + intermediate_size + 2 * self.n_qk_heads * ssm_state_size, + device=device, + dtype=dtype, + ).transpose(1, 2) + ] + self.ssm_states += [ + torch.zeros(batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype) + ] + else: + # Attention or MLP layer + self.conv_states += [torch.tensor([[]] * batch_size, device=device)] + self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + # Copied from modeling_mamba2.py + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False + ) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + @dataclass -class CustomMambaCausalLMOutput(ModelOutput): +class AprielHybridCausalOutput(ModelOutput): """Custom output class for MambaLMHeadModel.""" loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None last_hidden_state: Optional[torch.FloatTensor] = None + attention_weights: Optional[torch.FloatTensor] = None + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None class AprielRMSNorm(nn.Module): @@ -333,6 +451,7 @@ def materialize_mixer(A_log, B, C, D): return T +# This is from LLmaba/Mohawk: https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py class DiscreteMamba2(nn.Module): def __init__( self, @@ -424,7 +543,14 @@ def d_output(self): def state_to_tensor(self): return self.layer.state_to_tensor - def forward(self, u, return_mixer_matrix=False, inference_params=None, **kwargs): + def forward( + self, + u, + return_mixer_matrix=False, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ): """ u: (B, L, D) Returns: same shape as u @@ -433,16 +559,17 @@ def forward(self, u, return_mixer_matrix=False, inference_params=None, **kwargs) # assert state is None batch, seqlen, dim = u.shape - state = None - if inference_params is not None: - state = self._get_states_from_cache(inference_params, batch) - if inference_params.seqlen_offset > 0: + ssm_state, conv_state = None, None + if past_key_value is not None: + ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) + if cache_position[0] > 0: # States are updated inplace - out, _ = self.step(u, state) + u = u.squeeze(1) if len(u.shape) == 3 else u + out, _ = self.step(u, ssm_state, conv_state) return {"hidden_states": out} # Hacky way to initialize state during inference - chunk_size = self.chunk_size if state is None else seqlen + chunk_size = self.chunk_size if ssm_state is None else seqlen # Pad input to nearest multiple of chunklen padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size @@ -460,11 +587,11 @@ def forward(self, u, return_mixer_matrix=False, inference_params=None, **kwargs) dim=-1, ) - if state is not None: + if ssm_state is not None: # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") - state["conv"].copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) # Convolutional layer xBC = self.convolutional_forward(xBC, padded_len) @@ -493,12 +620,12 @@ def forward(self, u, return_mixer_matrix=False, inference_params=None, **kwargs) C=C, chunk_size=chunk_size, # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation - return_final_states=(state is not None), + return_final_states=(ssm_state is not None), ) - if state is not None: - y, ssm_state = result - state["ssm"].copy_(ssm_state) + if ssm_state is not None: + y, ssm_state_update = result + ssm_state.copy_(ssm_state_update) else: y = result @@ -513,7 +640,7 @@ def forward(self, u, return_mixer_matrix=False, inference_params=None, **kwargs) outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] return outputs - def step(self, u, state, **kwargs): + def step(self, u, ssm_state, conv_state, **kwargs): """ u: (B D) state: dict of states @@ -521,7 +648,7 @@ def step(self, u, state, **kwargs): """ # Project input - xBCzA_log = self.in_proj(u.squeeze(1)) + xBCzA_log = self.in_proj(u) xBC, z, A_log = torch.split( xBCzA_log, [ @@ -532,8 +659,8 @@ def step(self, u, state, **kwargs): dim=-1, ) - xBC, conv_state = self.convolutional_step(xBC, state["conv"]) - state["conv"].copy_(conv_state) # update state in place + xBC, conv_state = self.convolutional_step(xBC, conv_state) + conv_state.copy_(conv_state) # update state in place x, B, C = torch.split( xBC, @@ -549,7 +676,7 @@ def step(self, u, state, **kwargs): B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) - state["ssm"] = state["ssm"].to(x.dtype) + ssm_state = ssm_state.to(x.dtype) zeros = torch.zeros((self.n_v_heads, self.headdim), device=A_log.device).to(dtype=x.dtype) ones = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=A_log.device).to(dtype=x.dtype) y = selective_state_update( @@ -559,7 +686,7 @@ def step(self, u, state, **kwargs): A=-ones, B=B, C=C, - state=state["ssm"], # will be updated in place + state=ssm_state, # will be updated in place dt_bias=zeros, D=zeros, ) @@ -570,7 +697,7 @@ def step(self, u, state, **kwargs): # Norm and gate out = self.out_proj(y * F.silu(z + self.z_bias)) - return out, state + return out, ssm_state def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): device = self.in_proj.weight.device @@ -602,16 +729,17 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states """ assert self.layer_idx is not None # Allocate memory if not exists - if self.layer_idx not in inference_params.key_value_memory_dict: - inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( - batch_size, inference_params.max_seqlen, dtype=torch.float32 - ) + # if self.layer_idx not in inference_params.ssm_states: + # inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( + # batch_size, inference_params.max_seqlen, dtype=torch.float32 + # ) # Get states - states = inference_params.key_value_memory_dict[self.layer_idx] + ssm_states = inference_params.ssm_states[self.layer_idx] + conv_states = inference_params.conv_states[self.layer_idx] if initialize_states: - states["conv"].zero_() - states["ssm"].zero_() - return states + ssm_states.zero_() + conv_states.zero_() + return ssm_states, conv_states def convolutional_forward(self, xBC, padded_len): if causal_conv1d_fn is None or self.activation not in [ @@ -724,7 +852,7 @@ def __init__(self, config: AprielSSMHybridConfig, layer_idx: int, device=None, d self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) def forward( - self, hidden_states: torch.Tensor, inference_params=None, **kwargs + self, hidden_states: torch.Tensor, **kwargs ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: outputs = {} @@ -734,7 +862,7 @@ def forward( mixer_outputs = self.mixer( hidden_states, - inference_params=inference_params, + **kwargs, ) hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual @@ -878,6 +1006,7 @@ def __init__(self, config: AprielSSMHybridConfig, device=None, dtype=None, **kwa factory_kwargs = {"device": device, "dtype": dtype} self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, **factory_kwargs) blocks = [] + logger.info(f"Loading hyubrid model with the following layout: {config.hybrid_block_layout}") for layer_idx, type in enumerate(config.hybrid_block_layout): if type == "m2d": blocks.append(AprielSSMDecoderLayer(config, layer_idx, **factory_kwargs)) @@ -913,7 +1042,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -943,7 +1072,11 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - past_key_values = DynamicCache() + # past_key_values = HybridMambaAttentionDynamicCache() + logger.warning_once( + "Hybrid Apriel requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " + "provided, so no cache will be returned." + ) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1133,12 +1266,11 @@ class AprielSSMHybridForCausalLM(AprielSSMPreTrainedModel, GenerationMixin): _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - def __init__(self, config, device=None, dtype=None, **kwargs): - super().__init__(config, device=device, dtype=dtype, **kwargs) + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) self.model = AprielSSMHybridModel(config) self.vocab_size = config.vocab_size - factory_kwargs = {"device": device, "dtype": dtype} - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, **factory_kwargs) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() @@ -1161,23 +1293,82 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + # "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs + def forward( self, input_ids: torch.LongTensor = None, position_ids=None, return_hidden_states=False, return_logits=True, - inference_params=None, num_last_tokens=0, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[tuple, CausalLMOutputWithPast]: + # past_key_values is None if prepare_inputs_for_generation is not called, which is the case when we evaluate without calling generate (non-generation tasks) + # Its generally ok if cache is nto instantiated in this case, since we do single pass per sample anyways, a warning will be triggered in the model outputs: BaseModelOutputWithPast = self.model( input_ids, return_hidden_states=return_hidden_states, - inference_params=inference_params, position_ids=position_ids, - return_dict=True, + past_key_values=past_key_values, + **kwargs, ) if outputs["last_hidden_state"] is not None and return_logits: @@ -1186,22 +1377,17 @@ def forward( else: outputs["logits"] = None - return CustomMambaCausalLMOutput( + return AprielHybridCausalOutput( loss=None, logits=outputs["logits"], all_hidden_states=outputs.hidden_states, last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, ) - def generate(self, *args, **kwargs): - """ - This is a wrapper to make sure we comply with the HF generation interface for eval harness - """ - return super().generate(*args, **kwargs) - __all__ = [ - "AprielSSMForCausalLM", - "AprielModel", + "AprielSSMHybridForCausalLM", + "AprielSSMHybridModel", "AprielSSMPreTrainedModel", ] diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py index 38ad5edf..a7cf34e4 100644 --- a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py +++ b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py @@ -8,11 +8,10 @@ @register_model("apriel_ssm") class AprielSSMWrapper(HFLM): - """Wrapper for Rene model for compatibility with lm-evaluation-harness.""" + """Wrapper for AprielSSM model for compatibility with lm-evaluation-harness.""" def __init__(self, pretrained, **kwargs) -> None: if "backend" in kwargs: - # rene currently only supports causal models assert kwargs["backend"] == "causal" super().__init__( @@ -61,11 +60,10 @@ def _model_generate(self, context, max_length, stop, **generation_kwargs): @register_model("apriel_hybrid_ssm") class AprielHybridSSMWrapper(HFLM): - """Wrapper for Rene model for compatibility with lm-evaluation-harness.""" + """Wrapper for AprielHybridSSM model for compatibility with lm-evaluation-harness.""" def __init__(self, pretrained, **kwargs) -> None: if "backend" in kwargs: - # rene currently only supports causal models assert kwargs["backend"] == "causal" super().__init__( @@ -80,7 +78,7 @@ def _get_config(self, pretrained: str, **kwargs) -> None: """Get the model configuration.""" from fast_llm.models.ssm.external.apriel_hybrid.configuration_ssm_hybrid_apriel import AprielSSMHybridConfig - self._config = AprielSSMHybridConfig.from_pretrained(pretrained) + self._config = AprielSSMHybridConfig.from_pretrained(pretrained, trust_remote_code=True) def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: """Create the model.""" @@ -89,24 +87,10 @@ def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype] self._model = AprielSSMHybridForCausalLM.from_pretrained( pretrained, device=self._device, - dtype=torch.bfloat16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), - trust_remote_code=True, + torch_dtype=torch.bfloat16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), + **kwargs, ) def _model_generate(self, context, max_length, stop, **generation_kwargs): - """Generate text from the model.""" - for key in ("do_sample", "attention_mask"): - if key in generation_kwargs: - generation_kwargs.pop(key) - - # The custom GenerationMixin imported from mamba_ssm currently does not support - # passing stopping criteria. - # For the time being, we simply generate to max length, then truncate (equivalent result). - # This should be revisited to speed up generation - # stopping_criteria = stop_sequences_criteria(self.tokenizer, stop, 1, context.shape[0]) - - return self.model.generate( - input_ids=context, - max_length=max_length, - **generation_kwargs, - ) + # FOR now evaluating with non-generation tasks + raise NotImplementedError("Generation not implemented yet for AprielHybridSSMWrapper") From 4277e6779ab78d8dc62615d729e54c685cfa5d92 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 9 May 2025 12:49:57 +0000 Subject: [PATCH 074/114] modeling --- .../modeling_ssm_hybrid_apriel.py | 75 +++++++++---------- 1 file changed, 37 insertions(+), 38 deletions(-) diff --git a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py index d6fd3518..95c09e0c 100644 --- a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py +++ b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py @@ -699,28 +699,28 @@ def step(self, u, ssm_state, conv_state, **kwargs): return out, ssm_state - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - device = self.in_proj.weight.device - # conv_state: - conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype - conv_state = torch.zeros( - batch_size, - self.d_conv, - self.conv1d.weight.shape[0], - device=device, - dtype=conv_dtype, - ).transpose(1, 2) - # ssm_state: - ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype - ssm_state = torch.zeros( - batch_size, - self.n_v_heads, - self.headdim, - self.d_state, - device=device, - dtype=ssm_dtype, - ) - return {"conv": conv_state, "ssm": ssm_state} + # def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + # device = self.in_proj.weight.device + # # conv_state: + # conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + # conv_state = torch.zeros( + # batch_size, + # self.d_conv, + # self.conv1d.weight.shape[0], + # device=device, + # dtype=conv_dtype, + # ).transpose(1, 2) + # # ssm_state: + # ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype + # ssm_state = torch.zeros( + # batch_size, + # self.n_v_heads, + # self.headdim, + # self.d_state, + # device=device, + # dtype=ssm_dtype, + # ) + # return {"conv": conv_state, "ssm": ssm_state} def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): """ @@ -800,7 +800,6 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - inference_params=None, # just to be compatible with SSM block **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -878,11 +877,11 @@ def forward( return outputs - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - """Allocate inference cache for the model.""" - if getattr(self.mixer, "allocate_inference_cache", None) is None: - return - return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + # def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + # """Allocate inference cache for the model.""" + # if getattr(self.mixer, "allocate_inference_cache", None) is None: + # return + # return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) APRIEL_START_DOCSTRING = r""" @@ -920,9 +919,9 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def allocate_inference_cache(self, *args, **kwargs): - """Allocate inference cache for the model.""" - return getattr(self, self.base_model_prefix).allocate_inference_cache(*args, **kwargs) + # def allocate_inference_cache(self, *args, **kwargs): + # """Allocate inference cache for the model.""" + # return getattr(self, self.base_model_prefix).allocate_inference_cache(*args, **kwargs) APRIEL_INPUTS_DOCSTRING = r""" @@ -1028,13 +1027,13 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - def allocate_inference_cache(self, *args, **kwargs): - """Allocate inference cache for the model.""" - cache = {} - for i, layer in enumerate(self.layers): - if isinstance(layer, AprielSSMDecoderLayer): - cache[i] = layer.allocate_inference_cache(*args, **kwargs) - return cache + # def allocate_inference_cache(self, *args, **kwargs): + # """Allocate inference cache for the model.""" + # cache = {} + # for i, layer in enumerate(self.layers): + # if isinstance(layer, AprielSSMDecoderLayer): + # cache[i] = layer.allocate_inference_cache(*args, **kwargs) + # return cache @add_start_docstrings_to_model_forward(APRIEL_INPUTS_DOCSTRING) def forward( From c71cb16db8df7d637f83853ccd9419162c151177 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 9 May 2025 13:01:10 +0000 Subject: [PATCH 075/114] nvm --- tests/test_ssms.py | 45 +++------------------------------------------ 1 file changed, 3 insertions(+), 42 deletions(-) diff --git a/tests/test_ssms.py b/tests/test_ssms.py index 551da7e1..9f3382b6 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -13,6 +13,7 @@ from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames +from fast_llm.layers.ssm.config import SSMBlockType from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat from fast_llm.models.ssm.config import AprielSSMHHybridHuggingfaceCheckpointFormat, LLambaHuggingfaceCheckpointFormat @@ -182,53 +183,13 @@ def test_load_from_hybridssm_checkpoint(distributed_config): param_sum += torch.sum(fsdp._weight_shard).item() assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 - # # model = GPTModel.from_pretrained(checkpoint_config) - # assert model.config.base_model.vocab_size == vocab_size - # schedule_config = ScheduleConfig() - # with NoAutoValidate(): - # batch_config = GPTBatchConfig(micro_batch_size=batch_size, sequence_length=seq_length) - # batch_config.setup(distributed_config) - # batch_config.validate() - # schedule_runner = ScheduleRunner( - # config=schedule_config, - # multi_stage=model, - # distributed_config=model.distributed.config, - # ) - # schedule = Schedule( - # multi_stage=model, - # batch_config=batch_config, - # schedule_config=schedule_config, - # distributed_config=model.distributed.config, - # phase=PhaseType.inference, - # ) - # schedule_runner.setup(model.distributed, optimizer=None) - # from fast_llm.layers.transformer.config import RotaryConfig, RotaryEmbeddingType - # from fast_llm.layers.transformer.preprocessing import get_rotary_frequencies - - # rotary_config = RotaryConfig(type=RotaryEmbeddingType.default, theta=10000.0) # or whatever type your model uses - # frequencies = get_rotary_frequencies(rotary_config, seq_length, 4096, device="cuda") - - # from types import SimpleNamespace - - # batch = SimpleNamespace( - # token_ids=x, - # sequence_lengths=[[seq_length, seq_length]], - # ) - # input_data = [batch] - # losses, success, metrics = schedule_runner.run_step( - # iter(input_data), schedule, iteration=0, return_metrics=True, preprocessed=False - # ) - - # logits = input_data[0][1]["logits"].cpu() - # assert torch.allclose(logits, hf_logits, atol=1e-2) - @pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") @pytest.mark.parametrize( "hybrid_block_layout,LAYER_CLS", [ - (["m", "t"], MambaLayer), - (["m2d", "t"], DiscreteMamba2), + ([SSMBlockType.mamba, SSMBlockType.transformer], MambaLayer), + ([SSMBlockType.mamba2_discrete, SSMBlockType.transformer], DiscreteMamba2), ], ids=["mamba", "discrete_mamba2"], ) From be04c192575be79b3568162d54847c692d4cb56a Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 9 May 2025 13:14:51 +0000 Subject: [PATCH 076/114] output lr scale --- fast_llm/layers/language_model/config.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index ef1e3a37..96ea3f7c 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -146,6 +146,12 @@ class LanguageModelBaseConfig(BaseModelConfig): hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) + output_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the output weights.", + doc="May be used to freeze the output weights by setting their scale to zero.", + hint=FieldHint.feature, + ) def _validate(self) -> None: self.transformer.validate() From 1311f5b28aafe2c3df4e2deab7b8b350cc4d60d7 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 9 May 2025 13:15:12 +0000 Subject: [PATCH 077/114] output_lr_scale --- fast_llm/layers/language_model/head.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 813dcc07..2cc7730b 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -104,6 +104,7 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: min_val=config.init_method_min_embed, max_val=config.init_method_max_embed, ), + lr_scale=config.output_lr_scale, ) def forward( From baf4011943a4912467ab16930336febeb005d6b6 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 9 May 2025 13:44:15 +0000 Subject: [PATCH 078/114] nvm --- .../ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py index 95c09e0c..6800158e 100644 --- a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py +++ b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py @@ -9,6 +9,7 @@ from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined from torch import nn +from transformers import GenerationMixin from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter From 6cf26c5d78a52dfcd0fddd02b156e9dd53fdfa40 Mon Sep 17 00:00:00 2001 From: oleksost Date: Sat, 10 May 2025 15:41:02 +0000 Subject: [PATCH 079/114] eval --- fast_llm/models/ssm/external/eval/run_lm_eval.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fast_llm/models/ssm/external/eval/run_lm_eval.py b/fast_llm/models/ssm/external/eval/run_lm_eval.py index af07869a..c910bcc3 100644 --- a/fast_llm/models/ssm/external/eval/run_lm_eval.py +++ b/fast_llm/models/ssm/external/eval/run_lm_eval.py @@ -1,6 +1,9 @@ from lm_eval.__main__ import cli_evaluate -from fast_llm.models.ssm.external.eval.apriel_eval_wrapper import AprielSSMWrapper # noqa: F401 +from fast_llm.models.ssm.external.eval.apriel_eval_wrapper import ( # noqa: F401 + AprielHybridSSMWrapper, + AprielSSMWrapper, +) if __name__ == "__main__": cli_evaluate() From 901d1b6ad38cd6b498e43ada472074bbeb6a3766 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 12 May 2025 12:50:33 +0000 Subject: [PATCH 080/114] rename --- .../{aperiel_ssm => apriel_ssm}/configuration_ssm_apriel.py | 0 .../{aperiel_ssm => apriel_ssm}/modeling_ssm_apriel.py | 2 +- fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py | 4 ++-- 3 files changed, 3 insertions(+), 3 deletions(-) rename fast_llm/models/ssm/external/{aperiel_ssm => apriel_ssm}/configuration_ssm_apriel.py (100%) rename fast_llm/models/ssm/external/{aperiel_ssm => apriel_ssm}/modeling_ssm_apriel.py (99%) diff --git a/fast_llm/models/ssm/external/aperiel_ssm/configuration_ssm_apriel.py b/fast_llm/models/ssm/external/apriel_ssm/configuration_ssm_apriel.py similarity index 100% rename from fast_llm/models/ssm/external/aperiel_ssm/configuration_ssm_apriel.py rename to fast_llm/models/ssm/external/apriel_ssm/configuration_ssm_apriel.py diff --git a/fast_llm/models/ssm/external/aperiel_ssm/modeling_ssm_apriel.py b/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py similarity index 99% rename from fast_llm/models/ssm/external/aperiel_ssm/modeling_ssm_apriel.py rename to fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py index dd228024..a46530fc 100644 --- a/fast_llm/models/ssm/external/aperiel_ssm/modeling_ssm_apriel.py +++ b/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py @@ -19,7 +19,7 @@ from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging from transformers.utils.generic import ModelOutput -from fast_llm.models.ssm.external.aperiel_ssm.configuration_ssm_apriel import AprielSSMConfig +from fast_llm.models.ssm.external.apriel_ssm.configuration_ssm_apriel import AprielSSMConfig logger = logging.get_logger(__name__) diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py index a7cf34e4..02c9176b 100644 --- a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py +++ b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py @@ -24,13 +24,13 @@ def __init__(self, pretrained, **kwargs) -> None: def _get_config(self, pretrained: str, **kwargs) -> None: """Get the model configuration.""" - from fast_llm.models.ssm.external.aperiel_ssm.configuration_ssm_apriel import AprielSSMConfig + from fast_llm.models.ssm.external.apriel_ssm.configuration_ssm_apriel import AprielSSMConfig self._config = AprielSSMConfig.from_pretrained(pretrained) def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: """Create the model.""" - from fast_llm.models.ssm.external.aperiel_ssm.modeling_ssm_apriel import AprielSSMForCausalLM + from fast_llm.models.ssm.external.apriel_ssm.modeling_ssm_apriel import AprielSSMForCausalLM self._model = AprielSSMForCausalLM.from_pretrained( pretrained, From 616c54069802ccf2e65ca872f84e30024bc0ef20 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 12 May 2025 13:39:27 +0000 Subject: [PATCH 081/114] per_layer_lr_scale --- fast_llm/layers/common/config.py | 11 +++++++++++ fast_llm/layers/ssm/config.py | 13 +++++++++---- fast_llm/layers/ssm/discrete_mamba2.py | 11 ++++++++++- fast_llm/layers/ssm/mamba_layer.py | 10 ++++++++++ fast_llm/layers/transformer/config.py | 10 ++-------- 5 files changed, 42 insertions(+), 13 deletions(-) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 16be4987..e8e068c0 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -11,6 +11,17 @@ from fast_llm.layers.common.normalization import LayerNorm, RMSNorm +class LLMBlockConfig(BaseModelConfig): + _abstract = False + + per_layer_lr_scale: list[float] | None = Field( + default=None, + desc="Custom learning rate scale for each layer.", + doc="May be used to freeze some layers by setting their scale to zero.", + hint=FieldHint.feature, + ) + + class NormalizationImplementation(str, enum.Enum): """ An enum for the available implementations of layer norm. diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 9faec879..846ab43d 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,9 +1,8 @@ import enum -from fast_llm.config import Field, FieldHint, check_field, config_class -from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import NormalizationConfig +from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig from fast_llm.utils import Assert @@ -34,7 +33,7 @@ class SSMBlockType(str, enum.Enum): @config_class() -class SSMConfig(BaseModelConfig): +class SSMConfig(LLMBlockConfig): _abstract = False # Normalization @@ -122,6 +121,12 @@ class SSMConfig(BaseModelConfig): desc="Inner dimension for Mamba2 blocks.", hint=FieldHint.core, ) + mamba_lr_scale: float = Field( + default=None, + desc="Learning rate scale for Mamba blocks.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) def _validate(self) -> None: with self._set_implicit_default(): diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 49dacb91..16686fe8 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -9,6 +9,7 @@ from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ +from fast_llm.utils import get_lr_scale """ This code is adapted fropm https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py @@ -44,6 +45,8 @@ def __init__( bias = config.add_bias_linear self.layer_idx = layer_idx self._return_input = return_input + layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) @@ -73,6 +76,7 @@ def __init__( (td_inner,), weight_decay=False, init_method=init_zeros_, + lr_scale=mamba_layer_lr_scale, ) if not bias else 0.0 @@ -84,14 +88,18 @@ def __init__( init_method=init_uniform_( 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 + lr_scale=mamba_layer_lr_scale, + ) + self.conv1d_bias = ParameterMeta.from_dims( + (td_conv,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale ) - self.conv1d_bias = ParameterMeta.from_dims((td_conv,), init_method=bias_init_method(self.conv1d_weight)) # D "skip" parameter self.D = ParameterMeta.from_dims( (td_n_qk_heads,), weight_decay=False, init_method=init_ones_, + lr_scale=mamba_layer_lr_scale, ) # out_proj @@ -100,6 +108,7 @@ def __init__( td_model, bias=bias, weight_init_method=kaiming_init_(td_inner.size), + lr_scale=mamba_layer_lr_scale, ) @property diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 4704b522..e44a4e1d 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -9,6 +9,7 @@ from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ +from fast_llm.utils import get_lr_scale """ Note: this is mostly addapted from https://github.com/Zyphra/Zamba2, similar code is aslo in https://github.com/state-spaces/mamba. @@ -81,6 +82,8 @@ def __init__( self.d_state = td_state.size self.d_model = td_model.size self.dt_rank = tdt_rank.size + layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) self.in_proj_weight = ParameterMeta.from_dims( (td_inner_proj, td_model), @@ -90,6 +93,7 @@ def __init__( self.conv1d_weight = ParameterMeta.from_dims( (td_inner, TensorDim("D_inner_2", self.d_inner // self.d_inner), td_conv_kernel), init_method=kaiming_init_(td_inner.size), + lr_scale=mamba_layer_lr_scale, ) self.conv1d_bias = None @@ -102,6 +106,7 @@ def __init__( td_x_proj, weight_init_method=kaiming_init_(td_inner.size), bias=False, + layer_lr_scale=mamba_layer_lr_scale, **factory_kwargs, ) self.x_proj.weight.auto_grad_accumulation = True @@ -110,6 +115,7 @@ def __init__( self.dt_proj_weight = ParameterMeta.from_dims( (td_inner, tdt_rank), init_method=kaiming_init_(tdt_rank.size), + lr_scale=mamba_layer_lr_scale, ) self.dt_proj_bias = ParameterMeta.from_dims( @@ -117,12 +123,14 @@ def __init__( init_method=init_dtprojbias( self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor, factory_kwargs ), + lr_scale=mamba_layer_lr_scale, ) self.A_log = ParameterMeta.from_dims( (td_inner, td_state), weight_decay=False, init_method=init_A(self.d_state, self.d_inner), + lr_scale=mamba_layer_lr_scale, ) # D "skip" parameter @@ -130,6 +138,7 @@ def __init__( (td_inner,), weight_decay=False, init_method=init_ones_, + lr_scale=mamba_layer_lr_scale, ) self.out_proj = Linear( @@ -137,6 +146,7 @@ def __init__( td_model, bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. weight_init_method=kaiming_init_(td_model.size), + lr_scale=mamba_layer_lr_scale, **factory_kwargs, ) self.out_proj.weight.auto_grad_accumulation = True diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index c6ea98b1..e4eaac1d 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -11,7 +11,7 @@ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel, TritonConfig -from fast_llm.layers.common.config import NormalizationConfig, PeftConfig, PeftType +from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig, PeftConfig, PeftType from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: @@ -245,7 +245,7 @@ def _validate(self) -> None: @config_class() -class TransformerConfig(BaseModelConfig): +class TransformerConfig(LLMBlockConfig): _abstract = False normalization: NormalizationConfig = Field( default_factory=NormalizationConfig, @@ -496,12 +496,6 @@ class TransformerConfig(BaseModelConfig): doc="May be used to freeze some experts by setting their scale to zero.", hint=FieldHint.feature, ) - per_layer_lr_scale: list[float] | None = Field( - default=None, - desc="Custom learning rate scale for each layer.", - doc="May be used to freeze some layers by setting their scale to zero.", - hint=FieldHint.feature, - ) router_lr_scale: float | None = Field( default=None, desc="Custom learning rate for the MoE router weight.", From 9af5ee5da24fb693b0a03a3dd722af19ed0f98ce Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 12 May 2025 15:00:22 +0000 Subject: [PATCH 082/114] merged also prediction_loss_coefficient from #243 --- fast_llm/layers/common/config.py | 1 + fast_llm/layers/language_model/config.py | 10 ++++++++++ fast_llm/layers/ssm/config.py | 2 +- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index e8e068c0..50ccab01 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -11,6 +11,7 @@ from fast_llm.layers.common.normalization import LayerNorm, RMSNorm +@config_class() class LLMBlockConfig(BaseModelConfig): _abstract = False diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 96ea3f7c..f6d37616 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -152,6 +152,12 @@ class LanguageModelBaseConfig(BaseModelConfig): doc="May be used to freeze the output weights by setting their scale to zero.", hint=FieldHint.feature, ) + prediction_loss_coefficient: list[float] | None = Field( + default=None, + desc="Loss coefficient for each prediction head.", + doc="If not provided, all heads are equally weighted.", + hint=FieldHint.feature, + ) def _validate(self) -> None: self.transformer.validate() @@ -170,6 +176,10 @@ def _validate(self) -> None: if self.distillation_model is not None: if self.prediction_heads > 1: raise NotImplementedError("Multi-token prediction not supported with distillation.") + if isinstance(self.prediction_loss_coefficient, list): + Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads) + for coeff in self.prediction_loss_coefficient: + Assert.geq(coeff, 0) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: self.transformer.setup_tensor_space(tensor_space) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 846ab43d..6cfe2ebe 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -121,7 +121,7 @@ class SSMConfig(LLMBlockConfig): desc="Inner dimension for Mamba2 blocks.", hint=FieldHint.core, ) - mamba_lr_scale: float = Field( + mamba_lr_scale: float | None = Field( default=None, desc="Learning rate scale for Mamba blocks.", hint=FieldHint.feature, From 1a7939bf52ee06c7c2b296364343d223ac017b24 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 12 May 2025 15:23:40 +0000 Subject: [PATCH 083/114] added logging in mamba --- fast_llm/layers/ssm/discrete_mamba2.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 16686fe8..5526516f 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,3 +1,4 @@ +import logging import math import causal_conv1d @@ -11,6 +12,8 @@ from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ from fast_llm.utils import get_lr_scale +logger = logging.getLogger(__name__) + """ This code is adapted fropm https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py """ @@ -47,6 +50,7 @@ def __init__( self._return_input = return_input layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) + logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}") td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) From 532d0d577412d1601f5cfc187f1c1e853471ab43 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 12 May 2025 17:02:15 +0000 Subject: [PATCH 084/114] no norm layer freezing --- fast_llm/layers/transformer/transformer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 0b0d5334..fd56ba08 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -38,9 +38,10 @@ def __init__( self._layer_index = layer_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None - self.norm_1 = self._config.normalization.get_layer(hidden_dim, lr_scale=layer_lr_scale) - self.norm_2 = self._config.normalization.get_layer(hidden_dim, lr_scale=layer_lr_scale) + # layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + # we dont want to freeze norm layers here + self.norm_1 = self._config.normalization.get_layer(hidden_dim) + self.norm_2 = self._config.normalization.get_layer(hidden_dim) self._create_mixer() From 834913060a7e7eab7c4d0cfd7dba7e3bc6fa1228 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 12 May 2025 20:07:52 +0000 Subject: [PATCH 085/114] test --- fast_llm/layers/ssm/discrete_mamba2.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 5526516f..c518e798 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -74,13 +74,15 @@ def __init__( # TODO: double check innitializations # Projections - self.in_proj = Linear(td_model, td_inner_proj, bias=bias, weight_init_method=kaiming_init_(td_model.size)) + self.in_proj = Linear( + td_model, td_inner_proj, bias=bias, weight_init_method=kaiming_init_(td_model.size) + ) # , lr_scale=mamba_layer_lr_scale) self.z_bias = ( ParameterMeta.from_dims( (td_inner,), weight_decay=False, init_method=init_zeros_, - lr_scale=mamba_layer_lr_scale, + # lr_scale=mamba_layer_lr_scale, ) if not bias else 0.0 @@ -92,10 +94,12 @@ def __init__( init_method=init_uniform_( 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 - lr_scale=mamba_layer_lr_scale, + # lr_scale=mamba_layer_lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( - (td_conv,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale + (td_conv,), + init_method=bias_init_method(self.conv1d_weight), + # , lr_scale=mamba_layer_lr_scale ) # D "skip" parameter @@ -103,7 +107,7 @@ def __init__( (td_n_qk_heads,), weight_decay=False, init_method=init_ones_, - lr_scale=mamba_layer_lr_scale, + # lr_scale=mamba_layer_lr_scale, ) # out_proj @@ -112,7 +116,7 @@ def __init__( td_model, bias=bias, weight_init_method=kaiming_init_(td_inner.size), - lr_scale=mamba_layer_lr_scale, + # lr_scale=mamba_layer_lr_scale, ) @property From 023102c4bcab18b05408a791318c780f84d6462b Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 12 May 2025 20:19:57 +0000 Subject: [PATCH 086/114] test --- fast_llm/tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 611eb9f4..c82a3bf1 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -235,8 +235,8 @@ def __init__( self.lr_scale = lr_scale if isinstance(lr_scale, tuple) else (lr_scale,) # TODO: re-enable when fixed? - # self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) - self.requires_grad = requires_grad + self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) + # self.requires_grad = requires_grad # Ensure the parameter is split in chunks of equal size. Assert.multiple(self.dims[0].size, len(self.lr_scale)) From 865da957496f50e8f703e5a52b2be1d5cd9cbeb3 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 12 May 2025 20:43:38 +0000 Subject: [PATCH 087/114] debug --- fast_llm/layers/ssm/discrete_mamba2.py | 20 +++++++++++--------- fast_llm/tensor.py | 2 -- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index c518e798..d4f9f84d 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -75,14 +75,18 @@ def __init__( # TODO: double check innitializations # Projections self.in_proj = Linear( - td_model, td_inner_proj, bias=bias, weight_init_method=kaiming_init_(td_model.size) - ) # , lr_scale=mamba_layer_lr_scale) + td_model, + td_inner_proj, + bias=bias, + weight_init_method=kaiming_init_(td_model.size), + lr_scale=mamba_layer_lr_scale, + ) self.z_bias = ( ParameterMeta.from_dims( (td_inner,), weight_decay=False, init_method=init_zeros_, - # lr_scale=mamba_layer_lr_scale, + lr_scale=mamba_layer_lr_scale, ) if not bias else 0.0 @@ -94,12 +98,10 @@ def __init__( init_method=init_uniform_( 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 - # lr_scale=mamba_layer_lr_scale, + lr_scale=mamba_layer_lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( - (td_conv,), - init_method=bias_init_method(self.conv1d_weight), - # , lr_scale=mamba_layer_lr_scale + (td_conv,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale ) # D "skip" parameter @@ -107,7 +109,7 @@ def __init__( (td_n_qk_heads,), weight_decay=False, init_method=init_ones_, - # lr_scale=mamba_layer_lr_scale, + lr_scale=mamba_layer_lr_scale, ) # out_proj @@ -116,7 +118,7 @@ def __init__( td_model, bias=bias, weight_init_method=kaiming_init_(td_inner.size), - # lr_scale=mamba_layer_lr_scale, + lr_scale=mamba_layer_lr_scale, ) @property diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index c82a3bf1..84930756 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -234,9 +234,7 @@ def __init__( self.allow_no_grad = allow_no_grad self.lr_scale = lr_scale if isinstance(lr_scale, tuple) else (lr_scale,) - # TODO: re-enable when fixed? self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) - # self.requires_grad = requires_grad # Ensure the parameter is split in chunks of equal size. Assert.multiple(self.dims[0].size, len(self.lr_scale)) From 87c93d32583d1117b1410c1d25dd1b97de656f38 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 12 May 2025 21:06:50 +0000 Subject: [PATCH 088/114] comment --- fast_llm/layers/transformer/transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index fd56ba08..b51ba1e9 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -38,8 +38,8 @@ def __init__( self._layer_index = layer_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) - # layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None - # we dont want to freeze norm layers here + # Note, layer_lr_scale does not impact the norms + # TODO: add a seperate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) From a18b80f575c0be1aacf827eb49bce9d7d295ca9e Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 12 May 2025 22:04:18 +0000 Subject: [PATCH 089/114] debug --- fast_llm/tensor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 84930756..611eb9f4 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -234,7 +234,9 @@ def __init__( self.allow_no_grad = allow_no_grad self.lr_scale = lr_scale if isinstance(lr_scale, tuple) else (lr_scale,) - self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) + # TODO: re-enable when fixed? + # self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) + self.requires_grad = requires_grad # Ensure the parameter is split in chunks of equal size. Assert.multiple(self.dims[0].size, len(self.lr_scale)) From 40d5437917869a855f7b31ef96296ff5543b6518 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 14 May 2025 13:50:55 +0000 Subject: [PATCH 090/114] wip --- .../modeling_ssm_hybrid_apriel.py | 215 ++++++++++++++++-- .../apriel_ssm/modeling_ssm_apriel.py | 2 + tests/common.py | 8 +- 3 files changed, 204 insertions(+), 21 deletions(-) diff --git a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py index 6800158e..5d8f4cc5 100644 --- a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py +++ b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py @@ -29,6 +29,175 @@ logger = logging.get_logger(__name__) +class HybridMambaAttentionStaticCache(Cache): + def __init__(self, config: AprielSSMHybridConfig, batch_size, max_length, dtype=torch.float16, device=None): + super().__init__() # config, batch_size, max_length, device, dtype) + self.dtype = dtype + self.hybrid_override_pattern = config.hybrid_block_layout + self.has_previous_state = False # only used by mamba + intermediate_size = config.ssm_cfg["d_inner"] + ssm_state_size = config.ssm_cfg["d_state"] + conv_kernel_size = config.ssm_cfg["d_conv"] + self.n_qk_heads = config.ssm_cfg["n_qk_heads"] + assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" + self.head_d = intermediate_size // self.n_qk_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + + self.batch_size = batch_size + self.head_dim = ( + config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + ) + self.max_cache_len = config.max_position_embeddings if max_length is None else max_length + + self.num_key_value_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + cache_shape = (self.batch_size, self.num_key_value_heads, max_length, self.head_dim) + + for i in range(config.num_hidden_layers): + if self.hybrid_override_pattern[i] == "m2d": + # Mamba layer + new_layer_conv_state = torch.zeros( + batch_size, + conv_kernel_size, + intermediate_size + 2 * self.n_qk_heads * ssm_state_size, + device=device, + dtype=dtype, + ).transpose(1, 2) + + new_layer_ssm_state = torch.zeros( + batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype + ) + new_layer_key_cache = None # torch.zeros((0,), dtype=dtype, device=device) + new_layer_value_cache = None # torch.zeros((0,), dtype=dtype, device=device) + else: + # Attention or MLP layer + new_layer_conv_state = None # torch.tensor((0,), dtype=dtype, device=device) + new_layer_ssm_state = None # torch.tensor((0,), dtype=dtype, device=device) + new_layer_key_cache = torch.zeros(cache_shape, dtype=dtype, device=device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=dtype, device=device) + self.transformer_layers.append(i) + + # if not is_torchdynamo_compiling(): + # self.register_buffer(f"key_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) + # self.register_buffer(f"value_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) + # new_layer_key_cache = getattr(self, f"key_cache_{i}") + # new_layer_value_cache = getattr(self, f"value_cache_{i}") + # torch._dynamo.mark_static_address(new_layer_key_cache) + # torch._dynamo.mark_static_address(new_layer_value_cache) + # self.register_buffer(f"conv_states_{i}", new_layer_conv_state) + # self.register_buffer(f"ssm_states_{i}", new_layer_ssm_state) + # torch._dynamo.mark_static_address(new_layer_conv_state) + # torch._dynamo.mark_static_address(new_layer_ssm_state) + # new_layer_ssm_state = getattr(self, f"ssm_states_{i}") + # new_layer_conv_state = getattr(self, f"conv_states_{i}") + + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + self.conv_states.append(new_layer_conv_state) + self.ssm_states.append(new_layer_ssm_state) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. + + Return: + A tuple containing the updated key and value states. + """ + + cache_position = cache_kwargs.get("cache_position") + + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + + if cache_position is None: + k_out.copy_(key_states) + v_out.copy_(value_states) + else: + # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to + # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place + # operation, that avoids copies and uses less memory. + try: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + return k_out, v_out + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = None) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + if layer_idx is None: + layer_idx = self.transformer_layers[0] + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def get_max_cache_shape(self) -> Optional[int]: + return self.max_cache_len + + # Copied from modeling_mamba2.py + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False + ) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + # Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py class HybridMambaAttentionDynamicCache(DynamicCache): """ @@ -111,14 +280,6 @@ def reorder_cache(self, beam_idx: torch.LongTensor): device = self.ssm_states[layer_idx].device self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx: - return 0 - return self.key_cache[layer_idx].shape[-2] - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") @@ -549,7 +710,7 @@ def forward( u, return_mixer_matrix=False, past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, - cache_position: Optional[torch.LongTensor] = None, + inference_params=None, **kwargs, ): """ @@ -563,10 +724,12 @@ def forward( ssm_state, conv_state = None, None if past_key_value is not None: ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) - if cache_position[0] > 0: + if inference_params is not None and inference_params.seqlen_offset > 0: # States are updated inplace + # TODO: make sure inference_params with seqlen_offset are properly initialized u = u.squeeze(1) if len(u.shape) == 3 else u - out, _ = self.step(u, ssm_state, conv_state) + out, _, _ = self.step(u, ssm_state, conv_state) + out = out.unsqueeze(1) if len(u.shape) == 2 else out return {"hidden_states": out} # Hacky way to initialize state during inference @@ -660,8 +823,8 @@ def step(self, u, ssm_state, conv_state, **kwargs): dim=-1, ) - xBC, conv_state = self.convolutional_step(xBC, conv_state) - conv_state.copy_(conv_state) # update state in place + xBC, conv_state_new = self.convolutional_step(xBC, conv_state) + conv_state.copy_(conv_state_new) # update state in place x, B, C = torch.split( xBC, @@ -698,7 +861,7 @@ def step(self, u, ssm_state, conv_state, **kwargs): # Norm and gate out = self.out_proj(y * F.silu(z + self.z_bias)) - return out, ssm_state + return out, ssm_state, conv_state # def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): # device = self.in_proj.weight.device @@ -908,6 +1071,13 @@ class AprielSSMPreTrainedModel(PreTrainedModel): config_class = AprielSSMHybridConfig base_model_prefix = "model" _no_split_modules = ["AprielDecoderLayer", "AprielSSMDecoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True def _init_weights(self, module): std = self.config.initializer_range @@ -1018,6 +1188,7 @@ def __init__(self, config: AprielSSMHybridConfig, device=None, dtype=None, **kwa self.norm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) self.gradient_checkpointing = False self.rotary_emb = AprielRotaryEmbedding(config=config) + self.has_transformer_layers = any(type == "t" for type in config.hybrid_block_layout) # Initialize weights and apply final processing self.post_init() @@ -1078,23 +1249,25 @@ def forward( "provided, so no cache will be returned." ) - if cache_position is None: + if cache_position is None and self.has_transformer_layers: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - if position_ids is None: + if position_ids is None and self.has_transformer_layers: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + causal_mask = ( + self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions) + if self.has_transformer_layers + else None ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) + position_embeddings = self.rotary_emb(hidden_states, position_ids) if self.has_transformer_layers else None # decoder layers all_hidden_states = () if output_hidden_states else None @@ -1152,7 +1325,9 @@ def _update_causal_mask( # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) + using_static_cache = isinstance(past_key_values, StaticCache) or isinstance( + past_key_values, HybridMambaAttentionStaticCache + ) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: diff --git a/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py b/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py index a46530fc..09dc8259 100644 --- a/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py +++ b/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py @@ -225,7 +225,9 @@ def forward(self, u, return_mixer_matrix=False, inference_params=None, **kwargs) state = self._get_states_from_cache(inference_params, batch) if inference_params.seqlen_offset > 0: # States are updated inplace + u = u.squeeze(1) if len(u.shape) == 3 else u out, _ = self.step(u, state) + out = out.unsqueeze(1) if len(u.shape) == 2 else out return {"hidden_states": out} # Hacky way to initialize state during inference diff --git a/tests/common.py b/tests/common.py index 569d690c..2f2fd5f7 100644 --- a/tests/common.py +++ b/tests/common.py @@ -63,7 +63,8 @@ f"model.multi_stage.debug_param_init={_LOG_LEVEL}", f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", f"model.multi_stage.debug_layer_gradients={_LOG_LEVEL}", - f"model.multi_stage.debug_all_param_gradients={_LOG_LEVEL}", + # f"model.multi_stage.debug_all_param_gradients={_LOG_LEVEL}", + f"model.multi_stage.debug_all_param_gradients=0", "model.multi_stage.debug_tensor_parallel=True", "model.distributed.reproducible_init=True", "model.distributed.timeout=10", @@ -201,6 +202,11 @@ CONFIG_LLAMA_MTP_MEGATRON = None CONFIG_LLAMA_MTP_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ "model.base_model.prediction_heads=4", + "model.base_model.embeddings_lr_scale=0", + "model.base_model.transformer.per_layer_lr_scale=[0.1,0.0000001,0.0000001,1,1,.1]", + # "model.base_model.output_lr_scale=0", + # "model.base_model.prediction_loss_coefficient=[1, .5, .5, 0]", + # "model.base_model.cross_entropy_splits=4", ] CONFIG_LLAMA_MTP_COMMON = CONFIG_LLAMA_MTP_FAST_LLM + ["model.distributed.training_dtype=bf16"] From 72ace3b08cbd9255d8cdf1a64030625dc9e28fe4 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 14 May 2025 19:01:51 +0000 Subject: [PATCH 091/114] fix --- fast_llm/engine/checkpoint/safe_load.py | 4 ++-- fast_llm/engine/multi_stage/config.py | 5 +++++ fast_llm/engine/multi_stage/fsdp.py | 11 ++++++++--- fast_llm/engine/multi_stage/multi_stage.py | 7 +------ fast_llm/engine/multi_stage/stage.py | 2 +- fast_llm/engine/multi_stage/stage_base.py | 4 ++-- tests/test_checkpoint.py | 3 +-- 7 files changed, 20 insertions(+), 16 deletions(-) diff --git a/fast_llm/engine/checkpoint/safe_load.py b/fast_llm/engine/checkpoint/safe_load.py index 2eec57e0..4b7366e4 100644 --- a/fast_llm/engine/checkpoint/safe_load.py +++ b/fast_llm/engine/checkpoint/safe_load.py @@ -40,8 +40,8 @@ def __enter__(self) -> "SafeLoad": triton_fill(self_shard, math.nan) # Reset and count shard pads for _, fsdp, fsdp_shards in self._model.split_shards_by_fsdp(self._self_shards): - for fsdp_shard in fsdp_shards.values(): - self._loaded += fsdp.reset_shard_pad(fsdp_shard) + for shard_name, fsdp_shard in fsdp_shards.items(): + self._loaded += fsdp.reset_shard_pad(fsdp_shard, shard_name) return self def __exit__(self, exc_type, exc_val, exc_tb): diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index e2d04f80..40002ce6 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -36,6 +36,11 @@ logger = logging.getLogger(__name__) +class ShardName: + weights = "weights" + grads = "grads" + + class StageMode(str, enum.Enum): # Allow forward and backward passes and optimizer. # TODO: Add mode for forward and backward but not optimizer? diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index e9c84aa3..61d1c7a8 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -10,7 +10,7 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.engine.distributed.config import DistributedDim from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.config import SHARD_PAD_TO_MULTIPLE, StageMode +from fast_llm.engine.multi_stage.config import SHARD_PAD_TO_MULTIPLE, ShardName, StageMode from fast_llm.functional.triton.pointwise import triton_add, triton_copy from fast_llm.logging import log_distributed_tensor from fast_llm.tensor import ParameterMeta, SafeTensorSlice, TensorMeta @@ -246,13 +246,14 @@ def setup( ) self._parameter_buffers[parameter_name] = parameter_buffer - def reset_shard_pad(self, shard: torch.Tensor) -> int: + def reset_shard_pad(self, shard: torch.Tensor, shard_name: str) -> int: assert self._is_setup assert self._mode.on_device # TODO: Needed? # Prevent nans with the padded values # Also ensures a correct parameter count in loading context. - self._weight_shard_meta.validate(shard) + shard_meta = self._weight_shard_meta if shard_name == ShardName.weights else self._grad_shard_meta + shard_meta.validate(shard) if self._shard_pad > 0: shard[-self._shard_pad :].zero_() return self._shard_pad @@ -452,5 +453,9 @@ def copy_shard_overlaps( begin, end = self._parameter_range_in_shard(name) for shard_name, shard in shards.items(): + # Shards can be empty (frozen weights) + if shard.numel() == 0: + Assert.eq(loaded_shards[shard_name].numel(), 0) + continue shard[begin:end][overlap_mask] = loaded_shards[shard_name][overlap_index_map_masked] counter += overlap_count diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 21d0fe55..497d1110 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -14,7 +14,7 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode from fast_llm.engine.multi_stage.fsdp import FSDP from fast_llm.engine.multi_stage.stage import Stage from fast_llm.engine.optimizer.config import ParamGroup @@ -24,11 +24,6 @@ logger = logging.getLogger(__name__) -class ShardName: - weights = "weights" - grads = "grads" - - class MultiStageModel[ConfigType: FastLLMModelConfig](Configurable[ConfigType]): config_class: typing.ClassVar[type[FastLLMModelConfig]] = FastLLMModelConfig base_model_class: typing.ClassVar[type[BaseModel]] = BaseModel diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 675e878b..179a94c1 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -156,7 +156,7 @@ def reduce_gradients(self, accumulate=False) -> None: level=self._config.debug_param_gradients, global_=False, ) - if self._config.debug_all_param_gradients: + if self._config.debug_all_param_gradients and fsdp.requires_grad: fsdp.log_shard( name="gradient", shard=fsdp.grad_shard, diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index fd50f55c..e1b44471 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -10,7 +10,7 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.config import StageConfig, StageMode +from fast_llm.engine.multi_stage.config import ShardName, StageConfig, StageMode from fast_llm.engine.multi_stage.fsdp import FSDP from fast_llm.engine.optimizer.config import ParamGroup from fast_llm.logging import log_generator @@ -209,7 +209,7 @@ def initialize_weights(self) -> None: meta.init_parameter(parameter, self._distributed) if self.mode.on_device: - fsdp.reset_shard_pad(fsdp.weight_shard) + fsdp.reset_shard_pad(fsdp.weight_shard, ShardName.weights) if self._config.debug_param_init: log_generator("CPU generator after reset", torch.random.default_generator) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 257947e9..042f2bb2 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -14,8 +14,7 @@ FastLLMCheckpointFormat, ModelConfigType, ) -from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode -from fast_llm.engine.multi_stage.multi_stage import ShardName +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode from fast_llm.models.auto import model_registry from fast_llm.tools.convert import ConversionConfig from tests.common import ( From 121e9064c0970ff15595f80077d29f88df36f6c4 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 14 May 2025 19:43:05 +0000 Subject: [PATCH 092/114] test + comment --- fast_llm/tensor.py | 5 ++--- tests/common.py | 5 +---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index 611eb9f4..ad2d42d1 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -234,9 +234,8 @@ def __init__( self.allow_no_grad = allow_no_grad self.lr_scale = lr_scale if isinstance(lr_scale, tuple) else (lr_scale,) - # TODO: re-enable when fixed? - # self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) - self.requires_grad = requires_grad + # TODO: note, this pevents the tes_checkpoints to pass for MODEL=llama-mtp, they pass with `self.requires_grad=requires_grad` instead. However, the model export seem to work as expected, at least for hybrid SSM. + self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) # Ensure the parameter is split in chunks of equal size. Assert.multiple(self.dims[0].size, len(self.lr_scale)) diff --git a/tests/common.py b/tests/common.py index 2f2fd5f7..16d5114b 100644 --- a/tests/common.py +++ b/tests/common.py @@ -203,10 +203,7 @@ CONFIG_LLAMA_MTP_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ "model.base_model.prediction_heads=4", "model.base_model.embeddings_lr_scale=0", - "model.base_model.transformer.per_layer_lr_scale=[0.1,0.0000001,0.0000001,1,1,.1]", - # "model.base_model.output_lr_scale=0", - # "model.base_model.prediction_loss_coefficient=[1, .5, .5, 0]", - # "model.base_model.cross_entropy_splits=4", + "model.base_model.transformer.per_layer_lr_scale=[0.1,0,0,1,1,.1]", ] CONFIG_LLAMA_MTP_COMMON = CONFIG_LLAMA_MTP_FAST_LLM + ["model.distributed.training_dtype=bf16"] From aa3bc0be1368c22a9eceb5d00a1f69db779858b9 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 14 May 2025 16:43:45 -0400 Subject: [PATCH 093/114] stuff --- fast_llm/config.py | 92 ++++++++----------- fast_llm/data/data/config.py | 4 +- fast_llm/data/data/gpt/config.py | 3 +- fast_llm/data/dataset/config.py | 2 - fast_llm/data/dataset/gpt/config.py | 5 +- fast_llm/data/preparator/gpt_memmap/config.py | 7 +- fast_llm/engine/base_model/base_model.py | 2 +- fast_llm/engine/base_model/config.py | 1 + fast_llm/engine/config_utils/run.py | 8 +- fast_llm/engine/multi_stage/config.py | 17 +--- fast_llm/engine/training/config.py | 32 ++----- fast_llm/layers/language_model/config.py | 1 - fast_llm/layers/ssm/config.py | 1 - fast_llm/layers/transformer/config.py | 3 - fast_llm/models/custom/config.py | 8 +- fast_llm/models/gpt/config.py | 8 +- fast_llm/models/ssm/config.py | 9 +- fast_llm/tools/cli.py | 2 +- fast_llm/tools/convert.py | 18 ++-- tests/config/common.py | 8 +- tests/config/test_config.py | 30 ++++++ tests/data/common.py | 2 +- tests/test_checkpoint.py | 16 ++-- tests/test_ssms.py | 2 + tools/push_model.py | 4 +- 25 files changed, 128 insertions(+), 157 deletions(-) create mode 100644 tests/config/test_config.py diff --git a/fast_llm/config.py b/fast_llm/config.py index 46c903f1..3b277202 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1,3 +1,4 @@ +import abc import contextlib import copy import dataclasses @@ -137,7 +138,6 @@ def __init__( default=dataclasses.MISSING, default_factory=dataclasses.MISSING, init: bool = True, - repr: bool = True, hash=None, compare: bool = True, metadata=None, @@ -146,12 +146,11 @@ def __init__( if default is not dataclasses.MISSING and default_factory is not dataclasses.MISSING: raise ValueError("cannot specify both default and default_factory") if isinstance(default_factory, type) and issubclass(default_factory, Config): - default_factory = _ConfigFactory(default_factory) + raise ValueError("Config classes should not be used as `default_factory`") super().__init__( default=default, default_factory=default_factory, init=init, - repr=repr, hash=hash, compare=compare, metadata=metadata, @@ -223,20 +222,6 @@ def valid(x): return valid -class _ConfigFactory: - """ - A dataclass default factory that prevents early validation. - Validation is still done through the parent config if needed. - """ - - def __init__(self, factory: typing.Callable[[], "Config"] | type["Config"]): - self._factory = factory - - def __call__(self): - with NoAutoValidate(): - return self._factory() - - class ValidationError(ValueError): pass @@ -257,7 +242,7 @@ def _process_config_class(cls: type["Config"]): return cls -def config_class(cls=None): +def config_class[T: Config]() -> typing.Callable[[type[T]], type[T]]: """ Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications. """ @@ -280,20 +265,23 @@ def __init__(self, **kwargs): if _AUTO_VALIDATE: self.validate() - cls.__init__ = __init__ + wrapped.__init__ = __init__ return wrapped - # See if we're being called as @config_class or @config_class(). - if cls is None: - # We're called with parens. - return wrap + return wrap + - # We're called as @config_class without parens. - return wrap(cls) +class ConfigMeta(abc.ABCMeta): + def __call__(cls: "type[Config]", **kwargs): + # Always go through `_from_dict` for correct dynamic class selection and nested config instantiation. + if not kwargs.pop("_from_dict_check", False): + # with NoAutoValidate(): + return cls._from_dict(kwargs) + return super().__call__(**kwargs) -@dataclasses.dataclass() -class Config: +@dataclasses.dataclass(kw_only=True, repr=False) +class Config(metaclass=ConfigMeta): """ An advanced `dataclass` with basic type checking, validation and argparse support. Typically, a subclass will: @@ -307,14 +295,14 @@ class Config: # Set to true to prevent instantiation. _abstract: typing.ClassVar[bool] = False # Keep track of whether an instance has been validated - _validated: bool = Field(init=False, repr=False) + _validated: bool = Field(init=False) # Keep track of unknown fields so they can be reported during validation. - _unknown_fields: dict[str, typing.Any] = Field(init=False, repr=False) + _unknown_fields: dict[str, typing.Any] = Field(init=False) # Keep track of explicitly set fields to ensure they get serialized and used as config updates. - _explicit_fields: set[str] = Field(init=False, repr=False) + _explicit_fields: set[str] = Field(init=False) # Used within `_set_implicit_default` to set implicit defaults for fields # without them being automatically added to `_explicit_fields`. - _setting_implicit_default: bool | None = Field(init=False, repr=False) + _setting_implicit_default: bool | None = Field(init=False) def __setattr__(self, key: str, value: typing.Any) -> None: """ @@ -339,7 +327,7 @@ def __setattr__(self, key: str, value: typing.Any) -> None: ) else: field = self.get_field(key) - if field.init and field._field_type != dataclasses._FIELD_CLASSVAR: + if field.init and field._field_type == dataclasses._FIELD: # Adding to explicit field list except within `_set_implicit_default` context, # during dataclass initialization (`_setting_implicit_default` not yet set) # and during automated config validation (`_setting_implicit_default=None`) @@ -358,13 +346,13 @@ def __delattr__(self, key: str) -> None: super().__delattr__(key) @contextlib.contextmanager - def _set_implicit_default(self, _value: bool | int = True): + def _set_implicit_default(self, _value: bool | None = True): assert self._setting_implicit_default is False self._setting_implicit_default = _value yield self._setting_implicit_default = False - def validate[T](self: T, *, _is_validating: bool = False) -> T: + def validate[T: Config](self: T, *, _is_validating: bool = False) -> T: """ Validate a class and mark it as read-only This should not be overridden in derived classes. @@ -388,11 +376,16 @@ def _validate(self) -> None: Can be extended to add custom post-processing (typically before the super() call) and validation (typically after) """ - self._check_abstract() + if self._abstract: + raise ValidationError(f"{type(self).__name__} is abstract") + if not self.__class_validated__: + raise ValidationError( + f"{type(self).__name__} hasn't been validated. Make sure to use the @config_class decorator." + ) errors = [] with self._set_implicit_default(None): for name, field in self.fields(): - if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa + if not field.init or field._field_type != dataclasses._FIELD: # noqa continue value = getattr(self, name) if isinstance(value, Tag): @@ -610,11 +603,7 @@ def _add_field_to_args( all_fields: bool = False, serializable: bool = True, ) -> None: - if ( - field is not None - and (not field.init or field._field_type == dataclasses._FIELD_CLASSVAR) - and not all_fields - ): + if field is not None and (not field.init or field._field_type != dataclasses._FIELD) and not all_fields: # Exclude class variables and derived fields unless requested explicitly. return explicit_field = ( @@ -677,6 +666,9 @@ def to_copy[ ) -> T: return self.from_dict(self, *updates, strict=strict, update_type=update_type) + def __repr__(self): + return self.to_logs(log_fn=str) + def to_logs[ T ]( @@ -739,7 +731,7 @@ def _from_dict( flat: bool = False, ) -> typing.Self: # TODO v0.3: Remove flat format - out_arg_dict = {} + out_arg_dict = {"_from_dict_check": True} # TODO v0.3: Remove backward compatibility fix if "__class__" in default: @@ -748,7 +740,7 @@ def _from_dict( # Do not validate yet in case the root class sets cross-dependencies in validation. with NoAutoValidate(): for name, field in cls.fields(): - if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa + if not field.init or field._field_type != dataclasses._FIELD: # noqa continue if flat: if isinstance(field.type, type) and issubclass(field.type, Config): @@ -869,22 +861,15 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ f"Config comparison errors:\n " + "\n".join(errors), log_fn=log_fn, ) - - @classmethod - def _check_abstract(cls) -> None: - if cls._abstract: - raise ValidationError(f"{cls.__name__} is abstract") - if not cls.__class_validated__: - raise ValidationError( - f"{cls.__name__} hasn't been validated. Make sure to use the @config_class decorator." - ) + return None def __init_subclass__(cls): """ We need to postpone validation until the class has been processed by the dataclass wrapper. """ + Assert.eq(cls.__name__, cls.__qualname__) for base_class in cls.__mro__: - if issubclass(base_class, Config): + if issubclass(base_class, Config) and base_class is not cls: assert cls.__class_validated__, ( f"Parent class {get_type_name(base_class)} of config class {get_type_name(cls)} has not been validated." f" Make sure to use the @config_class decorator." @@ -913,7 +898,6 @@ def __init_subclass__(cls): valid=value.pop("valid", base_class_field.valid), default=value.pop("default", base_class_field.default), default_factory=value.pop("default_factory", base_class_field.default_factory), - repr=value.pop("repr", base_class_field.repr), hash=value.pop("hash", base_class_field.hash), compare=value.pop("compare", base_class_field.compare), metadata=value.pop("metadata", base_class_field.metadata), diff --git a/fast_llm/data/data/config.py b/fast_llm/data/data/config.py index 25850ac3..41dbb5d9 100644 --- a/fast_llm/data/data/config.py +++ b/fast_llm/data/data/config.py @@ -9,6 +9,4 @@ class DataConfig(Config): _abstract = True _sampling_config_class: typing.ClassVar[type[SamplingData]] - sampling: SamplingConfig = Field( - default_factory=SamplingConfig, desc="Default configuration for dataset sampling." - ) + sampling: SamplingConfig = Field(desc="Default configuration for dataset sampling.") diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 6c598c0c..85bcc656 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -27,7 +27,6 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): _abstract = False tokenizer: TokenizerConfig = Field( - default_factory=TokenizerConfig, desc="Configuration for the tokenizer (for FIM).", hint=FieldHint.feature, ) @@ -37,7 +36,7 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig): desc="Configuration for the dataset(s).", hint=FieldHint.core, ) - sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig) + sampling: GPTSamplingConfig = FieldUpdate() data_sample_warn_time_ms: float = Field( default=1000, desc="Warn if a sample takes too long to load.", diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 7901d6e7..1bb4b6be 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -174,12 +174,10 @@ class SampledDatasetUpdateConfig(SampledDatasetConfig): _abstract = True sampling: SamplingConfig = Field( - default_factory=SamplingConfig, desc="Optional override to sampling configuration parameters.", hint=FieldHint.core, ) dataset: SampledDatasetConfig = Field( - default_factory=SampledDatasetConfig, desc="The dataset to sample from.", hint=FieldHint.core, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index ed9128c6..f4f6e282 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -230,8 +230,8 @@ def build(self) -> "GPTDatasetSlice": class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig): _abstract = False type_: typing.ClassVar[str | None] = "sampled" - sampling: GPTSamplingConfig = FieldUpdate(default_factory=GPTSamplingConfig) - dataset: GPTSampledDatasetConfig = FieldUpdate(default_factory=GPTSampledDatasetConfig) + sampling: GPTSamplingConfig = FieldUpdate() + dataset: GPTSampledDatasetConfig = FieldUpdate() @config_class() @@ -450,7 +450,6 @@ class GPTLegacyConfig(Config): valid=_validate_path, ) fim: FimConfig = Field( - default_factory=FimConfig, desc="Configuration for Fill In the Middle (FIM).", hint=FieldHint.feature, ) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 2c4311c3..7091f3c8 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -24,7 +24,7 @@ MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00" -@config_class +@config_class() class GPTHuggingfaceDatasetConfig(Config): path: str = Field( default=None, @@ -77,7 +77,7 @@ class GPTHuggingfaceDatasetConfig(Config): ) -@config_class +@config_class() class DatasetPreparatorDistributedConfig(Config): # TODO: Unify with fast_llm.engine.distributed.config.DistributedConfig @@ -120,7 +120,6 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.core, ) distributed: DatasetPreparatorDistributedConfig = Field( - default_factory=DatasetPreparatorDistributedConfig, desc="Configuration for distributed processing.", hint=FieldHint.feature, ) @@ -149,12 +148,10 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): valid=check_field(Assert.geq, 1), ) dataset: GPTHuggingfaceDatasetConfig = Field( - default_factory=GPTHuggingfaceDatasetConfig, desc="Configuration for the dataset.", hint=FieldHint.feature, ) tokenizer: TokenizerConfig = Field( - default_factory=TokenizerConfig, desc="Configuration for the tokenizer.", hint=FieldHint.feature, ) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 2be1e487..df603a91 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -90,7 +90,7 @@ def __init__( config: BaseModelConfig, distributed_config: DistributedConfig, ): - self._tensor_space = TensorSpace(distributed_config) + self._tensor_space: TensorSpace = TensorSpace(distributed_config) config.setup_tensor_space(self._tensor_space) super().__init__(config) diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 25f53e4a..4be42e06 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -42,6 +42,7 @@ def _get_architecture(self) -> dict[str, typing.Any]: assert isinstance(field, Field), f"{name}, {field}" if field.hint == FieldHint.architecture: architecture[name] = self._serialize_architecture_field(getattr(self, name, MISSING)) + return architecture def _serialize_architecture_field(self, value: typing.Any) -> typing.Any: if isinstance(value, BaseModelConfig): diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index d6377409..126e0ae8 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -20,9 +20,7 @@ @config_class() class RunConfig(Config): - tensor_logs: TensorLogsConfig = Field( - default_factory=TensorLogsConfig, desc="Configuration for debug tensor logs.", hint=FieldHint.logging - ) + tensor_logs: TensorLogsConfig = Field(desc="Configuration for debug tensor logs.", hint=FieldHint.logging) # TODO v0.3: Adjust (now only affects logging to file). structured_logs: bool = Field( default=True, desc="Configure logging to the Fast-LLM format.", hint=FieldHint.logging @@ -70,9 +68,7 @@ def _validate(self): @config_class() class ExperimentConfig(RunnableConfig): - run: RunConfig = Field( - default_factory=RunConfig, desc="Global properties for the experiment.", hint=FieldHint.core - ) + run: RunConfig = Field(desc="Global properties for the experiment.", hint=FieldHint.core) def _show( self, diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index e2d04f80..9434fba6 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -211,17 +211,12 @@ class FastLLMModelConfig(Config): FastLLMCheckpointFormat, ) model_name: typing.ClassVar[str] - base_model: BaseModelConfig = Field( - default_factory=BaseModelConfig, desc="Configuration for the base model.", hint=FieldHint.core - ) + base_model: BaseModelConfig = Field(desc="Configuration for the base model.", hint=FieldHint.core) multi_stage: MultiStageConfig = Field( - default_factory=MultiStageConfig, desc="Configuration for the stage breakdown of the model.", hint=FieldHint.core, ) - distributed: DistributedConfig = Field( - default_factory=DistributedConfig, desc="Distributed configuration.", hint=FieldHint.core - ) + distributed: DistributedConfig = Field(desc="Distributed configuration.", hint=FieldHint.core) @classmethod def __fast_llm_serialize__(cls) -> str: @@ -291,11 +286,8 @@ class PretrainedFastLLMModelConfig(Config): # TODO: Generalize data, schedule, logging, etc. _abstract = True # This configs may be overridden with the pretrained config during validation, so we should be careful about accessing them before. - model: FastLLMModelConfig = Field( - default_factory=FastLLMModelConfig, desc="Configuration for the Fast-LLM model.", hint=FieldHint.core - ) + model: FastLLMModelConfig = Field(desc="Configuration for the Fast-LLM model.", hint=FieldHint.core) pretrained: CheckpointLoadConfig = Field( - default_factory=CheckpointLoadConfig, desc="Configuration for loading the configuration and state of a pretrained model.", hint=FieldHint.feature, ) @@ -315,7 +307,7 @@ def _setup(self) -> None: pass -@config_class +@config_class() class CheckpointMetadata(Config): # TODO: Make entries more flexible? # I.e.. model / format / usage (ex. training) - specific entries instead of a generic metadata? @@ -336,7 +328,6 @@ class CheckpointMetadata(Config): hint=FieldHint.core, ) config: FastLLMModelConfig = Field( - default_factory=FastLLMModelConfig, desc="The Fast-LLM model configuration for the saved model.", hint=FieldHint.core, ) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 1e990e9c..a5be2e7e 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -120,7 +120,7 @@ class WandbAlertConfig(IntervalConfig): "The update may be posted by email and/or slack depending on the Wandb account configuration.", hint=FieldHint.feature, ) - post_alerts: bool = Field(init=False, repr=False) + post_alerts: bool = Field(init=False) def _validate(self) -> None: if self.status_updates is None: @@ -141,7 +141,6 @@ class MetricsLogsConfig(IntervalConfig): @config_class() class WandbConfig(Config): alert: WandbAlertConfig = Field( - default_factory=WandbAlertConfig, desc="Configuration for Wandb alerts." " The alerts may be posted by email and/or slack depending on the Wandb account configuration.", hint=FieldHint.core, @@ -175,7 +174,6 @@ class TrainingCheckpointBaseConfig(IntervalConfig): _abstract = True save_name: typing.ClassVar[str] = "save" callback: CallbackConfig = Field( - default_factory=CallbackConfig, desc="Callback (shell script).", hint=FieldHint.core, ) @@ -257,7 +255,6 @@ class TrainingExportConfig(TrainingCheckpointBaseConfig, CheckpointStateSaveConf offset = FieldUpdate(desc="Offset for the first export.") callback: CallbackConfig = FieldUpdate(desc="Callback (shell script) to run after export.") - @abc.abstractmethod def get_save_directory(self, experiment_directory: pathlib.Path) -> pathlib.Path: return experiment_directory / "export" / self.format.name @@ -284,19 +281,11 @@ class TrainingConfig(Config): desc="A dictionary of evaluation dataset names and their configurations for the validation phase.", hint=FieldHint.core, ) - logs: MetricsLogsConfig = Field( - default_factory=MetricsLogsConfig, desc="Configuration for metric logging.", hint=FieldHint.core - ) - checkpoint: TrainingCheckpointConfig = Field( - default_factory=MetricsLogsConfig, desc="Configuration for checkpoints.", hint=FieldHint.core - ) - export: TrainingExportConfig = Field( - default_factory=MetricsLogsConfig, desc="Configuration for exports.", hint=FieldHint.core - ) - shutdown: ShutdownConfig = Field( - default_factory=ShutdownConfig, desc="Configuration for automated shutdown.", hint=FieldHint.core - ) - wandb: WandbConfig = Field(default_factory=WandbConfig, desc="Configuration for Wandb.", hint=FieldHint.core) + logs: MetricsLogsConfig = Field(desc="Configuration for metric logging.", hint=FieldHint.core) + checkpoint: TrainingCheckpointConfig = Field(desc="Configuration for checkpoints.", hint=FieldHint.core) + export: TrainingExportConfig = Field(desc="Configuration for exports.", hint=FieldHint.core) + shutdown: ShutdownConfig = Field(desc="Configuration for automated shutdown.", hint=FieldHint.core) + wandb: WandbConfig = Field(desc="Configuration for Wandb.", hint=FieldHint.core) train_iters: int = Field( default=0, desc="Total number of training iterations.", hint=FieldHint.core, valid=check_field(Assert.geq, 0) ) @@ -349,30 +338,23 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): _abstract = True # TODO: Generalize data, schedule, logging, etc. training: TrainingConfig = Field( - default_factory=TrainingConfig, desc="Configuration for the training phases and global properties.", hint=FieldHint.core, ) batch: BatchConfig = Field( - default_factory=BatchConfig, desc="Configuration for the training, validation and test batches.", hint=FieldHint.core, ) - schedule: ScheduleConfig = Field( - default_factory=ScheduleConfig, desc="Configuration for the scheduling of each iteration.", hint=FieldHint.core - ) + schedule: ScheduleConfig = Field(desc="Configuration for the scheduling of each iteration.", hint=FieldHint.core) data: DataConfig = Field( - default_factory=DataConfig, desc="Configuration for the dataset and model-independent preprocessing.", hint=FieldHint.core, ) profiling: ProfilingConfig = Field( - default_factory=ProfilingConfig, desc="Configuration for the optional profiling of GPU and CPU CUDA operations.", hint=FieldHint.logging, ) optimizer: OptimizerConfig = Field( - default_factory=OptimizerConfig, desc="Configuration for the training optimizer and learning rate schedule.", hint=FieldHint.core, ) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index d0f03ccf..0db76ad1 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -40,7 +40,6 @@ class LanguageModelKwargs: @config_class() class LanguageModelBaseConfig(BaseModelConfig): transformer: TransformerConfig = Field( - default_factory=TransformerConfig, desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index c6fe622e..25ad3d22 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -26,7 +26,6 @@ class SSMConfig(BaseModelConfig): # Normalization normalization: NormalizationConfig = Field( - default_factory=NormalizationConfig, desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index e69b1841..c621139c 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -248,17 +248,14 @@ def _validate(self) -> None: class TransformerConfig(BaseModelConfig): _abstract = False normalization: NormalizationConfig = Field( - default_factory=NormalizationConfig, desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) rotary: RotaryConfig = Field( - default_factory=RotaryConfig, desc="Configuration for the rotary positional embeddings.", hint=FieldHint.architecture, ) peft: TransformerPeftConfig = Field( - default_factory=TransformerPeftConfig, desc="Configuration for the parameter-efficient fine tuning.", hint=FieldHint.architecture, ) diff --git a/fast_llm/models/custom/config.py b/fast_llm/models/custom/config.py index 8be45e1c..08902e2c 100644 --- a/fast_llm/models/custom/config.py +++ b/fast_llm/models/custom/config.py @@ -26,7 +26,7 @@ class CustomBaseModelConfig(GPTBaseModelConfig): class CustomModelConfig(GPTModelConfig): # TODO: Add custom model config parameters, if any (typically none). model_name: typing.ClassVar[str] = "gpt_custom" - base_model: CustomBaseModelConfig = FieldUpdate(default_factory=CustomBaseModelConfig) + base_model: CustomBaseModelConfig = FieldUpdate() @classmethod def get_model_class(cls) -> type["CustomModel"]: @@ -43,14 +43,14 @@ def get_huggingface_model_class(cls) -> type["HuggingfaceCustomModelForCausalLM" @config_class() class PretrainedCustomModelConfig(PretrainedGPTModelConfig): - model: CustomModelConfig = FieldUpdate(default_factory=CustomModelConfig) + model: CustomModelConfig = FieldUpdate() @config_class() class CustomTrainerConfig(PretrainedCustomModelConfig, GPTTrainerConfig): # TODO: Add custom trainer config parameters, if any (typically none). - data: CustomDataConfig = FieldUpdate(default_factory=CustomDataConfig) - reference_models: dict[str, PretrainedCustomModelConfig] = FieldUpdate(default_factory=PretrainedCustomModelConfig) + data: CustomDataConfig = FieldUpdate() + reference_models: dict[str, PretrainedCustomModelConfig] = FieldUpdate() @classmethod def get_trainer_class(cls) -> type["CustomTrainer"]: diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 418f948e..0ec3fb51 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -129,7 +129,7 @@ def _from_dict( class GPTModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "gpt" - base_model: GPTBaseModelConfig = FieldUpdate(default_factory=GPTBaseModelConfig) + base_model: GPTBaseModelConfig = FieldUpdate() checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = FastLLMModelConfig.checkpoint_formats + ( AutoGPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, @@ -156,13 +156,13 @@ def get_huggingface_model_class(cls) -> type["HuggingfaceGPTModelForCausalLM"]: @config_class() class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): _abstract = False - model: GPTModelConfig = FieldUpdate(default_factory=GPTModelConfig) + model: GPTModelConfig = FieldUpdate() @config_class() class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): - data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) - batch: GPTBatchConfig = FieldUpdate(default_factory=GPTBatchConfig) + data: GPTDataConfig = FieldUpdate() + batch: GPTBatchConfig = FieldUpdate() # TODO: Use dynamic model type? reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 0311cc69..771a4fca 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -26,7 +26,6 @@ class HybridSSMBaseModelConfig(LanguageModelBaseConfig): _abstract = False ssm: SSMConfig = Field( - default_factory=SSMConfig, desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) @@ -129,7 +128,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: class HybridSSMModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "hybrid_ssm" - base_model: HybridSSMBaseModelConfig = FieldUpdate(default_factory=HybridSSMBaseModelConfig) + base_model: HybridSSMBaseModelConfig = FieldUpdate() checkpoint_formats = FastLLMModelConfig.checkpoint_formats + (LLambaHuggingfaceCheckpointFormat,) @classmethod @@ -154,13 +153,13 @@ def _validate(self): @config_class() class PretrainedHybridSSMModelConfig(PretrainedFastLLMModelConfig): _abstract = False - model: HybridSSMModelConfig = FieldUpdate(default_factory=HybridSSMModelConfig) + model: HybridSSMModelConfig = FieldUpdate() @config_class() class HybridTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): - data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) - batch: GPTBatchConfig = FieldUpdate(default_factory=GPTBatchConfig) + data: GPTDataConfig = FieldUpdate() + batch: GPTBatchConfig = FieldUpdate() @classmethod def get_trainer_class(cls) -> type["SSMTrainer"]: diff --git a/fast_llm/tools/cli.py b/fast_llm/tools/cli.py index 0cc02f42..8df884fe 100644 --- a/fast_llm/tools/cli.py +++ b/fast_llm/tools/cli.py @@ -21,7 +21,7 @@ def fast_llm(args=None): if parsed.subcommand == "train": from fast_llm.tools.train import CliTrainingConfig as Runnable elif parsed.subcommand == "convert": - from fast_llm.tools.convert import ConversionConfig as Runnable + from fast_llm.tools.convert import ConvertConfig as Runnable elif parsed.subcommand == "prepare": from fast_llm.tools.prepare_dataset import PrepareDatasetConfig as Runnable else: diff --git a/fast_llm/tools/convert.py b/fast_llm/tools/convert.py index d3db3745..3ee580aa 100644 --- a/fast_llm/tools/convert.py +++ b/fast_llm/tools/convert.py @@ -19,13 +19,13 @@ @config_class() -class ConversionConfig(RunnableConfig): - input: CheckpointLoadConfig = Field(default_factory=CheckpointLoadConfig) - output: CheckpointSaveConfig = Field(default_factory=CheckpointSaveConfig) +class ConvertConfig(RunnableConfig): + input: CheckpointLoadConfig = Field() + output: CheckpointSaveConfig = Field() use_cpu: bool = Field(default=False) exist_ok: bool = Field(default=False) layers_per_step: int | None = Field(default=None) - model_config_class: type[FastLLMModelConfig] = Field(default=None) + model: type[FastLLMModelConfig] = Field(default=None) @classmethod def _get_parser(cls): @@ -44,9 +44,9 @@ def _from_parsed_args(cls, parsed: argparse.Namespace, unparsed: list[str]): return config def _validate(self): - assert self.model_config_class is not None - self.input.setup(self.model_config_class) - self.output.setup(self.model_config_class) + assert self.model is not None + self.input.setup(self.model) + self.output.setup(self.model) super()._validate() def _convert_model_partial( @@ -81,7 +81,7 @@ def run(self): f"Output path {self.output.path} already exists and has been processed. Skipping model conversion..." ) return - model_class = self.model_config_class.get_model_class() + model_class = self.model.get_model_class() if self.layers_per_step is None: self._convert_model_partial(model_class, self.output) else: @@ -161,4 +161,4 @@ def run(self): if __name__ == "__main__": - ConversionConfig.parse_and_run() + ConvertConfig.parse_and_run() diff --git a/tests/config/common.py b/tests/config/common.py index a2657926..9ccfb597 100644 --- a/tests/config/common.py +++ b/tests/config/common.py @@ -13,7 +13,7 @@ class ExampleEnum(enum.StrEnum): c = "c" -@config_class +@config_class() class ExampleConfig(Config): int_field: int = Field(default=0, hint=FieldHint.optional) bool_field: bool = Field(default=False, hint=FieldHint.optional) @@ -40,7 +40,7 @@ def _validate(self) -> None: super()._validate() -@config_class +@config_class() class ExampleVerboseConfig(Config): # These fields will have non-empty default serialized values. list_default_field: list[int] = Field(default_factory=lambda: [0], hint=FieldHint.optional) @@ -56,9 +56,9 @@ def _validate(self) -> None: super()._validate() -@config_class +@config_class() class ExampleNestedConfig(ExampleConfig): - nested_field: ExampleConfig = Field(default_factory=ExampleConfig, hint=FieldHint.core) + nested_field: ExampleConfig = Field(hint=FieldHint.core) def check_config( diff --git a/tests/config/test_config.py b/tests/config/test_config.py new file mode 100644 index 00000000..4c473fa6 --- /dev/null +++ b/tests/config/test_config.py @@ -0,0 +1,30 @@ +import pytest + +from fast_llm.config import NoAutoValidate +from tests.config.common import ExampleConfig + + +def test_auto_validate(): + assert (config := ExampleConfig())._validated + with pytest.raises(RuntimeError): + config.bool_field = True + config.bool_field = False + + assert ExampleConfig.from_dict({})._validated + + with NoAutoValidate(): + assert not (config := ExampleConfig())._validated + + config.bool_field = True + + config.validate() + + assert config._validated + with pytest.raises(RuntimeError): + config.bool_field = False + config.bool_field = True + + with NoAutoValidate(): + assert not (config := ExampleConfig.from_dict({}))._validated + config.validate() + assert config._validated diff --git a/tests/data/common.py b/tests/data/common.py index 47b53195..00c3ff20 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -31,7 +31,6 @@ def get_sampling_data( *, seed: int = 54983, cache_directory: pathlib.Path | None = None, - distributed: Distributed = Distributed(DistributedConfig(), use_cpu=True), phase=PhaseType.training, sequence_length: int = 512, vocab_size=TEST_VOCAB_SIZE, @@ -41,6 +40,7 @@ def get_sampling_data( truncate_documents=True, ) -> GPTSamplingData: # Config with convenient defaults. + distributed = Distributed(DistributedConfig(), use_cpu=True) return GPTSamplingData( config=GPTSamplingConfig( seed=seed, diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 257947e9..77a4b482 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -17,7 +17,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.engine.multi_stage.multi_stage import ShardName from fast_llm.models.auto import model_registry -from fast_llm.tools.convert import ConversionConfig +from fast_llm.tools.convert import ConvertConfig from tests.common import ( CONFIG_COMMON, FORCE_REUSE_RESULTS, @@ -90,7 +90,7 @@ def test_resume(): ) -def _run_conversion(config: ConversionConfig): +def _run_conversion(config: ConvertConfig): if config.output.path.is_dir() and not REUSE_RESULTS: shutil.rmtree(config.output.path) if not config.output.path.is_dir(): @@ -106,7 +106,7 @@ def _run_conversion(config: ConversionConfig): @pytest.mark.depends(on=["test_checkpoint_and_eval"]) def test_convert_distributed_to_fast_llm(): _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CKPT_PATH, format=DistributedCheckpointFormat, @@ -125,7 +125,7 @@ def test_convert_fast_llm_to_huggingface(): if HUGGINGFACE_CHECKPOINT_FORMAT is None: pytest.skip(f"Conversion not supported for {TEST_MODEL}") _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CONVERT_PATH / "fast_llm_0", format=FastLLMCheckpointFormat, @@ -142,7 +142,7 @@ def test_convert_fast_llm_to_huggingface(): @pytest.mark.depends(on=["test_convert_fast_llm_to_huggingface"]) def test_convert_huggingface_to_distributed(): _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_0", format=HUGGINGFACE_CHECKPOINT_FORMAT, @@ -161,7 +161,7 @@ def test_convert_distributed_to_huggingface(): if HUGGINGFACE_CHECKPOINT_FORMAT is None: pytest.skip(f"Conversion not supported for {TEST_MODEL}") _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CKPT_PATH, format=DistributedCheckpointFormat, @@ -178,7 +178,7 @@ def test_convert_distributed_to_huggingface(): @pytest.mark.depends(on=["test_convert_distributed_to_huggingface"]) def test_convert_huggingface_to_fast_llm(): _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_1", format=HUGGINGFACE_CHECKPOINT_FORMAT, @@ -195,7 +195,7 @@ def test_convert_huggingface_to_fast_llm(): @pytest.mark.depends(on=["test_convert_huggingface_to_fast_llm"]) def test_convert_fast_llm_to_distributed(): _run_conversion( - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=_CONVERT_PATH / "fast_llm_1", format=FastLLMCheckpointFormat, diff --git a/tests/test_ssms.py b/tests/test_ssms.py index e6c9aafd..0fec3741 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -139,6 +139,7 @@ def test_load_from_llamba_checkpoint(distributed_config): assert torch.allclose(logits, hf_logits, atol=1e-2) +@pytest.mark.skip(reason="Too slow.") @pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") @pytest.mark.parametrize( "hybrid_block_layout,LAYER_CLS", @@ -207,6 +208,7 @@ def test_mamba_block(distributed_config, distributed): assert not torch.isinf(hidden_states).any() +@pytest.mark.slow @pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") @pytest.mark.parametrize( ("hybrid_block_layout"), diff --git a/tools/push_model.py b/tools/push_model.py index cd98b93c..edab3312 100644 --- a/tools/push_model.py +++ b/tools/push_model.py @@ -27,7 +27,7 @@ raise ImportError("Please install huggingface_hub to use this script") from e -from fast_llm.tools.convert import ConversionConfig # isort:skip +from fast_llm.tools.convert import ConvertConfig # isort:skip logger = logging.getLogger(__name__) @@ -147,7 +147,7 @@ def run(self) -> None: for _, checkpoint_path in new_checkpoint_paths: checkpoint_path_hf = checkpoint_path.with_name(checkpoint_path.name + "_hf") # Block until the conversion is done - ConversionConfig( + ConvertConfig( input=CheckpointLoadConfig( path=checkpoint_path, format=DistributedCheckpointFormat, From 28d321e320354539f2939fe3f94095a96fc43dcc Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 14 May 2025 17:02:59 -0400 Subject: [PATCH 094/114] stuff --- fast_llm/config.py | 1 + fast_llm/engine/optimizer/config.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 3b277202..4928cdbd 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -151,6 +151,7 @@ def __init__( default=default, default_factory=default_factory, init=init, + repr=False, hash=hash, compare=compare, metadata=metadata, diff --git a/fast_llm/engine/optimizer/config.py b/fast_llm/engine/optimizer/config.py index 3a154c9e..f4303a5d 100644 --- a/fast_llm/engine/optimizer/config.py +++ b/fast_llm/engine/optimizer/config.py @@ -74,12 +74,10 @@ class GradientScalerConfig(Config): class OptimizerConfig(Config): learning_rate: LearningRateScheduleConfig = Field( - default_factory=LearningRateScheduleConfig, desc="A schedule for the learning rate.", hint=FieldHint.core, ) gradient_scaler: GradientScalerConfig = Field( - default_factory=GradientScalerConfig, desc="Configuration for the fixed or dynamic gradient scaling.", hint=FieldHint.feature, ) From 1bbd7fb1bf55258a4f7589d6eb0b63d71d2f0fa5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 14 May 2025 17:46:02 -0400 Subject: [PATCH 095/114] stuff --- fast_llm/tools/convert.py | 2 +- tests/test_checkpoint.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/fast_llm/tools/convert.py b/fast_llm/tools/convert.py index 3ee580aa..648138ec 100644 --- a/fast_llm/tools/convert.py +++ b/fast_llm/tools/convert.py @@ -40,7 +40,7 @@ def _get_parser(cls): @classmethod def _from_parsed_args(cls, parsed: argparse.Namespace, unparsed: list[str]): config = super()._from_parsed_args(parsed, unparsed) - config.model_config_class = model_registry[parsed.model_type] + config.model = model_registry[parsed.model_type] return config def _validate(self): diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 77a4b482..e0845a4c 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -115,7 +115,7 @@ def test_convert_distributed_to_fast_llm(): path=_CONVERT_PATH / "fast_llm_0", format=FastLLMCheckpointFormat, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) @@ -134,7 +134,7 @@ def test_convert_fast_llm_to_huggingface(): path=_CONVERT_PATH / "huggingface_0", format=HUGGINGFACE_CHECKPOINT_FORMAT, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) @@ -151,7 +151,7 @@ def test_convert_huggingface_to_distributed(): path=_CONVERT_PATH / "distributed_0", format=DistributedCheckpointFormat, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) @@ -170,7 +170,7 @@ def test_convert_distributed_to_huggingface(): path=_CONVERT_PATH / "huggingface_1", format=HUGGINGFACE_CHECKPOINT_FORMAT, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) @@ -187,7 +187,7 @@ def test_convert_huggingface_to_fast_llm(): path=_CONVERT_PATH / "fast_llm_1", format=FastLLMCheckpointFormat, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) @@ -204,7 +204,7 @@ def test_convert_fast_llm_to_distributed(): path=_CONVERT_PATH / "distributed_1", format=DistributedCheckpointFormat, ), - model_config_class=TEST_MODEL_CONFIG_CLS, + model=TEST_MODEL_CONFIG_CLS, ) ) From 35959491886cd87d6a592b61745c42214a08c4a5 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 14 May 2025 19:29:34 -0400 Subject: [PATCH 096/114] Minimalistic dynamic configs --- fast_llm/config.py | 73 ++++++++++++++++++++++- fast_llm/data/dataset/gpt/config.py | 92 +++++------------------------ 2 files changed, 85 insertions(+), 80 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 4928cdbd..6e3e92dc 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -12,7 +12,7 @@ import yaml -from fast_llm.utils import Assert, Tag, compare_nested, get_type_name, header, log +from fast_llm.utils import Assert, Registry, Tag, compare_nested, get_type_name, header, log logger = logging.getLogger(__name__) @@ -243,7 +243,9 @@ def _process_config_class(cls: type["Config"]): return cls -def config_class[T: Config]() -> typing.Callable[[type[T]], type[T]]: +def config_class[ + T: Config +](registry: bool = False, dynamic_type: "dict[type[Config], str]|None" = None) -> typing.Callable[[type[T]], type[T]]: """ Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications. """ @@ -253,7 +255,7 @@ def wrap(cls): if hasattr(cls, "__post_init__"): raise TypeError(f"`__post_init__` should not be implemented for `Config` classes") - wrapped = _process_config_class(dataclasses.dataclass(cls, kw_only=True)) + wrapped = _process_config_class(dataclasses.dataclass(cls, kw_only=True, repr=False)) wrapped_init = cls.__init__ @@ -267,6 +269,13 @@ def __init__(self, **kwargs): self.validate() wrapped.__init__ = __init__ + + wrapped._registry = Registry[str, type[wrapped]](wrapped.__name__, {}) if registry else None + + if dynamic_type is not None: + for cls_, name in dynamic_type.items(): + cls_.register_subclass(name, wrapped) + return wrapped return wrap @@ -305,6 +314,9 @@ class Config(metaclass=ConfigMeta): # without them being automatically added to `_explicit_fields`. _setting_implicit_default: bool | None = Field(init=False) + # A registry for all the config classes. + _registry: typing.ClassVar[Registry[str, type[typing.Self]] | None] = None + def __setattr__(self, key: str, value: typing.Any) -> None: """ Make the class read-only after validation. @@ -358,6 +370,17 @@ def validate[T: Config](self: T, *, _is_validating: bool = False) -> T: Validate a class and mark it as read-only This should not be overridden in derived classes. """ + # Should be handled in `from_dict`, but can fail if instantiating directly. + try: + expected_class = self.get_subclass(self.type) + except KeyError as e: + # Delayed instantiation error in `from_dict`. + raise ValidationError(*e.args) + + if expected_class is not None: + # Should be handled in `from_dict`, but can fail if instantiating directly. + Assert.is_(self.__class__, expected_class) + if not self._validated: try: self._validate() @@ -738,6 +761,14 @@ def _from_dict( if "__class__" in default: del default["__class__"] + try: + actual_cls = cls.get_subclass(default.get("type")) + if actual_cls is not None and actual_cls is not cls: + return actual_cls._from_dict(default, strict=strict, flat=flat) + except KeyError: + # Postpone error to validation. + pass + # Do not validate yet in case the root class sets cross-dependencies in validation. with NoAutoValidate(): for name, field in cls.fields(): @@ -864,6 +895,42 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ ) return None + @classmethod + def register_subclass(cls, name: str, cls_: type[typing.Self]) -> None: + Assert.custom(issubclass, cls_, cls) + if cls._registry is None: + raise NotImplementedError(f"Subclass `{name}` doesn't have a registry..") + if name in cls._registry: + old_cls = cls._registry[name] + if old_cls.__name__ == cls_.__name__ and cls._registry[name].__module__ == cls_.__module__: + del cls._registry[name] + else: + raise KeyError(f"{cls.__name__} class registry already has an entry {name} from class {cls.__name__}.") + cls._registry[name] = cls_ + + @classmethod + def get_subclass(cls, name: str | None): + # TODO: Make it case-insensitive? + if name is None: + return None + cls_ = None + for base_class in cls.__mro__: + if issubclass(base_class, Config) and base_class._registry is not None and name in base_class._registry: + if cls_ is None: + cls_ = base_class._registry[name] + if not issubclass(cls_, cls): + raise KeyError(f" {cls_.__name__} is not a subclass of {cls.__name__} (from type {name})") + elif base_class._registry[name] is not cls_: + # We explicitly prevent ambiguous classes to ensure safe and unambiguous serialization. + # TODO: Only really need to avoid conflict with `Config`'s registry, relax this a bit? + raise KeyError( + f"Ambiguous type `{name}` for base class {cls.__name__}." + f" ({cls_.__name__} vs {base_class._registry[name]})" + ) + if cls_ is None: + raise KeyError(f"Unknown type {name} for base class {cls.__name__}") + return cls_ + def __init_subclass__(cls): """ We need to postpone validation until the class has been processed by the dataclass wrapper. diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index f4f6e282..4ab0b7df 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -23,7 +23,7 @@ SamplingParameters, ) from fast_llm.engine.distributed.config import PhaseType -from fast_llm.utils import Assert, Registry, normalize_probabilities, padded_cumsum +from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum if typing.TYPE_CHECKING: from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset @@ -92,61 +92,9 @@ class GPTSamplingData(SamplingData): truncate_documents: bool = True -@config_class() +@config_class(registry=True) class GPTSampledDatasetConfig(SampledDatasetConfig): - - # TODO: Generalize dynamic types? - _registry: typing.ClassVar[Registry[str, type["GPTSampledDatasetConfig"]]] = Registry[ - str, type["GPTDatasetConfig"] - ]("gpt_dataset_class", {}) - type_: typing.ClassVar[str | None] = None - type: str | None = Field( - default=None, - desc="The type of dataset.", - hint=FieldHint.core, - ) - - def _validate(self) -> None: - if self.type is None: - self.type = self.type_ - # Should be handled in `from_dict`, but can fail if instantiating directly. - Assert.eq(self.type, self.__class__.type_) - super()._validate() - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - type_ = default.get("type") - if type_ is None: - actual_cls = cls - else: - if type_ not in cls._registry: - raise ValueError( - f"Unknown {cls._registry.name} type {type_}." f" Available types: {list(cls._registry.keys())}" - ) - actual_cls = cls._registry[type_] - Assert.custom(issubclass, actual_cls, cls) - if actual_cls == cls: - return super()._from_dict(default, strict=strict, flat=flat) - else: - return actual_cls._from_dict(default, strict=strict, flat=flat) - - def __init_subclass__(cls) -> None: - if cls._abstract and cls.type_ is not None: - # Abstract classes should not have a `type_` - raise ValueError(f"Abstract class {cls.__name__} has type = {cls.type_}, expected None.") - if cls.type_ is not None: - if cls.type_ in cls._registry: - raise ValueError( - f"Registry {cls._registry.name} already contains type {cls.type_}." - f" Make sure all classes either have a unique or `None` type." - ) - GPTSampledDatasetConfig._registry[cls.type_] = cls - super().__init_subclass__() + pass @config_class() @@ -160,10 +108,9 @@ def build(self) -> "GPTIndexedDataset": raise NotImplementedError() -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "random"}) class GPTRandomDatasetConfig(GPTSamplableDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "random" name: str = Field( default="dummy", desc="The name of the dataset.", @@ -176,10 +123,9 @@ def build(self) -> "GPTRandomDataset": return GPTRandomDataset(self.name) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "memmap"}) class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "memmap" path: pathlib.Path = Field( default=None, desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.", @@ -202,10 +148,9 @@ def build(self) -> "GPTMemmapDataset": return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated"}) class GPTConcatenatedDatasetConfig(ConcatenatedDatasetConfig, GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "concatenated" datasets: list[GPTIndexedDatasetConfig] = FieldUpdate() def build(self) -> "GPTConcatenatedDataset": @@ -214,10 +159,9 @@ def build(self) -> "GPTConcatenatedDataset": return self._build(GPTConcatenatedDataset) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "slice"}) class GPTDatasetSliceConfig(DatasetSliceConfig, GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "slice" dataset: GPTIndexedDatasetConfig = FieldUpdate() def build(self) -> "GPTDatasetSlice": @@ -226,25 +170,22 @@ def build(self) -> "GPTDatasetSlice": return self._build(GPTDatasetSlice) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "sampled"}) class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig): _abstract = False - type_: typing.ClassVar[str | None] = "sampled" sampling: GPTSamplingConfig = FieldUpdate() dataset: GPTSampledDatasetConfig = FieldUpdate() -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "blended"}) class GPTBlendedDatasetConfig(BlendedDatasetConfig, GPTSampledDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "blended" datasets: list[GPTSampledDatasetConfig] = FieldUpdate() -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "file"}) class GPTDatasetFromFileConfig(GPTSamplableDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "file" path: pathlib.Path = Field( default=None, desc="The path to a dataset config file.", @@ -280,11 +221,11 @@ def _convert_paths(self, config): return config -@config_class() +# Add user-friendly names for the configs. +@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated_memmap"}) class GPTConcatenatedMemmapConfig(GPTIndexedDatasetConfig): # TODO v0.3: Remove. _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "concatenated_memmap" path: pathlib.Path = Field( default=None, desc="The path to a dataset directory.", @@ -387,14 +328,13 @@ class FimConfig(Config): ) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "fim"}) class GPTFimSampledDatasetConfig(GPTSampledDatasetConfig, FimConfig): """ Configuration for FIM. """ _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "fim" dataset: GPTSampledDatasetConfig = Field( default=None, @@ -455,10 +395,9 @@ class GPTLegacyConfig(Config): ) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "legacy"}) class GPTLegacyDatasetConfig(GPTSampledDatasetConfig, GPTLegacyConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "legacy" def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset: @@ -537,7 +476,7 @@ def build_and_sample(self, sampling: GPTSamplingData) -> SampledDataset: return GPTSampledDatasetConfig.from_dict(dataset_config).build_and_sample(sampling) -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "test_slow"}) class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig): """ A mock dataset that mimics a slow dataset creation on one rank, which may trigger a timeout. @@ -545,7 +484,6 @@ class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig): # TODO: This belongs to a testing plugin. _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "test_slow" sleep: float = Field( default=1, desc="Sleep time during build, in seconds.", From 39b1a04fd140718afda39e96a5882754819b49d8 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 14 May 2025 20:42:56 -0400 Subject: [PATCH 097/114] stuff --- fast_llm/config.py | 10 +++++++++- fast_llm/layers/common/config.py | 8 +++++++- fast_llm/layers/transformer/config.py | 16 ++++++++++++++-- tests/data/common.py | 3 +-- 4 files changed, 31 insertions(+), 6 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 6e3e92dc..380100e3 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -274,6 +274,7 @@ def __init__(self, **kwargs): if dynamic_type is not None: for cls_, name in dynamic_type.items(): + print(cls_, name, wrapped) cls_.register_subclass(name, wrapped) return wrapped @@ -899,7 +900,7 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ def register_subclass(cls, name: str, cls_: type[typing.Self]) -> None: Assert.custom(issubclass, cls_, cls) if cls._registry is None: - raise NotImplementedError(f"Subclass `{name}` doesn't have a registry..") + raise NotImplementedError(f"Subclass `{cls.__name__}` doesn't have a registry..") if name in cls._registry: old_cls = cls._registry[name] if old_cls.__name__ == cls_.__name__ and cls._registry[name].__module__ == cls_.__module__: @@ -980,6 +981,13 @@ def __init_subclass__(cls): # dataclasses expects an annotation, so we use the one from the base class. cls.__annotations__[name] = base_class_field.type + # Type for the field. At the end of class definition to avoid shadowing builtin. + type: str | None = Field( + default=None, + desc="The config class name.", + hint=FieldHint.feature, + ) + class Configurable[ConfigType: Config]: config_class: typing.ClassVar[type[Config]] = Config diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 269989ce..054c26c3 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -33,7 +33,7 @@ class NormalizationType(str, enum.Enum): rms_norm = "rms_norm" -@config_class() +@config_class(registry=True) class NormalizationConfig(BaseModelConfig): _abstract = False @@ -107,6 +107,12 @@ def _from_dict( return super()._from_dict(default, strict, flat) +for name in NormalizationType: + # We need this because we are using the reserved field name `type`. + # TODO: Implement proper dynamic typing. + NormalizationConfig.register_subclass(name.value, NormalizationConfig) + + class PeftType(str, enum.Enum): # TODO : Use a dynamic config type instead. none = "none" diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index c621139c..e7ef0b15 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -95,7 +95,7 @@ class RotaryEmbeddingType(str, enum.Enum): yarn = "yarn" -@config_class() +@config_class(registry=True) class RotaryConfig(BaseModelConfig): _abstract = False type: RotaryEmbeddingType = Field( @@ -158,6 +158,12 @@ def _validate(self) -> None: warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.") +for name in RotaryEmbeddingType: + # We need this because we are using the reserved field name `type`. + # TODO: Implement proper dynamic typing. + RotaryConfig.register_subclass(name.value, RotaryConfig) + + class AddLinearBiasChoices(str, enum.Enum): nowhere = "nowhere" everywhere = "everywhere" @@ -175,7 +181,7 @@ class TransformerSubLayerName(str, enum.Enum): mlp_2 = "mlp_2" -@config_class() +@config_class(registry=True) class TransformerPeftConfig(PeftConfig): layers: list[TransformerSubLayerName] = Field( default=None, @@ -244,6 +250,12 @@ def _validate(self) -> None: ) +for name in PeftType: + # We need this because we are using the reserved field name `type`. + # TODO: Implement proper dynamic typing. + TransformerPeftConfig.register_subclass(name.value, TransformerPeftConfig) + + @config_class() class TransformerConfig(BaseModelConfig): _abstract = False diff --git a/tests/data/common.py b/tests/data/common.py index 00c3ff20..cacb28e6 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -189,10 +189,9 @@ def validate_indexed_dataset_sampling( return token_ids -@config_class() +@config_class(dynamic_type={GPTSampledDatasetConfig: "mock_memmap"}) class MockGPTMemmapDatasetConfig(GPTIndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "mock_memmap" num_documents: int | None = Field( default=None, desc="Expected number of documents in the dataset.", From 8a8fa77be71a12a770ae43ee0d369ddec5e189eb Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 16 May 2025 15:42:04 +0000 Subject: [PATCH 098/114] fix --- fast_llm/engine/multi_stage/fsdp.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 61d1c7a8..d24c8f84 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -455,7 +455,10 @@ def copy_shard_overlaps( for shard_name, shard in shards.items(): # Shards can be empty (frozen weights) if shard.numel() == 0: - Assert.eq(loaded_shards[shard_name].numel(), 0) + continue + if loaded_shards[shard_name].numel() == 0: + shard[begin:end][overlap_mask] = 0 + counter += overlap_count continue shard[begin:end][overlap_mask] = loaded_shards[shard_name][overlap_index_map_masked] counter += overlap_count From 8e259904a48725a7b2b3e30bd82139fcaeb89fcd Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 16 May 2025 15:45:23 +0000 Subject: [PATCH 099/114] add test with frozen weights --- tests/test_checkpoint.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 042f2bb2..9d5c86e9 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -89,6 +89,23 @@ def test_resume(): ) +@pytest.mark.depends(on=["test_checkpoint_and_eval"]) +def test_resume_frozen(): + run_test_script( + f"test_{TEST_MODEL}_resume_frozen", + CONFIG_COMMON + + [ + "training.checkpoint.interval=1", + "training.evaluations.validation.interval=2", + "training.evaluations.validation.iterations=1", + "model.base_model.transformer.mlp_lr_scale=0.", + ], + compare=f"test_{TEST_MODEL}_checkpoint_and_eval", + prepare_fn=_prepare_resume_fn, + compare_fn=_compare_resume_fn, + ) + + def _run_conversion(config: ConversionConfig): if config.output.path.is_dir() and not REUSE_RESULTS: shutil.rmtree(config.output.path) From 456a0c528e7ed5d6648727c5f2f8fb7c7e18c318 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Fri, 16 May 2025 20:15:57 +0000 Subject: [PATCH 100/114] add description for tests --- tests/common.py | 3 ++- tests/test_checkpoint.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/common.py b/tests/common.py index 569d690c..6179957b 100644 --- a/tests/common.py +++ b/tests/common.py @@ -361,6 +361,7 @@ def run_test_script( config: CompareConfig | None = None, prepare_fn=None, compare_fn=None, + do_compare: bool = True, ): if torch.cuda.device_count() < num_gpus: pytest.skip(f"Not enough GPUs to run test ({torch.cuda.device_count()}<{num_gpus})") @@ -413,7 +414,7 @@ def run_test_script( completed_proc = subprocess.run(command, env=env, timeout=60) if completed_proc.returncode: raise RuntimeError(f"Process failed with return code {completed_proc.returncode}") - if compare: + if compare and do_compare: if compare_fn is not None: compare_fn(TEST_RESULTS_PATH / name, TEST_RESULTS_PATH / compare) compare_tensor_logs( diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 9d5c86e9..4dfd23a8 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -75,6 +75,7 @@ def _compare_resume_fn(test_path: pathlib.Path, compare_path: pathlib.Path): @pytest.mark.depends(on=["test_checkpoint_and_eval"]) def test_resume(): + # Resume from iteration=1 and compare outputs with the baseline run. run_test_script( f"test_{TEST_MODEL}_resume", CONFIG_COMMON @@ -91,6 +92,7 @@ def test_resume(): @pytest.mark.depends(on=["test_checkpoint_and_eval"]) def test_resume_frozen(): + # Resume with frozen mlp. No comparison. run_test_script( f"test_{TEST_MODEL}_resume_frozen", CONFIG_COMMON @@ -102,7 +104,7 @@ def test_resume_frozen(): ], compare=f"test_{TEST_MODEL}_checkpoint_and_eval", prepare_fn=_prepare_resume_fn, - compare_fn=_compare_resume_fn, + do_compare=False, ) From 87efd455af8fa4bf6f8e4a42cdfa73bf89913498 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 20 May 2025 12:17:32 +0000 Subject: [PATCH 101/114] 15b model apriel hybrid --- .../configuration_ssm_hybrid_apriel15b.py | 21 + .../modeling_ssm_hybrid_apriel15b.py | 948 ++++++++++++++++++ .../modeling_ssm_hybrid_apriel.py | 2 +- .../ssm/external/eval/apriel_eval_wrapper.py | 88 +- .../ssm/external/make_hybrid_checkpoint.py | 41 + 5 files changed, 1097 insertions(+), 3 deletions(-) create mode 100644 fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py create mode 100644 fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py create mode 100644 fast_llm/models/ssm/external/make_hybrid_checkpoint.py diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py new file mode 100644 index 00000000..bc2e603c --- /dev/null +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py @@ -0,0 +1,21 @@ +from transformers import MistralConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class AprielSSMHybridConfig(MistralConfig): + def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): + super().__init__(**kwargs) + self.hybrid_block_layout = hybrid_block_layout + self.ssm_cfg = ssm_cfg or { + "d_state": 64, + "n_v_heads": 24, + "n_qk_heads": 24, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + "d_conv": 4, + "d_inner": 24 * self.head_dim, # num_heads * head_dim + } diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py new file mode 100644 index 00000000..ba798a1c --- /dev/null +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -0,0 +1,948 @@ +import copy +from dataclasses import dataclass +from typing import Any, Optional, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from einops import rearrange, repeat +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined +from torch import nn +from transformers import GenerationMixin, MistralModel +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.mistral.modeling_mistral import ( + MistralDecoderLayer, + MistralMLP, + MistralPreTrainedModel, + MistralRMSNorm, +) +from transformers.processing_utils import Unpack +from transformers.utils import LossKwargs, logging +from transformers.utils.generic import ModelOutput + +from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig + +logger = logging.get_logger(__name__) + + +class HybridMambaAttentionStaticCache(Cache): + def __init__(self, config: AprielSSMHybridConfig, batch_size, max_length, dtype=torch.float16, device=None): + super().__init__() # config, batch_size, max_length, device, dtype) + self.dtype = dtype + self.hybrid_override_pattern = config.hybrid_block_layout + self.has_previous_state = False # only used by mamba + intermediate_size = config.ssm_cfg["d_inner"] + ssm_state_size = config.ssm_cfg["d_state"] + conv_kernel_size = config.ssm_cfg["d_conv"] + self.n_qk_heads = config.ssm_cfg["n_qk_heads"] + assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" + self.head_d = intermediate_size // self.n_qk_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + + self.batch_size = batch_size + self.head_dim = ( + config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + ) + self.max_cache_len = config.max_position_embeddings if max_length is None else max_length + + self.num_key_value_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + cache_shape = (self.batch_size, self.num_key_value_heads, max_length, self.head_dim) + + for i in range(config.num_hidden_layers): + if self.hybrid_override_pattern[i] == "m2d": + # Mamba layer + new_layer_conv_state = torch.zeros( + batch_size, + conv_kernel_size, + intermediate_size + 2 * self.n_qk_heads * ssm_state_size, + device=device, + dtype=dtype, + ).transpose(1, 2) + + new_layer_ssm_state = torch.zeros( + batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype + ) + new_layer_key_cache = None # torch.zeros((0,), dtype=dtype, device=device) + new_layer_value_cache = None # torch.zeros((0,), dtype=dtype, device=device) + else: + # Attention or MLP layer + new_layer_conv_state = None # torch.tensor((0,), dtype=dtype, device=device) + new_layer_ssm_state = None # torch.tensor((0,), dtype=dtype, device=device) + new_layer_key_cache = torch.zeros(cache_shape, dtype=dtype, device=device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=dtype, device=device) + self.transformer_layers.append(i) + + # if not is_torchdynamo_compiling(): + # self.register_buffer(f"key_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) + # self.register_buffer(f"value_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) + # new_layer_key_cache = getattr(self, f"key_cache_{i}") + # new_layer_value_cache = getattr(self, f"value_cache_{i}") + # torch._dynamo.mark_static_address(new_layer_key_cache) + # torch._dynamo.mark_static_address(new_layer_value_cache) + # self.register_buffer(f"conv_states_{i}", new_layer_conv_state) + # self.register_buffer(f"ssm_states_{i}", new_layer_ssm_state) + # torch._dynamo.mark_static_address(new_layer_conv_state) + # torch._dynamo.mark_static_address(new_layer_ssm_state) + # new_layer_ssm_state = getattr(self, f"ssm_states_{i}") + # new_layer_conv_state = getattr(self, f"conv_states_{i}") + + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + self.conv_states.append(new_layer_conv_state) + self.ssm_states.append(new_layer_ssm_state) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. + + Return: + A tuple containing the updated key and value states. + """ + + cache_position = cache_kwargs.get("cache_position") + + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + + if cache_position is None: + k_out.copy_(key_states) + v_out.copy_(value_states) + else: + # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to + # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place + # operation, that avoids copies and uses less memory. + try: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + return k_out, v_out + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = None) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + if layer_idx is None: + layer_idx = self.transformer_layers[0] + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def get_max_cache_shape(self) -> Optional[int]: + return self.max_cache_len + + # Copied from modeling_mamba2.py + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False + ) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py +class HybridMambaAttentionDynamicCache(DynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config: AprielSSMHybridConfig, batch_size, dtype=torch.float16, device=None): + super().__init__() + self.dtype = dtype + self.hybrid_override_pattern = config.hybrid_block_layout + self.has_previous_state = False # only used by mamba + intermediate_size = config.ssm_cfg["d_inner"] + ssm_state_size = config.ssm_cfg["d_state"] + conv_kernel_size = config.ssm_cfg["d_conv"] + self.n_qk_heads = config.ssm_cfg["n_qk_heads"] + assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" + self.head_d = intermediate_size // self.n_qk_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + if self.hybrid_override_pattern[i] == "m2d": + # Mamba layer + self.conv_states += [ + torch.zeros( + batch_size, + conv_kernel_size, + intermediate_size + 2 * self.n_qk_heads * ssm_state_size, + device=device, + dtype=dtype, + ).transpose(1, 2) + ] + self.ssm_states += [ + torch.zeros(batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype) + ] + else: + # Attention or MLP layer + self.conv_states += [torch.tensor([[]] * batch_size, device=device)] + self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + # Copied from modeling_mamba2.py + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False + ) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +@dataclass +class AprielHybridCausalOutput(ModelOutput): + """Custom output class for MambaLMHeadModel.""" + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + attention_weights: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + + +def segsum(x): + """More stable segment sum calculation.""" + # [1, 2, 3] + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + # [[1, 1, 1], [2, 2, 2], [3, 3, 3]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) + x = x.masked_fill(~mask, 0) + # [[0, 0, 0], [2, 0, 0], [3, 3, 0]] + x_segsum = torch.cumsum(x, dim=-2) + # [[0, 0, 0], [2, 0, 0], [5, 3, 0]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def materialize_mixer(A_log, B, C, D): + """ + Since the transfer matrix will be equated to the attention matrix, + we need to support the form: torch.matmul(attn_weights, value_states). + Thus, y = torch.matmul(T, X) + Arguments: + A_log: (batch, length, n_heads) + B: (batch, length, n_heads, d_state) + C: (batch, length, n_heads, d_state) + Return: + T: (batch, n_heads, length, length) + """ + batch_size, length, n_heads, d_state = B.shape + assert A_log.shape == (batch_size, length, n_heads) + assert B.shape == C.shape == (batch_size, length, n_heads, d_state) + + # Compute: + A_log = rearrange(-F.softplus(A_log), "b l h -> b h l") + powers = torch.exp(segsum(A_log)) + T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers) + + # Add D: + if D is not None: + T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1) + + T = rearrange(T, "b h z l -> b h l z") + return T + + +# This is from LLmaba/Mohawk: https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py +class DiscreteMamba2(nn.Module): + def __init__( + self, + d_model, + d_state=64, + n_qk_heads=32, + n_v_heads=32, + d_conv=4, + expand=1, + activation="identity", + bias=False, + conv_bias=True, + chunk_size=128, + layer_idx=None, + device=None, + dtype=None, + d_inner=None, + **kwargs, # Absorb kwarg for general module + ): + """ + See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. + Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" + + Other options are all experimental and should not need to be configured + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = self.expand * self.d_model if d_inner is None else d_inner + self.n_qk_heads = n_qk_heads + self.n_v_heads = n_v_heads + self.headdim = self.d_inner // self.n_v_heads + assert self.n_v_heads == self.d_inner // self.headdim + assert self.d_inner % self.headdim == 0 + assert self.n_v_heads % self.n_qk_heads == 0 + self.activation = activation + self.chunk_size = chunk_size + self.layer_idx = layer_idx + self.bias = bias + self.kwargs = kwargs + + # Projections + self.in_proj = nn.Linear( + self.d_model, + 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, + bias=bias, + **factory_kwargs, + ) + self.z_bias = ( + nn.Parameter(torch.zeros(self.d_inner, device=device)) if not bias else 0 + ) # make sure z_bias always exists + + # Convolutional layer + conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state + self.conv_bias = conv_bias + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + **factory_kwargs, + ) + + # Activation after conv + if self.activation == "identity": + self.act = nn.Identity() + elif self.activation in ["silu", "swish"]: + self.act = nn.SiLU() + else: + raise ValueError(f"Unknown activation {self.activation}") + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.n_v_heads, device=device)) + self.D._optim = {"weight_decay": 0.0} + + # out_proj + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + + @property + def d_output(self): + return self.d_model + + @property + def state_to_tensor(self): + return self.layer.state_to_tensor + + def forward( + self, + u, + return_mixer_matrix=False, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + inference_params=None, + **kwargs, + ): + """ + u: (B, L, D) + Returns: same shape as u + """ + outputs = {} + # assert state is None + batch, seqlen, dim = u.shape + + ssm_state, conv_state = None, None + if past_key_value is not None: + ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) + if inference_params is not None and inference_params.seqlen_offset > 0: + # States are updated inplace + # TODO: make sure inference_params with seqlen_offset are properly initialized + u = u.squeeze(1) if len(u.shape) == 3 else u + out, _, _ = self.step(u, ssm_state, conv_state) + out = out.unsqueeze(1) if len(u.shape) == 2 else out + return {"hidden_states": out} + + # Hacky way to initialize state during inference + chunk_size = self.chunk_size if ssm_state is None else seqlen + + # Pad input to nearest multiple of chunklen + padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size + u = F.pad(u, (0, 0, 0, padded_len - seqlen)) + + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + if ssm_state is not None: + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") + conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + + # Convolutional layer + xBC = self.convolutional_forward(xBC, padded_len) + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) + B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) + C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + + # SSM forward + result = mamba_chunk_scan_combined( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=A_log, + dt_softplus=True, + A=-torch.ones(self.n_v_heads, device=A_log.device), + B=B, + C=C, + chunk_size=chunk_size, + # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation + return_final_states=(ssm_state is not None), + ) + + if ssm_state is not None: + y, ssm_state_update = result + ssm_state.copy_(ssm_state_update) + else: + y = result + + Du = torch.einsum("h,blhp->blhp", self.D, x) + y = rearrange(y + Du, "b l h p -> b l (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + outputs["hidden_states"] = out[:, :seqlen, :] + + if return_mixer_matrix: + outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] + return outputs + + def step(self, u, ssm_state, conv_state, **kwargs): + """ + u: (B D) + state: dict of states + Returns: same shape as u + """ + + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + xBC, conv_state_new = self.convolutional_step(xBC, conv_state) + conv_state.copy_(conv_state_new) # update state in place + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) + B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) + C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) + + ssm_state = ssm_state.to(x.dtype) + zeros = torch.zeros((self.n_v_heads, self.headdim), device=A_log.device).to(dtype=x.dtype) + ones = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=A_log.device).to(dtype=x.dtype) + y = selective_state_update( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=repeat(A_log, "b h -> b h p", p=self.headdim), + dt_softplus=True, + A=-ones, + B=B, + C=C, + state=ssm_state, # will be updated in place + dt_bias=zeros, + D=zeros, + ) + + y = y + self.D[:, None] * x + y = rearrange(y, "b h p -> b (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + + return out, ssm_state, conv_state + + # def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + # device = self.in_proj.weight.device + # # conv_state: + # conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + # conv_state = torch.zeros( + # batch_size, + # self.d_conv, + # self.conv1d.weight.shape[0], + # device=device, + # dtype=conv_dtype, + # ).transpose(1, 2) + # # ssm_state: + # ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype + # ssm_state = torch.zeros( + # batch_size, + # self.n_v_heads, + # self.headdim, + # self.d_state, + # device=device, + # dtype=ssm_dtype, + # ) + # return {"conv": conv_state, "ssm": ssm_state} + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + """ + conv_state: (batch, d_conv, conv1d.weight.shape[0]) + ssm_state: (batch, n_qk_heads, headdim, d_state) + """ + assert self.layer_idx is not None + # Allocate memory if not exists + # if self.layer_idx not in inference_params.ssm_states: + # inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( + # batch_size, inference_params.max_seqlen, dtype=torch.float32 + # ) + # Get states + ssm_states = inference_params.ssm_states[self.layer_idx] + conv_states = inference_params.conv_states[self.layer_idx] + if initialize_states: + ssm_states.zero_() + conv_states.zero_() + return ssm_states, conv_states + + def convolutional_forward(self, xBC, padded_len): + if causal_conv1d_fn is None or self.activation not in [ + "silu", + "swish", + "identity", + ]: + xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) + else: + xBC = causal_conv1d_fn( + xBC.transpose(1, 2), + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + activation=None if self.activation == "identity" else self.activation, + ).transpose(1, 2) + return xBC + + def convolutional_step(self, xBC, conv_state): + # Convolutional layer + conv_state = conv_state.to(xBC.dtype) + if causal_conv1d_update: + xBC = causal_conv1d_update( + xBC, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation if self.activation != "identity" else None, + ) + else: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = xBC + xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv_bias: + xBC = xBC + self.conv1d.bias + xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype + + return xBC, conv_state + + +class AprielSSMDecoderLayer(nn.Module): + def __init__(self, config: AprielSSMHybridConfig, layer_idx: int, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} + self.hidden_size = config.hidden_size + + self.mixer = DiscreteMamba2( + d_model=config.hidden_size, + layer_idx=layer_idx, + **config.ssm_cfg, + **factory_kwargs, + ) + + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, hidden_states: torch.Tensor, **kwargs + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + + outputs = {} + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + mixer_outputs = self.mixer( + hidden_states, + **kwargs, + ) + + hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + # outputs["hidden_states"] = hidden_states + outputs = (hidden_states,) + + return outputs + + +class AprielHybridIdentity(nn.Module): + def __init__(self, config: AprielSSMHybridConfig): + super().__init__() + self.config = config + + def forward(self, hidden_states: torch.Tensor, **kwargs): + return (hidden_states,) + + +class AprielSSMHybridModel(MistralModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`, `AprielSSMDecoderLayer`] + Args: + config: AprielSSMHybridConfig + """ + + def __init__(self, config: AprielSSMHybridConfig, **kwargs): + config_copy = copy.deepcopy(config) + config_copy.num_hidden_layers = 0 + super().__init__(config_copy, **kwargs) + blocks = [] + logger.info(f"Loading hyubrid model with the following layout: {config.hybrid_block_layout}") + for layer_idx, type in enumerate(config.hybrid_block_layout): + if type == "m2d": + blocks.append(AprielSSMDecoderLayer(config, layer_idx)) + elif type == "t": + blocks.append(MistralDecoderLayer(config, layer_idx)) + elif type == "i": + blocks.append(AprielHybridIdentity(config)) + else: + raise ValueError(f"Invalid block type: {type}") + self.layers = nn.ModuleList(blocks) + + # Initialize weights and apply final processing + self.post_init() + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class AprielHybridPreTrainedModel(MistralPreTrainedModel): + config_class = AprielSSMHybridConfig + base_model_prefix = "model" + _no_split_modules = ["MistralDecoderLayer", "AprielSSMDecoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + +class AprielSSMHybridForCausalLM(AprielHybridPreTrainedModel, GenerationMixin): + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self.model = AprielSSMHybridModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache) + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + # "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return AprielHybridCausalOutput( + loss=loss, + logits=logits, + all_hidden_states=outputs.hidden_states, + past_key_values=outputs.past_key_values, + ) + + +__all__ = [ + "AprielSSMHybridForCausalLM", + "AprielSSMHybridModel", + "AprielSSMPreTrainedModel", +] diff --git a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py index 5d8f4cc5..ddb7d0f7 100644 --- a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py +++ b/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py @@ -1482,7 +1482,7 @@ def prepare_inputs_for_generation( ): # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` - empty_past_kv = past_key_values is None + empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache) # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py index 02c9176b..e15de8bb 100644 --- a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py +++ b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py @@ -92,5 +92,89 @@ def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype] ) def _model_generate(self, context, max_length, stop, **generation_kwargs): - # FOR now evaluating with non-generation tasks - raise NotImplementedError("Generation not implemented yet for AprielHybridSSMWrapper") + + stopping_criteria = lm_eval.models.utils.stop_sequences_criteria( + self.tokenizer, + stop, + context.shape[1], + context.shape[0], + ) + + generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", None) + + # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies + if generation_kwargs.get("temperature") == 0.0 and do_sample is None: + generation_kwargs["do_sample"] = do_sample = False + if do_sample is False and generation_kwargs.get("temperature") == 0.0: + generation_kwargs.pop("temperature") + return self.model.generate( + input_ids=context, + max_length=max_length, + stopping_criteria=stopping_criteria, + use_cache=True, + **generation_kwargs, + ) + + +@register_model("apriel_hybrid_ssm_15b") +class AprielHybridSSMWrapper(HFLM): + """Wrapper for AprielHybridSSM model for compatibility with lm-evaluation-harness.""" + + def __init__(self, pretrained, **kwargs) -> None: + if "backend" in kwargs: + assert kwargs["backend"] == "causal" + + super().__init__( + pretrained=pretrained, + backend=kwargs.pop("backend", "causal"), + tokenizer=kwargs.pop("tokenizer", "/mnt/checkpoints/upstream/Mistral-Nemo-Base-2407/"), + max_length=kwargs.pop("max_length", 4096), + **kwargs, + ) + + def _get_config(self, pretrained: str, **kwargs) -> None: + """Get the model configuration.""" + from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import ( + AprielSSMHybridConfig, + ) + + self._config = AprielSSMHybridConfig.from_pretrained(pretrained, trust_remote_code=True) + + def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: + """Create the model.""" + from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( + AprielSSMHybridForCausalLM, + ) + + self._model = AprielSSMHybridForCausalLM.from_pretrained( + pretrained, + device=self._device, + torch_dtype=torch.bfloat16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), + **kwargs, + ) + + def _model_generate(self, context, max_length, stop, **generation_kwargs): + + stopping_criteria = lm_eval.models.utils.stop_sequences_criteria( + self.tokenizer, + stop, + context.shape[1], + context.shape[0], + ) + + generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", None) + + # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies + if generation_kwargs.get("temperature") == 0.0 and do_sample is None: + generation_kwargs["do_sample"] = do_sample = False + if do_sample is False and generation_kwargs.get("temperature") == 0.0: + generation_kwargs.pop("temperature") + return self.model.generate( + input_ids=context, + max_length=max_length, + stopping_criteria=stopping_criteria, + use_cache=True, + **generation_kwargs, + ) diff --git a/fast_llm/models/ssm/external/make_hybrid_checkpoint.py b/fast_llm/models/ssm/external/make_hybrid_checkpoint.py new file mode 100644 index 00000000..a0616ab6 --- /dev/null +++ b/fast_llm/models/ssm/external/make_hybrid_checkpoint.py @@ -0,0 +1,41 @@ +import gc + +import click +import torch +from transformers import AutoConfig, AutoModelForCausalLM + +from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig +from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import AprielSSMHybridForCausalLM + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +@click.command() +@click.option("--identity_index", type=int, required=True) +@click.option("--save_dir", type=str, required=True) +def main(identity_index: int, save_dir: str): + checkpoint = "ServiceNow-AI/Apriel-Nemotron-15b-Thinker" + config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True) + + hybrid_block_layout = ["t"] * config.num_hidden_layers + if identity_index >= 0: + hybrid_block_layout[identity_index] = "i" + + hybrdif_apriel_config = AprielSSMHybridConfig(**config.to_dict(), hybrid_block_layout=hybrid_block_layout) + hybrid_apriel_model = AprielSSMHybridForCausalLM(hybrdif_apriel_config) + hybrid_apriel_model.to(dtype=torch.bfloat16).to(device) + + apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True) + apriel_state_dict = apriel_model.state_dict() + hybrid_apriel_model.load_state_dict(apriel_state_dict, strict=False) + + hybrid_apriel_model.save_pretrained(save_dir, save_config=True) + torch.cuda.empty_cache() + del hybrid_apriel_model + del apriel_model + del apriel_state_dict + gc.collect() + + +if __name__ == "__main__": + main() From aafbfb569c1370a683ec4931177edc30a1012e4d Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 20 May 2025 12:22:01 +0000 Subject: [PATCH 102/114] nvm --- fast_llm/tensor.py | 1 - tests/common.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index ad2d42d1..84930756 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -234,7 +234,6 @@ def __init__( self.allow_no_grad = allow_no_grad self.lr_scale = lr_scale if isinstance(lr_scale, tuple) else (lr_scale,) - # TODO: note, this pevents the tes_checkpoints to pass for MODEL=llama-mtp, they pass with `self.requires_grad=requires_grad` instead. However, the model export seem to work as expected, at least for hybrid SSM. self.requires_grad = requires_grad and any(lr_scale_ != 0 for lr_scale_ in self.lr_scale) # Ensure the parameter is split in chunks of equal size. Assert.multiple(self.dims[0].size, len(self.lr_scale)) diff --git a/tests/common.py b/tests/common.py index f9cb324e..fe3120c2 100644 --- a/tests/common.py +++ b/tests/common.py @@ -202,8 +202,6 @@ CONFIG_LLAMA_MTP_MEGATRON = None CONFIG_LLAMA_MTP_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [ "model.base_model.prediction_heads=4", - "model.base_model.embeddings_lr_scale=0", - "model.base_model.transformer.per_layer_lr_scale=[0.1,0,0,1,1,.1]", ] CONFIG_LLAMA_MTP_COMMON = CONFIG_LLAMA_MTP_FAST_LLM + ["model.distributed.training_dtype=bf16"] From c7fe8d74af47abc5d3f264349eca4eb4303e8ce9 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 20 May 2025 19:34:10 +0000 Subject: [PATCH 103/114] nvm --- .../modeling_ssm_hybrid_apriel15b.py | 48 +++++++++++++++---- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index ba798a1c..4b2e5724 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -10,16 +10,12 @@ from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined from torch import nn -from transformers import GenerationMixin, MistralModel +from transformers import GenerationMixin from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.models.mistral.modeling_mistral import ( - MistralDecoderLayer, - MistralMLP, - MistralPreTrainedModel, - MistralRMSNorm, -) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm from transformers.processing_utils import Unpack from transformers.utils import LossKwargs, logging from transformers.utils.generic import ModelOutput @@ -759,6 +755,7 @@ def __init__(self, config: AprielSSMHybridConfig, **kwargs): config_copy = copy.deepcopy(config) config_copy.num_hidden_layers = 0 super().__init__(config_copy, **kwargs) + self.config = config blocks = [] logger.info(f"Loading hyubrid model with the following layout: {config.hybrid_block_layout}") for layer_idx, type in enumerate(config.hybrid_block_layout): @@ -779,10 +776,11 @@ def __init__(self, config: AprielSSMHybridConfig, **kwargs): class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... -class AprielHybridPreTrainedModel(MistralPreTrainedModel): +class AprielHybridPreTrainedModel(PreTrainedModel): config_class = AprielSSMHybridConfig base_model_prefix = "model" _no_split_modules = ["MistralDecoderLayer", "AprielSSMDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True @@ -791,8 +789,24 @@ class AprielHybridPreTrainedModel(MistralPreTrainedModel): _supports_static_cache = True _supports_attention_backend = True + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, MistralRMSNorm): + module.weight.data.fill_(1.0) + class AprielSSMHybridForCausalLM(AprielHybridPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + def __init__(self, config, **kwargs): super().__init__(config, **kwargs) self.model = AprielSSMHybridModel(config) @@ -802,6 +816,24 @@ def __init__(self, config, **kwargs): # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + def prepare_inputs_for_generation( self, input_ids, From c285e8de0e831dbe4f8e211cd89f9eed70b03aa8 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 20 May 2025 19:51:30 +0000 Subject: [PATCH 104/114] nvm --- tests/common.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/common.py b/tests/common.py index fe3120c2..6179957b 100644 --- a/tests/common.py +++ b/tests/common.py @@ -63,8 +63,7 @@ f"model.multi_stage.debug_param_init={_LOG_LEVEL}", f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", f"model.multi_stage.debug_layer_gradients={_LOG_LEVEL}", - # f"model.multi_stage.debug_all_param_gradients={_LOG_LEVEL}", - f"model.multi_stage.debug_all_param_gradients=0", + f"model.multi_stage.debug_all_param_gradients={_LOG_LEVEL}", "model.multi_stage.debug_tensor_parallel=True", "model.distributed.reproducible_init=True", "model.distributed.timeout=10", From 3eaa240cc8373f4ddce1bf2dc254fd9a0d9d2d3b Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 22 May 2025 16:55:37 +0000 Subject: [PATCH 105/114] modeling --- .../modeling_ssm_hybrid_apriel15b.py | 117 +++++++++++++++++- 1 file changed, 115 insertions(+), 2 deletions(-) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index 4b2e5724..777fd3cf 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -1,5 +1,6 @@ import copy from dataclasses import dataclass +from functools import partial from typing import Any, Optional, Union import torch @@ -15,9 +16,15 @@ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel -from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm +from transformers.models.mistral.modeling_mistral import ( + MISTRAL_INPUTS_DOCSTRING, + MistralDecoderLayer, + MistralMLP, + MistralModel, + MistralRMSNorm, +) from transformers.processing_utils import Unpack -from transformers.utils import LossKwargs, logging +from transformers.utils import LossKwargs, add_start_docstrings_to_model_forward, can_return_tuple, logging from transformers.utils.generic import ModelOutput from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig @@ -772,6 +779,112 @@ def __init__(self, config: AprielSSMHybridConfig, **kwargs): # Initialize weights and apply final processing self.post_init() + @can_return_tuple + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # OO: Cache is initialized in the `prepare_inputs_for_generation` method, so this can be removed + # if use_cache and past_key_values is None: + # past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + partial(decoder_layer.__call__, **flash_attn_kwargs), + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... From 4781d158522c1e452c5e18f77bec1cf0e0fbe1ad Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 26 May 2025 16:37:24 +0000 Subject: [PATCH 106/114] wip --- fast_llm/layers/common/config.py | 451 +++++++++++++- fast_llm/layers/language_model/config.py | 70 ++- fast_llm/layers/language_model/embedding.py | 21 +- fast_llm/layers/language_model/head.py | 17 +- fast_llm/layers/ssm/blocks.py | 55 ++ fast_llm/layers/ssm/discrete_mamba2.py | 27 +- fast_llm/layers/ssm/llamba_block.py | 34 -- fast_llm/layers/ssm/mamba_layer.py | 17 +- fast_llm/layers/transformer/attention.py | 30 +- fast_llm/layers/transformer/config.py | 574 +++++------------- fast_llm/layers/transformer/mlp.py | 25 +- fast_llm/layers/transformer/transformer.py | 48 +- fast_llm/models/auto.py | 2 +- fast_llm/models/gpt/config.py | 39 +- fast_llm/models/gpt/model.py | 4 +- fast_llm/models/hybrid/config.py | 376 ++++++++++++ fast_llm/models/{ssm => hybrid}/conversion.py | 4 +- .../configuration_ssm_hybrid_apriel15b.py | 0 .../modeling_ssm_hybrid_apriel15b.py | 2 +- .../configuration_ssm_hybrid_apriel.py | 0 .../modeling_ssm_hybrid_apriel.py | 2 +- .../apriel_ssm/configuration_ssm_apriel.py | 0 .../apriel_ssm/modeling_ssm_apriel.py | 2 +- .../external/eval/apriel_eval_wrapper.py | 12 +- .../external/eval/run_lm_eval.py | 2 +- .../llamba/configuration_mtp_llamba.py | 0 .../external/llamba/modeling_mtp_llamba.py | 0 .../external/make_hybrid_checkpoint.py | 4 +- .../models/{ssm => hybrid}/huggingface.py | 4 +- fast_llm/models/{ssm => hybrid}/model.py | 55 +- fast_llm/models/{ssm => hybrid}/trainer.py | 4 +- fast_llm/models/ssm/config.py | 235 ------- tests/common.py | 2 +- tests/test_modular_config.py | 37 ++ tests/test_mtp.py | 2 +- tests/test_ssms.py | 12 +- 36 files changed, 1302 insertions(+), 867 deletions(-) create mode 100644 fast_llm/layers/ssm/blocks.py delete mode 100644 fast_llm/layers/ssm/llamba_block.py create mode 100644 fast_llm/models/hybrid/config.py rename fast_llm/models/{ssm => hybrid}/conversion.py (99%) rename fast_llm/models/{ssm => hybrid}/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py (100%) rename fast_llm/models/{ssm => hybrid}/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py (99%) rename fast_llm/models/{ssm => hybrid}/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py (100%) rename fast_llm/models/{ssm => hybrid}/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py (99%) rename fast_llm/models/{ssm => hybrid}/external/apriel_ssm/configuration_ssm_apriel.py (100%) rename fast_llm/models/{ssm => hybrid}/external/apriel_ssm/modeling_ssm_apriel.py (99%) rename fast_llm/models/{ssm => hybrid}/external/eval/apriel_eval_wrapper.py (91%) rename fast_llm/models/{ssm => hybrid}/external/eval/run_lm_eval.py (62%) rename fast_llm/models/{ssm => hybrid}/external/llamba/configuration_mtp_llamba.py (100%) rename fast_llm/models/{ssm => hybrid}/external/llamba/modeling_mtp_llamba.py (100%) rename fast_llm/models/{ssm => hybrid}/external/make_hybrid_checkpoint.py (84%) rename fast_llm/models/{ssm => hybrid}/huggingface.py (83%) rename fast_llm/models/{ssm => hybrid}/model.py (66%) rename fast_llm/models/{ssm => hybrid}/trainer.py (71%) delete mode 100644 fast_llm/models/ssm/config.py create mode 100644 tests/test_modular_config.py diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 1b84eeb2..0f0e1001 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -1,8 +1,11 @@ import enum import typing -from fast_llm.config import Field, FieldHint, check_field, config_class +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace +from fast_llm.engine.distributed.config import DistributedDimNames +from fast_llm.functional.config import ActivationType, MLPRecomputeLevel from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -11,16 +14,37 @@ from fast_llm.layers.common.normalization import LayerNorm, RMSNorm -@config_class() -class LLMBlockConfig(BaseModelConfig): - _abstract = False +class RotaryEmbeddingType(str, enum.Enum): + none = "none" + default = "default" + llama3 = "llama3" + yarn = "yarn" - per_layer_lr_scale: list[float] | None = Field( - default=None, - desc="Custom learning rate scale for each layer.", - doc="May be used to freeze some layers by setting their scale to zero.", - hint=FieldHint.feature, - ) + +class LLMDimNames: + input_hidden = "input_hidden" + output_hidden = "output_hidden" + # A set of common tensor dim names packed into a namespace. + # Input dimensions (variable) + # TODO: Does batch belong here? + batch = "batch" + # TODO: Distinguish micro-sequence? + sequence_q = "sequence_q" + sequence_q_tp = "sequence_q_tp" + sequence_k = "sequence_k" + hidden = "hidden" + # MLP dimensions + mlp = "mlp" + gate_and_up = "gate_and_up" + composite_gated_mlp = "composite_gated_mlp" + experts = "experts" + top_experts = "top_experts" + shared_experts = "shared_experts" + unshared_experts = "unshared_experts" + composite_expert_mlp = "composite_expert_mlp" + composite_gated_expert_mlp = "composite_gated_expert_mlp" + composite_shared_expert_mlp = "composite_shared_expert_mlp" + composite_gated_shared_expert_mlp = "composite_gated_shared_expert_mlp" class NormalizationImplementation(str, enum.Enum): @@ -175,3 +199,410 @@ def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": ) else: raise NotImplementedError(self.type) + + +class RoutingType(str, enum.Enum): + topk = "aux_loss" + sinkhorn = "sinkhorn" + + +class AddLinearBiasChoices(str, enum.Enum): + nowhere = "nowhere" + everywhere = "everywhere" + only_attn_qkv = "only_attn_qkv" + + +class BaseBlockSubLayerName(str, enum.Enum): + mlp_1 = "mlp_1" + mlp_2 = "mlp_2" + + +@config_class(registry=True) +class BaseBlockPeftConfig(PeftConfig): + """ + Peft Cofnig that applies to transformer layer. If this is used with GPTBaseModel it is reused for all transformer layers. + Note, this has no effect on the embedding layer, + if you want to freeze the embeddings (and other layers outside the transformer) you need to do so explicitly by setting embedding lr_scale to 0. + """ + + layers: list[BaseBlockSubLayerName] = Field( + default=None, + desc="The layers on which to apply LoRA.", + hint=FieldHint.feature, + ) + + def apply_linear(self, linear: "LinearBase", layer_type: BaseBlockSubLayerName | None = None) -> "LinearLike": + if self.type != PeftType.none: + if layer_type is None or self.layers is None or layer_type in self.layers: + return super().apply_linear(linear) + return linear + + def _validate(self) -> None: + if self.layers is None: + with self._set_implicit_default(): + self.layers = [] + if self.type != PeftType.none: + if BaseBlockSubLayerName.mlp_1 in self.layers or BaseBlockSubLayerName.mlp_2 in self.layers: + # TODO: Add MLP support. + raise NotImplementedError("LoRA not supported for MLP.") + + +for name in PeftType: + # We need this because we are using the reserved field name `type`. + # TODO: Implement proper dynamic typing. + BaseBlockPeftConfig.register_subclass(name.value, BaseBlockPeftConfig) + + +@config_class() +class BaseBlockConfig(BaseModelConfig): + _abstract = True + peft: BaseBlockPeftConfig = Field( + desc="Configuration for the parameter-efficient fine tuning.", + hint=FieldHint.architecture, + ) + normalization: NormalizationConfig = Field( + desc="Configuration for the normalization layers architecture.", + hint=FieldHint.architecture, + ) + hidden_dropout: float = Field( + default=0.0, + desc="Dropout applied to the residual connections.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + debug_block: int = Field( + default=0, + desc="Log the output of each operation in each layer.", + hint=FieldHint.logging, + valid=check_field(Assert.geq, 0), + ) + debug_block_memory: bool = Field( + default=False, + desc="Log the memory usage after each operation in each layer.", + hint=FieldHint.logging, + ) + num_experts: int = Field( + default=1, + desc="Number of MLP experts in a Mixture of Expert (MoE) model", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + norm_lr_scale: float | None | list[float | None] = Field( + default=None, + desc="Custom learning rate scale for each normalization layer.", + doc="May be used to freeze some normalization layers by setting their scale to zero.", + hint=FieldHint.feature, + ) + mlp_lr_scale: float | None | list[float | None] = Field( + default=None, + desc="Custom learning rate scale for each expert.", + doc="May be used to freeze some experts by setting their scale to zero.", + hint=FieldHint.feature, + ) + + num_layers: int = Field( + default=12, + desc="Number of layers in the transformer.", + hint=FieldHint.architecture, + valid=check_field(Assert.geq, 0), + ) + hidden_size: int = Field( + default=1024, + desc="Size of the transformer's main hidden dimension, e.g., for its input and output layers.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + ffn_hidden_size: int = Field( + default=None, + desc="Hidden dimension of the MLP intermediate state. Default: 4 * hidden_size.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + gated: bool = Field(default=False, desc="Enable gated MLP.", hint=FieldHint.architecture) + num_shared_experts: int = Field( + default=0, + desc="Number of MLP experts that are shared between all tokens, i.e., always enabled.", + hint=FieldHint.architecture, + valid=check_field(Assert.geq, 0), + ) + num_unshared_experts: int = Field( + init=False, + desc="Number of MLP experts excluding shared ones", + hint=FieldHint.architecture, + valid=check_field(Assert.geq, 0), + ) + num_experts_per_token: int = Field( + default=1, + desc="Active experts for each token in a MoE model.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + expert_routing_type: RoutingType = Field( + default=RoutingType.topk, + desc="The routing method, i.e., the method used to assign experts to tokens.", + hint=FieldHint.architecture, + ) + activation_type: ActivationType = Field( + default=None, + desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", + hint=FieldHint.core, + ) + # Default: hidden_size**-0.5 + # TODO: Allow custom initialization (InitializationConfig?) + init_method_std: float = Field( + default=None, + desc="Default scale for weight initialization. Default: hidden_size**-0.5", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max: float | None = Field( + default=None, + desc="Max value for clamping initialized weights. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min: float | None = Field( + default=None, + desc="Min value for clamping initialized weights. Default: -float('inf')", + hint=FieldHint.optional, + ) + init_method_std_qkv: float = Field( + default=None, + desc="Scale for the query, key and value weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_qkv: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for query, key and value matrices. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_qkv: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for query, key and value matrices. Default: -float('inf')", + hint=FieldHint.optional, + ) + init_method_std_attn_proj: float = Field( + default=None, + desc="Scale for the attention projection weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_attn_proj: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for attention projection. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_attn_proj: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for attention projection. Default: -float('inf')", + hint=FieldHint.optional, + ) + init_method_std_mlp_1: float = Field( + default=None, + desc="Scale for the MLP first layer weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_mlp_1: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for MLP first layer. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_mlp_1: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for MLP first layer. Default: -float('inf')", + hint=FieldHint.optional, + ) + init_method_std_mlp_2: float = Field( + default=None, + desc="Scale for the MLP second layer weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_mlp_2: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for MLP second layer. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_mlp_2: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for MLP second layer. Default: -float('inf')", + hint=FieldHint.optional, + ) + # normalization_implementation: NormalizationImplementation = NormalizationImplementation.auto + mlp_recompute_level: MLPRecomputeLevel = Field( + default=MLPRecomputeLevel.none, + desc="Set which of the MLP intermediate activations will be recomputed during the backward passes. This provides a trade-off between memory and speed.", + hint=FieldHint.performance, + ) + # Use random inits instead of constant values, useful for debugging. + random_bias_init: bool = Field( + default=False, + desc="Initialize the biases using the initialization method of their respective weights instead of setting them to zero. Used to test for issues that may not be visible when the biases are zero.", + hint=FieldHint.testing, + ) + expert_auxiliary_loss_coefficient: float = Field( + default=0.01, + desc="Scale of the load balancing auxiliary loss for topk routing.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + expert_z_loss_coefficient: float = Field( + default=0.0, + desc="Regularize the router during training by applying Z-loss to the logits.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + moe_jitter_eps: float = Field( + default=0.0, + desc="Regularize the router during training by applying a random multiplicative noise `uniform(1-eps, 1+eps)` to the logits.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + router_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate for the MoE router weight.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + dropless_moe: bool = Field( + default=True, desc="Evaluate all the experts at once using dropless MoE.", hint=FieldHint.expert + ) + dropless_dynamic_shape: bool = Field( + default=False, + desc="Use a dynamic shape for dropless MLP instead of the worst-case value." + " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", + hint=FieldHint.expert, + ) + add_linear_biases: bool | AddLinearBiasChoices = Field( + default=True, + desc="Add biases to all, none or Q, K, V layers. Accepted values: True, False, or AddLinearBiasChoices.", + hint=FieldHint.architecture, + ) + + def _validate(self) -> None: + with self._set_implicit_default(): + if self.ffn_hidden_size is None: + self.ffn_hidden_size = 4 * self.hidden_size + if self.activation_type is None: + self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu + if self.init_method_std is None: + self.init_method_std = self.hidden_size**-0.5 + if self.init_method_std_qkv is None: + self.init_method_std_qkv = self.init_method_std + if self.init_method_std_attn_proj is None: + self.init_method_std_attn_proj = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 + if self.init_method_std_mlp_1 is None: + self.init_method_std_mlp_1 = self.init_method_std + if self.init_method_std_mlp_2 is None: + self.init_method_std_mlp_2 = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 + if self.init_method_max_qkv is None: + self.init_method_max_qkv = self.init_method_max + if self.init_method_min_qkv is None: + self.init_method_min_qkv = self.init_method_min + if self.init_method_max_attn_proj is None: + self.init_method_max_attn_proj = self.init_method_max + if self.init_method_min_attn_proj is None: + self.init_method_min_attn_proj = self.init_method_min + if self.init_method_max_mlp_1 is None: + self.init_method_max_mlp_1 = self.init_method_max + if self.init_method_min_mlp_1 is None: + self.init_method_min_mlp_1 = self.init_method_min + if self.init_method_max_mlp_2 is None: + self.init_method_max_mlp_2 = self.init_method_max + if self.init_method_min_mlp_2 is None: + self.init_method_min_mlp_2 = self.init_method_min + if self.init_method_min is not None and self.init_method_max is not None: + Assert.leq(self.init_method_min, self.init_method_max) + if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: + Assert.leq(self.init_method_min, self.init_method_max) + if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: + Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) + if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: + Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) + if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: + Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) + if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: + Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) + self.num_unshared_experts = self.num_experts - self.num_shared_experts + Assert.geq( + self.hidden_dropout, 0 + ) # Do we need to check it here again given that its is already asserted in the config field? + if self.norm_lr_scale is not None: + Assert.geq(self.norm_lr_scale, 0) + + if isinstance(self.mlp_lr_scale, list): + Assert.eq(len(self.mlp_lr_scale), self.num_experts) + for scale in self.mlp_lr_scale: + if scale is not None: + Assert.geq(scale, 0) + elif self.mlp_lr_scale is not None: + Assert.geq(self.mlp_lr_scale, 0) + super()._validate() + Assert.leq(self.num_shared_experts, self.num_experts) + Assert.leq(self.num_shared_experts + self.num_experts_per_token, self.num_experts) + + @property + def add_mlp_bias(self) -> bool: + if isinstance(self.add_linear_biases, bool): + return self.add_linear_biases + if self.add_linear_biases == AddLinearBiasChoices.everywhere: + return True + return False + + @property + def add_attn_qkv_bias(self) -> bool: + if isinstance(self.add_linear_biases, bool): + return self.add_linear_biases + if self.add_linear_biases == AddLinearBiasChoices.nowhere: + return False + return True + + @property + def add_attn_dense_bias(self) -> bool: + if isinstance(self.add_linear_biases, bool): + return self.add_linear_biases + if self.add_linear_biases == AddLinearBiasChoices.everywhere: + return True + return False + + def setup_tensor_space(self, tensor_space: TensorSpace, block_name: str = "") -> None: + tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) + + # Hidden dimension + tensor_space.add_tensor_dim(TensorDim(f"{LLMDimNames.hidden}_{block_name}", self.hidden_size)) + + # MLP dimensions + tensor_space.add_tensor_dim(mlp := TensorDim(f"{LLMDimNames.mlp}_{block_name}", self.ffn_hidden_size, tensor)) + tensor_space.add_tensor_dim( + gate_and_up := TensorDim(f"{LLMDimNames.gate_and_up}_{block_name}", 2 if self.gated else 1) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(f"{LLMDimNames.composite_gated_mlp}_{block_name}", (gate_and_up, mlp)) + ) + tensor_space.add_tensor_dim(experts := TensorDim(f"{LLMDimNames.experts}_{block_name}", self.num_experts)) + tensor_space.add_tensor_dim( + CompositeTensorDim(f"{LLMDimNames.composite_expert_mlp}_{block_name}", (experts, mlp)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(f"{LLMDimNames.composite_gated_expert_mlp}_{block_name}", (experts, gate_and_up, mlp)) + ) + tensor_space.add_tensor_dim(TensorDim(f"{LLMDimNames.top_experts}_{block_name}", self.num_experts_per_token)) + tensor_space.add_tensor_dim( + TensorDim(f"{LLMDimNames.unshared_experts}_{block_name}", self.num_unshared_experts) + ) + + # shared_experts + if self.num_shared_experts: + tensor_space.add_tensor_dim( + shared_experts := TensorDim(f"{LLMDimNames.shared_experts}_{block_name}", self.num_shared_experts) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(f"{LLMDimNames.composite_shared_expert_mlp}_{block_name}", (shared_experts, mlp)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim( + f"{LLMDimNames.composite_gated_shared_expert_mlp}_{block_name}", + (shared_experts, gate_and_up, mlp), + ) + ) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 6e6a8ae5..f1079f12 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,11 +1,12 @@ import typing +import warnings from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames -from fast_llm.functional.config import CrossEntropyImpl -from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.functional.config import CrossEntropyImpl, TritonConfig +from fast_llm.layers.common.config import NormalizationConfig from fast_llm.utils import Assert @@ -41,10 +42,10 @@ class LanguageModelKwargs: @config_class() class LanguageModelBaseConfig(BaseModelConfig): - transformer: TransformerConfig = Field( - desc="Configuration for the transformer architecture.", - hint=FieldHint.architecture, - ) + """ + Base config for language models. + """ + max_position_embeddings: int = Field( default=2048, desc="Number of absolute position embeddings, if applicable.", @@ -58,8 +59,8 @@ class LanguageModelBaseConfig(BaseModelConfig): valid=check_field(Assert.gt, 0), ) use_position_embeddings: bool = Field( - default=None, - desc="Enable absolute position embeddings. Default: Enable unless using rotary embeddings.", + default=True, + desc="Enable absolute position embeddings.", # Default: Enable unless using rotary embeddings.", hint=FieldHint.architecture, ) tie_word_embeddings: bool = Field( @@ -174,18 +175,49 @@ class LanguageModelBaseConfig(BaseModelConfig): doc="If not provided, all heads are equally weighted.", hint=FieldHint.feature, ) + # rotary: RotaryConfig = Field( + # desc="Configuration for the rotary positional embeddings.", + # hint=FieldHint.architecture, + # ) + full_precision_residual: bool = Field( + default=False, + desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", + hint=FieldHint.stability, + ) + debug: bool = Field( + default=False, + desc="Enable debug mode.", + hint=FieldHint.testing, + ) + embeddings_hidden_dropout: bool = Field( + default=0.0, + desc="Dropout applied to the embeddings.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + head_normalization: NormalizationConfig = Field( + desc="Configuration for the normalization in the head.", + hint=FieldHint.architecture, + ) def _validate(self) -> None: - self.transformer.validate() - with self._set_implicit_default(): - if self.use_position_embeddings is None: - self.use_position_embeddings = not self.transformer.rotary.enabled - if self.init_method_std_embed is None: - self.init_method_std_embed = self.transformer.init_method_std - if self.init_method_max_embed is None: - self.init_method_max_embed = self.transformer.init_method_max - if self.init_method_min_embed is None: - self.init_method_min_embed = self.transformer.init_method_min + # # self.transformer.validate() + # with self._set_implicit_default(): + # if self.use_position_embeddings is None: + # self.use_position_embeddings = not self.rotary.enabled + # if self.init_method_std_embed is None: + # self.init_method_std_embed = self.transformer.init_method_std + # if self.init_method_max_embed is None: + # self.init_method_max_embed = self.transformer.init_method_max + # if self.init_method_min_embed is None: + # self.init_method_min_embed = self.transformer.init_method_min + + if not TritonConfig.TRITON_ENABLED: + warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") + Assert.geq( + self.embeddings_hidden_dropout, 0 + ) # Do we need to check it here again given that its is already asserted in the config field? + super()._validate() if self.init_method_max_embed is not None and self.init_method_min_embed is not None: Assert.leq(self.init_method_min_embed, self.init_method_max_embed) @@ -198,7 +230,7 @@ def _validate(self) -> None: Assert.geq(coeff, 0) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - self.transformer.setup_tensor_space(tensor_space) + # self.transformer.setup_tensor_space(tensor_space) tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Embedding dimensions diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index e0386d8d..fede81fc 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -7,8 +7,9 @@ from fast_llm.core.ops import reduce_forward, split from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.common.config import LLMDimNames from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelDimNames, LanguageModelKwargs -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ from fast_llm.utils import Assert @@ -37,16 +38,16 @@ def __init__( self._tensor_space = tensor_space self._residual_dtype = ( self._distributed_config.optimization_dtype - if config.transformer.full_precision_residual + if config.full_precision_residual else self._distributed_config.training_dtype ).torch self._group_size = self._distributed_config.tensor_parallel self._sequence_parallel = self._distributed_config.sequence_tensor_parallel self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings - self._dropout_p = config.transformer.hidden_dropout + self._dropout_p = config.embeddings_hidden_dropout self._use_absolute_position_embeddings = config.use_absolute_position_embeddings - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = tensor_space.get_tensor_dim(LLMDimNames.input_hidden) vocab_dim = tensor_space.get_tensor_dim( LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab ) @@ -76,12 +77,12 @@ def __init__( lr_scale=config.embeddings_lr_scale, ) - # PEFT. - self.word_embeddings_weight = self._config.transformer.peft.apply_weight(self.word_embeddings_weight) - if hasattr(self, "position_embeddings_weight"): - self.position_embeddings_weight = self._config.transformer.peft.apply_weight( - self.position_embeddings_weight - ) + # PEFT: layer freezing should be done by explicitly setting embeddings_lr_scale to 0.0 + # self.word_embeddings_weight = self._config.peft.apply_weight(self.word_embeddings_weight) + # if hasattr(self, "position_embeddings_weight"): + # self.position_embeddings_weight = self._config.peft.apply_weight( + # self.position_embeddings_weight + # ) @torch.compile def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None) -> torch.Tensor: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 233887ec..6d6d8d6f 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -15,6 +15,7 @@ from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss +from fast_llm.layers.common.config import LLMDimNames from fast_llm.layers.language_model.config import ( LanguageModelBaseConfig, LanguageModelDimNames, @@ -44,7 +45,7 @@ def __init__( prediction_distance: int, ): super().__init__(config) - self._debug_transformer = config.transformer.debug_transformer + self._debug_transformer = config.debug self._tie_word_embeddings = config.tie_word_embeddings self._tensor_space = tensor_space @@ -58,13 +59,13 @@ def __init__( if self._cross_entropy_splits is not None and self._sequence_parallel: assert not self._parallel_embeddings - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space.get_tensor_dim(LLMDimNames.output_hidden) self._loss_coefficient = ( config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 ) self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) - self.final_norm = config.transformer.normalization.get_layer(hidden_dim) + self.final_norm = config.head_normalization.get_layer(hidden_dim) self._logits_scale_factor = config.logits_scale_factor self._z_loss_factor = config.logit_z_loss @@ -92,12 +93,12 @@ def __init__( self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) - # PEFT. - self.final_norm = self._config.transformer.peft.apply_other(self.final_norm) - if hasattr(self, "output_weights"): - self.output_weights = self._config.transformer.peft.apply_weight(self.output_weights) + # PEFT: layer freezing should be done by explicitly setting output_lr_scale to 0.0 + # self.final_norm = self._config.transformer.peft.apply_other(self.final_norm) + # if hasattr(self, "output_weights"): + # self.output_weights = self._config.transformer.peft.apply_weight(self.output_weights) - def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: + def _init_output_weights(self, hidden_dim: TensorDim, config: LanguageModelBaseConfig) -> None: # Only the first head defines the output weights if self._tie_word_embeddings or self._prediction_distance > 0: return diff --git a/fast_llm/layers/ssm/blocks.py b/fast_llm/layers/ssm/blocks.py new file mode 100644 index 00000000..35be0eb9 --- /dev/null +++ b/fast_llm/layers/ssm/blocks.py @@ -0,0 +1,55 @@ +import typing + +from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 +from fast_llm.layers.ssm.mamba_layer import MambaLayer +from fast_llm.layers.transformer.transformer import BaseBlock + +if typing.TYPE_CHECKING: + from fast_llm.engine.config_utils.tensor_space import TensorSpace + from fast_llm.models.hybrid.config import MambaBlockConfig + + +class LlambaBlock(BaseBlock): + """ + A transformer-like decoder block with a discrete Mamba 2 mixer, see https://arxiv.org/abs/2502.14458 + """ + + _mixer_module_name = "mixer" + + def __init__( + self, + config: "MambaBlockConfig", + tensor_space: "TensorSpace", + layer_index: int, + block_name: str = "", + return_input: bool = False, + ): + super().__init__(config, tensor_space, layer_index, block_name, return_input) + + def _create_mixer(self): + self.mixer = DiscreteMamba2( + self._config, layer_idx=self._layer_index, tensor_space=self._tensor_space, name=self.block_name + ) + + +class LlambaOneBlock(BaseBlock): + """ + A transformer-like decoder block with a Mamba 1 mixer. + """ + + _mixer_module_name = "mamba1" + + def __init__( + self, + config: "MambaBlockConfig", + tensor_space: "TensorSpace", + layer_index: int, + block_name: str = "", + return_input: bool = False, + ): + super().__init__(config, tensor_space, layer_index, block_name, return_input) + + def _create_mixer(self): + self.mixer = MambaLayer( + self._config, layer_idx=self._layer_index, tensor_space=self._tensor_space, name=self.block_name + ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index d4f9f84d..39e0902b 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -31,8 +31,9 @@ class DiscreteMamba2(torch.nn.Module): def __init__( self, config: SSMConfig, - layer_idx: int, + layer_index: int, tensor_space: TensorSpace, + name: str = "", return_input: bool = False, ): """ @@ -46,20 +47,20 @@ def __init__( super().__init__() self.config: SSMConfig = config bias = config.add_bias_linear - self.layer_idx = layer_idx + self.layer_idx = layer_index self._return_input = return_input - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) - logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}") - - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) - td_conv = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) - td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.qk_heads) - td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.v_heads) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) - td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.inner_proj_mamba2) + logger.info(f"Setting lr_scale for layer {layer_index} of type {type(self)}: {mamba_layer_lr_scale}") + + td_inner = tensor_space.get_tensor_dim(f"{SSMDimNames.inner_dim}_{name}") + td_state = tensor_space.get_tensor_dim(f"{SSMDimNames.state_dim}_{name}") + td_model = tensor_space.get_tensor_dim(f"{SSMDimNames.model_dim}_{name}") + td_conv = tensor_space.get_tensor_dim(f"{SSMDimNames.conv_dim}_{name}") + td_n_qk_heads = tensor_space.get_tensor_dim(f"{SSMDimNames.qk_heads}_{name}") + td_n_v_heads = tensor_space.get_tensor_dim(f"{SSMDimNames.v_heads}_{name}") + td_conv_kernel = tensor_space.get_tensor_dim(f"{SSMDimNames.conv_kernel_size}_{name}") + td_inner_proj = tensor_space.get_tensor_dim(f"{SSMDimNames.inner_proj_mamba2}_{name}") self.d_model = td_model.size self.d_inner = td_inner.size diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py deleted file mode 100644 index ee222d6d..00000000 --- a/fast_llm/layers/ssm/llamba_block.py +++ /dev/null @@ -1,34 +0,0 @@ -import typing - -from fast_llm.layers.transformer.transformer import BaseBlock - -if typing.TYPE_CHECKING: - from fast_llm.engine.config_utils.tensor_space import TensorSpace - from fast_llm.layers.ssm.config import SSMConfig - from fast_llm.layers.transformer.config import TransformerConfig - - -class LlambaBlock(BaseBlock): - """ - A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 - """ - - _name = "Llamba block" - _mixer_module_name = "mixer" - - def __init__( - self, - config_transformer: "TransformerConfig", - config_ssm: "SSMConfig", - tensor_space: "TensorSpace", - mixer_cls, - layer_index: int, - return_input: bool = False, - ): - self.mixer_cls = mixer_cls - self._config_ssm = config_ssm - self._debug_mode = self._config_ssm.debug_ssm - super().__init__(config_transformer, tensor_space, layer_index, return_input) - - def _create_mixer(self): - self.mixer = self.mixer_cls(self._config_ssm, layer_idx=self._layer_index, tensor_space=self._tensor_space) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index e44a4e1d..1160630d 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -56,14 +56,15 @@ class MambaLayer(torch.nn.Module): def __init__( self, config: SSMConfig, - layer_idx: int, tensor_space: TensorSpace, + layer_index: int, + name: str = "", return_input: bool = False, ): factory_kwargs = {} super().__init__() self.config: SSMConfig = config - self.layer_idx = layer_idx + self.layer_idx = layer_index self._debug_mode = config.debug_ssm @@ -72,17 +73,17 @@ def __init__( td_inner_proj = tensor_space.get_tensor_dim( SSMDimNames.inner_proj_mamba ) # TensorDim("D_inner_2", self.d_inner * 2) - tdt_rank = tensor_space.get_tensor_dim(SSMDimNames.dt_rank) - td_x_proj = tensor_space.get_tensor_dim(SSMDimNames.x_proj_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) + tdt_rank = tensor_space.get_tensor_dim(f"{SSMDimNames.dt_rank}_{name}") + td_x_proj = tensor_space.get_tensor_dim(f"{SSMDimNames.x_proj_dim}_{name}") + td_state = tensor_space.get_tensor_dim(f"{SSMDimNames.state_dim}_{name}") + td_model = tensor_space.get_tensor_dim(f"{SSMDimNames.model_dim}_{name}") + td_conv_kernel = tensor_space.get_tensor_dim(f"{SSMDimNames.conv_kernel_size}_{name}") self.d_conv = td_conv_kernel.size self.d_inner = td_inner.size self.d_state = td_state.size self.d_model = td_model.size self.dt_rank = tdt_rank.size - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) self.in_proj_weight = ParameterMeta.from_dims( diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 0517c49c..60e1ab0a 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -17,7 +17,7 @@ ) from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.utils import Assert try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -80,6 +80,7 @@ def __init__( config: TransformerConfig, tensor_space: TensorSpace, layer_index, + block_name: str = "", ): super().__init__() self._config = config @@ -87,7 +88,7 @@ def __init__( Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1)) self._layer_index = layer_index self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel - self._debug_transformer = self._config.debug_transformer + self._debug_transformer = self._config.debug_block self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) init_method_qkv = init_normal_( @@ -101,22 +102,27 @@ def __init__( max_val=self._config.init_method_max_attn_proj, ) - self._kv_channels = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels).size - self._head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).global_size - self._local_head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).size - self._local_heads_per_group = self._tensor_space.get_tensor_dim(TransformerDimNames.group_heads).size + self._kv_channels = self._tensor_space.get_tensor_dim(f"{TransformerDimNames.kv_channels}_{block_name}").size + self._head_groups = self._tensor_space.get_tensor_dim( + f"{TransformerDimNames.head_groups}_{block_name}" + ).global_size + self._local_head_groups = self._tensor_space.get_tensor_dim( + f"{TransformerDimNames.head_groups}_{block_name}" + ).size + self._local_heads_per_group = self._tensor_space.get_tensor_dim( + f"{TransformerDimNames.group_heads}_{block_name}" + ).size self._local_heads = self._local_head_groups * self._local_heads_per_group self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space.get_tensor_dim(f"{TransformerDimNames.hidden}_{block_name}") - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None - attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) + attention_lr_scale = self._config.attention_lr_scale # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_query), + self._tensor_space.get_tensor_dim(f"{TransformerDimNames.composite_query}_{block_name}"), bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -125,7 +131,7 @@ def __init__( ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_key_value), + self._tensor_space.get_tensor_dim(f"{TransformerDimNames.composite_key_value}_{block_name}"), bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -136,7 +142,7 @@ def __init__( # Output. self.dense = InputParallelLinear( - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_dense), + self._tensor_space.get_tensor_dim(f"{TransformerDimNames.composite_dense}_{block_name}"), hidden_dim, bias=self._config.add_attn_dense_bias, weight_init_method=init_method_std_attn_proj, diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 9cc9510b..5aa62553 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -10,34 +10,25 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames -from fast_llm.functional.config import ActivationType, MLPRecomputeLevel, TritonConfig -from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig, PeftConfig, PeftType +from fast_llm.functional.config import TritonConfig +from fast_llm.layers.common.config import ( + BaseBlockConfig, + BaseBlockPeftConfig, + BaseBlockSubLayerName, + LLMDimNames, + PeftType, +) from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: - import torch + pass from fast_llm.layers.common.linear import LinearBase, LinearLike - from fast_llm.tensor import ParameterMeta logger = logging.getLogger(__name__) -class RoutingType(str, enum.Enum): - topk = "aux_loss" - sinkhorn = "sinkhorn" - - -class TransformerDimNames: - # A set of common tensor dim names packed into a namespace. - # Input dimensions (variable) - # TODO: Does batch belong here? - batch = "batch" - # TODO: Distinguish micro-sequence? - sequence_q = "sequence_q" - sequence_q_tp = "sequence_q_tp" - sequence_k = "sequence_k" - hidden = "hidden" +class TransformerDimNames(LLMDimNames): # Self-attention dimensions head_groups = "head_groups" group_heads = "group_heads" @@ -47,18 +38,6 @@ class TransformerDimNames: composite_query = "composite_query" composite_key_value = "composite_key_value" composite_dense = "composite_dense" - # MLP dimensions - mlp = "mlp" - gate_and_up = "gate_and_up" - composite_gated_mlp = "composite_gated_mlp" - experts = "experts" - top_experts = "top_experts" - shared_experts = "shared_experts" - unshared_experts = "unshared_experts" - composite_expert_mlp = "composite_expert_mlp" - composite_gated_expert_mlp = "composite_gated_expert_mlp" - composite_shared_expert_mlp = "composite_shared_expert_mlp" - composite_gated_shared_expert_mlp = "composite_gated_shared_expert_mlp" class TransformerKwargs: @@ -164,25 +143,23 @@ def _validate(self) -> None: RotaryConfig.register_subclass(name.value, RotaryConfig) -class AddLinearBiasChoices(str, enum.Enum): - nowhere = "nowhere" - everywhere = "everywhere" - only_attn_qkv = "only_attn_qkv" - - -class TransformerSubLayerName(str, enum.Enum): +class TransformerSubLayerName(BaseBlockSubLayerName): # TODO: Use this to replace AddLinearBiasChoices. query = "query" key = "key" value_ = "value" key_value = "key_value" dense = "dense" - mlp_1 = "mlp_1" - mlp_2 = "mlp_2" @config_class(registry=True) -class TransformerPeftConfig(PeftConfig): +class TransformerPeftConfig(BaseBlockPeftConfig): + """ + Peft Cofnig that applies to transformer layer. If this is used with GPTBaseModel it is reused for all transformer layers. + Note, this does not freeze layers! + If you want to freeze weights, you need to do so explicitly by setting the corresponding layer's lr_scales (embeddings/mlp etc.) to 0. + """ + layers: list[TransformerSubLayerName] = Field( default=None, desc="The layers on which to apply LoRA.", @@ -206,18 +183,23 @@ def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName linear.weight.requires_grad = False return linear - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - if self.type != PeftType.none and self.freeze_others: - for parameter in module.parameters(): - parameter.requires_grad = False - return module - - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - if self.type != PeftType.none and self.freeze_others: - parameter.requires_grad = False - return parameter + # def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": + # warnings.warn("TransformerPeftConfig.apply_other is deprecated. Use explicit layer freezing using e.g. learning rate scaling parameters.") + # # if self.type != PeftType.none and self.freeze_others: + # # for parameter in module.parameters(): + # # parameter.requires_grad = False + # # return module + + # def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": + # if self.type != PeftType.none and self.freeze_others: + # warnings.warn( + # "Freezing weights with TransformerPeftConfig. Note, this does not freeze the embeddings or output heads, those must be frozen explicitly using their lr_scales." + # ) + # parameter.requires_grad = False + # return parameter def _validate(self) -> None: + super()._validate() if self.layers is None: with self._set_implicit_default(): # Setting the default layers only whee PeFT is enabled @@ -228,9 +210,6 @@ def _validate(self) -> None: else [] ) if self.type != PeftType.none: - if TransformerSubLayerName.mlp_1 in self.layers or TransformerSubLayerName.mlp_2 in self.layers: - # TODO: Add MLP support. - raise NotImplementedError("LoRA not supported for MLP.") if TransformerSubLayerName.dense in self.layers: # TODO: Support InputParallelLinear (different output format). raise NotImplementedError("LoRA not supported for attention dense layer.") @@ -257,12 +236,12 @@ def _validate(self) -> None: @config_class() -class TransformerConfig(LLMBlockConfig): +class TransformerConfig(BaseBlockConfig): _abstract = False - normalization: NormalizationConfig = Field( - desc="Configuration for the normalization layers architecture.", - hint=FieldHint.architecture, - ) + # normalization: NormalizationConfig = Field( + # desc="Configuration for the normalization layers architecture.", + # hint=FieldHint.architecture, + # ) rotary: RotaryConfig = Field( desc="Configuration for the rotary positional embeddings.", hint=FieldHint.architecture, @@ -271,176 +250,20 @@ class TransformerConfig(LLMBlockConfig): desc="Configuration for the parameter-efficient fine tuning.", hint=FieldHint.architecture, ) - num_layers: int = Field( - default=12, - desc="Number of layers in the transformer.", - hint=FieldHint.architecture, - valid=check_field(Assert.geq, 0), - ) - hidden_size: int = Field( - default=1024, - desc="Size of the transformer's main hidden dimension, e.g., for its input and output layers.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - num_attention_heads: int = Field(default=8, desc="Number of attention heads.", hint=FieldHint.architecture) - head_groups: int = Field( - default=1, - desc="Number of head group for grouped query attention.", - doc="Set to 1 for multi-query attention, `num_attention_heads` for multi-head.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - add_linear_biases: bool | AddLinearBiasChoices = Field( - default=True, - desc="Add biases to all, none or Q, K, V layers. Accepted values: True, False, or AddLinearBiasChoices.", - hint=FieldHint.architecture, - ) - ffn_hidden_size: int = Field( - default=None, - desc="Hidden dimension of the MLP intermediate state. Default: 4 * hidden_size.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - kv_channels: int = Field( - default=None, - desc="Number of key and value channels, i.e., hidden dimension of each attention head. Default: hidden_size // num_attention_heads", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - gated: bool = Field(default=False, desc="Enable gated MLP.", hint=FieldHint.architecture) - num_experts: int = Field( - default=1, - desc="Number of MLP experts in a Mixture of Expert (MoE) model", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - num_shared_experts: int = Field( - default=0, - desc="Number of MLP experts that are shared between all tokens, i.e., always enabled.", - hint=FieldHint.architecture, - valid=check_field(Assert.geq, 0), - ) - num_unshared_experts: int = Field( - init=False, - desc="Number of MLP experts excluding shared ones", - hint=FieldHint.architecture, - valid=check_field(Assert.geq, 0), - ) - num_experts_per_token: int = Field( - default=1, - desc="Active experts for each token in a MoE model.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - expert_routing_type: RoutingType = Field( - default=RoutingType.topk, - desc="The routing method, i.e., the method used to assign experts to tokens.", - hint=FieldHint.architecture, - ) - activation_type: ActivationType = Field( - default=None, - desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", - hint=FieldHint.core, - ) - # Default: hidden_size**-0.5 - # TODO: Allow custom initialization (InitializationConfig?) - init_method_std: float = Field( - default=None, - desc="Default scale for weight initialization. Default: hidden_size**-0.5", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max: float | None = Field( - default=None, - desc="Max value for clamping initialized weights. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min: float | None = Field( - default=None, - desc="Min value for clamping initialized weights. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_qkv: float = Field( - default=None, - desc="Scale for the query, key and value weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_qkv: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for query, key and value matrices. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_qkv: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for query, key and value matrices. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_attn_proj: float = Field( - default=None, - desc="Scale for the attention projection weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_attn_proj: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for attention projection. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_attn_proj: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for attention projection. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_mlp_1: float = Field( - default=None, - desc="Scale for the MLP first layer weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_mlp_1: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for MLP first layer. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_mlp_1: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for MLP first layer. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_mlp_2: float = Field( - default=None, - desc="Scale for the MLP second layer weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_mlp_2: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for MLP second layer. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_mlp_2: float | None = Field( + attention_lr_scale: float | None = Field( default=None, - desc="Min value for clamping initialized weights for MLP second layer. Default: -float('inf')", - hint=FieldHint.optional, - ) - attention_dropout: float = Field( - default=0.0, - desc="Dropout applied to the attention intermediate states.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - hidden_dropout: float = Field( - default=0.0, - desc="Dropout applied to the residual connections.", + desc="Custom learning rate scale for the Attention projection weights.", + doc="Can be used in muP to scale the Attention learning rate by 1/width_factor", hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), + valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - full_precision_residual: bool = Field( - default=False, - desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", - hint=FieldHint.stability, + attention_softmax_scale_power: float = Field( + default=0.5, + desc="The scaling power to apply to kv_channel in the attention calculation. " + " Under Standard Parameterization (SP): default to 0.5. " + " Under muP (if scaling kv_channels size): use 1. " + " Under muP (if scaling number of heads instead of kv_channels): use 0.5.", + valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) # Use flash attention if possible (fp16 or bf16) use_flash_attention: bool = Field( @@ -458,179 +281,120 @@ class TransformerConfig(LLMBlockConfig): hint=FieldHint.optional, valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - # normalization_implementation: NormalizationImplementation = NormalizationImplementation.auto - mlp_recompute_level: MLPRecomputeLevel = Field( - default=MLPRecomputeLevel.none, - desc="Set which of the MLP intermediate activations will be recomputed during the backward passes. This provides a trade-off between memory and speed.", - hint=FieldHint.performance, - ) - debug_transformer: int = Field( - default=0, - desc="Log the output of each operation in a transformer layer.", - hint=FieldHint.logging, - valid=check_field(Assert.geq, 0), - ) - debug_transformer_memory: bool = Field( - default=False, - desc="Log the memory usage after each operation in a transformer layer..", - hint=FieldHint.logging, - ) - # Use random inits instead of constant values, useful for debugging. - random_bias_init: bool = Field( - default=False, - desc="Initialize the biases using the initialization method of their respective weights instead of setting them to zero. Used to test for issues that may not be visible when the biases are zero.", - hint=FieldHint.testing, - ) - expert_auxiliary_loss_coefficient: float = Field( - default=0.01, - desc="Scale of the load balancing auxiliary loss for topk routing.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - expert_z_loss_coefficient: float = Field( - default=0.0, - desc="Regularize the router during training by applying Z-loss to the logits.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - moe_jitter_eps: float = Field( + attention_dropout: float = Field( default=0.0, - desc="Regularize the router during training by applying a random multiplicative noise `uniform(1-eps, 1+eps)` to the logits.", + desc="Dropout applied to the attention intermediate states.", hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - mlp_lr_scale: float | None | list[float | None] = Field( - default=None, - desc="Custom learning rate scale for each expert.", - doc="May be used to freeze some experts by setting their scale to zero.", - hint=FieldHint.feature, - ) - router_lr_scale: float | None = Field( - default=None, - desc="Custom learning rate for the MoE router weight.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - attention_lr_scale: float | None = Field( + kv_channels: int = Field( default=None, - desc="Custom learning rate scale for the Attention projection weights.", - doc="Can be used in muP to scale the Attention learning rate by 1/width_factor", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - attention_softmax_scale_power: float = Field( - default=0.5, - desc="The scaling power to apply to kv_channel in the attention calculation. " - " Under Standard Parameterization (SP): default to 0.5. " - " Under muP (if scaling kv_channels size): use 1. " - " Under muP (if scaling number of heads instead of kv_channels): use 0.5.", - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - dropless_moe: bool = Field( - default=True, desc="Evaluate all the experts at once using dropless MoE.", hint=FieldHint.expert + desc="Number of key and value channels, i.e., hidden dimension of each attention head. Default: hidden_size // num_attention_heads", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), ) - dropless_dynamic_shape: bool = Field( - default=False, - desc="Use a dynamic shape for dropless MLP instead of the worst-case value." - " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", - hint=FieldHint.expert, + num_attention_heads: int = Field(default=8, desc="Number of attention heads.", hint=FieldHint.architecture) + head_groups: int = Field( + default=1, + desc="Number of head group for grouped query attention.", + doc="Set to 1 for multi-query attention, `num_attention_heads` for multi-head.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), ) def _validate(self) -> None: + super()._validate() with self._set_implicit_default(): - if self.ffn_hidden_size is None: - self.ffn_hidden_size = 4 * self.hidden_size if self.kv_channels is None: self.kv_channels = div(self.hidden_size, self.num_attention_heads) - if self.activation_type is None: - self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu - if self.init_method_std is None: - self.init_method_std = self.hidden_size**-0.5 - if self.init_method_std_qkv is None: - self.init_method_std_qkv = self.init_method_std - if self.init_method_std_attn_proj is None: - self.init_method_std_attn_proj = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 - if self.init_method_std_mlp_1 is None: - self.init_method_std_mlp_1 = self.init_method_std - if self.init_method_std_mlp_2 is None: - self.init_method_std_mlp_2 = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 - if self.init_method_max_qkv is None: - self.init_method_max_qkv = self.init_method_max - if self.init_method_min_qkv is None: - self.init_method_min_qkv = self.init_method_min - if self.init_method_max_attn_proj is None: - self.init_method_max_attn_proj = self.init_method_max - if self.init_method_min_attn_proj is None: - self.init_method_min_attn_proj = self.init_method_min - if self.init_method_max_mlp_1 is None: - self.init_method_max_mlp_1 = self.init_method_max - if self.init_method_min_mlp_1 is None: - self.init_method_min_mlp_1 = self.init_method_min - if self.init_method_max_mlp_2 is None: - self.init_method_max_mlp_2 = self.init_method_max - if self.init_method_min_mlp_2 is None: - self.init_method_min_mlp_2 = self.init_method_min - if self.init_method_min is not None and self.init_method_max is not None: - Assert.leq(self.init_method_min, self.init_method_max) - if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: - Assert.leq(self.init_method_min, self.init_method_max) - if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: - Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) - if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: - Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) - if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: - Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) - if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: - Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) - self.num_unshared_experts = self.num_experts - self.num_shared_experts - - super()._validate() - - if not TritonConfig.TRITON_ENABLED: - warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") - - Assert.leq(self.num_shared_experts, self.num_experts) - Assert.leq(self.num_shared_experts + self.num_experts_per_token, self.num_experts) Assert.multiple(self.num_attention_heads, self.head_groups) Assert.geq(self.attention_dropout, 0) - Assert.geq(self.hidden_dropout, 0) - if isinstance(self.mlp_lr_scale, list): - Assert.eq(len(self.mlp_lr_scale), self.num_experts) - for scale in self.mlp_lr_scale: - if scale is not None: - Assert.geq(scale, 0) - elif self.mlp_lr_scale is not None: - Assert.geq(self.mlp_lr_scale, 0) + # with self._set_implicit_default(): + # if self.ffn_hidden_size is None: + # self.ffn_hidden_size = 4 * self.hidden_size + # if self.kv_channels is None: + # self.kv_channels = div(self.hidden_size, self.num_attention_heads) + # if self.activation_type is None: + # self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu + # if self.init_method_std is None: + # self.init_method_std = self.hidden_size**-0.5 + # if self.init_method_std_qkv is None: + # self.init_method_std_qkv = self.init_method_std + # if self.init_method_std_attn_proj is None: + # self.init_method_std_attn_proj = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 + # if self.init_method_std_mlp_1 is None: + # self.init_method_std_mlp_1 = self.init_method_std + # if self.init_method_std_mlp_2 is None: + # self.init_method_std_mlp_2 = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 + # if self.init_method_max_qkv is None: + # self.init_method_max_qkv = self.init_method_max + # if self.init_method_min_qkv is None: + # self.init_method_min_qkv = self.init_method_min + # if self.init_method_max_attn_proj is None: + # self.init_method_max_attn_proj = self.init_method_max + # if self.init_method_min_attn_proj is None: + # self.init_method_min_attn_proj = self.init_method_min + # if self.init_method_max_mlp_1 is None: + # self.init_method_max_mlp_1 = self.init_method_max + # if self.init_method_min_mlp_1 is None: + # self.init_method_min_mlp_1 = self.init_method_min + # if self.init_method_max_mlp_2 is None: + # self.init_method_max_mlp_2 = self.init_method_max + # if self.init_method_min_mlp_2 is None: + # self.init_method_min_mlp_2 = self.init_method_min + # if self.init_method_min is not None and self.init_method_max is not None: + # Assert.leq(self.init_method_min, self.init_method_max) + # if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: + # Assert.leq(self.init_method_min, self.init_method_max) + # if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: + # Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) + # if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: + # Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) + # if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: + # Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) + # if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: + # Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) + # self.num_unshared_experts = self.num_experts - self.num_shared_experts + + # super()._validate() + + # # if not TritonConfig.TRITON_ENABLED: + # # warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") + + # Assert.leq(self.num_shared_experts, self.num_experts) + # Assert.leq(self.num_shared_experts + self.num_experts_per_token, self.num_experts) + # Assert.multiple(self.num_attention_heads, self.head_groups) + # Assert.geq(self.attention_dropout, 0) @functools.cached_property def projection_size(self): assert self._validated return self.num_attention_heads * self.kv_channels - @property - def add_mlp_bias(self) -> bool: - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False - - @property - def add_attn_qkv_bias(self) -> bool: - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.nowhere: - return False - return True - - @property - def add_attn_dense_bias(self) -> bool: - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False + # # @property + # def add_mlp_bias(self) -> bool: + # if isinstance(self.add_linear_biases, bool): + # return self.add_linear_biases + # if self.add_linear_biases == AddLinearBiasChoices.everywhere: + # return True + # return False + + # @property + # def add_attn_qkv_bias(self) -> bool: + # if isinstance(self.add_linear_biases, bool): + # return self.add_linear_biases + # if self.add_linear_biases == AddLinearBiasChoices.nowhere: + # return False + # return True + + # @property + # def add_attn_dense_bias(self) -> bool: + # if isinstance(self.add_linear_biases, bool): + # return self.add_linear_biases + # if self.add_linear_biases == AddLinearBiasChoices.everywhere: + # return True + # return False @classmethod def _from_dict( @@ -650,65 +414,47 @@ def _from_dict( cls._handle_renamed_field(default, "triton_rotary", ("rotary", "triton")) return super()._from_dict(default, strict, flat) - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + def setup_tensor_space(self, tensor_space: TensorSpace, block_name: str = "") -> None: + super().setup_tensor_space(tensor_space, block_name) tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.hidden, self.hidden_size)) - # Self-attention dimensions tensor_space.add_tensor_dim( head_groups := TensorDim( - TransformerDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + f"{TransformerDimNames.head_groups}_{block_name}", + self.head_groups, + tensor if self.head_groups > 1 else None, ) ) tensor_space.add_tensor_dim( group_heads := TensorDim( - TransformerDimNames.group_heads, + f"{TransformerDimNames.group_heads}_{block_name}", div(self.num_attention_heads, self.head_groups), None if self.head_groups > 1 else tensor, ) ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(TransformerDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(TransformerDimNames.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim(key_and_value := TensorDim(f"{TransformerDimNames.key_and_value}_{block_name}", 2)) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_heads, (head_groups, group_heads)) + kv_channels := TensorDim(f"{TransformerDimNames.kv_channels}_{block_name}", self.kv_channels) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_query, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(f"{TransformerDimNames.composite_heads}_{block_name}", (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) + CompositeTensorDim( + f"{TransformerDimNames.composite_query}_{block_name}", (head_groups, group_heads, kv_channels) + ) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_dense, (head_groups, group_heads, kv_channels)) + CompositeTensorDim( + f"{TransformerDimNames.composite_key_value}_{block_name}", (key_and_value, head_groups, kv_channels) + ) ) - - # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(TransformerDimNames.mlp, self.ffn_hidden_size, tensor)) - tensor_space.add_tensor_dim(gate_and_up := TensorDim(TransformerDimNames.gate_and_up, 2 if self.gated else 1)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_gated_mlp, (gate_and_up, mlp))) - tensor_space.add_tensor_dim(experts := TensorDim(TransformerDimNames.experts, self.num_experts)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_expert_mlp, (experts, mlp))) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) - ) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.unshared_experts, self.num_unshared_experts)) - - # shared_experts - if self.num_shared_experts: - tensor_space.add_tensor_dim( - shared_experts := TensorDim(TransformerDimNames.shared_experts, self.num_shared_experts) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_shared_expert_mlp, (shared_experts, mlp)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim( - TransformerDimNames.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) - ) + CompositeTensorDim( + f"{TransformerDimNames.composite_dense}_{block_name}", (head_groups, group_heads, kv_channels) ) + ) def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: use_flash_attention = self.use_flash_attention and distributed_config.training_dtype in ( diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index c4d8afdc..73afd745 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -7,16 +7,17 @@ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd +from fast_llm.layers.common.config import BaseBlockConfig from fast_llm.layers.common.linear import LinearBase -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerSubLayerName +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerSubLayerName from fast_llm.tensor import init_normal_, init_zeros_ -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.utils import Assert class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: BaseBlockConfig, tensor_space: TensorSpace, block_name: str = "", layer_index: int = 0): super().__init__() - self._name = name + self._block_name = block_name self._layer_index = layer_index init_method_1 = init_normal_( @@ -30,8 +31,10 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s max_val=config.init_method_max_mlp_2, ) - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - self._intermediate_dim = tensor_space.get_tensor_dim(TransformerDimNames.composite_expert_mlp) + hidden_dim = tensor_space.get_tensor_dim(f"{TransformerDimNames.hidden}_{block_name}") + self._intermediate_dim = tensor_space.get_tensor_dim( + f"{TransformerDimNames.composite_expert_mlp}_{block_name}" + ) self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._recompute_level = config.mlp_recompute_level @@ -39,14 +42,12 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._activation_type = config.activation_type self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None - lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale - lr_scale = get_lr_scale(lr_scale, layer_lr_scale) + lr_scale = config.mlp_lr_scale # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - tensor_space.get_tensor_dim(TransformerDimNames.composite_gated_expert_mlp), + tensor_space.get_tensor_dim(f"{TransformerDimNames.composite_gated_expert_mlp}_{block_name}"), bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, @@ -69,9 +70,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: BaseBlockConfig, tensor_space: TensorSpace, block_name: str = "", layer_index: int = 0): Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, name) + super().__init__(config, tensor_space, block_name, layer_index) def forward( self, diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index b51ba1e9..aa9682e7 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -8,6 +8,7 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.common.config import BaseBlockConfig from fast_llm.layers.transformer.attention import Attention from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP @@ -26,32 +27,37 @@ class BaseBlock(Layer, abc.ABC): _mixer_module_name = "self_attn" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, + config: BaseBlockConfig, + tensor_space: TensorSpace, + layer_index: int, + block_name: str = "", + return_input: bool = False, ): super().__init__() self._config: TransformerConfig = config self._tensor_space: TensorSpace = tensor_space self._dropout_p: float = self._config.hidden_dropout + self.block_name = block_name # this name is used for tensor space setup and corresponds to the block name in the hybrid setup or to "" in the old setup (GPT Model) # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input self._layer_index = layer_index - self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + self._debug_mode = self._config.debug_block or self._config.debug_block_memory + hidden_dim = self._tensor_space.get_tensor_dim(f"{TransformerDimNames.hidden}_{block_name}") # Note, layer_lr_scale does not impact the norms - # TODO: add a seperate norm_lr_scale - self.norm_1 = self._config.normalization.get_layer(hidden_dim) - self.norm_2 = self._config.normalization.get_layer(hidden_dim) + self.norm_1 = self._config.normalization.get_layer(hidden_dim, lr_scale=self._config.norm_lr_scale) + self.norm_2 = self._config.normalization.get_layer(hidden_dim, lr_scale=self._config.norm_lr_scale) self._create_mixer() self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp", layer_index=layer_index + self._config, self._tensor_space, f"{self.block_name} mlp", layer_index=layer_index ) - # PEFT. - self.norm_1 = self._config.peft.apply_other(self.norm_1) - self.norm_2 = self._config.peft.apply_other(self.norm_2) + # PEFT. Layer freezing must be explicit now. + # self.norm_1 = self._config.peft.apply_other(self.norm_1) + # self.norm_2 = self._config.peft.apply_other(self.norm_2) @abc.abstractmethod def _create_mixer(self): @@ -67,7 +73,7 @@ def _bias_dropout_add( @property def name(self) -> str: - return f"{self._name} {self._layer_index}" + return f"{self.__class__.__name__} {self.block_name} {self._layer_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[TransformerKwargs.hidden_dims] @@ -76,21 +82,21 @@ def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) def _debug_log(self, tensor: torch.Tensor | None, name: str, kwargs: dict[str, typing.Any], *, bias=None) -> None: - if self._config.debug_transformer_memory: + if self._config.debug_block_memory: log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self.name} {name}", str)) - if self._config.debug_transformer and tensor is not None: + if self._config.debug_block and tensor is not None: # TODO: Local vs global log_distributed_tensor( "", tensor if bias is None else tensor + bias, - level=self._config.debug_transformer, + level=self._config.debug_block, meta=self._get_meta(tensor, name, kwargs), distributed=self._tensor_space.distributed, ) log_distributed_grad( "", tensor, - level=self._config.debug_transformer, + level=self._config.debug_block, meta=self._get_meta(tensor, name + " grad", kwargs), distributed=self._tensor_space.distributed, ) @@ -138,13 +144,17 @@ def forward( class TransformerLayer(BaseBlock): - _name = "Transformer layer" _mixer_module_name = "self_attn" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, + config: TransformerConfig, + tensor_space: TensorSpace, + layer_index: int, + block_name: str = "", + return_input: bool = False, ): - super().__init__(config, tensor_space, layer_index, return_input) + super().__init__(config, tensor_space, layer_index, block_name, return_input) def _create_mixer(self): - self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) + self.self_attn = Attention(self._config, self._tensor_space, self._layer_index, self.block_name) diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index 8f16aaea..96c2917d 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -2,7 +2,7 @@ from fast_llm.engine.training.config import TrainerConfig from fast_llm.models.custom.config import CustomModelConfig, CustomTrainerConfig from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig -from fast_llm.models.ssm.config import HybridSSMModelConfig, HybridTrainerConfig +from fast_llm.models.hybrid.config import HybridSSMModelConfig, HybridTrainerConfig from fast_llm.utils import Registry model_registry = Registry[str, FastLLMModelConfig]( diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index d9085c67..d6938b7e 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -4,10 +4,13 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig +from fast_llm.layers.common.config import LLMDimNames from fast_llm.layers.language_model.config import LanguageModelBaseConfig +from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.models.gpt.megatron import set_megatron_distributed_seeds from fast_llm.utils import Assert, div @@ -97,6 +100,11 @@ def micro_batch_splits(self) -> int: @config_class() class GPTBaseModelConfig(LanguageModelBaseConfig): + """ + Base model config for GPT models. + This model is built exclusively from transformer layers which share the same config. + """ + _abstract = False # Debug, to get an exact match with megatron init. @@ -104,6 +112,11 @@ class GPTBaseModelConfig(LanguageModelBaseConfig): default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing ) + transformer: TransformerConfig = Field( + desc="Configuration for the transformer architecture.", + hint=FieldHint.architecture, + ) + @classmethod def _from_dict( cls, @@ -124,8 +137,32 @@ def _from_dict( del default["fused_mlp"] return super()._from_dict(default, strict, flat) + def _validate(self) -> None: + if self.debug: + self.transformer.debug_block = True + self.transformer.debug_block_memory = True + self.transformer.validate() + self.use_position_embeddings = not self.transformer.rotary.enabled + self.embeddings_hidden_dropout = self.transformer.hidden_dropout # legacy behavior + self.head_normalization = self.transformer.normalization # legacy behavior + with self._set_implicit_default(): + if self.init_method_std_embed is None: + self.init_method_std_embed = self.transformer.init_method_std + if self.init_method_max_embed is None: + self.init_method_max_embed = self.transformer.init_method_max + if self.init_method_min_embed is None: + self.init_method_min_embed = self.transformer.init_method_min + super()._validate() + + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + self.transformer.setup_tensor_space(tensor_space) + # Mark the input hidden dimension of the model + tensor_space.add_tensor_dim(TensorDim(LLMDimNames.input_hidden, self.transformer.hidden_size)) + # Mark the output hidden dimension of the model, which is the same for GPT models + tensor_space.add_tensor_dim(TensorDim(LLMDimNames.output_hidden, self.transformer.hidden_size)) + super().setup_tensor_space(tensor_space) + -@config_class() class GPTModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "gpt" diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index b548ab52..eb5f00ed 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -62,9 +62,7 @@ def __init__( if self._config.use_absolute_position_embeddings: self._preprocessors.append(PositionEmbeddingPreprocessor(self._config, self._tensor_space)) if self._config.transformer.rotary.enabled: - self._preprocessors.append( - RotaryEmbeddingPreprocessor(self._config.transformer.rotary, self._tensor_space) - ) + self._preprocessors.append(RotaryEmbeddingPreprocessor(self._config.rotary, self._tensor_space)) if self._use_flash_attention: self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space)) else: diff --git a/fast_llm/models/hybrid/config.py b/fast_llm/models/hybrid/config.py new file mode 100644 index 00000000..060550c5 --- /dev/null +++ b/fast_llm/models/hybrid/config.py @@ -0,0 +1,376 @@ +import logging +import math +import typing + +from blocks import LlambaBlock, LlambaOneBlock + +from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class +from fast_llm.data.data.gpt.config import GPTDataConfig +from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig +from fast_llm.engine.training.config import TrainerConfig +from fast_llm.layers.common.config import LLMDimNames +from fast_llm.layers.language_model.config import LanguageModelBaseConfig +from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig, SSMDimNames +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.transformer import BaseBlock, TransformerLayer +from fast_llm.models.gpt.config import GPTBatchConfig, PretrainedGPTModelConfig +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + from fast_llm.models.gpt.model import GPTInferenceRunner + from fast_llm.models.hybrid.huggingface import HuggingfaceHybridSSMModelForCausalLM + from fast_llm.models.hybrid.model import HybridSSMModel + from fast_llm.models.hybrid.trainer import SSMTrainer + +logger = logging.getLogger(__name__) + + +@config_class(registry=True) +class BlockConfig(Config): + _abstract = True + block_class: typing.ClassVar[type[BaseBlock]] + # config: TransformerConfig | SSMConfig + + lr_scale: list[float] | None = Field( + default=None, + desc="Custom learning rate scale for each layer.", + doc="May be used to freeze some layers by setting their scale to zero.", + hint=FieldHint.feature, + ) + + def setup_tensor_space(self, tensor_space: "TensorSpace", block_name: str) -> None: + raise NotImplementedError() + + @property + def hidden_size(self) -> int: + raise NotImplementedError() + + +@config_class(dynamic_type={BlockConfig: "transformer"}) +class TransformerBlockConfig(BlockConfig, TransformerConfig): + _abstract = False + block_class: typing.ClassVar[type[BaseBlock]] = TransformerLayer + + def setup_tensor_space(self, tensor_space: "TensorSpace", block_name: str) -> None: + TransformerConfig.setup_tensor_space(self, tensor_space, block_name) + + +@config_class(dynamic_type={BlockConfig: "discrete_mamba2"}) +class DiscreteMamba2BlockConfig(BlockConfig, SSMConfig): + _abstract = False + block_class: typing.ClassVar[type[BaseBlock]] = LlambaBlock + + hidden_size: int = Field( + default=1024, + desc="Hidden size of the block.", + hint=FieldHint.architecture, + ) + + # def _validate(self): + # self.config.validate() + + def setup_tensor_space(self, tensor_space: TensorSpace, block_name: str) -> None: + + d_inner = int(self.expansion_factor * self.hidden_size) if self.d_inner is None else self.d_inner + # Hidden dimension + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.model_dim}_{block_name}", self.hidden_size)) + # Mamba-specific dimensions + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.inner_dim}_{block_name}", d_inner)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.state_dim}_{block_name}", self.state_size)) + tensor_space.add_tensor_dim( + TensorDim(f"{SSMDimNames.conv_kernel_size}_{block_name}", self.conv_kernel_dimension) + ) + + # as per https://github.com/cartesia-ai/edge/blob/a0e121ebed3d2324c6d762b0e211a08d62583681/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py#L66C3-L66C4 + headdim = d_inner // self.n_v_heads + Assert.eq(self.n_v_heads, d_inner // headdim) + Assert.eq(d_inner % headdim, 0) + Assert.eq(self.n_v_heads % self.n_qk_heads, 0) + + conv_dim = d_inner + 2 * self.n_qk_heads * self.state_size + inner_proj_dim = 2 * d_inner + 2 * self.n_qk_heads * self.state_size + self.n_v_heads + + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.qk_heads}_{block_name}", self.n_qk_heads)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.v_heads}_{block_name}", self.n_v_heads)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.inner_proj_mamba2}_{block_name}", inner_proj_dim)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.conv_dim}_{block_name}", conv_dim)) + + +@config_class(dynamic_type={BlockConfig: "mamba"}) +class MambaBlockConfig(BlockConfig, SSMConfig): + _abstract = False + block_class: typing.ClassVar[type[BaseBlock]] = LlambaOneBlock + + hidden_size: int = Field( + default=1024, + desc="Hidden size of the block.", + hint=FieldHint.architecture, + ) + + def setup_tensor_space(self, tensor_space: TensorSpace, name: str) -> None: + + if self.dt_rank is None: + mamba_dt_rank = math.ceil(self.hidden_size / 16) + else: + mamba_dt_rank = self.dt_rank + + d_inner = int(self.expansion_factor * self.hidden_size) if self.d_inner is None else self.d_inner + # Hidden dimension + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.model_dim}_{name}", self.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.inner_dim}_{name}", d_inner)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.state_dim}_{name}", self.state_size)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.dt_rank}_{name}", mamba_dt_rank)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.x_proj_dim}_{name}", mamba_dt_rank + self.state_size * 2)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.conv_kernel_size}_{name}", self.conv_kernel_dimension)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.inner_proj_mamba}_{name}", d_inner * 2)) + + +@config_class() +class HybridSSMBaseModelConfig(LanguageModelBaseConfig): + _abstract = False + ############################################################################################ + # Note, transformer and ssm are here for legacy reasons, we should migrate to blocks field + transformer: TransformerConfig = Field( + desc="Configuration for the transformer architecture. Note, having transformer and ssm fields in HybridSSMBaseModelConfig is depricated.", + hint=FieldHint.architecture, + ) + + ssm: SSMConfig = Field( + desc="Configuration for the SSM architecture. Note, having transformer and ssm fields in HybridSSMBaseModelConfig is depricated.", + hint=FieldHint.architecture, + ) + ############################################################################################ + blocks: dict[str, BlockConfig] = Field( + default=None, + desc="Named block configurations that can be referenced in block_pattern.", + hint=FieldHint.architecture, + ) + + hybrid_block_layout: list[str] | None = Field( + default=None, + desc=f"Pattern of blocks to use in the model (still supports the previous depricated format with {SSMBlockType.__members__.values()})", + hint=FieldHint.core, + ) + + default_mtp_type: str | None = Field( + default=None, + desc="Multi-token prediction mixer to use in the model. 't' for Transformer, 'm' for Mamba1, 'm2' for discrete Mamba2. If None, will use the last block type in `hybrid_block_layout`.", + hint=FieldHint.optional, + ) + # TODO: ideally these things should be move to LanguageModelBaseConfig? + # TODO: currently num_layers is defined in TransformerConfig, but ideally this should be migrated to LanguageModelBaseConfig in the future. + # Hence, for now: the num_layers should be set in the first transformer block, if no transformer blocks used we will fallback to num_layers from here. + num_layers: int = Field( + default=12, + desc="Number of layers in the transformer.", + hint=FieldHint.architecture, + valid=check_field(Assert.geq, 0), + ) + + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + """ + Setup the tensor space for the model. + """ + for block_name, block_config in self.blocks.items(): + block_config.setup_tensor_space(tensor_space, block_name) + # The first layer's hidden dimension is the input hidden dimension of the model + tensor_space.add_tensor_dim( + TensorDim(LLMDimNames.input_hidden, self.blocks[self.hybrid_block_layout[0]].hidden_size) + ) + # Mark the output hidden dimension of the model + tensor_space.add_tensor_dim( + TensorDim(LLMDimNames.output_hidden, self.blocks[self.hybrid_block_layout[-1]].hidden_size) + ) + super().setup_tensor_space(tensor_space) + + def _validate(self): + if self.blocks is not None and self.hybrid_block_layout is not None: + # Validate that all pattern entries refer to valid blocks + for block_name in self.hybrid_block_layout: + if block_name not in self.blocks: + raise ValueError(f"Block name '{block_name}' in block_pattern not found in blocks dictionary") + + first_transformer_block_config: TransformerBlockConfig | None = None + + for block_name, block_config in self.blocks.items(): + if isinstance(block_config, TransformerBlockConfig): + if first_transformer_block_config is None: + first_transformer_block_config = block_config + else: + logger.warning( + f"Found multiple transformer blocks with different number of layers, using num_layers from the first transformer block for all" + ) + block_config._validate() + + if first_transformer_block_config is not None: + num_layers = first_transformer_block_config.config.num_layers + logger.warning( + f"TransformerBlockConfig overwrites BaseModelConfig num_layers, setting num_layers = {num_layers}" + ) + self.num_layers = num_layers + else: + logger.warning( + f"No transformer blocks found in blocks dictionary, using num_layers from BaseModelConfig: {self.num_layers} and falling back to old behavior with hybrid_block_layout containing any of {SSMBlockType.__members__.values()}" + ) + if self.hybrid_block_layout is None: + with self._set_implicit_default(): + self.hybrid_block_layout = [SSMBlockType.mamba2_discrete.value] + + if len(self.hybrid_block_layout) != self.transformer.num_layers: + if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: + raise ValueError( + f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" + ) + num_repeats = int(self.transformer.num_layers // len(self.hybrid_block_layout)) + logger.warning( + f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" + ) + self.hybrid_block_layout = self.hybrid_block_layout * num_repeats + + Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) + Assert.custom( + lambda _: all( + block_type in SSMBlockType.__members__.values() for block_type in self.hybrid_block_layout + ), + f"Invalid block type: {self.hybrid_block_layout}. Must be one of {SSMBlockType.__members__.values()}", + ) + Assert.custom( + lambda _: self.default_mtp_type in SSMBlockType.__members__.values() or self.default_mtp_type is None, + f"Invalid MTP type: {self.default_mtp_type}. Must be one of {SSMBlockType.__members__.values()} or None", + ) + # TODO: prepare hybrid_block_layout here + + with self._set_implicit_default(): + if self.init_method_std_embed is None: + self.init_method_std_embed = ( + first_transformer_block_config.config.init_method_std + if first_transformer_block_config is not None + else 0.02 + ) + if self.init_method_max_embed is None: + self.init_method_max_embed = ( + first_transformer_block_config.config.init_method_max + if first_transformer_block_config is not None + else 0.02 + ) + if self.init_method_min_embed is None: + self.init_method_min_embed = ( + first_transformer_block_config.config.init_method_min + if first_transformer_block_config is not None + else 0.02 + ) + + super()._validate() + + +class LLambaHuggingfaceCheckpointFormat(CheckpointFormat): + support_optimizer: typing.ClassVar[bool] = False + name: typing.ClassVar[str] = "llamba" + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.hybrid.conversion import LLambaHuggingfaceCheckpointHandler + + return LLambaHuggingfaceCheckpointHandler + + +class AprielSSMHuggingfaceCheckpointFormat(CheckpointFormat): + support_optimizer: typing.ClassVar[bool] = False + name: typing.ClassVar[str] = "apriel_ssm" + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.hybrid.conversion import AprielSSMHuggingfaceCheckpointHandler + + return AprielSSMHuggingfaceCheckpointHandler + + +class AprielSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): + support_optimizer: typing.ClassVar[bool] = False + name: typing.ClassVar[str] = "apriel_ssm_hybrid" + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.hybrid.conversion import AprielSSMHHybridHuggingfaceCheckpointHandler + + return AprielSSMHHybridHuggingfaceCheckpointHandler + + +@config_class() +class HybridSSMModelConfig(FastLLMModelConfig): + _abstract = False + model_name: typing.ClassVar[str] = "hybrid_ssm" + base_model: HybridSSMBaseModelConfig = FieldUpdate() + checkpoint_formats = FastLLMModelConfig.checkpoint_formats + ( + LLambaHuggingfaceCheckpointFormat, + AprielSSMHuggingfaceCheckpointFormat, + AprielSSMHHybridHuggingfaceCheckpointFormat, + ) + + @classmethod + def get_model_class(cls) -> type["HybridSSMModel"]: + from fast_llm.models.hybrid.model import HybridSSMModel + + return HybridSSMModel + + @classmethod + def get_huggingface_model_class(cls) -> type["HuggingfaceHybridSSMModelForCausalLM"]: + from fast_llm.models.hybrid.huggingface import HuggingfaceHybridSSMModelForCausalLM + + return HuggingfaceHybridSSMModelForCausalLM + + def _validate(self): + logger.warning( + "HybridSSMModelConfig is being instantiated. This model is experimental and may not work as expected." + ) + super()._validate() + + +@config_class() +class PretrainedHybridSSMModelConfig(PretrainedFastLLMModelConfig): + _abstract = False + model: HybridSSMModelConfig = FieldUpdate() + + +@config_class() +class HybridTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): + data: GPTDataConfig = FieldUpdate() + batch: GPTBatchConfig = FieldUpdate() + reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() + + @classmethod + def get_trainer_class(cls) -> type["SSMTrainer"]: + from fast_llm.models.hybrid.trainer import SSMTrainer + + return SSMTrainer + + def _validate(self) -> None: + super()._validate() + if (name := self.model.base_model.distillation_model) is None: + Assert.empty(self.reference_models) + else: + Assert.eq(self.reference_models.keys(), {name}) + if self.model.base_model.use_absolute_position_embeddings: + Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) + # if self.model.base_model.distillation_model is not None: + # # TODO: Support loss masking for distillation? + # assert not self.batch.use_loss_masking_spans + for reference_model in self.reference_models.values(): + Assert.none(reference_model.model.base_model.distillation_model) + # TODO: Support more LM head features. + Assert.none(reference_model.model.base_model.cross_entropy_splits) + Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) + Assert.geq(reference_model.model.base_model.prediction_heads, self.model.base_model.prediction_heads) + + @classmethod + def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: + from fast_llm.models.gpt.model import GPTInferenceRunner + + # TODO: we dont have inference runner for SSM/Hybrid yet, should return None? + logger.warning( + "No inference runner for SSM/Hybrid yet, using GPTInferenceRunner for now, which does not support SSM/Hybrid" + ) + + return GPTInferenceRunner diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/hybrid/conversion.py similarity index 99% rename from fast_llm/models/ssm/conversion.py rename to fast_llm/models/hybrid/conversion.py index 357a26c0..ef270145 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/hybrid/conversion.py @@ -20,13 +20,13 @@ from fast_llm.layers.common.config import NormalizationType from fast_llm.layers.ssm.config import SSMBlockType from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter -from fast_llm.models.ssm.config import ( +from fast_llm.models.hybrid.config import ( AprielSSMHHybridHuggingfaceCheckpointFormat, AprielSSMHuggingfaceCheckpointFormat, HybridSSMModelConfig, LLambaHuggingfaceCheckpointFormat, ) -from fast_llm.models.ssm.model import HybridSSMModel +from fast_llm.models.hybrid.model import HybridSSMModel from fast_llm.utils import Assert if typing.TYPE_CHECKING: diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py b/fast_llm/models/hybrid/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py similarity index 100% rename from fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py rename to fast_llm/models/hybrid/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/hybrid/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py similarity index 99% rename from fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py rename to fast_llm/models/hybrid/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index 777fd3cf..bc62f241 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/hybrid/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -27,7 +27,7 @@ from transformers.utils import LossKwargs, add_start_docstrings_to_model_forward, can_return_tuple, logging from transformers.utils.generic import ModelOutput -from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig +from fast_llm.models.hybrid.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig logger = logging.get_logger(__name__) diff --git a/fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py b/fast_llm/models/hybrid/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py similarity index 100% rename from fast_llm/models/ssm/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py rename to fast_llm/models/hybrid/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py diff --git a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py b/fast_llm/models/hybrid/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py similarity index 99% rename from fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py rename to fast_llm/models/hybrid/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py index ddb7d0f7..52b8e47e 100644 --- a/fast_llm/models/ssm/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py +++ b/fast_llm/models/hybrid/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py @@ -21,7 +21,7 @@ from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging from transformers.utils.generic import ModelOutput -from fast_llm.models.ssm.external.apriel_hybrid.configuration_ssm_hybrid_apriel import ( +from fast_llm.models.hybrid.external.apriel_hybrid.configuration_ssm_hybrid_apriel import ( ROPE_INIT_FUNCTIONS, AprielSSMHybridConfig, ) diff --git a/fast_llm/models/ssm/external/apriel_ssm/configuration_ssm_apriel.py b/fast_llm/models/hybrid/external/apriel_ssm/configuration_ssm_apriel.py similarity index 100% rename from fast_llm/models/ssm/external/apriel_ssm/configuration_ssm_apriel.py rename to fast_llm/models/hybrid/external/apriel_ssm/configuration_ssm_apriel.py diff --git a/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py b/fast_llm/models/hybrid/external/apriel_ssm/modeling_ssm_apriel.py similarity index 99% rename from fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py rename to fast_llm/models/hybrid/external/apriel_ssm/modeling_ssm_apriel.py index 09dc8259..82272e2a 100644 --- a/fast_llm/models/ssm/external/apriel_ssm/modeling_ssm_apriel.py +++ b/fast_llm/models/hybrid/external/apriel_ssm/modeling_ssm_apriel.py @@ -19,7 +19,7 @@ from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging from transformers.utils.generic import ModelOutput -from fast_llm.models.ssm.external.apriel_ssm.configuration_ssm_apriel import AprielSSMConfig +from fast_llm.models.hybrid.external.apriel_ssm.configuration_ssm_apriel import AprielSSMConfig logger = logging.get_logger(__name__) diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm/models/hybrid/external/eval/apriel_eval_wrapper.py similarity index 91% rename from fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py rename to fast_llm/models/hybrid/external/eval/apriel_eval_wrapper.py index e15de8bb..9ccc7768 100644 --- a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py +++ b/fast_llm/models/hybrid/external/eval/apriel_eval_wrapper.py @@ -24,13 +24,13 @@ def __init__(self, pretrained, **kwargs) -> None: def _get_config(self, pretrained: str, **kwargs) -> None: """Get the model configuration.""" - from fast_llm.models.ssm.external.apriel_ssm.configuration_ssm_apriel import AprielSSMConfig + from fast_llm.models.hybrid.external.apriel_ssm.configuration_ssm_apriel import AprielSSMConfig self._config = AprielSSMConfig.from_pretrained(pretrained) def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: """Create the model.""" - from fast_llm.models.ssm.external.apriel_ssm.modeling_ssm_apriel import AprielSSMForCausalLM + from fast_llm.models.hybrid.external.apriel_ssm.modeling_ssm_apriel import AprielSSMForCausalLM self._model = AprielSSMForCausalLM.from_pretrained( pretrained, @@ -76,13 +76,13 @@ def __init__(self, pretrained, **kwargs) -> None: def _get_config(self, pretrained: str, **kwargs) -> None: """Get the model configuration.""" - from fast_llm.models.ssm.external.apriel_hybrid.configuration_ssm_hybrid_apriel import AprielSSMHybridConfig + from fast_llm.models.hybrid.external.apriel_hybrid.configuration_ssm_hybrid_apriel import AprielSSMHybridConfig self._config = AprielSSMHybridConfig.from_pretrained(pretrained, trust_remote_code=True) def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: """Create the model.""" - from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridForCausalLM + from fast_llm.models.hybrid.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridForCausalLM self._model = AprielSSMHybridForCausalLM.from_pretrained( pretrained, @@ -135,7 +135,7 @@ def __init__(self, pretrained, **kwargs) -> None: def _get_config(self, pretrained: str, **kwargs) -> None: """Get the model configuration.""" - from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import ( + from fast_llm.models.hybrid.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import ( AprielSSMHybridConfig, ) @@ -143,7 +143,7 @@ def _get_config(self, pretrained: str, **kwargs) -> None: def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: """Create the model.""" - from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( + from fast_llm.models.hybrid.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( AprielSSMHybridForCausalLM, ) diff --git a/fast_llm/models/ssm/external/eval/run_lm_eval.py b/fast_llm/models/hybrid/external/eval/run_lm_eval.py similarity index 62% rename from fast_llm/models/ssm/external/eval/run_lm_eval.py rename to fast_llm/models/hybrid/external/eval/run_lm_eval.py index c910bcc3..b6313cf1 100644 --- a/fast_llm/models/ssm/external/eval/run_lm_eval.py +++ b/fast_llm/models/hybrid/external/eval/run_lm_eval.py @@ -1,6 +1,6 @@ from lm_eval.__main__ import cli_evaluate -from fast_llm.models.ssm.external.eval.apriel_eval_wrapper import ( # noqa: F401 +from fast_llm.models.hybrid.external.eval.apriel_eval_wrapper import ( # noqa: F401 AprielHybridSSMWrapper, AprielSSMWrapper, ) diff --git a/fast_llm/models/ssm/external/llamba/configuration_mtp_llamba.py b/fast_llm/models/hybrid/external/llamba/configuration_mtp_llamba.py similarity index 100% rename from fast_llm/models/ssm/external/llamba/configuration_mtp_llamba.py rename to fast_llm/models/hybrid/external/llamba/configuration_mtp_llamba.py diff --git a/fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py b/fast_llm/models/hybrid/external/llamba/modeling_mtp_llamba.py similarity index 100% rename from fast_llm/models/ssm/external/llamba/modeling_mtp_llamba.py rename to fast_llm/models/hybrid/external/llamba/modeling_mtp_llamba.py diff --git a/fast_llm/models/ssm/external/make_hybrid_checkpoint.py b/fast_llm/models/hybrid/external/make_hybrid_checkpoint.py similarity index 84% rename from fast_llm/models/ssm/external/make_hybrid_checkpoint.py rename to fast_llm/models/hybrid/external/make_hybrid_checkpoint.py index a0616ab6..2fe15c0d 100644 --- a/fast_llm/models/ssm/external/make_hybrid_checkpoint.py +++ b/fast_llm/models/hybrid/external/make_hybrid_checkpoint.py @@ -4,8 +4,8 @@ import torch from transformers import AutoConfig, AutoModelForCausalLM -from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig -from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import AprielSSMHybridForCausalLM +from fast_llm.models.hybrid.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig +from fast_llm.models.hybrid.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import AprielSSMHybridForCausalLM device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/fast_llm/models/ssm/huggingface.py b/fast_llm/models/hybrid/huggingface.py similarity index 83% rename from fast_llm/models/ssm/huggingface.py rename to fast_llm/models/hybrid/huggingface.py index 77cd346f..6e818a32 100644 --- a/fast_llm/models/ssm/huggingface.py +++ b/fast_llm/models/hybrid/huggingface.py @@ -2,8 +2,8 @@ from fast_llm.engine.huggingface.config import HuggingfaceModelConfig from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM -from fast_llm.models.ssm.config import HybridSSMModelConfig -from fast_llm.models.ssm.model import HybridSSMModel +from fast_llm.models.hybrid.config import HybridSSMModelConfig +from fast_llm.models.hybrid.model import HybridSSMModel logger = logging.getLogger(__name__) diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/hybrid/model.py similarity index 66% rename from fast_llm/models/ssm/model.py rename to fast_llm/models/hybrid/model.py index 118a195b..f464a42f 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/hybrid/model.py @@ -7,20 +7,17 @@ from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 -from fast_llm.layers.ssm.llamba_block import LlambaBlock from fast_llm.layers.ssm.mamba_layer import MambaLayer from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.models.gpt.model import GPTBaseModel -from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType +from fast_llm.models.hybrid.config import HybridSSMBaseModelConfig, HybridSSMModelConfig logger = logging.getLogger(__name__) class HybridSSMBaseModel[ConfigType: HybridSSMBaseModelConfig](GPTBaseModel[ConfigType]): """ - A hybrid model that interleaves Transformer and Mamba blocks. - Right now only LlambaBlock is supported. - As for the mixer, transformer uses MHA. For the LlambaBlock we support Mamba1 and discrete mamba2. + A hybrid model that can interleave Transformer, Mamba and other blocks. """ config_class: typing.ClassVar[type[HybridSSMBaseModelConfig]] = HybridSSMBaseModelConfig @@ -31,7 +28,7 @@ def __init__( config: HybridSSMBaseModelConfig, distributed_config: DistributedConfig, ): - self.SSM_BLOCK_CLS = LlambaBlock # TODO: extend to other block types if needed + super().__init__(config, distributed_config) def get_output_layers(self) -> list[Layer]: @@ -87,49 +84,19 @@ def get_layers(self) -> list[Layer]: layers = [LanguageModelEmbedding(self._config, self._tensor_space)] # Create blocks according to pattern - for i, block_type in enumerate(self._config.hybrid_block_layout): - if block_type == SSMBlockType.transformer.value: - # Transformer block - layers.append( - TransformerLayer( - self._config.transformer, - self._tensor_space, - layer_index=i + 1, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - ) - elif block_type == SSMBlockType.mamba2_discrete.value: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=DiscreteMamba2, + for i, block_name in enumerate(self._config.hybrid_block_layout): + BLOCK_CLS = self._config.blocks[block_name].block_class + layers.append( + BLOCK_CLS( + self._config.blocks[block_name], + self._tensor_space, layer_index=i + 1, - tensor_space=self._tensor_space, return_input=( i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 ), + block_name=block_name, ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba.value: - # Create Mamba block - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") - - # Add the output layers + ) layers += self.get_output_layers() return layers diff --git a/fast_llm/models/ssm/trainer.py b/fast_llm/models/hybrid/trainer.py similarity index 71% rename from fast_llm/models/ssm/trainer.py rename to fast_llm/models/hybrid/trainer.py index c0e5be26..55c16ad0 100644 --- a/fast_llm/models/ssm/trainer.py +++ b/fast_llm/models/hybrid/trainer.py @@ -1,8 +1,8 @@ import typing from fast_llm.models.gpt.trainer import GPTTrainer -from fast_llm.models.ssm.config import HybridTrainerConfig -from fast_llm.models.ssm.model import HybridSSMModel +from fast_llm.models.hybrid.config import HybridTrainerConfig +from fast_llm.models.hybrid.model import HybridSSMModel class SSMTrainer[ConfigType: HybridTrainerConfig](GPTTrainer[ConfigType]): diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py deleted file mode 100644 index 0207cfdd..00000000 --- a/fast_llm/models/ssm/config.py +++ /dev/null @@ -1,235 +0,0 @@ -import logging -import math -import typing - -from fast_llm.config import Field, FieldHint, FieldUpdate, config_class -from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig -from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelBaseConfig -from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig, SSMDimNames -from fast_llm.models.gpt.config import GPTBatchConfig, PretrainedGPTModelConfig -from fast_llm.utils import Assert - -if typing.TYPE_CHECKING: - from fast_llm.models.gpt.model import GPTInferenceRunner - from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM - from fast_llm.models.ssm.model import HybridSSMModel - from fast_llm.models.ssm.trainer import SSMTrainer - -logger = logging.getLogger(__name__) - - -@config_class() -class HybridSSMBaseModelConfig(LanguageModelBaseConfig): - _abstract = False - - ssm: SSMConfig = Field( - desc="Configuration for the transformer architecture.", - hint=FieldHint.architecture, - ) - hybrid_block_layout: list[str] | None = Field( - default=None, - desc=f"Pattern of blocks to use in the model. Availabel types: {SSMBlockType.__members__.values()}", - hint=FieldHint.core, - ) - default_mtp_type: str | None = Field( - default=None, - desc="Multi-token prediction mixer to use in the model. 't' for Transformer, 'm' for Mamba1, 'm2' for discrete Mamba2. If None, will use the last block type in `hybrid_block_layout`.", - hint=FieldHint.optional, - ) - use_megatron_initialization: bool = Field( - default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing - ) # TODO: is this needed? - - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - """ - Setup the tensor space for the model. - Some of these can be setup directly in the layer config, but keeping them here for clarity. - """ - super().setup_tensor_space(tensor_space) - if ( - not SSMBlockType.mamba2_discrete.value in self.hybrid_block_layout - and not SSMBlockType.mamba.value in self.hybrid_block_layout - ): - raise ValueError( - f"Block pattern must contain at least one '{SSMBlockType.mamba2_discrete.value}' or '{SSMBlockType.mamba.value}', use gpt model for transformer only architectures" - ) - - if self.ssm.dt_rank is None: - mamba_dt_rank = math.ceil(self.transformer.hidden_size / 16) - else: - mamba_dt_rank = self.ssm.dt_rank - - d_inner = ( - int(self.ssm.expansion_factor * self.transformer.hidden_size) - if self.ssm.d_inner is None - else self.ssm.d_inner - ) - # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.model_dim, self.transformer.hidden_size)) - # Mamba-specific dimensions - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_dim, d_inner)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.state_dim, self.ssm.state_size)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.dt_rank, mamba_dt_rank)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim, mamba_dt_rank + self.ssm.state_size * 2)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel_size, self.ssm.conv_kernel_dimension)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba, d_inner * 2)) - - if SSMBlockType.mamba2_discrete.value in self.hybrid_block_layout: - # Mamba2 specific dimensions - # as per https://github.com/cartesia-ai/edge/blob/a0e121ebed3d2324c6d762b0e211a08d62583681/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py#L66C3-L66C4 - headdim = d_inner // self.ssm.n_v_heads - Assert.eq(self.ssm.n_v_heads, d_inner // headdim) - Assert.eq(d_inner % headdim, 0) - Assert.eq(self.ssm.n_v_heads % self.ssm.n_qk_heads, 0) - - conv_dim = d_inner + 2 * self.ssm.n_qk_heads * self.ssm.state_size - inner_proj_dim = 2 * d_inner + 2 * self.ssm.n_qk_heads * self.ssm.state_size + self.ssm.n_v_heads - - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.head_dim, headdim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.qk_heads, self.ssm.n_qk_heads)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.v_heads, self.ssm.n_v_heads)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba2, inner_proj_dim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_dim, conv_dim)) - - def _validate(self): - if self.hybrid_block_layout is None: - with self._set_implicit_default(): - self.hybrid_block_layout = [SSMBlockType.mamba2_discrete.value] - - if len(self.hybrid_block_layout) != self.transformer.num_layers: - if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: - raise ValueError( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" - ) - num_repeats = int(self.transformer.num_layers // len(self.hybrid_block_layout)) - logger.warning( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" - ) - self.hybrid_block_layout = self.hybrid_block_layout * num_repeats - - Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) - Assert.custom( - lambda _: all(block_type in SSMBlockType.__members__.values() for block_type in self.hybrid_block_layout), - f"Invalid block type: {self.hybrid_block_layout}. Must be one of {SSMBlockType.__members__.values()}", - ) - Assert.custom( - lambda _: self.default_mtp_type in SSMBlockType.__members__.values() or self.default_mtp_type is None, - f"Invalid MTP type: {self.default_mtp_type}. Must be one of {SSMBlockType.__members__.values()} or None", - ) - - super()._validate() - - -class LLambaHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False - name: typing.ClassVar[str] = "llamba" - - @classmethod - def get_handler_class(cls) -> type[CheckpointHandler]: - from fast_llm.models.ssm.conversion import LLambaHuggingfaceCheckpointHandler - - return LLambaHuggingfaceCheckpointHandler - - -class AprielSSMHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False - name: typing.ClassVar[str] = "apriel_ssm" - - @classmethod - def get_handler_class(cls) -> type[CheckpointHandler]: - from fast_llm.models.ssm.conversion import AprielSSMHuggingfaceCheckpointHandler - - return AprielSSMHuggingfaceCheckpointHandler - - -class AprielSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False - name: typing.ClassVar[str] = "apriel_ssm_hybrid" - - @classmethod - def get_handler_class(cls) -> type[CheckpointHandler]: - from fast_llm.models.ssm.conversion import AprielSSMHHybridHuggingfaceCheckpointHandler - - return AprielSSMHHybridHuggingfaceCheckpointHandler - - -@config_class() -class HybridSSMModelConfig(FastLLMModelConfig): - _abstract = False - model_name: typing.ClassVar[str] = "hybrid_ssm" - base_model: HybridSSMBaseModelConfig = FieldUpdate(default_factory=HybridSSMBaseModelConfig) - checkpoint_formats = FastLLMModelConfig.checkpoint_formats + ( - LLambaHuggingfaceCheckpointFormat, - AprielSSMHuggingfaceCheckpointFormat, - AprielSSMHHybridHuggingfaceCheckpointFormat, - ) - - @classmethod - def get_model_class(cls) -> type["HybridSSMModel"]: - from fast_llm.models.ssm.model import HybridSSMModel - - return HybridSSMModel - - @classmethod - def get_huggingface_model_class(cls) -> type["HuggingfaceHybridSSMModelForCausalLM"]: - from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM - - return HuggingfaceHybridSSMModelForCausalLM - - def _validate(self): - logger.warning( - "HybridSSMModelConfig is being instantiated. This model is experimental and may not work as expected." - ) - super()._validate() - - -@config_class() -class PretrainedHybridSSMModelConfig(PretrainedFastLLMModelConfig): - _abstract = False - model: HybridSSMModelConfig = FieldUpdate() - - -@config_class() -class HybridTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): - data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) - batch: GPTBatchConfig = FieldUpdate(default_factory=GPTBatchConfig) - reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() - - @classmethod - def get_trainer_class(cls) -> type["SSMTrainer"]: - from fast_llm.models.ssm.trainer import SSMTrainer - - return SSMTrainer - - def _validate(self) -> None: - super()._validate() - if (name := self.model.base_model.distillation_model) is None: - Assert.empty(self.reference_models) - else: - Assert.eq(self.reference_models.keys(), {name}) - if self.model.base_model.use_absolute_position_embeddings: - Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) - # if self.model.base_model.distillation_model is not None: - # # TODO: Support loss masking for distillation? - # assert not self.batch.use_loss_masking_spans - for reference_model in self.reference_models.values(): - Assert.none(reference_model.model.base_model.distillation_model) - # TODO: Support more LM head features. - Assert.none(reference_model.model.base_model.cross_entropy_splits) - Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) - Assert.geq(reference_model.model.base_model.prediction_heads, self.model.base_model.prediction_heads) - - @classmethod - def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: - from fast_llm.models.gpt.model import GPTInferenceRunner - - # TODO: we dont have inference runner for SSM/Hybrid yet, should return None? - logger.warning( - "No inference runner for SSM/Hybrid yet, using GPTInferenceRunner for now, which does not support SSM/Hybrid" - ) - - return GPTInferenceRunner diff --git a/tests/common.py b/tests/common.py index 6179957b..11567977 100644 --- a/tests/common.py +++ b/tests/common.py @@ -23,7 +23,7 @@ Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) -from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, LLambaHuggingfaceCheckpointFormat +from fast_llm.models.hybrid.config import HybridSSMBaseModelConfig, LLambaHuggingfaceCheckpointFormat from fast_llm.tools.train import CliTrainingConfig from tests.compare_tensor_logs import CompareConfig, compare_tensor_logs diff --git a/tests/test_modular_config.py b/tests/test_modular_config.py new file mode 100644 index 00000000..52d2750b --- /dev/null +++ b/tests/test_modular_config.py @@ -0,0 +1,37 @@ +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.ssm.config import SSMConfig +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.models.hybrid.config import HybridSSMBaseModelConfig, MambaBlockConfig, TransformerBlockConfig +from fast_llm.models.hybrid.model import HybridSSMBaseModel + +config = HybridSSMBaseModelConfig( + blocks={ + "transformer_block": TransformerBlockConfig( + transformer=TransformerConfig( + hidden_size=4096, + num_attention_heads=32, + num_layers=10, + ), + ), + "mamba_block": MambaBlockConfig( + ssm=SSMConfig( + state_size=16, + ), + ), + "mamba2_block": MambaBlockConfig( + ssm=SSMConfig( + state_size=16, + ), + ), + }, + hybrid_block_layout=["mamba_block", "mamba2_block", "mamba_block"], +) + +distributed_config = DistributedConfig( + tensor_parallel=1, + pipeline_parallel=1, + world_size=1, +) + +# Create model +model = HybridSSMBaseModel(config, distributed_config) diff --git a/tests/test_mtp.py b/tests/test_mtp.py index 9f1939f1..ea46ace3 100644 --- a/tests/test_mtp.py +++ b/tests/test_mtp.py @@ -20,7 +20,7 @@ try: from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 from fast_llm.layers.ssm.mamba_layer import MambaLayer - from fast_llm.models.ssm.model import HybridSSMBaseModel + from fast_llm.models.hybrid.model import HybridSSMBaseModel except ImportError: MambaLayer, HybridSSMBaseModel, DiscreteMamba2 = ( None, diff --git a/tests/test_ssms.py b/tests/test_ssms.py index f9303412..ea3732b8 100644 --- a/tests/test_ssms.py +++ b/tests/test_ssms.py @@ -16,14 +16,18 @@ from fast_llm.layers.ssm.config import SSMBlockType from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat -from fast_llm.models.ssm.config import AprielSSMHHybridHuggingfaceCheckpointFormat, LLambaHuggingfaceCheckpointFormat +from fast_llm.models.hybrid.config import ( + AprielSSMHHybridHuggingfaceCheckpointFormat, + LLambaHuggingfaceCheckpointFormat, +) from tests.common import get_hybrid_config, materialize_meta_tensors try: + from blocks import LlambaBlock + from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 - from fast_llm.layers.ssm.llamba_block import LlambaBlock from fast_llm.layers.ssm.mamba_layer import MambaLayer - from fast_llm.models.ssm.model import HybridSSMBaseModel, HybridSSMModel + from fast_llm.models.hybrid.model import HybridSSMBaseModel, HybridSSMModel except ImportError: MambaLayer, LlambaBlock, HybridSSMBaseModel, DiscreteMamba2 = ( None, @@ -140,7 +144,7 @@ def test_load_from_llamba_checkpoint(distributed_config): def get_hf_apriel_hybrid_out(input_ids, path, format): - from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridForCausalLM + from fast_llm.models.hybrid.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridForCausalLM model = AprielSSMHybridForCausalLM.from_pretrained(path, strict=True).to("cuda") parameter_sum = sum(p.detach().cpu().numpy().sum() for p in model.parameters()) From ac4bfa95fda30d9a8055eae252aa2db91626d223 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 27 May 2025 12:32:21 +0000 Subject: [PATCH 107/114] wip --- fast_llm/layers/common/config.py | 142 ++++++++--- fast_llm/layers/ssm/config.py | 17 +- fast_llm/layers/transformer/config.py | 109 ++++----- .../layers/transformer/mixture_of_experts.py | 2 +- fast_llm/models/gpt/config.py | 1 + fast_llm/models/gpt/model.py | 8 +- fast_llm/models/hybrid/config.py | 182 +++++++++----- fast_llm/models/hybrid/model.py | 49 +--- tests/common.py | 4 +- tests/{test_ssms.py => test_hybrid.py} | 228 +++++++++--------- tests/test_modular_config.py | 4 +- 11 files changed, 395 insertions(+), 351 deletions(-) rename tests/{test_ssms.py => test_hybrid.py} (58%) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 0f0e1001..bb1e6319 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -156,15 +156,89 @@ class PeftType(str, enum.Enum): lora = "lora" -@config_class() +@config_class(registry=True) class PeftConfig(BaseModelConfig): - _abstract = False + _abstract = True type: PeftType = Field( default=PeftType.none, desc="The type of parameter-efficient fine tuning to use Only LoRA is supported at the moment.", hint=FieldHint.core, ) + + # @classmethod + # def get_subclass(cls, name: str): + # return super().get_subclass(name) + + def validate(self: typing.Self, *, _is_validating: bool = False): + """ + Validate a class and mark it as read-only + This should not be overridden in derived classes. + """ + # Should be handled in `from_dict`, but can fail if instantiating directly. + try: + expected_class = self.get_subclass(self.type) + except KeyError as e: + # Delayed instantiation error in `from_dict`. + raise ValueError(*e.args) + + if expected_class is not None: + # Should be handled in `from_dict`, but can fail if instantiating directly. + # Assert.is_(expected_class, self.__class__) + Assert.custom(issubclass, expected_class, self.__class__) + + if not self._validated: + try: + self._validate() + except Exception as e: + raise type(e)("\n".join(e.args)) from None + self._validated = True + return self + + +@config_class(dynamic_type={PeftConfig: "none"}) +class EmptyPeftConfig(PeftConfig): + """ + A dummy PeftConfig that does nothing. + """ + + _abstract = False + + def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": + return linear + + def validate(self: typing.Self, *, _is_validating: bool = False): + """ + Validate a class and mark it as read-only + This should not be overridden in derived classes. + """ + # Should be handled in `from_dict`, but can fail if instantiating directly. + try: + expected_class = self.get_subclass(self.type) + except KeyError as e: + # Delayed instantiation error in `from_dict`. + raise ValueError(*e.args) + + if expected_class is not None: + # Should be handled in `from_dict`, but can fail if instantiating directly. + Assert.is_(self.__class__, expected_class) + + if not self._validated: + try: + self._validate() + except Exception as e: + raise type(e)("\n".join(e.args)) from None + self._validated = True + return self + + +@config_class(dynamic_type={PeftConfig: "lora"}) +class LoRAConfig(PeftConfig): + """ + LoRA configuration. + """ + + _abstract = False rank: int = Field( default=8, desc="The LoRA rank, i.e. the size of the intermediate dimension.", @@ -182,23 +256,18 @@ class PeftConfig(BaseModelConfig): ) def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": - if self.type == PeftType.none: - return linear - elif self.type == PeftType.lora: - from fast_llm.layers.common.peft import lora_linear - - # TODO: Init method? - return lora_linear( - linear, - linear.weight.param_init_method, - linear.weight.param_init_method, - self.rank, - self.alpha, - self.dropout, - **kwargs, - ) - else: - raise NotImplementedError(self.type) + from fast_llm.layers.common.peft import lora_linear + + # TODO: Init method? + return lora_linear( + linear, + linear.weight.param_init_method, + linear.weight.param_init_method, + self.rank, + self.alpha, + self.dropout, + **kwargs, + ) class RoutingType(str, enum.Enum): @@ -212,19 +281,19 @@ class AddLinearBiasChoices(str, enum.Enum): only_attn_qkv = "only_attn_qkv" -class BaseBlockSubLayerName(str, enum.Enum): +class BaseBlockSubLayerName: mlp_1 = "mlp_1" mlp_2 = "mlp_2" -@config_class(registry=True) -class BaseBlockPeftConfig(PeftConfig): +@config_class(dynamic_type={PeftConfig: "base_lora"}) +class BaseBlockLoRAConfig(LoRAConfig): """ - Peft Cofnig that applies to transformer layer. If this is used with GPTBaseModel it is reused for all transformer layers. - Note, this has no effect on the embedding layer, - if you want to freeze the embeddings (and other layers outside the transformer) you need to do so explicitly by setting embedding lr_scale to 0. + TODO: Add support for MLP. """ + _abstract = False + layers: list[BaseBlockSubLayerName] = Field( default=None, desc="The layers on which to apply LoRA.", @@ -232,31 +301,30 @@ class BaseBlockPeftConfig(PeftConfig): ) def apply_linear(self, linear: "LinearBase", layer_type: BaseBlockSubLayerName | None = None) -> "LinearLike": - if self.type != PeftType.none: - if layer_type is None or self.layers is None or layer_type in self.layers: - return super().apply_linear(linear) + if layer_type is None or self.layers is None or layer_type in self.layers: + return super().apply_linear(linear) return linear def _validate(self) -> None: if self.layers is None: with self._set_implicit_default(): self.layers = [] - if self.type != PeftType.none: - if BaseBlockSubLayerName.mlp_1 in self.layers or BaseBlockSubLayerName.mlp_2 in self.layers: - # TODO: Add MLP support. - raise NotImplementedError("LoRA not supported for MLP.") + if BaseBlockSubLayerName.mlp_1 in self.layers or BaseBlockSubLayerName.mlp_2 in self.layers: + # TODO: Add MLP support. + raise NotImplementedError("LoRA not supported for MLP.") -for name in PeftType: - # We need this because we are using the reserved field name `type`. - # TODO: Implement proper dynamic typing. - BaseBlockPeftConfig.register_subclass(name.value, BaseBlockPeftConfig) +# for name in PeftType: +# # We need this because we are using the reserved field name `type`. +# # TODO: Implement proper dynamic typing. +# BaseBlockPeftConfig.register_subclass(name.value, BaseBlockPeftConfig) @config_class() class BaseBlockConfig(BaseModelConfig): + _abstract = True - peft: BaseBlockPeftConfig = Field( + peft: PeftConfig = Field( desc="Configuration for the parameter-efficient fine tuning.", hint=FieldHint.architecture, ) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 13418254..f7a978e5 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,8 +1,6 @@ -import enum - from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig +from fast_llm.layers.common.config import BaseBlockConfig, NormalizationConfig from fast_llm.utils import Assert @@ -21,19 +19,8 @@ class SSMDimNames: v_heads = "v_heads" # Number of V heads -class SSMBlockType(str, enum.Enum): - """ - An enum for the available mamba types for the MLP layer. - """ - - mamba = "m" - mamba2_discrete = "m2d" - mamba2 = "m2" - transformer = "t" - - @config_class() -class SSMConfig(LLMBlockConfig): +class SSMConfig(BaseBlockConfig): _abstract = False # Normalization diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 5aa62553..04710471 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -13,10 +13,10 @@ from fast_llm.functional.config import TritonConfig from fast_llm.layers.common.config import ( BaseBlockConfig, - BaseBlockPeftConfig, + BaseBlockLoRAConfig, BaseBlockSubLayerName, LLMDimNames, - PeftType, + PeftConfig, ) from fast_llm.utils import Assert, div @@ -152,11 +152,11 @@ class TransformerSubLayerName(BaseBlockSubLayerName): dense = "dense" -@config_class(registry=True) -class TransformerPeftConfig(BaseBlockPeftConfig): +@config_class(dynamic_type={PeftConfig: "transformer_lora"}) +class TransformerLoRaConfig(BaseBlockLoRAConfig): """ - Peft Cofnig that applies to transformer layer. If this is used with GPTBaseModel it is reused for all transformer layers. - Note, this does not freeze layers! + LoRa config that applies to transformer layer. If this is used with GPTBaseModel it is reused for all transformer layers. + Note, this does not freeze layers. If you want to freeze weights, you need to do so explicitly by setting the corresponding layer's lr_scales (embeddings/mlp etc.) to 0. """ @@ -165,74 +165,49 @@ class TransformerPeftConfig(BaseBlockPeftConfig): desc="The layers on which to apply LoRA.", hint=FieldHint.feature, ) - freeze_others: bool = Field( - default=True, - desc="Whether to freeze other layers during training.", - ) def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - if self.type != PeftType.none: - if layer_type is None or self.layers is None or layer_type in self.layers: - if layer_type == TransformerSubLayerName.key: - return super().apply_linear(linear, out_channel_end=div(linear._out_dim.global_size, 2)) - elif layer_type == TransformerSubLayerName.value_: - return super().apply_linear(linear, out_channel_begin=div(linear._out_dim.global_size, 2)) - else: - return super().apply_linear(linear) - elif self.freeze_others: - linear.weight.requires_grad = False + if layer_type is None or self.layers is None or layer_type in self.layers: + if layer_type == TransformerSubLayerName.key: + return super().apply_linear(linear, out_channel_end=div(linear._out_dim.global_size, 2)) + elif layer_type == TransformerSubLayerName.value_: + return super().apply_linear(linear, out_channel_begin=div(linear._out_dim.global_size, 2)) + else: + return super().apply_linear(linear) + # elif self.freeze_others: + # linear.weight.requires_grad = False return linear - # def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - # warnings.warn("TransformerPeftConfig.apply_other is deprecated. Use explicit layer freezing using e.g. learning rate scaling parameters.") - # # if self.type != PeftType.none and self.freeze_others: - # # for parameter in module.parameters(): - # # parameter.requires_grad = False - # # return module - - # def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - # if self.type != PeftType.none and self.freeze_others: - # warnings.warn( - # "Freezing weights with TransformerPeftConfig. Note, this does not freeze the embeddings or output heads, those must be frozen explicitly using their lr_scales." - # ) - # parameter.requires_grad = False - # return parameter - def _validate(self) -> None: - super()._validate() if self.layers is None: with self._set_implicit_default(): # Setting the default layers only whee PeFT is enabled # so they don't appear when serializing the default transformer config. - self.layers = ( - [TransformerSubLayerName.query, TransformerSubLayerName.value_] - if self.type == PeftType.lora - else [] - ) - if self.type != PeftType.none: - if TransformerSubLayerName.dense in self.layers: - # TODO: Support InputParallelLinear (different output format). - raise NotImplementedError("LoRA not supported for attention dense layer.") - if ( - sum( - name in self.layers - for name in ( - TransformerSubLayerName.key_value, - TransformerSubLayerName.key, - TransformerSubLayerName.value_, - ) - ) - > 1 - ): - raise ValueError( - f"{TransformerSubLayerName.key_value.value}, {TransformerSubLayerName.key.value} and {TransformerSubLayerName.value_.value} are mutually exclusive." + self.layers = [TransformerSubLayerName.query, TransformerSubLayerName.value_] + super()._validate() + if TransformerSubLayerName.dense in self.layers: + # TODO: Support InputParallelLinear (different output format). + raise NotImplementedError("LoRA not supported for attention dense layer.") + if ( + sum( + name in self.layers + for name in ( + TransformerSubLayerName.key_value, + TransformerSubLayerName.key, + TransformerSubLayerName.value_, ) + ) + > 1 + ): + raise ValueError( + f"{TransformerSubLayerName.key_value.value}, {TransformerSubLayerName.key.value} and {TransformerSubLayerName.value_.value} are mutually exclusive." + ) -for name in PeftType: - # We need this because we are using the reserved field name `type`. - # TODO: Implement proper dynamic typing. - TransformerPeftConfig.register_subclass(name.value, TransformerPeftConfig) +# for name in PeftType: +# # We need this because we are using the reserved field name `type`. +# # TODO: Implement proper dynamic typing. +# TransformerPeftConfig.register_subclass(name.value, TransformerPeftConfig) @config_class() @@ -246,10 +221,10 @@ class TransformerConfig(BaseBlockConfig): desc="Configuration for the rotary positional embeddings.", hint=FieldHint.architecture, ) - peft: TransformerPeftConfig = Field( - desc="Configuration for the parameter-efficient fine tuning.", - hint=FieldHint.architecture, - ) + # peft: PeftConfig = FieldUpdate( + # desc="Configuration for the parameter-efficient fine tuning.", + # hint=FieldHint.architecture, + # ) attention_lr_scale: float | None = Field( default=None, desc="Custom learning rate scale for the Attention projection weights.", @@ -303,12 +278,12 @@ class TransformerConfig(BaseBlockConfig): ) def _validate(self) -> None: - super()._validate() with self._set_implicit_default(): if self.kv_channels is None: self.kv_channels = div(self.hidden_size, self.num_attention_heads) Assert.multiple(self.num_attention_heads, self.head_groups) Assert.geq(self.attention_dropout, 0) + super()._validate() # with self._set_implicit_default(): # if self.ffn_hidden_size is None: diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 49778c63..8a8fd05e 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -10,9 +10,9 @@ from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss +from fast_llm.layers.common.config import RoutingType from fast_llm.layers.common.linear import Linear from fast_llm.layers.transformer.config import ( - RoutingType, TransformerConfig, TransformerDimNames, TransformerKwargs, diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index d6938b7e..29864ac0 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -163,6 +163,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: super().setup_tensor_space(tensor_space) +@config_class() class GPTModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "gpt" diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index eb5f00ed..774544fb 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -10,16 +10,12 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.layers.common.config import RoutingType from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor -from fast_llm.layers.transformer.config import ( - RoutingType, - TransformerDimNames, - TransformerKwargs, - TransformerLossNames, -) +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs, TransformerLossNames from fast_llm.layers.transformer.preprocessing import ( BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor, diff --git a/fast_llm/models/hybrid/config.py b/fast_llm/models/hybrid/config.py index 060550c5..57acf092 100644 --- a/fast_llm/models/hybrid/config.py +++ b/fast_llm/models/hybrid/config.py @@ -1,3 +1,4 @@ +import enum import logging import math import typing @@ -12,7 +13,7 @@ from fast_llm.engine.training.config import TrainerConfig from fast_llm.layers.common.config import LLMDimNames from fast_llm.layers.language_model.config import LanguageModelBaseConfig -from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig, SSMDimNames +from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.transformer import BaseBlock, TransformerLayer from fast_llm.models.gpt.config import GPTBatchConfig, PretrainedGPTModelConfig @@ -28,7 +29,7 @@ @config_class(registry=True) -class BlockConfig(Config): +class HybridBlockConfig(Config): _abstract = True block_class: typing.ClassVar[type[BaseBlock]] # config: TransformerConfig | SSMConfig @@ -48,8 +49,8 @@ def hidden_size(self) -> int: raise NotImplementedError() -@config_class(dynamic_type={BlockConfig: "transformer"}) -class TransformerBlockConfig(BlockConfig, TransformerConfig): +@config_class(dynamic_type={HybridBlockConfig: "transformer"}) +class TransformerBlockConfig(HybridBlockConfig, TransformerConfig): _abstract = False block_class: typing.ClassVar[type[BaseBlock]] = TransformerLayer @@ -57,8 +58,8 @@ def setup_tensor_space(self, tensor_space: "TensorSpace", block_name: str) -> No TransformerConfig.setup_tensor_space(self, tensor_space, block_name) -@config_class(dynamic_type={BlockConfig: "discrete_mamba2"}) -class DiscreteMamba2BlockConfig(BlockConfig, SSMConfig): +@config_class(dynamic_type={HybridBlockConfig: "discrete_mamba2"}) +class DiscreteMamba2BlockConfig(HybridBlockConfig, SSMConfig): _abstract = False block_class: typing.ClassVar[type[BaseBlock]] = LlambaBlock @@ -98,8 +99,8 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_name: str) -> None tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.conv_dim}_{block_name}", conv_dim)) -@config_class(dynamic_type={BlockConfig: "mamba"}) -class MambaBlockConfig(BlockConfig, SSMConfig): +@config_class(dynamic_type={HybridBlockConfig: "mamba"}) +class MambaBlockConfig(HybridBlockConfig, SSMConfig): _abstract = False block_class: typing.ClassVar[type[BaseBlock]] = LlambaOneBlock @@ -127,22 +128,39 @@ def setup_tensor_space(self, tensor_space: TensorSpace, name: str) -> None: tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.inner_proj_mamba}_{name}", d_inner * 2)) +class HybridBlockType(str, enum.Enum): + """ + An enum for the available block types, legacy format. + """ + + m: MambaBlockConfig + m2d: LlambaBlock + t: TransformerBlockConfig + + @config_class() -class HybridSSMBaseModelConfig(LanguageModelBaseConfig): +class HybridBaseModelConfig(LanguageModelBaseConfig): + """ + HybridBaseModelConfig is a configuration class for hybrid models. + Currently it supports two formats for architecture definition: + - the old and deprecated format with transformer and ssm fields (t, m2d, m), in wich case all blocks share the same config; + - and the new format with blocks field, in which case each block can have its own config. + """ + _abstract = False ############################################################################################ - # Note, transformer and ssm are here for legacy reasons, we should migrate to blocks field + # Note, transformer and ssm are here for legacy reasons transformer: TransformerConfig = Field( - desc="Configuration for the transformer architecture. Note, having transformer and ssm fields in HybridSSMBaseModelConfig is depricated.", + desc="Configuration for the transformer architecture. Note, having transformer and ssm fields in HybridBaseModelConfig is depricated.", hint=FieldHint.architecture, ) ssm: SSMConfig = Field( - desc="Configuration for the SSM architecture. Note, having transformer and ssm fields in HybridSSMBaseModelConfig is depricated.", + desc="Configuration for the SSM architecture. Note, having transformer and ssm fields in HybridBaseModelConfig is depricated.", hint=FieldHint.architecture, ) ############################################################################################ - blocks: dict[str, BlockConfig] = Field( + blocks: dict[str, HybridBlockConfig] = Field( default=None, desc="Named block configurations that can be referenced in block_pattern.", hint=FieldHint.architecture, @@ -150,20 +168,20 @@ class HybridSSMBaseModelConfig(LanguageModelBaseConfig): hybrid_block_layout: list[str] | None = Field( default=None, - desc=f"Pattern of blocks to use in the model (still supports the previous depricated format with {SSMBlockType.__members__.values()})", + desc=f"Pattern of blocks to use in the model (still supports the previous depricated format with {HybridBlockType.__members__.keys()})", hint=FieldHint.core, ) default_mtp_type: str | None = Field( default=None, - desc="Multi-token prediction mixer to use in the model. 't' for Transformer, 'm' for Mamba1, 'm2' for discrete Mamba2. If None, will use the last block type in `hybrid_block_layout`.", + desc="Multi-token prediction mixer to use in the model. Can be either one of the blocks, or follow the depricated legacy format: 't' for Transformer, 'm' for Mamba1, 'm2' for discrete Mamba2. If None, will use the last block type in `hybrid_block_layout`.", hint=FieldHint.optional, ) - # TODO: ideally these things should be move to LanguageModelBaseConfig? + # TODO: currently num_layers is defined in TransformerConfig, but ideally this should be migrated to LanguageModelBaseConfig in the future. - # Hence, for now: the num_layers should be set in the first transformer block, if no transformer blocks used we will fallback to num_layers from here. + # Hence, for now: the num_layers can be set in the first transformer block, if no transformer blocks used we will fallback to num_layers parameter defined here. num_layers: int = Field( - default=12, + default=None, desc="Number of layers in the transformer.", hint=FieldHint.architecture, valid=check_field(Assert.geq, 0), @@ -186,82 +204,114 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: super().setup_tensor_space(tensor_space) def _validate(self): - if self.blocks is not None and self.hybrid_block_layout is not None: - # Validate that all pattern entries refer to valid blocks - for block_name in self.hybrid_block_layout: - if block_name not in self.blocks: - raise ValueError(f"Block name '{block_name}' in block_pattern not found in blocks dictionary") - - first_transformer_block_config: TransformerBlockConfig | None = None - - for block_name, block_config in self.blocks.items(): - if isinstance(block_config, TransformerBlockConfig): - if first_transformer_block_config is None: - first_transformer_block_config = block_config - else: - logger.warning( - f"Found multiple transformer blocks with different number of layers, using num_layers from the first transformer block for all" - ) - block_config._validate() - - if first_transformer_block_config is not None: - num_layers = first_transformer_block_config.config.num_layers - logger.warning( - f"TransformerBlockConfig overwrites BaseModelConfig num_layers, setting num_layers = {num_layers}" - ) - self.num_layers = num_layers - else: + if self.blocks is None: logger.warning( - f"No transformer blocks found in blocks dictionary, using num_layers from BaseModelConfig: {self.num_layers} and falling back to old behavior with hybrid_block_layout containing any of {SSMBlockType.__members__.values()}" + f"Blocks not set, falling back to old behavior with hybrid_block_layout containing any of {HybridBlockType.__members__.keys()}" ) if self.hybrid_block_layout is None: with self._set_implicit_default(): - self.hybrid_block_layout = [SSMBlockType.mamba2_discrete.value] - - if len(self.hybrid_block_layout) != self.transformer.num_layers: - if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: - raise ValueError( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" + logger.warning( + f"No hybrid_block_layout found in HybridBaseModelConfig, using default block {HybridBlockType.m2d}" ) - num_repeats = int(self.transformer.num_layers // len(self.hybrid_block_layout)) - logger.warning( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" - ) - self.hybrid_block_layout = self.hybrid_block_layout * num_repeats + self.hybrid_block_layout = [HybridBlockType.m2d] - Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) + # Legacy format with t, m, m2d, convert to new format Assert.custom( lambda _: all( - block_type in SSMBlockType.__members__.values() for block_type in self.hybrid_block_layout + block_type in HybridBlockType.__members__.keys() for block_type in self.hybrid_block_layout ), - f"Invalid block type: {self.hybrid_block_layout}. Must be one of {SSMBlockType.__members__.values()}", + f"Invalid block type: {self.hybrid_block_layout}. Must be one of {HybridBlockType.__members__.keys()}", ) - Assert.custom( - lambda _: self.default_mtp_type in SSMBlockType.__members__.values() or self.default_mtp_type is None, - f"Invalid MTP type: {self.default_mtp_type}. Must be one of {SSMBlockType.__members__.values()} or None", + blocks = {} + for block_type in self.hybrid_block_layout: + if block_type not in blocks: + hybrid_block_config_cls = HybridBlockType[block_type] + if hybrid_block_config_cls == TransformerBlockConfig: + blocks[block_type] = TransformerConfig.from_dict(self.transformer.to_dict()) + elif hybrid_block_config_cls == MambaBlockConfig: + blocks[block_type] = SSMConfig.from_dict(self.ssm.to_dict()) + elif hybrid_block_config_cls == LlambaBlock: + blocks[block_type] = SSMConfig.from_dict(self.ssm.to_dict()) + else: + raise ValueError(f"Invalid block type: {block_type}") + self.blocks = blocks + self.hybrid_block_layout = [HybridBlockType[block_type] for block_type in self.hybrid_block_layout] + + Assert.gt(len(self.hybrid_block_layout), 0, "No blocks found in hybrid_block_layout") + # Validate that all pattern entries refer to valid blocks + for block_name in self.hybrid_block_layout: + if block_name not in self.blocks: + raise ValueError(f"Block name '{block_name}' in block_pattern not found in blocks dictionary") + + first_transformer_block_config: TransformerBlockConfig | None = None + + for block_name, block_config in self.blocks.items(): + if isinstance(block_config, TransformerBlockConfig): + if first_transformer_block_config is None: + first_transformer_block_config = block_config + elif block_config.num_layers != first_transformer_block_config.num_layers: + logger.warning( + f"Found multiple transformer blocks with different number of layers, using num_layers from the first transformer block for all" + ) + block_config._validate() + + # set num_layers from transformer block config if it exists and if num_layers is not set in HybridBaseModelConfig + # i.e. the resolution hierarchy for num_layers is: HybridBaseModelConfig.num_layers > TransformerBlockConfig.num_layers + if first_transformer_block_config is not None: + num_layers = first_transformer_block_config.num_layers + with self._set_implicit_default(): + if self.num_layers is None: + logger.warning( + f"TransformerBlockConfig overwrites BaseModelConfig num_layers, setting num_layers = {num_layers}" + ) + self.num_layers = num_layers + + # make sure that the hybrid_block_layout length matches the num_layers. If it doesn't, repeat the hybrid_block_layout; + if len(self.hybrid_block_layout) != self.num_layers: + if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: + raise ValueError( + f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" + ) + num_repeats = int(self.transformer.num_layers // len(self.hybrid_block_layout)) + logger.warning( + f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" ) - # TODO: prepare hybrid_block_layout here + self.hybrid_block_layout = self.hybrid_block_layout * num_repeats + + Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) with self._set_implicit_default(): if self.init_method_std_embed is None: self.init_method_std_embed = ( - first_transformer_block_config.config.init_method_std + first_transformer_block_config.init_method_std if first_transformer_block_config is not None else 0.02 ) if self.init_method_max_embed is None: self.init_method_max_embed = ( - first_transformer_block_config.config.init_method_max + first_transformer_block_config.init_method_max if first_transformer_block_config is not None else 0.02 ) if self.init_method_min_embed is None: self.init_method_min_embed = ( - first_transformer_block_config.config.init_method_min + first_transformer_block_config.init_method_min if first_transformer_block_config is not None else 0.02 ) + if self.prediction_heads > 1: + with self._set_implicit_default(): + if self.default_mtp_type is None: + logger.warning( + f"No default_mtp_type found in HybridBaseModelConfig, using the last block type in hybrid_block_layout: {self.hybrid_block_layout[-1]}" + ) + self.default_mtp_type = self.hybrid_block_layout[-1] + else: + if self.default_mtp_type not in self.hybrid_block_layout: + raise ValueError( + f"default_mtp_type {self.default_mtp_type} not found in hybrid_block_layout {self.hybrid_block_layout}" + ) super()._validate() @@ -302,7 +352,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: class HybridSSMModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "hybrid_ssm" - base_model: HybridSSMBaseModelConfig = FieldUpdate() + base_model: HybridBaseModelConfig = FieldUpdate() checkpoint_formats = FastLLMModelConfig.checkpoint_formats + ( LLambaHuggingfaceCheckpointFormat, AprielSSMHuggingfaceCheckpointFormat, diff --git a/fast_llm/models/hybrid/model.py b/fast_llm/models/hybrid/model.py index f464a42f..c8024564 100644 --- a/fast_llm/models/hybrid/model.py +++ b/fast_llm/models/hybrid/model.py @@ -6,26 +6,23 @@ from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 -from fast_llm.layers.ssm.mamba_layer import MambaLayer -from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.models.gpt.model import GPTBaseModel -from fast_llm.models.hybrid.config import HybridSSMBaseModelConfig, HybridSSMModelConfig +from fast_llm.models.hybrid.config import HybridBaseModelConfig, HybridSSMModelConfig logger = logging.getLogger(__name__) -class HybridSSMBaseModel[ConfigType: HybridSSMBaseModelConfig](GPTBaseModel[ConfigType]): +class HybridSSMBaseModel[ConfigType: HybridBaseModelConfig](GPTBaseModel[ConfigType]): """ A hybrid model that can interleave Transformer, Mamba and other blocks. """ - config_class: typing.ClassVar[type[HybridSSMBaseModelConfig]] = HybridSSMBaseModelConfig + config_class: typing.ClassVar[type[HybridBaseModelConfig]] = HybridBaseModelConfig _is_setup: bool = False def __init__( self, - config: HybridSSMBaseModelConfig, + config: HybridBaseModelConfig, distributed_config: DistributedConfig, ): @@ -39,39 +36,19 @@ def get_output_layers(self) -> list[Layer]: layers = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] if self._config.prediction_heads > 1: - block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1] + block_name = self._config.default_mtp_type + assert block_name in self._config.blocks, f"Block {block_name} not found in config" + BLOCK_CLS = self._config.blocks[block_name].block_class for i in range(1, self._config.prediction_heads): - if block_type == SSMBlockType.transformer.value: - layers.append( - TransformerLayer( - self._config.transformer, - self._tensor_space, - layer_index=len(self._config.hybrid_block_layout), - return_input=i != self._config.prediction_heads - 1, - ) - ) - elif block_type == SSMBlockType.mamba2_discrete.value: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba.value: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=MambaLayer, + layers.append( + BLOCK_CLS( + self._config.blocks[block_name], + self._tensor_space, layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, return_input=i != self._config.prediction_heads - 1, + block_name=block_name, ) - layers.append(mamba_block) - else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + ) layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=i)) return layers diff --git a/tests/common.py b/tests/common.py index 11567977..06003d9b 100644 --- a/tests/common.py +++ b/tests/common.py @@ -23,7 +23,7 @@ Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) -from fast_llm.models.hybrid.config import HybridSSMBaseModelConfig, LLambaHuggingfaceCheckpointFormat +from fast_llm.models.hybrid.config import HybridBaseModelConfig, LLambaHuggingfaceCheckpointFormat from fast_llm.tools.train import CliTrainingConfig from tests.compare_tensor_logs import CompareConfig, compare_tensor_logs @@ -449,7 +449,7 @@ def materialize_meta_tensors(model, tensor_space): def get_hybrid_config(hybrid_block_layout=["t", "m"], prediction_heads=1, default_mtp_type=None): - config = HybridSSMBaseModelConfig( + config = HybridBaseModelConfig( transformer=TransformerConfig(num_layers=len(hybrid_block_layout)), ssm=SSMConfig(), hybrid_block_layout=hybrid_block_layout, diff --git a/tests/test_ssms.py b/tests/test_hybrid.py similarity index 58% rename from tests/test_ssms.py rename to tests/test_hybrid.py index ea3732b8..bb0f6972 100644 --- a/tests/test_ssms.py +++ b/tests/test_hybrid.py @@ -1,25 +1,14 @@ -import pathlib from functools import partial import pytest import torch -from fast_llm.config import NoAutoValidate -from fast_llm.engine.checkpoint.config import CheckpointLoadConfig from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.engine.distributed.config import DistributedConfig, PhaseType +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.schedule.config import ScheduleConfig -from fast_llm.engine.schedule.runner import ScheduleRunner -from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames -from fast_llm.layers.ssm.config import SSMBlockType from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat -from fast_llm.models.hybrid.config import ( - AprielSSMHHybridHuggingfaceCheckpointFormat, - LLambaHuggingfaceCheckpointFormat, -) +from fast_llm.models.gpt.config import LlamaGPTHuggingfaceCheckpointFormat +from fast_llm.models.hybrid.config import LLambaHuggingfaceCheckpointFormat from tests.common import get_hybrid_config, materialize_meta_tensors try: @@ -27,7 +16,7 @@ from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 from fast_llm.layers.ssm.mamba_layer import MambaLayer - from fast_llm.models.hybrid.model import HybridSSMBaseModel, HybridSSMModel + from fast_llm.models.hybrid.model import HybridSSMBaseModel except ImportError: MambaLayer, LlambaBlock, HybridSSMBaseModel, DiscreteMamba2 = ( None, @@ -77,123 +66,124 @@ def get_hf_llamba_out(input_ids, path, format): return output, parameter_sum -@pytest.mark.slow -@pytest.mark.skipif( - not run_test or LMHeadModel is None, - reason=f"Skipping because one of the following: cartesia_pytorch.Llamba not installed or no CUDA available or Mamba not installed", -) -def test_load_from_llamba_checkpoint(distributed_config): - """ - Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. - """ - vocab_size = 128256 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json - batch_size = 2 - seq_length = 32 - - path = pathlib.Path("/mnt/checkpoints_fml/pretrained_models/Llamba-1B") - format = LLambaHuggingfaceCheckpointFormat - - x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") - hf_logits, parameter_sum_hf = get_hf_llamba_out(x, path, format) - hf_logits = hf_logits["logits"].cpu() - - # Create checkpoint load config - checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) - # Initialize model - model = HybridSSMModel.from_pretrained(checkpoint_config) - param_sum = 0 - for stage in model.stages: - for fsdp in stage.fsdps: - if hasattr(fsdp, "_weight_shard"): - param_sum += torch.sum(fsdp._weight_shard).item() - assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 - - # model = GPTModel.from_pretrained(checkpoint_config) - assert model.config.base_model.vocab_size == vocab_size - schedule_config = ScheduleConfig() - with NoAutoValidate(): - batch_config = GPTBatchConfig(micro_batch_size=batch_size, sequence_length=seq_length) - batch_config.setup(distributed_config) - batch_config.validate() - schedule_runner = ScheduleRunner( - config=schedule_config, - multi_stage=model, - distributed_config=model.distributed.config, - ) - schedule = Schedule( - multi_stage=model, - batch_config=batch_config, - schedule_config=schedule_config, - distributed_config=model.distributed.config, - phase=PhaseType.inference, - ) - schedule_runner.setup(model.distributed, optimizer=None) - - common_kwargs = { - TransformerKwargs.sequence_first: True, - TransformerKwargs.grad_output: False, - } - input_data = [(x, common_kwargs)] - - losses, success, metrics = schedule_runner.run_step( - iter([input_data]), schedule, iteration=0, return_metrics=True, preprocessed=True - ) - - logits = input_data[0][1]["logits"].cpu() - assert torch.allclose(logits, hf_logits, atol=1e-2) - +# @pytest.mark.slow +# @pytest.mark.skipif( +# not run_test or LMHeadModel is None, +# reason=f"Skipping because one of the following: cartesia_pytorch.Llamba not installed or no CUDA available or Mamba not installed", +# ) +# def test_load_from_llamba_checkpoint(distributed_config): +# """ +# Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. +# """ +# vocab_size = 128256 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json +# batch_size = 2 +# seq_length = 32 -def get_hf_apriel_hybrid_out(input_ids, path, format): - from fast_llm.models.hybrid.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridForCausalLM +# path = pathlib.Path("/mnt/checkpoints_fml/pretrained_models/Llamba-1B") +# format = LLambaHuggingfaceCheckpointFormat + +# x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") +# hf_logits, parameter_sum_hf = get_hf_llamba_out(x, path, format) +# hf_logits = hf_logits["logits"].cpu() + +# # Create checkpoint load config +# checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) +# # Initialize model +# model = HybridSSMModel.from_pretrained(checkpoint_config) +# param_sum = 0 +# for stage in model.stages: +# for fsdp in stage.fsdps: +# if hasattr(fsdp, "_weight_shard"): +# param_sum += torch.sum(fsdp._weight_shard).item() +# assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 + +# # model = GPTModel.from_pretrained(checkpoint_config) +# assert model.config.base_model.vocab_size == vocab_size +# schedule_config = ScheduleConfig() +# with NoAutoValidate(): +# batch_config = GPTBatchConfig(micro_batch_size=batch_size, sequence_length=seq_length) +# batch_config.setup(distributed_config) +# batch_config.validate() +# schedule_runner = ScheduleRunner( +# config=schedule_config, +# multi_stage=model, +# distributed_config=model.distributed.config, +# ) +# schedule = Schedule( +# multi_stage=model, +# batch_config=batch_config, +# schedule_config=schedule_config, +# distributed_config=model.distributed.config, +# phase=PhaseType.inference, +# ) +# schedule_runner.setup(model.distributed, optimizer=None) - model = AprielSSMHybridForCausalLM.from_pretrained(path, strict=True).to("cuda") - parameter_sum = sum(p.detach().cpu().numpy().sum() for p in model.parameters()) - print(f"Parameter sum: {parameter_sum}") - output = model(input_ids) - del model - torch.cuda.empty_cache() - return output, parameter_sum +# common_kwargs = { +# TransformerKwargs.sequence_first: True, +# TransformerKwargs.grad_output: False, +# } +# input_data = [(x, common_kwargs)] +# losses, success, metrics = schedule_runner.run_step( +# iter([input_data]), schedule, iteration=0, return_metrics=True, preprocessed=True +# ) -@pytest.mark.slow -@pytest.mark.skipif( - not run_test - and not pathlib.Path("/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_ssm2nd_init_mambainlama_debug").exists(), - reason=f"Skipping because no CUDA available or Mamba not installed", -) -def test_load_from_hybridssm_checkpoint(distributed_config): - """ - Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. - """ - vocab_size = 131072 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json - batch_size = 2 - seq_length = 32 +# logits = input_data[0][1]["logits"].cpu() +# assert torch.allclose(logits, hf_logits, atol=1e-2) + + +# def get_hf_apriel_hybrid_out(input_ids, path, format): +# from fast_llm.models.hybrid.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridForCausalLM + +# model = AprielSSMHybridForCausalLM.from_pretrained(path, strict=True).to("cuda") +# parameter_sum = sum(p.detach().cpu().numpy().sum() for p in model.parameters()) +# print(f"Parameter sum: {parameter_sum}") +# output = model(input_ids) +# del model +# torch.cuda.empty_cache() +# return output, parameter_sum + + +# @pytest.mark.slow +# @pytest.mark.skipif( +# not run_test +# and not pathlib.Path("/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_ssm2nd_init_mambainlama_debug").exists(), +# reason=f"Skipping because no CUDA available or Mamba not installed", +# ) +# def test_load_from_hybridssm_checkpoint(distributed_config): +# """ +# Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. +# """ +# vocab_size = 131072 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json +# batch_size = 2 +# seq_length = 32 - path = pathlib.Path("/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_ssm2nd_init_mambainlama_debug") - format = AprielSSMHHybridHuggingfaceCheckpointFormat +# path = pathlib.Path("/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_ssm2nd_init_mambainlama_debug") +# format = AprielSSMHHybridHuggingfaceCheckpointFormat - x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") - hf_logits, parameter_sum_hf = get_hf_apriel_hybrid_out(x, path, format) - hf_logits = hf_logits["logits"].cpu() +# x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") +# hf_logits, parameter_sum_hf = get_hf_apriel_hybrid_out(x, path, format) +# hf_logits = hf_logits["logits"].cpu() - # Create checkpoint load config - checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) - # Initialize model - model = HybridSSMModel.from_pretrained(checkpoint_config) - param_sum = 0 - for stage in model.stages: - for fsdp in stage.fsdps: - if hasattr(fsdp, "_weight_shard"): - param_sum += torch.sum(fsdp._weight_shard).item() - assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 +# # Create checkpoint load config +# checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) +# # Initialize model +# model = HybridSSMModel.from_pretrained(checkpoint_config) +# param_sum = 0 +# for stage in model.stages: +# for fsdp in stage.fsdps: +# if hasattr(fsdp, "_weight_shard"): +# param_sum += torch.sum(fsdp._weight_shard).item() +# assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 +# test legacy behavior @pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") @pytest.mark.parametrize( "hybrid_block_layout,LAYER_CLS", [ - ([SSMBlockType.mamba, SSMBlockType.transformer], MambaLayer), - ([SSMBlockType.mamba2_discrete, SSMBlockType.transformer], DiscreteMamba2), + (["m", "t"], MambaLayer), + (["m2d", "t"], DiscreteMamba2), ], ids=["mamba", "discrete_mamba2"], ) diff --git a/tests/test_modular_config.py b/tests/test_modular_config.py index 52d2750b..38a137a2 100644 --- a/tests/test_modular_config.py +++ b/tests/test_modular_config.py @@ -1,10 +1,10 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.ssm.config import SSMConfig from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.models.hybrid.config import HybridSSMBaseModelConfig, MambaBlockConfig, TransformerBlockConfig +from fast_llm.models.hybrid.config import HybridBaseModelConfig, MambaBlockConfig, TransformerBlockConfig from fast_llm.models.hybrid.model import HybridSSMBaseModel -config = HybridSSMBaseModelConfig( +config = HybridBaseModelConfig( blocks={ "transformer_block": TransformerBlockConfig( transformer=TransformerConfig( From 45008b5c87e11ed95fb0b95d0f15ba1e42ee75c2 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 27 May 2025 19:20:33 +0000 Subject: [PATCH 108/114] wip --- fast_llm/config.py | 4 +- fast_llm/layers/common/config.py | 55 +-------------- fast_llm/layers/language_model/config.py | 2 +- fast_llm/layers/ssm/blocks.py | 4 +- fast_llm/layers/ssm/discrete_mamba2.py | 4 +- fast_llm/layers/ssm/mamba_layer.py | 10 ++- fast_llm/layers/transformer/transformer.py | 8 +-- fast_llm/models/hybrid/config.py | 70 +++++++++--------- tests/test_hybrid.py | 82 ++++++++++++---------- 9 files changed, 91 insertions(+), 148 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 380100e3..d6e9a4b2 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -380,8 +380,8 @@ def validate[T: Config](self: T, *, _is_validating: bool = False) -> T: if expected_class is not None: # Should be handled in `from_dict`, but can fail if instantiating directly. - Assert.is_(self.__class__, expected_class) - + # TODO: is this ok? i.e. we want the assigned class to be a subclass of the expected class, not neccessarily exactly the same class. + Assert.custom(issubclass, expected_class, self.__class__) if not self._validated: try: self._validate() diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index bb1e6319..9ba6103b 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -158,7 +158,7 @@ class PeftType(str, enum.Enum): @config_class(registry=True) class PeftConfig(BaseModelConfig): - _abstract = True + _abstract = False type: PeftType = Field( default=PeftType.none, @@ -166,35 +166,6 @@ class PeftConfig(BaseModelConfig): hint=FieldHint.core, ) - # @classmethod - # def get_subclass(cls, name: str): - # return super().get_subclass(name) - - def validate(self: typing.Self, *, _is_validating: bool = False): - """ - Validate a class and mark it as read-only - This should not be overridden in derived classes. - """ - # Should be handled in `from_dict`, but can fail if instantiating directly. - try: - expected_class = self.get_subclass(self.type) - except KeyError as e: - # Delayed instantiation error in `from_dict`. - raise ValueError(*e.args) - - if expected_class is not None: - # Should be handled in `from_dict`, but can fail if instantiating directly. - # Assert.is_(expected_class, self.__class__) - Assert.custom(issubclass, expected_class, self.__class__) - - if not self._validated: - try: - self._validate() - except Exception as e: - raise type(e)("\n".join(e.args)) from None - self._validated = True - return self - @config_class(dynamic_type={PeftConfig: "none"}) class EmptyPeftConfig(PeftConfig): @@ -207,30 +178,6 @@ class EmptyPeftConfig(PeftConfig): def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": return linear - def validate(self: typing.Self, *, _is_validating: bool = False): - """ - Validate a class and mark it as read-only - This should not be overridden in derived classes. - """ - # Should be handled in `from_dict`, but can fail if instantiating directly. - try: - expected_class = self.get_subclass(self.type) - except KeyError as e: - # Delayed instantiation error in `from_dict`. - raise ValueError(*e.args) - - if expected_class is not None: - # Should be handled in `from_dict`, but can fail if instantiating directly. - Assert.is_(self.__class__, expected_class) - - if not self._validated: - try: - self._validate() - except Exception as e: - raise type(e)("\n".join(e.args)) from None - self._validated = True - return self - @config_class(dynamic_type={PeftConfig: "lora"}) class LoRAConfig(PeftConfig): diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index f1079f12..38650466 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -189,7 +189,7 @@ class LanguageModelBaseConfig(BaseModelConfig): desc="Enable debug mode.", hint=FieldHint.testing, ) - embeddings_hidden_dropout: bool = Field( + embeddings_hidden_dropout: float = Field( default=0.0, desc="Dropout applied to the embeddings.", hint=FieldHint.feature, diff --git a/fast_llm/layers/ssm/blocks.py b/fast_llm/layers/ssm/blocks.py index 35be0eb9..d08c122f 100644 --- a/fast_llm/layers/ssm/blocks.py +++ b/fast_llm/layers/ssm/blocks.py @@ -28,7 +28,7 @@ def __init__( def _create_mixer(self): self.mixer = DiscreteMamba2( - self._config, layer_idx=self._layer_index, tensor_space=self._tensor_space, name=self.block_name + self._config, layer_index=self._layer_index, tensor_space=self._tensor_space, name=self.block_name ) @@ -51,5 +51,5 @@ def __init__( def _create_mixer(self): self.mixer = MambaLayer( - self._config, layer_idx=self._layer_index, tensor_space=self._tensor_space, name=self.block_name + self._config, layer_index=self._layer_index, tensor_space=self._tensor_space, name=self.block_name ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 39e0902b..a907b535 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -10,7 +10,6 @@ from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ -from fast_llm.utils import get_lr_scale logger = logging.getLogger(__name__) @@ -49,8 +48,7 @@ def __init__( bias = config.add_bias_linear self.layer_idx = layer_index self._return_input = return_input - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None - mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) + mamba_layer_lr_scale = self.config.mamba_lr_scale logger.info(f"Setting lr_scale for layer {layer_index} of type {type(self)}: {mamba_layer_lr_scale}") td_inner = tensor_space.get_tensor_dim(f"{SSMDimNames.inner_dim}_{name}") diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 1160630d..aa774aba 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -9,7 +9,6 @@ from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ -from fast_llm.utils import get_lr_scale """ Note: this is mostly addapted from https://github.com/Zyphra/Zamba2, similar code is aslo in https://github.com/state-spaces/mamba. @@ -69,9 +68,9 @@ def __init__( self._debug_mode = config.debug_ssm # Tensor dims: - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) + td_inner = tensor_space.get_tensor_dim(f"{SSMDimNames.inner_dim}_{name}") td_inner_proj = tensor_space.get_tensor_dim( - SSMDimNames.inner_proj_mamba + f"{SSMDimNames.inner_proj_mamba}_{name}" ) # TensorDim("D_inner_2", self.d_inner * 2) tdt_rank = tensor_space.get_tensor_dim(f"{SSMDimNames.dt_rank}_{name}") td_x_proj = tensor_space.get_tensor_dim(f"{SSMDimNames.x_proj_dim}_{name}") @@ -83,8 +82,7 @@ def __init__( self.d_state = td_state.size self.d_model = td_model.size self.dt_rank = tdt_rank.size - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None - mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) + mamba_layer_lr_scale = self.config.mamba_lr_scale self.in_proj_weight = ParameterMeta.from_dims( (td_inner_proj, td_model), @@ -107,7 +105,7 @@ def __init__( td_x_proj, weight_init_method=kaiming_init_(td_inner.size), bias=False, - layer_lr_scale=mamba_layer_lr_scale, + lr_scale=mamba_layer_lr_scale, **factory_kwargs, ) self.x_proj.weight.auto_grad_accumulation = True diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index aa9682e7..0618821b 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -8,9 +8,9 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.common.config import BaseBlockConfig +from fast_llm.layers.common.config import BaseBlockConfig, LLMDimNames from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage @@ -44,7 +44,7 @@ def __init__( self._layer_index = layer_index self._debug_mode = self._config.debug_block or self._config.debug_block_memory - hidden_dim = self._tensor_space.get_tensor_dim(f"{TransformerDimNames.hidden}_{block_name}") + hidden_dim = self._tensor_space.get_tensor_dim(f"{LLMDimNames.hidden}_{block_name}") # Note, layer_lr_scale does not impact the norms self.norm_1 = self._config.normalization.get_layer(hidden_dim, lr_scale=self._config.norm_lr_scale) self.norm_2 = self._config.normalization.get_layer(hidden_dim, lr_scale=self._config.norm_lr_scale) @@ -52,7 +52,7 @@ def __init__( self._create_mixer() self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.block_name} mlp", layer_index=layer_index + self._config, self._tensor_space, f"{self.block_name}", layer_index=layer_index ) # PEFT. Layer freezing must be explicit now. diff --git a/fast_llm/models/hybrid/config.py b/fast_llm/models/hybrid/config.py index 57acf092..eda91087 100644 --- a/fast_llm/models/hybrid/config.py +++ b/fast_llm/models/hybrid/config.py @@ -40,14 +40,15 @@ class HybridBlockConfig(Config): doc="May be used to freeze some layers by setting their scale to zero.", hint=FieldHint.feature, ) + hidden_size: int = Field( + default=1024, + desc="Hidden size of the block.", + hint=FieldHint.architecture, + ) def setup_tensor_space(self, tensor_space: "TensorSpace", block_name: str) -> None: raise NotImplementedError() - @property - def hidden_size(self) -> int: - raise NotImplementedError() - @config_class(dynamic_type={HybridBlockConfig: "transformer"}) class TransformerBlockConfig(HybridBlockConfig, TransformerConfig): @@ -63,12 +64,6 @@ class DiscreteMamba2BlockConfig(HybridBlockConfig, SSMConfig): _abstract = False block_class: typing.ClassVar[type[BaseBlock]] = LlambaBlock - hidden_size: int = Field( - default=1024, - desc="Hidden size of the block.", - hint=FieldHint.architecture, - ) - # def _validate(self): # self.config.validate() @@ -76,6 +71,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_name: str) -> None d_inner = int(self.expansion_factor * self.hidden_size) if self.d_inner is None else self.d_inner # Hidden dimension + tensor_space.add_tensor_dim(TensorDim(f"{LLMDimNames.hidden}_{block_name}", self.hidden_size)) tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.model_dim}_{block_name}", self.hidden_size)) # Mamba-specific dimensions tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.inner_dim}_{block_name}", d_inner)) @@ -104,13 +100,7 @@ class MambaBlockConfig(HybridBlockConfig, SSMConfig): _abstract = False block_class: typing.ClassVar[type[BaseBlock]] = LlambaOneBlock - hidden_size: int = Field( - default=1024, - desc="Hidden size of the block.", - hint=FieldHint.architecture, - ) - - def setup_tensor_space(self, tensor_space: TensorSpace, name: str) -> None: + def setup_tensor_space(self, tensor_space: TensorSpace, block_name: str) -> None: if self.dt_rank is None: mamba_dt_rank = math.ceil(self.hidden_size / 16) @@ -119,23 +109,28 @@ def setup_tensor_space(self, tensor_space: TensorSpace, name: str) -> None: d_inner = int(self.expansion_factor * self.hidden_size) if self.d_inner is None else self.d_inner # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.model_dim}_{name}", self.hidden_size)) - tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.inner_dim}_{name}", d_inner)) - tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.state_dim}_{name}", self.state_size)) - tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.dt_rank}_{name}", mamba_dt_rank)) - tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.x_proj_dim}_{name}", mamba_dt_rank + self.state_size * 2)) - tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.conv_kernel_size}_{name}", self.conv_kernel_dimension)) - tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.inner_proj_mamba}_{name}", d_inner * 2)) + tensor_space.add_tensor_dim(TensorDim(f"{LLMDimNames.hidden}_{block_name}", self.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.model_dim}_{block_name}", self.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.inner_dim}_{block_name}", d_inner)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.state_dim}_{block_name}", self.state_size)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.dt_rank}_{block_name}", mamba_dt_rank)) + tensor_space.add_tensor_dim( + TensorDim(f"{SSMDimNames.x_proj_dim}_{block_name}", mamba_dt_rank + self.state_size * 2) + ) + tensor_space.add_tensor_dim( + TensorDim(f"{SSMDimNames.conv_kernel_size}_{block_name}", self.conv_kernel_dimension) + ) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.inner_proj_mamba}_{block_name}", d_inner * 2)) -class HybridBlockType(str, enum.Enum): +class HybridBlockType(enum.Enum): """ An enum for the available block types, legacy format. """ - m: MambaBlockConfig - m2d: LlambaBlock - t: TransformerBlockConfig + m = MambaBlockConfig + m2d = DiscreteMamba2BlockConfig + t = TransformerBlockConfig @config_class() @@ -181,7 +176,7 @@ class HybridBaseModelConfig(LanguageModelBaseConfig): # TODO: currently num_layers is defined in TransformerConfig, but ideally this should be migrated to LanguageModelBaseConfig in the future. # Hence, for now: the num_layers can be set in the first transformer block, if no transformer blocks used we will fallback to num_layers parameter defined here. num_layers: int = Field( - default=None, + default=12, desc="Number of layers in the transformer.", hint=FieldHint.architecture, valid=check_field(Assert.geq, 0), @@ -225,23 +220,22 @@ def _validate(self): blocks = {} for block_type in self.hybrid_block_layout: if block_type not in blocks: - hybrid_block_config_cls = HybridBlockType[block_type] + hybrid_block_config_cls = HybridBlockType[block_type].value if hybrid_block_config_cls == TransformerBlockConfig: - blocks[block_type] = TransformerConfig.from_dict(self.transformer.to_dict()) + blocks[block_type] = TransformerBlockConfig.from_dict(self.transformer.to_dict()) elif hybrid_block_config_cls == MambaBlockConfig: - blocks[block_type] = SSMConfig.from_dict(self.ssm.to_dict()) - elif hybrid_block_config_cls == LlambaBlock: - blocks[block_type] = SSMConfig.from_dict(self.ssm.to_dict()) + blocks[block_type] = MambaBlockConfig.from_dict(self.ssm.to_dict()) + elif hybrid_block_config_cls == DiscreteMamba2BlockConfig: + blocks[block_type] = DiscreteMamba2BlockConfig.from_dict(self.ssm.to_dict()) else: raise ValueError(f"Invalid block type: {block_type}") self.blocks = blocks - self.hybrid_block_layout = [HybridBlockType[block_type] for block_type in self.hybrid_block_layout] - Assert.gt(len(self.hybrid_block_layout), 0, "No blocks found in hybrid_block_layout") + Assert.gt(len(self.hybrid_block_layout), 0) # Validate that all pattern entries refer to valid blocks for block_name in self.hybrid_block_layout: if block_name not in self.blocks: - raise ValueError(f"Block name '{block_name}' in block_pattern not found in blocks dictionary") + raise ValueError(f"Block name '{block_name}' not found in blocks dictionary") first_transformer_block_config: TransformerBlockConfig | None = None @@ -253,7 +247,7 @@ def _validate(self): logger.warning( f"Found multiple transformer blocks with different number of layers, using num_layers from the first transformer block for all" ) - block_config._validate() + block_config.validate() # set num_layers from transformer block config if it exists and if num_layers is not set in HybridBaseModelConfig # i.e. the resolution hierarchy for num_layers is: HybridBaseModelConfig.num_layers > TransformerBlockConfig.num_layers diff --git a/tests/test_hybrid.py b/tests/test_hybrid.py index bb0f6972..e0e41903 100644 --- a/tests/test_hybrid.py +++ b/tests/test_hybrid.py @@ -1,11 +1,10 @@ -from functools import partial - import pytest import torch from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.models.gpt.config import LlamaGPTHuggingfaceCheckpointFormat from fast_llm.models.hybrid.config import LLambaHuggingfaceCheckpointFormat @@ -177,56 +176,63 @@ def get_hf_llamba_out(input_ids, path, format): # assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 -# test legacy behavior +# test legacy behavior of using m and m2d +# @pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") +# @pytest.mark.parametrize( +# "hybrid_block_layout,LAYER_CLS", +# [ +# (["m"], MambaLayer), +# (["m2d"], DiscreteMamba2), +# ], +# ids=["mamba", "discrete_mamba2"], +# ) +# def test_mamba_layer(distributed_config, distributed, hybrid_block_layout, LAYER_CLS): +# hybrid_config: HybridBaseModelConfig = get_hybrid_config(hybrid_block_layout=hybrid_block_layout) +# tensor_space = TensorSpace(distributed_config=distributed_config) +# hybrid_config.setup_tensor_space(tensor_space) +# layer = LAYER_CLS(hybrid_config.ssm, layer_index=0, tensor_space=tensor_space, name=hybrid_block_layout[0]) +# tensor_space.setup(distributed) +# materialize_meta_tensors(layer, tensor_space) +# layer.to(distributed.device) + +# batch_size = 2 +# seq_length = 32 +# hidden_size = hybrid_config.transformer.hidden_size +# x = torch.randn(batch_size, seq_length, hidden_size, device=distributed.device) + +# # Run forward pass +# output, _ = layer(x, {}) + +# loss = output.sum() +# loss.backward() +# # Basic shape checkss +# assert output.shape == x.shape +# assert not torch.isnan(output).any() +# assert not torch.isinf(output).any() + + +# test legacy behavior of using m and m2d @pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") @pytest.mark.parametrize( - "hybrid_block_layout,LAYER_CLS", + "hybrid_block_layout", [ - (["m", "t"], MambaLayer), - (["m2d", "t"], DiscreteMamba2), + (["m"]), + (["m2d"]), ], ids=["mamba", "discrete_mamba2"], ) -def test_mamba_layer(distributed_config, distributed, hybrid_block_layout, LAYER_CLS): +def test_mamba_block(distributed_config, distributed, hybrid_block_layout): hybrid_config = get_hybrid_config(hybrid_block_layout=hybrid_block_layout) tensor_space = TensorSpace(distributed_config=distributed_config) - hybrid_config.setup_tensor_space(tensor_space) - layer = LAYER_CLS(hybrid_config.ssm, layer_idx=0, tensor_space=tensor_space) - tensor_space.setup(distributed) - materialize_meta_tensors(layer, tensor_space) - layer.to(distributed.device) - - batch_size = 2 - seq_length = 32 - hidden_size = hybrid_config.transformer.hidden_size - x = torch.randn(batch_size, seq_length, hidden_size, device=distributed.device) - - # Run forward pass - output, _ = layer(x, {}) - - loss = output.sum() - loss.backward() - # Basic shape checkss - assert output.shape == x.shape - assert not torch.isnan(output).any() - assert not torch.isinf(output).any() - - -@pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") -def test_mamba_block(distributed_config, distributed): - hybrid_config = get_hybrid_config(hybrid_block_layout=["m", "t"]) - tensor_space = TensorSpace(distributed_config=distributed_config) tensor_space.setup(distributed) hybrid_config.setup_tensor_space(tensor_space) layer_idx = 0 - - mixer_cls = partial(MambaLayer, layer_idx=layer_idx) - block = LlambaBlock( - hybrid_config.transformer, + BLOCK_CLS = hybrid_config.blocks[hybrid_block_layout[0]].block_class + block = BLOCK_CLS( hybrid_config.ssm, - mixer_cls=mixer_cls, tensor_space=tensor_space, layer_index=layer_idx, + block_name=hybrid_block_layout[0], ) materialize_meta_tensors(block, tensor_space) From a37895423e018968b92ab6179c4f29f80ef06be3 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 28 May 2025 15:41:27 +0000 Subject: [PATCH 109/114] wip hybrid block architecture --- fast_llm/layers/common/config.py | 25 ++- fast_llm/layers/language_model/config.py | 8 + fast_llm/layers/ssm/blocks.py | 2 +- fast_llm/models/auto.py | 4 +- fast_llm/models/gpt/config.py | 8 - fast_llm/models/hybrid/config.py | 33 +-- fast_llm/models/hybrid/conversion.py | 24 +-- fast_llm/models/hybrid/huggingface.py | 16 +- fast_llm/models/hybrid/model.py | 10 +- fast_llm/models/hybrid/trainer.py | 4 +- tests/common.py | 4 +- tests/test_hybrid.py | 257 ++++++++++------------- tests/test_modular_config.py | 4 +- tests/test_mtp.py | 6 +- 14 files changed, 198 insertions(+), 207 deletions(-) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 9ba6103b..96affa36 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -1,3 +1,4 @@ +import abc import enum import typing @@ -158,7 +159,6 @@ class PeftType(str, enum.Enum): @config_class(registry=True) class PeftConfig(BaseModelConfig): - _abstract = False type: PeftType = Field( default=PeftType.none, @@ -166,6 +166,22 @@ class PeftConfig(BaseModelConfig): hint=FieldHint.core, ) + @abc.abstractmethod + def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": + pass + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is PeftConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return EmptyPeftConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + @config_class(dynamic_type={PeftConfig: "none"}) class EmptyPeftConfig(PeftConfig): @@ -175,8 +191,8 @@ class EmptyPeftConfig(PeftConfig): _abstract = False - def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": - return linear + def apply_linear(self, *args, **kwargs) -> "LinearLike": + return args[0] @config_class(dynamic_type={PeftConfig: "lora"}) @@ -270,8 +286,9 @@ def _validate(self) -> None: @config_class() class BaseBlockConfig(BaseModelConfig): - _abstract = True + _abstract = False peft: PeftConfig = Field( + # default_factory=lambda: PeftConfig(type=PeftType.none), desc="Configuration for the parameter-efficient fine tuning.", hint=FieldHint.architecture, ) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 38650466..c6096bb6 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -199,6 +199,10 @@ class LanguageModelBaseConfig(BaseModelConfig): desc="Configuration for the normalization in the head.", hint=FieldHint.architecture, ) + # Debug, to get an exact match with megatron init. + use_megatron_initialization: bool = Field( + default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing + ) def _validate(self) -> None: # # self.transformer.validate() @@ -261,4 +265,8 @@ def from_flat_dict( cls._handle_renamed_field(default, "normalization_type", "type") cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") + + if "match_megatron" in default: + assert "use_megatron_initialization" not in default + default["use_megatron_initialization"] = default.pop("match_megatron") return super().from_flat_dict(default, strict) diff --git a/fast_llm/layers/ssm/blocks.py b/fast_llm/layers/ssm/blocks.py index d08c122f..d0521c32 100644 --- a/fast_llm/layers/ssm/blocks.py +++ b/fast_llm/layers/ssm/blocks.py @@ -50,6 +50,6 @@ def __init__( super().__init__(config, tensor_space, layer_index, block_name, return_input) def _create_mixer(self): - self.mixer = MambaLayer( + self.mamba1 = MambaLayer( self._config, layer_index=self._layer_index, tensor_space=self._tensor_space, name=self.block_name ) diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index 96c2917d..8fc0a09c 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -2,7 +2,7 @@ from fast_llm.engine.training.config import TrainerConfig from fast_llm.models.custom.config import CustomModelConfig, CustomTrainerConfig from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig -from fast_llm.models.hybrid.config import HybridSSMModelConfig, HybridTrainerConfig +from fast_llm.models.hybrid.config import HybridModelConfig, HybridTrainerConfig from fast_llm.utils import Registry model_registry = Registry[str, FastLLMModelConfig]( @@ -12,7 +12,7 @@ for model in [ GPTModelConfig, CustomModelConfig, - HybridSSMModelConfig, + HybridModelConfig, ] }, ) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 29864ac0..583d4703 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -107,11 +107,6 @@ class GPTBaseModelConfig(LanguageModelBaseConfig): _abstract = False - # Debug, to get an exact match with megatron init. - use_megatron_initialization: bool = Field( - default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing - ) - transformer: TransformerConfig = Field( desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, @@ -127,9 +122,6 @@ def _from_dict( # TODO v0.3: Remove backward compatibility fix if "transposed_mlp_weight" in default: assert default.pop("transposed_mlp_weight") - if "match_megatron" in default: - assert "use_megatron_initialization" not in default - default["use_megatron_initialization"] = default.pop("match_megatron") if "layer_norm_impl" in default: assert "normalization_implementation" not in default default["normalization_implementation"] = default.pop("layer_norm_impl") diff --git a/fast_llm/models/hybrid/config.py b/fast_llm/models/hybrid/config.py index eda91087..a5e2f6e1 100644 --- a/fast_llm/models/hybrid/config.py +++ b/fast_llm/models/hybrid/config.py @@ -21,8 +21,8 @@ if typing.TYPE_CHECKING: from fast_llm.models.gpt.model import GPTInferenceRunner - from fast_llm.models.hybrid.huggingface import HuggingfaceHybridSSMModelForCausalLM - from fast_llm.models.hybrid.model import HybridSSMModel + from fast_llm.models.hybrid.huggingface import HuggingfaceHybridModelForCausalLM + from fast_llm.models.hybrid.model import HybridModel from fast_llm.models.hybrid.trainer import SSMTrainer logger = logging.getLogger(__name__) @@ -46,9 +46,6 @@ class HybridBlockConfig(Config): hint=FieldHint.architecture, ) - def setup_tensor_space(self, tensor_space: "TensorSpace", block_name: str) -> None: - raise NotImplementedError() - @config_class(dynamic_type={HybridBlockConfig: "transformer"}) class TransformerBlockConfig(HybridBlockConfig, TransformerConfig): @@ -94,6 +91,8 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_name: str) -> None tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.inner_proj_mamba2}_{block_name}", inner_proj_dim)) tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.conv_dim}_{block_name}", conv_dim)) + SSMConfig.setup_tensor_space(self, tensor_space, block_name) + @config_class(dynamic_type={HybridBlockConfig: "mamba"}) class MambaBlockConfig(HybridBlockConfig, SSMConfig): @@ -122,6 +121,8 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_name: str) -> None ) tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.inner_proj_mamba}_{block_name}", d_inner * 2)) + SSMConfig.setup_tensor_space(self, tensor_space, block_name) + class HybridBlockType(enum.Enum): """ @@ -343,7 +344,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: @config_class() -class HybridSSMModelConfig(FastLLMModelConfig): +class HybridModelConfig(FastLLMModelConfig): _abstract = False model_name: typing.ClassVar[str] = "hybrid_ssm" base_model: HybridBaseModelConfig = FieldUpdate() @@ -354,32 +355,32 @@ class HybridSSMModelConfig(FastLLMModelConfig): ) @classmethod - def get_model_class(cls) -> type["HybridSSMModel"]: - from fast_llm.models.hybrid.model import HybridSSMModel + def get_model_class(cls) -> type["HybridModel"]: + from fast_llm.models.hybrid.model import HybridModel - return HybridSSMModel + return HybridModel @classmethod - def get_huggingface_model_class(cls) -> type["HuggingfaceHybridSSMModelForCausalLM"]: - from fast_llm.models.hybrid.huggingface import HuggingfaceHybridSSMModelForCausalLM + def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceHybridModelForCausalLM"]: + from fast_llm.models.hybrid.huggingface import HuggingfaceHybridModelForCausalLM - return HuggingfaceHybridSSMModelForCausalLM + return HuggingfaceHybridModelForCausalLM def _validate(self): logger.warning( - "HybridSSMModelConfig is being instantiated. This model is experimental and may not work as expected." + "HybridModelConfig is being instantiated. This model is experimental and may not work as expected." ) super()._validate() @config_class() -class PretrainedHybridSSMModelConfig(PretrainedFastLLMModelConfig): +class PretrainedHybridModelConfig(PretrainedFastLLMModelConfig): _abstract = False - model: HybridSSMModelConfig = FieldUpdate() + model: HybridModelConfig = FieldUpdate() @config_class() -class HybridTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): +class HybridTrainerConfig(PretrainedHybridModelConfig, TrainerConfig): data: GPTDataConfig = FieldUpdate() batch: GPTBatchConfig = FieldUpdate() reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() diff --git a/fast_llm/models/hybrid/conversion.py b/fast_llm/models/hybrid/conversion.py index ef270145..73f1bceb 100644 --- a/fast_llm/models/hybrid/conversion.py +++ b/fast_llm/models/hybrid/conversion.py @@ -23,10 +23,10 @@ from fast_llm.models.hybrid.config import ( AprielSSMHHybridHuggingfaceCheckpointFormat, AprielSSMHuggingfaceCheckpointFormat, - HybridSSMModelConfig, + HybridModelConfig, LLambaHuggingfaceCheckpointFormat, ) -from fast_llm.models.hybrid.model import HybridSSMModel +from fast_llm.models.hybrid.model import HybridModel from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -40,8 +40,8 @@ class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): If block_pattern is provided, it will export/import it as-is. """ - _model: HybridSSMModel - _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + _model: HybridModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridModelConfig _default_block_type: str = SSMBlockType.mamba2_discrete.value @classmethod @@ -67,8 +67,8 @@ def _create_config_converters(cls) -> list[ParamConverter]: class CommonSSMHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): - _model: HybridSSMModel - _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + _model: HybridModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridModelConfig @classmethod def _create_config_converters(cls) -> list[ParamConverter]: @@ -209,8 +209,8 @@ def _get_weight_and_bias_converters( class LLambaHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandler): - _model: HybridSSMModel - _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + _model: HybridModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridModelConfig format: typing.ClassVar[type[CheckpointFormat]] = LLambaHuggingfaceCheckpointFormat _hf_prefix: str = "backbone" @@ -414,8 +414,8 @@ class AprielSSMHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandle Lamba-like configs, pure SSM models. """ - _model: HybridSSMModel - _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + _model: HybridModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridModelConfig format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHuggingfaceCheckpointFormat @classmethod @@ -536,8 +536,8 @@ class AprielSSMHHybridHuggingfaceCheckpointHandler( Lamba-like configs, models that interleave LLama like layers with LLamba-like SSM layers. """ - _model: HybridSSMModel - _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig + _model: HybridModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridModelConfig format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHHybridHuggingfaceCheckpointFormat _default_block_type: str = SSMBlockType.mamba2_discrete.value diff --git a/fast_llm/models/hybrid/huggingface.py b/fast_llm/models/hybrid/huggingface.py index 6e818a32..8191a5a2 100644 --- a/fast_llm/models/hybrid/huggingface.py +++ b/fast_llm/models/hybrid/huggingface.py @@ -1,21 +1,21 @@ import logging -from fast_llm.engine.huggingface.config import HuggingfaceModelConfig +from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM -from fast_llm.models.hybrid.config import HybridSSMModelConfig -from fast_llm.models.hybrid.model import HybridSSMModel +from fast_llm.models.hybrid.config import HybridModelConfig +from fast_llm.models.hybrid.model import HybridModel logger = logging.getLogger(__name__) class HuggingfaceSSMModelConfig(HuggingfaceModelConfig): model_type = "fast_llm_ssm" - model_config_class = HybridSSMModelConfig - fast_llm_config: HybridSSMModelConfig + model_config_class = HybridModelConfig + fast_llm_config: HybridModelConfig -class HuggingfaceHybridSSMModelForCausalLM(HuggingfaceGPTModelForCausalLM): +class HuggingfaceHybridModelForCausalLM(HuggingfaceGPTModelForCausalLM): config_class = HuggingfaceSSMModelConfig config: HuggingfaceSSMModelConfig - model_class = HybridSSMModel - _fast_llm_model: HybridSSMModel + model_class = HybridModel + _fast_llm_model: HybridModel diff --git a/fast_llm/models/hybrid/model.py b/fast_llm/models/hybrid/model.py index c8024564..87a731c5 100644 --- a/fast_llm/models/hybrid/model.py +++ b/fast_llm/models/hybrid/model.py @@ -7,12 +7,12 @@ from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.models.gpt.model import GPTBaseModel -from fast_llm.models.hybrid.config import HybridBaseModelConfig, HybridSSMModelConfig +from fast_llm.models.hybrid.config import HybridBaseModelConfig, HybridModelConfig logger = logging.getLogger(__name__) -class HybridSSMBaseModel[ConfigType: HybridBaseModelConfig](GPTBaseModel[ConfigType]): +class HybridBaseModel[ConfigType: HybridBaseModelConfig](GPTBaseModel[ConfigType]): """ A hybrid model that can interleave Transformer, Mamba and other blocks. """ @@ -79,10 +79,10 @@ def get_layers(self) -> list[Layer]: return layers -class HybridSSMModel[ConfigType: HybridSSMModelConfig](FastLLMModel[ConfigType]): +class HybridModel[ConfigType: HybridModelConfig](FastLLMModel[ConfigType]): """ A hybrid model that combines Transformer and SSM blocks. """ - config_class: typing.ClassVar[type[HybridSSMModelConfig]] = HybridSSMModelConfig - base_model_class: typing.ClassVar[type[HybridSSMBaseModel]] = HybridSSMBaseModel + config_class: typing.ClassVar[type[HybridModelConfig]] = HybridModelConfig + base_model_class: typing.ClassVar[type[HybridBaseModel]] = HybridBaseModel diff --git a/fast_llm/models/hybrid/trainer.py b/fast_llm/models/hybrid/trainer.py index 55c16ad0..9e489e89 100644 --- a/fast_llm/models/hybrid/trainer.py +++ b/fast_llm/models/hybrid/trainer.py @@ -2,9 +2,9 @@ from fast_llm.models.gpt.trainer import GPTTrainer from fast_llm.models.hybrid.config import HybridTrainerConfig -from fast_llm.models.hybrid.model import HybridSSMModel +from fast_llm.models.hybrid.model import HybridModel class SSMTrainer[ConfigType: HybridTrainerConfig](GPTTrainer[ConfigType]): config_class: typing.ClassVar[type[HybridTrainerConfig]] = HybridTrainerConfig - model_class: typing.ClassVar[type[HybridSSMModel]] = HybridSSMModel + model_class: typing.ClassVar[type[HybridModel]] = HybridModel diff --git a/tests/common.py b/tests/common.py index 06003d9b..94175792 100644 --- a/tests/common.py +++ b/tests/common.py @@ -36,7 +36,7 @@ FORCE_REUSE_RESULTS = int(os.environ.get("FORCE_REUSE_RESULTS", 0)) != 0 REUSE_RESULTS = FORCE_REUSE_RESULTS or int(os.environ.get("REUSE_RESULTS", 0)) != 0 _LOG_LEVEL = int(os.environ.get("LOG_LEVEL", 13)) -TEST_MODEL = os.environ.get("MODEL", "llama") +TEST_MODEL = os.environ.get("MODEL", "llamba") ARTIFACT_PATH = "runs/0/artifacts" @@ -204,7 +204,7 @@ ] CONFIG_LLAMA_MTP_COMMON = CONFIG_LLAMA_MTP_FAST_LLM + ["model.distributed.training_dtype=bf16"] -CONFIG_LLAMBA_FAST_LLM = CONFIG_LLAMA_FAST_LLM + ["model.base_model.hybrid_block_layout==['t','m']"] +CONFIG_LLAMBA_FAST_LLM = CONFIG_LLAMA_FAST_LLM + ["model.base_model.hybrid_block_layout=['m2d','m2d']"] CONFIG_LLAMBA_MEGATRON = CONFIG_LLAMA_MEGATRON + [] CONFIG_LLAMBA_COMMON = CONFIG_LLAMBA_FAST_LLM diff --git a/tests/test_hybrid.py b/tests/test_hybrid.py index e0e41903..ea5c0a32 100644 --- a/tests/test_hybrid.py +++ b/tests/test_hybrid.py @@ -1,23 +1,31 @@ +import pathlib + import pytest import torch +from fast_llm.config import NoAutoValidate +from fast_llm.engine.checkpoint.config import CheckpointLoadConfig from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.schedule.config import ScheduleConfig +from fast_llm.engine.schedule.runner import ScheduleRunner +from fast_llm.engine.schedule.schedule import Schedule from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.models.gpt.config import LlamaGPTHuggingfaceCheckpointFormat -from fast_llm.models.hybrid.config import LLambaHuggingfaceCheckpointFormat +from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat +from fast_llm.models.hybrid.config import ( + AprielSSMHHybridHuggingfaceCheckpointFormat, + LLambaHuggingfaceCheckpointFormat, +) +from fast_llm.models.hybrid.model import HybridModel from tests.common import get_hybrid_config, materialize_meta_tensors try: - from blocks import LlambaBlock - - from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 from fast_llm.layers.ssm.mamba_layer import MambaLayer - from fast_llm.models.hybrid.model import HybridSSMBaseModel + from fast_llm.models.hybrid.model import HybridBaseModel except ImportError: - MambaLayer, LlambaBlock, HybridSSMBaseModel, DiscreteMamba2 = ( + MambaLayer, LlambaBlock, HybridBaseModel, DiscreteMamba2 = ( None, None, None, @@ -65,150 +73,115 @@ def get_hf_llamba_out(input_ids, path, format): return output, parameter_sum -# @pytest.mark.slow -# @pytest.mark.skipif( -# not run_test or LMHeadModel is None, -# reason=f"Skipping because one of the following: cartesia_pytorch.Llamba not installed or no CUDA available or Mamba not installed", -# ) -# def test_load_from_llamba_checkpoint(distributed_config): -# """ -# Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. -# """ -# vocab_size = 128256 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json -# batch_size = 2 -# seq_length = 32 +@pytest.mark.slow +@pytest.mark.skipif( + not run_test or LMHeadModel is None, + reason=f"Skipping because one of the following: cartesia_pytorch.Llamba not installed or no CUDA available or Mamba not installed", +) +def test_load_from_llamba_checkpoint(distributed_config): + """ + Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. + """ + vocab_size = 128256 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json + batch_size = 2 + seq_length = 32 -# path = pathlib.Path("/mnt/checkpoints_fml/pretrained_models/Llamba-1B") -# format = LLambaHuggingfaceCheckpointFormat - -# x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") -# hf_logits, parameter_sum_hf = get_hf_llamba_out(x, path, format) -# hf_logits = hf_logits["logits"].cpu() - -# # Create checkpoint load config -# checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) -# # Initialize model -# model = HybridSSMModel.from_pretrained(checkpoint_config) -# param_sum = 0 -# for stage in model.stages: -# for fsdp in stage.fsdps: -# if hasattr(fsdp, "_weight_shard"): -# param_sum += torch.sum(fsdp._weight_shard).item() -# assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 - -# # model = GPTModel.from_pretrained(checkpoint_config) -# assert model.config.base_model.vocab_size == vocab_size -# schedule_config = ScheduleConfig() -# with NoAutoValidate(): -# batch_config = GPTBatchConfig(micro_batch_size=batch_size, sequence_length=seq_length) -# batch_config.setup(distributed_config) -# batch_config.validate() -# schedule_runner = ScheduleRunner( -# config=schedule_config, -# multi_stage=model, -# distributed_config=model.distributed.config, -# ) -# schedule = Schedule( -# multi_stage=model, -# batch_config=batch_config, -# schedule_config=schedule_config, -# distributed_config=model.distributed.config, -# phase=PhaseType.inference, -# ) -# schedule_runner.setup(model.distributed, optimizer=None) + path = pathlib.Path("/mnt/checkpoints_fml/pretrained_models/Llamba-1B") + format = LLambaHuggingfaceCheckpointFormat + + x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") + hf_logits, parameter_sum_hf = get_hf_llamba_out(x, path, format) + hf_logits = hf_logits["logits"].cpu() + + # Create checkpoint load config + checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) + # Initialize model + model = HybridModel.from_pretrained(checkpoint_config) + param_sum = 0 + for stage in model.stages: + for fsdp in stage.fsdps: + if hasattr(fsdp, "_weight_shard"): + param_sum += torch.sum(fsdp._weight_shard).item() + assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 + + # model = GPTModel.from_pretrained(checkpoint_config) + assert model.config.base_model.vocab_size == vocab_size + schedule_config = ScheduleConfig() + with NoAutoValidate(): + batch_config = GPTBatchConfig(micro_batch_size=batch_size, sequence_length=seq_length) + batch_config.setup(distributed_config) + batch_config.validate() + schedule_runner = ScheduleRunner( + config=schedule_config, + multi_stage=model, + distributed_config=model.distributed.config, + ) + schedule = Schedule( + multi_stage=model, + batch_config=batch_config, + schedule_config=schedule_config, + distributed_config=model.distributed.config, + phase=PhaseType.inference, + ) + schedule_runner.setup(model.distributed, optimizer=None) -# common_kwargs = { -# TransformerKwargs.sequence_first: True, -# TransformerKwargs.grad_output: False, -# } -# input_data = [(x, common_kwargs)] + common_kwargs = { + TransformerKwargs.sequence_first: True, + TransformerKwargs.grad_output: False, + } + input_data = [(x, common_kwargs)] -# losses, success, metrics = schedule_runner.run_step( -# iter([input_data]), schedule, iteration=0, return_metrics=True, preprocessed=True -# ) + losses, success, metrics = schedule_runner.run_step( + iter([input_data]), schedule, iteration=0, return_metrics=True, preprocessed=True + ) -# logits = input_data[0][1]["logits"].cpu() -# assert torch.allclose(logits, hf_logits, atol=1e-2) - - -# def get_hf_apriel_hybrid_out(input_ids, path, format): -# from fast_llm.models.hybrid.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridForCausalLM - -# model = AprielSSMHybridForCausalLM.from_pretrained(path, strict=True).to("cuda") -# parameter_sum = sum(p.detach().cpu().numpy().sum() for p in model.parameters()) -# print(f"Parameter sum: {parameter_sum}") -# output = model(input_ids) -# del model -# torch.cuda.empty_cache() -# return output, parameter_sum - - -# @pytest.mark.slow -# @pytest.mark.skipif( -# not run_test -# and not pathlib.Path("/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_ssm2nd_init_mambainlama_debug").exists(), -# reason=f"Skipping because no CUDA available or Mamba not installed", -# ) -# def test_load_from_hybridssm_checkpoint(distributed_config): -# """ -# Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. -# """ -# vocab_size = 131072 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json -# batch_size = 2 -# seq_length = 32 + logits = input_data[0][1]["logits"].cpu() + assert torch.allclose(logits, hf_logits, atol=1e-2) -# path = pathlib.Path("/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_ssm2nd_init_mambainlama_debug") -# format = AprielSSMHHybridHuggingfaceCheckpointFormat -# x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") -# hf_logits, parameter_sum_hf = get_hf_apriel_hybrid_out(x, path, format) -# hf_logits = hf_logits["logits"].cpu() +def get_hf_apriel_hybrid_out(input_ids, path, format): + from fast_llm.models.hybrid.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridForCausalLM -# # Create checkpoint load config -# checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) -# # Initialize model -# model = HybridSSMModel.from_pretrained(checkpoint_config) -# param_sum = 0 -# for stage in model.stages: -# for fsdp in stage.fsdps: -# if hasattr(fsdp, "_weight_shard"): -# param_sum += torch.sum(fsdp._weight_shard).item() -# assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 + model = AprielSSMHybridForCausalLM.from_pretrained(path, strict=True).to("cuda") + parameter_sum = sum(p.detach().cpu().numpy().sum() for p in model.parameters()) + print(f"Parameter sum: {parameter_sum}") + output = model(input_ids) + del model + torch.cuda.empty_cache() + return output, parameter_sum -# test legacy behavior of using m and m2d -# @pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") -# @pytest.mark.parametrize( -# "hybrid_block_layout,LAYER_CLS", -# [ -# (["m"], MambaLayer), -# (["m2d"], DiscreteMamba2), -# ], -# ids=["mamba", "discrete_mamba2"], -# ) -# def test_mamba_layer(distributed_config, distributed, hybrid_block_layout, LAYER_CLS): -# hybrid_config: HybridBaseModelConfig = get_hybrid_config(hybrid_block_layout=hybrid_block_layout) -# tensor_space = TensorSpace(distributed_config=distributed_config) -# hybrid_config.setup_tensor_space(tensor_space) -# layer = LAYER_CLS(hybrid_config.ssm, layer_index=0, tensor_space=tensor_space, name=hybrid_block_layout[0]) -# tensor_space.setup(distributed) -# materialize_meta_tensors(layer, tensor_space) -# layer.to(distributed.device) +@pytest.mark.slow +@pytest.mark.skipif( + not run_test + and not pathlib.Path("/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_ssm2nd_init_mambainlama_debug").exists(), + reason=f"Skipping because no CUDA available or Mamba not installed", +) +def test_load_from_hybridssm_checkpoint(): + """ + Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. + """ + vocab_size = 131072 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json + batch_size = 2 + seq_length = 32 -# batch_size = 2 -# seq_length = 32 -# hidden_size = hybrid_config.transformer.hidden_size -# x = torch.randn(batch_size, seq_length, hidden_size, device=distributed.device) + path = pathlib.Path("/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_ssm2nd_init_mambainlama_debug") + format = AprielSSMHHybridHuggingfaceCheckpointFormat -# # Run forward pass -# output, _ = layer(x, {}) + x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") + hf_logits, parameter_sum_hf = get_hf_apriel_hybrid_out(x, path, format) + hf_logits = hf_logits["logits"].cpu() -# loss = output.sum() -# loss.backward() -# # Basic shape checkss -# assert output.shape == x.shape -# assert not torch.isnan(output).any() -# assert not torch.isinf(output).any() + # Create checkpoint load config + checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) + # Initialize model + model = HybridModel.from_pretrained(checkpoint_config) + param_sum = 0 + for stage in model.stages: + for fsdp in stage.fsdps: + if hasattr(fsdp, "_weight_shard"): + param_sum += torch.sum(fsdp._weight_shard).item() + assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 # test legacy behavior of using m and m2d @@ -264,7 +237,7 @@ def test_mamba_block(distributed_config, distributed, hybrid_block_layout): ) def test_hybrid_model_train_with_fast_mode(distributed_config, hybrid_block_layout): hybrid_config = get_hybrid_config(hybrid_block_layout=hybrid_block_layout) - model = HybridSSMBaseModel(hybrid_config, distributed_config) + model = HybridBaseModel(hybrid_config, distributed_config) distributed = Distributed(distributed_config) model.setup(distributed) tensor_space = model._tensor_space @@ -311,7 +284,7 @@ def test_hybrid_model_train_with_fast_mode(distributed_config, hybrid_block_layo # @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA available") # def test_hybrid_model_inference(distributed_config, hybrid_config): # hybrid_config.ssm.use_fast_path = False -# model = HybridSSMBaseModel(hybrid_config, distributed_config) +# model = HybridBaseModel(hybrid_config, distributed_config) # distributed = Distributed(distributed_config) # model.setup(distributed) # tensor_space = model._tensor_space diff --git a/tests/test_modular_config.py b/tests/test_modular_config.py index 38a137a2..b740a43c 100644 --- a/tests/test_modular_config.py +++ b/tests/test_modular_config.py @@ -2,7 +2,7 @@ from fast_llm.layers.ssm.config import SSMConfig from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.models.hybrid.config import HybridBaseModelConfig, MambaBlockConfig, TransformerBlockConfig -from fast_llm.models.hybrid.model import HybridSSMBaseModel +from fast_llm.models.hybrid.model import HybridBaseModel config = HybridBaseModelConfig( blocks={ @@ -34,4 +34,4 @@ ) # Create model -model = HybridSSMBaseModel(config, distributed_config) +model = HybridBaseModel(config, distributed_config) diff --git a/tests/test_mtp.py b/tests/test_mtp.py index ea46ace3..6df35102 100644 --- a/tests/test_mtp.py +++ b/tests/test_mtp.py @@ -20,9 +20,9 @@ try: from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 from fast_llm.layers.ssm.mamba_layer import MambaLayer - from fast_llm.models.hybrid.model import HybridSSMBaseModel + from fast_llm.models.hybrid.model import HybridBaseModel except ImportError: - MambaLayer, HybridSSMBaseModel, DiscreteMamba2 = ( + MambaLayer, HybridBaseModel, DiscreteMamba2 = ( None, None, None, @@ -147,7 +147,7 @@ def test_hybrid_model_mtp(distributed_config, hybrid_block_layout, prediction_he hybrid_config = get_hybrid_config( hybrid_block_layout=hybrid_block_layout, prediction_heads=prediction_heads, default_mtp_type=default_mtp_type ) - model = HybridSSMBaseModel(hybrid_config, distributed_config) + model = HybridBaseModel(hybrid_config, distributed_config) distributed = Distributed(distributed_config) model.setup(distributed) tensor_space = model._tensor_space From 38fc5290c58ad8b92cd2e7bff3013aa18b8b96d9 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 29 May 2025 03:17:26 +0000 Subject: [PATCH 110/114] wip --- fast_llm/layers/language_model/config.py | 3 --- fast_llm/models/gpt/model.py | 2 +- fast_llm/models/hybrid/config.py | 30 +++++++++++++++++------- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index c6096bb6..9a8b45b8 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -266,7 +266,4 @@ def from_flat_dict( cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") - if "match_megatron" in default: - assert "use_megatron_initialization" not in default - default["use_megatron_initialization"] = default.pop("match_megatron") return super().from_flat_dict(default, strict) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 774544fb..f5569396 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -160,7 +160,7 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.input_hidden) hidden_dims = ( (hidden_sequence_q_dim, batch_dim, hidden_dim) if sequence_first diff --git a/fast_llm/models/hybrid/config.py b/fast_llm/models/hybrid/config.py index a5e2f6e1..63b1ec7b 100644 --- a/fast_llm/models/hybrid/config.py +++ b/fast_llm/models/hybrid/config.py @@ -30,9 +30,13 @@ @config_class(registry=True) class HybridBlockConfig(Config): - _abstract = True block_class: typing.ClassVar[type[BaseBlock]] - # config: TransformerConfig | SSMConfig + + type: str | None = Field( + default="transformer", + desc="The config class name.", + hint=FieldHint.feature, + ) lr_scale: list[float] | None = Field( default=None, @@ -46,6 +50,17 @@ class HybridBlockConfig(Config): hint=FieldHint.architecture, ) + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is HybridBlockConfig and cls.get_subclass(default.get("type")) is None: + raise ValueError(f"Block type not set in {cls}") + return super()._from_dict(default, strict=strict, flat=flat) + @config_class(dynamic_type={HybridBlockConfig: "transformer"}) class TransformerBlockConfig(HybridBlockConfig, TransformerConfig): @@ -177,7 +192,7 @@ class HybridBaseModelConfig(LanguageModelBaseConfig): # TODO: currently num_layers is defined in TransformerConfig, but ideally this should be migrated to LanguageModelBaseConfig in the future. # Hence, for now: the num_layers can be set in the first transformer block, if no transformer blocks used we will fallback to num_layers parameter defined here. num_layers: int = Field( - default=12, + default=None, desc="Number of layers in the transformer.", hint=FieldHint.architecture, valid=check_field(Assert.geq, 0), @@ -252,13 +267,12 @@ def _validate(self): # set num_layers from transformer block config if it exists and if num_layers is not set in HybridBaseModelConfig # i.e. the resolution hierarchy for num_layers is: HybridBaseModelConfig.num_layers > TransformerBlockConfig.num_layers - if first_transformer_block_config is not None: + if first_transformer_block_config is not None and self.num_layers is None: num_layers = first_transformer_block_config.num_layers with self._set_implicit_default(): - if self.num_layers is None: - logger.warning( - f"TransformerBlockConfig overwrites BaseModelConfig num_layers, setting num_layers = {num_layers}" - ) + logger.warning( + f"TransformerBlockConfig overwrites BaseModelConfig num_layers, setting num_layers = {num_layers}" + ) self.num_layers = num_layers # make sure that the hybrid_block_layout length matches the num_layers. If it doesn't, repeat the hybrid_block_layout; From e5534fdef2b08534266896f32283fa87c7bba8c5 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 3 Jun 2025 14:11:56 +0000 Subject: [PATCH 111/114] wip --- fast_llm/layers/common/config.py | 8 ++++++++ fast_llm/layers/transformer/transformer.py | 1 + fast_llm/models/hybrid/config.py | 14 +------------- fast_llm/models/hybrid/model.py | 3 ++- 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 96affa36..bc496d88 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -319,6 +319,14 @@ class BaseBlockConfig(BaseModelConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + + lr_scale: float = Field( + default=1.0, + desc="Custom learning rate scale for full block. note, ", + doc="May be used to freeze some layers by setting their scale to zero. Note, in non-hybrid models (GPT model) all layers share same config and setting lr_scale to 0 will freeze all layers. Consider using norm_lr_scale, mlp_lr_scale etc. instead.", + hint=FieldHint.feature, + ) + norm_lr_scale: float | None | list[float | None] = Field( default=None, desc="Custom learning rate scale for each normalization layer.", diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 0618821b..e2dd484e 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -50,6 +50,7 @@ def __init__( self.norm_2 = self._config.normalization.get_layer(hidden_dim, lr_scale=self._config.norm_lr_scale) self._create_mixer() + self.lr_scale = self._config.lr_scale self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( self._config, self._tensor_space, f"{self.block_name}", layer_index=layer_index diff --git a/fast_llm/models/hybrid/config.py b/fast_llm/models/hybrid/config.py index 63b1ec7b..21ca0148 100644 --- a/fast_llm/models/hybrid/config.py +++ b/fast_llm/models/hybrid/config.py @@ -30,26 +30,14 @@ @config_class(registry=True) class HybridBlockConfig(Config): + _abstract = True block_class: typing.ClassVar[type[BaseBlock]] - type: str | None = Field( default="transformer", desc="The config class name.", hint=FieldHint.feature, ) - lr_scale: list[float] | None = Field( - default=None, - desc="Custom learning rate scale for each layer.", - doc="May be used to freeze some layers by setting their scale to zero.", - hint=FieldHint.feature, - ) - hidden_size: int = Field( - default=1024, - desc="Hidden size of the block.", - hint=FieldHint.architecture, - ) - @classmethod def _from_dict( cls, diff --git a/fast_llm/models/hybrid/model.py b/fast_llm/models/hybrid/model.py index 87a731c5..1f387f4b 100644 --- a/fast_llm/models/hybrid/model.py +++ b/fast_llm/models/hybrid/model.py @@ -6,6 +6,7 @@ from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead +from fast_llm.layers.transformer.transformer import BaseBlock from fast_llm.models.gpt.model import GPTBaseModel from fast_llm.models.hybrid.config import HybridBaseModelConfig, HybridModelConfig @@ -62,7 +63,7 @@ def get_layers(self) -> list[Layer]: # Create blocks according to pattern for i, block_name in enumerate(self._config.hybrid_block_layout): - BLOCK_CLS = self._config.blocks[block_name].block_class + BLOCK_CLS: BaseBlock = self._config.blocks[block_name].block_class layers.append( BLOCK_CLS( self._config.blocks[block_name], From 6860c431eb6ddee8373aea46a04ffe9a333b0863 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 3 Jun 2025 14:55:14 +0000 Subject: [PATCH 112/114] added lr scales per block --- fast_llm/layers/common/config.py | 4 ++-- fast_llm/layers/ssm/discrete_mamba2.py | 5 ++++- fast_llm/layers/ssm/mamba_layer.py | 5 ++++- fast_llm/layers/transformer/attention.py | 5 +++-- fast_llm/layers/transformer/mixture_of_experts.py | 12 ++++++------ fast_llm/layers/transformer/mlp.py | 6 ++++-- fast_llm/layers/transformer/transformer.py | 10 +++++++--- 7 files changed, 30 insertions(+), 17 deletions(-) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index bc496d88..d4188e11 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -320,8 +320,8 @@ class BaseBlockConfig(BaseModelConfig): valid=check_field(Assert.gt, 0), ) - lr_scale: float = Field( - default=1.0, + lr_scale: float | None = Field( + default=None, desc="Custom learning rate scale for full block. note, ", doc="May be used to freeze some layers by setting their scale to zero. Note, in non-hybrid models (GPT model) all layers share same config and setting lr_scale to 0 will freeze all layers. Consider using norm_lr_scale, mlp_lr_scale etc. instead.", hint=FieldHint.feature, diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index a907b535..da57e917 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -10,6 +10,7 @@ from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ +from fast_llm.utils import get_lr_scale logger = logging.getLogger(__name__) @@ -48,7 +49,9 @@ def __init__( bias = config.add_bias_linear self.layer_idx = layer_index self._return_input = return_input - mamba_layer_lr_scale = self.config.mamba_lr_scale + + layer_lr_scale = config.lr_scale if config.lr_scale else None + mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) logger.info(f"Setting lr_scale for layer {layer_index} of type {type(self)}: {mamba_layer_lr_scale}") td_inner = tensor_space.get_tensor_dim(f"{SSMDimNames.inner_dim}_{name}") diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index aa774aba..76be3c4d 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -9,6 +9,7 @@ from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ +from fast_llm.utils import get_lr_scale """ Note: this is mostly addapted from https://github.com/Zyphra/Zamba2, similar code is aslo in https://github.com/state-spaces/mamba. @@ -82,7 +83,9 @@ def __init__( self.d_state = td_state.size self.d_model = td_model.size self.dt_rank = tdt_rank.size - mamba_layer_lr_scale = self.config.mamba_lr_scale + + layer_lr_scale = self.config.lr_scale if self.config.lr_scale else None + mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) self.in_proj_weight = ParameterMeta.from_dims( (td_inner_proj, td_model), diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 60e1ab0a..f267b1bb 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -17,7 +17,7 @@ ) from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ -from fast_llm.utils import Assert +from fast_llm.utils import Assert, get_lr_scale try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -117,7 +117,8 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(f"{TransformerDimNames.hidden}_{block_name}") - attention_lr_scale = self._config.attention_lr_scale + layer_lr_scale = self._config.lr_scale if self._config.lr_scale else None + attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 8a8fd05e..f4fc8cf9 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -47,7 +47,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s super().__init__(config, tensor_space, name) self._config = config self._tensor_space = tensor_space - self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory + self._debug_mode = self._config.debug_block or self._config.debug_block_memory self._num_experts = config.num_experts self._experts_per_token = config.num_experts_per_token @@ -59,7 +59,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._z_loss_factor = config.expert_z_loss_coefficient self._moe_jitter_eps = config.moe_jitter_eps - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = self._config.lr_scale if self._config.lr_scale else None router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( @@ -229,15 +229,15 @@ def _debug_log( kwargs: dict[str, typing.Any], global_: bool = True, ) -> None: - if self._config.debug_transformer_memory: + if self._config.debug_block_memory: log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self._name} {name}", str)) - if self._config.debug_transformer and tensor is not None: + if self._config.debug_block and tensor is not None: # TODO: Local vs global meta = self._get_meta(tensor, name, dim_name, kwargs) log_distributed_tensor( "", tensor.view_as(meta), - level=self._config.debug_transformer, + level=self._config.debug_block, meta=meta, distributed=self._tensor_space.distributed, global_=global_, @@ -246,7 +246,7 @@ def _debug_log( log_distributed_grad( "", tensor, - level=self._config.debug_transformer, + level=self._config.debug_block, meta=self._get_meta(tensor, name + " grad", dim_name, kwargs), distributed=self._tensor_space.distributed, grad_fn=lambda tensor_: tensor_.view_as(meta), diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index 73afd745..03bfba22 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -11,7 +11,7 @@ from fast_llm.layers.common.linear import LinearBase from fast_llm.layers.transformer.config import TransformerDimNames, TransformerSubLayerName from fast_llm.tensor import init_normal_, init_zeros_ -from fast_llm.utils import Assert +from fast_llm.utils import Assert, get_lr_scale class MLPBase(Layer, ABC): @@ -42,7 +42,9 @@ def __init__(self, config: BaseBlockConfig, tensor_space: TensorSpace, block_nam self._activation_type = config.activation_type self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - lr_scale = config.mlp_lr_scale + layer_lr_scale = config.lr_scale if config.lr_scale else None + mlp_lr_scale = tuple(config.lr_scale) if isinstance(config.lr_scale, list) else config.lr_scale + lr_scale = get_lr_scale(mlp_lr_scale, layer_lr_scale) # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index e2dd484e..85a9b465 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -15,6 +15,7 @@ from fast_llm.layers.transformer.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta +from fast_llm.utils import get_lr_scale logger = logging.getLogger(__name__) @@ -46,11 +47,14 @@ def __init__( self._debug_mode = self._config.debug_block or self._config.debug_block_memory hidden_dim = self._tensor_space.get_tensor_dim(f"{LLMDimNames.hidden}_{block_name}") # Note, layer_lr_scale does not impact the norms - self.norm_1 = self._config.normalization.get_layer(hidden_dim, lr_scale=self._config.norm_lr_scale) - self.norm_2 = self._config.normalization.get_layer(hidden_dim, lr_scale=self._config.norm_lr_scale) + + layer_lr_scale = self._config.lr_scale if self._config.lr_scale else None + norm_lr_scale = get_lr_scale(self._config.norm_lr_scale, layer_lr_scale) + + self.norm_1 = self._config.normalization.get_layer(hidden_dim, lr_scale=norm_lr_scale) + self.norm_2 = self._config.normalization.get_layer(hidden_dim, lr_scale=norm_lr_scale) self._create_mixer() - self.lr_scale = self._config.lr_scale self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( self._config, self._tensor_space, f"{self.block_name}", layer_index=layer_index From 7178407b63d4239c02ec593c242f77349dc29662 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 3 Jun 2025 18:20:19 +0000 Subject: [PATCH 113/114] weight sharing --- fast_llm/models/hybrid/config.py | 41 +++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/fast_llm/models/hybrid/config.py b/fast_llm/models/hybrid/config.py index 21ca0148..75f93465 100644 --- a/fast_llm/models/hybrid/config.py +++ b/fast_llm/models/hybrid/config.py @@ -2,6 +2,7 @@ import logging import math import typing +from abc import abstractmethod from blocks import LlambaBlock, LlambaOneBlock @@ -38,6 +39,16 @@ class HybridBlockConfig(Config): hint=FieldHint.feature, ) + share_weights: bool = Field( + default=False, + desc="Whether to share weights between blocks. If True, blocks with the same name will share weights.", + hint=FieldHint.optional, + ) + + @abstractmethod + def setup_tensor_space(self, tensor_space: TensorSpace, block_name: str) -> None: + pass + @classmethod def _from_dict( cls, @@ -243,8 +254,25 @@ def _validate(self): first_transformer_block_config: TransformerBlockConfig | None = None + ### Weight sharing ### + # handle share_weights by renaming blocks with shared weights. Layer names are used for setting tensor dimensions. + blocks = {} + hybrid_block_layout = [] + for i, block_name in enumerate(self.hybrid_block_layout): + block_config = self.blocks[block_name] + if not block_config.share_weights: + logger.info(f"Weight sharing disable for block {block_name}, renaming to {block_name}_{i}") + block_name = f"{block_name}_{i}" + else: + logger.info(f"Weight sharing enabled for block {block_name}") + blocks[block_name] = block_config + hybrid_block_layout.append(block_name) + self.blocks = blocks + self.hybrid_block_layout = hybrid_block_layout + ###\Weight sharing ### + for block_name, block_config in self.blocks.items(): - if isinstance(block_config, TransformerBlockConfig): + if isinstance(block_config, TransformerBlockConfig) and self.num_layers is None: if first_transformer_block_config is None: first_transformer_block_config = block_config elif block_config.num_layers != first_transformer_block_config.num_layers: @@ -265,17 +293,18 @@ def _validate(self): # make sure that the hybrid_block_layout length matches the num_layers. If it doesn't, repeat the hybrid_block_layout; if len(self.hybrid_block_layout) != self.num_layers: - if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: + if self.num_layers % len(self.hybrid_block_layout) != 0: raise ValueError( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" + f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.num_layers}" ) - num_repeats = int(self.transformer.num_layers // len(self.hybrid_block_layout)) + num_repeats = int(self.num_layers // len(self.hybrid_block_layout)) logger.warning( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" + f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times with weight sharing between repeats." ) self.hybrid_block_layout = self.hybrid_block_layout * num_repeats - Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) + Assert.eq(len(self.hybrid_block_layout), self.num_layers) + logger.info(f"Hybrid block layout: {self.hybrid_block_layout}") with self._set_implicit_default(): if self.init_method_std_embed is None: From 0553a4bd43b1e3a850ba4c759aea903d2d6be4aa Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 4 Jun 2025 14:30:32 +0000 Subject: [PATCH 114/114] test --- fast_llm/layers/language_model/config.py | 27 ++---- fast_llm/layers/ssm/discrete_mamba2.py | 10 +- fast_llm/models/gpt/config.py | 13 ++- fast_llm/models/hybrid/config.py | 117 ++++++++++++++++++----- fast_llm/models/hybrid/model.py | 14 ++- tests/test_config.py | 58 ++++++++++- 6 files changed, 182 insertions(+), 57 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 9a8b45b8..4fb74eab 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -59,7 +59,7 @@ class LanguageModelBaseConfig(BaseModelConfig): valid=check_field(Assert.gt, 0), ) use_position_embeddings: bool = Field( - default=True, + default=None, desc="Enable absolute position embeddings.", # Default: Enable unless using rotary embeddings.", hint=FieldHint.architecture, ) @@ -175,10 +175,6 @@ class LanguageModelBaseConfig(BaseModelConfig): doc="If not provided, all heads are equally weighted.", hint=FieldHint.feature, ) - # rotary: RotaryConfig = Field( - # desc="Configuration for the rotary positional embeddings.", - # hint=FieldHint.architecture, - # ) full_precision_residual: bool = Field( default=False, desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", @@ -190,12 +186,13 @@ class LanguageModelBaseConfig(BaseModelConfig): hint=FieldHint.testing, ) embeddings_hidden_dropout: float = Field( - default=0.0, + default=None, desc="Dropout applied to the embeddings.", hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - head_normalization: NormalizationConfig = Field( + head_normalization: NormalizationConfig | None = Field( + default=None, desc="Configuration for the normalization in the head.", hint=FieldHint.architecture, ) @@ -205,16 +202,11 @@ class LanguageModelBaseConfig(BaseModelConfig): ) def _validate(self) -> None: - # # self.transformer.validate() - # with self._set_implicit_default(): - # if self.use_position_embeddings is None: - # self.use_position_embeddings = not self.rotary.enabled - # if self.init_method_std_embed is None: - # self.init_method_std_embed = self.transformer.init_method_std - # if self.init_method_max_embed is None: - # self.init_method_max_embed = self.transformer.init_method_max - # if self.init_method_min_embed is None: - # self.init_method_min_embed = self.transformer.init_method_min + with self._set_implicit_default(): + if self.embeddings_hidden_dropout is None: + self.embeddings_hidden_dropout = 0.0 + if self.head_normalization is None: + self.head_normalization = NormalizationConfig() if not TritonConfig.TRITON_ENABLED: warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") @@ -234,7 +226,6 @@ def _validate(self) -> None: Assert.geq(coeff, 0) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - # self.transformer.setup_tensor_space(tensor_space) tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Embedding dimensions diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index da57e917..b01afb03 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,5 +1,6 @@ import logging import math +import typing import causal_conv1d import einops @@ -8,10 +9,13 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.ssm.config import SSMDimNames from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ from fast_llm.utils import get_lr_scale +if typing.TYPE_CHECKING: + from fast_llm.layers.ssm.config import SSMConfig + logger = logging.getLogger(__name__) """ @@ -30,7 +34,7 @@ class DiscreteMamba2(torch.nn.Module): def __init__( self, - config: SSMConfig, + config: "SSMConfig", layer_index: int, tensor_space: TensorSpace, name: str = "", @@ -45,7 +49,7 @@ def __init__( """ # factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16} super().__init__() - self.config: SSMConfig = config + self.config: "SSMConfig" = config bias = config.add_bias_linear self.layer_idx = layer_index self._return_input = return_input diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 583d4703..035fb4bb 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -1,4 +1,5 @@ import functools +import logging import typing from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class @@ -14,6 +15,8 @@ from fast_llm.models.gpt.megatron import set_megatron_distributed_seeds from fast_llm.utils import Assert, div +logger = logging.getLogger(__name__) + if typing.TYPE_CHECKING: from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM from fast_llm.models.gpt.model import GPTInferenceRunner, GPTModel @@ -134,10 +137,14 @@ def _validate(self) -> None: self.transformer.debug_block = True self.transformer.debug_block_memory = True self.transformer.validate() - self.use_position_embeddings = not self.transformer.rotary.enabled - self.embeddings_hidden_dropout = self.transformer.hidden_dropout # legacy behavior - self.head_normalization = self.transformer.normalization # legacy behavior + with self._set_implicit_default(): + if self.head_normalization is None: + self.head_normalization = self.transformer.normalization + if self.embeddings_hidden_dropout is None: + self.embeddings_hidden_dropout = self.transformer.hidden_dropout + if self.use_position_embeddings is None: + self.use_position_embeddings = not self.transformer.rotary.enabled if self.init_method_std_embed is None: self.init_method_std_embed = self.transformer.init_method_std if self.init_method_max_embed is None: diff --git a/fast_llm/models/hybrid/config.py b/fast_llm/models/hybrid/config.py index 75f93465..44ebe1a0 100644 --- a/fast_llm/models/hybrid/config.py +++ b/fast_llm/models/hybrid/config.py @@ -3,8 +3,7 @@ import math import typing from abc import abstractmethod - -from blocks import LlambaBlock, LlambaOneBlock +from collections import Counter from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.data.gpt.config import GPTDataConfig @@ -16,11 +15,11 @@ from fast_llm.layers.language_model.config import LanguageModelBaseConfig from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.layers.transformer.transformer import BaseBlock, TransformerLayer from fast_llm.models.gpt.config import GPTBatchConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: + from fast_llm.layers.transformer.transformer import BaseBlock from fast_llm.models.gpt.model import GPTInferenceRunner from fast_llm.models.hybrid.huggingface import HuggingfaceHybridModelForCausalLM from fast_llm.models.hybrid.model import HybridModel @@ -29,10 +28,35 @@ logger = logging.getLogger(__name__) +def _get_llamba_block(): + """Lazy import to avoid loading heavy dependencies during config validation.""" + from fast_llm.layers.ssm.blocks import LlambaBlock + + return LlambaBlock + + +def _get_llamba_one_block(): + """Lazy import to avoid loading heavy dependencies during config validation.""" + from fast_llm.layers.ssm.blocks import LlambaOneBlock + + return LlambaOneBlock + + +def _get_transformer_block(): + """Lazy import to avoid loading heavy dependencies during config validation.""" + from fast_llm.layers.transformer.transformer import TransformerLayer + + return TransformerLayer + + @config_class(registry=True) class HybridBlockConfig(Config): _abstract = True - block_class: typing.ClassVar[type[BaseBlock]] + + @abstractmethod + def block_class(self) -> type["BaseBlock"]: + raise NotImplementedError("Subclasses must implement block_class") + type: str | None = Field( default="transformer", desc="The config class name.", @@ -64,7 +88,10 @@ def _from_dict( @config_class(dynamic_type={HybridBlockConfig: "transformer"}) class TransformerBlockConfig(HybridBlockConfig, TransformerConfig): _abstract = False - block_class: typing.ClassVar[type[BaseBlock]] = TransformerLayer + + @property + def block_class(self) -> type["BaseBlock"]: + return _get_transformer_block() def setup_tensor_space(self, tensor_space: "TensorSpace", block_name: str) -> None: TransformerConfig.setup_tensor_space(self, tensor_space, block_name) @@ -73,7 +100,10 @@ def setup_tensor_space(self, tensor_space: "TensorSpace", block_name: str) -> No @config_class(dynamic_type={HybridBlockConfig: "discrete_mamba2"}) class DiscreteMamba2BlockConfig(HybridBlockConfig, SSMConfig): _abstract = False - block_class: typing.ClassVar[type[BaseBlock]] = LlambaBlock + + @property + def block_class(self) -> type["BaseBlock"]: + return _get_llamba_block() # def _validate(self): # self.config.validate() @@ -111,7 +141,10 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_name: str) -> None @config_class(dynamic_type={HybridBlockConfig: "mamba"}) class MambaBlockConfig(HybridBlockConfig, SSMConfig): _abstract = False - block_class: typing.ClassVar[type[BaseBlock]] = LlambaOneBlock + + @property + def block_class(self) -> type["BaseBlock"]: + return _get_llamba_one_block() def setup_tensor_space(self, tensor_space: TensorSpace, block_name: str) -> None: @@ -170,7 +203,7 @@ class HybridBaseModelConfig(LanguageModelBaseConfig): hint=FieldHint.architecture, ) ############################################################################################ - blocks: dict[str, HybridBlockConfig] = Field( + blocks: dict[str, HybridBlockConfig] | None = Field( default=None, desc="Named block configurations that can be referenced in block_pattern.", hint=FieldHint.architecture, @@ -188,6 +221,18 @@ class HybridBaseModelConfig(LanguageModelBaseConfig): hint=FieldHint.optional, ) + _hybrid_block_layout: list[str] | None = Field( + init=False, + desc="Internal representation of the block layout.", + hint=FieldHint.derived, + ) + + _blocks: dict[str, HybridBlockConfig] | None = Field( + init=False, + desc="Internal representation of the blocks.", + hint=FieldHint.derived, + ) + # TODO: currently num_layers is defined in TransformerConfig, but ideally this should be migrated to LanguageModelBaseConfig in the future. # Hence, for now: the num_layers can be set in the first transformer block, if no transformer blocks used we will fallback to num_layers parameter defined here. num_layers: int = Field( @@ -197,6 +242,14 @@ class HybridBaseModelConfig(LanguageModelBaseConfig): valid=check_field(Assert.geq, 0), ) + @property + def block_layout(self) -> list[str]: + return self._hybrid_block_layout + + @property + def registered_blocks(self) -> dict[str, HybridBlockConfig]: + return self._blocks + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: """ Setup the tensor space for the model. @@ -258,24 +311,31 @@ def _validate(self): # handle share_weights by renaming blocks with shared weights. Layer names are used for setting tensor dimensions. blocks = {} hybrid_block_layout = [] + block_count = Counter(self.hybrid_block_layout) for i, block_name in enumerate(self.hybrid_block_layout): block_config = self.blocks[block_name] if not block_config.share_weights: - logger.info(f"Weight sharing disable for block {block_name}, renaming to {block_name}_{i}") - block_name = f"{block_name}_{i}" + if block_count[block_name] > 1: + logger.info(f"Weight sharing disabled for block {block_name}, renaming to {block_name}_{i}") + block_name = f"{block_name}_{i}" + else: + logger.info(f"Weight sharing disabled for block {block_name}, no renaming needed") else: logger.info(f"Weight sharing enabled for block {block_name}") blocks[block_name] = block_config hybrid_block_layout.append(block_name) - self.blocks = blocks - self.hybrid_block_layout = hybrid_block_layout + with self._set_implicit_default(): + # self.blocks = blocks + # self.hybrid_block_layout = hybrid_block_layout + self._hybrid_block_layout = hybrid_block_layout + self._blocks = blocks ###\Weight sharing ### - for block_name, block_config in self.blocks.items(): - if isinstance(block_config, TransformerBlockConfig) and self.num_layers is None: + for block_name, block_config in self._blocks.items(): + if isinstance(block_config, TransformerBlockConfig): if first_transformer_block_config is None: first_transformer_block_config = block_config - elif block_config.num_layers != first_transformer_block_config.num_layers: + elif self.num_layers is None and block_config.num_layers != first_transformer_block_config.num_layers: logger.warning( f"Found multiple transformer blocks with different number of layers, using num_layers from the first transformer block for all" ) @@ -292,21 +352,32 @@ def _validate(self): self.num_layers = num_layers # make sure that the hybrid_block_layout length matches the num_layers. If it doesn't, repeat the hybrid_block_layout; - if len(self.hybrid_block_layout) != self.num_layers: - if self.num_layers % len(self.hybrid_block_layout) != 0: + if len(self._hybrid_block_layout) != self.num_layers: + if self.num_layers % len(self._hybrid_block_layout) != 0: raise ValueError( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.num_layers}" + f"hybrid_block_layout length {len(self._hybrid_block_layout)} does not match num_layers {self.num_layers}" ) - num_repeats = int(self.num_layers // len(self.hybrid_block_layout)) + num_repeats = int(self.num_layers // len(self._hybrid_block_layout)) logger.warning( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times with weight sharing between repeats." + f"hybrid_block_layout length {len(self._hybrid_block_layout)} does not match num_layers {self.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times with weight sharing between repeats." ) - self.hybrid_block_layout = self.hybrid_block_layout * num_repeats + self._hybrid_block_layout = self._hybrid_block_layout * num_repeats - Assert.eq(len(self.hybrid_block_layout), self.num_layers) - logger.info(f"Hybrid block layout: {self.hybrid_block_layout}") + Assert.eq(len(self._hybrid_block_layout), self.num_layers) + logger.info(f"Hybrid block layout: {self._hybrid_block_layout}") with self._set_implicit_default(): + if self.use_position_embeddings is None: + if first_transformer_block_config is not None: + self.use_position_embeddings = not first_transformer_block_config.rotary.enabled + self.embeddings_hidden_dropout = first_transformer_block_config.hidden_dropout + else: + self.use_position_embeddings = False + self.embeddings_hidden_dropout = 0.0 + logger.warning( + f"No transformer block config found in HybridBaseModelConfig, setting use_position_embeddings to False" + ) + if self.init_method_std_embed is None: self.init_method_std_embed = ( first_transformer_block_config.init_method_std diff --git a/fast_llm/models/hybrid/model.py b/fast_llm/models/hybrid/model.py index 1f387f4b..cee3f4d2 100644 --- a/fast_llm/models/hybrid/model.py +++ b/fast_llm/models/hybrid/model.py @@ -38,14 +38,14 @@ def get_output_layers(self) -> list[Layer]: if self._config.prediction_heads > 1: block_name = self._config.default_mtp_type - assert block_name in self._config.blocks, f"Block {block_name} not found in config" - BLOCK_CLS = self._config.blocks[block_name].block_class + assert block_name in self._config.registered_blocks, f"Block {block_name} not found in config" + BLOCK_CLS = self._config.registered_blocks[block_name].block_class for i in range(1, self._config.prediction_heads): layers.append( BLOCK_CLS( - self._config.blocks[block_name], + self._config.registered_blocks[block_name], self._tensor_space, - layer_index=len(self._config.hybrid_block_layout), + layer_index=len(self._config.block_layout), return_input=i != self._config.prediction_heads - 1, block_name=block_name, ) @@ -62,16 +62,14 @@ def get_layers(self) -> list[Layer]: layers = [LanguageModelEmbedding(self._config, self._tensor_space)] # Create blocks according to pattern - for i, block_name in enumerate(self._config.hybrid_block_layout): + for i, block_name in enumerate(self._config.block_layout): BLOCK_CLS: BaseBlock = self._config.blocks[block_name].block_class layers.append( BLOCK_CLS( self._config.blocks[block_name], self._tensor_space, layer_index=i + 1, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), + return_input=(i == len(self._config.block_layout) - 1 and self._config.prediction_heads > 1), block_name=block_name, ) ) diff --git a/tests/test_config.py b/tests/test_config.py index 80bed418..62069162 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -13,6 +13,7 @@ from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.models.auto import trainer_registry from fast_llm.models.gpt.config import GPTModelConfig, PretrainedGPTModelConfig +from fast_llm.models.hybrid.config import HybridBaseModelConfig from fast_llm.utils import Assert, check_equal_nested from tests.common import TEST_RESULTS_PATH @@ -132,11 +133,11 @@ def test_pretrained_config(load_config: ModelConfigType): "transformer": { # rotary: Don't override nested. "normalization": {"implementation": "triton"}, # Update non-default nested - "peft": {"freeze_others": False}, # Update default nested, non-architecture "hidden_size": 512, # Override, affects derived value (kv channels) "head_groups": 1, # Override to default }, "vocab_size": 1000, + "head_normalization": {"type": "rms_norm"}, } pretrained_config = PretrainedGPTModelConfig.from_dict( { @@ -159,7 +160,6 @@ def test_pretrained_config(load_config: ModelConfigType): "transformer": { "normalization": {"type": "rms_norm", "implementation": "triton"}, "rotary": {"type": "default"}, - "peft": {"freeze_others": False}, "num_layers": 12, "hidden_size": 512, "ffn_hidden_size": 4096, @@ -169,8 +169,62 @@ def test_pretrained_config(load_config: ModelConfigType): }, "tie_word_embeddings": False, "vocab_size": 1000, + "head_normalization": {"type": "rms_norm"}, } else: expected_config["base_model"] = base_model_update check_equal_nested(serialized_config, expected_config) + + +# TODO: add test for hybrid pretrained config as above +def test_hybrid_block_modular_config(): + + config = { + "blocks": { + "bob_shared": { + "type": "transformer", + "hidden_size": 512, + "share_weights": True, + }, + "mamba_non_shared": { + "type": "discrete_mamba2", + "state_size": 16, + "expansion_factor": 2, + "hidden_size": 512, + "share_weights": False, + }, + }, + "hybrid_block_layout": ["bob_shared", "mamba_non_shared", "bob_shared", "mamba_non_shared"], + "num_layers": 8, + } + + modular_config = HybridBaseModelConfig.from_dict(config) + modular_config.validate() + Assert.eq(modular_config.hybrid_block_layout, ["bob_shared", "mamba_non_shared", "bob_shared", "mamba_non_shared"]) + Assert.eq( + modular_config.block_layout, + [ + "bob_shared", + "mamba_non_shared_1", + "bob_shared", + "mamba_non_shared_3", + "bob_shared", + "mamba_non_shared_1", + "bob_shared", + "mamba_non_shared_3", + ], + ) # with num_layers = 8, the block_layout should be 8 blocks with repeated pattern of ["bob_shared", "mamba_non_shared", "bob_shared", "mamba_non_shared"] with names abjusted for weight sharing + for block_name in modular_config.block_layout: + Assert.custom( + lambda _: block_name in modular_config.registered_blocks, + f"Block {block_name} not found in registered blocks", + ) + serialized = modular_config.to_dict() + reconstructed = HybridBaseModelConfig.from_dict(serialized) + Assert.eq(reconstructed.to_dict(), config) + reconstructed.validate() + + +if __name__ == "__main__": + pytest.main([__file__])