From 53639168bd3f4d5128157f4db934336eef463574 Mon Sep 17 00:00:00 2001 From: xuebi Date: Fri, 31 Oct 2025 14:17:08 +0800 Subject: [PATCH 1/9] update: init m2 Signed-off-by: xuebi --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/minimax_m2.md | 71 ++ src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 5 + .../models/minimax_m2/__init__.py | 29 + .../minimax_m2/configuration_minimax_m2.py | 208 +++++ .../models/minimax_m2/modeling_minimax_m2.py | 725 ++++++++++++++++++ .../models/minimax_m2/modular_minimax_m2.py | 226 ++++++ 9 files changed, 1269 insertions(+) create mode 100644 docs/source/en/model_doc/minimax_m2.md create mode 100644 src/transformers/models/minimax_m2/__init__.py create mode 100644 src/transformers/models/minimax_m2/configuration_minimax_m2.py create mode 100644 src/transformers/models/minimax_m2/modeling_minimax_m2.py create mode 100644 src/transformers/models/minimax_m2/modular_minimax_m2.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 0198cdd33711..6f3d7b5d22f7 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -600,6 +600,8 @@ title: MegatronGPT2 - local: model_doc/minimax title: MiniMax + - local: model_doc/minimax_m2 + title: MiniMax-M2 - local: model_doc/ministral title: Ministral - local: model_doc/mistral diff --git a/docs/source/en/model_doc/minimax_m2.md b/docs/source/en/model_doc/minimax_m2.md new file mode 100644 index 000000000000..81df4a083975 --- /dev/null +++ b/docs/source/en/model_doc/minimax_m2.md @@ -0,0 +1,71 @@ + + + +# MiniMax-M2 + +## Overview + +The MiniMax-M2 model was proposed in []() by . + + +The abstract from the paper is the following: + + + +Tips: + + + +This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). +The original code can be found [here](). + +## Usage examples + + + +## MiniMaxM2Config + +[[autodoc]] MiniMaxM2Config + +## MiniMaxM2ForCausalLM + +[[autodoc]] MiniMaxM2ForCausalLM + +## MiniMaxM2ForQuestionAnswering + +[[autodoc]] MiniMaxM2ForQuestionAnswering + +## MiniMaxM2Model + +[[autodoc]] MiniMaxM2Model + - forward + +## MiniMaxM2PreTrainedModel + +[[autodoc]] MiniMaxM2PreTrainedModel + - forward + +## MiniMaxM2ForSequenceClassification + +[[autodoc]] MiniMaxM2ForSequenceClassification + +## MiniMaxM2ForTokenClassification + +[[autodoc]] MiniMaxM2ForTokenClassification \ No newline at end of file diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 5630063f92ec..d9179c892b59 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -216,6 +216,7 @@ from .mgp_str import * from .mimi import * from .minimax import * + from .minimax_m2 import * from .ministral import * from .mistral import * from .mistral3 import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 7e2e84a445ef..876d8f19224f 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -258,6 +258,7 @@ ("mgp-str", "MgpstrConfig"), ("mimi", "MimiConfig"), ("minimax", "MiniMaxConfig"), + ("minimax_m2", "MiniMaxM2Config"), ("ministral", "MinistralConfig"), ("mistral", "MistralConfig"), ("mistral3", "Mistral3Config"), @@ -711,6 +712,7 @@ ("mgp-str", "MGP-STR"), ("mimi", "Mimi"), ("minimax", "MiniMax"), + ("minimax_m2", "MiniMax-M2"), ("ministral", "Ministral"), ("mistral", "Mistral"), ("mistral3", "Mistral3"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 197029464efd..8cd084f1eae8 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -259,6 +259,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("mgp-str", "MgpstrForSceneTextRecognition"), ("mimi", "MimiModel"), ("minimax", "MiniMaxModel"), + ("minimax_m2", "MiniMaxM2Model"), ("ministral", "MinistralModel"), ("mistral", "MistralModel"), ("mistral3", "Mistral3Model"), @@ -710,6 +711,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("mega", "MegaForCausalLM"), ("megatron-bert", "MegatronBertForCausalLM"), ("minimax", "MiniMaxForCausalLM"), + ("minimax_m2", "MiniMaxM2ForCausalLM"), ("ministral", "MinistralForCausalLM"), ("mistral", "MistralForCausalLM"), ("mixtral", "MixtralForCausalLM"), @@ -1272,6 +1274,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("mega", "MegaForSequenceClassification"), ("megatron-bert", "MegatronBertForSequenceClassification"), ("minimax", "MiniMaxForSequenceClassification"), + ("minimax_m2", "MiniMaxM2ForSequenceClassification"), ("ministral", "MinistralForSequenceClassification"), ("mistral", "MistralForSequenceClassification"), ("mixtral", "MixtralForSequenceClassification"), @@ -1372,6 +1375,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("mega", "MegaForQuestionAnswering"), ("megatron-bert", "MegatronBertForQuestionAnswering"), ("minimax", "MiniMaxForQuestionAnswering"), + ("minimax_m2", "MiniMaxM2ForQuestionAnswering"), ("ministral", "MinistralForQuestionAnswering"), ("mistral", "MistralForQuestionAnswering"), ("mixtral", "MixtralForQuestionAnswering"), @@ -1488,6 +1492,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("mega", "MegaForTokenClassification"), ("megatron-bert", "MegatronBertForTokenClassification"), ("minimax", "MiniMaxForTokenClassification"), + ("minimax_m2", "MiniMaxM2ForTokenClassification"), ("ministral", "MinistralForTokenClassification"), ("mistral", "MistralForTokenClassification"), ("mixtral", "MixtralForTokenClassification"), diff --git a/src/transformers/models/minimax_m2/__init__.py b/src/transformers/models/minimax_m2/__init__.py new file mode 100644 index 000000000000..819d6dd9c568 --- /dev/null +++ b/src/transformers/models/minimax_m2/__init__.py @@ -0,0 +1,29 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_minimax_m2 import * + from .modeling_minimax_m2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/minimax_m2/configuration_minimax_m2.py b/src/transformers/models/minimax_m2/configuration_minimax_m2.py new file mode 100644 index 000000000000..cd33f3150a30 --- /dev/null +++ b/src/transformers/models/minimax_m2/configuration_minimax_m2.py @@ -0,0 +1,208 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/minimax_m2/modular_minimax_m2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_minimax_m2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 the MiniMax AI Team and HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params + + +class MiniMaxM2Config(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MiniMaxM2Model`]. It is used to instantiate an + MiniMaxM2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the MiniMaxM2-7B-v0.1 or MiniMaxM2-7B-Instruct-v0.1. + + [minimax_m2ai/MiniMaxM2-8x7B](https://huggingface.co/minimax_m2ai/MiniMaxM2-8x7B) + [minimax_m2ai/MiniMaxM2-7B-Instruct-v0.1](https://huggingface.co/minimax_m2ai/MiniMaxM2-7B-Instruct-v0.1) + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the MiniMaxM2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MiniMaxM2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`. + head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. MiniMaxM2's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter + num_local_experts (`int`, *optional*, defaults to 8): + Number of experts per Sparse MLP layer. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss. See [here]() for more details + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + router_jitter_noise (`float`, *optional*, defaults to 0.0): + Amount of noise to add to the router. + rope_parameters (`RopeParameters`, *optional*): + Dictionary containing the configuration parameters for the RoPE embeddings. The dictionaty should contain + a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE + with longer `max_position_embeddings`. + + ```python + >>> from transformers import MiniMaxM2Model, MiniMaxM2Config + + >>> # Initializing a MiniMaxM2 7B style configuration + >>> configuration = MiniMaxM2Config() + + >>> # Initializing a model from the MiniMaxM2 7B style configuration + >>> model = MiniMaxM2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "minimax_m2" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts + "layers.*.block_sparse_moe.experts.*.w1": "colwise", + "layers.*.block_sparse_moe.experts.*.w2": "rowwise", + "layers.*.block_sparse_moe.experts.*.w3": "colwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + attribute_map = { + "num_experts": "num_local_experts", + } + + def __init__( + self, + vocab_size: Optional[int] = 32000, + hidden_size: Optional[int] = 4096, + intermediate_size: Optional[int] = 14336, + num_hidden_layers: Optional[int] = 32, + num_attention_heads: Optional[int] = 32, + num_key_value_heads: Optional[int] = 8, + head_dim: Optional[int] = None, + hidden_act: Optional[str] = "silu", + max_position_embeddings: Optional[int] = 4096 * 32, + initializer_range: Optional[float] = 0.02, + rms_norm_eps: Optional[int] = 1e-5, + use_cache: Optional[bool] = True, + pad_token_id: Optional[int] = None, + bos_token_id: Optional[int] = 1, + eos_token_id: Optional[int] = 2, + tie_word_embeddings: Optional[bool] = False, + sliding_window: Optional[int] = None, + attention_dropout: Optional[float] = 0.0, + num_experts_per_tok: Optional[int] = 2, + num_local_experts: Optional[int] = 8, + output_router_logits: Optional[bool] = False, + router_aux_loss_coef: Optional[float] = 0.001, + router_jitter_noise: Optional[float] = 0.0, + rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.attention_dropout = attention_dropout + self.head_dim = head_dim + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + # Try to set `rope_scaling` if available, otherwise use `rope_parameters` + rope_scaling = kwargs.pop("rope_scaling", None) + self.rope_parameters = rope_scaling or rope_parameters + + # Validate the correctness of rotary position embeddings parameters + rope_theta = kwargs.get("rope_theta", 1000000.0) + standardize_rope_params(self, rope_theta=rope_theta) + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["MiniMaxM2Config"] diff --git a/src/transformers/models/minimax_m2/modeling_minimax_m2.py b/src/transformers/models/minimax_m2/modeling_minimax_m2.py new file mode 100644 index 000000000000..9a4d2e0a40ab --- /dev/null +++ b/src/transformers/models/minimax_m2/modeling_minimax_m2.py @@ -0,0 +1,725 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/minimax_m2/modular_minimax_m2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_minimax_m2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 the MiniMax AI Team and HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +from typing import Optional, Union, Unpack + +import torch +from torch import nn + +from transformers.utils.generic import check_model_inputs + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import ( + GenericForQuestionAnswering, + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import OutputRecorder +from .configuration_minimax_m2 import MiniMaxM2Config + + +class MiniMaxM2MLP(nn.Module): + def __init__(self, config: MiniMaxM2Config): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class MiniMaxM2Experts(nn.ModuleList): + """ + ModuleList of experts. + """ + + def __init__(self, config: MiniMaxM2Config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + for _ in range(self.num_experts): + self.append(MiniMaxM2MLP(config)) + + def forward( + self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + ) -> torch.Tensor: + """ + Args: + hidden_states: (batch_size * sequence_length, hidden_dim) + selected_experts: (batch_size * sequence_length, top_k) + routing_weights: (batch_size * sequence_length, top_k) + Returns: + (batch_size * sequence_length, hidden_dim) + """ + final_hidden_states = torch.zeros_like(hidden_states) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) + current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) + current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + return final_hidden_states + + +class MiniMaxM2SparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.jitter_noise = config.router_jitter_noise + self.gate = nn.Linear(config.hidden_size, config.num_local_experts, bias=False) + self.experts = MiniMaxM2Experts(config) + self.e_score_correction_bias = nn.Parameter(torch.zeros((config.num_local_experts), dtype=torch.float32)) + + def route_tokens_to_experts(self, router_logits): + routing_weights = torch.nn.functional.sigmoid(router_logits.float()) + scores_for_choice = routing_weights + self.e_score_correction_bias + _, top_k_index = torch.topk(scores_for_choice, self.top_k, dim=-1, sorted=False) + top_k_weights = routing_weights.gather(1, top_k_index) + top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) + return top_k_index, top_k_weights.to(router_logits.dtype) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, sequence_length, hidden_dim = hidden_states.shape + if self.training and self.jitter_noise > 0: + hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + router_logits = self.gate(hidden_states) + top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits) + hidden_states = self.experts(hidden_states, top_k_index, top_k_weights.to(hidden_states.dtype)) + hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return hidden_states + + +@use_kernel_forward_from_hub("RMSNorm") +class MiniMaxM2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MiniMaxM2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class MiniMaxM2RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: MiniMaxM2Config, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = inv_freq + + @staticmethod + def compute_default_rope_parameters( + config: Optional[MiniMaxM2Config] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "rotary_dim", None) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Keep half or full tensor for later concatenation + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + +class MiniMaxM2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MiniMaxM2Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + + self.use_qk_norm = config.use_qk_norm + if self.use_qk_norm: + self.q_norm = MiniMaxM2RMSNorm(self.head_dim * config.num_attention_heads, eps=config.rms_norm_eps) + self.k_norm = MiniMaxM2RMSNorm(self.head_dim * config.num_key_value_heads, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.use_qk_norm: # main diff from Llama + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + key_states = key_states.view(hidden_shape) + query_states = query_states.view(hidden_shape) + value_states = value_states.view(hidden_shape) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class MiniMaxM2DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: MiniMaxM2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MiniMaxM2Attention(config, layer_idx) + + self.block_sparse_moe = MiniMaxM2SparseMoeBlock(config) + self.input_layernorm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class MiniMaxM2PreTrainedModel(PreTrainedModel): + config: MiniMaxM2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MiniMaxM2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _supports_attention_backend = True + _can_record_outputs = { + "router_logits": OutputRecorder(nn.Linear, layer_name="block_sparse_moe.gate", index=0), + "hidden_states": MiniMaxM2DecoderLayer, + "attentions": MiniMaxM2Attention, + } + + +@auto_docstring +class MiniMaxM2Model(MiniMaxM2PreTrainedModel): + def __init__(self, config: MiniMaxM2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MiniMaxM2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = MiniMaxM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = MiniMaxM2RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask + causal_mask = mask_function( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +@auto_docstring +class MiniMaxM2ForCausalLM(MiniMaxM2PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = MiniMaxM2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, MiniMaxM2ForCausalLM + + >>> model = MiniMaxM2ForCausalLM.from_pretrained("mistralai/MiniMaxM2-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/MiniMaxM2-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +class MiniMaxM2ForSequenceClassification(GenericForSequenceClassification, MiniMaxM2PreTrainedModel): + pass + + +class MiniMaxM2ForTokenClassification(GenericForTokenClassification, MiniMaxM2PreTrainedModel): + pass + + +class MiniMaxM2ForQuestionAnswering(GenericForQuestionAnswering, MiniMaxM2PreTrainedModel): + pass + + +__all__ = [ + "MiniMaxM2ForCausalLM", + "MiniMaxM2ForQuestionAnswering", + "MiniMaxM2Model", + "MiniMaxM2PreTrainedModel", + "MiniMaxM2ForSequenceClassification", + "MiniMaxM2ForTokenClassification", +] diff --git a/src/transformers/models/minimax_m2/modular_minimax_m2.py b/src/transformers/models/minimax_m2/modular_minimax_m2.py new file mode 100644 index 000000000000..a031b097a5ec --- /dev/null +++ b/src/transformers/models/minimax_m2/modular_minimax_m2.py @@ -0,0 +1,226 @@ +# coding=utf-8 +# Copyright 2025 the MiniMax AI Team and HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Unpack +import torch +from torch import nn + +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...cache_utils import Cache + + +from ..mixtral.configuration_mixtral import MixtralConfig +from ..mixtral.modeling_mixtral import ( + MixtralDecoderLayer, + MixtralExperts, + MixtralForCausalLM, + MixtralForQuestionAnswering, + MixtralForSequenceClassification, + MixtralForTokenClassification, + MixtralMLP, + MixtralModel, + MixtralPreTrainedModel, + MixtralRMSNorm, + MixtralSparseMoeBlock, +) +from ..glm4_moe.modeling_glm4_moe import ( + Glm4MoeAttention, + Glm4MoeRotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, +) + + + +class MiniMaxM2Config(MixtralConfig): + model_type = "minimax_m2" + + +class MiniMaxM2MLP(MixtralMLP): + pass + + +class MiniMaxM2Experts(MixtralExperts): + pass + + +class MiniMaxM2SparseMoeBlock(MixtralSparseMoeBlock): + def __init__(self, config): + nn.Module.__init__(self) + self.top_k = config.num_experts_per_tok + self.jitter_noise = config.router_jitter_noise + self.gate = nn.Linear(config.hidden_size, config.num_local_experts, bias=False) + self.experts = MiniMaxM2Experts(config) + self.e_score_correction_bias = nn.Parameter(torch.zeros((config.num_local_experts), dtype=torch.float32)) + + def route_tokens_to_experts(self, router_logits): + routing_weights = torch.nn.functional.sigmoid(router_logits.float()) + scores_for_choice = routing_weights + self.e_score_correction_bias + _, top_k_index = torch.topk(scores_for_choice, self.top_k, dim=-1, sorted=False) + top_k_weights = routing_weights.gather(1, top_k_index) + top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) + return top_k_index, top_k_weights.to(router_logits.dtype) + + +class MiniMaxM2RMSNorm(MixtralRMSNorm): + pass + + +class MiniMaxM2RotaryEmbedding(Glm4MoeRotaryEmbedding): + @staticmethod + def compute_default_rope_parameters( + config: Optional[MiniMaxM2Config] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "rotary_dim", None) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + +class MiniMaxM2Attention(Glm4MoeAttention): + def __init__(self, config: MiniMaxM2Config, layer_idx: int): + nn.Module.__init__(self) + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + + self.use_qk_norm = config.use_qk_norm + if self.use_qk_norm: + self.q_norm = MiniMaxM2RMSNorm(self.head_dim * config.num_attention_heads, eps=config.rms_norm_eps) + self.k_norm = MiniMaxM2RMSNorm(self.head_dim * config.num_key_value_heads, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.use_qk_norm: # main diff from Llama + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + key_states = key_states.view(hidden_shape) + query_states = query_states.view(hidden_shape) + value_states = value_states.view(hidden_shape) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class MiniMaxM2DecoderLayer(MixtralDecoderLayer): + pass + + +class MiniMaxM2PreTrainedModel(MixtralPreTrainedModel): + pass + + +class MiniMaxM2Model(MixtralModel): + pass + + +class MiniMaxM2ForCausalLM(MixtralForCausalLM): + pass + + +class MiniMaxM2ForSequenceClassification(MixtralForSequenceClassification): + pass + + +class MiniMaxM2ForTokenClassification(MixtralForTokenClassification): + pass + + +class MiniMaxM2ForQuestionAnswering(MixtralForQuestionAnswering): + pass + + +__all__ = [ + "MiniMaxM2Config", + "MiniMaxM2ForCausalLM", + "MiniMaxM2ForQuestionAnswering", + "MiniMaxM2Model", + "MiniMaxM2PreTrainedModel", + "MiniMaxM2ForSequenceClassification", + "MiniMaxM2ForTokenClassification", +] From 261fe5cc1e047207188cb3864372a9eafa9dc682 Mon Sep 17 00:00:00 2001 From: xuebi Date: Fri, 31 Oct 2025 15:27:02 +0800 Subject: [PATCH 2/9] update: docs and config Signed-off-by: xuebi --- docs/source/en/model_doc/minimax.md | 2 + docs/source/en/model_doc/minimax_m2.md | 32 ++- .../minimax_m2/configuration_minimax_m2.py | 13 +- .../models/minimax_m2/modeling_minimax_m2.py | 4 +- .../models/minimax_m2/modular_minimax_m2.py | 203 ++++++++++++++++-- 5 files changed, 224 insertions(+), 30 deletions(-) diff --git a/docs/source/en/model_doc/minimax.md b/docs/source/en/model_doc/minimax.md index bd98761bfed1..a2d5a781c251 100644 --- a/docs/source/en/model_doc/minimax.md +++ b/docs/source/en/model_doc/minimax.md @@ -17,6 +17,8 @@ rendered properly in your Markdown viewer. # MiniMax +> [MiniMax-M2](https://huggingface.co/docs/transformers/en/model_doc/minimax_m2) was released on 2025‑10‑27. We recommend using MiniMax‑M2 for most use cases due to better overall performance. + ## Overview The MiniMax-Text-01 model was proposed in [MiniMax-01: Scaling Foundation Models with Lightning Attention](https://huggingface.co/papers/2501.08313) by MiniMax, Aonian Li, Bangwei Gong, Bo Yang, Boji Shan, Chang Liu, Cheng Zhu, Chunhao Zhang, Congchao Guo, Da Chen, Dong Li, Enwei Jiao, Gengxin Li, Guojun Zhang, Haohai Sun, Houze Dong, Jiadai Zhu, Jiaqi Zhuang, Jiayuan Song, Jin Zhu, Jingtao Han, Jingyang Li, Junbin Xie, Junhao Xu, Junjie Yan, Kaishun Zhang, Kecheng Xiao, Kexi Kang, Le Han, Leyang Wang, Lianfei Yu, Liheng Feng, Lin Zheng, Linbo Chai, Long Xing, Meizhi Ju, Mingyuan Chi, Mozhi Zhang, Peikai Huang, Pengcheng Niu, Pengfei Li, Pengyu Zhao, Qi Yang, Qidi Xu, Qiexiang Wang, Qin Wang, Qiuhui Li, Ruitao Leng, Shengmin Shi, Shuqi Yu, Sichen Li, Songquan Zhu, Tao Huang, Tianrun Liang, Weigao Sun, Weixuan Sun, Weiyu Cheng, Wenkai Li, Xiangjun Song, Xiao Su, Xiaodong Han, Xinjie Zhang, Xinzhu Hou, Xu Min, Xun Zou, Xuyang Shen, Yan Gong, Yingjie Zhu, Yipeng Zhou, Yiran Zhong, Yongyi Hu, Yuanxiang Fan, Yue Yu, Yufeng Yang, Yuhao Li, Yunan Huang, Yunji Li, Yunpeng Huang, Yunzhi Xu, Yuxin Mao, Zehan Li, Zekang Li, Zewei Tao, Zewen Ying, Zhaoyang Cong, Zhen Qin, Zhenhua Fan, Zhihang Yu, Zhuo Jiang, Zijia Wu. diff --git a/docs/source/en/model_doc/minimax_m2.md b/docs/source/en/model_doc/minimax_m2.md index 81df4a083975..f72b796c30a3 100644 --- a/docs/source/en/model_doc/minimax_m2.md +++ b/docs/source/en/model_doc/minimax_m2.md @@ -22,23 +22,35 @@ limitations under the License. ## Overview -The MiniMax-M2 model was proposed in []() by . - +MiniMax-M2 redefines efficiency for agents. It's a compact, fast, and cost-effective MoE model (230 billion total parameters with 10 billion active parameters) built for elite performance in coding and agentic tasks, all while maintaining powerful general intelligence. With just 10 billion activated parameters, MiniMax-M2 provides the sophisticated, end-to-end tool use performance expected from today's leading models, but in a streamlined form factor that makes deployment and scaling easier than ever. -The abstract from the paper is the following: +For more details refer to the [release blog post](https://www.minimax.io/news/minimax-m2). - +## Usage examples -Tips: +```python +from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig - +model = AutoModelForCausalLM.from_pretrained("MiniMaxAI/MiniMax-M2", device_map="auto") -This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). -The original code can be found [here](). +tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-M2") -## Usage examples +generation_config = GenerationConfig.from_pretrained("MiniMaxAI/MiniMax-M2") + +messages = [ + {"role": "user", "content": "What is your favourite condiment?"}, + {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"}, + {"role": "user", "content": "Do you have mayonnaise recipes?"} +] + +model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to("cuda") + +generated_ids = model.generate(model_inputs, max_new_tokens=100, generation_config=generation_config) + +response = tokenizer.batch_decode(generated_ids)[0] - +print(response) +``` ## MiniMaxM2Config diff --git a/src/transformers/models/minimax_m2/configuration_minimax_m2.py b/src/transformers/models/minimax_m2/configuration_minimax_m2.py index cd33f3150a30..ee9e76e69e08 100644 --- a/src/transformers/models/minimax_m2/configuration_minimax_m2.py +++ b/src/transformers/models/minimax_m2/configuration_minimax_m2.py @@ -29,10 +29,9 @@ class MiniMaxM2Config(PreTrainedConfig): r""" This is the configuration class to store the configuration of a [`MiniMaxM2Model`]. It is used to instantiate an MiniMaxM2 model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of the MiniMaxM2-7B-v0.1 or MiniMaxM2-7B-Instruct-v0.1. + with the defaults will yield a similar configuration to that of the MiniMaxM2. - [minimax_m2ai/MiniMaxM2-8x7B](https://huggingface.co/minimax_m2ai/MiniMaxM2-8x7B) - [minimax_m2ai/MiniMaxM2-7B-Instruct-v0.1](https://huggingface.co/minimax_m2ai/MiniMaxM2-7B-Instruct-v0.1) + [MiniMaxAI/MiniMax-M2](https://huggingface.co/MiniMaxAI/MiniMax-M2) Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PreTrainedConfig`] for more information. @@ -99,14 +98,16 @@ class MiniMaxM2Config(PreTrainedConfig): Dictionary containing the configuration parameters for the RoPE embeddings. The dictionaty should contain a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE with longer `max_position_embeddings`. + rotary_dim (`int`, *optional*, defaults to `None`): + The dimension of the rotary embeddings. If not specified, will default to `head_dim`. ```python >>> from transformers import MiniMaxM2Model, MiniMaxM2Config - >>> # Initializing a MiniMaxM2 7B style configuration + >>> # Initializing a MiniMaxM2 style configuration >>> configuration = MiniMaxM2Config() - >>> # Initializing a model from the MiniMaxM2 7B style configuration + >>> # Initializing a model from the MiniMaxM2 style configuration >>> model = MiniMaxM2Model(configuration) >>> # Accessing the model configuration @@ -160,6 +161,7 @@ def __init__( router_aux_loss_coef: Optional[float] = 0.001, router_jitter_noise: Optional[float] = 0.0, rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None, + rotary_dim: Optional[int] = 64, **kwargs, ): self.vocab_size = vocab_size @@ -190,6 +192,7 @@ def __init__( # Try to set `rope_scaling` if available, otherwise use `rope_parameters` rope_scaling = kwargs.pop("rope_scaling", None) self.rope_parameters = rope_scaling or rope_parameters + self.rotary_dim = rotary_dim # Validate the correctness of rotary position embeddings parameters rope_theta = kwargs.get("rope_theta", 1000000.0) diff --git a/src/transformers/models/minimax_m2/modeling_minimax_m2.py b/src/transformers/models/minimax_m2/modeling_minimax_m2.py index 9a4d2e0a40ab..72d04ee208cf 100644 --- a/src/transformers/models/minimax_m2/modeling_minimax_m2.py +++ b/src/transformers/models/minimax_m2/modeling_minimax_m2.py @@ -25,8 +25,6 @@ import torch from torch import nn -from transformers.utils.generic import check_model_inputs - from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin @@ -43,7 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import OutputRecorder +from ...utils.generic import OutputRecorder, check_model_inputs from .configuration_minimax_m2 import MiniMaxM2Config diff --git a/src/transformers/models/minimax_m2/modular_minimax_m2.py b/src/transformers/models/minimax_m2/modular_minimax_m2.py index a031b097a5ec..31ce8cfe3853 100644 --- a/src/transformers/models/minimax_m2/modular_minimax_m2.py +++ b/src/transformers/models/minimax_m2/modular_minimax_m2.py @@ -13,16 +13,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Unpack +from collections.abc import Callable +from typing import Optional, Unpack + import torch from torch import nn +from ...cache_utils import Cache +from ...configuration_utils import PreTrainedConfig from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params from ...modeling_utils import ALL_ATTENTION_FUNCTIONS -from ...cache_utils import Cache - - -from ..mixtral.configuration_mixtral import MixtralConfig +from ..glm4_moe.modeling_glm4_moe import ( + Glm4MoeAttention, + Glm4MoeRotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, +) from ..mixtral.modeling_mixtral import ( MixtralDecoderLayer, MixtralExperts, @@ -36,17 +43,189 @@ MixtralRMSNorm, MixtralSparseMoeBlock, ) -from ..glm4_moe.modeling_glm4_moe import ( - Glm4MoeAttention, - Glm4MoeRotaryEmbedding, - apply_rotary_pos_emb, - eager_attention_forward, -) +class MiniMaxM2Config(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MiniMaxM2Model`]. It is used to instantiate an + MiniMaxM2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the MiniMaxM2. + + [MiniMaxAI/MiniMax-M2](https://huggingface.co/MiniMaxAI/MiniMax-M2) + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the MiniMaxM2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MiniMaxM2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`. + head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. MiniMaxM2's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter + num_local_experts (`int`, *optional*, defaults to 8): + Number of experts per Sparse MLP layer. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss. See [here]() for more details + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + router_jitter_noise (`float`, *optional*, defaults to 0.0): + Amount of noise to add to the router. + rope_parameters (`RopeParameters`, *optional*): + Dictionary containing the configuration parameters for the RoPE embeddings. The dictionaty should contain + a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE + with longer `max_position_embeddings`. + rotary_dim (`int`, *optional*, defaults to `None`): + The dimension of the rotary embeddings. If not specified, will default to `head_dim`. + + ```python + >>> from transformers import MiniMaxM2Model, MiniMaxM2Config + + >>> # Initializing a MiniMaxM2 style configuration + >>> configuration = MiniMaxM2Config() + + >>> # Initializing a model from the MiniMaxM2 style configuration + >>> model = MiniMaxM2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" -class MiniMaxM2Config(MixtralConfig): model_type = "minimax_m2" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts + "layers.*.block_sparse_moe.experts.*.w1": "colwise", + "layers.*.block_sparse_moe.experts.*.w2": "rowwise", + "layers.*.block_sparse_moe.experts.*.w3": "colwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + attribute_map = { + "num_experts": "num_local_experts", + } + + def __init__( + self, + vocab_size: Optional[int] = 32000, + hidden_size: Optional[int] = 4096, + intermediate_size: Optional[int] = 14336, + num_hidden_layers: Optional[int] = 32, + num_attention_heads: Optional[int] = 32, + num_key_value_heads: Optional[int] = 8, + head_dim: Optional[int] = None, + hidden_act: Optional[str] = "silu", + max_position_embeddings: Optional[int] = 4096 * 32, + initializer_range: Optional[float] = 0.02, + rms_norm_eps: Optional[int] = 1e-5, + use_cache: Optional[bool] = True, + pad_token_id: Optional[int] = None, + bos_token_id: Optional[int] = 1, + eos_token_id: Optional[int] = 2, + tie_word_embeddings: Optional[bool] = False, + sliding_window: Optional[int] = None, + attention_dropout: Optional[float] = 0.0, + num_experts_per_tok: Optional[int] = 2, + num_local_experts: Optional[int] = 8, + output_router_logits: Optional[bool] = False, + router_aux_loss_coef: Optional[float] = 0.001, + router_jitter_noise: Optional[float] = 0.0, + rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None, + rotary_dim: Optional[int] = 64, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.attention_dropout = attention_dropout + self.head_dim = head_dim + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + # Try to set `rope_scaling` if available, otherwise use `rope_parameters` + rope_scaling = kwargs.pop("rope_scaling", None) + self.rope_parameters = rope_scaling or rope_parameters + self.rotary_dim = rotary_dim + + # Validate the correctness of rotary position embeddings parameters + rope_theta = kwargs.get("rope_theta", 1000000.0) + standardize_rope_params(self, rope_theta=rope_theta) + rope_config_validation(self) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) class MiniMaxM2MLP(MixtralMLP): From ac4613cadefebd1565ac5b4c8fd1dd8ed881bb5f Mon Sep 17 00:00:00 2001 From: xuebi Date: Fri, 31 Oct 2025 15:27:23 +0800 Subject: [PATCH 3/9] update: init minimax-m2 test Signed-off-by: xuebi --- tests/models/minimax_m2/__init__.py | 0 .../minimax_m2/test_modeling_minimax_m2.py | 198 ++++++++++++++++++ 2 files changed, 198 insertions(+) create mode 100644 tests/models/minimax_m2/__init__.py create mode 100644 tests/models/minimax_m2/test_modeling_minimax_m2.py diff --git a/tests/models/minimax_m2/__init__.py b/tests/models/minimax_m2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/minimax_m2/test_modeling_minimax_m2.py b/tests/models/minimax_m2/test_modeling_minimax_m2.py new file mode 100644 index 000000000000..44528d7a3f3d --- /dev/null +++ b/tests/models/minimax_m2/test_modeling_minimax_m2.py @@ -0,0 +1,198 @@ +# coding=utf-8 +# Copyright 2025 the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch MiniMaxM2 model.""" + +import unittest + +import pytest + +from transformers import is_torch_available +from transformers.testing_utils import ( + Expectations, + require_flash_attn, + require_torch, + require_torch_accelerator, + require_torch_gpu, + slow, + torch_device, +) + + +if is_torch_available(): + import torch + + from transformers import ( + MiniMaxM2ForCausalLM, + MiniMaxM2Model, + ) + +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester + + +class MiniMaxM2ModelTester(CausalLMModelTester): + if is_torch_available(): + base_model_class = MiniMaxM2Model + + +@require_torch +class MiniMaxM2ModelTest(CausalLMModelTest, unittest.TestCase): + model_tester_class = MiniMaxM2ModelTester + + # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 + def is_pipeline_test_to_skip( + self, + pipeline_test_case_name, + config_class, + model_architecture, + tokenizer_name, + image_processor_name, + feature_extractor_name, + processor_name, + ): + return True + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence_right_padding(self): + self.skipTest(reason="MiniMaxM2 flash attention does not support right padding") + + # Ignore copy + def test_load_balancing_loss(self): + r""" + Let's make sure we can actually compute the loss and do a backward on it. + """ + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.num_local_experts = 8 + config.output_router_logits = True + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + model = MiniMaxM2ForCausalLM(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask) + self.assertEqual(result.router_logits[0].shape, (91, config.num_local_experts)) + torch.testing.assert_close(result.aux_loss.cpu(), torch.tensor(2, dtype=torch.float32), rtol=1e-2, atol=1e-2) + + # First, we make sure that adding padding tokens doesn't change the loss + # loss(input_ids, attention_mask=None) == loss(input_ids + padding, attention_mask=attention_mask_with_padding) + pad_length = 1000 + # Add padding tokens (assume that pad_token_id=1) to input_ids + padding_block = torch.ones(input_ids.shape[0], pad_length, dtype=torch.int32).to(torch_device) + padded_input_ids = torch.cat((padding_block, input_ids), dim=1) # this is to simulate padding to the left + padded_attention_mask = padded_input_ids.ne(1).to(torch_device) + + padded_result = model(padded_input_ids, attention_mask=padded_attention_mask) + torch.testing.assert_close(result.aux_loss.cpu(), padded_result.aux_loss.cpu(), rtol=1e-4, atol=1e-4) + + # We make sure that the loss of including padding tokens != the loss without padding tokens + # if attention_mask=None --> we don't exclude padding tokens + include_padding_result = model(padded_input_ids, attention_mask=None) + + # This is to mimic torch.testing.assert_not_close + self.assertNotAlmostEqual(include_padding_result.aux_loss.item(), result.aux_loss.item()) + + +@require_torch +class MiniMaxM2IntegrationTest(unittest.TestCase): + @slow + @require_torch_accelerator + def test_small_model_logits(self): + model_id = "hf-internal-testing/MiniMaxM2-tiny" + dummy_input = torch.LongTensor([[0, 1, 0], [0, 1, 0]]).to(torch_device) + + model = MiniMaxM2ForCausalLM.from_pretrained( + model_id, + dtype=torch.bfloat16, + ).to(torch_device) + # TODO: might need to tweak it in case the logits do not match on our daily runners + # these logits have been obtained with the original megablocks implementation. + # ("cuda", 8) for A100/A10, and ("cuda", 7) for T4 + # considering differences in hardware processing and potential deviations in output. + # fmt: off + EXPECTED_LOGITS = Expectations( + { + ("cuda", 7): torch.Tensor([[0.1640, 0.1621, 0.6093], [-0.8906, -0.1640, -0.6093], [0.1562, 0.1250, 0.7226]]).to(torch_device), + ("cuda", 8): torch.Tensor([[0.1631, 0.1621, 0.6094], [-0.8906, -0.1621, -0.6094], [0.1572, 0.1270, 0.7227]]).to(torch_device), + ("rocm", 9): torch.Tensor([[0.1641, 0.1621, 0.6094], [-0.8906, -0.1631, -0.6094], [0.1572, 0.1260, 0.7227]]).to(torch_device), + } + ) + # fmt: on + expected_logit = EXPECTED_LOGITS.get_expectation() + + with torch.no_grad(): + logits = model(dummy_input).logits + + logits = logits.float() + + torch.testing.assert_close(logits[0, :3, :3], expected_logit, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(logits[1, :3, :3], expected_logit, atol=1e-3, rtol=1e-3) + + @slow + @require_torch_accelerator + def test_small_model_logits_batched(self): + model_id = "hf-internal-testing/MiniMaxM2-tiny" + dummy_input = torch.LongTensor([[0, 0, 0, 0, 0, 0, 1, 2, 3], [1, 1, 2, 3, 4, 5, 6, 7, 8]]).to(torch_device) + attention_mask = dummy_input.ne(0).to(torch.long) + + model = MiniMaxM2ForCausalLM.from_pretrained( + model_id, + dtype=torch.bfloat16, + ).to(torch_device) + + # TODO: might need to tweak it in case the logits do not match on our daily runners + # + # ("cuda", 8) for A100/A10, and ("cuda", 7) for T4. + # + # considering differences in hardware processing and potential deviations in generated text. + + EXPECTED_LOGITS_LEFT_UNPADDED = Expectations( + { + ("xpu", 3): [[0.2236, 0.5195, -0.3828], [0.8203, -0.2295, 0.6055], [0.2676, -0.7070, 0.2461]], + ("cuda", 7): [[0.2236, 0.5195, -0.3828], [0.8203, -0.2275, 0.6054], [0.2656, -0.7070, 0.2460]], + ("cuda", 8): [[0.2217, 0.5195, -0.3828], [0.8203, -0.2295, 0.6055], [0.2676, -0.7109, 0.2461]], + ("rocm", 9): [[0.2236, 0.5195, -0.3828], [0.8203, -0.2285, 0.6055], [0.2637, -0.7109, 0.2451]], + } + ) + expected_left_unpadded = torch.tensor(EXPECTED_LOGITS_LEFT_UNPADDED.get_expectation(), device=torch_device) + + EXPECTED_LOGITS_RIGHT_UNPADDED = Expectations( + { + ("xpu", 3): [[0.2178, 0.1270, -0.1641], [-0.3496, 0.2988, -1.0312], [0.0693, 0.7930, 0.8008]], + ("cuda", 7): [[0.2167, 0.1269, -0.1640], [-0.3496, 0.2988, -1.0312], [0.0688, 0.7929, 0.8007]], + ("cuda", 8): [[0.2178, 0.1260, -0.1621], [-0.3496, 0.2988, -1.0312], [0.0693, 0.7930, 0.8008]], + ("rocm", 9): [[0.2197, 0.1250, -0.1611], [-0.3516, 0.3008, -1.0312], [0.0684, 0.7930, 0.8008]], + } + ) + expected_right_unpadded = torch.tensor(EXPECTED_LOGITS_RIGHT_UNPADDED.get_expectation(), device=torch_device) + + with torch.no_grad(): + logits = model(dummy_input, attention_mask=attention_mask).logits + logits = logits.float() + + torch.testing.assert_close( + logits[0, -3:, -3:], + expected_left_unpadded, + atol=1e-3, + rtol=1e-3, + ) + torch.testing.assert_close( + logits[1, -3:, -3:], + expected_right_unpadded, + atol=1e-3, + rtol=1e-3, + ) From 3421fe74a7d588db021b0284fd1991636d7f89f1 Mon Sep 17 00:00:00 2001 From: xuebi Date: Fri, 31 Oct 2025 18:29:37 +0800 Subject: [PATCH 4/9] update: fix tests Signed-off-by: xuebi --- .../models/minimax_m2/configuration_minimax_m2.py | 2 ++ src/transformers/models/minimax_m2/modeling_minimax_m2.py | 2 +- src/transformers/models/minimax_m2/modular_minimax_m2.py | 4 +++- tests/models/minimax_m2/test_modeling_minimax_m2.py | 4 ++++ 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/minimax_m2/configuration_minimax_m2.py b/src/transformers/models/minimax_m2/configuration_minimax_m2.py index ee9e76e69e08..d5f82baa0d3a 100644 --- a/src/transformers/models/minimax_m2/configuration_minimax_m2.py +++ b/src/transformers/models/minimax_m2/configuration_minimax_m2.py @@ -162,6 +162,7 @@ def __init__( router_jitter_noise: Optional[float] = 0.0, rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None, rotary_dim: Optional[int] = 64, + use_qk_norm: Optional[bool] = True, **kwargs, ): self.vocab_size = vocab_size @@ -193,6 +194,7 @@ def __init__( rope_scaling = kwargs.pop("rope_scaling", None) self.rope_parameters = rope_scaling or rope_parameters self.rotary_dim = rotary_dim + self.use_qk_norm = use_qk_norm # Validate the correctness of rotary position embeddings parameters rope_theta = kwargs.get("rope_theta", 1000000.0) diff --git a/src/transformers/models/minimax_m2/modeling_minimax_m2.py b/src/transformers/models/minimax_m2/modeling_minimax_m2.py index 72d04ee208cf..e631046a8b19 100644 --- a/src/transformers/models/minimax_m2/modeling_minimax_m2.py +++ b/src/transformers/models/minimax_m2/modeling_minimax_m2.py @@ -105,7 +105,7 @@ def __init__(self, config): self.jitter_noise = config.router_jitter_noise self.gate = nn.Linear(config.hidden_size, config.num_local_experts, bias=False) self.experts = MiniMaxM2Experts(config) - self.e_score_correction_bias = nn.Parameter(torch.zeros((config.num_local_experts), dtype=torch.float32)) + self.register_buffer("e_score_correction_bias", torch.zeros(config.num_local_experts)) def route_tokens_to_experts(self, router_logits): routing_weights = torch.nn.functional.sigmoid(router_logits.float()) diff --git a/src/transformers/models/minimax_m2/modular_minimax_m2.py b/src/transformers/models/minimax_m2/modular_minimax_m2.py index 31ce8cfe3853..0f9ee77e025b 100644 --- a/src/transformers/models/minimax_m2/modular_minimax_m2.py +++ b/src/transformers/models/minimax_m2/modular_minimax_m2.py @@ -182,6 +182,7 @@ def __init__( router_jitter_noise: Optional[float] = 0.0, rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None, rotary_dim: Optional[int] = 64, + use_qk_norm: Optional[bool] = True, **kwargs, ): self.vocab_size = vocab_size @@ -213,6 +214,7 @@ def __init__( rope_scaling = kwargs.pop("rope_scaling", None) self.rope_parameters = rope_scaling or rope_parameters self.rotary_dim = rotary_dim + self.use_qk_norm = use_qk_norm # Validate the correctness of rotary position embeddings parameters rope_theta = kwargs.get("rope_theta", 1000000.0) @@ -243,7 +245,7 @@ def __init__(self, config): self.jitter_noise = config.router_jitter_noise self.gate = nn.Linear(config.hidden_size, config.num_local_experts, bias=False) self.experts = MiniMaxM2Experts(config) - self.e_score_correction_bias = nn.Parameter(torch.zeros((config.num_local_experts), dtype=torch.float32)) + self.register_buffer("e_score_correction_bias", torch.zeros(config.num_local_experts)) def route_tokens_to_experts(self, router_logits): routing_weights = torch.nn.functional.sigmoid(router_logits.float()) diff --git a/tests/models/minimax_m2/test_modeling_minimax_m2.py b/tests/models/minimax_m2/test_modeling_minimax_m2.py index 44528d7a3f3d..661700f784c6 100644 --- a/tests/models/minimax_m2/test_modeling_minimax_m2.py +++ b/tests/models/minimax_m2/test_modeling_minimax_m2.py @@ -42,6 +42,10 @@ class MiniMaxM2ModelTester(CausalLMModelTester): + def __init__(self, parent): + super().__init__(parent) + self.rotary_dim = self.head_dim + if is_torch_available(): base_model_class = MiniMaxM2Model From 3a5df7a98cf39edbac2df81aa3cd7413dbea985a Mon Sep 17 00:00:00 2001 From: xuebi Date: Tue, 4 Nov 2025 15:39:30 +0800 Subject: [PATCH 5/9] update: use partial_rotary_factor Signed-off-by: xuebi --- .../minimax_m2/configuration_minimax_m2.py | 11 +++-- .../models/minimax_m2/modeling_minimax_m2.py | 4 +- .../models/minimax_m2/modular_minimax_m2.py | 41 ++++--------------- 3 files changed, 20 insertions(+), 36 deletions(-) diff --git a/src/transformers/models/minimax_m2/configuration_minimax_m2.py b/src/transformers/models/minimax_m2/configuration_minimax_m2.py index d5f82baa0d3a..8b103ceb9fac 100644 --- a/src/transformers/models/minimax_m2/configuration_minimax_m2.py +++ b/src/transformers/models/minimax_m2/configuration_minimax_m2.py @@ -161,7 +161,7 @@ def __init__( router_aux_loss_coef: Optional[float] = 0.001, router_jitter_noise: Optional[float] = 0.0, rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None, - rotary_dim: Optional[int] = 64, + rotary_dim: Optional[int] = None, use_qk_norm: Optional[bool] = True, **kwargs, ): @@ -193,14 +193,19 @@ def __init__( # Try to set `rope_scaling` if available, otherwise use `rope_parameters` rope_scaling = kwargs.pop("rope_scaling", None) self.rope_parameters = rope_scaling or rope_parameters - self.rotary_dim = rotary_dim - self.use_qk_norm = use_qk_norm # Validate the correctness of rotary position embeddings parameters rope_theta = kwargs.get("rope_theta", 1000000.0) standardize_rope_params(self, rope_theta=rope_theta) rope_config_validation(self) + self.rotary_dim = rotary_dim if rotary_dim is not None else head_dim + self.use_qk_norm = use_qk_norm + self.partial_rotary_factor = kwargs.pop("partial_rotary_factor", 1.0) + + if self.head_dim is not None: + self.partial_rotary_factor = self.rotary_dim / self.head_dim + super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/src/transformers/models/minimax_m2/modeling_minimax_m2.py b/src/transformers/models/minimax_m2/modeling_minimax_m2.py index e631046a8b19..b42d3b95e067 100644 --- a/src/transformers/models/minimax_m2/modeling_minimax_m2.py +++ b/src/transformers/models/minimax_m2/modeling_minimax_m2.py @@ -187,7 +187,9 @@ def compute_default_rope_parameters( post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ base = config.rope_parameters["rope_theta"] - dim = getattr(config, "rotary_dim", None) + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) attention_factor = 1.0 # Unused in this type of RoPE diff --git a/src/transformers/models/minimax_m2/modular_minimax_m2.py b/src/transformers/models/minimax_m2/modular_minimax_m2.py index 0f9ee77e025b..7481abcd69b9 100644 --- a/src/transformers/models/minimax_m2/modular_minimax_m2.py +++ b/src/transformers/models/minimax_m2/modular_minimax_m2.py @@ -181,7 +181,7 @@ def __init__( router_aux_loss_coef: Optional[float] = 0.001, router_jitter_noise: Optional[float] = 0.0, rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None, - rotary_dim: Optional[int] = 64, + rotary_dim: Optional[int] = None, use_qk_norm: Optional[bool] = True, **kwargs, ): @@ -213,14 +213,19 @@ def __init__( # Try to set `rope_scaling` if available, otherwise use `rope_parameters` rope_scaling = kwargs.pop("rope_scaling", None) self.rope_parameters = rope_scaling or rope_parameters - self.rotary_dim = rotary_dim - self.use_qk_norm = use_qk_norm # Validate the correctness of rotary position embeddings parameters rope_theta = kwargs.get("rope_theta", 1000000.0) standardize_rope_params(self, rope_theta=rope_theta) rope_config_validation(self) + self.rotary_dim = rotary_dim if rotary_dim is not None else head_dim + self.use_qk_norm = use_qk_norm + self.partial_rotary_factor = kwargs.pop("partial_rotary_factor", 1.0) + + if self.head_dim is not None: + self.partial_rotary_factor = self.rotary_dim / self.head_dim + super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, @@ -261,35 +266,7 @@ class MiniMaxM2RMSNorm(MixtralRMSNorm): class MiniMaxM2RotaryEmbedding(Glm4MoeRotaryEmbedding): - @staticmethod - def compute_default_rope_parameters( - config: Optional[MiniMaxM2Config] = None, - device: Optional["torch.device"] = None, - seq_len: Optional[int] = None, - ) -> tuple["torch.Tensor", float]: - """ - Computes the inverse frequencies according to the original RoPE implementation - Args: - config ([`~transformers.PreTrainedConfig`]): - The model configuration. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). - """ - base = config.rope_parameters["rope_theta"] - dim = getattr(config, "rotary_dim", None) - - attention_factor = 1.0 # Unused in this type of RoPE - - # Compute the inverse frequencies - inv_freq = 1.0 / ( - base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) - ) - return inv_freq, attention_factor + pass class MiniMaxM2Attention(Glm4MoeAttention): From cb17f622f2f508efd4160f0d4d5667103cdb553f Mon Sep 17 00:00:00 2001 From: xuebi Date: Wed, 5 Nov 2025 17:14:44 +0800 Subject: [PATCH 6/9] update: some fix Signed-off-by: xuebi --- docs/source/en/model_doc/minimax_m2.md | 11 +++++------ .../models/minimax_m2/configuration_minimax_m2.py | 9 ++++----- .../models/minimax_m2/modular_minimax_m2.py | 9 ++++----- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/docs/source/en/model_doc/minimax_m2.md b/docs/source/en/model_doc/minimax_m2.md index f72b796c30a3..c5c4d12944e6 100644 --- a/docs/source/en/model_doc/minimax_m2.md +++ b/docs/source/en/model_doc/minimax_m2.md @@ -59,25 +59,24 @@ print(response) ## MiniMaxM2ForCausalLM [[autodoc]] MiniMaxM2ForCausalLM + - forward ## MiniMaxM2ForQuestionAnswering [[autodoc]] MiniMaxM2ForQuestionAnswering + - forward ## MiniMaxM2Model [[autodoc]] MiniMaxM2Model - forward -## MiniMaxM2PreTrainedModel - -[[autodoc]] MiniMaxM2PreTrainedModel - - forward - ## MiniMaxM2ForSequenceClassification [[autodoc]] MiniMaxM2ForSequenceClassification + - forward ## MiniMaxM2ForTokenClassification -[[autodoc]] MiniMaxM2ForTokenClassification \ No newline at end of file +[[autodoc]] MiniMaxM2ForTokenClassification + - forward diff --git a/src/transformers/models/minimax_m2/configuration_minimax_m2.py b/src/transformers/models/minimax_m2/configuration_minimax_m2.py index 8b103ceb9fac..c5a8ff2f60c4 100644 --- a/src/transformers/models/minimax_m2/configuration_minimax_m2.py +++ b/src/transformers/models/minimax_m2/configuration_minimax_m2.py @@ -98,8 +98,8 @@ class MiniMaxM2Config(PreTrainedConfig): Dictionary containing the configuration parameters for the RoPE embeddings. The dictionaty should contain a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE with longer `max_position_embeddings`. - rotary_dim (`int`, *optional*, defaults to `None`): - The dimension of the rotary embeddings. If not specified, will default to `head_dim`. + use_qk_norm (`Optional`, *optional*, defaults to `True`): + Whether to use layer normalization on the query and key states. ```python >>> from transformers import MiniMaxM2Model, MiniMaxM2Config @@ -161,7 +161,6 @@ def __init__( router_aux_loss_coef: Optional[float] = 0.001, router_jitter_noise: Optional[float] = 0.0, rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None, - rotary_dim: Optional[int] = None, use_qk_norm: Optional[bool] = True, **kwargs, ): @@ -199,12 +198,12 @@ def __init__( standardize_rope_params(self, rope_theta=rope_theta) rope_config_validation(self) - self.rotary_dim = rotary_dim if rotary_dim is not None else head_dim + rotary_dim = kwargs.pop("rotary_dim", head_dim) self.use_qk_norm = use_qk_norm self.partial_rotary_factor = kwargs.pop("partial_rotary_factor", 1.0) if self.head_dim is not None: - self.partial_rotary_factor = self.rotary_dim / self.head_dim + self.partial_rotary_factor = rotary_dim / self.head_dim super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/minimax_m2/modular_minimax_m2.py b/src/transformers/models/minimax_m2/modular_minimax_m2.py index 7481abcd69b9..dd408667bd28 100644 --- a/src/transformers/models/minimax_m2/modular_minimax_m2.py +++ b/src/transformers/models/minimax_m2/modular_minimax_m2.py @@ -118,8 +118,8 @@ class MiniMaxM2Config(PreTrainedConfig): Dictionary containing the configuration parameters for the RoPE embeddings. The dictionaty should contain a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE with longer `max_position_embeddings`. - rotary_dim (`int`, *optional*, defaults to `None`): - The dimension of the rotary embeddings. If not specified, will default to `head_dim`. + use_qk_norm (`Optional`, *optional*, defaults to `True`): + Whether to use layer normalization on the query and key states. ```python >>> from transformers import MiniMaxM2Model, MiniMaxM2Config @@ -181,7 +181,6 @@ def __init__( router_aux_loss_coef: Optional[float] = 0.001, router_jitter_noise: Optional[float] = 0.0, rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None, - rotary_dim: Optional[int] = None, use_qk_norm: Optional[bool] = True, **kwargs, ): @@ -219,12 +218,12 @@ def __init__( standardize_rope_params(self, rope_theta=rope_theta) rope_config_validation(self) - self.rotary_dim = rotary_dim if rotary_dim is not None else head_dim + rotary_dim = kwargs.pop("rotary_dim", head_dim) self.use_qk_norm = use_qk_norm self.partial_rotary_factor = kwargs.pop("partial_rotary_factor", 1.0) if self.head_dim is not None: - self.partial_rotary_factor = self.rotary_dim / self.head_dim + self.partial_rotary_factor = rotary_dim / self.head_dim super().__init__( pad_token_id=pad_token_id, From f6775d805dad068f7864181ca7715fe405dc5a6b Mon Sep 17 00:00:00 2001 From: xuebi Date: Wed, 5 Nov 2025 17:40:24 +0800 Subject: [PATCH 7/9] fix: import Unpack from processing_utils Signed-off-by: xuebi --- src/transformers/models/minimax_m2/modeling_minimax_m2.py | 3 ++- src/transformers/models/minimax_m2/modular_minimax_m2.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/minimax_m2/modeling_minimax_m2.py b/src/transformers/models/minimax_m2/modeling_minimax_m2.py index b42d3b95e067..f75ac830902b 100644 --- a/src/transformers/models/minimax_m2/modeling_minimax_m2.py +++ b/src/transformers/models/minimax_m2/modeling_minimax_m2.py @@ -20,7 +20,7 @@ # limitations under the License. from collections.abc import Callable -from typing import Optional, Union, Unpack +from typing import Optional, Union import torch from torch import nn @@ -40,6 +40,7 @@ from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import OutputRecorder, check_model_inputs from .configuration_minimax_m2 import MiniMaxM2Config diff --git a/src/transformers/models/minimax_m2/modular_minimax_m2.py b/src/transformers/models/minimax_m2/modular_minimax_m2.py index dd408667bd28..955b6f3b89f3 100644 --- a/src/transformers/models/minimax_m2/modular_minimax_m2.py +++ b/src/transformers/models/minimax_m2/modular_minimax_m2.py @@ -14,7 +14,7 @@ # limitations under the License. from collections.abc import Callable -from typing import Optional, Unpack +from typing import Optional import torch from torch import nn @@ -24,6 +24,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack from ..glm4_moe.modeling_glm4_moe import ( Glm4MoeAttention, Glm4MoeRotaryEmbedding, From 73904ee73749c4cff2112b885fa05cfb42558031 Mon Sep 17 00:00:00 2001 From: Roger Young <42564206+rogeryoungh@users.noreply.github.com> Date: Thu, 6 Nov 2025 12:42:48 +0800 Subject: [PATCH 8/9] update: apply suggestions from code review Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> --- docs/source/en/model_doc/minimax_m2.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/minimax_m2.md b/docs/source/en/model_doc/minimax_m2.md index c5c4d12944e6..51c168b2f29b 100644 --- a/docs/source/en/model_doc/minimax_m2.md +++ b/docs/source/en/model_doc/minimax_m2.md @@ -22,7 +22,7 @@ limitations under the License. ## Overview -MiniMax-M2 redefines efficiency for agents. It's a compact, fast, and cost-effective MoE model (230 billion total parameters with 10 billion active parameters) built for elite performance in coding and agentic tasks, all while maintaining powerful general intelligence. With just 10 billion activated parameters, MiniMax-M2 provides the sophisticated, end-to-end tool use performance expected from today's leading models, but in a streamlined form factor that makes deployment and scaling easier than ever. +MiniMax-M2 is a compact, fast, and cost-effective MoE model (230 billion total parameters with 10 billion active parameters) built for elite performance in coding and agentic tasks, all while maintaining powerful general intelligence. With just 10 billion activated parameters, MiniMax-M2 provides the sophisticated, end-to-end tool use performance expected from today's leading models, but in a streamlined form factor that makes deployment and scaling easier than ever. For more details refer to the [release blog post](https://www.minimax.io/news/minimax-m2). From 6b7e39732764ed7db3872ef19c5c23cb205f4f23 Mon Sep 17 00:00:00 2001 From: xuebi Date: Thu, 6 Nov 2025 13:13:06 +0800 Subject: [PATCH 9/9] update: remove MiniMaxM2DecoderLayer and MiniMaxM2MLP Signed-off-by: xuebi --- .../models/minimax_m2/modular_minimax_m2.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/transformers/models/minimax_m2/modular_minimax_m2.py b/src/transformers/models/minimax_m2/modular_minimax_m2.py index 955b6f3b89f3..231625e413c8 100644 --- a/src/transformers/models/minimax_m2/modular_minimax_m2.py +++ b/src/transformers/models/minimax_m2/modular_minimax_m2.py @@ -32,13 +32,11 @@ eager_attention_forward, ) from ..mixtral.modeling_mixtral import ( - MixtralDecoderLayer, MixtralExperts, MixtralForCausalLM, MixtralForQuestionAnswering, MixtralForSequenceClassification, MixtralForTokenClassification, - MixtralMLP, MixtralModel, MixtralPreTrainedModel, MixtralRMSNorm, @@ -235,10 +233,6 @@ def __init__( ) -class MiniMaxM2MLP(MixtralMLP): - pass - - class MiniMaxM2Experts(MixtralExperts): pass @@ -345,10 +339,6 @@ def forward( return attn_output, attn_weights -class MiniMaxM2DecoderLayer(MixtralDecoderLayer): - pass - - class MiniMaxM2PreTrainedModel(MixtralPreTrainedModel): pass