diff --git a/fast_llm/config.py b/fast_llm/config.py index 380100e3..f34b3fe6 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -380,8 +380,8 @@ def validate[T: Config](self: T, *, _is_validating: bool = False) -> T: if expected_class is not None: # Should be handled in `from_dict`, but can fail if instantiating directly. - Assert.is_(self.__class__, expected_class) - + # TODO: is this ok? i.e. we want the assigned class to be a subclass of the expected class, not neccessarily exactly the same class. + Assert.custom(issubclass, expected_class, self.__class__) if not self._validated: try: self._validate() @@ -720,7 +720,7 @@ def _get_class_name(cls) -> str: @classmethod def from_dict( cls, - default: "Config| dict[str, typing.Any]]", + default: "Config| dict[str, typing.Any]", *updates: "Config| dict[str | tuple[str, ...], typing.Any]", strict: bool = True, update_type: UpdateType = UpdateType.override, diff --git a/fast_llm/engine/optimizer/learning_rate.py b/fast_llm/engine/optimizer/learning_rate.py index bf11038a..c6912e4f 100644 --- a/fast_llm/engine/optimizer/learning_rate.py +++ b/fast_llm/engine/optimizer/learning_rate.py @@ -120,19 +120,19 @@ def create_schedule_from_config(config: LearningRateScheduleConfig) -> LearningR begin_step = 0 for stage_arg_str in config.schedule.split(";"): try: - for stage_type, num_steps, lr, *stage_args in stage_arg_str.split(","): - assert begin_step is not None - num_steps = int(num_steps) - end_step = None if num_steps < 0 else begin_step + num_steps - kwargs = {"begin_step": begin_step, "end_step": end_step, "lr": float(lr)} - if len(stage_args) > 0: - kwargs["end_lr"] = float(stage_args[0]) - if len(stage_args) > 1: - kwargs["power"] = float(stage_args[1]) - if len(stage_args) > 2: - raise ValueError(stage_args[2:]) - stages.append(_STAGE_TYPE_MAP[stage_type](**kwargs)) - begin_step = end_step + stage_type, num_steps, lr, *stage_args = stage_arg_str.split(",") + assert begin_step is not None + num_steps = int(num_steps) + end_step = None if num_steps < 0 else begin_step + num_steps + kwargs = {"begin_step": begin_step, "end_step": end_step, "lr": float(lr)} + if len(stage_args) > 0: + kwargs["end_lr"] = float(stage_args[0]) + if len(stage_args) > 1: + kwargs["power"] = float(stage_args[1]) + if len(stage_args) > 2: + raise ValueError(stage_args[2:]) + stages.append(_STAGE_TYPE_MAP[stage_type](**kwargs)) + begin_step = end_step except Exception: raise ValueError(f'Cannot parse optimizer stage definition "{stage_arg_str}"') return LearningRateSchedule(stages) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 054c26c3..d4188e11 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -1,8 +1,12 @@ +import abc import enum import typing -from fast_llm.config import Field, FieldHint, check_field, config_class +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace +from fast_llm.engine.distributed.config import DistributedDimNames +from fast_llm.functional.config import ActivationType, MLPRecomputeLevel from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -11,6 +15,39 @@ from fast_llm.layers.common.normalization import LayerNorm, RMSNorm +class RotaryEmbeddingType(str, enum.Enum): + none = "none" + default = "default" + llama3 = "llama3" + yarn = "yarn" + + +class LLMDimNames: + input_hidden = "input_hidden" + output_hidden = "output_hidden" + # A set of common tensor dim names packed into a namespace. + # Input dimensions (variable) + # TODO: Does batch belong here? + batch = "batch" + # TODO: Distinguish micro-sequence? + sequence_q = "sequence_q" + sequence_q_tp = "sequence_q_tp" + sequence_k = "sequence_k" + hidden = "hidden" + # MLP dimensions + mlp = "mlp" + gate_and_up = "gate_and_up" + composite_gated_mlp = "composite_gated_mlp" + experts = "experts" + top_experts = "top_experts" + shared_experts = "shared_experts" + unshared_experts = "unshared_experts" + composite_expert_mlp = "composite_expert_mlp" + composite_gated_expert_mlp = "composite_gated_expert_mlp" + composite_shared_expert_mlp = "composite_shared_expert_mlp" + composite_gated_shared_expert_mlp = "composite_gated_shared_expert_mlp" + + class NormalizationImplementation(str, enum.Enum): """ An enum for the available implementations of layer norm. @@ -68,7 +105,7 @@ class NormalizationConfig(BaseModelConfig): valid=check_field(Assert.geq, 0), ) - def get_layer(self, hidden_dim: "TensorDim") -> "LayerNorm | RMSNorm": + def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": from fast_llm.layers.common.normalization import LayerNorm, RMSNorm from fast_llm.tensor import init_uniform_ @@ -77,6 +114,7 @@ def get_layer(self, hidden_dim: "TensorDim") -> "LayerNorm | RMSNorm": "eps": self.epsilon, "implementation": self.implementation, "zero_centered": self.zero_centered, + "lr_scale": lr_scale, } if self.initialization_range: mean = 0 if self.zero_centered else 1 @@ -119,15 +157,51 @@ class PeftType(str, enum.Enum): lora = "lora" -@config_class() +@config_class(registry=True) class PeftConfig(BaseModelConfig): - _abstract = False type: PeftType = Field( default=PeftType.none, desc="The type of parameter-efficient fine tuning to use Only LoRA is supported at the moment.", hint=FieldHint.core, ) + + @abc.abstractmethod + def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": + pass + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is PeftConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return EmptyPeftConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + +@config_class(dynamic_type={PeftConfig: "none"}) +class EmptyPeftConfig(PeftConfig): + """ + A dummy PeftConfig that does nothing. + """ + + _abstract = False + + def apply_linear(self, *args, **kwargs) -> "LinearLike": + return args[0] + + +@config_class(dynamic_type={PeftConfig: "lora"}) +class LoRAConfig(PeftConfig): + """ + LoRA configuration. + """ + + _abstract = False rank: int = Field( default=8, desc="The LoRA rank, i.e. the size of the intermediate dimension.", @@ -145,20 +219,430 @@ class PeftConfig(BaseModelConfig): ) def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": - if self.type == PeftType.none: - return linear - elif self.type == PeftType.lora: - from fast_llm.layers.common.peft import lora_linear - - # TODO: Init method? - return lora_linear( - linear, - linear.weight.param_init_method, - linear.weight.param_init_method, - self.rank, - self.alpha, - self.dropout, - **kwargs, + from fast_llm.layers.common.peft import lora_linear + + # TODO: Init method? + return lora_linear( + linear, + linear.weight.param_init_method, + linear.weight.param_init_method, + self.rank, + self.alpha, + self.dropout, + **kwargs, + ) + + +class RoutingType(str, enum.Enum): + topk = "aux_loss" + sinkhorn = "sinkhorn" + + +class AddLinearBiasChoices(str, enum.Enum): + nowhere = "nowhere" + everywhere = "everywhere" + only_attn_qkv = "only_attn_qkv" + + +class BaseBlockSubLayerName: + mlp_1 = "mlp_1" + mlp_2 = "mlp_2" + + +@config_class(dynamic_type={PeftConfig: "base_lora"}) +class BaseBlockLoRAConfig(LoRAConfig): + """ + TODO: Add support for MLP. + """ + + _abstract = False + + layers: list[BaseBlockSubLayerName] = Field( + default=None, + desc="The layers on which to apply LoRA.", + hint=FieldHint.feature, + ) + + def apply_linear(self, linear: "LinearBase", layer_type: BaseBlockSubLayerName | None = None) -> "LinearLike": + if layer_type is None or self.layers is None or layer_type in self.layers: + return super().apply_linear(linear) + return linear + + def _validate(self) -> None: + if self.layers is None: + with self._set_implicit_default(): + self.layers = [] + if BaseBlockSubLayerName.mlp_1 in self.layers or BaseBlockSubLayerName.mlp_2 in self.layers: + # TODO: Add MLP support. + raise NotImplementedError("LoRA not supported for MLP.") + + +# for name in PeftType: +# # We need this because we are using the reserved field name `type`. +# # TODO: Implement proper dynamic typing. +# BaseBlockPeftConfig.register_subclass(name.value, BaseBlockPeftConfig) + + +@config_class() +class BaseBlockConfig(BaseModelConfig): + + _abstract = False + peft: PeftConfig = Field( + # default_factory=lambda: PeftConfig(type=PeftType.none), + desc="Configuration for the parameter-efficient fine tuning.", + hint=FieldHint.architecture, + ) + normalization: NormalizationConfig = Field( + desc="Configuration for the normalization layers architecture.", + hint=FieldHint.architecture, + ) + hidden_dropout: float = Field( + default=0.0, + desc="Dropout applied to the residual connections.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + debug_block: int = Field( + default=0, + desc="Log the output of each operation in each layer.", + hint=FieldHint.logging, + valid=check_field(Assert.geq, 0), + ) + debug_block_memory: bool = Field( + default=False, + desc="Log the memory usage after each operation in each layer.", + hint=FieldHint.logging, + ) + num_experts: int = Field( + default=1, + desc="Number of MLP experts in a Mixture of Expert (MoE) model", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + + lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for full block. note, ", + doc="May be used to freeze some layers by setting their scale to zero. Note, in non-hybrid models (GPT model) all layers share same config and setting lr_scale to 0 will freeze all layers. Consider using norm_lr_scale, mlp_lr_scale etc. instead.", + hint=FieldHint.feature, + ) + + norm_lr_scale: float | None | list[float | None] = Field( + default=None, + desc="Custom learning rate scale for each normalization layer.", + doc="May be used to freeze some normalization layers by setting their scale to zero.", + hint=FieldHint.feature, + ) + mlp_lr_scale: float | None | list[float | None] = Field( + default=None, + desc="Custom learning rate scale for each expert.", + doc="May be used to freeze some experts by setting their scale to zero.", + hint=FieldHint.feature, + ) + + num_layers: int = Field( + default=12, + desc="Number of layers in the transformer.", + hint=FieldHint.architecture, + valid=check_field(Assert.geq, 0), + ) + hidden_size: int = Field( + default=1024, + desc="Size of the transformer's main hidden dimension, e.g., for its input and output layers.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + ffn_hidden_size: int = Field( + default=None, + desc="Hidden dimension of the MLP intermediate state. Default: 4 * hidden_size.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + gated: bool = Field(default=False, desc="Enable gated MLP.", hint=FieldHint.architecture) + num_shared_experts: int = Field( + default=0, + desc="Number of MLP experts that are shared between all tokens, i.e., always enabled.", + hint=FieldHint.architecture, + valid=check_field(Assert.geq, 0), + ) + num_unshared_experts: int = Field( + init=False, + desc="Number of MLP experts excluding shared ones", + hint=FieldHint.architecture, + valid=check_field(Assert.geq, 0), + ) + num_experts_per_token: int = Field( + default=1, + desc="Active experts for each token in a MoE model.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + expert_routing_type: RoutingType = Field( + default=RoutingType.topk, + desc="The routing method, i.e., the method used to assign experts to tokens.", + hint=FieldHint.architecture, + ) + activation_type: ActivationType = Field( + default=None, + desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", + hint=FieldHint.core, + ) + # Default: hidden_size**-0.5 + # TODO: Allow custom initialization (InitializationConfig?) + init_method_std: float = Field( + default=None, + desc="Default scale for weight initialization. Default: hidden_size**-0.5", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max: float | None = Field( + default=None, + desc="Max value for clamping initialized weights. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min: float | None = Field( + default=None, + desc="Min value for clamping initialized weights. Default: -float('inf')", + hint=FieldHint.optional, + ) + init_method_std_qkv: float = Field( + default=None, + desc="Scale for the query, key and value weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_qkv: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for query, key and value matrices. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_qkv: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for query, key and value matrices. Default: -float('inf')", + hint=FieldHint.optional, + ) + init_method_std_attn_proj: float = Field( + default=None, + desc="Scale for the attention projection weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_attn_proj: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for attention projection. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_attn_proj: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for attention projection. Default: -float('inf')", + hint=FieldHint.optional, + ) + init_method_std_mlp_1: float = Field( + default=None, + desc="Scale for the MLP first layer weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_mlp_1: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for MLP first layer. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_mlp_1: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for MLP first layer. Default: -float('inf')", + hint=FieldHint.optional, + ) + init_method_std_mlp_2: float = Field( + default=None, + desc="Scale for the MLP second layer weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_mlp_2: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for MLP second layer. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_mlp_2: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for MLP second layer. Default: -float('inf')", + hint=FieldHint.optional, + ) + # normalization_implementation: NormalizationImplementation = NormalizationImplementation.auto + mlp_recompute_level: MLPRecomputeLevel = Field( + default=MLPRecomputeLevel.none, + desc="Set which of the MLP intermediate activations will be recomputed during the backward passes. This provides a trade-off between memory and speed.", + hint=FieldHint.performance, + ) + # Use random inits instead of constant values, useful for debugging. + random_bias_init: bool = Field( + default=False, + desc="Initialize the biases using the initialization method of their respective weights instead of setting them to zero. Used to test for issues that may not be visible when the biases are zero.", + hint=FieldHint.testing, + ) + expert_auxiliary_loss_coefficient: float = Field( + default=0.01, + desc="Scale of the load balancing auxiliary loss for topk routing.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + expert_z_loss_coefficient: float = Field( + default=0.0, + desc="Regularize the router during training by applying Z-loss to the logits.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + moe_jitter_eps: float = Field( + default=0.0, + desc="Regularize the router during training by applying a random multiplicative noise `uniform(1-eps, 1+eps)` to the logits.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + router_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate for the MoE router weight.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + dropless_moe: bool = Field( + default=True, desc="Evaluate all the experts at once using dropless MoE.", hint=FieldHint.expert + ) + dropless_dynamic_shape: bool = Field( + default=False, + desc="Use a dynamic shape for dropless MLP instead of the worst-case value." + " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", + hint=FieldHint.expert, + ) + add_linear_biases: bool | AddLinearBiasChoices = Field( + default=True, + desc="Add biases to all, none or Q, K, V layers. Accepted values: True, False, or AddLinearBiasChoices.", + hint=FieldHint.architecture, + ) + + def _validate(self) -> None: + with self._set_implicit_default(): + if self.ffn_hidden_size is None: + self.ffn_hidden_size = 4 * self.hidden_size + if self.activation_type is None: + self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu + if self.init_method_std is None: + self.init_method_std = self.hidden_size**-0.5 + if self.init_method_std_qkv is None: + self.init_method_std_qkv = self.init_method_std + if self.init_method_std_attn_proj is None: + self.init_method_std_attn_proj = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 + if self.init_method_std_mlp_1 is None: + self.init_method_std_mlp_1 = self.init_method_std + if self.init_method_std_mlp_2 is None: + self.init_method_std_mlp_2 = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 + if self.init_method_max_qkv is None: + self.init_method_max_qkv = self.init_method_max + if self.init_method_min_qkv is None: + self.init_method_min_qkv = self.init_method_min + if self.init_method_max_attn_proj is None: + self.init_method_max_attn_proj = self.init_method_max + if self.init_method_min_attn_proj is None: + self.init_method_min_attn_proj = self.init_method_min + if self.init_method_max_mlp_1 is None: + self.init_method_max_mlp_1 = self.init_method_max + if self.init_method_min_mlp_1 is None: + self.init_method_min_mlp_1 = self.init_method_min + if self.init_method_max_mlp_2 is None: + self.init_method_max_mlp_2 = self.init_method_max + if self.init_method_min_mlp_2 is None: + self.init_method_min_mlp_2 = self.init_method_min + if self.init_method_min is not None and self.init_method_max is not None: + Assert.leq(self.init_method_min, self.init_method_max) + if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: + Assert.leq(self.init_method_min, self.init_method_max) + if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: + Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) + if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: + Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) + if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: + Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) + if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: + Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) + self.num_unshared_experts = self.num_experts - self.num_shared_experts + Assert.geq( + self.hidden_dropout, 0 + ) # Do we need to check it here again given that its is already asserted in the config field? + if self.norm_lr_scale is not None: + Assert.geq(self.norm_lr_scale, 0) + + if isinstance(self.mlp_lr_scale, list): + Assert.eq(len(self.mlp_lr_scale), self.num_experts) + for scale in self.mlp_lr_scale: + if scale is not None: + Assert.geq(scale, 0) + elif self.mlp_lr_scale is not None: + Assert.geq(self.mlp_lr_scale, 0) + super()._validate() + Assert.leq(self.num_shared_experts, self.num_experts) + Assert.leq(self.num_shared_experts + self.num_experts_per_token, self.num_experts) + + @property + def add_mlp_bias(self) -> bool: + if isinstance(self.add_linear_biases, bool): + return self.add_linear_biases + if self.add_linear_biases == AddLinearBiasChoices.everywhere: + return True + return False + + @property + def add_attn_qkv_bias(self) -> bool: + if isinstance(self.add_linear_biases, bool): + return self.add_linear_biases + if self.add_linear_biases == AddLinearBiasChoices.nowhere: + return False + return True + + @property + def add_attn_dense_bias(self) -> bool: + if isinstance(self.add_linear_biases, bool): + return self.add_linear_biases + if self.add_linear_biases == AddLinearBiasChoices.everywhere: + return True + return False + + def setup_tensor_space(self, tensor_space: TensorSpace, block_name: str = "") -> None: + tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) + + # Hidden dimension + tensor_space.add_tensor_dim(TensorDim(f"{LLMDimNames.hidden}_{block_name}", self.hidden_size)) + + # MLP dimensions + tensor_space.add_tensor_dim(mlp := TensorDim(f"{LLMDimNames.mlp}_{block_name}", self.ffn_hidden_size, tensor)) + tensor_space.add_tensor_dim( + gate_and_up := TensorDim(f"{LLMDimNames.gate_and_up}_{block_name}", 2 if self.gated else 1) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(f"{LLMDimNames.composite_gated_mlp}_{block_name}", (gate_and_up, mlp)) + ) + tensor_space.add_tensor_dim(experts := TensorDim(f"{LLMDimNames.experts}_{block_name}", self.num_experts)) + tensor_space.add_tensor_dim( + CompositeTensorDim(f"{LLMDimNames.composite_expert_mlp}_{block_name}", (experts, mlp)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(f"{LLMDimNames.composite_gated_expert_mlp}_{block_name}", (experts, gate_and_up, mlp)) + ) + tensor_space.add_tensor_dim(TensorDim(f"{LLMDimNames.top_experts}_{block_name}", self.num_experts_per_token)) + tensor_space.add_tensor_dim( + TensorDim(f"{LLMDimNames.unshared_experts}_{block_name}", self.num_unshared_experts) + ) + + # shared_experts + if self.num_shared_experts: + tensor_space.add_tensor_dim( + shared_experts := TensorDim(f"{LLMDimNames.shared_experts}_{block_name}", self.num_shared_experts) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim(f"{LLMDimNames.composite_shared_expert_mlp}_{block_name}", (shared_experts, mlp)) + ) + tensor_space.add_tensor_dim( + CompositeTensorDim( + f"{LLMDimNames.composite_gated_shared_expert_mlp}_{block_name}", + (shared_experts, gate_and_up, mlp), + ) ) - else: - raise NotImplementedError(self.type) diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index 848abb97..5f30beae 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -155,6 +155,7 @@ def __init__( weight_init_method=None, bias_init_method=init_zeros_, zero_centered: bool = False, + lr_scale: float | None = None, ): super().__init__() assert hidden_dim.parallel_dim is None @@ -193,12 +194,14 @@ def __init__( init_method=weight_init_method, weight_decay=False, auto_grad_accumulation=implementation == NormalizationImplementation.torch, + lr_scale=lr_scale, ) self.bias = ParameterMeta.from_dims( (hidden_dim,), init_method=bias_init_method, weight_decay=False, auto_grad_accumulation=implementation == NormalizationImplementation.torch, + lr_scale=lr_scale, ) self.normalized_shape = self.weight.shape @@ -236,6 +239,7 @@ def __init__( implementation: NormalizationImplementation = NormalizationImplementation.auto, weight_init_method=None, zero_centered: bool = False, + lr_scale: float | None = None, ): super().__init__() assert hidden_dim.parallel_dim is None @@ -269,6 +273,7 @@ def __init__( init_method=weight_init_method, weight_decay=False, auto_grad_accumulation=True, + lr_scale=lr_scale, ) self.normalized_shape = self.weight.shape diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 2d5fd843..4fb74eab 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,11 +1,12 @@ import typing +import warnings from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames -from fast_llm.functional.config import CrossEntropyImpl -from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.functional.config import CrossEntropyImpl, TritonConfig +from fast_llm.layers.common.config import NormalizationConfig from fast_llm.utils import Assert @@ -41,10 +42,10 @@ class LanguageModelKwargs: @config_class() class LanguageModelBaseConfig(BaseModelConfig): - transformer: TransformerConfig = Field( - desc="Configuration for the transformer architecture.", - hint=FieldHint.architecture, - ) + """ + Base config for language models. + """ + max_position_embeddings: int = Field( default=2048, desc="Number of absolute position embeddings, if applicable.", @@ -59,7 +60,7 @@ class LanguageModelBaseConfig(BaseModelConfig): ) use_position_embeddings: bool = Field( default=None, - desc="Enable absolute position embeddings. Default: Enable unless using rotary embeddings.", + desc="Enable absolute position embeddings.", # Default: Enable unless using rotary embeddings.", hint=FieldHint.architecture, ) tie_word_embeddings: bool = Field( @@ -155,27 +156,76 @@ class LanguageModelBaseConfig(BaseModelConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + embeddings_lr_scale: float | None = Field( + default=None, + desc="Learning rate scale for the word embeddings.", + doc="May be used to freeze some layers by setting their scale to zero.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + output_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the output weights.", + doc="May be used to freeze the output weights by setting their scale to zero.", + hint=FieldHint.feature, + ) + prediction_loss_coefficient: list[float] | None = Field( + default=None, + desc="Loss coefficient for each prediction head.", + doc="If not provided, all heads are equally weighted.", + hint=FieldHint.feature, + ) + full_precision_residual: bool = Field( + default=False, + desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", + hint=FieldHint.stability, + ) + debug: bool = Field( + default=False, + desc="Enable debug mode.", + hint=FieldHint.testing, + ) + embeddings_hidden_dropout: float = Field( + default=None, + desc="Dropout applied to the embeddings.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + head_normalization: NormalizationConfig | None = Field( + default=None, + desc="Configuration for the normalization in the head.", + hint=FieldHint.architecture, + ) + # Debug, to get an exact match with megatron init. + use_megatron_initialization: bool = Field( + default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing + ) def _validate(self) -> None: - self.transformer.validate() with self._set_implicit_default(): - if self.use_position_embeddings is None: - self.use_position_embeddings = not self.transformer.rotary.enabled - if self.init_method_std_embed is None: - self.init_method_std_embed = self.transformer.init_method_std - if self.init_method_max_embed is None: - self.init_method_max_embed = self.transformer.init_method_max - if self.init_method_min_embed is None: - self.init_method_min_embed = self.transformer.init_method_min + if self.embeddings_hidden_dropout is None: + self.embeddings_hidden_dropout = 0.0 + if self.head_normalization is None: + self.head_normalization = NormalizationConfig() + + if not TritonConfig.TRITON_ENABLED: + warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") + Assert.geq( + self.embeddings_hidden_dropout, 0 + ) # Do we need to check it here again given that its is already asserted in the config field? + super()._validate() if self.init_method_max_embed is not None and self.init_method_min_embed is not None: Assert.leq(self.init_method_min_embed, self.init_method_max_embed) if self.distillation_model is not None: if self.prediction_heads > 1: raise NotImplementedError("Multi-token prediction not supported with distillation.") + if isinstance(self.prediction_loss_coefficient, list): + Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads) + for coeff in self.prediction_loss_coefficient: + Assert.geq(coeff, 0) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - self.transformer.setup_tensor_space(tensor_space) tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Embedding dimensions @@ -206,4 +256,5 @@ def from_flat_dict( cls._handle_renamed_field(default, "normalization_type", "type") cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") + return super().from_flat_dict(default, strict) diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 1d9406ed..fede81fc 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -7,8 +7,9 @@ from fast_llm.core.ops import reduce_forward, split from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.layers.common.config import LLMDimNames from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelDimNames, LanguageModelKwargs -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ from fast_llm.utils import Assert @@ -37,16 +38,16 @@ def __init__( self._tensor_space = tensor_space self._residual_dtype = ( self._distributed_config.optimization_dtype - if config.transformer.full_precision_residual + if config.full_precision_residual else self._distributed_config.training_dtype ).torch self._group_size = self._distributed_config.tensor_parallel self._sequence_parallel = self._distributed_config.sequence_tensor_parallel self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings - self._dropout_p = config.transformer.hidden_dropout + self._dropout_p = config.embeddings_hidden_dropout self._use_absolute_position_embeddings = config.use_absolute_position_embeddings - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = tensor_space.get_tensor_dim(LLMDimNames.input_hidden) vocab_dim = tensor_space.get_tensor_dim( LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab ) @@ -62,6 +63,7 @@ def __init__( min_val=config.init_method_min_embed, max_val=config.init_method_max_embed, ), + lr_scale=config.embeddings_lr_scale, ) if self._use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( @@ -72,14 +74,15 @@ def __init__( max_val=config.init_method_max_embed, ), allow_sequence_tensor_parallel=not config.parallel_embeddings, + lr_scale=config.embeddings_lr_scale, ) - # PEFT. - self.word_embeddings_weight = self._config.transformer.peft.apply_weight(self.word_embeddings_weight) - if hasattr(self, "position_embeddings_weight"): - self.position_embeddings_weight = self._config.transformer.peft.apply_weight( - self.position_embeddings_weight - ) + # PEFT: layer freezing should be done by explicitly setting embeddings_lr_scale to 0.0 + # self.word_embeddings_weight = self._config.peft.apply_weight(self.word_embeddings_weight) + # if hasattr(self, "position_embeddings_weight"): + # self.position_embeddings_weight = self._config.peft.apply_weight( + # self.position_embeddings_weight + # ) @torch.compile def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None) -> torch.Tensor: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index d6d1b8a5..6d6d8d6f 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -15,6 +15,7 @@ from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss +from fast_llm.layers.common.config import LLMDimNames from fast_llm.layers.language_model.config import ( LanguageModelBaseConfig, LanguageModelDimNames, @@ -44,7 +45,7 @@ def __init__( prediction_distance: int, ): super().__init__(config) - self._debug_transformer = config.transformer.debug_transformer + self._debug_transformer = config.debug self._tie_word_embeddings = config.tie_word_embeddings self._tensor_space = tensor_space @@ -58,10 +59,13 @@ def __init__( if self._cross_entropy_splits is not None and self._sequence_parallel: assert not self._parallel_embeddings - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space.get_tensor_dim(LLMDimNames.output_hidden) + self._loss_coefficient = ( + config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 + ) self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) - self.final_norm = config.transformer.normalization.get_layer(hidden_dim) + self.final_norm = config.head_normalization.get_layer(hidden_dim) self._logits_scale_factor = config.logits_scale_factor self._z_loss_factor = config.logit_z_loss @@ -89,12 +93,12 @@ def __init__( self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) - # PEFT. - self.final_norm = self._config.transformer.peft.apply_other(self.final_norm) - if hasattr(self, "output_weights"): - self.output_weights = self._config.transformer.peft.apply_weight(self.output_weights) + # PEFT: layer freezing should be done by explicitly setting output_lr_scale to 0.0 + # self.final_norm = self._config.transformer.peft.apply_other(self.final_norm) + # if hasattr(self, "output_weights"): + # self.output_weights = self._config.transformer.peft.apply_weight(self.output_weights) - def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: + def _init_output_weights(self, hidden_dim: TensorDim, config: LanguageModelBaseConfig) -> None: # Only the first head defines the output weights if self._tie_word_embeddings or self._prediction_distance > 0: return @@ -109,6 +113,7 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: min_val=config.init_method_min_embed, max_val=config.init_method_max_embed, ), + lr_scale=config.output_lr_scale, ) def forward( @@ -139,7 +144,7 @@ def forward( else: if self.training: # Backward hook to compute the gradient of the loss - shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, 1.0) + shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, self._loss_coefficient) # MTP: Return shared_hidden to be used by the next head. return shared_hidden diff --git a/fast_llm/layers/ssm/blocks.py b/fast_llm/layers/ssm/blocks.py new file mode 100644 index 00000000..d0521c32 --- /dev/null +++ b/fast_llm/layers/ssm/blocks.py @@ -0,0 +1,55 @@ +import typing + +from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 +from fast_llm.layers.ssm.mamba_layer import MambaLayer +from fast_llm.layers.transformer.transformer import BaseBlock + +if typing.TYPE_CHECKING: + from fast_llm.engine.config_utils.tensor_space import TensorSpace + from fast_llm.models.hybrid.config import MambaBlockConfig + + +class LlambaBlock(BaseBlock): + """ + A transformer-like decoder block with a discrete Mamba 2 mixer, see https://arxiv.org/abs/2502.14458 + """ + + _mixer_module_name = "mixer" + + def __init__( + self, + config: "MambaBlockConfig", + tensor_space: "TensorSpace", + layer_index: int, + block_name: str = "", + return_input: bool = False, + ): + super().__init__(config, tensor_space, layer_index, block_name, return_input) + + def _create_mixer(self): + self.mixer = DiscreteMamba2( + self._config, layer_index=self._layer_index, tensor_space=self._tensor_space, name=self.block_name + ) + + +class LlambaOneBlock(BaseBlock): + """ + A transformer-like decoder block with a Mamba 1 mixer. + """ + + _mixer_module_name = "mamba1" + + def __init__( + self, + config: "MambaBlockConfig", + tensor_space: "TensorSpace", + layer_index: int, + block_name: str = "", + return_input: bool = False, + ): + super().__init__(config, tensor_space, layer_index, block_name, return_input) + + def _create_mixer(self): + self.mamba1 = MambaLayer( + self._config, layer_index=self._layer_index, tensor_space=self._tensor_space, name=self.block_name + ) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 25ad3d22..f7a978e5 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,7 +1,6 @@ -from fast_llm.config import Field, FieldHint, check_field, config_class -from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import NormalizationConfig +from fast_llm.layers.common.config import BaseBlockConfig, NormalizationConfig from fast_llm.utils import Assert @@ -21,7 +20,7 @@ class SSMDimNames: @config_class() -class SSMConfig(BaseModelConfig): +class SSMConfig(BaseBlockConfig): _abstract = False # Normalization @@ -53,7 +52,8 @@ class SSMConfig(BaseModelConfig): desc="Whether to use bias in SSM layers", hint=FieldHint.architecture, ) - dt_rank: int = Field( + + dt_rank: None | int = Field( default=None, desc="Rank of the Δ projection matrix. If 'None', will be set to ceil(hidden_size/16)", hint=FieldHint.architecture, @@ -102,12 +102,22 @@ class SSMConfig(BaseModelConfig): valid=check_field(Assert.gt, 0), ) + d_inner: None | int = Field( + default=None, + desc="Inner dimension for Mamba2 blocks.", + hint=FieldHint.core, + ) + mamba_lr_scale: float | None = Field( + default=None, + desc="Learning rate scale for Mamba blocks.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + def _validate(self) -> None: with self._set_implicit_default(): if self.activation_type is None: self.activation_type = ActivationType.silu - if self.dt_rank is None: - self.dt_rank = -1 # set to -1, it will be overwrittem in ssm validation super()._validate() Assert.geq(self.dt_max, self.dt_min) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 49dacb91..b01afb03 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,4 +1,6 @@ +import logging import math +import typing import causal_conv1d import einops @@ -7,8 +9,14 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.ssm.config import SSMDimNames from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ +from fast_llm.utils import get_lr_scale + +if typing.TYPE_CHECKING: + from fast_llm.layers.ssm.config import SSMConfig + +logger = logging.getLogger(__name__) """ This code is adapted fropm https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py @@ -26,9 +34,10 @@ class DiscreteMamba2(torch.nn.Module): def __init__( self, - config: SSMConfig, - layer_idx: int, + config: "SSMConfig", + layer_index: int, tensor_space: TensorSpace, + name: str = "", return_input: bool = False, ): """ @@ -40,19 +49,23 @@ def __init__( """ # factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16} super().__init__() - self.config: SSMConfig = config + self.config: "SSMConfig" = config bias = config.add_bias_linear - self.layer_idx = layer_idx + self.layer_idx = layer_index self._return_input = return_input - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) - td_conv = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) - td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.qk_heads) - td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.v_heads) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) - td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.inner_proj_mamba2) + layer_lr_scale = config.lr_scale if config.lr_scale else None + mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) + logger.info(f"Setting lr_scale for layer {layer_index} of type {type(self)}: {mamba_layer_lr_scale}") + + td_inner = tensor_space.get_tensor_dim(f"{SSMDimNames.inner_dim}_{name}") + td_state = tensor_space.get_tensor_dim(f"{SSMDimNames.state_dim}_{name}") + td_model = tensor_space.get_tensor_dim(f"{SSMDimNames.model_dim}_{name}") + td_conv = tensor_space.get_tensor_dim(f"{SSMDimNames.conv_dim}_{name}") + td_n_qk_heads = tensor_space.get_tensor_dim(f"{SSMDimNames.qk_heads}_{name}") + td_n_v_heads = tensor_space.get_tensor_dim(f"{SSMDimNames.v_heads}_{name}") + td_conv_kernel = tensor_space.get_tensor_dim(f"{SSMDimNames.conv_kernel_size}_{name}") + td_inner_proj = tensor_space.get_tensor_dim(f"{SSMDimNames.inner_proj_mamba2}_{name}") self.d_model = td_model.size self.d_inner = td_inner.size @@ -67,12 +80,19 @@ def __init__( # TODO: double check innitializations # Projections - self.in_proj = Linear(td_model, td_inner_proj, bias=bias, weight_init_method=kaiming_init_(td_model.size)) + self.in_proj = Linear( + td_model, + td_inner_proj, + bias=bias, + weight_init_method=kaiming_init_(td_model.size), + lr_scale=mamba_layer_lr_scale, + ) self.z_bias = ( ParameterMeta.from_dims( (td_inner,), weight_decay=False, init_method=init_zeros_, + lr_scale=mamba_layer_lr_scale, ) if not bias else 0.0 @@ -84,14 +104,18 @@ def __init__( init_method=init_uniform_( 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 + lr_scale=mamba_layer_lr_scale, + ) + self.conv1d_bias = ParameterMeta.from_dims( + (td_conv,), init_method=bias_init_method(self.conv1d_weight), lr_scale=mamba_layer_lr_scale ) - self.conv1d_bias = ParameterMeta.from_dims((td_conv,), init_method=bias_init_method(self.conv1d_weight)) # D "skip" parameter self.D = ParameterMeta.from_dims( (td_n_qk_heads,), weight_decay=False, init_method=init_ones_, + lr_scale=mamba_layer_lr_scale, ) # out_proj @@ -100,6 +124,7 @@ def __init__( td_model, bias=bias, weight_init_method=kaiming_init_(td_inner.size), + lr_scale=mamba_layer_lr_scale, ) @property diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py deleted file mode 100644 index ee222d6d..00000000 --- a/fast_llm/layers/ssm/llamba_block.py +++ /dev/null @@ -1,34 +0,0 @@ -import typing - -from fast_llm.layers.transformer.transformer import BaseBlock - -if typing.TYPE_CHECKING: - from fast_llm.engine.config_utils.tensor_space import TensorSpace - from fast_llm.layers.ssm.config import SSMConfig - from fast_llm.layers.transformer.config import TransformerConfig - - -class LlambaBlock(BaseBlock): - """ - A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 - """ - - _name = "Llamba block" - _mixer_module_name = "mixer" - - def __init__( - self, - config_transformer: "TransformerConfig", - config_ssm: "SSMConfig", - tensor_space: "TensorSpace", - mixer_cls, - layer_index: int, - return_input: bool = False, - ): - self.mixer_cls = mixer_cls - self._config_ssm = config_ssm - self._debug_mode = self._config_ssm.debug_ssm - super().__init__(config_transformer, tensor_space, layer_index, return_input) - - def _create_mixer(self): - self.mixer = self.mixer_cls(self._config_ssm, layer_idx=self._layer_index, tensor_space=self._tensor_space) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 4704b522..76be3c4d 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -9,6 +9,7 @@ from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ +from fast_llm.utils import get_lr_scale """ Note: this is mostly addapted from https://github.com/Zyphra/Zamba2, similar code is aslo in https://github.com/state-spaces/mamba. @@ -55,33 +56,37 @@ class MambaLayer(torch.nn.Module): def __init__( self, config: SSMConfig, - layer_idx: int, tensor_space: TensorSpace, + layer_index: int, + name: str = "", return_input: bool = False, ): factory_kwargs = {} super().__init__() self.config: SSMConfig = config - self.layer_idx = layer_idx + self.layer_idx = layer_index self._debug_mode = config.debug_ssm # Tensor dims: - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) + td_inner = tensor_space.get_tensor_dim(f"{SSMDimNames.inner_dim}_{name}") td_inner_proj = tensor_space.get_tensor_dim( - SSMDimNames.inner_proj_mamba + f"{SSMDimNames.inner_proj_mamba}_{name}" ) # TensorDim("D_inner_2", self.d_inner * 2) - tdt_rank = tensor_space.get_tensor_dim(SSMDimNames.dt_rank) - td_x_proj = tensor_space.get_tensor_dim(SSMDimNames.x_proj_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) + tdt_rank = tensor_space.get_tensor_dim(f"{SSMDimNames.dt_rank}_{name}") + td_x_proj = tensor_space.get_tensor_dim(f"{SSMDimNames.x_proj_dim}_{name}") + td_state = tensor_space.get_tensor_dim(f"{SSMDimNames.state_dim}_{name}") + td_model = tensor_space.get_tensor_dim(f"{SSMDimNames.model_dim}_{name}") + td_conv_kernel = tensor_space.get_tensor_dim(f"{SSMDimNames.conv_kernel_size}_{name}") self.d_conv = td_conv_kernel.size self.d_inner = td_inner.size self.d_state = td_state.size self.d_model = td_model.size self.dt_rank = tdt_rank.size + layer_lr_scale = self.config.lr_scale if self.config.lr_scale else None + mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) + self.in_proj_weight = ParameterMeta.from_dims( (td_inner_proj, td_model), init_method=kaiming_init_(td_model.size), @@ -90,6 +95,7 @@ def __init__( self.conv1d_weight = ParameterMeta.from_dims( (td_inner, TensorDim("D_inner_2", self.d_inner // self.d_inner), td_conv_kernel), init_method=kaiming_init_(td_inner.size), + lr_scale=mamba_layer_lr_scale, ) self.conv1d_bias = None @@ -102,6 +108,7 @@ def __init__( td_x_proj, weight_init_method=kaiming_init_(td_inner.size), bias=False, + lr_scale=mamba_layer_lr_scale, **factory_kwargs, ) self.x_proj.weight.auto_grad_accumulation = True @@ -110,6 +117,7 @@ def __init__( self.dt_proj_weight = ParameterMeta.from_dims( (td_inner, tdt_rank), init_method=kaiming_init_(tdt_rank.size), + lr_scale=mamba_layer_lr_scale, ) self.dt_proj_bias = ParameterMeta.from_dims( @@ -117,12 +125,14 @@ def __init__( init_method=init_dtprojbias( self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor, factory_kwargs ), + lr_scale=mamba_layer_lr_scale, ) self.A_log = ParameterMeta.from_dims( (td_inner, td_state), weight_decay=False, init_method=init_A(self.d_state, self.d_inner), + lr_scale=mamba_layer_lr_scale, ) # D "skip" parameter @@ -130,6 +140,7 @@ def __init__( (td_inner,), weight_decay=False, init_method=init_ones_, + lr_scale=mamba_layer_lr_scale, ) self.out_proj = Linear( @@ -137,6 +148,7 @@ def __init__( td_model, bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. weight_init_method=kaiming_init_(td_model.size), + lr_scale=mamba_layer_lr_scale, **factory_kwargs, ) self.out_proj.weight.auto_grad_accumulation = True diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 0b442f66..f267b1bb 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -17,7 +17,7 @@ ) from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ -from fast_llm.utils import Assert +from fast_llm.utils import Assert, get_lr_scale try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -80,6 +80,7 @@ def __init__( config: TransformerConfig, tensor_space: TensorSpace, layer_index, + block_name: str = "", ): super().__init__() self._config = config @@ -87,7 +88,7 @@ def __init__( Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1)) self._layer_index = layer_index self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel - self._debug_transformer = self._config.debug_transformer + self._debug_transformer = self._config.debug_block self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) init_method_qkv = init_normal_( @@ -101,45 +102,54 @@ def __init__( max_val=self._config.init_method_max_attn_proj, ) - self._kv_channels = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels).size - self._head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).global_size - self._local_head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).size - self._local_heads_per_group = self._tensor_space.get_tensor_dim(TransformerDimNames.group_heads).size + self._kv_channels = self._tensor_space.get_tensor_dim(f"{TransformerDimNames.kv_channels}_{block_name}").size + self._head_groups = self._tensor_space.get_tensor_dim( + f"{TransformerDimNames.head_groups}_{block_name}" + ).global_size + self._local_head_groups = self._tensor_space.get_tensor_dim( + f"{TransformerDimNames.head_groups}_{block_name}" + ).size + self._local_heads_per_group = self._tensor_space.get_tensor_dim( + f"{TransformerDimNames.group_heads}_{block_name}" + ).size self._local_heads = self._local_head_groups * self._local_heads_per_group self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space.get_tensor_dim(f"{TransformerDimNames.hidden}_{block_name}") + + layer_lr_scale = self._config.lr_scale if self._config.lr_scale else None + attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_query), + self._tensor_space.get_tensor_dim(f"{TransformerDimNames.composite_query}_{block_name}"), bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=self._config.attention_lr_scale, + lr_scale=attention_lr_scale, ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_key_value), + self._tensor_space.get_tensor_dim(f"{TransformerDimNames.composite_key_value}_{block_name}"), bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=self._config.attention_lr_scale, + lr_scale=attention_lr_scale, ) self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward) # Output. self.dense = InputParallelLinear( - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_dense), + self._tensor_space.get_tensor_dim(f"{TransformerDimNames.composite_dense}_{block_name}"), hidden_dim, bias=self._config.add_attn_dense_bias, weight_init_method=init_method_std_attn_proj, bias_init_method=init_method_std_attn_proj if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=self._config.attention_lr_scale, + lr_scale=attention_lr_scale, ) # PEFT. diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index e7ef0b15..6a1cf8cf 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -10,34 +10,25 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames -from fast_llm.functional.config import ActivationType, MLPRecomputeLevel, TritonConfig -from fast_llm.layers.common.config import NormalizationConfig, PeftConfig, PeftType +from fast_llm.functional.config import TritonConfig +from fast_llm.layers.common.config import ( + BaseBlockConfig, + BaseBlockLoRAConfig, + BaseBlockSubLayerName, + LLMDimNames, + PeftConfig, +) from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: - import torch + pass from fast_llm.layers.common.linear import LinearBase, LinearLike - from fast_llm.tensor import ParameterMeta logger = logging.getLogger(__name__) -class RoutingType(str, enum.Enum): - topk = "aux_loss" - sinkhorn = "sinkhorn" - - -class TransformerDimNames: - # A set of common tensor dim names packed into a namespace. - # Input dimensions (variable) - # TODO: Does batch belong here? - batch = "batch" - # TODO: Distinguish micro-sequence? - sequence_q = "sequence_q" - sequence_q_tp = "sequence_q_tp" - sequence_k = "sequence_k" - hidden = "hidden" +class TransformerDimNames(LLMDimNames): # Self-attention dimensions head_groups = "head_groups" group_heads = "group_heads" @@ -47,18 +38,6 @@ class TransformerDimNames: composite_query = "composite_query" composite_key_value = "composite_key_value" composite_dense = "composite_dense" - # MLP dimensions - mlp = "mlp" - gate_and_up = "gate_and_up" - composite_gated_mlp = "composite_gated_mlp" - experts = "experts" - top_experts = "top_experts" - shared_experts = "shared_experts" - unshared_experts = "unshared_experts" - composite_expert_mlp = "composite_expert_mlp" - composite_gated_expert_mlp = "composite_gated_expert_mlp" - composite_shared_expert_mlp = "composite_shared_expert_mlp" - composite_gated_shared_expert_mlp = "composite_gated_shared_expert_mlp" class TransformerKwargs: @@ -170,277 +149,96 @@ class AddLinearBiasChoices(str, enum.Enum): only_attn_qkv = "only_attn_qkv" -class TransformerSubLayerName(str, enum.Enum): +class TransformerSubLayerName(BaseBlockSubLayerName): # TODO: Use this to replace AddLinearBiasChoices. query = "query" key = "key" value_ = "value" key_value = "key_value" dense = "dense" - mlp_1 = "mlp_1" - mlp_2 = "mlp_2" -@config_class(registry=True) -class TransformerPeftConfig(PeftConfig): +@config_class(dynamic_type={PeftConfig: "transformer_lora"}) +class TransformerLoRaConfig(BaseBlockLoRAConfig): + """ + LoRa config that applies to transformer layer. If this is used with GPTBaseModel it is reused for all transformer layers. + Note, this does not freeze layers. + If you want to freeze weights, you need to do so explicitly by setting the corresponding layer's lr_scales (embeddings/mlp etc.) to 0. + """ + layers: list[TransformerSubLayerName] = Field( default=None, desc="The layers on which to apply LoRA.", hint=FieldHint.feature, ) - freeze_others: bool = Field( - default=True, - desc="Whether to freeze other layers during training.", - ) def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - if self.type != PeftType.none: - if layer_type is None or self.layers is None or layer_type in self.layers: - if layer_type == TransformerSubLayerName.key: - return super().apply_linear(linear, out_channel_end=div(linear._out_dim.global_size, 2)) - elif layer_type == TransformerSubLayerName.value_: - return super().apply_linear(linear, out_channel_begin=div(linear._out_dim.global_size, 2)) - else: - return super().apply_linear(linear) - elif self.freeze_others: - linear.weight.requires_grad = False + if layer_type is None or self.layers is None or layer_type in self.layers: + if layer_type == TransformerSubLayerName.key: + return super().apply_linear(linear, out_channel_end=div(linear._out_dim.global_size, 2)) + elif layer_type == TransformerSubLayerName.value_: + return super().apply_linear(linear, out_channel_begin=div(linear._out_dim.global_size, 2)) + else: + return super().apply_linear(linear) + # elif self.freeze_others: + # linear.weight.requires_grad = False return linear - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - if self.type != PeftType.none and self.freeze_others: - for parameter in module.parameters(): - parameter.requires_grad = False - return module - - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - if self.type != PeftType.none and self.freeze_others: - parameter.requires_grad = False - return parameter - def _validate(self) -> None: if self.layers is None: with self._set_implicit_default(): # Setting the default layers only whee PeFT is enabled # so they don't appear when serializing the default transformer config. - self.layers = ( - [TransformerSubLayerName.query, TransformerSubLayerName.value_] - if self.type == PeftType.lora - else [] - ) - if self.type != PeftType.none: - if TransformerSubLayerName.mlp_1 in self.layers or TransformerSubLayerName.mlp_2 in self.layers: - # TODO: Add MLP support. - raise NotImplementedError("LoRA not supported for MLP.") - if TransformerSubLayerName.dense in self.layers: - # TODO: Support InputParallelLinear (different output format). - raise NotImplementedError("LoRA not supported for attention dense layer.") - if ( - sum( - name in self.layers - for name in ( - TransformerSubLayerName.key_value, - TransformerSubLayerName.key, - TransformerSubLayerName.value_, - ) - ) - > 1 - ): - raise ValueError( - f"{TransformerSubLayerName.key_value.value}, {TransformerSubLayerName.key.value} and {TransformerSubLayerName.value_.value} are mutually exclusive." + self.layers = [TransformerSubLayerName.query, TransformerSubLayerName.value_] + super()._validate() + if TransformerSubLayerName.dense in self.layers: + # TODO: Support InputParallelLinear (different output format). + raise NotImplementedError("LoRA not supported for attention dense layer.") + if ( + sum( + name in self.layers + for name in ( + TransformerSubLayerName.key_value, + TransformerSubLayerName.key, + TransformerSubLayerName.value_, ) - - -for name in PeftType: - # We need this because we are using the reserved field name `type`. - # TODO: Implement proper dynamic typing. - TransformerPeftConfig.register_subclass(name.value, TransformerPeftConfig) + ) + > 1 + ): + raise ValueError( + f"{TransformerSubLayerName.key_value.value}, {TransformerSubLayerName.key.value} and {TransformerSubLayerName.value_.value} are mutually exclusive." + ) @config_class() -class TransformerConfig(BaseModelConfig): +class TransformerConfig(BaseBlockConfig): _abstract = False - normalization: NormalizationConfig = Field( - desc="Configuration for the normalization layers architecture.", - hint=FieldHint.architecture, - ) + # normalization: NormalizationConfig = Field( + # desc="Configuration for the normalization layers architecture.", + # hint=FieldHint.architecture, + # ) rotary: RotaryConfig = Field( desc="Configuration for the rotary positional embeddings.", hint=FieldHint.architecture, ) - peft: TransformerPeftConfig = Field( - desc="Configuration for the parameter-efficient fine tuning.", - hint=FieldHint.architecture, - ) - num_layers: int = Field( - default=12, - desc="Number of layers in the transformer.", - hint=FieldHint.architecture, - valid=check_field(Assert.geq, 0), - ) - hidden_size: int = Field( - default=1024, - desc="Size of the transformer's main hidden dimension, e.g., for its input and output layers.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - num_attention_heads: int = Field(default=8, desc="Number of attention heads.", hint=FieldHint.architecture) - head_groups: int = Field( - default=1, - desc="Number of head group for grouped query attention.", - doc="Set to 1 for multi-query attention, `num_attention_heads` for multi-head.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - add_linear_biases: bool | AddLinearBiasChoices = Field( - default=True, - desc="Add biases to all, none or Q, K, V layers. Accepted values: True, False, or AddLinearBiasChoices.", - hint=FieldHint.architecture, - ) - ffn_hidden_size: int = Field( - default=None, - desc="Hidden dimension of the MLP intermediate state. Default: 4 * hidden_size.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - kv_channels: int = Field( - default=None, - desc="Number of key and value channels, i.e., hidden dimension of each attention head. Default: hidden_size // num_attention_heads", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - gated: bool = Field(default=False, desc="Enable gated MLP.", hint=FieldHint.architecture) - num_experts: int = Field( - default=1, - desc="Number of MLP experts in a Mixture of Expert (MoE) model", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - num_shared_experts: int = Field( - default=0, - desc="Number of MLP experts that are shared between all tokens, i.e., always enabled.", - hint=FieldHint.architecture, - valid=check_field(Assert.geq, 0), - ) - num_unshared_experts: int = Field( - init=False, - desc="Number of MLP experts excluding shared ones", - hint=FieldHint.architecture, - valid=check_field(Assert.geq, 0), - ) - num_experts_per_token: int = Field( - default=1, - desc="Active experts for each token in a MoE model.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - expert_routing_type: RoutingType = Field( - default=RoutingType.topk, - desc="The routing method, i.e., the method used to assign experts to tokens.", - hint=FieldHint.architecture, - ) - activation_type: ActivationType = Field( - default=None, - desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", - hint=FieldHint.core, - ) - # Default: hidden_size**-0.5 - # TODO: Allow custom initialization (InitializationConfig?) - init_method_std: float = Field( - default=None, - desc="Default scale for weight initialization. Default: hidden_size**-0.5", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max: float | None = Field( - default=None, - desc="Max value for clamping initialized weights. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min: float | None = Field( - default=None, - desc="Min value for clamping initialized weights. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_qkv: float = Field( - default=None, - desc="Scale for the query, key and value weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_qkv: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for query, key and value matrices. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_qkv: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for query, key and value matrices. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_attn_proj: float = Field( - default=None, - desc="Scale for the attention projection weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_attn_proj: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for attention projection. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_attn_proj: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for attention projection. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_mlp_1: float = Field( - default=None, - desc="Scale for the MLP first layer weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_mlp_1: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for MLP first layer. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_mlp_1: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for MLP first layer. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_mlp_2: float = Field( - default=None, - desc="Scale for the MLP second layer weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_mlp_2: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for MLP second layer. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_mlp_2: float | None = Field( + # peft: PeftConfig = FieldUpdate( + # desc="Configuration for the parameter-efficient fine tuning.", + # hint=FieldHint.architecture, + # ) + attention_lr_scale: float | None = Field( default=None, - desc="Min value for clamping initialized weights for MLP second layer. Default: -float('inf')", - hint=FieldHint.optional, - ) - attention_dropout: float = Field( - default=0.0, - desc="Dropout applied to the attention intermediate states.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - hidden_dropout: float = Field( - default=0.0, - desc="Dropout applied to the residual connections.", + desc="Custom learning rate scale for the Attention projection weights.", + doc="Can be used in muP to scale the Attention learning rate by 1/width_factor", hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), + valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - full_precision_residual: bool = Field( - default=False, - desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", - hint=FieldHint.stability, + attention_softmax_scale_power: float = Field( + default=0.5, + desc="The scaling power to apply to kv_channel in the attention calculation. " + " Under Standard Parameterization (SP): default to 0.5. " + " Under muP (if scaling kv_channels size): use 1. " + " Under muP (if scaling number of heads instead of kv_channels): use 0.5.", + valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) # Use flash attention if possible (fp16 or bf16) use_flash_attention: bool = Field( @@ -458,179 +256,120 @@ class TransformerConfig(BaseModelConfig): hint=FieldHint.optional, valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - # normalization_implementation: NormalizationImplementation = NormalizationImplementation.auto - mlp_recompute_level: MLPRecomputeLevel = Field( - default=MLPRecomputeLevel.none, - desc="Set which of the MLP intermediate activations will be recomputed during the backward passes. This provides a trade-off between memory and speed.", - hint=FieldHint.performance, - ) - debug_transformer: int = Field( - default=0, - desc="Log the output of each operation in a transformer layer.", - hint=FieldHint.logging, - valid=check_field(Assert.geq, 0), - ) - debug_transformer_memory: bool = Field( - default=False, - desc="Log the memory usage after each operation in a transformer layer..", - hint=FieldHint.logging, - ) - # Use random inits instead of constant values, useful for debugging. - random_bias_init: bool = Field( - default=False, - desc="Initialize the biases using the initialization method of their respective weights instead of setting them to zero. Used to test for issues that may not be visible when the biases are zero.", - hint=FieldHint.testing, - ) - expert_auxiliary_loss_coefficient: float = Field( - default=0.01, - desc="Scale of the load balancing auxiliary loss for topk routing.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - expert_z_loss_coefficient: float = Field( - default=0.0, - desc="Regularize the router during training by applying Z-loss to the logits.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - moe_jitter_eps: float = Field( + attention_dropout: float = Field( default=0.0, - desc="Regularize the router during training by applying a random multiplicative noise `uniform(1-eps, 1+eps)` to the logits.", + desc="Dropout applied to the attention intermediate states.", hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - mlp_lr_scale: float | None | list[float | None] = Field( - default=None, - desc="Custom learning rate scale for each expert.", - doc="May be used to freeze some experts by setting their scale to zero.", - hint=FieldHint.feature, - ) - router_lr_scale: float | None = Field( - default=None, - desc="Custom learning rate for the MoE router weight.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - attention_lr_scale: float | None = Field( + kv_channels: int = Field( default=None, - desc="Custom learning rate scale for the Attention projection weights.", - doc="Can be used in muP to scale the Attention learning rate by 1/width_factor", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - attention_softmax_scale_power: float = Field( - default=0.5, - desc="The scaling power to apply to kv_channel in the attention calculation. " - " Under Standard Parameterization (SP): default to 0.5. " - " Under muP (if scaling kv_channels size): use 1. " - " Under muP (if scaling number of heads instead of kv_channels): use 0.5.", - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - dropless_moe: bool = Field( - default=True, desc="Evaluate all the experts at once using dropless MoE.", hint=FieldHint.expert + desc="Number of key and value channels, i.e., hidden dimension of each attention head. Default: hidden_size // num_attention_heads", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), ) - dropless_dynamic_shape: bool = Field( - default=False, - desc="Use a dynamic shape for dropless MLP instead of the worst-case value." - " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", - hint=FieldHint.expert, + num_attention_heads: int = Field(default=8, desc="Number of attention heads.", hint=FieldHint.architecture) + head_groups: int = Field( + default=1, + desc="Number of head group for grouped query attention.", + doc="Set to 1 for multi-query attention, `num_attention_heads` for multi-head.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), ) def _validate(self) -> None: with self._set_implicit_default(): - if self.ffn_hidden_size is None: - self.ffn_hidden_size = 4 * self.hidden_size if self.kv_channels is None: self.kv_channels = div(self.hidden_size, self.num_attention_heads) - if self.activation_type is None: - self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu - if self.init_method_std is None: - self.init_method_std = self.hidden_size**-0.5 - if self.init_method_std_qkv is None: - self.init_method_std_qkv = self.init_method_std - if self.init_method_std_attn_proj is None: - self.init_method_std_attn_proj = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 - if self.init_method_std_mlp_1 is None: - self.init_method_std_mlp_1 = self.init_method_std - if self.init_method_std_mlp_2 is None: - self.init_method_std_mlp_2 = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 - if self.init_method_max_qkv is None: - self.init_method_max_qkv = self.init_method_max - if self.init_method_min_qkv is None: - self.init_method_min_qkv = self.init_method_min - if self.init_method_max_attn_proj is None: - self.init_method_max_attn_proj = self.init_method_max - if self.init_method_min_attn_proj is None: - self.init_method_min_attn_proj = self.init_method_min - if self.init_method_max_mlp_1 is None: - self.init_method_max_mlp_1 = self.init_method_max - if self.init_method_min_mlp_1 is None: - self.init_method_min_mlp_1 = self.init_method_min - if self.init_method_max_mlp_2 is None: - self.init_method_max_mlp_2 = self.init_method_max - if self.init_method_min_mlp_2 is None: - self.init_method_min_mlp_2 = self.init_method_min - if self.init_method_min is not None and self.init_method_max is not None: - Assert.leq(self.init_method_min, self.init_method_max) - if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: - Assert.leq(self.init_method_min, self.init_method_max) - if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: - Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) - if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: - Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) - if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: - Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) - if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: - Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) - self.num_unshared_experts = self.num_experts - self.num_shared_experts - - super()._validate() - - if not TritonConfig.TRITON_ENABLED: - warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") - - Assert.leq(self.num_shared_experts, self.num_experts) - Assert.leq(self.num_shared_experts + self.num_experts_per_token, self.num_experts) Assert.multiple(self.num_attention_heads, self.head_groups) Assert.geq(self.attention_dropout, 0) - Assert.geq(self.hidden_dropout, 0) + super()._validate() - if isinstance(self.mlp_lr_scale, list): - Assert.eq(len(self.mlp_lr_scale), self.num_experts) - for scale in self.mlp_lr_scale: - if scale is not None: - Assert.geq(scale, 0) - elif self.mlp_lr_scale is not None: - Assert.geq(self.mlp_lr_scale, 0) + # with self._set_implicit_default(): + # if self.ffn_hidden_size is None: + # self.ffn_hidden_size = 4 * self.hidden_size + # if self.kv_channels is None: + # self.kv_channels = div(self.hidden_size, self.num_attention_heads) + # if self.activation_type is None: + # self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu + # if self.init_method_std is None: + # self.init_method_std = self.hidden_size**-0.5 + # if self.init_method_std_qkv is None: + # self.init_method_std_qkv = self.init_method_std + # if self.init_method_std_attn_proj is None: + # self.init_method_std_attn_proj = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 + # if self.init_method_std_mlp_1 is None: + # self.init_method_std_mlp_1 = self.init_method_std + # if self.init_method_std_mlp_2 is None: + # self.init_method_std_mlp_2 = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 + # if self.init_method_max_qkv is None: + # self.init_method_max_qkv = self.init_method_max + # if self.init_method_min_qkv is None: + # self.init_method_min_qkv = self.init_method_min + # if self.init_method_max_attn_proj is None: + # self.init_method_max_attn_proj = self.init_method_max + # if self.init_method_min_attn_proj is None: + # self.init_method_min_attn_proj = self.init_method_min + # if self.init_method_max_mlp_1 is None: + # self.init_method_max_mlp_1 = self.init_method_max + # if self.init_method_min_mlp_1 is None: + # self.init_method_min_mlp_1 = self.init_method_min + # if self.init_method_max_mlp_2 is None: + # self.init_method_max_mlp_2 = self.init_method_max + # if self.init_method_min_mlp_2 is None: + # self.init_method_min_mlp_2 = self.init_method_min + # if self.init_method_min is not None and self.init_method_max is not None: + # Assert.leq(self.init_method_min, self.init_method_max) + # if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: + # Assert.leq(self.init_method_min, self.init_method_max) + # if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: + # Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) + # if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: + # Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) + # if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: + # Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) + # if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: + # Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) + # self.num_unshared_experts = self.num_experts - self.num_shared_experts + + # super()._validate() + + # # if not TritonConfig.TRITON_ENABLED: + # # warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") + + # Assert.leq(self.num_shared_experts, self.num_experts) + # Assert.leq(self.num_shared_experts + self.num_experts_per_token, self.num_experts) + # Assert.multiple(self.num_attention_heads, self.head_groups) + # Assert.geq(self.attention_dropout, 0) @functools.cached_property def projection_size(self): assert self._validated return self.num_attention_heads * self.kv_channels - @property - def add_mlp_bias(self) -> bool: - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False - - @property - def add_attn_qkv_bias(self) -> bool: - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.nowhere: - return False - return True - - @property - def add_attn_dense_bias(self) -> bool: - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False + # # @property + # def add_mlp_bias(self) -> bool: + # if isinstance(self.add_linear_biases, bool): + # return self.add_linear_biases + # if self.add_linear_biases == AddLinearBiasChoices.everywhere: + # return True + # return False + + # @property + # def add_attn_qkv_bias(self) -> bool: + # if isinstance(self.add_linear_biases, bool): + # return self.add_linear_biases + # if self.add_linear_biases == AddLinearBiasChoices.nowhere: + # return False + # return True + + # @property + # def add_attn_dense_bias(self) -> bool: + # if isinstance(self.add_linear_biases, bool): + # return self.add_linear_biases + # if self.add_linear_biases == AddLinearBiasChoices.everywhere: + # return True + # return False @classmethod def _from_dict( @@ -650,65 +389,47 @@ def _from_dict( cls._handle_renamed_field(default, "triton_rotary", ("rotary", "triton")) return super()._from_dict(default, strict, flat) - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + def setup_tensor_space(self, tensor_space: TensorSpace, block_name: str = "") -> None: + super().setup_tensor_space(tensor_space, block_name) tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.hidden, self.hidden_size)) - # Self-attention dimensions tensor_space.add_tensor_dim( head_groups := TensorDim( - TransformerDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None + f"{TransformerDimNames.head_groups}_{block_name}", + self.head_groups, + tensor if self.head_groups > 1 else None, ) ) tensor_space.add_tensor_dim( group_heads := TensorDim( - TransformerDimNames.group_heads, + f"{TransformerDimNames.group_heads}_{block_name}", div(self.num_attention_heads, self.head_groups), None if self.head_groups > 1 else tensor, ) ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(TransformerDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(TransformerDimNames.kv_channels, self.kv_channels)) + tensor_space.add_tensor_dim(key_and_value := TensorDim(f"{TransformerDimNames.key_and_value}_{block_name}", 2)) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_heads, (head_groups, group_heads)) + kv_channels := TensorDim(f"{TransformerDimNames.kv_channels}_{block_name}", self.kv_channels) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_query, (head_groups, group_heads, kv_channels)) + CompositeTensorDim(f"{TransformerDimNames.composite_heads}_{block_name}", (head_groups, group_heads)) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) + CompositeTensorDim( + f"{TransformerDimNames.composite_query}_{block_name}", (head_groups, group_heads, kv_channels) + ) ) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_dense, (head_groups, group_heads, kv_channels)) + CompositeTensorDim( + f"{TransformerDimNames.composite_key_value}_{block_name}", (key_and_value, head_groups, kv_channels) + ) ) - - # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(TransformerDimNames.mlp, self.ffn_hidden_size, tensor)) - tensor_space.add_tensor_dim(gate_and_up := TensorDim(TransformerDimNames.gate_and_up, 2 if self.gated else 1)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_gated_mlp, (gate_and_up, mlp))) - tensor_space.add_tensor_dim(experts := TensorDim(TransformerDimNames.experts, self.num_experts)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_expert_mlp, (experts, mlp))) tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) - ) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.unshared_experts, self.num_unshared_experts)) - - # shared_experts - if self.num_shared_experts: - tensor_space.add_tensor_dim( - shared_experts := TensorDim(TransformerDimNames.shared_experts, self.num_shared_experts) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_shared_expert_mlp, (shared_experts, mlp)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim( - TransformerDimNames.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) - ) + CompositeTensorDim( + f"{TransformerDimNames.composite_dense}_{block_name}", (head_groups, group_heads, kv_channels) ) + ) def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: use_flash_attention = self.use_flash_attention and distributed_config.training_dtype in ( diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 85c6686f..f4fc8cf9 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -10,9 +10,9 @@ from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss +from fast_llm.layers.common.config import RoutingType from fast_llm.layers.common.linear import Linear from fast_llm.layers.transformer.config import ( - RoutingType, TransformerConfig, TransformerDimNames, TransformerKwargs, @@ -21,7 +21,7 @@ from fast_llm.layers.transformer.mlp import MLPBase from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta, init_normal_ -from fast_llm.utils import Assert +from fast_llm.utils import Assert, get_lr_scale logger = logging.getLogger(__name__) @@ -40,14 +40,14 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." super().__init__(config, tensor_space, name) self._config = config self._tensor_space = tensor_space - self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory + self._debug_mode = self._config.debug_block or self._config.debug_block_memory self._num_experts = config.num_experts self._experts_per_token = config.num_experts_per_token @@ -59,6 +59,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._z_loss_factor = config.expert_z_loss_coefficient self._moe_jitter_eps = config.moe_jitter_eps + layer_lr_scale = self._config.lr_scale if self._config.lr_scale else None + router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) + self.router = Linear( tensor_space.get_tensor_dim(TransformerDimNames.hidden), tensor_space.get_tensor_dim(TransformerDimNames.unshared_experts), @@ -66,7 +69,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s weight_init_method=init_normal_( std=config.init_method_std, min_val=config.init_method_min, max_val=config.init_method_max ), - lr_scale=config.router_lr_scale, + lr_scale=router_lr_scale, ) dropless_moe = config.dropless_moe if dropless_moe and tensor_space.distributed_config.sequence_tensor_parallel: @@ -226,15 +229,15 @@ def _debug_log( kwargs: dict[str, typing.Any], global_: bool = True, ) -> None: - if self._config.debug_transformer_memory: + if self._config.debug_block_memory: log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self._name} {name}", str)) - if self._config.debug_transformer and tensor is not None: + if self._config.debug_block and tensor is not None: # TODO: Local vs global meta = self._get_meta(tensor, name, dim_name, kwargs) log_distributed_tensor( "", tensor.view_as(meta), - level=self._config.debug_transformer, + level=self._config.debug_block, meta=meta, distributed=self._tensor_space.distributed, global_=global_, @@ -243,7 +246,7 @@ def _debug_log( log_distributed_grad( "", tensor, - level=self._config.debug_transformer, + level=self._config.debug_block, meta=self._get_meta(tensor, name + " grad", dim_name, kwargs), distributed=self._tensor_space.distributed, grad_fn=lambda tensor_: tensor_.view_as(meta), diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index 1c38705f..03bfba22 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -7,16 +7,18 @@ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd +from fast_llm.layers.common.config import BaseBlockConfig from fast_llm.layers.common.linear import LinearBase -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerSubLayerName +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerSubLayerName from fast_llm.tensor import init_normal_, init_zeros_ -from fast_llm.utils import Assert +from fast_llm.utils import Assert, get_lr_scale class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: BaseBlockConfig, tensor_space: TensorSpace, block_name: str = "", layer_index: int = 0): super().__init__() - self._name = name + self._block_name = block_name + self._layer_index = layer_index init_method_1 = init_normal_( std=config.init_method_std_mlp_1, @@ -29,8 +31,10 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s max_val=config.init_method_max_mlp_2, ) - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - self._intermediate_dim = tensor_space.get_tensor_dim(TransformerDimNames.composite_expert_mlp) + hidden_dim = tensor_space.get_tensor_dim(f"{TransformerDimNames.hidden}_{block_name}") + self._intermediate_dim = tensor_space.get_tensor_dim( + f"{TransformerDimNames.composite_expert_mlp}_{block_name}" + ) self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._recompute_level = config.mlp_recompute_level @@ -38,14 +42,18 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._activation_type = config.activation_type self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation + layer_lr_scale = config.lr_scale if config.lr_scale else None + mlp_lr_scale = tuple(config.lr_scale) if isinstance(config.lr_scale, list) else config.lr_scale + lr_scale = get_lr_scale(mlp_lr_scale, layer_lr_scale) + # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - tensor_space.get_tensor_dim(TransformerDimNames.composite_gated_expert_mlp), + tensor_space.get_tensor_dim(f"{TransformerDimNames.composite_gated_expert_mlp}_{block_name}"), bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, - lr_scale=tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale, + lr_scale=lr_scale, ) self.layer_2 = LinearBase( self._intermediate_dim, @@ -55,7 +63,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s bias_init_method=init_method_2 if config.random_bias_init else init_zeros_, auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1, transposed_weight=True, - lr_scale=tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale, + lr_scale=lr_scale, ) # PEFT. @@ -64,9 +72,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: BaseBlockConfig, tensor_space: TensorSpace, block_name: str = "", layer_index: int = 0): Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, name) + super().__init__(config, tensor_space, block_name, layer_index) def forward( self, diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 40dd2e00..85a9b465 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -8,12 +8,14 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.layers.common.config import BaseBlockConfig, LLMDimNames from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta +from fast_llm.utils import get_lr_scale logger = logging.getLogger(__name__) @@ -26,30 +28,41 @@ class BaseBlock(Layer, abc.ABC): _mixer_module_name = "self_attn" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, + config: BaseBlockConfig, + tensor_space: TensorSpace, + layer_index: int, + block_name: str = "", + return_input: bool = False, ): super().__init__() self._config: TransformerConfig = config self._tensor_space: TensorSpace = tensor_space self._dropout_p: float = self._config.hidden_dropout + self.block_name = block_name # this name is used for tensor space setup and corresponds to the block name in the hybrid setup or to "" in the old setup (GPT Model) # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input self._layer_index = layer_index - self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) - self.norm_1 = self._config.normalization.get_layer(hidden_dim) - self.norm_2 = self._config.normalization.get_layer(hidden_dim) + self._debug_mode = self._config.debug_block or self._config.debug_block_memory + hidden_dim = self._tensor_space.get_tensor_dim(f"{LLMDimNames.hidden}_{block_name}") + # Note, layer_lr_scale does not impact the norms + + layer_lr_scale = self._config.lr_scale if self._config.lr_scale else None + norm_lr_scale = get_lr_scale(self._config.norm_lr_scale, layer_lr_scale) + + self.norm_1 = self._config.normalization.get_layer(hidden_dim, lr_scale=norm_lr_scale) + self.norm_2 = self._config.normalization.get_layer(hidden_dim, lr_scale=norm_lr_scale) self._create_mixer() self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp" + self._config, self._tensor_space, f"{self.block_name}", layer_index=layer_index ) - # PEFT. - self.norm_1 = self._config.peft.apply_other(self.norm_1) - self.norm_2 = self._config.peft.apply_other(self.norm_2) + # PEFT. Layer freezing must be explicit now. + # self.norm_1 = self._config.peft.apply_other(self.norm_1) + # self.norm_2 = self._config.peft.apply_other(self.norm_2) @abc.abstractmethod def _create_mixer(self): @@ -65,7 +78,7 @@ def _bias_dropout_add( @property def name(self) -> str: - return f"{self._name} {self._layer_index}" + return f"{self.__class__.__name__} {self.block_name} {self._layer_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[TransformerKwargs.hidden_dims] @@ -74,21 +87,21 @@ def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) def _debug_log(self, tensor: torch.Tensor | None, name: str, kwargs: dict[str, typing.Any], *, bias=None) -> None: - if self._config.debug_transformer_memory: + if self._config.debug_block_memory: log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self.name} {name}", str)) - if self._config.debug_transformer and tensor is not None: + if self._config.debug_block and tensor is not None: # TODO: Local vs global log_distributed_tensor( "", tensor if bias is None else tensor + bias, - level=self._config.debug_transformer, + level=self._config.debug_block, meta=self._get_meta(tensor, name, kwargs), distributed=self._tensor_space.distributed, ) log_distributed_grad( "", tensor, - level=self._config.debug_transformer, + level=self._config.debug_block, meta=self._get_meta(tensor, name + " grad", kwargs), distributed=self._tensor_space.distributed, ) @@ -136,13 +149,17 @@ def forward( class TransformerLayer(BaseBlock): - _name = "Transformer layer" _mixer_module_name = "self_attn" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, + config: TransformerConfig, + tensor_space: TensorSpace, + layer_index: int, + block_name: str = "", + return_input: bool = False, ): - super().__init__(config, tensor_space, layer_index, return_input) + super().__init__(config, tensor_space, layer_index, block_name, return_input) def _create_mixer(self): - self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) + self.self_attn = Attention(self._config, self._tensor_space, self._layer_index, self.block_name) diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index 8f16aaea..8fc0a09c 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -2,7 +2,7 @@ from fast_llm.engine.training.config import TrainerConfig from fast_llm.models.custom.config import CustomModelConfig, CustomTrainerConfig from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig -from fast_llm.models.ssm.config import HybridSSMModelConfig, HybridTrainerConfig +from fast_llm.models.hybrid.config import HybridModelConfig, HybridTrainerConfig from fast_llm.utils import Registry model_registry = Registry[str, FastLLMModelConfig]( @@ -12,7 +12,7 @@ for model in [ GPTModelConfig, CustomModelConfig, - HybridSSMModelConfig, + HybridModelConfig, ] }, ) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index d9085c67..035fb4bb 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -1,16 +1,22 @@ import functools +import logging import typing from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig +from fast_llm.layers.common.config import LLMDimNames from fast_llm.layers.language_model.config import LanguageModelBaseConfig +from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.models.gpt.megatron import set_megatron_distributed_seeds from fast_llm.utils import Assert, div +logger = logging.getLogger(__name__) + if typing.TYPE_CHECKING: from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM from fast_llm.models.gpt.model import GPTInferenceRunner, GPTModel @@ -97,11 +103,16 @@ def micro_batch_splits(self) -> int: @config_class() class GPTBaseModelConfig(LanguageModelBaseConfig): + """ + Base model config for GPT models. + This model is built exclusively from transformer layers which share the same config. + """ + _abstract = False - # Debug, to get an exact match with megatron init. - use_megatron_initialization: bool = Field( - default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing + transformer: TransformerConfig = Field( + desc="Configuration for the transformer architecture.", + hint=FieldHint.architecture, ) @classmethod @@ -114,9 +125,6 @@ def _from_dict( # TODO v0.3: Remove backward compatibility fix if "transposed_mlp_weight" in default: assert default.pop("transposed_mlp_weight") - if "match_megatron" in default: - assert "use_megatron_initialization" not in default - default["use_megatron_initialization"] = default.pop("match_megatron") if "layer_norm_impl" in default: assert "normalization_implementation" not in default default["normalization_implementation"] = default.pop("layer_norm_impl") @@ -124,6 +132,35 @@ def _from_dict( del default["fused_mlp"] return super()._from_dict(default, strict, flat) + def _validate(self) -> None: + if self.debug: + self.transformer.debug_block = True + self.transformer.debug_block_memory = True + self.transformer.validate() + + with self._set_implicit_default(): + if self.head_normalization is None: + self.head_normalization = self.transformer.normalization + if self.embeddings_hidden_dropout is None: + self.embeddings_hidden_dropout = self.transformer.hidden_dropout + if self.use_position_embeddings is None: + self.use_position_embeddings = not self.transformer.rotary.enabled + if self.init_method_std_embed is None: + self.init_method_std_embed = self.transformer.init_method_std + if self.init_method_max_embed is None: + self.init_method_max_embed = self.transformer.init_method_max + if self.init_method_min_embed is None: + self.init_method_min_embed = self.transformer.init_method_min + super()._validate() + + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + self.transformer.setup_tensor_space(tensor_space) + # Mark the input hidden dimension of the model + tensor_space.add_tensor_dim(TensorDim(LLMDimNames.input_hidden, self.transformer.hidden_size)) + # Mark the output hidden dimension of the model, which is the same for GPT models + tensor_space.add_tensor_dim(TensorDim(LLMDimNames.output_hidden, self.transformer.hidden_size)) + super().setup_tensor_space(tensor_space) + @config_class() class GPTModelConfig(FastLLMModelConfig): diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index b548ab52..f5569396 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -10,16 +10,12 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.layers.common.config import RoutingType from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor -from fast_llm.layers.transformer.config import ( - RoutingType, - TransformerDimNames, - TransformerKwargs, - TransformerLossNames, -) +from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs, TransformerLossNames from fast_llm.layers.transformer.preprocessing import ( BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor, @@ -62,9 +58,7 @@ def __init__( if self._config.use_absolute_position_embeddings: self._preprocessors.append(PositionEmbeddingPreprocessor(self._config, self._tensor_space)) if self._config.transformer.rotary.enabled: - self._preprocessors.append( - RotaryEmbeddingPreprocessor(self._config.transformer.rotary, self._tensor_space) - ) + self._preprocessors.append(RotaryEmbeddingPreprocessor(self._config.rotary, self._tensor_space)) if self._use_flash_attention: self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space)) else: @@ -166,7 +160,7 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.input_hidden) hidden_dims = ( (hidden_sequence_q_dim, batch_dim, hidden_dim) if sequence_first diff --git a/fast_llm/models/hybrid/config.py b/fast_llm/models/hybrid/config.py new file mode 100644 index 00000000..44ebe1a0 --- /dev/null +++ b/fast_llm/models/hybrid/config.py @@ -0,0 +1,523 @@ +import enum +import logging +import math +import typing +from abc import abstractmethod +from collections import Counter + +from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class +from fast_llm.data.data.gpt.config import GPTDataConfig +from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler +from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig +from fast_llm.engine.training.config import TrainerConfig +from fast_llm.layers.common.config import LLMDimNames +from fast_llm.layers.language_model.config import LanguageModelBaseConfig +from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.models.gpt.config import GPTBatchConfig, PretrainedGPTModelConfig +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + from fast_llm.layers.transformer.transformer import BaseBlock + from fast_llm.models.gpt.model import GPTInferenceRunner + from fast_llm.models.hybrid.huggingface import HuggingfaceHybridModelForCausalLM + from fast_llm.models.hybrid.model import HybridModel + from fast_llm.models.hybrid.trainer import SSMTrainer + +logger = logging.getLogger(__name__) + + +def _get_llamba_block(): + """Lazy import to avoid loading heavy dependencies during config validation.""" + from fast_llm.layers.ssm.blocks import LlambaBlock + + return LlambaBlock + + +def _get_llamba_one_block(): + """Lazy import to avoid loading heavy dependencies during config validation.""" + from fast_llm.layers.ssm.blocks import LlambaOneBlock + + return LlambaOneBlock + + +def _get_transformer_block(): + """Lazy import to avoid loading heavy dependencies during config validation.""" + from fast_llm.layers.transformer.transformer import TransformerLayer + + return TransformerLayer + + +@config_class(registry=True) +class HybridBlockConfig(Config): + _abstract = True + + @abstractmethod + def block_class(self) -> type["BaseBlock"]: + raise NotImplementedError("Subclasses must implement block_class") + + type: str | None = Field( + default="transformer", + desc="The config class name.", + hint=FieldHint.feature, + ) + + share_weights: bool = Field( + default=False, + desc="Whether to share weights between blocks. If True, blocks with the same name will share weights.", + hint=FieldHint.optional, + ) + + @abstractmethod + def setup_tensor_space(self, tensor_space: TensorSpace, block_name: str) -> None: + pass + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is HybridBlockConfig and cls.get_subclass(default.get("type")) is None: + raise ValueError(f"Block type not set in {cls}") + return super()._from_dict(default, strict=strict, flat=flat) + + +@config_class(dynamic_type={HybridBlockConfig: "transformer"}) +class TransformerBlockConfig(HybridBlockConfig, TransformerConfig): + _abstract = False + + @property + def block_class(self) -> type["BaseBlock"]: + return _get_transformer_block() + + def setup_tensor_space(self, tensor_space: "TensorSpace", block_name: str) -> None: + TransformerConfig.setup_tensor_space(self, tensor_space, block_name) + + +@config_class(dynamic_type={HybridBlockConfig: "discrete_mamba2"}) +class DiscreteMamba2BlockConfig(HybridBlockConfig, SSMConfig): + _abstract = False + + @property + def block_class(self) -> type["BaseBlock"]: + return _get_llamba_block() + + # def _validate(self): + # self.config.validate() + + def setup_tensor_space(self, tensor_space: TensorSpace, block_name: str) -> None: + + d_inner = int(self.expansion_factor * self.hidden_size) if self.d_inner is None else self.d_inner + # Hidden dimension + tensor_space.add_tensor_dim(TensorDim(f"{LLMDimNames.hidden}_{block_name}", self.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.model_dim}_{block_name}", self.hidden_size)) + # Mamba-specific dimensions + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.inner_dim}_{block_name}", d_inner)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.state_dim}_{block_name}", self.state_size)) + tensor_space.add_tensor_dim( + TensorDim(f"{SSMDimNames.conv_kernel_size}_{block_name}", self.conv_kernel_dimension) + ) + + # as per https://github.com/cartesia-ai/edge/blob/a0e121ebed3d2324c6d762b0e211a08d62583681/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py#L66C3-L66C4 + headdim = d_inner // self.n_v_heads + Assert.eq(self.n_v_heads, d_inner // headdim) + Assert.eq(d_inner % headdim, 0) + Assert.eq(self.n_v_heads % self.n_qk_heads, 0) + + conv_dim = d_inner + 2 * self.n_qk_heads * self.state_size + inner_proj_dim = 2 * d_inner + 2 * self.n_qk_heads * self.state_size + self.n_v_heads + + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.qk_heads}_{block_name}", self.n_qk_heads)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.v_heads}_{block_name}", self.n_v_heads)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.inner_proj_mamba2}_{block_name}", inner_proj_dim)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.conv_dim}_{block_name}", conv_dim)) + + SSMConfig.setup_tensor_space(self, tensor_space, block_name) + + +@config_class(dynamic_type={HybridBlockConfig: "mamba"}) +class MambaBlockConfig(HybridBlockConfig, SSMConfig): + _abstract = False + + @property + def block_class(self) -> type["BaseBlock"]: + return _get_llamba_one_block() + + def setup_tensor_space(self, tensor_space: TensorSpace, block_name: str) -> None: + + if self.dt_rank is None: + mamba_dt_rank = math.ceil(self.hidden_size / 16) + else: + mamba_dt_rank = self.dt_rank + + d_inner = int(self.expansion_factor * self.hidden_size) if self.d_inner is None else self.d_inner + # Hidden dimension + tensor_space.add_tensor_dim(TensorDim(f"{LLMDimNames.hidden}_{block_name}", self.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.model_dim}_{block_name}", self.hidden_size)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.inner_dim}_{block_name}", d_inner)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.state_dim}_{block_name}", self.state_size)) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.dt_rank}_{block_name}", mamba_dt_rank)) + tensor_space.add_tensor_dim( + TensorDim(f"{SSMDimNames.x_proj_dim}_{block_name}", mamba_dt_rank + self.state_size * 2) + ) + tensor_space.add_tensor_dim( + TensorDim(f"{SSMDimNames.conv_kernel_size}_{block_name}", self.conv_kernel_dimension) + ) + tensor_space.add_tensor_dim(TensorDim(f"{SSMDimNames.inner_proj_mamba}_{block_name}", d_inner * 2)) + + SSMConfig.setup_tensor_space(self, tensor_space, block_name) + + +class HybridBlockType(enum.Enum): + """ + An enum for the available block types, legacy format. + """ + + m = MambaBlockConfig + m2d = DiscreteMamba2BlockConfig + t = TransformerBlockConfig + + +@config_class() +class HybridBaseModelConfig(LanguageModelBaseConfig): + """ + HybridBaseModelConfig is a configuration class for hybrid models. + Currently it supports two formats for architecture definition: + - the old and deprecated format with transformer and ssm fields (t, m2d, m), in wich case all blocks share the same config; + - and the new format with blocks field, in which case each block can have its own config. + """ + + _abstract = False + ############################################################################################ + # Note, transformer and ssm are here for legacy reasons + transformer: TransformerConfig = Field( + desc="Configuration for the transformer architecture. Note, having transformer and ssm fields in HybridBaseModelConfig is depricated.", + hint=FieldHint.architecture, + ) + + ssm: SSMConfig = Field( + desc="Configuration for the SSM architecture. Note, having transformer and ssm fields in HybridBaseModelConfig is depricated.", + hint=FieldHint.architecture, + ) + ############################################################################################ + blocks: dict[str, HybridBlockConfig] | None = Field( + default=None, + desc="Named block configurations that can be referenced in block_pattern.", + hint=FieldHint.architecture, + ) + + hybrid_block_layout: list[str] | None = Field( + default=None, + desc=f"Pattern of blocks to use in the model (still supports the previous depricated format with {HybridBlockType.__members__.keys()})", + hint=FieldHint.core, + ) + + default_mtp_type: str | None = Field( + default=None, + desc="Multi-token prediction mixer to use in the model. Can be either one of the blocks, or follow the depricated legacy format: 't' for Transformer, 'm' for Mamba1, 'm2' for discrete Mamba2. If None, will use the last block type in `hybrid_block_layout`.", + hint=FieldHint.optional, + ) + + _hybrid_block_layout: list[str] | None = Field( + init=False, + desc="Internal representation of the block layout.", + hint=FieldHint.derived, + ) + + _blocks: dict[str, HybridBlockConfig] | None = Field( + init=False, + desc="Internal representation of the blocks.", + hint=FieldHint.derived, + ) + + # TODO: currently num_layers is defined in TransformerConfig, but ideally this should be migrated to LanguageModelBaseConfig in the future. + # Hence, for now: the num_layers can be set in the first transformer block, if no transformer blocks used we will fallback to num_layers parameter defined here. + num_layers: int = Field( + default=None, + desc="Number of layers in the transformer.", + hint=FieldHint.architecture, + valid=check_field(Assert.geq, 0), + ) + + @property + def block_layout(self) -> list[str]: + return self._hybrid_block_layout + + @property + def registered_blocks(self) -> dict[str, HybridBlockConfig]: + return self._blocks + + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: + """ + Setup the tensor space for the model. + """ + for block_name, block_config in self.blocks.items(): + block_config.setup_tensor_space(tensor_space, block_name) + # The first layer's hidden dimension is the input hidden dimension of the model + tensor_space.add_tensor_dim( + TensorDim(LLMDimNames.input_hidden, self.blocks[self.hybrid_block_layout[0]].hidden_size) + ) + # Mark the output hidden dimension of the model + tensor_space.add_tensor_dim( + TensorDim(LLMDimNames.output_hidden, self.blocks[self.hybrid_block_layout[-1]].hidden_size) + ) + super().setup_tensor_space(tensor_space) + + def _validate(self): + if self.blocks is None: + logger.warning( + f"Blocks not set, falling back to old behavior with hybrid_block_layout containing any of {HybridBlockType.__members__.keys()}" + ) + if self.hybrid_block_layout is None: + with self._set_implicit_default(): + logger.warning( + f"No hybrid_block_layout found in HybridBaseModelConfig, using default block {HybridBlockType.m2d}" + ) + self.hybrid_block_layout = [HybridBlockType.m2d] + + # Legacy format with t, m, m2d, convert to new format + Assert.custom( + lambda _: all( + block_type in HybridBlockType.__members__.keys() for block_type in self.hybrid_block_layout + ), + f"Invalid block type: {self.hybrid_block_layout}. Must be one of {HybridBlockType.__members__.keys()}", + ) + blocks = {} + for block_type in self.hybrid_block_layout: + if block_type not in blocks: + hybrid_block_config_cls = HybridBlockType[block_type].value + if hybrid_block_config_cls == TransformerBlockConfig: + blocks[block_type] = TransformerBlockConfig.from_dict(self.transformer.to_dict()) + elif hybrid_block_config_cls == MambaBlockConfig: + blocks[block_type] = MambaBlockConfig.from_dict(self.ssm.to_dict()) + elif hybrid_block_config_cls == DiscreteMamba2BlockConfig: + blocks[block_type] = DiscreteMamba2BlockConfig.from_dict(self.ssm.to_dict()) + else: + raise ValueError(f"Invalid block type: {block_type}") + self.blocks = blocks + + Assert.gt(len(self.hybrid_block_layout), 0) + # Validate that all pattern entries refer to valid blocks + for block_name in self.hybrid_block_layout: + if block_name not in self.blocks: + raise ValueError(f"Block name '{block_name}' not found in blocks dictionary") + + first_transformer_block_config: TransformerBlockConfig | None = None + + ### Weight sharing ### + # handle share_weights by renaming blocks with shared weights. Layer names are used for setting tensor dimensions. + blocks = {} + hybrid_block_layout = [] + block_count = Counter(self.hybrid_block_layout) + for i, block_name in enumerate(self.hybrid_block_layout): + block_config = self.blocks[block_name] + if not block_config.share_weights: + if block_count[block_name] > 1: + logger.info(f"Weight sharing disabled for block {block_name}, renaming to {block_name}_{i}") + block_name = f"{block_name}_{i}" + else: + logger.info(f"Weight sharing disabled for block {block_name}, no renaming needed") + else: + logger.info(f"Weight sharing enabled for block {block_name}") + blocks[block_name] = block_config + hybrid_block_layout.append(block_name) + with self._set_implicit_default(): + # self.blocks = blocks + # self.hybrid_block_layout = hybrid_block_layout + self._hybrid_block_layout = hybrid_block_layout + self._blocks = blocks + ###\Weight sharing ### + + for block_name, block_config in self._blocks.items(): + if isinstance(block_config, TransformerBlockConfig): + if first_transformer_block_config is None: + first_transformer_block_config = block_config + elif self.num_layers is None and block_config.num_layers != first_transformer_block_config.num_layers: + logger.warning( + f"Found multiple transformer blocks with different number of layers, using num_layers from the first transformer block for all" + ) + block_config.validate() + + # set num_layers from transformer block config if it exists and if num_layers is not set in HybridBaseModelConfig + # i.e. the resolution hierarchy for num_layers is: HybridBaseModelConfig.num_layers > TransformerBlockConfig.num_layers + if first_transformer_block_config is not None and self.num_layers is None: + num_layers = first_transformer_block_config.num_layers + with self._set_implicit_default(): + logger.warning( + f"TransformerBlockConfig overwrites BaseModelConfig num_layers, setting num_layers = {num_layers}" + ) + self.num_layers = num_layers + + # make sure that the hybrid_block_layout length matches the num_layers. If it doesn't, repeat the hybrid_block_layout; + if len(self._hybrid_block_layout) != self.num_layers: + if self.num_layers % len(self._hybrid_block_layout) != 0: + raise ValueError( + f"hybrid_block_layout length {len(self._hybrid_block_layout)} does not match num_layers {self.num_layers}" + ) + num_repeats = int(self.num_layers // len(self._hybrid_block_layout)) + logger.warning( + f"hybrid_block_layout length {len(self._hybrid_block_layout)} does not match num_layers {self.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times with weight sharing between repeats." + ) + self._hybrid_block_layout = self._hybrid_block_layout * num_repeats + + Assert.eq(len(self._hybrid_block_layout), self.num_layers) + logger.info(f"Hybrid block layout: {self._hybrid_block_layout}") + + with self._set_implicit_default(): + if self.use_position_embeddings is None: + if first_transformer_block_config is not None: + self.use_position_embeddings = not first_transformer_block_config.rotary.enabled + self.embeddings_hidden_dropout = first_transformer_block_config.hidden_dropout + else: + self.use_position_embeddings = False + self.embeddings_hidden_dropout = 0.0 + logger.warning( + f"No transformer block config found in HybridBaseModelConfig, setting use_position_embeddings to False" + ) + + if self.init_method_std_embed is None: + self.init_method_std_embed = ( + first_transformer_block_config.init_method_std + if first_transformer_block_config is not None + else 0.02 + ) + if self.init_method_max_embed is None: + self.init_method_max_embed = ( + first_transformer_block_config.init_method_max + if first_transformer_block_config is not None + else 0.02 + ) + if self.init_method_min_embed is None: + self.init_method_min_embed = ( + first_transformer_block_config.init_method_min + if first_transformer_block_config is not None + else 0.02 + ) + + if self.prediction_heads > 1: + with self._set_implicit_default(): + if self.default_mtp_type is None: + logger.warning( + f"No default_mtp_type found in HybridBaseModelConfig, using the last block type in hybrid_block_layout: {self.hybrid_block_layout[-1]}" + ) + self.default_mtp_type = self.hybrid_block_layout[-1] + else: + if self.default_mtp_type not in self.hybrid_block_layout: + raise ValueError( + f"default_mtp_type {self.default_mtp_type} not found in hybrid_block_layout {self.hybrid_block_layout}" + ) + super()._validate() + + +class LLambaHuggingfaceCheckpointFormat(CheckpointFormat): + support_optimizer: typing.ClassVar[bool] = False + name: typing.ClassVar[str] = "llamba" + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.hybrid.conversion import LLambaHuggingfaceCheckpointHandler + + return LLambaHuggingfaceCheckpointHandler + + +class AprielSSMHuggingfaceCheckpointFormat(CheckpointFormat): + support_optimizer: typing.ClassVar[bool] = False + name: typing.ClassVar[str] = "apriel_ssm" + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.hybrid.conversion import AprielSSMHuggingfaceCheckpointHandler + + return AprielSSMHuggingfaceCheckpointHandler + + +class AprielSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): + support_optimizer: typing.ClassVar[bool] = False + name: typing.ClassVar[str] = "apriel_ssm_hybrid" + + @classmethod + def get_handler_class(cls) -> type[CheckpointHandler]: + from fast_llm.models.hybrid.conversion import AprielSSMHHybridHuggingfaceCheckpointHandler + + return AprielSSMHHybridHuggingfaceCheckpointHandler + + +@config_class() +class HybridModelConfig(FastLLMModelConfig): + _abstract = False + model_name: typing.ClassVar[str] = "hybrid_ssm" + base_model: HybridBaseModelConfig = FieldUpdate() + checkpoint_formats = FastLLMModelConfig.checkpoint_formats + ( + LLambaHuggingfaceCheckpointFormat, + AprielSSMHuggingfaceCheckpointFormat, + AprielSSMHHybridHuggingfaceCheckpointFormat, + ) + + @classmethod + def get_model_class(cls) -> type["HybridModel"]: + from fast_llm.models.hybrid.model import HybridModel + + return HybridModel + + @classmethod + def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceHybridModelForCausalLM"]: + from fast_llm.models.hybrid.huggingface import HuggingfaceHybridModelForCausalLM + + return HuggingfaceHybridModelForCausalLM + + def _validate(self): + logger.warning( + "HybridModelConfig is being instantiated. This model is experimental and may not work as expected." + ) + super()._validate() + + +@config_class() +class PretrainedHybridModelConfig(PretrainedFastLLMModelConfig): + _abstract = False + model: HybridModelConfig = FieldUpdate() + + +@config_class() +class HybridTrainerConfig(PretrainedHybridModelConfig, TrainerConfig): + data: GPTDataConfig = FieldUpdate() + batch: GPTBatchConfig = FieldUpdate() + reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() + + @classmethod + def get_trainer_class(cls) -> type["SSMTrainer"]: + from fast_llm.models.hybrid.trainer import SSMTrainer + + return SSMTrainer + + def _validate(self) -> None: + super()._validate() + if (name := self.model.base_model.distillation_model) is None: + Assert.empty(self.reference_models) + else: + Assert.eq(self.reference_models.keys(), {name}) + if self.model.base_model.use_absolute_position_embeddings: + Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) + # if self.model.base_model.distillation_model is not None: + # # TODO: Support loss masking for distillation? + # assert not self.batch.use_loss_masking_spans + for reference_model in self.reference_models.values(): + Assert.none(reference_model.model.base_model.distillation_model) + # TODO: Support more LM head features. + Assert.none(reference_model.model.base_model.cross_entropy_splits) + Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) + Assert.geq(reference_model.model.base_model.prediction_heads, self.model.base_model.prediction_heads) + + @classmethod + def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: + from fast_llm.models.gpt.model import GPTInferenceRunner + + # TODO: we dont have inference runner for SSM/Hybrid yet, should return None? + logger.warning( + "No inference runner for SSM/Hybrid yet, using GPTInferenceRunner for now, which does not support SSM/Hybrid" + ) + + return GPTInferenceRunner diff --git a/fast_llm/models/hybrid/conversion.py b/fast_llm/models/hybrid/conversion.py new file mode 100644 index 00000000..73f1bceb --- /dev/null +++ b/fast_llm/models/hybrid/conversion.py @@ -0,0 +1,584 @@ +import json +import os +import pathlib +import typing + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.checkpoint.external import ( + ConstantExportParamConverter, + ConstantImportParamConverter, + IgnoreImportWeightConverter, + MappedConfigParamConverter, + ParamConverter, + RenameParamConverter, + SplitWeightConverter, + WeightConverter, +) +from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.functional.config import ActivationType +from fast_llm.layers.common.config import NormalizationType +from fast_llm.layers.ssm.config import SSMBlockType +from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter +from fast_llm.models.hybrid.config import ( + AprielSSMHHybridHuggingfaceCheckpointFormat, + AprielSSMHuggingfaceCheckpointFormat, + HybridModelConfig, + LLambaHuggingfaceCheckpointFormat, +) +from fast_llm.models.hybrid.model import HybridModel +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + pass + + +class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): + """ + This is a temporary solution for importing/exporting hybrid models. Since there is no standard solution for this in HF, we just use the block_pattern. + If block_pattern is None, it will multiply the provided default block type by the number of layers and export/import it. + If block_pattern is provided, it will export/import it as-is. + """ + + _model: HybridModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridModelConfig + _default_block_type: str = SSMBlockType.mamba2_discrete.value + + @classmethod + def _import_config(cls, config): + cls.num_layers = config["n_layer"] if "n_layer" in config else config["num_hidden_layers"] + cls.block_pattern = config.get("hybrid_block_layout", None) + return super()._import_config(config) + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + if cls.block_pattern is not None: + block_converter = RenameParamConverter( + fast_llm_names=(("hybrid_block_layout",),), + export_names=(("hybrid_block_layout",),), + ) + else: + block_converter = ConstantImportParamConverter( + fast_llm_names=(("hybrid_block_layout",),), + fast_llm_value=[cls._default_block_type] * cls.num_layers, + ) + + return super()._create_config_converters() + [block_converter] + + +class CommonSSMHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): + _model: HybridModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridModelConfig + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + RenameParamConverter( + fast_llm_names=(("ssm", "state_size"),), + export_names=( + ( + "ssm_cfg", + "d_state", + ), + ), + ), + RenameParamConverter( + fast_llm_names=(("ssm", "n_v_heads"),), + export_names=( + ( + "ssm_cfg", + "n_v_heads", + ), + ), + ), + RenameParamConverter( + fast_llm_names=(("ssm", "n_qk_heads"),), + export_names=( + ( + "ssm_cfg", + "n_qk_heads", + ), + ), + ), + RenameParamConverter( + fast_llm_names=(("ssm", "expansion_factor"),), + export_names=( + ( + "ssm_cfg", + "expand", + ), + ), + ), + RenameParamConverter( + fast_llm_names=(("ssm", "chunk_size"),), + export_names=( + ( + "ssm_cfg", + "chunk_size", + ), + ), + ), + RenameParamConverter( + fast_llm_names=(("ssm", "add_bias_linear"),), + export_names=( + ( + "ssm_cfg", + "bias", + ), + ), + ), + MappedConfigParamConverter( + fast_llm_names=(("ssm", "activation_type"),), + export_names=( + ( + "ssm_cfg", + "activation", + ), + ), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + ] + + def _create_weight_converters(self) -> list[WeightConverter]: + converters = super()._create_weight_converters() or [] + + num_layers = self._model.config.base_model.transformer.num_layers + ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear + + for i in range(num_layers): + # SSM + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.in_proj", f"model.layers.{i}.mixer.in_proj", ssm_bias + ) + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.out_proj", f"model.layers.{i}.mixer.out_proj", ssm_bias + ) + converters.append( + WeightConverter(f"layers.{i+1}.mixer.D", f"model.layers.{i}.mixer.D", self._model.config.base_model) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.z_bias", f"model.layers.{i}.mixer.z_bias", self._model.config.base_model + ) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.conv1d_weight", + f"model.layers.{i}.mixer.conv1d.weight", + self._model.config.base_model, + ) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.conv1d_bias", + f"model.layers.{i}.mixer.conv1d.bias", + self._model.config.base_model, + ) + ) + + return converters + + def _get_weight_and_bias_converters( + self, + fast_llm_prefix: str | tuple[str, ...], + hf_prefix: str | tuple[str, ...], + use_bias: bool, + cls=WeightConverter, + ) -> list[WeightConverter]: + if isinstance(fast_llm_prefix, str): + fast_llm_prefix = (fast_llm_prefix,) + if isinstance(hf_prefix, str): + hf_prefix = (hf_prefix,) + converters = [ + cls( + tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), + tuple(f"{prefix}.weight" for prefix in hf_prefix), + self._model.config.base_model, + ) + ] + if use_bias: + converters.append( + cls( + tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), + tuple(f"{prefix}.bias" for prefix in hf_prefix), + self._model.config.base_model, + ) + ) + return converters + + +class LLambaHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandler): + _model: HybridModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridModelConfig + format: typing.ClassVar[type[CheckpointFormat]] = LLambaHuggingfaceCheckpointFormat + _hf_prefix: str = "backbone" + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + """ + Create config converters for the model, see args under https://huggingface.co/cartesia-ai/Llamba-8B/blob/main/config.json + """ + return super()._create_config_converters() + [ + RenameParamConverter( + fast_llm_names=(("vocab_size",),), + export_names=(("vocab_size",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) + ), + ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=(("transformer", "num_layers"),), + export_names=(("n_layer",),), + ), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + ), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=( + ( + "mlp_cfg", + "act_fn", + ), + ), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("transformer", "add_linear_biases"),), + export_names=( + ( + "mlp_cfg", + "bias", + ), + ), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "ffn_hidden_size"),), + export_names=( + ( + "mlp_cfg", + "intermediate_size", + ), + ), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "hidden_size"),), + export_names=(("d_model",),), + ), + RenameParamConverter( + fast_llm_names=(("tie_word_embeddings",),), + export_names=(("tie_embeddings",),), + ), + ] + + def _create_weight_converters(self) -> list[WeightConverter]: + # not using super() because LLamba model is called backbone in the checkpoints + converters = [] + num_layers = self._model.config.base_model.transformer.num_layers + norm_bias: bool = False + ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear + + # Embedding and output + if self._model.config.base_model.tie_word_embeddings: + converters.append( + WeightConverter("layers.0.word_embeddings_weight", f"{self._hf_prefix}.embedding.weight") + ) + converters.append(IgnoreImportWeightConverter((), f"{self._hf_prefix}.lm_head.weight")) + else: + converters.append( + WeightConverter("layers.0.word_embeddings_weight", f"{self._hf_prefix}.embedding.weight") + ) + converters.append( + WeightConverter(f"layers.{num_layers + 1}.output_weights", f"{self._hf_prefix}.lm_head.weight") + ) + + # Final norm + converters += self._get_weight_and_bias_converters( + f"layers.{num_layers + 1}.final_norm", f"{self._hf_prefix}.final_layernorm", norm_bias + ) + + for i in range(num_layers): + # SSM + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.in_proj", f"{self._hf_prefix}.layers.{i}.mixer.in_proj", ssm_bias + ) + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.out_proj", f"{self._hf_prefix}.layers.{i}.mixer.out_proj", ssm_bias + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.D", f"{self._hf_prefix}.layers.{i}.mixer.D", self._model.config.base_model + ) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.z_bias", + f"{self._hf_prefix}.layers.{i}.mixer.z_bias", + self._model.config.base_model, + ) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.conv1d_weight", + f"{self._hf_prefix}.layers.{i}.mixer.conv1d.weight", + self._model.config.base_model, + ) + ) + converters.append( + WeightConverter( + f"layers.{i+1}.mixer.conv1d_bias", + f"{self._hf_prefix}.layers.{i}.mixer.conv1d.bias", + self._model.config.base_model, + ) + ) + + # Norm + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.norm_1", f"{self._hf_prefix}.layers.{i}.input_layernorm", norm_bias + ) + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.norm_2", f"{self._hf_prefix}.layers.{i}.post_attention_layernorm", norm_bias + ) + + # MLP + converters += self._get_mlp_converters(f"layers.{i+1}", f"{self._hf_prefix}.layers.{i}") + + return converters + + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases + return [ + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), + linear_bias, + SplitWeightConverter, + ), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + linear_bias, + MLPLayer2Converter, + ), + ] + + def _get_weight_and_bias_converters( + self, + fast_llm_prefix: str | tuple[str, ...], + hf_prefix: str | tuple[str, ...], + use_bias: bool, + cls=WeightConverter, + ) -> list[WeightConverter]: + if isinstance(fast_llm_prefix, str): + fast_llm_prefix = (fast_llm_prefix,) + if isinstance(hf_prefix, str): + hf_prefix = (hf_prefix,) + converters = [ + cls( + tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), + tuple(f"{prefix}.weight" for prefix in hf_prefix), + self._model.config.base_model, + ) + ] + if use_bias: + converters.append( + cls( + tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), + tuple(f"{prefix}.bias" for prefix in hf_prefix), + self._model.config.base_model, + ) + ) + return converters + + @classmethod + def _load_config(cls, directory: pathlib.Path | str) -> dict: + if not os.path.exists(directory / "config.json"): + raise FileNotFoundError(f"config.json not found in {directory}") + with open(directory / "config.json") as f: + config = json.load(f) + Assert.eq(config["model_type"], cls.get_huggingface_model_type()) + return config + + @classmethod + def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: + with open(directory / "config.json", "w") as f: + json.dump(config, f) + + +class AprielSSMHuggingfaceCheckpointHandler(CommonSSMHuggingfaceCheckpointHandler): + """ + Lamba-like configs, pure SSM models. + """ + + _model: HybridModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridModelConfig + format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHuggingfaceCheckpointFormat + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + RenameParamConverter( + fast_llm_names=(("vocab_size",),), + export_names=(("vocab_size",),), + ), + RenameParamConverter( + fast_llm_names=(("ssm", "d_inner"),), + export_names=(("ssm_cfg", "d_inner"),), + ), + ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False), + ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), + MappedConfigParamConverter( + fast_llm_names=(("transformer", "activation_type"),), + export_names=(("hidden_act",),), + fast_llm_value=ActivationType.from_hf_name, + export_value=lambda activation_type: activation_type.hf_name, + ), + RenameParamConverter( + fast_llm_names=(("transformer", "num_layers"),), + export_names=(("num_hidden_layers",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "hidden_size"),), + export_names=(("hidden_size",),), + ), + RenameParamConverter( + fast_llm_names=(("transformer", "ffn_hidden_size"),), + export_names=(("intermediate_size",),), + ), + ConstantImportParamConverter( + fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm + ), + RenameParamConverter( + fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),) + ), + ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), + ConstantImportParamConverter(fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value=False), + RenameParamConverter( + fast_llm_names=(("tie_word_embeddings",),), + export_names=(("tie_word_embeddings",),), + ), + ] + + def _create_weight_converters(self) -> list[WeightConverter]: + converters = super()._create_weight_converters() + num_layers = self._model.config.base_model.transformer.num_layers + norm_bias: bool = False + + # Embedding and output + if self._model.config.base_model.tie_word_embeddings: + converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) + converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) + else: + converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) + converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) + + # Final norm + converters += self._get_weight_and_bias_converters( + f"layers.{num_layers + 1}.final_norm", "model.norm", norm_bias + ) + + for i in range(num_layers): + # Norm + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.norm_1", f"model.layers.{i}.input_layernorm", norm_bias + ) + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.norm_2", f"model.layers.{i}.post_attention_layernorm", norm_bias + ) + + # MLP + converters += self._get_mlp_converters(f"layers.{i+1}", f"model.layers.{i}") + + return converters + + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases + return [ + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), + linear_bias, + SplitWeightConverter, + ), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + linear_bias, + MLPLayer2Converter, + ), + ] + + @classmethod + def _load_config(cls, directory: pathlib.Path | str) -> dict: + if not os.path.exists(directory / "config.json"): + raise FileNotFoundError(f"config.json not found in {directory}") + with open(directory / "config.json") as f: + config = json.load(f) + Assert.eq(config["model_type"], cls.get_huggingface_model_type()) + return config + + @classmethod + def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: + with open(directory / "config.json", "w") as f: + json.dump(config, f) + + +class AprielSSMHHybridHuggingfaceCheckpointHandler( + HybridModelCheckpointHandler, # handles the block structure parameter + CommonSSMHuggingfaceCheckpointHandler, # handles the SSM layers + CommonLlamaHuggingfaceCheckpointHandler, # handles the LLama layers +): + """ + Lamba-like configs, models that interleave LLama like layers with LLamba-like SSM layers. + """ + + _model: HybridModel + _model_class: typing.ClassVar[FastLLMModelConfig] = HybridModelConfig + format: typing.ClassVar[type[CheckpointFormat]] = AprielSSMHHybridHuggingfaceCheckpointFormat + _default_block_type: str = SSMBlockType.mamba2_discrete.value + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + RenameParamConverter( + fast_llm_names=(("ssm", "d_inner"),), + export_names=(("ssm_cfg", "d_inner"),), + ), + ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False), + ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False), + ] + + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: + linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases + return [ + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), + linear_bias, + SplitWeightConverter, + ), + *self._get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + linear_bias, + MLPLayer2Converter, + ), + ] + + @classmethod + def _load_config(cls, directory: pathlib.Path | str) -> dict: + if not os.path.exists(directory / "config.json"): + raise FileNotFoundError(f"config.json not found in {directory}") + with open(directory / "config.json") as f: + config = json.load(f) + Assert.eq(config["model_type"], cls.get_huggingface_model_type()) + return config + + @classmethod + def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: + with open(directory / "config.json", "w") as f: + json.dump(config, f) diff --git a/fast_llm/models/hybrid/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py b/fast_llm/models/hybrid/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py new file mode 100644 index 00000000..bc2e603c --- /dev/null +++ b/fast_llm/models/hybrid/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py @@ -0,0 +1,21 @@ +from transformers import MistralConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class AprielSSMHybridConfig(MistralConfig): + def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): + super().__init__(**kwargs) + self.hybrid_block_layout = hybrid_block_layout + self.ssm_cfg = ssm_cfg or { + "d_state": 64, + "n_v_heads": 24, + "n_qk_heads": 24, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + "d_conv": 4, + "d_inner": 24 * self.head_dim, # num_heads * head_dim + } diff --git a/fast_llm/models/hybrid/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/hybrid/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py new file mode 100644 index 00000000..bc62f241 --- /dev/null +++ b/fast_llm/models/hybrid/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -0,0 +1,1093 @@ +import copy +from dataclasses import dataclass +from functools import partial +from typing import Any, Optional, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from einops import rearrange, repeat +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined +from torch import nn +from transformers import GenerationMixin +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.models.mistral.modeling_mistral import ( + MISTRAL_INPUTS_DOCSTRING, + MistralDecoderLayer, + MistralMLP, + MistralModel, + MistralRMSNorm, +) +from transformers.processing_utils import Unpack +from transformers.utils import LossKwargs, add_start_docstrings_to_model_forward, can_return_tuple, logging +from transformers.utils.generic import ModelOutput + +from fast_llm.models.hybrid.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig + +logger = logging.get_logger(__name__) + + +class HybridMambaAttentionStaticCache(Cache): + def __init__(self, config: AprielSSMHybridConfig, batch_size, max_length, dtype=torch.float16, device=None): + super().__init__() # config, batch_size, max_length, device, dtype) + self.dtype = dtype + self.hybrid_override_pattern = config.hybrid_block_layout + self.has_previous_state = False # only used by mamba + intermediate_size = config.ssm_cfg["d_inner"] + ssm_state_size = config.ssm_cfg["d_state"] + conv_kernel_size = config.ssm_cfg["d_conv"] + self.n_qk_heads = config.ssm_cfg["n_qk_heads"] + assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" + self.head_d = intermediate_size // self.n_qk_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + + self.batch_size = batch_size + self.head_dim = ( + config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + ) + self.max_cache_len = config.max_position_embeddings if max_length is None else max_length + + self.num_key_value_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + cache_shape = (self.batch_size, self.num_key_value_heads, max_length, self.head_dim) + + for i in range(config.num_hidden_layers): + if self.hybrid_override_pattern[i] == "m2d": + # Mamba layer + new_layer_conv_state = torch.zeros( + batch_size, + conv_kernel_size, + intermediate_size + 2 * self.n_qk_heads * ssm_state_size, + device=device, + dtype=dtype, + ).transpose(1, 2) + + new_layer_ssm_state = torch.zeros( + batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype + ) + new_layer_key_cache = None # torch.zeros((0,), dtype=dtype, device=device) + new_layer_value_cache = None # torch.zeros((0,), dtype=dtype, device=device) + else: + # Attention or MLP layer + new_layer_conv_state = None # torch.tensor((0,), dtype=dtype, device=device) + new_layer_ssm_state = None # torch.tensor((0,), dtype=dtype, device=device) + new_layer_key_cache = torch.zeros(cache_shape, dtype=dtype, device=device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=dtype, device=device) + self.transformer_layers.append(i) + + # if not is_torchdynamo_compiling(): + # self.register_buffer(f"key_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) + # self.register_buffer(f"value_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) + # new_layer_key_cache = getattr(self, f"key_cache_{i}") + # new_layer_value_cache = getattr(self, f"value_cache_{i}") + # torch._dynamo.mark_static_address(new_layer_key_cache) + # torch._dynamo.mark_static_address(new_layer_value_cache) + # self.register_buffer(f"conv_states_{i}", new_layer_conv_state) + # self.register_buffer(f"ssm_states_{i}", new_layer_ssm_state) + # torch._dynamo.mark_static_address(new_layer_conv_state) + # torch._dynamo.mark_static_address(new_layer_ssm_state) + # new_layer_ssm_state = getattr(self, f"ssm_states_{i}") + # new_layer_conv_state = getattr(self, f"conv_states_{i}") + + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + self.conv_states.append(new_layer_conv_state) + self.ssm_states.append(new_layer_ssm_state) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. + + Return: + A tuple containing the updated key and value states. + """ + + cache_position = cache_kwargs.get("cache_position") + + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + + if cache_position is None: + k_out.copy_(key_states) + v_out.copy_(value_states) + else: + # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to + # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place + # operation, that avoids copies and uses less memory. + try: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + return k_out, v_out + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = None) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + if layer_idx is None: + layer_idx = self.transformer_layers[0] + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def get_max_cache_shape(self) -> Optional[int]: + return self.max_cache_len + + # Copied from modeling_mamba2.py + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False + ) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py +class HybridMambaAttentionDynamicCache(DynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config: AprielSSMHybridConfig, batch_size, dtype=torch.float16, device=None): + super().__init__() + self.dtype = dtype + self.hybrid_override_pattern = config.hybrid_block_layout + self.has_previous_state = False # only used by mamba + intermediate_size = config.ssm_cfg["d_inner"] + ssm_state_size = config.ssm_cfg["d_state"] + conv_kernel_size = config.ssm_cfg["d_conv"] + self.n_qk_heads = config.ssm_cfg["n_qk_heads"] + assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" + self.head_d = intermediate_size // self.n_qk_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + if self.hybrid_override_pattern[i] == "m2d": + # Mamba layer + self.conv_states += [ + torch.zeros( + batch_size, + conv_kernel_size, + intermediate_size + 2 * self.n_qk_heads * ssm_state_size, + device=device, + dtype=dtype, + ).transpose(1, 2) + ] + self.ssm_states += [ + torch.zeros(batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype) + ] + else: + # Attention or MLP layer + self.conv_states += [torch.tensor([[]] * batch_size, device=device)] + self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + # Copied from modeling_mamba2.py + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False + ) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +@dataclass +class AprielHybridCausalOutput(ModelOutput): + """Custom output class for MambaLMHeadModel.""" + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + attention_weights: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + + +def segsum(x): + """More stable segment sum calculation.""" + # [1, 2, 3] + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + # [[1, 1, 1], [2, 2, 2], [3, 3, 3]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) + x = x.masked_fill(~mask, 0) + # [[0, 0, 0], [2, 0, 0], [3, 3, 0]] + x_segsum = torch.cumsum(x, dim=-2) + # [[0, 0, 0], [2, 0, 0], [5, 3, 0]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def materialize_mixer(A_log, B, C, D): + """ + Since the transfer matrix will be equated to the attention matrix, + we need to support the form: torch.matmul(attn_weights, value_states). + Thus, y = torch.matmul(T, X) + Arguments: + A_log: (batch, length, n_heads) + B: (batch, length, n_heads, d_state) + C: (batch, length, n_heads, d_state) + Return: + T: (batch, n_heads, length, length) + """ + batch_size, length, n_heads, d_state = B.shape + assert A_log.shape == (batch_size, length, n_heads) + assert B.shape == C.shape == (batch_size, length, n_heads, d_state) + + # Compute: + A_log = rearrange(-F.softplus(A_log), "b l h -> b h l") + powers = torch.exp(segsum(A_log)) + T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers) + + # Add D: + if D is not None: + T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1) + + T = rearrange(T, "b h z l -> b h l z") + return T + + +# This is from LLmaba/Mohawk: https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py +class DiscreteMamba2(nn.Module): + def __init__( + self, + d_model, + d_state=64, + n_qk_heads=32, + n_v_heads=32, + d_conv=4, + expand=1, + activation="identity", + bias=False, + conv_bias=True, + chunk_size=128, + layer_idx=None, + device=None, + dtype=None, + d_inner=None, + **kwargs, # Absorb kwarg for general module + ): + """ + See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. + Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" + + Other options are all experimental and should not need to be configured + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = self.expand * self.d_model if d_inner is None else d_inner + self.n_qk_heads = n_qk_heads + self.n_v_heads = n_v_heads + self.headdim = self.d_inner // self.n_v_heads + assert self.n_v_heads == self.d_inner // self.headdim + assert self.d_inner % self.headdim == 0 + assert self.n_v_heads % self.n_qk_heads == 0 + self.activation = activation + self.chunk_size = chunk_size + self.layer_idx = layer_idx + self.bias = bias + self.kwargs = kwargs + + # Projections + self.in_proj = nn.Linear( + self.d_model, + 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, + bias=bias, + **factory_kwargs, + ) + self.z_bias = ( + nn.Parameter(torch.zeros(self.d_inner, device=device)) if not bias else 0 + ) # make sure z_bias always exists + + # Convolutional layer + conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state + self.conv_bias = conv_bias + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + **factory_kwargs, + ) + + # Activation after conv + if self.activation == "identity": + self.act = nn.Identity() + elif self.activation in ["silu", "swish"]: + self.act = nn.SiLU() + else: + raise ValueError(f"Unknown activation {self.activation}") + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.n_v_heads, device=device)) + self.D._optim = {"weight_decay": 0.0} + + # out_proj + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + + @property + def d_output(self): + return self.d_model + + @property + def state_to_tensor(self): + return self.layer.state_to_tensor + + def forward( + self, + u, + return_mixer_matrix=False, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + inference_params=None, + **kwargs, + ): + """ + u: (B, L, D) + Returns: same shape as u + """ + outputs = {} + # assert state is None + batch, seqlen, dim = u.shape + + ssm_state, conv_state = None, None + if past_key_value is not None: + ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) + if inference_params is not None and inference_params.seqlen_offset > 0: + # States are updated inplace + # TODO: make sure inference_params with seqlen_offset are properly initialized + u = u.squeeze(1) if len(u.shape) == 3 else u + out, _, _ = self.step(u, ssm_state, conv_state) + out = out.unsqueeze(1) if len(u.shape) == 2 else out + return {"hidden_states": out} + + # Hacky way to initialize state during inference + chunk_size = self.chunk_size if ssm_state is None else seqlen + + # Pad input to nearest multiple of chunklen + padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size + u = F.pad(u, (0, 0, 0, padded_len - seqlen)) + + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + if ssm_state is not None: + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") + conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + + # Convolutional layer + xBC = self.convolutional_forward(xBC, padded_len) + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) + B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) + C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + + # SSM forward + result = mamba_chunk_scan_combined( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=A_log, + dt_softplus=True, + A=-torch.ones(self.n_v_heads, device=A_log.device), + B=B, + C=C, + chunk_size=chunk_size, + # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation + return_final_states=(ssm_state is not None), + ) + + if ssm_state is not None: + y, ssm_state_update = result + ssm_state.copy_(ssm_state_update) + else: + y = result + + Du = torch.einsum("h,blhp->blhp", self.D, x) + y = rearrange(y + Du, "b l h p -> b l (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + outputs["hidden_states"] = out[:, :seqlen, :] + + if return_mixer_matrix: + outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] + return outputs + + def step(self, u, ssm_state, conv_state, **kwargs): + """ + u: (B D) + state: dict of states + Returns: same shape as u + """ + + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + xBC, conv_state_new = self.convolutional_step(xBC, conv_state) + conv_state.copy_(conv_state_new) # update state in place + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) + B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) + C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) + + ssm_state = ssm_state.to(x.dtype) + zeros = torch.zeros((self.n_v_heads, self.headdim), device=A_log.device).to(dtype=x.dtype) + ones = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=A_log.device).to(dtype=x.dtype) + y = selective_state_update( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=repeat(A_log, "b h -> b h p", p=self.headdim), + dt_softplus=True, + A=-ones, + B=B, + C=C, + state=ssm_state, # will be updated in place + dt_bias=zeros, + D=zeros, + ) + + y = y + self.D[:, None] * x + y = rearrange(y, "b h p -> b (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + + return out, ssm_state, conv_state + + # def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + # device = self.in_proj.weight.device + # # conv_state: + # conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + # conv_state = torch.zeros( + # batch_size, + # self.d_conv, + # self.conv1d.weight.shape[0], + # device=device, + # dtype=conv_dtype, + # ).transpose(1, 2) + # # ssm_state: + # ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype + # ssm_state = torch.zeros( + # batch_size, + # self.n_v_heads, + # self.headdim, + # self.d_state, + # device=device, + # dtype=ssm_dtype, + # ) + # return {"conv": conv_state, "ssm": ssm_state} + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + """ + conv_state: (batch, d_conv, conv1d.weight.shape[0]) + ssm_state: (batch, n_qk_heads, headdim, d_state) + """ + assert self.layer_idx is not None + # Allocate memory if not exists + # if self.layer_idx not in inference_params.ssm_states: + # inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( + # batch_size, inference_params.max_seqlen, dtype=torch.float32 + # ) + # Get states + ssm_states = inference_params.ssm_states[self.layer_idx] + conv_states = inference_params.conv_states[self.layer_idx] + if initialize_states: + ssm_states.zero_() + conv_states.zero_() + return ssm_states, conv_states + + def convolutional_forward(self, xBC, padded_len): + if causal_conv1d_fn is None or self.activation not in [ + "silu", + "swish", + "identity", + ]: + xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) + else: + xBC = causal_conv1d_fn( + xBC.transpose(1, 2), + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + activation=None if self.activation == "identity" else self.activation, + ).transpose(1, 2) + return xBC + + def convolutional_step(self, xBC, conv_state): + # Convolutional layer + conv_state = conv_state.to(xBC.dtype) + if causal_conv1d_update: + xBC = causal_conv1d_update( + xBC, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation if self.activation != "identity" else None, + ) + else: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = xBC + xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv_bias: + xBC = xBC + self.conv1d.bias + xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype + + return xBC, conv_state + + +class AprielSSMDecoderLayer(nn.Module): + def __init__(self, config: AprielSSMHybridConfig, layer_idx: int, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} + self.hidden_size = config.hidden_size + + self.mixer = DiscreteMamba2( + d_model=config.hidden_size, + layer_idx=layer_idx, + **config.ssm_cfg, + **factory_kwargs, + ) + + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, hidden_states: torch.Tensor, **kwargs + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + + outputs = {} + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + mixer_outputs = self.mixer( + hidden_states, + **kwargs, + ) + + hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + # outputs["hidden_states"] = hidden_states + outputs = (hidden_states,) + + return outputs + + +class AprielHybridIdentity(nn.Module): + def __init__(self, config: AprielSSMHybridConfig): + super().__init__() + self.config = config + + def forward(self, hidden_states: torch.Tensor, **kwargs): + return (hidden_states,) + + +class AprielSSMHybridModel(MistralModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`, `AprielSSMDecoderLayer`] + Args: + config: AprielSSMHybridConfig + """ + + def __init__(self, config: AprielSSMHybridConfig, **kwargs): + config_copy = copy.deepcopy(config) + config_copy.num_hidden_layers = 0 + super().__init__(config_copy, **kwargs) + self.config = config + blocks = [] + logger.info(f"Loading hyubrid model with the following layout: {config.hybrid_block_layout}") + for layer_idx, type in enumerate(config.hybrid_block_layout): + if type == "m2d": + blocks.append(AprielSSMDecoderLayer(config, layer_idx)) + elif type == "t": + blocks.append(MistralDecoderLayer(config, layer_idx)) + elif type == "i": + blocks.append(AprielHybridIdentity(config)) + else: + raise ValueError(f"Invalid block type: {type}") + self.layers = nn.ModuleList(blocks) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # OO: Cache is initialized in the `prepare_inputs_for_generation` method, so this can be removed + # if use_cache and past_key_values is None: + # past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + partial(decoder_layer.__call__, **flash_attn_kwargs), + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class AprielHybridPreTrainedModel(PreTrainedModel): + config_class = AprielSSMHybridConfig + base_model_prefix = "model" + _no_split_modules = ["MistralDecoderLayer", "AprielSSMDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, MistralRMSNorm): + module.weight.data.fill_(1.0) + + +class AprielSSMHybridForCausalLM(AprielHybridPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self.model = AprielSSMHybridModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache) + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + # "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return AprielHybridCausalOutput( + loss=loss, + logits=logits, + all_hidden_states=outputs.hidden_states, + past_key_values=outputs.past_key_values, + ) + + +__all__ = [ + "AprielSSMHybridForCausalLM", + "AprielSSMHybridModel", + "AprielSSMPreTrainedModel", +] diff --git a/fast_llm/models/hybrid/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py b/fast_llm/models/hybrid/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py new file mode 100644 index 00000000..1d230bb6 --- /dev/null +++ b/fast_llm/models/hybrid/external/apriel_hybrid/configuration_ssm_hybrid_apriel.py @@ -0,0 +1,448 @@ +import math +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import is_torch_available, logging + +logger = logging.get_logger(__name__) + +if is_torch_available(): + import torch + + +def _compute_default_rope_parameters( + config: Optional[PretrainedConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + **rope_kwargs, +) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + if config is not None and len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " + f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" + ) + if len(rope_kwargs) > 0: + base = rope_kwargs["base"] + dim = rope_kwargs["dim"] + elif config is not None: + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) + return inv_freq, attention_factor + + +def _compute_yarn_parameters( + config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs +) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with NTK scaling. Please refer to the + [original paper](https://arxiv.org/abs/2309.00071) + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # No need to keep BC with yarn, unreleased when this new pattern was created. + if len(rope_kwargs) > 0: + raise ValueError( + f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}" + ) + + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + + # Apriel: Use original max_position_embeddings instead of max_position_embeddings + max_position_embeddings = config.rope_scaling.get( + "original_max_position_embeddings", config.max_position_embeddings + ) + factor = config.rope_scaling["factor"] + + # Sets the attention factor as suggested in the paper + attention_factor = config.rope_scaling.get("attention_factor") + if attention_factor is None: + attention_factor = 0.1 * math.log(factor) + 1.0 + + # Optional config options + # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) + beta_fast = config.rope_scaling.get("beta_fast") or 32 + beta_slow = config.rope_scaling.get("beta_slow") or 1 + + # Compute the inverse frequencies + def find_correction_dim(num_rotations, dim, base, max_position_embeddings): + """Inverse dimension formula to find the dimension based on the number of rotations""" + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings): + """Find dimension range bounds based on rotations""" + low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs + # to expand the possible context length. In other words, interpolation = apply scaling factor. + pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor * pos_freqs) + + low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings) + + # Get n-dimensional rotational scaling corrected for extrapolation + inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device) + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor + ) + + return inv_freq, attention_factor + + +def _check_received_keys( + rope_type: str, + received_keys: set, + required_keys: set, + optional_keys: Optional[set] = None, + ignore_keys: Optional[set] = None, +): + """Compare the received keys in `config.rope_scaling` against the expected and optional keys""" + # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present + if "type" in received_keys: + received_keys -= {"type"} + required_keys.add("rope_type") + + # Some models need to store model-specific keys, and we don't want to throw warning at them + if ignore_keys is not None: + received_keys -= ignore_keys + + missing_keys = required_keys - received_keys + if missing_keys: + raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}") + + if optional_keys is not None: + unused_keys = received_keys - required_keys - optional_keys + else: + unused_keys = received_keys - required_keys + if unused_keys: + logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") + + +def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) + + +def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): + rope_scaling = config.rope_scaling + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor", "original_max_position_embeddings"} + optional_keys = {"attention_factor", "beta_fast", "beta_slow"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + attention_factor = rope_scaling.get("attention_factor") + if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0): + logger.warning( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + beta_fast = rope_scaling.get("beta_fast") + if beta_fast is not None and not isinstance(beta_fast, float): + logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + beta_slow = rope_scaling.get("beta_slow") + if beta_slow is not None and not isinstance(beta_slow, float): + logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + + if (beta_fast or 32) < (beta_slow or 1): + logger.warning( + f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " + f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" + ) + + +# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters +# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE +# parameterizations, as long as the callable has the same signature. +ROPE_INIT_FUNCTIONS = { + "default": _compute_default_rope_parameters, + "yarn": _compute_yarn_parameters, +} + +# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types. +ROPE_VALIDATION_FUNCTIONS = { + "default": _validate_default_rope_parameters, + "yarn": _validate_yarn_parameters, +} + + +def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None): + """ + Validate the RoPE config arguments, given a `PretrainedConfig` object + """ + rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig` + if rope_scaling is None: + return + + # BC: "rope_type" was originally "type" + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) + validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) + if validation_fn is not None: + validation_fn(config, ignore_keys=ignore_keys) + else: + logger.warning( + f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" + ) + + +class AprielSSMHybridConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`AprielModel`]. It is used to instantiate an Apriel + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Apriel-5B-Base. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Apriel model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`AprielModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Apriel-5B-Base supports up to 16384 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to + understand more about it. This value is necessary to ensure exact reproducibility of the pretraining + results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'yarn'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'yarn', 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + head_dim (`int`, *optional*): + The attention head dimension. If None, it will default to hidden_size // num_attention_heads + ```python + >>> from transformers import AprielModel, AprielConfig + >>> # Initializing an Apriel Apriel-5B-Base style configuration + >>> configuration = AprielConfig() + >>> # Initializing a model from the Apriel-5B-Base style configuration + >>> model = AprielModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "apriel_ssm_hybrid" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `AprielModel` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + head_dim=None, + hybrid_block_layout=["m2d"], + ssm_cfg=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + self.hybrid_block_layout = hybrid_block_layout + if len(hybrid_block_layout) == 1: + self.hybrid_block_layout = [hybrid_block_layout[0]] * self.num_hidden_layers + assert len(self.hybrid_block_layout) == self.num_hidden_layers + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + ssm_defaults = { + "d_state": 64, + "n_v_heads": 24, + "n_qk_heads": 24, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + "d_conv": 4, + "d_inner": 24 * self.head_dim, # num_heads * head_dim + } + self.ssm_cfg = ssm_cfg or ssm_defaults + for k, v in ssm_defaults.items(): + if k not in self.ssm_cfg: + self.ssm_cfg[k] = v diff --git a/fast_llm/models/hybrid/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py b/fast_llm/models/hybrid/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py new file mode 100644 index 00000000..52b8e47e --- /dev/null +++ b/fast_llm/models/hybrid/external/apriel_hybrid/modeling_ssm_hybrid_apriel.py @@ -0,0 +1,1568 @@ +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from einops import rearrange, repeat +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined +from torch import nn +from transformers import GenerationMixin +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from transformers.utils.generic import ModelOutput + +from fast_llm.models.hybrid.external.apriel_hybrid.configuration_ssm_hybrid_apriel import ( + ROPE_INIT_FUNCTIONS, + AprielSSMHybridConfig, +) + +logger = logging.get_logger(__name__) + + +class HybridMambaAttentionStaticCache(Cache): + def __init__(self, config: AprielSSMHybridConfig, batch_size, max_length, dtype=torch.float16, device=None): + super().__init__() # config, batch_size, max_length, device, dtype) + self.dtype = dtype + self.hybrid_override_pattern = config.hybrid_block_layout + self.has_previous_state = False # only used by mamba + intermediate_size = config.ssm_cfg["d_inner"] + ssm_state_size = config.ssm_cfg["d_state"] + conv_kernel_size = config.ssm_cfg["d_conv"] + self.n_qk_heads = config.ssm_cfg["n_qk_heads"] + assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" + self.head_d = intermediate_size // self.n_qk_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + + self.batch_size = batch_size + self.head_dim = ( + config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + ) + self.max_cache_len = config.max_position_embeddings if max_length is None else max_length + + self.num_key_value_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + cache_shape = (self.batch_size, self.num_key_value_heads, max_length, self.head_dim) + + for i in range(config.num_hidden_layers): + if self.hybrid_override_pattern[i] == "m2d": + # Mamba layer + new_layer_conv_state = torch.zeros( + batch_size, + conv_kernel_size, + intermediate_size + 2 * self.n_qk_heads * ssm_state_size, + device=device, + dtype=dtype, + ).transpose(1, 2) + + new_layer_ssm_state = torch.zeros( + batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype + ) + new_layer_key_cache = None # torch.zeros((0,), dtype=dtype, device=device) + new_layer_value_cache = None # torch.zeros((0,), dtype=dtype, device=device) + else: + # Attention or MLP layer + new_layer_conv_state = None # torch.tensor((0,), dtype=dtype, device=device) + new_layer_ssm_state = None # torch.tensor((0,), dtype=dtype, device=device) + new_layer_key_cache = torch.zeros(cache_shape, dtype=dtype, device=device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=dtype, device=device) + self.transformer_layers.append(i) + + # if not is_torchdynamo_compiling(): + # self.register_buffer(f"key_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) + # self.register_buffer(f"value_cache_{i}", torch.zeros(cache_shape, dtype=dtype, device=device)) + # new_layer_key_cache = getattr(self, f"key_cache_{i}") + # new_layer_value_cache = getattr(self, f"value_cache_{i}") + # torch._dynamo.mark_static_address(new_layer_key_cache) + # torch._dynamo.mark_static_address(new_layer_value_cache) + # self.register_buffer(f"conv_states_{i}", new_layer_conv_state) + # self.register_buffer(f"ssm_states_{i}", new_layer_ssm_state) + # torch._dynamo.mark_static_address(new_layer_conv_state) + # torch._dynamo.mark_static_address(new_layer_ssm_state) + # new_layer_ssm_state = getattr(self, f"ssm_states_{i}") + # new_layer_conv_state = getattr(self, f"conv_states_{i}") + + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + self.conv_states.append(new_layer_conv_state) + self.ssm_states.append(new_layer_ssm_state) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. + + Return: + A tuple containing the updated key and value states. + """ + + cache_position = cache_kwargs.get("cache_position") + + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + + if cache_position is None: + k_out.copy_(key_states) + v_out.copy_(value_states) + else: + # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to + # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place + # operation, that avoids copies and uses less memory. + try: + k_out.index_copy_(2, cache_position, key_states) + v_out.index_copy_(2, cache_position, value_states) + except NotImplementedError: + # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + return k_out, v_out + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = None) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + if layer_idx is None: + layer_idx = self.transformer_layers[0] + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def get_max_cache_shape(self) -> Optional[int]: + return self.max_cache_len + + # Copied from modeling_mamba2.py + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False + ) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py +class HybridMambaAttentionDynamicCache(DynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config: AprielSSMHybridConfig, batch_size, dtype=torch.float16, device=None): + super().__init__() + self.dtype = dtype + self.hybrid_override_pattern = config.hybrid_block_layout + self.has_previous_state = False # only used by mamba + intermediate_size = config.ssm_cfg["d_inner"] + ssm_state_size = config.ssm_cfg["d_state"] + conv_kernel_size = config.ssm_cfg["d_conv"] + self.n_qk_heads = config.ssm_cfg["n_qk_heads"] + assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" + self.head_d = intermediate_size // self.n_qk_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + if self.hybrid_override_pattern[i] == "m2d": + # Mamba layer + self.conv_states += [ + torch.zeros( + batch_size, + conv_kernel_size, + intermediate_size + 2 * self.n_qk_heads * ssm_state_size, + device=device, + dtype=dtype, + ).transpose(1, 2) + ] + self.ssm_states += [ + torch.zeros(batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype) + ] + else: + # Attention or MLP layer + self.conv_states += [torch.tensor([[]] * batch_size, device=device)] + self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + # Copied from modeling_mamba2.py + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False + ) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +@dataclass +class AprielHybridCausalOutput(ModelOutput): + """Custom output class for MambaLMHeadModel.""" + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + attention_weights: Optional[torch.FloatTensor] = None + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None + + +class AprielRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6, device=None, dtype=None, **kwargs): + """ + AprielRMSNorm is equivalent to T5LayerNorm + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(AprielRMSNorm) + + +class AprielMLP(nn.Module): + def __init__(self, config, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias, **factory_kwargs) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class AprielRotaryEmbedding(nn.Module): + def __init__(self, config: AprielSSMHybridConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class AprielAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: AprielSSMHybridConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +def segsum(x): + """More stable segment sum calculation.""" + # [1, 2, 3] + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + # [[1, 1, 1], [2, 2, 2], [3, 3, 3]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) + x = x.masked_fill(~mask, 0) + # [[0, 0, 0], [2, 0, 0], [3, 3, 0]] + x_segsum = torch.cumsum(x, dim=-2) + # [[0, 0, 0], [2, 0, 0], [5, 3, 0]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def materialize_mixer(A_log, B, C, D): + """ + Since the transfer matrix will be equated to the attention matrix, + we need to support the form: torch.matmul(attn_weights, value_states). + Thus, y = torch.matmul(T, X) + Arguments: + A_log: (batch, length, n_heads) + B: (batch, length, n_heads, d_state) + C: (batch, length, n_heads, d_state) + Return: + T: (batch, n_heads, length, length) + """ + batch_size, length, n_heads, d_state = B.shape + assert A_log.shape == (batch_size, length, n_heads) + assert B.shape == C.shape == (batch_size, length, n_heads, d_state) + + # Compute: + A_log = rearrange(-F.softplus(A_log), "b l h -> b h l") + powers = torch.exp(segsum(A_log)) + T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers) + + # Add D: + if D is not None: + T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1) + + T = rearrange(T, "b h z l -> b h l z") + return T + + +# This is from LLmaba/Mohawk: https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py +class DiscreteMamba2(nn.Module): + def __init__( + self, + d_model, + d_state=64, + n_qk_heads=32, + n_v_heads=32, + d_conv=4, + expand=1, + activation="identity", + bias=False, + conv_bias=True, + chunk_size=128, + layer_idx=None, + device=None, + dtype=None, + d_inner=None, + **kwargs, # Absorb kwarg for general module + ): + """ + See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. + Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" + + Other options are all experimental and should not need to be configured + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = self.expand * self.d_model if d_inner is None else d_inner + self.n_qk_heads = n_qk_heads + self.n_v_heads = n_v_heads + self.headdim = self.d_inner // self.n_v_heads + assert self.n_v_heads == self.d_inner // self.headdim + assert self.d_inner % self.headdim == 0 + assert self.n_v_heads % self.n_qk_heads == 0 + self.activation = activation + self.chunk_size = chunk_size + self.layer_idx = layer_idx + self.bias = bias + self.kwargs = kwargs + + # Projections + self.in_proj = nn.Linear( + self.d_model, + 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, + bias=bias, + **factory_kwargs, + ) + self.z_bias = ( + nn.Parameter(torch.zeros(self.d_inner, device=device)) if not bias else 0 + ) # make sure z_bias always exists + + # Convolutional layer + conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state + self.conv_bias = conv_bias + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + **factory_kwargs, + ) + + # Activation after conv + if self.activation == "identity": + self.act = nn.Identity() + elif self.activation in ["silu", "swish"]: + self.act = nn.SiLU() + else: + raise ValueError(f"Unknown activation {self.activation}") + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.n_v_heads, device=device)) + self.D._optim = {"weight_decay": 0.0} + + # out_proj + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + + @property + def d_output(self): + return self.d_model + + @property + def state_to_tensor(self): + return self.layer.state_to_tensor + + def forward( + self, + u, + return_mixer_matrix=False, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + inference_params=None, + **kwargs, + ): + """ + u: (B, L, D) + Returns: same shape as u + """ + outputs = {} + # assert state is None + batch, seqlen, dim = u.shape + + ssm_state, conv_state = None, None + if past_key_value is not None: + ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) + if inference_params is not None and inference_params.seqlen_offset > 0: + # States are updated inplace + # TODO: make sure inference_params with seqlen_offset are properly initialized + u = u.squeeze(1) if len(u.shape) == 3 else u + out, _, _ = self.step(u, ssm_state, conv_state) + out = out.unsqueeze(1) if len(u.shape) == 2 else out + return {"hidden_states": out} + + # Hacky way to initialize state during inference + chunk_size = self.chunk_size if ssm_state is None else seqlen + + # Pad input to nearest multiple of chunklen + padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size + u = F.pad(u, (0, 0, 0, padded_len - seqlen)) + + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + if ssm_state is not None: + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") + conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + + # Convolutional layer + xBC = self.convolutional_forward(xBC, padded_len) + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) + B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) + C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + + # SSM forward + result = mamba_chunk_scan_combined( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=A_log, + dt_softplus=True, + A=-torch.ones(self.n_v_heads, device=A_log.device), + B=B, + C=C, + chunk_size=chunk_size, + # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation + return_final_states=(ssm_state is not None), + ) + + if ssm_state is not None: + y, ssm_state_update = result + ssm_state.copy_(ssm_state_update) + else: + y = result + + Du = torch.einsum("h,blhp->blhp", self.D, x) + y = rearrange(y + Du, "b l h p -> b l (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + outputs["hidden_states"] = out[:, :seqlen, :] + + if return_mixer_matrix: + outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] + return outputs + + def step(self, u, ssm_state, conv_state, **kwargs): + """ + u: (B D) + state: dict of states + Returns: same shape as u + """ + + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + xBC, conv_state_new = self.convolutional_step(xBC, conv_state) + conv_state.copy_(conv_state_new) # update state in place + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) + B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) + C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) + + ssm_state = ssm_state.to(x.dtype) + zeros = torch.zeros((self.n_v_heads, self.headdim), device=A_log.device).to(dtype=x.dtype) + ones = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=A_log.device).to(dtype=x.dtype) + y = selective_state_update( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=repeat(A_log, "b h -> b h p", p=self.headdim), + dt_softplus=True, + A=-ones, + B=B, + C=C, + state=ssm_state, # will be updated in place + dt_bias=zeros, + D=zeros, + ) + + y = y + self.D[:, None] * x + y = rearrange(y, "b h p -> b (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + + return out, ssm_state, conv_state + + # def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + # device = self.in_proj.weight.device + # # conv_state: + # conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + # conv_state = torch.zeros( + # batch_size, + # self.d_conv, + # self.conv1d.weight.shape[0], + # device=device, + # dtype=conv_dtype, + # ).transpose(1, 2) + # # ssm_state: + # ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype + # ssm_state = torch.zeros( + # batch_size, + # self.n_v_heads, + # self.headdim, + # self.d_state, + # device=device, + # dtype=ssm_dtype, + # ) + # return {"conv": conv_state, "ssm": ssm_state} + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + """ + conv_state: (batch, d_conv, conv1d.weight.shape[0]) + ssm_state: (batch, n_qk_heads, headdim, d_state) + """ + assert self.layer_idx is not None + # Allocate memory if not exists + # if self.layer_idx not in inference_params.ssm_states: + # inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( + # batch_size, inference_params.max_seqlen, dtype=torch.float32 + # ) + # Get states + ssm_states = inference_params.ssm_states[self.layer_idx] + conv_states = inference_params.conv_states[self.layer_idx] + if initialize_states: + ssm_states.zero_() + conv_states.zero_() + return ssm_states, conv_states + + def convolutional_forward(self, xBC, padded_len): + if causal_conv1d_fn is None or self.activation not in [ + "silu", + "swish", + "identity", + ]: + xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) + else: + xBC = causal_conv1d_fn( + xBC.transpose(1, 2), + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + activation=None if self.activation == "identity" else self.activation, + ).transpose(1, 2) + return xBC + + def convolutional_step(self, xBC, conv_state): + # Convolutional layer + conv_state = conv_state.to(xBC.dtype) + if causal_conv1d_update: + xBC = causal_conv1d_update( + xBC, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation if self.activation != "identity" else None, + ) + else: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = xBC + xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv_bias: + xBC = xBC + self.conv1d.bias + xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype + + return xBC, conv_state + + +class AprielDecoderLayer(nn.Module): + def __init__(self, config: AprielSSMHybridConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = AprielAttention(config=config, layer_idx=layer_idx) + + self.mlp = AprielMLP(config) + self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class AprielSSMDecoderLayer(nn.Module): + def __init__(self, config: AprielSSMHybridConfig, layer_idx: int, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} + self.hidden_size = config.hidden_size + + self.mixer = DiscreteMamba2( + d_model=config.hidden_size, + layer_idx=layer_idx, + **config.ssm_cfg, + **factory_kwargs, + ) + + self.mlp = AprielMLP(config, **factory_kwargs) + self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) + self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) + + def forward( + self, hidden_states: torch.Tensor, **kwargs + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + + outputs = {} + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + mixer_outputs = self.mixer( + hidden_states, + **kwargs, + ) + + hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + # outputs["hidden_states"] = hidden_states + outputs = (hidden_states,) + + return outputs + + # def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + # """Allocate inference cache for the model.""" + # if getattr(self.mixer, "allocate_inference_cache", None) is None: + # return + # return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + +APRIEL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`AprielSSMHybridConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Apriel Model outputting raw hidden-states without any specific head on top.", + APRIEL_START_DOCSTRING, +) +class AprielSSMPreTrainedModel(PreTrainedModel): + config_class = AprielSSMHybridConfig + base_model_prefix = "model" + _no_split_modules = ["AprielDecoderLayer", "AprielSSMDecoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + # def allocate_inference_cache(self, *args, **kwargs): + # """Allocate inference cache for the model.""" + # return getattr(self, self.base_model_prefix).allocate_inference_cache(*args, **kwargs) + + +APRIEL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Apriel Model outputting raw hidden-states without any specific head on top.", + APRIEL_START_DOCSTRING, +) +class AprielSSMHybridModel(AprielSSMPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`, `AprielSSMDecoderLayer`] + Args: + config: AprielSSMHybridConfig + """ + + def __init__(self, config: AprielSSMHybridConfig, device=None, dtype=None, **kwargs): + super().__init__(config, device=device, dtype=dtype, **kwargs) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + factory_kwargs = {"device": device, "dtype": dtype} + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, **factory_kwargs) + blocks = [] + logger.info(f"Loading hyubrid model with the following layout: {config.hybrid_block_layout}") + for layer_idx, type in enumerate(config.hybrid_block_layout): + if type == "m2d": + blocks.append(AprielSSMDecoderLayer(config, layer_idx, **factory_kwargs)) + elif type == "t": + blocks.append(AprielDecoderLayer(config, layer_idx)) + else: + raise ValueError(f"Invalid block type: {type}") + self.layers = nn.ModuleList(blocks) + self.norm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) + self.gradient_checkpointing = False + self.rotary_emb = AprielRotaryEmbedding(config=config) + self.has_transformer_layers = any(type == "t" for type in config.hybrid_block_layout) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # def allocate_inference_cache(self, *args, **kwargs): + # """Allocate inference cache for the model.""" + # cache = {} + # for i, layer in enumerate(self.layers): + # if isinstance(layer, AprielSSMDecoderLayer): + # cache[i] = layer.allocate_inference_cache(*args, **kwargs) + # return cache + + @add_start_docstrings_to_model_forward(APRIEL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + inference_params=None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + # past_key_values = HybridMambaAttentionDynamicCache() + logger.warning_once( + "Hybrid Apriel requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " + "provided, so no cache will be returned." + ) + + if cache_position is None and self.has_transformer_layers: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None and self.has_transformer_layers: + position_ids = cache_position.unsqueeze(0) + + causal_mask = ( + self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions) + if self.has_transformer_layers + else None + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) if self.has_transformer_layers else None + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + inference_params=inference_params, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions and isinstance(decoder_layer, AprielDecoderLayer): + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) or isinstance( + past_key_values, HybridMambaAttentionStaticCache + ) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class AprielSSMHybridForCausalLM(AprielSSMPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + self.model = AprielSSMHybridModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache) + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + # "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids=None, + return_hidden_states=False, + return_logits=True, + num_last_tokens=0, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[tuple, CausalLMOutputWithPast]: + + # past_key_values is None if prepare_inputs_for_generation is not called, which is the case when we evaluate without calling generate (non-generation tasks) + # Its generally ok if cache is nto instantiated in this case, since we do single pass per sample anyways, a warning will be triggered in the model + outputs: BaseModelOutputWithPast = self.model( + input_ids, + return_hidden_states=return_hidden_states, + position_ids=position_ids, + past_key_values=past_key_values, + **kwargs, + ) + + if outputs["last_hidden_state"] is not None and return_logits: + logits = self.lm_head(outputs["last_hidden_state"]).float() + outputs["logits"] = logits if num_last_tokens == 0 else logits[:, -num_last_tokens:] + else: + outputs["logits"] = None + + return AprielHybridCausalOutput( + loss=None, + logits=outputs["logits"], + all_hidden_states=outputs.hidden_states, + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + ) + + +__all__ = [ + "AprielSSMHybridForCausalLM", + "AprielSSMHybridModel", + "AprielSSMPreTrainedModel", +] diff --git a/fast_llm/models/hybrid/external/apriel_ssm/configuration_ssm_apriel.py b/fast_llm/models/hybrid/external/apriel_ssm/configuration_ssm_apriel.py new file mode 100644 index 00000000..6943a312 --- /dev/null +++ b/fast_llm/models/hybrid/external/apriel_ssm/configuration_ssm_apriel.py @@ -0,0 +1,103 @@ +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Apriel SSM model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import is_torch_available, logging + +logger = logging.get_logger(__name__) + +if is_torch_available(): + pass + + +class AprielSSMConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`AprielModel`]. It is used to instantiate an Apriel + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Apriel-5B-Base. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + .... + ```""" + + model_type = "apriel_ssm" + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + hidden_act="silu", + initializer_range=0.02, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + mlp_bias=False, + rms_norm_eps=1e-5, + ssm_cfg: dict = None, + head_dim: int = 128, + **kwargs, + ): + self.vocab_size = vocab_size + # self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + # self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + # self.rope_theta = rope_theta + self.mlp_bias = mlp_bias + self.head_dim = head_dim + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + # if self.rope_scaling is not None and "type" in self.rope_scaling: + # self.rope_scaling["rope_type"] = self.rope_scaling["type"] + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + self.ssm_cfg = ssm_cfg or { + "d_state": 64, + "n_v_heads": 24, + "n_qk_heads": 24, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + "d_inner": 24 * self.head_dim, # num_heads * head_dim + } + if self.head_dim != self.ssm_cfg["d_inner"] // self.ssm_cfg["n_qk_heads"]: + logger.warning("Head dim is not equal to d_inner // n_qk_heads.") + + +__all__ = ["AprielConfig"] diff --git a/fast_llm/models/hybrid/external/apriel_ssm/modeling_ssm_apriel.py b/fast_llm/models/hybrid/external/apriel_ssm/modeling_ssm_apriel.py new file mode 100644 index 00000000..82272e2a --- /dev/null +++ b/fast_llm/models/hybrid/external/apriel_ssm/modeling_ssm_apriel.py @@ -0,0 +1,743 @@ +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from einops import rearrange, repeat +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined +from mamba_ssm.utils.generation import GenerationMixin +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from transformers.utils.generic import ModelOutput + +from fast_llm.models.hybrid.external.apriel_ssm.configuration_ssm_apriel import AprielSSMConfig + +logger = logging.get_logger(__name__) + + +@dataclass +class CustomMambaCausalLMOutput(ModelOutput): + """Custom output class for MambaLMHeadModel.""" + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + + +class AprielRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6, device=None, dtype=None, **kwargs): + """ + AprielRMSNorm is equivalent to T5LayerNorm + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(AprielRMSNorm) + + +class AprielMLP(nn.Module): + def __init__(self, config, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias, **factory_kwargs) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias, **factory_kwargs) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def segsum(x): + """More stable segment sum calculation.""" + # [1, 2, 3] + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + # [[1, 1, 1], [2, 2, 2], [3, 3, 3]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) + x = x.masked_fill(~mask, 0) + # [[0, 0, 0], [2, 0, 0], [3, 3, 0]] + x_segsum = torch.cumsum(x, dim=-2) + # [[0, 0, 0], [2, 0, 0], [5, 3, 0]] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def materialize_mixer(A_log, B, C, D): + """ + Since the transfer matrix will be equated to the attention matrix, + we need to support the form: torch.matmul(attn_weights, value_states). + Thus, y = torch.matmul(T, X) + Arguments: + A_log: (batch, length, n_heads) + B: (batch, length, n_heads, d_state) + C: (batch, length, n_heads, d_state) + Return: + T: (batch, n_heads, length, length) + """ + batch_size, length, n_heads, d_state = B.shape + assert A_log.shape == (batch_size, length, n_heads) + assert B.shape == C.shape == (batch_size, length, n_heads, d_state) + + # Compute: + A_log = rearrange(-F.softplus(A_log), "b l h -> b h l") + powers = torch.exp(segsum(A_log)) + T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers) + + # Add D: + if D is not None: + T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1) + + T = rearrange(T, "b h z l -> b h l z") + return T + + +class DiscreteMamba2(nn.Module): + def __init__( + self, + d_model, + d_state=64, + n_qk_heads=32, + n_v_heads=32, + d_conv=4, + expand=1, + activation="identity", + bias=False, + conv_bias=True, + chunk_size=128, + layer_idx=None, + device=None, + dtype=None, + d_inner=None, + **kwargs, # Absorb kwarg for general module + ): + """ + See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. + Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" + + Other options are all experimental and should not need to be configured + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = self.expand * self.d_model if d_inner is None else d_inner + self.n_qk_heads = n_qk_heads + self.n_v_heads = n_v_heads + self.headdim = self.d_inner // self.n_v_heads + assert self.n_v_heads == self.d_inner // self.headdim + assert self.d_inner % self.headdim == 0 + assert self.n_v_heads % self.n_qk_heads == 0 + self.activation = activation + self.chunk_size = chunk_size + self.layer_idx = layer_idx + self.bias = bias + self.kwargs = kwargs + + # Projections + self.in_proj = nn.Linear( + self.d_model, + 2 * self.d_inner + 2 * self.n_qk_heads * self.d_state + self.n_v_heads, + bias=bias, + **factory_kwargs, + ) + self.z_bias = ( + nn.Parameter(torch.zeros(self.d_inner, **factory_kwargs)) if not bias else 0 + ) # make sure z_bias always exists + + # Convolutional layer + conv_dim = self.d_inner + 2 * self.n_qk_heads * self.d_state + self.conv_bias = conv_bias + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + bias=conv_bias, + kernel_size=d_conv, + groups=conv_dim, + padding=d_conv - 1, + **factory_kwargs, + ) + + # Activation after conv + if self.activation == "identity": + self.act = nn.Identity() + elif self.activation in ["silu", "swish"]: + self.act = nn.SiLU() + else: + raise ValueError(f"Unknown activation {self.activation}") + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.n_v_heads, **factory_kwargs)) + self.D._optim = {"weight_decay": 0.0} + + # out_proj + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + + @property + def d_output(self): + return self.d_model + + @property + def state_to_tensor(self): + return self.layer.state_to_tensor + + def forward(self, u, return_mixer_matrix=False, inference_params=None, **kwargs): + """ + u: (B, L, D) + Returns: same shape as u + """ + outputs = {} + # assert state is None + batch, seqlen, dim = u.shape + + state = None + if inference_params is not None: + state = self._get_states_from_cache(inference_params, batch) + if inference_params.seqlen_offset > 0: + # States are updated inplace + u = u.squeeze(1) if len(u.shape) == 3 else u + out, _ = self.step(u, state) + out = out.unsqueeze(1) if len(u.shape) == 2 else out + return {"hidden_states": out} + + # Hacky way to initialize state during inference + chunk_size = self.chunk_size if state is None else seqlen + + # Pad input to nearest multiple of chunklen + padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size + u = F.pad(u, (0, 0, 0, padded_len - seqlen)) + + # Project input + xBCzA_log = self.in_proj(u) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + if state is not None: + # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") + state["conv"].copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) + + # Convolutional layer + xBC = self.convolutional_forward(xBC, padded_len) + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) + B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) + C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) + + # SSM forward + result = mamba_chunk_scan_combined( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=A_log, + dt_softplus=True, + A=-torch.ones(self.n_v_heads, device=A_log.device), + B=B, + C=C, + chunk_size=chunk_size, + # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation + return_final_states=(state is not None), + ) + + if state is not None: + y, ssm_state = result + state["ssm"].copy_(ssm_state) + else: + y = result + + Du = torch.einsum("h,blhp->blhp", self.D, x) + y = rearrange(y + Du, "b l h p -> b l (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + outputs["hidden_states"] = out[:, :seqlen, :] + + if return_mixer_matrix: + outputs["transfer_matrix"] = materialize_mixer(A_log=A_log, B=B, C=C, D=self.D)[..., :seqlen, :seqlen] + return outputs + + def step(self, u, state, **kwargs): + """ + u: (B D) + state: dict of states + Returns: same shape as u + """ + + # Project input + xBCzA_log = self.in_proj(u.squeeze(1)) + xBC, z, A_log = torch.split( + xBCzA_log, + [ + self.d_inner + 2 * self.n_qk_heads * self.d_state, + self.d_inner, + self.n_v_heads, + ], + dim=-1, + ) + + xBC, conv_state = self.convolutional_step(xBC, state["conv"]) + state["conv"].copy_(conv_state) # update state in place + + x, B, C = torch.split( + xBC, + [ + self.d_inner, + self.n_qk_heads * self.d_state, + self.n_qk_heads * self.d_state, + ], + dim=-1, + ) + + x = rearrange(x, "b (h s) -> b h s", h=self.n_v_heads) + B = rearrange(B, "b (h s) -> b h s", h=self.n_qk_heads) + C = rearrange(C, "b (h s) -> b h s", h=self.n_qk_heads) + + state["ssm"] = state["ssm"].to(x.dtype) + zeros = torch.zeros((self.n_v_heads, self.headdim), device=A_log.device).to(dtype=x.dtype) + ones = torch.ones((self.n_v_heads, self.headdim, self.d_state), device=A_log.device).to(dtype=x.dtype) + y = selective_state_update( + x=x / F.softplus(A_log).to(x.dtype).unsqueeze(-1), + dt=repeat(A_log, "b h -> b h p", p=self.headdim), + dt_softplus=True, + A=-ones, + B=B, + C=C, + state=state["ssm"], # will be updated in place + dt_bias=zeros, + D=zeros, + ) + + y = y + self.D[:, None] * x + y = rearrange(y, "b h p -> b (h p)") + + # Norm and gate + out = self.out_proj(y * F.silu(z + self.z_bias)) + + return out, state + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + device = self.in_proj.weight.device + # conv_state: + conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + conv_state = torch.zeros( + batch_size, + self.d_conv, + self.conv1d.weight.shape[0], + device=device, + dtype=conv_dtype, + ).transpose(1, 2) + # ssm_state: + ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype + ssm_state = torch.zeros( + batch_size, + self.n_v_heads, + self.headdim, + self.d_state, + device=device, + dtype=ssm_dtype, + ) + return {"conv": conv_state, "ssm": ssm_state} + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + """ + conv_state: (batch, d_conv, conv1d.weight.shape[0]) + ssm_state: (batch, n_qk_heads, headdim, d_state) + """ + assert self.layer_idx is not None + # Allocate memory if not exists + if self.layer_idx not in inference_params.key_value_memory_dict: + inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( + batch_size, inference_params.max_seqlen, dtype=torch.float32 + ) + # Get states + states = inference_params.key_value_memory_dict[self.layer_idx] + if initialize_states: + states["conv"].zero_() + states["ssm"].zero_() + return states + + def convolutional_forward(self, xBC, padded_len): + if causal_conv1d_fn is None or self.activation not in [ + "silu", + "swish", + "identity", + ]: + xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) + else: + xBC = causal_conv1d_fn( + xBC.transpose(1, 2), + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + activation=None if self.activation == "identity" else self.activation, + ).transpose(1, 2) + return xBC + + def convolutional_step(self, xBC, conv_state): + # Convolutional layer + conv_state = conv_state.to(xBC.dtype) + if causal_conv1d_update: + xBC = causal_conv1d_update( + xBC, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation if self.activation != "identity" else None, + ) + else: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = xBC + xBC = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv_bias: + xBC = xBC + self.conv1d.bias + xBC = self.act(xBC).to(xBC.dtype) # Some activations change dtype + + return xBC, conv_state + + +class AprielDecoderLayer(nn.Module): + def __init__(self, config: AprielSSMConfig, layer_idx: int, device=None, dtype=None, **kwargs): + super().__init__(**kwargs) + factory_kwargs = {"device": device, "dtype": dtype} + self.hidden_size = config.hidden_size + + self.mixer = DiscreteMamba2( + d_model=config.hidden_size, + layer_idx=layer_idx, + **config.ssm_cfg, + **factory_kwargs, + ) + + self.mlp = AprielMLP(config, **factory_kwargs) + self.input_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) + self.post_attention_layernorm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) + + def forward( + self, hidden_states: torch.Tensor, inference_params=None, **kwargs + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + + outputs = {} + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + mixer_outputs = self.mixer( + hidden_states, + inference_params=inference_params, + ) + + hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs["hidden_states"] = hidden_states + + return outputs + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + """Allocate inference cache for the model.""" + if getattr(self.mixer, "allocate_inference_cache", None) is None: + return + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + +APRIEL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`AprielSSMConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Apriel Model outputting raw hidden-states without any specific head on top.", + APRIEL_START_DOCSTRING, +) +class AprielSSMPreTrainedModel(PreTrainedModel): + config_class = AprielSSMConfig + base_model_prefix = "model" + _no_split_modules = ["AprielDecoderLayer"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def allocate_inference_cache(self, *args, **kwargs): + """Allocate inference cache for the model.""" + return getattr(self, self.base_model_prefix).allocate_inference_cache(*args, **kwargs) + + +APRIEL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Apriel Model outputting raw hidden-states without any specific head on top.", + APRIEL_START_DOCSTRING, +) +class AprielSSMModel(AprielSSMPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`] + Args: + config: AprielSSMConfig + """ + + def __init__(self, config: AprielSSMConfig, device=None, dtype=None, **kwargs): + super().__init__(config, device=device, dtype=dtype, **kwargs) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + factory_kwargs = {"device": device, "dtype": dtype} + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, **factory_kwargs) + self.layers = nn.ModuleList( + [AprielDecoderLayer(config, layer_idx, **factory_kwargs) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = AprielRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **factory_kwargs) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def allocate_inference_cache(self, *args, **kwargs): + """Allocate inference cache for the model.""" + return {i: layer.allocate_inference_cache(*args, **kwargs) for i, layer in enumerate(self.layers)} + + @add_start_docstrings_to_model_forward(APRIEL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + return_hidden_states=False, + inference_params=None, + position_ids=None, + ) -> Union[tuple, BaseModelOutputWithPast]: + + hidden_states = self.embed_tokens(input_ids) + + # decoder layers + outputs = { + "last_hidden_state": None, + "all_hidden_states": (hidden_states,) if return_hidden_states else (), + } + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + + layer_outputs = decoder_layer( + hidden_states, + inference_params=inference_params, + position_ids=position_ids, + ) + # Record outputs + hidden_states = layer_outputs["hidden_states"] + if return_hidden_states: + outputs["all_hidden_states"] += (hidden_states,) + + outputs["last_hidden_state"] = self.norm(hidden_states) + return outputs + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class AprielSSMForCausalLM(AprielSSMPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config, device=None, dtype=None, **kwargs): + super().__init__(config, device=device, dtype=dtype, **kwargs) + self.model = AprielSSMModel(config, device=device, dtype=dtype) + self.vocab_size = config.vocab_size + factory_kwargs = {"device": device, "dtype": dtype} + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, **factory_kwargs) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids=None, + return_hidden_states=False, + return_logits=True, + inference_params=None, + num_last_tokens=0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[tuple, CausalLMOutputWithPast]: + + outputs = self.model( + input_ids, + return_hidden_states=return_hidden_states, + inference_params=inference_params, + position_ids=position_ids, + ) + + if outputs["last_hidden_state"] is not None and return_logits: + logits = self.lm_head(outputs["last_hidden_state"]).float() + outputs["logits"] = logits if num_last_tokens == 0 else logits[:, -num_last_tokens:] + else: + outputs["logits"] = None + + return CustomMambaCausalLMOutput( + loss=None, + logits=outputs["logits"], + all_hidden_states=outputs["all_hidden_states"], + last_hidden_state=outputs["last_hidden_state"], + ) + + def generate(self, *args, **kwargs): + """ + This is a wrapper to make sure we comply with the HF generation interface for eval harness + """ + return super().generate(*args, **kwargs) + + +__all__ = [ + "AprielSSMForCausalLM", + "AprielModel", + "AprielSSMPreTrainedModel", +] diff --git a/fast_llm/models/hybrid/external/eval/apriel_eval_wrapper.py b/fast_llm/models/hybrid/external/eval/apriel_eval_wrapper.py new file mode 100644 index 00000000..9ccc7768 --- /dev/null +++ b/fast_llm/models/hybrid/external/eval/apriel_eval_wrapper.py @@ -0,0 +1,180 @@ +from typing import Optional, Union + +import lm_eval.models.utils +import torch +from lm_eval.api.registry import register_model +from lm_eval.models.huggingface import HFLM + + +@register_model("apriel_ssm") +class AprielSSMWrapper(HFLM): + """Wrapper for AprielSSM model for compatibility with lm-evaluation-harness.""" + + def __init__(self, pretrained, **kwargs) -> None: + if "backend" in kwargs: + assert kwargs["backend"] == "causal" + + super().__init__( + pretrained=pretrained, + backend=kwargs.pop("backend", "causal"), + tokenizer=kwargs.pop("tokenizer", "/mnt/checkpoints/upstream/Mistral-Nemo-Base-2407/"), + max_length=kwargs.pop("max_length", 4096), + **kwargs, + ) + + def _get_config(self, pretrained: str, **kwargs) -> None: + """Get the model configuration.""" + from fast_llm.models.hybrid.external.apriel_ssm.configuration_ssm_apriel import AprielSSMConfig + + self._config = AprielSSMConfig.from_pretrained(pretrained) + + def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: + """Create the model.""" + from fast_llm.models.hybrid.external.apriel_ssm.modeling_ssm_apriel import AprielSSMForCausalLM + + self._model = AprielSSMForCausalLM.from_pretrained( + pretrained, + device=self._device, + dtype=torch.bfloat16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), + trust_remote_code=True, + ) + + def _model_generate(self, context, max_length, stop, **generation_kwargs): + """Generate text from the model.""" + for key in ("do_sample", "attention_mask"): + if key in generation_kwargs: + generation_kwargs.pop(key) + + # The custom GenerationMixin imported from mamba_ssm currently does not support + # passing stopping criteria. + # For the time being, we simply generate to max length, then truncate (equivalent result). + # This should be revisited to speed up generation + # stopping_criteria = stop_sequences_criteria(self.tokenizer, stop, 1, context.shape[0]) + + return self.model.generate( + input_ids=context, + max_length=max_length, + **generation_kwargs, + ) + + +@register_model("apriel_hybrid_ssm") +class AprielHybridSSMWrapper(HFLM): + """Wrapper for AprielHybridSSM model for compatibility with lm-evaluation-harness.""" + + def __init__(self, pretrained, **kwargs) -> None: + if "backend" in kwargs: + assert kwargs["backend"] == "causal" + + super().__init__( + pretrained=pretrained, + backend=kwargs.pop("backend", "causal"), + tokenizer=kwargs.pop("tokenizer", "/mnt/checkpoints/upstream/Mistral-Nemo-Base-2407/"), + max_length=kwargs.pop("max_length", 4096), + **kwargs, + ) + + def _get_config(self, pretrained: str, **kwargs) -> None: + """Get the model configuration.""" + from fast_llm.models.hybrid.external.apriel_hybrid.configuration_ssm_hybrid_apriel import AprielSSMHybridConfig + + self._config = AprielSSMHybridConfig.from_pretrained(pretrained, trust_remote_code=True) + + def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: + """Create the model.""" + from fast_llm.models.hybrid.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridForCausalLM + + self._model = AprielSSMHybridForCausalLM.from_pretrained( + pretrained, + device=self._device, + torch_dtype=torch.bfloat16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), + **kwargs, + ) + + def _model_generate(self, context, max_length, stop, **generation_kwargs): + + stopping_criteria = lm_eval.models.utils.stop_sequences_criteria( + self.tokenizer, + stop, + context.shape[1], + context.shape[0], + ) + + generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", None) + + # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies + if generation_kwargs.get("temperature") == 0.0 and do_sample is None: + generation_kwargs["do_sample"] = do_sample = False + if do_sample is False and generation_kwargs.get("temperature") == 0.0: + generation_kwargs.pop("temperature") + return self.model.generate( + input_ids=context, + max_length=max_length, + stopping_criteria=stopping_criteria, + use_cache=True, + **generation_kwargs, + ) + + +@register_model("apriel_hybrid_ssm_15b") +class AprielHybridSSMWrapper(HFLM): + """Wrapper for AprielHybridSSM model for compatibility with lm-evaluation-harness.""" + + def __init__(self, pretrained, **kwargs) -> None: + if "backend" in kwargs: + assert kwargs["backend"] == "causal" + + super().__init__( + pretrained=pretrained, + backend=kwargs.pop("backend", "causal"), + tokenizer=kwargs.pop("tokenizer", "/mnt/checkpoints/upstream/Mistral-Nemo-Base-2407/"), + max_length=kwargs.pop("max_length", 4096), + **kwargs, + ) + + def _get_config(self, pretrained: str, **kwargs) -> None: + """Get the model configuration.""" + from fast_llm.models.hybrid.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import ( + AprielSSMHybridConfig, + ) + + self._config = AprielSSMHybridConfig.from_pretrained(pretrained, trust_remote_code=True) + + def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: + """Create the model.""" + from fast_llm.models.hybrid.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import ( + AprielSSMHybridForCausalLM, + ) + + self._model = AprielSSMHybridForCausalLM.from_pretrained( + pretrained, + device=self._device, + torch_dtype=torch.bfloat16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), + **kwargs, + ) + + def _model_generate(self, context, max_length, stop, **generation_kwargs): + + stopping_criteria = lm_eval.models.utils.stop_sequences_criteria( + self.tokenizer, + stop, + context.shape[1], + context.shape[0], + ) + + generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", None) + + # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies + if generation_kwargs.get("temperature") == 0.0 and do_sample is None: + generation_kwargs["do_sample"] = do_sample = False + if do_sample is False and generation_kwargs.get("temperature") == 0.0: + generation_kwargs.pop("temperature") + return self.model.generate( + input_ids=context, + max_length=max_length, + stopping_criteria=stopping_criteria, + use_cache=True, + **generation_kwargs, + ) diff --git a/fast_llm/models/hybrid/external/eval/run_lm_eval.py b/fast_llm/models/hybrid/external/eval/run_lm_eval.py new file mode 100644 index 00000000..b6313cf1 --- /dev/null +++ b/fast_llm/models/hybrid/external/eval/run_lm_eval.py @@ -0,0 +1,9 @@ +from lm_eval.__main__ import cli_evaluate + +from fast_llm.models.hybrid.external.eval.apriel_eval_wrapper import ( # noqa: F401 + AprielHybridSSMWrapper, + AprielSSMWrapper, +) + +if __name__ == "__main__": + cli_evaluate() diff --git a/fast_llm/models/hybrid/external/llamba/configuration_mtp_llamba.py b/fast_llm/models/hybrid/external/llamba/configuration_mtp_llamba.py new file mode 100644 index 00000000..b8173b73 --- /dev/null +++ b/fast_llm/models/hybrid/external/llamba/configuration_mtp_llamba.py @@ -0,0 +1,94 @@ +from enum import Enum + +from transformers.configuration_utils import PretrainedConfig + + +class StateUpdateKernel(Enum): + ssu_verification = "ssu_verification" # selective scan for multi-token verification, not implemented yet + cs = "chunk_scan" # see https://proceedings.mlr.press/v262/wu24a.html + ssu = "standard" # usual one token per time-step inference using selective-scan update, no verification + + +class MTPLlambaConfig(PretrainedConfig): + r"""Configuration class for the CustomMamba model. + + This configuration is used to instantiate the CustomMamba model according to the specified arguments, + defining the model architecture. + + Args: + vocab_size (`int`, *optional*, defaults to 128256): + Vocabulary size of the model. + tie_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + pad_vocab_size_multiple (`int`, *optional*, defaults to 8): + Pad the vocabulary size up to the next multiple of this value. + lm_head_bias (`bool`, *optional*, defaults to `False`): + Whether the LM head includes a bias term. + d_model (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + lm_head_prenorm (`str`, *optional*, defaults to "rms"): + Normalization type for LM head. + n_layer (`int`, *optional*, defaults to 32): + Number of layers in the model. + resid_dropout (`float`, *optional*, defaults to 0.0): + Dropout rate for residual connections. + norm_epsilon (`float`, *optional*, defaults to 1e-5): + Epsilon value used for normalization layers. + mlp_cfg (`dict`, *optional*): + Configuration for the MLP (Multi-Layer Perceptron) layer, including intermediate size, activation function, and whether to use bias. + ssm_cfg (`dict`, *optional*): + Configuration for the SSM (State Space Model) layer, including d_state, number of heads, expansion, and other parameters. + + """ + + model_type = "llamba" + + def __init__( + self, + vocab_size: int, + d_model: int, + tie_embeddings: bool = False, + pad_vocab_size_multiple: int = 8, + lm_head_bias: bool = False, + n_layer: int = 32, + resid_dropout: float = 0.0, + norm_epsilon: float = 1e-5, + mlp_cfg: dict = None, + ssm_cfg: dict = None, + prediction_heads=1, + state_update_kernel: StateUpdateKernel = StateUpdateKernel.cs, + **kwargs, + ): + super().__init__(**kwargs) + + self.vocab_size = vocab_size + self.tie_embeddings = tie_embeddings + self.pad_vocab_size_multiple = pad_vocab_size_multiple + self.lm_head_bias = lm_head_bias + self.d_model = d_model + self.n_layer = n_layer + self.resid_dropout = resid_dropout + self.norm_epsilon = norm_epsilon + self.prediction_heads = prediction_heads + assert ( + state_update_kernel != StateUpdateKernel.ssu_verification + ), "Only chunk scan and standard modes are supported for now" + self.state_update_kernel = state_update_kernel + + # MLP (Multi-Layer Perceptron) Config + self.mlp_cfg = mlp_cfg or { + "intermediate_size": 14336, + "bias": False, + "act_fn": "silu", + } + + # SSM (State Space Model) Config + self.ssm_cfg = ssm_cfg or { + "d_state": 64, + "n_v_heads": 32, + "n_qk_heads": 32, + "expand": 1, + "chunk_size": 128, + "activation": "identity", + "bias": False, + } diff --git a/fast_llm/models/hybrid/external/llamba/modeling_mtp_llamba.py b/fast_llm/models/hybrid/external/llamba/modeling_mtp_llamba.py new file mode 100644 index 00000000..6d9746db --- /dev/null +++ b/fast_llm/models/hybrid/external/llamba/modeling_mtp_llamba.py @@ -0,0 +1,389 @@ +# Copyright (c) 2024, Kevin Li, Aviv Bick. + +import json +import os +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin +from mamba_ssm.utils.generation import GenerationMixin +from torch import Tensor, nn +from transformers.activations import ACT2FN +from transformers.utils.generic import ModelOutput + +from .configuration_mtp_llamba import MTPLlambaConfig as LlambaConfig +from .discrete_mamba2 import DiscreteMamba2 + + +class LlamaRMSNorm(nn.Module): + """LlamaRMSNorm (taken from transformers.models.llama.modeling_llama.LlamaRMSNorm).""" + + def __init__(self, hidden_size, eps=1e-6, factory_kwargs=None): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size, **factory_kwargs)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + """ + Args: + hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size). + + Returns: + torch.Tensor of shape (batch_size, seq_len, hidden_size). + """ + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + """Set the extra representation of the module.""" + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class LlamaMLP(nn.Module): + """LlamaMLP (taken from transformers.models.llama.modeling_llama.LlamaMLP).""" + + def __init__(self, hidden_size, intermediate_size, bias, act_fn, factory_kwargs=None): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias, **factory_kwargs) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias, **factory_kwargs) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias, **factory_kwargs) + self.act_fn = ACT2FN[act_fn] + + def forward(self, x): + """ + Args: + x: torch.Tensor of shape (batch_size, seq_len, hidden_size). + + Returns: + torch.Tensor of shape (batch_size, seq_len, hidden_size). + """ + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +@dataclass +class CustomMambaCausalLMOutput(ModelOutput): + """Custom output class for MambaLMHeadModel.""" + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None + last_hidden_state: Optional[torch.FloatTensor] = None + + +class MTPLlambaLMHeadModel(nn.Module, GenerationMixin, PyTorchModelHubMixin): + """MambaLM model with a language modeling head on top (linear layer).""" + + def __init__(self, config, initializer_cfg=None, device=None, dtype=None, **kwargs) -> None: + super().__init__() + + # Load config + if not isinstance(config, LlambaConfig): + config = LlambaConfig(**config) + self.config = config + + # Factory kwargs + factory_kwargs = {"device": device, "dtype": dtype} + + # Pad vocab size to be a multiple of pad_vocab_size_multiple + vocab_size = config.vocab_size + pad_vocab_size_multiple = config.pad_vocab_size_multiple + if vocab_size % pad_vocab_size_multiple != 0: + vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) + self.config.vocab_size = vocab_size + + # Mixer model + self.backbone = MixerModel( + input_size=vocab_size, + config=self.config, + initializer_cfg=initializer_cfg, + **factory_kwargs, + ) + + # MTP heads + self.mtp_heads = nn.ModuleList( + [ + Block( + config=config, + factory_kwargs=factory_kwargs, + layer_idx=layer_idx, + ).to(device) + for layer_idx in range(config.n_layer, config.n_layer + config.prediction_heads - 1) + ] + ) + + self.mtp_norms = nn.ModuleList( + [ + LlamaRMSNorm(config.d_model, eps=config.norm_epsilon, factory_kwargs=factory_kwargs) + for _ in range(config.prediction_heads - 1) + ] + ) + # LM head + if not self.config.tie_embeddings: + self.lm_head = nn.Linear( + in_features=self.config.d_model, + out_features=self.config.vocab_size, + bias=self.config.lm_head_bias, + **factory_kwargs, + ) + else: + self.lm_head = lambda x: x @ self.backbone.embedding.weight.t() + + def allocate_inference_cache(self, *args, **kwargs): + """Allocate inference cache for the model.""" + + mtps = { + i + self.config.n_layer: layer.allocate_inference_cache(*args, **kwargs) + for i, layer in enumerate(self.mtp_heads) + } + return {**self.backbone.allocate_inference_cache(*args, **kwargs), **mtps} + + def forward( + self, + input_ids, + position_ids=None, + return_hidden_states=False, + return_logits=True, + inference_params=None, + num_last_tokens=0, + ): + """ + Args: + input_ids: torch.Tensor of shape (batch_size, seq_len), + position_ids: torch.Tensor of shape (batch_size, seq_len), optional, not used (just for compatibility), + return_hidden_states: bool, optional, + return_logits: bool, optional, whether to compute the logits with the LM head, + inference_params: dict, optional, the model's inference cache, + num_last_tokens: int, optional. If > 0, only return the logits for the last n tokens. + + Returns: + CustomMambaCausalLMOutput. + + """ + outputs = self.backbone( + input_ids, + return_hidden_states=return_hidden_states, + inference_params=inference_params, + position_ids=position_ids, + ) + + # MTP heads processing + latents = [] + hidden_states = outputs["last_hidden_state"] + hidden_states_before_last = outputs["hidden_state_before_last"] + + # last layer already has layer norm applied + latents.append(hidden_states) + + # Process through MTP heads + for i, mtp_head in enumerate(self.mtp_heads): + mtp_outputs = mtp_head( + hidden_states_before_last, + inference_params=inference_params, + position_ids=position_ids, + ) + mtp_hidden_states = mtp_outputs["hidden_states"] + latents.append(self.mtp_norms[i](mtp_hidden_states)) + + # Stack the latents to get (batch_size, seq_len, num_prediction_heads, hidden_size) + stacked_latents = torch.stack(latents, dim=-2) + + if return_logits: + if isinstance(self.lm_head, nn.Linear): + # Apply lm_head to each prediction head's output + logits = self.lm_head(stacked_latents).float() + else: + # Using the tied embedding weights + logits = self.lm_head(stacked_latents) + + outputs["logits"] = logits if num_last_tokens == 0 else logits[:, -num_last_tokens:] + else: + outputs["logits"] = None + + return CustomMambaCausalLMOutput( + loss=None, + logits=outputs["logits"], + all_hidden_states=outputs["all_hidden_states"], + last_hidden_state=stacked_latents, + ) + + def save_pretrained(self, save_directory): + """ + Minimal implementation of save_pretrained for MambaLMHeadModel. + Save the model and its configuration file to a directory. + """ + # Ensure save_directory exists + if not os.path.exists(save_directory): + os.makedirs(save_directory) + + # Save the model's state_dict + model_path = os.path.join(save_directory, "pytorch_model.bin") + torch.save(self.state_dict(), model_path) + + # Save the configuration of the model + config_path = os.path.join(save_directory, "config.json") + with open(config_path, "w") as f: + json.dump(self.config.to_dict(), f) + + +class MixerModel(nn.Module): + """Mixer model with a stack of Mixer layers.""" + + def __init__(self, input_size, config=None, device=None, dtype=None, **kwargs) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.config = config + self.embedding = nn.Embedding(input_size, self.config.d_model, **factory_kwargs) + + self.layers = nn.ModuleList( + [ + Block( + config=config, + factory_kwargs=factory_kwargs, + layer_idx=i, + ).to(device) + for i in range(self.config.n_layer) + ] + ) + + self.final_layernorm = LlamaRMSNorm( + hidden_size=self.config.d_model, + eps=self.config.norm_epsilon, + factory_kwargs=factory_kwargs, + ) + + return + + def allocate_inference_cache(self, *args, **kwargs): + """Allocate inference cache for the model.""" + return {i: layer.allocate_inference_cache(*args, **kwargs) for i, layer in enumerate(self.layers)} + + def forward( + self, + input_ids, + return_hidden_states=False, + inference_params=None, + position_ids=None, + ): + """Run the model.""" + # Start running the layers + hidden_states = self.embedding(input_ids) + + # Initialize outputs + outputs = { + "last_hidden_state": None, + "hidden_state_before_last": None, + "all_hidden_states": (hidden_states,) if return_hidden_states else (), + } + + # Run the layers + for layer in self.layers: + layer_outputs = layer( + hidden_states, + inference_params=inference_params, + position_ids=position_ids, + ) + if layer == self.layers[-1]: + outputs["hidden_state_before_last"] = hidden_states + # Record outputs + hidden_states = layer_outputs["hidden_states"] + if return_hidden_states: + outputs["all_hidden_states"] += (hidden_states,) + + # Last layer, apply layer norm + outputs["last_hidden_state"] = self.final_layernorm(hidden_states) + return outputs + + +class Block(nn.Module): + """ + Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection. + + This Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA/MLP -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Add -> LN -> Mixer, returning both + the hidden_states (output of the mixer) and the residual. + This is purely for performance reasons, as we can fuse add and LayerNorm. + The residual needs to be provided (except for the very first block). + """ + + def __init__(self, config, factory_kwargs, layer_idx, **kwargs): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + # Mixer + self.mixer = DiscreteMamba2( + d_model=self.config.d_model, + layer_idx=layer_idx, + **config.ssm_cfg, + **factory_kwargs, + ) + + # Other components + self.input_layernorm = LlamaRMSNorm(hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs) + self.post_attention_layernorm = LlamaRMSNorm( + hidden_size=self.config.d_model, eps=1e-5, factory_kwargs=factory_kwargs + ) + self.mlp = LlamaMLP( + hidden_size=self.config.d_model, + **config.mlp_cfg, + factory_kwargs=factory_kwargs, + ) + + def forward( + self, + hidden_states: Tensor, + inference_params=None, + **kwargs, + ): + """ + Pass the input through the encoder layer. + + Args: + hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size), + inference_params: dict, optional, + + Returns: + dict with keys: + hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size), + mamba_hidden_states: torch.Tensor of shape (batch_size, seq_len, hidden_size), + transfer_matrix: torch.Tensor of shape (batch_size, seq_len, seq_len). + """ + outputs = {} + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Apply Mixer + mixer_outputs = self.mixer( + hidden_states, + inference_params=inference_params, + ) + + hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs["hidden_states"] = hidden_states + + return outputs + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + """Allocate inference cache for the model.""" + if getattr(self.mixer, "allocate_inference_cache", None) is None: + return + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) diff --git a/fast_llm/models/hybrid/external/make_hybrid_checkpoint.py b/fast_llm/models/hybrid/external/make_hybrid_checkpoint.py new file mode 100644 index 00000000..2fe15c0d --- /dev/null +++ b/fast_llm/models/hybrid/external/make_hybrid_checkpoint.py @@ -0,0 +1,41 @@ +import gc + +import click +import torch +from transformers import AutoConfig, AutoModelForCausalLM + +from fast_llm.models.hybrid.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig +from fast_llm.models.hybrid.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import AprielSSMHybridForCausalLM + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +@click.command() +@click.option("--identity_index", type=int, required=True) +@click.option("--save_dir", type=str, required=True) +def main(identity_index: int, save_dir: str): + checkpoint = "ServiceNow-AI/Apriel-Nemotron-15b-Thinker" + config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True) + + hybrid_block_layout = ["t"] * config.num_hidden_layers + if identity_index >= 0: + hybrid_block_layout[identity_index] = "i" + + hybrdif_apriel_config = AprielSSMHybridConfig(**config.to_dict(), hybrid_block_layout=hybrid_block_layout) + hybrid_apriel_model = AprielSSMHybridForCausalLM(hybrdif_apriel_config) + hybrid_apriel_model.to(dtype=torch.bfloat16).to(device) + + apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True) + apriel_state_dict = apriel_model.state_dict() + hybrid_apriel_model.load_state_dict(apriel_state_dict, strict=False) + + hybrid_apriel_model.save_pretrained(save_dir, save_config=True) + torch.cuda.empty_cache() + del hybrid_apriel_model + del apriel_model + del apriel_state_dict + gc.collect() + + +if __name__ == "__main__": + main() diff --git a/fast_llm/models/hybrid/huggingface.py b/fast_llm/models/hybrid/huggingface.py new file mode 100644 index 00000000..8191a5a2 --- /dev/null +++ b/fast_llm/models/hybrid/huggingface.py @@ -0,0 +1,21 @@ +import logging + +from fast_llm.engine.inference.config import HuggingfaceModelConfig +from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM +from fast_llm.models.hybrid.config import HybridModelConfig +from fast_llm.models.hybrid.model import HybridModel + +logger = logging.getLogger(__name__) + + +class HuggingfaceSSMModelConfig(HuggingfaceModelConfig): + model_type = "fast_llm_ssm" + model_config_class = HybridModelConfig + fast_llm_config: HybridModelConfig + + +class HuggingfaceHybridModelForCausalLM(HuggingfaceGPTModelForCausalLM): + config_class = HuggingfaceSSMModelConfig + config: HuggingfaceSSMModelConfig + model_class = HybridModel + _fast_llm_model: HybridModel diff --git a/fast_llm/models/hybrid/model.py b/fast_llm/models/hybrid/model.py new file mode 100644 index 00000000..cee3f4d2 --- /dev/null +++ b/fast_llm/models/hybrid/model.py @@ -0,0 +1,87 @@ +import logging +import typing + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.layers.language_model.embedding import LanguageModelEmbedding +from fast_llm.layers.language_model.head import LanguageModelHead +from fast_llm.layers.transformer.transformer import BaseBlock +from fast_llm.models.gpt.model import GPTBaseModel +from fast_llm.models.hybrid.config import HybridBaseModelConfig, HybridModelConfig + +logger = logging.getLogger(__name__) + + +class HybridBaseModel[ConfigType: HybridBaseModelConfig](GPTBaseModel[ConfigType]): + """ + A hybrid model that can interleave Transformer, Mamba and other blocks. + """ + + config_class: typing.ClassVar[type[HybridBaseModelConfig]] = HybridBaseModelConfig + _is_setup: bool = False + + def __init__( + self, + config: HybridBaseModelConfig, + distributed_config: DistributedConfig, + ): + + super().__init__(config, distributed_config) + + def get_output_layers(self) -> list[Layer]: + """ + Get the output layers of the model. + This includes the language model head and any additional heads specified in the configuration. + """ + layers = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] + + if self._config.prediction_heads > 1: + block_name = self._config.default_mtp_type + assert block_name in self._config.registered_blocks, f"Block {block_name} not found in config" + BLOCK_CLS = self._config.registered_blocks[block_name].block_class + for i in range(1, self._config.prediction_heads): + layers.append( + BLOCK_CLS( + self._config.registered_blocks[block_name], + self._tensor_space, + layer_index=len(self._config.block_layout), + return_input=i != self._config.prediction_heads - 1, + block_name=block_name, + ) + ) + layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=i)) + + return layers + + def get_layers(self) -> list[Layer]: + """ + Create a list of layers for the model, interleaving Transformer and Mamba blocks + according to the block pattern. + """ + layers = [LanguageModelEmbedding(self._config, self._tensor_space)] + + # Create blocks according to pattern + for i, block_name in enumerate(self._config.block_layout): + BLOCK_CLS: BaseBlock = self._config.blocks[block_name].block_class + layers.append( + BLOCK_CLS( + self._config.blocks[block_name], + self._tensor_space, + layer_index=i + 1, + return_input=(i == len(self._config.block_layout) - 1 and self._config.prediction_heads > 1), + block_name=block_name, + ) + ) + layers += self.get_output_layers() + + return layers + + +class HybridModel[ConfigType: HybridModelConfig](FastLLMModel[ConfigType]): + """ + A hybrid model that combines Transformer and SSM blocks. + """ + + config_class: typing.ClassVar[type[HybridModelConfig]] = HybridModelConfig + base_model_class: typing.ClassVar[type[HybridBaseModel]] = HybridBaseModel diff --git a/fast_llm/models/ssm/trainer.py b/fast_llm/models/hybrid/trainer.py similarity index 55% rename from fast_llm/models/ssm/trainer.py rename to fast_llm/models/hybrid/trainer.py index c0e5be26..9e489e89 100644 --- a/fast_llm/models/ssm/trainer.py +++ b/fast_llm/models/hybrid/trainer.py @@ -1,10 +1,10 @@ import typing from fast_llm.models.gpt.trainer import GPTTrainer -from fast_llm.models.ssm.config import HybridTrainerConfig -from fast_llm.models.ssm.model import HybridSSMModel +from fast_llm.models.hybrid.config import HybridTrainerConfig +from fast_llm.models.hybrid.model import HybridModel class SSMTrainer[ConfigType: HybridTrainerConfig](GPTTrainer[ConfigType]): config_class: typing.ClassVar[type[HybridTrainerConfig]] = HybridTrainerConfig - model_class: typing.ClassVar[type[HybridSSMModel]] = HybridSSMModel + model_class: typing.ClassVar[type[HybridModel]] = HybridModel diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py deleted file mode 100644 index 771a4fca..00000000 --- a/fast_llm/models/ssm/config.py +++ /dev/null @@ -1,168 +0,0 @@ -import logging -import math -import typing - -from fast_llm.config import Field, FieldHint, FieldUpdate, config_class -from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig -from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelBaseConfig -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.models.gpt.config import GPTBatchConfig -from fast_llm.utils import Assert - -if typing.TYPE_CHECKING: - from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM - from fast_llm.models.ssm.model import HybridSSMModel - from fast_llm.models.ssm.trainer import SSMTrainer - -logger = logging.getLogger(__name__) - - -@config_class() -class HybridSSMBaseModelConfig(LanguageModelBaseConfig): - _abstract = False - - ssm: SSMConfig = Field( - desc="Configuration for the transformer architecture.", - hint=FieldHint.architecture, - ) - hybrid_block_layout: list[str] = Field( - default_factory=lambda: ["m2"], - desc="Pattern of blocks to use in the model. 't' for Transformer, 'm' for Mamba1, 'm2' for Discrete Mamba2", - hint=FieldHint.architecture, - ) - default_mtp_type: str | None = Field( - default=None, - desc="Multi-token prediction mixer to use in the model. 't' for Transformer, 'm' for Mamba1, 'm2' for discrete Mamba2. If None, will use the last block type in `hybrid_block_layout`.", - hint=FieldHint.optional, - ) - use_megatron_initialization: bool = Field( - default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing - ) # TODO: is this needed? - - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - """ - Setup the tensor space for the model. - Some of these can be setup directly in the layer config, but keeping them here for clarity. - """ - super().setup_tensor_space(tensor_space) - if not "m2" in self.hybrid_block_layout and not "m" in self.hybrid_block_layout: - raise ValueError( - "Block pattern must contain at least one 'm' or 'm2', use gpt model for transformer only architectures" - ) - - if self.ssm.dt_rank < 0: - mamba_dt_rank = math.ceil(self.transformer.hidden_size / 16) - else: - mamba_dt_rank = self.ssm.dt_rank - - d_inner = int(self.ssm.expansion_factor * self.transformer.hidden_size) - # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.model_dim, self.transformer.hidden_size)) - # Mamba-specific dimensions - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_dim, d_inner)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.state_dim, self.ssm.state_size)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.dt_rank, mamba_dt_rank)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim, mamba_dt_rank + self.ssm.state_size * 2)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_kernel_size, self.ssm.conv_kernel_dimension)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba, d_inner * 2)) - - if "m2" in self.hybrid_block_layout or self.default_mtp_type == "m2": - # Mamba2 specific dimensions - # as per https://github.com/cartesia-ai/edge/blob/a0e121ebed3d2324c6d762b0e211a08d62583681/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py#L66C3-L66C4 - headdim = d_inner // self.ssm.n_v_heads - Assert.eq(self.ssm.n_v_heads, d_inner // headdim) - Assert.eq(d_inner % headdim, 0) - Assert.eq(self.ssm.n_v_heads % self.ssm.n_qk_heads, 0) - - conv_dim = d_inner + 2 * self.ssm.n_qk_heads * self.ssm.state_size - inner_proj_dim = 2 * d_inner + 2 * self.ssm.n_qk_heads * self.ssm.state_size + self.ssm.n_v_heads - - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.head_dim, headdim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.qk_heads, self.ssm.n_qk_heads)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.v_heads, self.ssm.n_v_heads)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba2, inner_proj_dim)) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_dim, conv_dim)) - - def _validate(self): - len_block_layout = len(self.hybrid_block_layout) - if len_block_layout != self.transformer.num_layers: - if self.transformer.num_layers % len_block_layout != 0: - raise ValueError( - f"hybrid_block_layout length {len_block_layout} does not match num_layers {self.transformer.num_layers}" - ) - num_repeats = int(self.transformer.num_layers // len_block_layout) - logger.warning( - f"hybrid_block_layout length {len_block_layout} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" - ) - self.hybrid_block_layout = self.hybrid_block_layout * num_repeats - - Assert.eq(len_block_layout, self.transformer.num_layers) - Assert.custom( - lambda _: all(block_type in ["t", "m", "m2"] for block_type in self.hybrid_block_layout), - f"Invalid block type: {self.hybrid_block_layout}. Must be 't' or 'm' or 'm2'", - ) - Assert.custom( - lambda _: self.default_mtp_type in ["t", "m", "m2", None], - f"Invalid MTP type: {self.default_mtp_type}. Must be 't' or 'm' or 'm2' or None", - ) - - super()._validate() - - -class LLambaHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False - name: typing.ClassVar[str] = "llamba" - - @classmethod - def get_handler_class(cls) -> type[CheckpointHandler]: - from fast_llm.models.ssm.conversion import LLambaHuggingfaceCheckpointHandler - - return LLambaHuggingfaceCheckpointHandler - - -@config_class() -class HybridSSMModelConfig(FastLLMModelConfig): - _abstract = False - model_name: typing.ClassVar[str] = "hybrid_ssm" - base_model: HybridSSMBaseModelConfig = FieldUpdate() - checkpoint_formats = FastLLMModelConfig.checkpoint_formats + (LLambaHuggingfaceCheckpointFormat,) - - @classmethod - def get_model_class(cls) -> type["HybridSSMModel"]: - from fast_llm.models.ssm.model import HybridSSMModel - - return HybridSSMModel - - @classmethod - def get_huggingface_model_class(cls) -> type["HuggingfaceHybridSSMModelForCausalLM"]: - from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM - - return HuggingfaceHybridSSMModelForCausalLM - - def _validate(self): - logger.warning( - "HybridSSMModelConfig is being instantiated. This model is experimental and may not work as expected." - ) - super()._validate() - - -@config_class() -class PretrainedHybridSSMModelConfig(PretrainedFastLLMModelConfig): - _abstract = False - model: HybridSSMModelConfig = FieldUpdate() - - -@config_class() -class HybridTrainerConfig(PretrainedHybridSSMModelConfig, TrainerConfig): - data: GPTDataConfig = FieldUpdate() - batch: GPTBatchConfig = FieldUpdate() - - @classmethod - def get_trainer_class(cls) -> type["SSMTrainer"]: - from fast_llm.models.ssm.trainer import SSMTrainer - - return SSMTrainer diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py deleted file mode 100644 index 190b2ffa..00000000 --- a/fast_llm/models/ssm/conversion.py +++ /dev/null @@ -1,284 +0,0 @@ -import json -import os -import pathlib -import typing - -from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import ( - ConstantImportParamConverter, - IgnoreImportWeightConverter, - MappedConfigParamConverter, - ParamConverter, - RenameParamConverter, - SplitWeightConverter, - WeightConverter, -) -from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler -from fast_llm.engine.multi_stage.config import FastLLMModelConfig -from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import NormalizationType -from fast_llm.models.gpt.conversion import MLPLayer2Converter -from fast_llm.models.ssm.config import HybridSSMModelConfig, LLambaHuggingfaceCheckpointFormat -from fast_llm.models.ssm.model import HybridSSMModel -from fast_llm.utils import Assert - -if typing.TYPE_CHECKING: - pass - - -class LLambaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): - _model: HybridSSMModel - _model_class: typing.ClassVar[FastLLMModelConfig] = HybridSSMModelConfig - format: typing.ClassVar[type[CheckpointFormat]] = LLambaHuggingfaceCheckpointFormat - - @classmethod - def _create_config_converters(cls) -> list[ParamConverter]: - """ - Create config converters for the model, see args under https://huggingface.co/cartesia-ai/Llamba-8B/blob/main/config.json - """ - return super()._create_config_converters() + [ - ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False), - RenameParamConverter( - fast_llm_names=(("transformer", "num_layers"),), - export_names=(("n_layer",),), - ), - ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True), - # TODO: is there an equivalen of pad_vocab_size_multiple in FastLLM, does it matter? - RenameParamConverter( - fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) - ), - RenameParamConverter( - fast_llm_names=(("ssm", "normalization", "epsilon"),), export_names=(("norm_epsilon",),) - ), - ConstantImportParamConverter( - fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm - ), - RenameParamConverter( - fast_llm_names=(("vocab_size",),), - export_names=(("vocab_size",),), - ), - RenameParamConverter( - fast_llm_names=(("tie_word_embeddings",),), - export_names=(("tie_embeddings",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "hidden_size"),), - export_names=(("d_model",),), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "ffn_hidden_size"),), - export_names=( - ( - "mlp_cfg", - "intermediate_size", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("transformer", "add_linear_biases"),), - export_names=( - ( - "mlp_cfg", - "bias", - ), - ), - ), - MappedConfigParamConverter( - fast_llm_names=(("transformer", "activation_type"),), - export_names=( - ( - "mlp_cfg", - "act_fn", - ), - ), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - RenameParamConverter( - fast_llm_names=(("ssm", "state_size"),), - export_names=( - ( - "ssm_cfg", - "d_state", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("ssm", "n_v_heads"),), - export_names=( - ( - "ssm_cfg", - "n_v_heads", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("ssm", "n_qk_heads"),), - export_names=( - ( - "ssm_cfg", - "n_qk_heads", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("ssm", "expansion_factor"),), - export_names=( - ( - "ssm_cfg", - "expand", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("ssm", "chunk_size"),), - export_names=( - ( - "ssm_cfg", - "chunk_size", - ), - ), - ), - RenameParamConverter( - fast_llm_names=(("ssm", "add_bias_linear"),), - export_names=( - ( - "ssm_cfg", - "bias", - ), - ), - ), - MappedConfigParamConverter( - fast_llm_names=(("ssm", "activation_type"),), - export_names=( - ( - "ssm_cfg", - "activation", - ), - ), - fast_llm_value=ActivationType.from_hf_name, - export_value=lambda activation_type: activation_type.hf_name, - ), - ] - - def _create_weight_converters(self) -> list[WeightConverter]: - converters = [] - num_layers = self._model.config.base_model.transformer.num_layers - norm_bias: bool = False - ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear - - # Embedding and output - if self._model.config.base_model.tie_word_embeddings: - converters.append(WeightConverter("layers.0.word_embeddings_weight", "backbone.embedding.weight")) - converters.append(IgnoreImportWeightConverter((), "lm_head.weight")) - else: - converters.append(WeightConverter("layers.0.word_embeddings_weight", "backbone.embedding.weight")) - converters.append(WeightConverter(f"layers.{num_layers + 1}.output_weights", "lm_head.weight")) - - # Final norm - converters += self._get_weight_and_bias_converters( - f"layers.{num_layers + 1}.final_norm", "backbone.final_layernorm", norm_bias - ) - - for i in range(num_layers): - # SSM - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.in_proj", f"backbone.layers.{i}.mixer.in_proj", ssm_bias - ) - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.mixer.out_proj", f"backbone.layers.{i}.mixer.out_proj", ssm_bias - ) - converters.append( - WeightConverter(f"layers.{i+1}.mixer.D", f"backbone.layers.{i}.mixer.D", self._model.config.base_model) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.z_bias", f"backbone.layers.{i}.mixer.z_bias", self._model.config.base_model - ) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.conv1d_weight", - f"backbone.layers.{i}.mixer.conv1d.weight", - self._model.config.base_model, - ) - ) - converters.append( - WeightConverter( - f"layers.{i+1}.mixer.conv1d_bias", - f"backbone.layers.{i}.mixer.conv1d.bias", - self._model.config.base_model, - ) - ) - - # Norm - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.norm_1", f"backbone.layers.{i}.input_layernorm", norm_bias - ) - converters += self._get_weight_and_bias_converters( - f"layers.{i+1}.norm_2", f"backbone.layers.{i}.post_attention_layernorm", norm_bias - ) - - # MLP - converters += self._get_mlp_converters(f"layers.{i+1}", f"backbone.layers.{i}") - - return converters - - def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases - return [ - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", - (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - linear_bias, - SplitWeightConverter, - ), - *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_2", - f"{hf_prefix}.mlp.down_proj", - linear_bias, - MLPLayer2Converter, - ), - ] - - def _get_weight_and_bias_converters( - self, - fast_llm_prefix: str | tuple[str, ...], - hf_prefix: str | tuple[str, ...], - use_bias: bool, - cls=WeightConverter, - ) -> list[WeightConverter]: - if isinstance(fast_llm_prefix, str): - fast_llm_prefix = (fast_llm_prefix,) - if isinstance(hf_prefix, str): - hf_prefix = (hf_prefix,) - converters = [ - cls( - tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), - tuple(f"{prefix}.weight" for prefix in hf_prefix), - self._model.config.base_model, - ) - ] - if use_bias: - converters.append( - cls( - tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), - tuple(f"{prefix}.bias" for prefix in hf_prefix), - self._model.config.base_model, - ) - ) - return converters - - @classmethod - def _load_config(cls, directory: pathlib.Path | str) -> dict: - if not os.path.exists(directory / "config.json"): - raise FileNotFoundError(f"config.json not found in {directory}") - with open(directory / "config.json") as f: - config = json.load(f) - Assert.eq(config["model_type"], cls.get_huggingface_model_type()) - return config - - @classmethod - def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None: - with open(directory / "config.json", "w") as f: - json.dump(config, f) diff --git a/fast_llm/models/ssm/huggingface.py b/fast_llm/models/ssm/huggingface.py deleted file mode 100644 index 77cd346f..00000000 --- a/fast_llm/models/ssm/huggingface.py +++ /dev/null @@ -1,21 +0,0 @@ -import logging - -from fast_llm.engine.huggingface.config import HuggingfaceModelConfig -from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM -from fast_llm.models.ssm.config import HybridSSMModelConfig -from fast_llm.models.ssm.model import HybridSSMModel - -logger = logging.getLogger(__name__) - - -class HuggingfaceSSMModelConfig(HuggingfaceModelConfig): - model_type = "fast_llm_ssm" - model_config_class = HybridSSMModelConfig - fast_llm_config: HybridSSMModelConfig - - -class HuggingfaceHybridSSMModelForCausalLM(HuggingfaceGPTModelForCausalLM): - config_class = HuggingfaceSSMModelConfig - config: HuggingfaceSSMModelConfig - model_class = HybridSSMModel - _fast_llm_model: HybridSSMModel diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py deleted file mode 100644 index 6ff6c5f5..00000000 --- a/fast_llm/models/ssm/model.py +++ /dev/null @@ -1,143 +0,0 @@ -import logging -import typing - -from fast_llm.engine.base_model.base_model import Layer -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel -from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 -from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.ssm.mamba_layer import MambaLayer -from fast_llm.layers.transformer.transformer import TransformerLayer -from fast_llm.models.gpt.model import GPTBaseModel -from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig - -logger = logging.getLogger(__name__) - - -class HybridSSMBaseModel[ConfigType: HybridSSMBaseModelConfig](GPTBaseModel[ConfigType]): - """ - A hybrid model that interleaves Transformer and Mamba blocks. - Right now only LlambaBlock is supported. - As for the mixer, transformer uses MHA. For the LlambaBlock we support Mamba1 and discrete mamba2. - """ - - config_class: typing.ClassVar[type[HybridSSMBaseModelConfig]] = HybridSSMBaseModelConfig - _is_setup: bool = False - - def __init__( - self, - config: HybridSSMBaseModelConfig, - distributed_config: DistributedConfig, - ): - self.SSM_BLOCK_CLS = LlambaBlock # TODO: extend to other block types if needed - super().__init__(config, distributed_config) - - def get_output_layers(self) -> list[Layer]: - """ - Get the output layers of the model. - This includes the language model head and any additional heads specified in the configuration. - """ - layers = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] - - if self._config.prediction_heads > 1: - block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1] - for i in range(1, self._config.prediction_heads): - if block_type == "t": - layers.append( - TransformerLayer( - self._config.transformer, - self._tensor_space, - layer_index=len(self._config.hybrid_block_layout), - return_input=i != self._config.prediction_heads - 1, - ) - ) - elif block_type == "m2": - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == "m": - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - else: - raise ValueError(f"Invalid block type: {block_type}. Must be 't' or 'm' or 'm2'") - layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=i)) - - return layers - - def get_layers(self) -> list[Layer]: - """ - Create a list of layers for the model, interleaving Transformer and Mamba blocks - according to the block pattern. - """ - layers = [LanguageModelEmbedding(self._config, self._tensor_space)] - - # Create blocks according to pattern - for i, block_type in enumerate(self._config.hybrid_block_layout): - if block_type == "t": - # Transformer block - layers.append( - TransformerLayer( - self._config.transformer, - self._tensor_space, - layer_index=i + 1, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - ) - elif block_type == "m2": - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - elif block_type == "m": - # Create Mamba block - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - else: - raise ValueError(f"Invalid block type: {block_type}. Must be 't' or 'm' or 'm2'") - - # Add the output layers - layers += self.get_output_layers() - - return layers - - -class HybridSSMModel[ConfigType: HybridSSMModelConfig](FastLLMModel[ConfigType]): - """ - A hybrid model that combines Transformer and SSM blocks. - """ - - config_class: typing.ClassVar[type[HybridSSMModelConfig]] = HybridSSMModelConfig - base_model_class: typing.ClassVar[type[HybridSSMBaseModel]] = HybridSSMBaseModel diff --git a/fast_llm/utils.py b/fast_llm/utils.py index d89c9d76..04cfb001 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -339,6 +339,24 @@ def check_equal_nested(config_a, config_b): raise ValueError("\n".join(errors)) +def get_lr_scale( + lr_scale: float | None | tuple[float | None, ...], layer_lr_scale: float | None +) -> float | None | tuple[float | None, ...]: + """ + Combine module and layer lr_scale. + If one is None, return the other. + """ + if lr_scale is None: + return layer_lr_scale + if layer_lr_scale is None: + return lr_scale + if isinstance(lr_scale, float): + return lr_scale * layer_lr_scale + if isinstance(lr_scale, tuple): + return tuple(lrs * layer_lr_scale if lrs is not None else layer_lr_scale for lrs in lr_scale) + raise ValueError(f"Invalid lr_scale: {lr_scale} (type {type(lr_scale)})") + + class Interrupter: def __init__(self, enabled: bool = True, signals: typing.Sequence[int] = (signal.SIGINT, signal.SIGTERM)): self._enabled = enabled diff --git a/tests/common.py b/tests/common.py index 6179957b..94175792 100644 --- a/tests/common.py +++ b/tests/common.py @@ -23,7 +23,7 @@ Qwen2GPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) -from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, LLambaHuggingfaceCheckpointFormat +from fast_llm.models.hybrid.config import HybridBaseModelConfig, LLambaHuggingfaceCheckpointFormat from fast_llm.tools.train import CliTrainingConfig from tests.compare_tensor_logs import CompareConfig, compare_tensor_logs @@ -36,7 +36,7 @@ FORCE_REUSE_RESULTS = int(os.environ.get("FORCE_REUSE_RESULTS", 0)) != 0 REUSE_RESULTS = FORCE_REUSE_RESULTS or int(os.environ.get("REUSE_RESULTS", 0)) != 0 _LOG_LEVEL = int(os.environ.get("LOG_LEVEL", 13)) -TEST_MODEL = os.environ.get("MODEL", "llama") +TEST_MODEL = os.environ.get("MODEL", "llamba") ARTIFACT_PATH = "runs/0/artifacts" @@ -204,7 +204,7 @@ ] CONFIG_LLAMA_MTP_COMMON = CONFIG_LLAMA_MTP_FAST_LLM + ["model.distributed.training_dtype=bf16"] -CONFIG_LLAMBA_FAST_LLM = CONFIG_LLAMA_FAST_LLM + ["model.base_model.hybrid_block_layout==['t','m']"] +CONFIG_LLAMBA_FAST_LLM = CONFIG_LLAMA_FAST_LLM + ["model.base_model.hybrid_block_layout=['m2d','m2d']"] CONFIG_LLAMBA_MEGATRON = CONFIG_LLAMA_MEGATRON + [] CONFIG_LLAMBA_COMMON = CONFIG_LLAMBA_FAST_LLM @@ -449,7 +449,7 @@ def materialize_meta_tensors(model, tensor_space): def get_hybrid_config(hybrid_block_layout=["t", "m"], prediction_heads=1, default_mtp_type=None): - config = HybridSSMBaseModelConfig( + config = HybridBaseModelConfig( transformer=TransformerConfig(num_layers=len(hybrid_block_layout)), ssm=SSMConfig(), hybrid_block_layout=hybrid_block_layout, diff --git a/tests/test_config.py b/tests/test_config.py index 80bed418..62069162 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -13,6 +13,7 @@ from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.models.auto import trainer_registry from fast_llm.models.gpt.config import GPTModelConfig, PretrainedGPTModelConfig +from fast_llm.models.hybrid.config import HybridBaseModelConfig from fast_llm.utils import Assert, check_equal_nested from tests.common import TEST_RESULTS_PATH @@ -132,11 +133,11 @@ def test_pretrained_config(load_config: ModelConfigType): "transformer": { # rotary: Don't override nested. "normalization": {"implementation": "triton"}, # Update non-default nested - "peft": {"freeze_others": False}, # Update default nested, non-architecture "hidden_size": 512, # Override, affects derived value (kv channels) "head_groups": 1, # Override to default }, "vocab_size": 1000, + "head_normalization": {"type": "rms_norm"}, } pretrained_config = PretrainedGPTModelConfig.from_dict( { @@ -159,7 +160,6 @@ def test_pretrained_config(load_config: ModelConfigType): "transformer": { "normalization": {"type": "rms_norm", "implementation": "triton"}, "rotary": {"type": "default"}, - "peft": {"freeze_others": False}, "num_layers": 12, "hidden_size": 512, "ffn_hidden_size": 4096, @@ -169,8 +169,62 @@ def test_pretrained_config(load_config: ModelConfigType): }, "tie_word_embeddings": False, "vocab_size": 1000, + "head_normalization": {"type": "rms_norm"}, } else: expected_config["base_model"] = base_model_update check_equal_nested(serialized_config, expected_config) + + +# TODO: add test for hybrid pretrained config as above +def test_hybrid_block_modular_config(): + + config = { + "blocks": { + "bob_shared": { + "type": "transformer", + "hidden_size": 512, + "share_weights": True, + }, + "mamba_non_shared": { + "type": "discrete_mamba2", + "state_size": 16, + "expansion_factor": 2, + "hidden_size": 512, + "share_weights": False, + }, + }, + "hybrid_block_layout": ["bob_shared", "mamba_non_shared", "bob_shared", "mamba_non_shared"], + "num_layers": 8, + } + + modular_config = HybridBaseModelConfig.from_dict(config) + modular_config.validate() + Assert.eq(modular_config.hybrid_block_layout, ["bob_shared", "mamba_non_shared", "bob_shared", "mamba_non_shared"]) + Assert.eq( + modular_config.block_layout, + [ + "bob_shared", + "mamba_non_shared_1", + "bob_shared", + "mamba_non_shared_3", + "bob_shared", + "mamba_non_shared_1", + "bob_shared", + "mamba_non_shared_3", + ], + ) # with num_layers = 8, the block_layout should be 8 blocks with repeated pattern of ["bob_shared", "mamba_non_shared", "bob_shared", "mamba_non_shared"] with names abjusted for weight sharing + for block_name in modular_config.block_layout: + Assert.custom( + lambda _: block_name in modular_config.registered_blocks, + f"Block {block_name} not found in registered blocks", + ) + serialized = modular_config.to_dict() + reconstructed = HybridBaseModelConfig.from_dict(serialized) + Assert.eq(reconstructed.to_dict(), config) + reconstructed.validate() + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/test_ssms.py b/tests/test_hybrid.py similarity index 76% rename from tests/test_ssms.py rename to tests/test_hybrid.py index f1ef5654..16ee10e7 100644 --- a/tests/test_ssms.py +++ b/tests/test_hybrid.py @@ -1,5 +1,4 @@ import pathlib -from functools import partial import pytest import torch @@ -15,22 +14,23 @@ from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.models.gpt.config import GPTBatchConfig, LlamaGPTHuggingfaceCheckpointFormat -from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat +from fast_llm.models.hybrid.config import ( + AprielSSMHHybridHuggingfaceCheckpointFormat, + LLambaHuggingfaceCheckpointFormat, +) +from fast_llm.models.hybrid.model import HybridModel from tests.common import get_hybrid_config, materialize_meta_tensors try: - from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 - from fast_llm.layers.ssm.llamba_block import LlambaBlock from fast_llm.layers.ssm.mamba_layer import MambaLayer - from fast_llm.models.ssm.model import HybridSSMBaseModel, HybridSSMModel + from fast_llm.models.hybrid.model import HybridBaseModel except ImportError: - MambaLayer, LlambaBlock, HybridSSMBaseModel, DiscreteMamba2 = ( + MambaLayer, LlambaBlock, HybridBaseModel, DiscreteMamba2 = ( None, None, None, None, ) - # Mamba not installed, skipping tests try: from cartesia_pytorch.Llamba.llamba import LlambaLMHeadModel as LMHeadModel @@ -96,7 +96,7 @@ def test_load_from_llamba_checkpoint(distributed_config): # Create checkpoint load config checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) # Initialize model - model = HybridSSMModel.from_pretrained(checkpoint_config) + model = HybridModel.from_pretrained(checkpoint_config) param_sum = 0 for stage in model.stages: for fsdp in stage.fsdps: @@ -139,57 +139,75 @@ def test_load_from_llamba_checkpoint(distributed_config): assert torch.allclose(logits, hf_logits, atol=1e-2) +def get_hf_apriel_hybrid_out(input_ids, path, format): + from fast_llm.models.hybrid.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridForCausalLM + + model = AprielSSMHybridForCausalLM.from_pretrained(path, strict=True).to("cuda") + parameter_sum = sum(p.detach().cpu().numpy().sum() for p in model.parameters()) + print(f"Parameter sum: {parameter_sum}") + output = model(input_ids) + del model + torch.cuda.empty_cache() + return output, parameter_sum + + +@pytest.mark.slow +@pytest.mark.skipif( + not run_test + and not pathlib.Path("/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_ssm2nd_init_mambainlama_debug").exists(), + reason=f"Skipping because no CUDA available or Mamba not installed", +) +def test_load_from_hybridssm_checkpoint(): + """ + Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. + """ + vocab_size = 131072 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json + batch_size = 2 + seq_length = 32 + + path = pathlib.Path("/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_ssm2nd_init_mambainlama_debug") + format = AprielSSMHHybridHuggingfaceCheckpointFormat + + x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") + hf_logits, parameter_sum_hf = get_hf_apriel_hybrid_out(x, path, format) + hf_logits = hf_logits["logits"].cpu() + + # Create checkpoint load config + checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) + # Initialize model + model = HybridModel.from_pretrained(checkpoint_config) + param_sum = 0 + for stage in model.stages: + for fsdp in stage.fsdps: + if hasattr(fsdp, "_weight_shard"): + param_sum += torch.sum(fsdp._weight_shard).item() + assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 + + +# test legacy behavior of using m and m2d # TODO: Speed up this test or bring it back as an integration test. @pytest.mark.skip(reason="Too slow.") @pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") @pytest.mark.parametrize( - "hybrid_block_layout,LAYER_CLS", + "hybrid_block_layout", [ - (["m", "t"], MambaLayer), - (["m2", "t"], DiscreteMamba2), + (["m"]), + (["m2d"]), ], ids=["mamba", "discrete_mamba2"], ) -def test_mamba_layer(distributed_config, distributed, hybrid_block_layout, LAYER_CLS): +def test_mamba_block(distributed_config, distributed, hybrid_block_layout): hybrid_config = get_hybrid_config(hybrid_block_layout=hybrid_block_layout) tensor_space = TensorSpace(distributed_config=distributed_config) - hybrid_config.setup_tensor_space(tensor_space) - layer = LAYER_CLS(hybrid_config.ssm, layer_idx=0, tensor_space=tensor_space) - tensor_space.setup(distributed) - materialize_meta_tensors(layer, tensor_space) - layer.to(distributed.device) - - batch_size = 2 - seq_length = 32 - hidden_size = hybrid_config.transformer.hidden_size - x = torch.randn(batch_size, seq_length, hidden_size, device=distributed.device) - - # Run forward pass - output, _ = layer(x, {}) - - loss = output.sum() - loss.backward() - # Basic shape checkss - assert output.shape == x.shape - assert not torch.isnan(output).any() - assert not torch.isinf(output).any() - - -@pytest.mark.skipif(not run_test, reason="No CUDA available or Mamba not installed") -def test_mamba_block(distributed_config, distributed): - hybrid_config = get_hybrid_config(hybrid_block_layout=["m", "t"]) - tensor_space = TensorSpace(distributed_config=distributed_config) tensor_space.setup(distributed) hybrid_config.setup_tensor_space(tensor_space) layer_idx = 0 - - mixer_cls = partial(MambaLayer, layer_idx=layer_idx) - block = LlambaBlock( - hybrid_config.transformer, + BLOCK_CLS = hybrid_config.blocks[hybrid_block_layout[0]].block_class + block = BLOCK_CLS( hybrid_config.ssm, - mixer_cls=mixer_cls, tensor_space=tensor_space, layer_index=layer_idx, + block_name=hybrid_block_layout[0], ) materialize_meta_tensors(block, tensor_space) @@ -215,13 +233,13 @@ def test_mamba_block(distributed_config, distributed): ("hybrid_block_layout"), [ (["m", "t"]), - (["m2", "t"]), + (["m2d", "t"]), ], ids=["mamba", "discrete_mamba2"], ) def test_hybrid_model_train_with_fast_mode(distributed_config, hybrid_block_layout): hybrid_config = get_hybrid_config(hybrid_block_layout=hybrid_block_layout) - model = HybridSSMBaseModel(hybrid_config, distributed_config) + model = HybridBaseModel(hybrid_config, distributed_config) distributed = Distributed(distributed_config) model.setup(distributed) tensor_space = model._tensor_space @@ -268,7 +286,7 @@ def test_hybrid_model_train_with_fast_mode(distributed_config, hybrid_block_layo # @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA available") # def test_hybrid_model_inference(distributed_config, hybrid_config): # hybrid_config.ssm.use_fast_path = False -# model = HybridSSMBaseModel(hybrid_config, distributed_config) +# model = HybridBaseModel(hybrid_config, distributed_config) # distributed = Distributed(distributed_config) # model.setup(distributed) # tensor_space = model._tensor_space @@ -302,3 +320,6 @@ def test_hybrid_model_train_with_fast_mode(distributed_config, hybrid_block_layo # }, # losses=losses, # ) + +if __name__ == "__main__": + pytest.main(["-s", __file__]) diff --git a/tests/test_modular_config.py b/tests/test_modular_config.py new file mode 100644 index 00000000..b740a43c --- /dev/null +++ b/tests/test_modular_config.py @@ -0,0 +1,37 @@ +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.ssm.config import SSMConfig +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.models.hybrid.config import HybridBaseModelConfig, MambaBlockConfig, TransformerBlockConfig +from fast_llm.models.hybrid.model import HybridBaseModel + +config = HybridBaseModelConfig( + blocks={ + "transformer_block": TransformerBlockConfig( + transformer=TransformerConfig( + hidden_size=4096, + num_attention_heads=32, + num_layers=10, + ), + ), + "mamba_block": MambaBlockConfig( + ssm=SSMConfig( + state_size=16, + ), + ), + "mamba2_block": MambaBlockConfig( + ssm=SSMConfig( + state_size=16, + ), + ), + }, + hybrid_block_layout=["mamba_block", "mamba2_block", "mamba_block"], +) + +distributed_config = DistributedConfig( + tensor_parallel=1, + pipeline_parallel=1, + world_size=1, +) + +# Create model +model = HybridBaseModel(config, distributed_config) diff --git a/tests/test_mtp.py b/tests/test_mtp.py index edce4e74..6df35102 100644 --- a/tests/test_mtp.py +++ b/tests/test_mtp.py @@ -9,6 +9,7 @@ from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead +from fast_llm.layers.ssm.config import SSMBlockType from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.models.gpt.config import GPTBaseModelConfig @@ -19,9 +20,9 @@ try: from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 from fast_llm.layers.ssm.mamba_layer import MambaLayer - from fast_llm.models.ssm.model import HybridSSMBaseModel + from fast_llm.models.hybrid.model import HybridBaseModel except ImportError: - MambaLayer, HybridSSMBaseModel, DiscreteMamba2 = ( + MambaLayer, HybridBaseModel, DiscreteMamba2 = ( None, None, None, @@ -135,18 +136,18 @@ def test_transformer_mtp(config_dict: dict[str, typing.Any]): @pytest.mark.parametrize( ("hybrid_block_layout", "prediction_heads", "default_mtp_type"), [ - (["m", "t"], 1, None), - (["t", "m"], 2, None), - (["m", "t"], 2, None), - (["t", "m2"], 3, None), - (["t", "m2"], 3, "m"), + ([SSMBlockType.mamba.value, SSMBlockType.transformer.value], 1, None), + ([SSMBlockType.transformer.value, SSMBlockType.mamba.value], 2, None), + ([SSMBlockType.mamba.value, SSMBlockType.transformer.value], 2, None), + ([SSMBlockType.transformer.value, SSMBlockType.mamba2_discrete.value], 3, None), + ([SSMBlockType.transformer.value, SSMBlockType.mamba2_discrete.value], 3, SSMBlockType.mamba.value), ], ) def test_hybrid_model_mtp(distributed_config, hybrid_block_layout, prediction_heads, default_mtp_type): hybrid_config = get_hybrid_config( hybrid_block_layout=hybrid_block_layout, prediction_heads=prediction_heads, default_mtp_type=default_mtp_type ) - model = HybridSSMBaseModel(hybrid_config, distributed_config) + model = HybridBaseModel(hybrid_config, distributed_config) distributed = Distributed(distributed_config) model.setup(distributed) tensor_space = model._tensor_space @@ -154,7 +155,11 @@ def test_hybrid_model_mtp(distributed_config, hybrid_block_layout, prediction_he model.to("cuda") num_heads, num_mtp_blocks = 0, 0 - str_block_mapping = {"t": TransformerLayer, "m": MambaLayer, "m2": DiscreteMamba2} + str_block_mapping = { + SSMBlockType.transformer: TransformerLayer, + SSMBlockType.mamba: MambaLayer, + SSMBlockType.mamba2_discrete: DiscreteMamba2, + } mtp_block_type = default_mtp_type or hybrid_block_layout[-1] for block in model.get_output_layers(): if isinstance(block, LanguageModelHead):