diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py
index 51c9f68e43a..5c83b980d4e 100644
--- a/benchmarks/kernels/benchmark_moe.py
+++ b/benchmarks/kernels/benchmark_moe.py
@@ -576,7 +576,11 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
- elif config.architectures[0] in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"):
+ elif config.architectures[0] in (
+ "DeepseekV3ForCausalLM",
+ "DeepseekV2ForCausalLM",
+ "Glm4MoeForCausalLM",
+ ):
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py
index dba1f3943b9..4ed69009014 100644
--- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py
+++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py
@@ -318,6 +318,7 @@ def main(args: argparse.Namespace):
elif (
config.architectures[0] == "DeepseekV3ForCausalLM"
or config.architectures[0] == "DeepseekV2ForCausalLM"
+ or config.architectures[0] == "Glm4MoeForCausalLM"
):
E = config.n_routed_experts
topk = config.num_experts_per_tok
diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 42afaeac0e8..55d5d1f9dea 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -574,6 +574,7 @@ Specified using `--task generate`.
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
| `GLM4VForCausalLM`^ | GLM-4V | T + I | `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + IE+ + VE+ | `THUDM/GLM-4.1V-9B-Thinkg`, etc. | ✅︎ | ✅︎ | ✅︎ |
+| `Glm4MoeForCausalLM` | GLM-4-MoE | T + IE+ + VE+ | `THUDM/GLM-4-MoE-100B-A10B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ |
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ |
diff --git a/tests/models/registry.py b/tests/models/registry.py
index d2e70e291df..fcd43430d3b 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -357,6 +357,7 @@ def check_available_online(
trust_remote_code=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
"Glm4vForConditionalGeneration": _HfExamplesInfo("THUDM/GLM-4.1V-9B-Thinking", min_transformers_version="4.53"), # noqa: E501
+ "Glm4MoeForCausalLM": _HfExamplesInfo("THUDM/GLM-4-MoE-100B-A10B", min_transformers_version="4.54"), # noqa: E501
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501
max_transformers_version="4.48", # noqa: E501
@@ -475,6 +476,8 @@ def check_available_online(
is_available_online=False,
speculative_model="openbmb/MiniCPM-2B-sft-bf16",
tokenizer="openbmb/MiniCPM-2B-sft-bf16"),
+ "Glm4MoeMTPModel": _HfExamplesInfo("THUDM/GLM-4-MoE",
+ speculative_model="THUDM/GLM-4-MoE"),
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
trust_remote_code=True,
speculative_model="XiaomiMiMo/MiMo-7B-RL")
diff --git a/vllm/config.py b/vllm/config.py
index 6c56ac1eec8..2deed558949 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -2515,7 +2515,8 @@ def __post_init__(self):
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
- "mlp_speculator", "draft_model", "deepseek_mtp"]
+ "mlp_speculator", "draft_model", "deepseek_mtp",
+ "glm4_moe_mtp"]
SpeculativeAcceptanceMethod = Literal["rejection_sampler",
"typical_acceptance_sampler"]
@@ -2656,7 +2657,13 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
"n_predict": n_predict,
"architectures": ["DeepSeekMTPModel"]
})
-
+ if hf_config.architectures[0] == "Glm4MoeForCausalLM":
+ hf_config.model_type = "glm4_moe_mtp"
+ n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
+ hf_config.update({
+ "n_predict": n_predict,
+ "architectures": ["Glm4MoeMTPForCausalLM"]
+ })
if hf_config.architectures[0] == "MiMoForCausalLM":
hf_config.model_type = "mimo_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
@@ -2665,8 +2672,6 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
"n_predict": n_predict,
"architectures": ["MiMoMTPModel"]
})
- return hf_config
-
return hf_config
def __post_init__(self):
@@ -2683,10 +2688,8 @@ def __post_init__(self):
# TODO(Shangming): Refactor mtp configuration logic when supporting
# mtp acceleration for more models besides deepseek_v3
if self.target_model_config and \
- (self.target_model_config.hf_text_config.model_type \
- == "deepseek_v3" or
- self.target_model_config.hf_text_config.model_type \
- == "mimo"):
+ (self.target_model_config.hf_text_config.model_type in
+ ('deepseek_v3', 'mimo', 'glm4_moe')):
# use the draft model from the same model:
self.model = self.target_model_config.model
elif self.method in ("ngram", "[ngram]"):
@@ -2775,8 +2778,10 @@ def __post_init__(self):
elif (self.draft_model_config.hf_config.model_type ==
"mlp_speculator"):
self.method = "mlp_speculator"
- elif (self.draft_model_config.hf_config.model_type ==
- "deepseek_mtp"):
+ elif (self.draft_model_config.hf_config.model_type
+ == "deepseek_mtp"
+ or self.draft_model_config.hf_config.model_type
+ == "glm4_moe_mtp"):
self.method = "deepseek_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 7b73060e349..3d7ee91198c 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -1421,7 +1421,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
is_ngram_enabled = True
elif speculative_method == "medusa":
is_medusa_enabled = True
- elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"):
+ elif speculative_method in ("eagle", "eagle3", "deepseek_mtp",
+ "glm4_moe_mtp"):
is_eagle_enabled = True
else:
speculative_model = self.speculative_config.get("model")
diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py
new file mode 100644
index 00000000000..a7db4fc7750
--- /dev/null
+++ b/vllm/model_executor/models/glm4_moe.py
@@ -0,0 +1,670 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# Copyright 2025 The ZhipuAI Team.
+# Copyright 2023 The vLLM team.
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only GLM-4-MOE model compatible with HuggingFace weights."""
+import typing
+from collections.abc import Callable, Iterable
+from typing import Any, Optional, Union
+
+import torch
+from torch import nn
+from transformers import PretrainedConfig
+
+from vllm.attention import Attention
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
+from vllm.distributed import (get_ep_group, get_pp_group,
+ get_tensor_model_parallel_world_size)
+from vllm.logger import init_logger
+from vllm.model_executor.layers.activation import SiluAndMul
+from vllm.model_executor.layers.fused_moe import FusedMoE
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
+ QKVParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import (
+ default_weight_loader, maybe_remap_kv_scale_name)
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.sequence import IntermediateTensors
+
+from .interfaces import SupportsPP
+from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
+ make_empty_intermediate_tensors_factory, make_layers,
+ maybe_prefix)
+
+logger = init_logger(__name__)
+
+
+class Glm4MoeMLP(nn.Module):
+
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ quant_config: Optional[QuantizationConfig] = None,
+ reduce_results: bool = True,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.gate_up_proj = MergedColumnParallelLinear(
+ hidden_size, [intermediate_size] * 2,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate_up_proj")
+ self.down_proj = RowParallelLinear(intermediate_size,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ reduce_results=reduce_results,
+ prefix=f"{prefix}.down_proj")
+ if hidden_act != "silu":
+ raise ValueError(f"Unsupported activation: {hidden_act}. "
+ "Only silu is supported for now.")
+ self.act_fn = SiluAndMul()
+
+ def forward(self, x):
+ gate_up, _ = self.gate_up_proj(x)
+ x = self.act_fn(gate_up)
+ x, _ = self.down_proj(x)
+ return x
+
+
+class Glm4MoE(nn.Module):
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ enable_eplb: bool = False,
+ ):
+ super().__init__()
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.routed_scaling_factor = config.routed_scaling_factor
+
+ self.ep_group = get_ep_group().device_group
+ self.ep_rank = self.ep_group.rank()
+ self.ep_size = self.ep_group.size()
+ self.n_routed_experts: int = config.n_routed_experts
+ self.n_shared_experts: int = config.n_shared_experts
+
+ if config.hidden_act != "silu":
+ raise ValueError(f"Unsupported activation: {config.hidden_act}. "
+ "Only silu is supported for now.")
+
+ self.gate = ReplicatedLinear(config.hidden_size,
+ config.n_routed_experts,
+ bias=False,
+ quant_config=None,
+ prefix=f"{prefix}.gate")
+
+ # noaux_tc is not set in transformers new config now
+ self.gate.e_score_correction_bias = (nn.Parameter(
+ torch.empty(config.n_routed_experts)))
+
+ # Load balancing settings.
+ vllm_config = get_current_vllm_config()
+ parallel_config = vllm_config.parallel_config
+ self.enable_eplb = enable_eplb
+
+ self.n_redundant_experts = parallel_config.num_redundant_experts
+ self.n_logical_experts = self.n_routed_experts
+ self.n_physical_experts = (self.n_logical_experts +
+ self.n_redundant_experts)
+ self.n_local_physical_experts = self.n_physical_experts // self.ep_size
+
+ self.physical_expert_start = (self.ep_rank *
+ self.n_local_physical_experts)
+ self.physical_expert_end = (self.physical_expert_start +
+ self.n_local_physical_experts)
+
+ self.experts = FusedMoE(
+ num_experts=config.n_routed_experts,
+ top_k=config.num_experts_per_tok,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.moe_intermediate_size,
+ reduce_results=False,
+ renormalize=config.norm_topk_prob,
+ quant_config=quant_config,
+ use_grouped_topk=True,
+ num_expert_group=config.n_group,
+ topk_group=config.topk_group,
+ prefix=f"{prefix}.experts",
+ scoring_func="sigmoid",
+ e_score_correction_bias=self.gate.e_score_correction_bias,
+ enable_eplb=self.enable_eplb,
+ num_redundant_experts=self.n_redundant_experts)
+
+ if config.n_shared_experts is not None:
+ intermediate_size = (config.moe_intermediate_size *
+ config.n_shared_experts)
+ self.shared_experts = Glm4MoeMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=intermediate_size,
+ hidden_act=config.hidden_act,
+ quant_config=quant_config,
+ reduce_results=self.experts.must_reduce_shared_expert_outputs(
+ ),
+ prefix=f"{prefix}.shared_experts",
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ num_tokens, hidden_dim = hidden_states.shape
+ hidden_states = hidden_states.view(-1, hidden_dim)
+
+ if self.n_shared_experts is not None:
+ shared_output = self.shared_experts(hidden_states)
+ router_logits, _ = self.gate(hidden_states)
+ final_hidden_states = self.experts(
+ hidden_states=hidden_states,
+ router_logits=router_logits) * self.routed_scaling_factor
+ if shared_output is not None:
+ final_hidden_states = final_hidden_states + shared_output
+ if self.tp_size > 1:
+ final_hidden_states = (
+ self.experts.maybe_all_reduce_tensor_model_parallel(
+ final_hidden_states))
+ return final_hidden_states.view(num_tokens, hidden_dim)
+
+
+class Glm4MoeAttention(nn.Module):
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ rope_theta: float = 10000,
+ rope_scaling: Optional[dict[str, Any]] = None,
+ max_position_embeddings: int = 131072,
+ head_dim: Optional[int] = None,
+ rms_norm_eps: float = 1e-05,
+ qkv_bias: bool = False,
+ use_qk_norm: bool = False,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.hidden_size = hidden_size
+ tp_size = get_tensor_model_parallel_world_size()
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+ self.total_num_kv_heads = num_kv_heads
+ if self.total_num_kv_heads >= tp_size:
+ # Number of KV heads is greater than TP size, so we partition
+ # the KV heads across multiple tensor parallel GPUs.
+ assert self.total_num_kv_heads % tp_size == 0
+ else:
+ # Number of KV heads is less than TP size, so we replicate
+ # the KV heads across multiple tensor parallel GPUs.
+ assert tp_size % self.total_num_kv_heads == 0
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+ self.head_dim = head_dim or (hidden_size // self.total_num_heads)
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim**-0.5
+ self.rope_theta = rope_theta
+ self.max_position_embeddings = max_position_embeddings
+ self.use_qk_norm = use_qk_norm
+
+ self.qkv_proj = QKVParallelLinear(hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=qkv_bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj")
+
+ self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj")
+
+ partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=max_position_embeddings,
+ base=rope_theta,
+ rope_scaling=rope_scaling,
+ partial_rotary_factor=partial_rotary_factor,
+ )
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ num_kv_heads=self.num_kv_heads,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.attn",
+ )
+
+ if self.use_qk_norm:
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ qkv, _ = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ if self.use_qk_norm:
+ q = self.q_norm(q.reshape(-1, self.num_heads,
+ self.head_dim)).reshape(q.shape)
+ k = self.k_norm(k.reshape(-1, self.num_kv_heads,
+ self.head_dim)).reshape(k.shape)
+
+ q, k = self.rotary_emb(positions, q, k)
+ attn_output = self.attn(q, k, v)
+ output, _ = self.o_proj(attn_output)
+ return output
+
+
+class Glm4MoeDecoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ enable_eplb: bool = False,
+ ) -> None:
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ rope_theta = getattr(config, "rope_theta", 10000)
+ rope_scaling = getattr(config, "rope_scaling", None)
+ max_position_embeddings = getattr(config, "max_position_embeddings",
+ 131072)
+ # DecoderLayers are created with `make_layers` which passes the prefix
+ # with the layer's index.
+ layer_idx = int(prefix.split(sep='.')[-1])
+ self.layer_idx = layer_idx
+
+ self.self_attn = Glm4MoeAttention(
+ config=config,
+ hidden_size=self.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=config.num_key_value_heads,
+ rope_theta=rope_theta,
+ rope_scaling=rope_scaling,
+ max_position_embeddings=max_position_embeddings,
+ head_dim=config.head_dim,
+ rms_norm_eps=config.rms_norm_eps,
+ qkv_bias=config.attention_bias,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.self_attn",
+ use_qk_norm=config.use_qk_norm,
+ )
+
+ if (config.n_routed_experts is not None
+ and layer_idx >= config.first_k_dense_replace):
+ self.mlp = Glm4MoE(
+ config=config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ enable_eplb=enable_eplb,
+ )
+ else:
+ self.mlp = Glm4MoeMLP(hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp")
+
+ self.input_layernorm = RMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+ self.routed_scaling_factor = config.routed_scaling_factor
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: Optional[torch.Tensor],
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ else:
+ hidden_states, residual = self.input_layernorm(
+ hidden_states, residual)
+ hidden_states = self.self_attn(positions=positions,
+ hidden_states=hidden_states)
+ hidden_states, residual = self.post_attention_layernorm(
+ hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+ return hidden_states, residual
+
+
+@support_torch_compile
+class Glm4MoeModel(nn.Module):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+
+ config = vllm_config.model_config.hf_config
+ cache_config = vllm_config.cache_config
+ quant_config = vllm_config.quant_config
+ enable_eplb = vllm_config.parallel_config.enable_eplb
+ self.config = config
+
+ self.vocab_size = config.vocab_size
+
+ if get_pp_group().is_first_rank:
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.embed_tokens")
+ else:
+ self.embed_tokens = PPMissingLayer()
+
+ self.start_layer, self.end_layer, self.layers = make_layers(
+ config.num_hidden_layers,
+ lambda prefix: Glm4MoeDecoderLayer(
+ config=config,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=prefix,
+ enable_eplb=enable_eplb,
+ ),
+ prefix=f"{prefix}.layers")
+
+ if get_pp_group().is_last_rank:
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ else:
+ self.norm = PPMissingLayer()
+ self.make_empty_intermediate_tensors = (
+ make_empty_intermediate_tensors_factory(
+ ["hidden_states", "residual"], config.hidden_size))
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.embed_tokens(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ else:
+ hidden_states = self.get_input_embeddings(input_ids)
+ residual = None
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"]
+ residual = intermediate_tensors["residual"]
+
+ for i in range(self.start_layer, self.end_layer):
+ layer = self.layers[i]
+ hidden_states, residual = layer(positions, hidden_states, residual)
+
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors({
+ "hidden_states": hidden_states,
+ "residual": residual
+ })
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+ def make_empty_intermediate_tensors(
+ self, batch_size: int, dtype: torch.dtype,
+ device: torch.device) -> IntermediateTensors:
+ return IntermediateTensors({
+ "hidden_states":
+ torch.zeros((batch_size, self.config.hidden_size),
+ dtype=dtype,
+ device=device),
+ "residual":
+ torch.zeros((batch_size, self.config.hidden_size),
+ dtype=dtype,
+ device=device),
+ })
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ]
+
+ # Params for weights, fp8 weight scales, fp8 activation scales
+ # (param_name, weight_name, expert_id, shard_id)
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
+ ckpt_gate_proj_name="gate_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="up_proj",
+ num_experts=self.config.n_routed_experts)
+
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+ for name, loaded_weight in weights:
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ # Skip non-stacked layers and experts (experts handled below).
+ if weight_name not in name:
+ continue
+ # We have mlp.experts[0].gate_proj in the checkpoint.
+ # Since we handle the experts below in expert_params_mapping,
+ # we need to skip here BEFORE we update the name, otherwise
+ # name will be updated to mlp.experts[0].gate_up_proj, which
+ # will then be updated below in expert_params_mapping
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
+ if (("mlp.experts." in name) and name not in params_dict):
+ continue
+ name = name.replace(weight_name, param_name)
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ if is_pp_missing_parameter(name, self):
+ continue
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ is_expert_weight = False
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+ if weight_name not in name:
+ continue
+
+ # Anyway, this is an expert weight and should not be
+ # attempted to load as other weights later
+ is_expert_weight = True
+
+ # Do not modify `name` since the loop may continue here
+ # Instead, create a new variable
+ name_mapped = name.replace(weight_name, param_name)
+
+ if is_pp_missing_parameter(name_mapped, self):
+ continue
+
+ param = params_dict[name_mapped]
+ # We should ask the weight loader to return success or not
+ # here since otherwise we may skip experts with other
+ # available replicas.
+ weight_loader = typing.cast(Callable[..., bool],
+ param.weight_loader)
+ success = weight_loader(param,
+ loaded_weight,
+ name_mapped,
+ shard_id=shard_id,
+ expert_id=expert_id,
+ return_success=True)
+ if success:
+ name = name_mapped
+ break
+ else:
+ if is_expert_weight:
+ # We've checked that this is an expert weight
+ # However it's not mapped locally to this rank
+ # So we simply skip it
+ continue
+
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+
+ # Remapping the name of FP8 kv-scale.
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
+
+ if is_pp_missing_parameter(name, self):
+ continue
+
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+
+ return loaded_params
+
+
+class Glm4MoeForCausalLM(nn.Module, SupportsPP):
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": [
+ "gate_proj",
+ "up_proj",
+ ],
+ }
+
+ fall_back_to_pt_during_load = False
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ self.config = config
+ self.quant_config = quant_config
+ self.model = Glm4MoeModel(vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "model"))
+ if get_pp_group().is_last_rank:
+ self.lm_head = ParallelLMHead(config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config)
+ else:
+ self.lm_head = PPMissingLayer()
+ if self.config.tie_word_embeddings:
+ self.lm_head.weight = self.model.embed_tokens.weight
+ self.logits_processor = LogitsProcessor(config.vocab_size)
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors)
+ self.expert_weights = []
+
+ # Set MoE hyperparameters
+ self.num_moe_layers = (config.num_hidden_layers -
+ config.first_k_dense_replace)
+ self.num_expert_groups = config.n_group
+
+ self.moe_layers: list[FusedMoE] = []
+ for layer in self.model.layers:
+ assert isinstance(layer, Glm4MoeDecoderLayer)
+ if isinstance(layer.mlp, Glm4MoE):
+ self.moe_layers.append(layer.mlp.experts)
+
+ # Pick last one layer since the first ones may be dense layers.
+ example_moe = typing.cast(
+ Glm4MoE, self.model.layers[config.num_hidden_layers - 1].mlp)
+ self.num_logical_experts = example_moe.n_logical_experts
+ self.num_physical_experts = example_moe.n_physical_experts
+ self.num_local_physical_experts = example_moe.n_local_physical_experts
+ self.num_routed_experts = example_moe.n_routed_experts
+ self.num_shared_experts = example_moe.n_shared_experts
+ self.num_redundant_experts = example_moe.n_redundant_experts
+
+ def set_eplb_state(
+ self,
+ expert_load_view: torch.Tensor,
+ logical_to_physical_map: torch.Tensor,
+ logical_replica_count: torch.Tensor,
+ ) -> None:
+ for layer_idx, layer in enumerate(self.moe_layers):
+ # Register the expert weights.
+ self.expert_weights.append(layer.get_expert_weights())
+ layer.set_eplb_state(
+ moe_layer_idx=layer_idx,
+ expert_load_view=expert_load_view,
+ logical_to_physical_map=logical_to_physical_map,
+ logical_replica_count=logical_replica_count,
+ )
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.get_input_embeddings(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ hidden_states = self.model(input_ids, positions, intermediate_tensors,
+ inputs_embeds)
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ logits = self.logits_processor(self.lm_head, hidden_states,
+ sampling_metadata)
+ return logits
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py
new file mode 100644
index 00000000000..dde060c3561
--- /dev/null
+++ b/vllm/model_executor/models/glm4_moe_mtp.py
@@ -0,0 +1,285 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from collections.abc import Iterable
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from transformers import PretrainedConfig
+
+from vllm.config import CacheConfig, VllmConfig
+from vllm.model_executor.layers.fused_moe import FusedMoE
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.sequence import IntermediateTensors
+
+from .deepseek_v2 import get_spec_layer_idx_from_weight_name
+from .glm4_moe import Glm4MoeDecoderLayer
+from .interfaces import SupportsPP
+from .utils import maybe_prefix
+
+
+class SharedHead(nn.Module):
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ ) -> None:
+ super().__init__()
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.head = ParallelLMHead(config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return self.norm(hidden_states)
+
+
+class Glm4MoeMultiTokenPredictorLayer(nn.Module):
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ prefix: str,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ ) -> None:
+ super().__init__()
+ self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.eh_proj = nn.Linear(config.hidden_size * 2,
+ config.hidden_size,
+ bias=False)
+ self.shared_head = SharedHead(config=config, quant_config=quant_config)
+ self.mtp_block = Glm4MoeDecoderLayer(config=config,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=prefix)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ previous_hidden_states: torch.Tensor,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ spec_step_index: int = 0,
+ ) -> torch.Tensor:
+ assert inputs_embeds is not None
+ # masking inputs at position 0, as not needed by MTP
+ inputs_embeds[positions == 0] = 0
+ inputs_embeds = self.enorm(inputs_embeds)
+ previous_hidden_states = self.hnorm(previous_hidden_states)
+
+ hidden_states = self.eh_proj(
+ torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
+
+ hidden_states, residual = self.mtp_block(positions=positions,
+ hidden_states=hidden_states,
+ residual=None)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class Glm4MoeMultiTokenPredictor(nn.Module):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ self.mtp_start_layer_idx = config.num_hidden_layers
+ self.num_mtp_layers = config.num_nextn_predict_layers
+ # to map the exact layer index from weights
+ self.layers = torch.nn.ModuleDict({
+ str(idx):
+ Glm4MoeMultiTokenPredictorLayer(
+ config,
+ f"{prefix}.layers.{idx}",
+ cache_config=vllm_config.cache_config,
+ quant_config=vllm_config.quant_config,
+ )
+ for idx in range(self.mtp_start_layer_idx,
+ self.mtp_start_layer_idx + self.num_mtp_layers)
+ })
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ )
+ self.logits_processor = LogitsProcessor(config.vocab_size)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ previous_hidden_states: torch.Tensor,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ spec_step_idx: int = 0,
+ ) -> torch.Tensor:
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ current_step_idx = (spec_step_idx % self.num_mtp_layers)
+ return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
+ input_ids,
+ positions,
+ previous_hidden_states,
+ inputs_embeds,
+ current_step_idx,
+ )
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ spec_step_idx: int = 0,
+ ) -> torch.Tensor:
+ current_step_idx = (spec_step_idx % self.num_mtp_layers)
+ mtp_layer = self.layers[str(self.mtp_start_layer_idx +
+ current_step_idx)]
+ logits = self.logits_processor(mtp_layer.shared_head.head,
+ mtp_layer.shared_head(hidden_states),
+ sampling_metadata)
+ return logits
+
+
+class Glm4MoeMTP(nn.Module, SupportsPP):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ self.config = vllm_config.model_config.hf_config
+ self.model = Glm4MoeMultiTokenPredictor(vllm_config=vllm_config,
+ prefix=maybe_prefix(
+ prefix, "model"))
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ previous_hidden_states: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ spec_step_idx: int = 0,
+ ) -> torch.Tensor:
+ hidden_states = self.model(input_ids, positions,
+ previous_hidden_states, inputs_embeds,
+ spec_step_idx)
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ spec_step_idx: int = 0,
+ ) -> Optional[torch.Tensor]:
+ return self.model.compute_logits(hidden_states, sampling_metadata,
+ spec_step_idx)
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ]
+
+ # Params for weights, fp8 weight scales, fp8 activation scales
+ # (param_name, weight_name, expert_id, shard_id)
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
+ ckpt_gate_proj_name="gate_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="up_proj",
+ num_experts=self.config.n_routed_experts)
+
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+ for name, loaded_weight in weights:
+ spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
+ if spec_layer is None:
+ continue
+ name = self._rewrite_spec_layer_name(spec_layer, name)
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ # Skip non-stacked layers and experts (experts handled below).
+ if weight_name not in name:
+ continue
+ # We have mlp.experts[0].gate_proj in the checkpoint.
+ # Since we handle the experts below in expert_params_mapping,
+ # we need to skip here BEFORE we update the name, otherwise
+ # name will be updated to mlp.experts[0].gate_up_proj, which
+ # will then be updated below in expert_params_mapping
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
+ if (("mlp.experts." in name) and name not in params_dict):
+ continue
+ name = name.replace(weight_name, param_name)
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param,
+ loaded_weight,
+ name,
+ shard_id=shard_id,
+ expert_id=expert_id)
+ break
+ else:
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+
+ # According to DeepSeek-V3 Technical Report, MTP modules
+ # shares embedding layer. We only load the first weights.
+ if (spec_layer != self.model.mtp_start_layer_idx
+ and ".layers" not in name):
+ continue
+
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
+ """
+ Rewrite the weight name to match the format of the original model.
+ Add .mtp_block for modules in transformer layer block for spec layer
+ and rename shared layer weights to be top level.
+ """
+ spec_layer_weight_names = [
+ "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
+ ]
+ shared_weight_names = ["embed_tokens"]
+ spec_layer_weight = False
+ shared_weight = False
+ for weight_name in spec_layer_weight_names:
+ if weight_name in name:
+ spec_layer_weight = True
+ if weight_name in shared_weight_names:
+ shared_weight = True
+ break
+ if not spec_layer_weight:
+ # treat rest weights as weights for transformer layer block
+ name = name.replace(f"model.layers.{spec_layer}.",
+ f"model.layers.{spec_layer}.mtp_block.")
+ elif shared_weight:
+ # treat shared weights as top level weights
+ name = name.replace(f"model.layers.{spec_layer}.", "model.")
+ return name
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index bc936500bdc..1332609ff82 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -66,6 +66,7 @@
"Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"), # noqa: E501
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
+ "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
@@ -248,6 +249,7 @@
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
+ "Glm4MoeMTPForCausalLM": ("glm4_moe_mtp", "Glm4MoeMTP"),
"MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
}
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index b2926dbd185..6b6943d7643 100644
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -77,7 +77,8 @@ def __init__(
"mlp_speculator",
"eagle",
"deepseek_mtp",
- "mimo_mtp")) \
+ "glm4_moe_mtp",
+ "mimo_mtp")) \
else {"return_hidden_states": True}
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner