From e11b28a1e1c45e4b6b7c8bd380d705b1baa98b35 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 10 Jul 2025 13:43:03 +0800 Subject: [PATCH 01/17] init Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com> --- benchmarks/kernels/benchmark_moe.py | 24 +- tests/models/registry.py | 1 + vllm/model_executor/models/glm4_moe.py | 664 ++++++++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/transformers_utils/configs/ovis.py | 2 +- 5 files changed, 675 insertions(+), 17 deletions(-) create mode 100644 vllm/model_executor/models/glm4_moe.py diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 07af58d81c6..7ab63eaffba 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -563,22 +563,14 @@ def main(args: argparse.Namespace): if args.model_prefix: config = getattr(config, args.model_prefix) - if config.architectures[0] == "DbrxForCausalLM": - E = config.ffn_config.moe_num_experts - topk = config.ffn_config.moe_top_k - intermediate_size = config.ffn_config.ffn_hidden_size - shard_intermediate_size = 2 * intermediate_size // args.tp_size - elif config.architectures[0] == "JambaForCausalLM": - E = config.num_experts - topk = config.num_experts_per_tok - intermediate_size = config.intermediate_size - shard_intermediate_size = 2 * intermediate_size // args.tp_size - elif config.architectures[0] in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"): - E = config.n_routed_experts - topk = config.num_experts_per_tok - intermediate_size = config.moe_intermediate_size - shard_intermediate_size = 2 * intermediate_size // args.tp_size - elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"): + if config.architectures[0] in ( + "DbrxForCausalLM", + "JambaForCausalLM", + "DeepseekV3ForCausalLM", + "DeepseekV2ForCausalLMQwen2MoeForCausalLM", + "Qwen3MoeForCausalLM", + "Glm4MoeForCausalLM", + ): E = config.num_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size diff --git a/tests/models/registry.py b/tests/models/registry.py index 04fff03862f..7f5b2b96cf8 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -348,6 +348,7 @@ def check_available_online( trust_remote_code=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 "Glm4vForConditionalGeneration": _HfExamplesInfo("THUDM/GLM-4.1V-9B-Thinking", min_transformers_version="4.53"), # noqa: E501 + "Glm4MoeForCausalLM": _HfExamplesInfo("/model/GLM-4-MoE-100B-A10B", min_transformers_version="4.54"), # noqa: E501 "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m", extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501 max_transformers_version="4.48", # noqa: E501 diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py new file mode 100644 index 00000000000..dd9f738aeeb --- /dev/null +++ b/vllm/model_executor/models/glm4_moe.py @@ -0,0 +1,664 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The ZhipuAI Team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GLM-4-MOE model compatible with HuggingFace weights.""" + +from collections.abc import Iterable +from typing import Any, Optional, Union + +import torch +from torch import nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (AutoWeightsLoader, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class Glm4MoeMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Glm4MoeTopkRouter(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) + self.register_buffer("e_score_correction_bias",torch.zeros((self.n_routed_experts))) + + @torch.no_grad() + def get_topk_indices(self, scores): + scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) + + group_scores = ( + scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.n_group, self.n_routed_experts // self.n_group) + .reshape(-1, self.n_routed_experts) + ) + + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + return topk_indices + + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.config.hidden_size) + router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) + scores = router_logits.sigmoid() + + topk_indices = self.get_topk_indices(scores) + topk_weights = scores.gather(1, topk_indices) + + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + + topk_weights = topk_weights * self.routed_scaling_factor + return topk_indices, topk_weights + + +class Glm4MoeSparseMoeBlock(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.config = config + self.num_experts = config.n_routed_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + + if self.tp_size > self.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.num_experts}.") + + self.gate = Glm4MoeTopkRouter( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.gate" + ) + + self.experts = FusedMoE( + num_experts=self.num_experts, + top_k=self.top_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=self.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts" + ) + + self.shared_experts = Glm4MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size * config.n_shared_experts, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_experts" + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + + residuals = hidden_states + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + topk_indices, topk_weights = self.gate(hidden_states) + batch_size = hidden_states.shape[0] + router_logits = torch.zeros( + batch_size, self.num_experts, + device=hidden_states.device, + dtype=hidden_states.dtype + ) + + for i in range(batch_size): + router_logits[i, topk_indices[i]] = topk_weights[i] + routed_output = self.experts( + hidden_states=hidden_states, + router_logits=router_logits + ) + + if self.tp_size > 1: + routed_output = self.experts.maybe_all_reduce_tensor_model_parallel( + routed_output + ) + + shared_output = self.shared_experts(residuals.view(-1, hidden_dim)) + final_output = routed_output + shared_output + + return final_output.view(orig_shape) + + +class Glm4MoeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + head_dim: Optional[int] = None, + rms_norm_eps: float = 1e-06, + qkv_bias: bool = False, + add_qk_norm: bool = False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim ** -0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.add_qk_norm = add_qk_norm + + self.qkv_proj = QKVParallelLinear(hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + if self.add_qk_norm: + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + if self.add_qk_norm: + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) + q_by_head = self.q_norm(q_by_head) + q = q_by_head.view(q.shape) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) + k_by_head = self.k_norm(k_by_head) + k = k_by_head.view(k.shape) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Glm4MoeDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + + self.self_attn = Glm4MoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, 'attention_bias', False), + head_dim=getattr(config, 'head_dim', None), + add_qk_norm=getattr(config, 'add_qk_norm', False), # Add this + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + layer_idx = extract_layer_index(prefix) + if layer_idx >= getattr(config, "first_k_dense_replace", 1): + self.mlp = Glm4MoeSparseMoeBlock( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp" + ) + else: + self.mlp = Glm4MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp" + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class Glm4MoeModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=f"{prefix}.embed_tokens" + ) + + num_layers = config.num_hidden_layers + if hasattr(config, 'num_nextn_predict_layers'): + num_layers = config.num_hidden_layers - config.num_nextn_predict_layers + + self.start_layer, self.end_layer, self.layers = make_layers( + num_layers, + lambda prefix: Glm4MoeDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix + ), + prefix=f"{prefix}.layers", + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Skip loading extra parameters for GPTQ/modelopt models. + ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale", + ".v_scale", "_v_scale", ".weight_scale", + "_weight_scale", ".input_scale", "_input_scale") + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts + ) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if "gate.weight" in name and "experts" not in name: + if is_pp_missing_parameter(name, self): + continue + if name in params_dict: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name and "shared_experts" not in name: + continue + + name = name.replace(weight_name, param_name) + + # Skip loading extra parameters for GPTQ/modelopt models. + if name.endswith(ignore_suffixes) and name not in params_dict: + continue + + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) + break + else: + if "mlp.experts" in name and "shared_experts" not in name: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + if name.endswith(ignore_suffixes) and name not in params_dict: + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, name, + shard_id=shard_id, expert_id=expert_id) + loaded_params.add(name) + break + else: + # Handle other parameters + if name.endswith(ignore_suffixes) and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + # Remapping for FP8 kv-scale + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace(".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", + name, remapped_kv_scale_name, + ) + continue + else: + name = remapped_kv_scale_name + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class Glm4MoeForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + self.model = Glm4MoeModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model") + ) + + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config + ) + + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 27d47692985..c8d17d23d23 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -65,6 +65,7 @@ "Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"), # noqa: E501 "GlmForCausalLM": ("glm", "GlmForCausalLM"), "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"), + "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), diff --git a/vllm/transformers_utils/configs/ovis.py b/vllm/transformers_utils/configs/ovis.py index c2728f0ed64..db6050fac57 100644 --- a/vllm/transformers_utils/configs/ovis.py +++ b/vllm/transformers_utils/configs/ovis.py @@ -73,7 +73,7 @@ def __init__( IMAGE_ATOM_ID = -300 IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305] -AutoConfig.register("aimv2", AIMv2Config) +# AutoConfig.register("aimv2", AIMv2Config) # ---------------------------------------------------------------------- From 818db5926c023d455fe89cf2811a7e695dae47d4 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 10 Jul 2025 13:45:43 +0800 Subject: [PATCH 02/17] ovis aimv2 model type need changed, mark Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com> --- tests/models/registry.py | 2 +- vllm/transformers_utils/configs/ovis.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 7f5b2b96cf8..f3340acf040 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -348,7 +348,7 @@ def check_available_online( trust_remote_code=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 "Glm4vForConditionalGeneration": _HfExamplesInfo("THUDM/GLM-4.1V-9B-Thinking", min_transformers_version="4.53"), # noqa: E501 - "Glm4MoeForCausalLM": _HfExamplesInfo("/model/GLM-4-MoE-100B-A10B", min_transformers_version="4.54"), # noqa: E501 + "Glm4MoeForCausalLM": _HfExamplesInfo("THUDM/GLM-4-MoE-100B-A10B", min_transformers_version="4.54"), # noqa: E501 "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m", extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501 max_transformers_version="4.48", # noqa: E501 diff --git a/vllm/transformers_utils/configs/ovis.py b/vllm/transformers_utils/configs/ovis.py index db6050fac57..c2728f0ed64 100644 --- a/vllm/transformers_utils/configs/ovis.py +++ b/vllm/transformers_utils/configs/ovis.py @@ -73,7 +73,7 @@ def __init__( IMAGE_ATOM_ID = -300 IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305] -# AutoConfig.register("aimv2", AIMv2Config) +AutoConfig.register("aimv2", AIMv2Config) # ---------------------------------------------------------------------- From c6b8eb652c31cec3913d775efe04a73d5752505a Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 10 Jul 2025 14:01:25 +0800 Subject: [PATCH 03/17] format Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com> --- vllm/model_executor/models/glm4_moe.py | 394 +++++++++++++------------ 1 file changed, 201 insertions(+), 193 deletions(-) diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index dd9f738aeeb..832fff897fc 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -27,8 +27,8 @@ from typing import Any, Optional, Union import torch -from torch import nn import torch.nn.functional as F +from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention @@ -63,13 +63,13 @@ class Glm4MoeMLP(nn.Module): def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True, - prefix: str = "", + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -97,70 +97,81 @@ def forward(self, x): class Glm4MoeTopkRouter(nn.Module): - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.config = config - self.top_k = config.num_experts_per_tok - self.n_routed_experts = config.n_routed_experts - self.routed_scaling_factor = config.routed_scaling_factor - self.n_group = config.n_group - self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - - self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) - self.register_buffer("e_score_correction_bias",torch.zeros((self.n_routed_experts))) - - @torch.no_grad() - def get_topk_indices(self, scores): - scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) - - group_scores = ( - scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) - .topk(2, dim=-1)[0] - .sum(dim=-1) - ) - - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - - score_mask = ( - group_mask.unsqueeze(-1) - .expand(-1, self.n_group, self.n_routed_experts // self.n_group) - .reshape(-1, self.n_routed_experts) - ) - - scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) - topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] - return topk_indices - - def forward(self, hidden_states): - hidden_states = hidden_states.view(-1, self.config.hidden_size) - router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) - scores = router_logits.sigmoid() - - topk_indices = self.get_topk_indices(scores) - topk_weights = scores.gather(1, topk_indices) - - if self.norm_topk_prob: - denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 - topk_weights /= denominator - - topk_weights = topk_weights * self.routed_scaling_factor - return topk_indices, topk_weights + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, config.hidden_size))) + self.register_buffer( + "e_score_correction_bias", + torch.zeros((self.n_routed_experts), dtype=torch.float32)) + + @torch.no_grad() + def get_topk_indices(self, scores): + scores_for_choice = scores.view( + -1, + self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) + + group_scores = (scores_for_choice.view( + -1, self.n_group, + self.n_routed_experts // self.n_group).topk(2, + dim=-1)[0].sum(dim=-1)) + + group_idx = torch.topk(group_scores, + k=self.topk_group, + dim=-1, + sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + + score_mask = (group_mask.unsqueeze(-1).expand( + -1, self.n_group, self.n_routed_experts // self.n_group).reshape( + -1, self.n_routed_experts)) + + scores_for_choice = scores_for_choice.masked_fill( + ~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, + k=self.top_k, + dim=-1, + sorted=False)[1] + return topk_indices + + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.config.hidden_size) + router_logits = F.linear(hidden_states.type(torch.float32), + self.weight.type(torch.float32)) + scores = router_logits.sigmoid() + + topk_indices = self.get_topk_indices(scores) + topk_weights = scores.gather(1, topk_indices) + + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + + topk_weights = topk_weights * self.routed_scaling_factor + return topk_indices, topk_weights class Glm4MoeSparseMoeBlock(nn.Module): + def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() @@ -174,31 +185,27 @@ def __init__( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {self.num_experts}.") - self.gate = Glm4MoeTopkRouter( - config=config, - quant_config=quant_config, - prefix=f"{prefix}.gate" - ) + self.gate = Glm4MoeTopkRouter(config=config, + quant_config=quant_config, + prefix=f"{prefix}.gate") - self.experts = FusedMoE( - num_experts=self.num_experts, - top_k=self.top_k, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=self.norm_topk_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts" - ) + self.experts = FusedMoE(num_experts=self.num_experts, + top_k=self.top_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=self.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts") self.shared_experts = Glm4MoeMLP( hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size * config.n_shared_experts, + intermediate_size=config.moe_intermediate_size * + config.n_shared_experts, hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=False, - prefix=f"{prefix}.shared_experts" - ) + prefix=f"{prefix}.shared_experts") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -208,23 +215,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, hidden_dim) topk_indices, topk_weights = self.gate(hidden_states) batch_size = hidden_states.shape[0] - router_logits = torch.zeros( - batch_size, self.num_experts, - device=hidden_states.device, - dtype=hidden_states.dtype - ) + router_logits = torch.zeros(batch_size, + self.num_experts, + device=hidden_states.device, + dtype=hidden_states.dtype) for i in range(batch_size): router_logits[i, topk_indices[i]] = topk_weights[i] - routed_output = self.experts( - hidden_states=hidden_states, - router_logits=router_logits - ) + routed_output = self.experts(hidden_states=hidden_states, + router_logits=router_logits) if self.tp_size > 1: routed_output = self.experts.maybe_all_reduce_tensor_model_parallel( - routed_output - ) + routed_output) shared_output = self.shared_experts(residuals.view(-1, hidden_dim)) final_output = routed_output + shared_output @@ -235,20 +238,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Glm4MoeAttention(nn.Module): def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, - max_position_embeddings: int = 8192, - head_dim: Optional[int] = None, - rms_norm_eps: float = 1e-06, - qkv_bias: bool = False, - add_qk_norm: bool = False, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + head_dim: Optional[int] = None, + rms_norm_eps: float = 1e-06, + qkv_bias: bool = False, + add_qk_norm: bool = False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -269,7 +272,7 @@ def __init__( self.head_dim = head_dim or (hidden_size // self.total_num_heads) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim ** -0.5 + self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings self.add_qk_norm = add_qk_norm @@ -308,19 +311,21 @@ def __init__( self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) if self.add_qk_norm: - q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, + self.head_dim) q_by_head = self.q_norm(q_by_head) q = q_by_head.view(q.shape) - k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, + self.head_dim) k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) @@ -333,17 +338,18 @@ def forward( class Glm4MoeDecoderLayer(nn.Module): def __init__( - self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) self.self_attn = Glm4MoeAttention( hidden_size=self.hidden_size, @@ -355,7 +361,7 @@ def __init__( rms_norm_eps=config.rms_norm_eps, qkv_bias=getattr(config, 'attention_bias', False), head_dim=getattr(config, 'head_dim', None), - add_qk_norm=getattr(config, 'add_qk_norm', False), # Add this + add_qk_norm=getattr(config, 'add_qk_norm', False), cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", @@ -363,41 +369,41 @@ def __init__( layer_idx = extract_layer_index(prefix) if layer_idx >= getattr(config, "first_k_dense_replace", 1): - self.mlp = Glm4MoeSparseMoeBlock( - config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp" - ) + self.mlp = Glm4MoeSparseMoeBlock(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") else: - self.mlp = Glm4MoeMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp" - ) + self.mlp = Glm4MoeMLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp") - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states, residual = self.input_layernorm( + hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -419,8 +425,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, - prefix=f"{prefix}.embed_tokens" - ) + prefix=f"{prefix}.embed_tokens") num_layers = config.num_hidden_layers if hasattr(config, 'num_nextn_predict_layers'): @@ -428,31 +433,27 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( num_layers, - lambda prefix: Glm4MoeDecoderLayer( - config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix - ), + lambda prefix: Glm4MoeDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size - ) - ) + ["hidden_states", "residual"], config.hidden_size)) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -478,7 +479,8 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -499,8 +501,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts - ) + num_experts=self.config.n_routed_experts) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -511,7 +512,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: continue if name in params_dict: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) continue @@ -555,31 +557,39 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: name = name.replace(weight_name, param_name) if is_pp_missing_parameter(name, self): continue - if name.endswith(ignore_suffixes) and name not in params_dict: + if name.endswith( + ignore_suffixes) and name not in params_dict: continue if name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, name, - shard_id=shard_id, expert_id=expert_id) + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) loaded_params.add(name) break else: # Handle other parameters - if name.endswith(ignore_suffixes) and name not in params_dict: + if name.endswith( + ignore_suffixes) and name not in params_dict: continue + # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue # Remapping for FP8 kv-scale if name.endswith("kv_scale"): - remapped_kv_scale_name = name.replace(".kv_scale", ".attn.kv_scale") + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") if remapped_kv_scale_name not in params_dict: logger.warning_once( "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", - name, remapped_kv_scale_name, + name, + remapped_kv_scale_name, ) continue else: @@ -589,7 +599,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -618,47 +629,44 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config - self.model = Glm4MoeModel( - vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model") - ) + self.model = Glm4MoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - quant_config=quant_config - ) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors - ) + self.model.make_empty_intermediate_tensors) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) \ No newline at end of file + return loader.load_weights(weights) From f406a09a006e3861b082c1368ff81efcd78fb57a Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 10 Jul 2025 14:28:10 +0800 Subject: [PATCH 04/17] format Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com> --- vllm/model_executor/models/glm4_moe.py | 146 +++++++----------------- vllm/transformers_utils/configs/ovis.py | 2 +- 2 files changed, 44 insertions(+), 104 deletions(-) diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index 832fff897fc..cbeee1e22f9 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -27,7 +27,6 @@ from typing import Any, Optional, Union import torch -import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig @@ -41,6 +40,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -95,76 +95,6 @@ def forward(self, x): return x -class Glm4MoeTopkRouter(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.config = config - self.top_k = config.num_experts_per_tok - self.n_routed_experts = config.n_routed_experts - self.routed_scaling_factor = config.routed_scaling_factor - self.n_group = config.n_group - self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - - self.weight = nn.Parameter( - torch.empty((self.n_routed_experts, config.hidden_size))) - self.register_buffer( - "e_score_correction_bias", - torch.zeros((self.n_routed_experts), dtype=torch.float32)) - - @torch.no_grad() - def get_topk_indices(self, scores): - scores_for_choice = scores.view( - -1, - self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) - - group_scores = (scores_for_choice.view( - -1, self.n_group, - self.n_routed_experts // self.n_group).topk(2, - dim=-1)[0].sum(dim=-1)) - - group_idx = torch.topk(group_scores, - k=self.topk_group, - dim=-1, - sorted=False)[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - - score_mask = (group_mask.unsqueeze(-1).expand( - -1, self.n_group, self.n_routed_experts // self.n_group).reshape( - -1, self.n_routed_experts)) - - scores_for_choice = scores_for_choice.masked_fill( - ~score_mask.bool(), 0.0) - topk_indices = torch.topk(scores_for_choice, - k=self.top_k, - dim=-1, - sorted=False)[1] - return topk_indices - - def forward(self, hidden_states): - hidden_states = hidden_states.view(-1, self.config.hidden_size) - router_logits = F.linear(hidden_states.type(torch.float32), - self.weight.type(torch.float32)) - scores = router_logits.sigmoid() - - topk_indices = self.get_topk_indices(scores) - topk_weights = scores.gather(1, topk_indices) - - if self.norm_topk_prob: - denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 - topk_weights /= denominator - - topk_weights = topk_weights * self.routed_scaling_factor - return topk_indices, topk_weights - - class Glm4MoeSparseMoeBlock(nn.Module): def __init__( @@ -178,6 +108,7 @@ def __init__( self.config = config self.num_experts = config.n_routed_experts self.top_k = config.num_experts_per_tok + self.routed_scaling_factor = config.routed_scaling_factor self.norm_topk_prob = config.norm_topk_prob if self.tp_size > self.num_experts: @@ -185,18 +116,29 @@ def __init__( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {self.num_experts}.") - self.gate = Glm4MoeTopkRouter(config=config, - quant_config=quant_config, - prefix=f"{prefix}.gate") + self.gate = ReplicatedLinear(config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") - self.experts = FusedMoE(num_experts=self.num_experts, - top_k=self.top_k, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=self.norm_topk_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts") + self.gate.e_score_correction_bias = nn.Parameter( + torch.zeros(config.n_routed_experts, dtype=torch.float32)) + + self.experts = FusedMoE( + num_experts=self.num_experts, + top_k=self.top_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=self.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=getattr(config, 'scoring_func', 'sigmoid'), + e_score_correction_bias=self.gate.e_score_correction_bias) self.shared_experts = Glm4MoeMLP( hidden_size=config.hidden_size, @@ -208,31 +150,25 @@ def __init__( prefix=f"{prefix}.shared_experts") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - - residuals = hidden_states - orig_shape = hidden_states.shape - hidden_dim = hidden_states.shape[-1] + num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - topk_indices, topk_weights = self.gate(hidden_states) - batch_size = hidden_states.shape[0] - router_logits = torch.zeros(batch_size, - self.num_experts, - device=hidden_states.device, - dtype=hidden_states.dtype) - - for i in range(batch_size): - router_logits[i, topk_indices[i]] = topk_weights[i] - routed_output = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + + shared_output = self.shared_experts(hidden_states) + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + routed_output = self.experts( + hidden_states=hidden_states, + router_logits=router_logits) * self.routed_scaling_factor if self.tp_size > 1: routed_output = self.experts.maybe_all_reduce_tensor_model_parallel( routed_output) - shared_output = self.shared_experts(residuals.view(-1, hidden_dim)) final_output = routed_output + shared_output - return final_output.view(orig_shape) + return final_output.view(num_tokens, hidden_dim) class Glm4MoeAttention(nn.Module): @@ -317,7 +253,7 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - + # Add qk-norm if self.add_qk_norm: q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) @@ -367,6 +303,7 @@ def __init__( prefix=f"{prefix}.self_attn", ) + # `mlp_only_layers` in the config. layer_idx = extract_layer_index(prefix) if layer_idx >= getattr(config, "first_k_dense_replace", 1): self.mlp = Glm4MoeSparseMoeBlock(config=config, @@ -390,6 +327,7 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -402,6 +340,7 @@ def forward( hidden_states=hidden_states, ) + # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) hidden_states = self.mlp(hidden_states) @@ -555,8 +494,10 @@ def load_weights(self, weights: Iterable[tuple[str, if weight_name not in name: continue name = name.replace(weight_name, param_name) + # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue + # Skip loading extra parameters for GPTQ/modelopt models. if name.endswith( ignore_suffixes) and name not in params_dict: continue @@ -573,15 +514,14 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params.add(name) break else: - # Handle other parameters + # Skip loading extra parameters for GPTQ/modelopt models. if name.endswith( ignore_suffixes) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue - - # Remapping for FP8 kv-scale + # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( ".kv_scale", ".attn.kv_scale") diff --git a/vllm/transformers_utils/configs/ovis.py b/vllm/transformers_utils/configs/ovis.py index c2728f0ed64..db6050fac57 100644 --- a/vllm/transformers_utils/configs/ovis.py +++ b/vllm/transformers_utils/configs/ovis.py @@ -73,7 +73,7 @@ def __init__( IMAGE_ATOM_ID = -300 IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305] -AutoConfig.register("aimv2", AIMv2Config) +# AutoConfig.register("aimv2", AIMv2Config) # ---------------------------------------------------------------------- From efcff2bad35e65845f7a50a3d29d6ff6c1645c0f Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 10 Jul 2025 14:43:54 +0800 Subject: [PATCH 05/17] use ds loading(not work) --- vllm/model_executor/models/glm4_moe.py | 139 ++++++++++++------------- 1 file changed, 67 insertions(+), 72 deletions(-) diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index cbeee1e22f9..3b6e6ec301b 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -22,8 +22,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GLM-4-MOE model compatible with HuggingFace weights.""" - -from collections.abc import Iterable +import typing +from collections.abc import Callable, Iterable from typing import Any, Optional, Union import torch @@ -47,7 +47,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors @@ -429,11 +430,6 @@ def load_weights(self, weights: Iterable[tuple[str, ("gate_up_proj", "up_proj", 1), ] - # Skip loading extra parameters for GPTQ/modelopt models. - ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale", - ".v_scale", "_v_scale", ".weight_scale", - "_weight_scale", ".input_scale", "_input_scale") - # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) expert_params_mapping = FusedMoE.make_expert_params_mapping( @@ -444,18 +440,10 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "gate.weight" in name and "experts" not in name: - if is_pp_missing_parameter(name, self): - continue - if name in params_dict: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - continue + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). @@ -467,82 +455,77 @@ def load_weights(self, weights: Iterable[tuple[str, # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if "mlp.experts" in name and "shared_experts" not in name: + if (("mlp.experts." in name) and name not in params_dict): continue - name = name.replace(weight_name, param_name) - - # Skip loading extra parameters for GPTQ/modelopt models. - if name.endswith(ignore_suffixes) and name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue - # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue - if name not in params_dict: - continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) - loaded_params.add(name) break else: - if "mlp.experts" in name and "shared_experts" not in name: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - # Skip loading extra parameters for GPTQ/modelopt models. - if name.endswith( - ignore_suffixes) and name not in params_dict: - continue - if name not in params_dict: - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) - loaded_params.add(name) + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast(Callable[..., bool], + param.weight_loader) + success = weight_loader(param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True) + if success: + name = name_mapped break else: - # Skip loading extra parameters for GPTQ/modelopt models. - if name.endswith( - ignore_suffixes) and name not in params_dict: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it continue - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue + # Remapping the name of FP8 kv-scale. - if name.endswith("kv_scale"): - remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") - if remapped_kv_scale_name not in params_dict: - logger.warning_once( - "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", - name, - remapped_kv_scale_name, - ) - continue - else: - name = remapped_kv_scale_name - - if name not in params_dict: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - loaded_params.add(name) + loaded_params.add(name) return loaded_params @@ -610,3 +593,15 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + +def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, + weight_name: str) -> Optional[int]: + if hasattr(config, + "num_nextn_predict_layers") and (config.num_nextn_predict_layers + > 0): + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + if weight_name.startswith(f"model.layers.{layer_idx+i}."): + return layer_idx + i + return None From e6ad57675773ce146983ed43cfd6a2811c6e7e08 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 10 Jul 2025 18:42:54 +0800 Subject: [PATCH 06/17] update Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com> --- vllm/model_executor/models/glm4_moe.py | 126 ++++++++++++------------- 1 file changed, 60 insertions(+), 66 deletions(-) diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index 3b6e6ec301b..d6cf9c74133 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -33,7 +33,9 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE @@ -53,8 +55,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, extract_layer_index, - is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -96,7 +97,7 @@ def forward(self, x): return x -class Glm4MoeSparseMoeBlock(nn.Module): +class Glm4MoeE(nn.Module): def __init__( self, @@ -108,9 +109,13 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.config = config self.num_experts = config.n_routed_experts - self.top_k = config.num_experts_per_tok self.routed_scaling_factor = config.routed_scaling_factor self.norm_topk_prob = config.norm_topk_prob + self.n_shared_experts = config.n_shared_experts + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") if self.tp_size > self.num_experts: raise ValueError( @@ -123,53 +128,52 @@ def __init__( quant_config=None, prefix=f"{prefix}.gate") + # noaux_tc is not wrote in config now self.gate.e_score_correction_bias = nn.Parameter( torch.zeros(config.n_routed_experts, dtype=torch.float32)) self.experts = FusedMoE( - num_experts=self.num_experts, - top_k=self.top_k, + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, reduce_results=False, - renormalize=self.norm_topk_prob, + renormalize=config.norm_topk_prob, quant_config=quant_config, use_grouped_topk=True, num_expert_group=config.n_group, topk_group=config.topk_group, prefix=f"{prefix}.experts", - scoring_func=getattr(config, 'scoring_func', 'sigmoid'), + scoring_func=config.scoring_func, e_score_correction_bias=self.gate.e_score_correction_bias) - self.shared_experts = Glm4MoeMLP( - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size * - config.n_shared_experts, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=False, - prefix=f"{prefix}.shared_experts") + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + self.shared_experts = Glm4MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - - shared_output = self.shared_experts(hidden_states) - - # router_logits: (num_tokens, n_experts) + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) router_logits, _ = self.gate(hidden_states) - - routed_output = self.experts( + final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits) * self.routed_scaling_factor - + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: - routed_output = self.experts.maybe_all_reduce_tensor_model_parallel( - routed_output) - - final_output = routed_output + shared_output - - return final_output.view(num_tokens, hidden_dim) + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_dim) class Glm4MoeAttention(nn.Module): @@ -181,9 +185,9 @@ def __init__( num_kv_heads: int, rope_theta: float = 10000, rope_scaling: Optional[dict[str, Any]] = None, - max_position_embeddings: int = 8192, + max_position_embeddings: int = 131072, head_dim: Optional[int] = None, - rms_norm_eps: float = 1e-06, + rms_norm_eps: float = 1e-05, qkv_bias: bool = False, add_qk_norm: bool = False, cache_config: Optional[CacheConfig] = None, @@ -254,7 +258,6 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - # Add qk-norm if self.add_qk_norm: q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) @@ -286,7 +289,9 @@ def __init__( rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + 131072) + layer_idx = int(prefix.split(sep='.')[-1]) + self.layer_idx = layer_idx self.self_attn = Glm4MoeAttention( hidden_size=self.hidden_size, @@ -296,20 +301,15 @@ def __init__( rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, rms_norm_eps=config.rms_norm_eps, - qkv_bias=getattr(config, 'attention_bias', False), - head_dim=getattr(config, 'head_dim', None), - add_qk_norm=getattr(config, 'add_qk_norm', False), cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", ) - # `mlp_only_layers` in the config. - layer_idx = extract_layer_index(prefix) - if layer_idx >= getattr(config, "first_k_dense_replace", 1): - self.mlp = Glm4MoeSparseMoeBlock(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + if layer_idx >= config.first_k_dense_replace: + self.mlp = Glm4MoeE(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") else: self.mlp = Glm4MoeMLP(hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, @@ -328,20 +328,14 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: - # Self Attention if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm( hidden_states, residual) - - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - ) - - # Fully Connected + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) hidden_states = self.mlp(hidden_states) @@ -357,30 +351,31 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size self.config = config - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - prefix=f"{prefix}.embed_tokens") + self.vocab_size = config.vocab_size - num_layers = config.num_hidden_layers - if hasattr(config, 'num_nextn_predict_layers'): - num_layers = config.num_hidden_layers - config.num_nextn_predict_layers + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( - num_layers, + config.num_hidden_layers, lambda prefix: Glm4MoeDecoderLayer(config=config, cache_config=cache_config, quant_config=quant_config, prefix=prefix), - prefix=f"{prefix}.layers", - ) + prefix=f"{prefix}.layers") - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) @@ -461,7 +456,6 @@ def load_weights(self, weights: Iterable[tuple[str, # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if is_pp_missing_parameter(name, self): continue From d49c5bc00f0c584198f43a0ca8763f6770a58ff2 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 10 Jul 2025 18:56:05 +0800 Subject: [PATCH 07/17] update 1 Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com> --- vllm/model_executor/models/glm4_moe.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index d6cf9c74133..bbd62f57329 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -128,9 +128,9 @@ def __init__( quant_config=None, prefix=f"{prefix}.gate") - # noaux_tc is not wrote in config now - self.gate.e_score_correction_bias = nn.Parameter( - torch.zeros(config.n_routed_experts, dtype=torch.float32)) + # noaux_tc is not set in config now + self.gate.e_score_correction_bias = (nn.Parameter( + torch.empty(config.n_routed_experts))) self.experts = FusedMoE( num_experts=config.n_routed_experts, @@ -144,7 +144,7 @@ def __init__( num_expert_group=config.n_group, topk_group=config.topk_group, prefix=f"{prefix}.experts", - scoring_func=config.scoring_func, + scoring_func="sigmoid", e_score_correction_bias=self.gate.e_score_correction_bias) if config.n_shared_experts is not None: @@ -306,7 +306,8 @@ def __init__( prefix=f"{prefix}.self_attn", ) - if layer_idx >= config.first_k_dense_replace: + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace): self.mlp = Glm4MoeE(config=config, quant_config=quant_config, prefix=f"{prefix}.mlp") From c8ad31e0f8d694e79b56013d87f073318c2bf26c Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 10 Jul 2025 21:27:08 +0800 Subject: [PATCH 08/17] test --- vllm/model_executor/models/glm4_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index bbd62f57329..105abba7c8c 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -300,6 +300,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + head_dim=config.head_dim, rms_norm_eps=config.rms_norm_eps, cache_config=cache_config, quant_config=quant_config, From 2440756a46c9f94ea7b54639e3ccbaa49f724ebd Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Thu, 10 Jul 2025 23:07:10 +0800 Subject: [PATCH 09/17] Update glm4_moe.py --- vllm/model_executor/models/glm4_moe.py | 64 ++++++++------------------ 1 file changed, 19 insertions(+), 45 deletions(-) diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index 105abba7c8c..418bfb29047 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -97,7 +97,7 @@ def forward(self, x): return x -class Glm4MoeE(nn.Module): +class Glm4MoE(nn.Module): def __init__( self, @@ -107,28 +107,20 @@ def __init__( ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() - self.config = config - self.num_experts = config.n_routed_experts self.routed_scaling_factor = config.routed_scaling_factor - self.norm_topk_prob = config.norm_topk_prob self.n_shared_experts = config.n_shared_experts if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " "Only silu is supported for now.") - if self.tp_size > self.num_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.num_experts}.") - self.gate = ReplicatedLinear(config.hidden_size, config.n_routed_experts, bias=False, quant_config=None, prefix=f"{prefix}.gate") - # noaux_tc is not set in config now + # noaux_tc is not set in transformers new config now self.gate.e_score_correction_bias = (nn.Parameter( torch.empty(config.n_routed_experts))) @@ -236,16 +228,19 @@ def __init__( self.head_dim, rotary_dim=self.head_dim, max_position=max_position_embeddings, + partial_rotary_factor=0.5, base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) if self.add_qk_norm: self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) @@ -259,15 +254,10 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) if self.add_qk_norm: - q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, - self.head_dim) - q_by_head = self.q_norm(q_by_head) - q = q_by_head.view(q.shape) - - k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, - self.head_dim) - k_by_head = self.k_norm(k_by_head) - k = k_by_head.view(k.shape) + q = self.q_norm(q.reshape(-1, self.num_heads, + self.head_dim)).reshape(q.shape) + k = self.k_norm(k.reshape(-1, self.num_kv_heads, + self.head_dim)).reshape(k.shape) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) @@ -309,9 +299,9 @@ def __init__( if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace): - self.mlp = Glm4MoeE(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + self.mlp = Glm4MoE(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") else: self.mlp = Glm4MoeMLP(hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, @@ -438,10 +428,6 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) - if spec_layer is not None: - continue # skip spec decode layers for main model - for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: @@ -589,15 +575,3 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) - - -def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, - weight_name: str) -> Optional[int]: - if hasattr(config, - "num_nextn_predict_layers") and (config.num_nextn_predict_layers - > 0): - layer_idx = config.num_hidden_layers - for i in range(config.num_nextn_predict_layers): - if weight_name.startswith(f"model.layers.{layer_idx+i}."): - return layer_idx + i - return None From 5e38b14a135a4aedde57227f614eb668106b9619 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Fri, 11 Jul 2025 00:03:40 +0800 Subject: [PATCH 10/17] update --- vllm/model_executor/models/glm4_moe.py | 49 ++++++++++++++++++++------ 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index 418bfb29047..52b946865c8 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -32,7 +32,7 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -104,6 +104,7 @@ def __init__( config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + enable_eplb: bool = False, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() @@ -124,6 +125,22 @@ def __init__( self.gate.e_score_correction_bias = (nn.Parameter( torch.empty(config.n_routed_experts))) + # Load balancing settings. + vllm_config = get_current_vllm_config() + parallel_config = vllm_config.parallel_config + self.enable_eplb = enable_eplb + + self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_logical_experts = self.n_routed_experts + self.n_physical_experts = (self.n_logical_experts + + self.n_redundant_experts) + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = (self.ep_rank * + self.n_local_physical_experts) + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + self.experts = FusedMoE( num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, @@ -137,7 +154,9 @@ def __init__( topk_group=config.topk_group, prefix=f"{prefix}.experts", scoring_func="sigmoid", - e_score_correction_bias=self.gate.e_score_correction_bias) + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) if config.n_shared_experts is not None: intermediate_size = (config.moe_intermediate_size * @@ -147,7 +166,8 @@ def __init__( intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=False, + reduce_results=self.experts.must_reduce_shared_expert_outputs( + ), prefix=f"{prefix}.shared_experts", ) @@ -228,7 +248,6 @@ def __init__( self.head_dim, rotary_dim=self.head_dim, max_position=max_position_embeddings, - partial_rotary_factor=0.5, base=rope_theta, rope_scaling=rope_scaling, ) @@ -273,6 +292,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + enable_eplb: bool = False, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -299,9 +319,12 @@ def __init__( if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace): - self.mlp = Glm4MoE(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + self.mlp = Glm4MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + enable_eplb=enable_eplb, + ) else: self.mlp = Glm4MoeMLP(hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, @@ -343,6 +366,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + enable_eplb = vllm_config.parallel_config.enable_eplb self.config = config self.vocab_size = config.vocab_size @@ -358,10 +382,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Glm4MoeDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: Glm4MoeDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + enable_eplb=enable_eplb, + ), prefix=f"{prefix}.layers") if get_pp_group().is_last_rank: From c227471e74bcb4c4ebd1360776966ea083abd590 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Fri, 11 Jul 2025 00:06:09 +0800 Subject: [PATCH 11/17] Update glm4_moe.py --- vllm/model_executor/models/glm4_moe.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index 52b946865c8..ca6250b7b76 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -33,7 +33,7 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config -from vllm.distributed import (get_pp_group, +from vllm.distributed import (get_ep_group, get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.logger import init_logger @@ -109,7 +109,12 @@ def __init__( super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.routed_scaling_factor = config.routed_scaling_factor - self.n_shared_experts = config.n_shared_experts + + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts: int = config.n_routed_experts + self.n_shared_experts: int = config.n_shared_experts if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " From 9f3ab705a6cf0e4fb931956986f1833729d5c8fb Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Fri, 11 Jul 2025 10:52:01 +0800 Subject: [PATCH 12/17] 1 --- vllm/model_executor/models/glm4_moe.py | 42 ++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index ca6250b7b76..b615cd434ee 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -305,6 +305,8 @@ def __init__( rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 131072) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. layer_idx = int(prefix.split(sep='.')[-1]) self.layer_idx = layer_idx @@ -341,6 +343,7 @@ def __init__( eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.routed_scaling_factor = config.routed_scaling_factor def forward( self, @@ -580,6 +583,45 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + self.expert_weights = [] + + # Set MoE hyperparameters + self.num_moe_layers = (config.num_hidden_layers - + config.first_k_dense_replace) + self.num_expert_groups = config.n_group + + self.moe_layers: list[FusedMoE] = [] + for layer in self.model.layers: + assert isinstance(layer, Glm4MoeDecoderLayer) + if isinstance(layer.mlp, Glm4MoE): + self.moe_layers.append(layer.mlp.experts) + + # Pick last one layer since the first ones may be dense layers. + example_moe = typing.cast( + Glm4MoE, self.model.layers[config.num_hidden_layers - 1].mlp) + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) From 8feace60a9db2620fbe869e5d6ce7f875201eab6 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Fri, 11 Jul 2025 13:13:37 +0800 Subject: [PATCH 13/17] use ds imp --- vllm/model_executor/models/glm4_moe.py | 227 +++++++++++++------------ 1 file changed, 117 insertions(+), 110 deletions(-) diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index b615cd434ee..28eddbf8fa0 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -34,8 +34,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.distributed import (get_ep_group, get_pp_group, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) + get_tensor_model_parallel_world_size) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE @@ -55,7 +54,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, +from .utils import (PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -188,8 +187,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = ( + self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states)) return final_hidden_states.view(num_tokens, hidden_dim) @@ -441,6 +441,117 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + +class Glm4MoeForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Glm4MoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + self.expert_weights = [] + + # Set MoE hyperparameters + self.num_moe_layers = (config.num_hidden_layers - + config.first_k_dense_replace) + self.num_expert_groups = config.n_group + + self.moe_layers: list[FusedMoE] = [] + for layer in self.model.layers: + assert isinstance(layer, Glm4MoeDecoderLayer) + if isinstance(layer.mlp, Glm4MoE): + self.moe_layers.append(layer.mlp.experts) + + # Pick last one layer since the first ones may be dense layers. + example_moe = typing.cast( + Glm4MoE, self.model.layers[config.num_hidden_layers - 1].mlp) + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -544,108 +655,4 @@ def load_weights(self, weights: Iterable[tuple[str, weight_loader(param, loaded_weight) loaded_params.add(name) - return loaded_params - - -class Glm4MoeForCausalLM(nn.Module, SupportsPP): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - fall_back_to_pt_during_load = False - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - - self.model = Glm4MoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) - - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - self.expert_weights = [] - - # Set MoE hyperparameters - self.num_moe_layers = (config.num_hidden_layers - - config.first_k_dense_replace) - self.num_expert_groups = config.n_group - - self.moe_layers: list[FusedMoE] = [] - for layer in self.model.layers: - assert isinstance(layer, Glm4MoeDecoderLayer) - if isinstance(layer.mlp, Glm4MoE): - self.moe_layers.append(layer.mlp.experts) - - # Pick last one layer since the first ones may be dense layers. - example_moe = typing.cast( - Glm4MoE, self.model.layers[config.num_hidden_layers - 1].mlp) - self.num_logical_experts = example_moe.n_logical_experts - self.num_physical_experts = example_moe.n_physical_experts - self.num_local_physical_experts = example_moe.n_local_physical_experts - self.num_routed_experts = example_moe.n_routed_experts - self.num_shared_experts = example_moe.n_shared_experts - self.num_redundant_experts = example_moe.n_redundant_experts - - def set_eplb_state( - self, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - ) -> None: - for layer_idx, layer in enumerate(self.moe_layers): - # Register the expert weights. - self.expert_weights.append(layer.get_expert_weights()) - layer.set_eplb_state( - moe_layer_idx=layer_idx, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + return loaded_params \ No newline at end of file From 12666d3da2101a2ed127e7ebd1ca32d7fb07fb9b Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Fri, 11 Jul 2025 22:27:45 +0800 Subject: [PATCH 14/17] update and merge --- vllm/model_executor/models/glm4_moe.py | 204 ++++++++++++------------ vllm/transformers_utils/configs/ovis.py | 2 +- 2 files changed, 106 insertions(+), 100 deletions(-) diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index 28eddbf8fa0..0bb465267c2 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -54,7 +54,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -178,6 +178,7 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) + if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) router_logits, _ = self.gate(hidden_states) @@ -441,103 +442,6 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - -class Glm4MoeForCausalLM(nn.Module, SupportsPP): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - fall_back_to_pt_during_load = False - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - self.model = Glm4MoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) - else: - self.lm_head = PPMissingLayer() - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - self.expert_weights = [] - - # Set MoE hyperparameters - self.num_moe_layers = (config.num_hidden_layers - - config.first_k_dense_replace) - self.num_expert_groups = config.n_group - - self.moe_layers: list[FusedMoE] = [] - for layer in self.model.layers: - assert isinstance(layer, Glm4MoeDecoderLayer) - if isinstance(layer.mlp, Glm4MoE): - self.moe_layers.append(layer.mlp.experts) - - # Pick last one layer since the first ones may be dense layers. - example_moe = typing.cast( - Glm4MoE, self.model.layers[config.num_hidden_layers - 1].mlp) - self.num_logical_experts = example_moe.n_logical_experts - self.num_physical_experts = example_moe.n_physical_experts - self.num_local_physical_experts = example_moe.n_local_physical_experts - self.num_routed_experts = example_moe.n_routed_experts - self.num_shared_experts = example_moe.n_shared_experts - self.num_redundant_experts = example_moe.n_redundant_experts - - def set_eplb_state( - self, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - ) -> None: - for layer_idx, layer in enumerate(self.moe_layers): - # Register the expert weights. - self.expert_weights.append(layer.get_expert_weights()) - layer.set_eplb_state( - moe_layer_idx=layer_idx, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - def make_empty_intermediate_tensors( self, batch_size: int, dtype: torch.dtype, device: torch.device) -> IntermediateTensors: @@ -655,4 +559,106 @@ def load_weights(self, weights: Iterable[tuple[str, weight_loader(param, loaded_weight) loaded_params.add(name) - return loaded_params \ No newline at end of file + return loaded_params + + +class Glm4MoeForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Glm4MoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + self.expert_weights = [] + + # Set MoE hyperparameters + self.num_moe_layers = (config.num_hidden_layers - + config.first_k_dense_replace) + self.num_expert_groups = config.n_group + + self.moe_layers: list[FusedMoE] = [] + for layer in self.model.layers: + assert isinstance(layer, Glm4MoeDecoderLayer) + if isinstance(layer.mlp, Glm4MoE): + self.moe_layers.append(layer.mlp.experts) + + # Pick last one layer since the first ones may be dense layers. + example_moe = typing.cast( + Glm4MoE, self.model.layers[config.num_hidden_layers - 1].mlp) + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/transformers_utils/configs/ovis.py b/vllm/transformers_utils/configs/ovis.py index db6050fac57..c2728f0ed64 100644 --- a/vllm/transformers_utils/configs/ovis.py +++ b/vllm/transformers_utils/configs/ovis.py @@ -73,7 +73,7 @@ def __init__( IMAGE_ATOM_ID = -300 IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305] -# AutoConfig.register("aimv2", AIMv2Config) +AutoConfig.register("aimv2", AIMv2Config) # ---------------------------------------------------------------------- From ffe9d62e42a7b222eee5f7e3c2af3fc5e2ebabe7 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sat, 12 Jul 2025 22:28:43 +0800 Subject: [PATCH 15/17] fix partial_rotary_factor Signed-off-by: Isotr0py --- vllm/model_executor/models/glm4_moe.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index 0bb465267c2..c8e7988f6de 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -198,6 +198,7 @@ class Glm4MoeAttention(nn.Module): def __init__( self, + config: PretrainedConfig, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -250,12 +251,14 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.o_proj") + partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, + partial_rotary_factor=partial_rotary_factor, ) self.attn = Attention( self.num_heads, @@ -312,6 +315,7 @@ def __init__( self.layer_idx = layer_idx self.self_attn = Glm4MoeAttention( + config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, @@ -320,6 +324,7 @@ def __init__( max_position_embeddings=max_position_embeddings, head_dim=config.head_dim, rms_norm_eps=config.rms_norm_eps, + qkv_bias=config.attention_bias, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", From dbf1719cfd05992a81663757515280230a1b57a0 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Mon, 14 Jul 2025 00:02:45 +0800 Subject: [PATCH 16/17] Update for doc --- benchmarks/kernels/benchmark_moe_permute_unpermute.py | 6 ++++-- docs/models/supported_models.md | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py index dba1f3943b9..e503307f37c 100644 --- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py +++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py @@ -321,10 +321,12 @@ def main(args: argparse.Namespace): ): E = config.n_routed_experts topk = config.num_experts_per_tok - elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]: + elif ( + config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"] + or config.architectures[0] == "Glm4MoeForCausalLM" + ): E = config.num_experts topk = config.num_experts_per_tok - else: # Support for llama4 config = config.get_text_config() diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index eca37a09058..f36742501fc 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -573,6 +573,7 @@ Specified using `--task generate`. | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | | `GLM4VForCausalLM`^ | GLM-4V | T + I | `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + IE+ + VE+ | `THUDM/GLM-4.1V-9B-Thinkg`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Glm4MoeForCausalLM` | GLM-4-MoE | T + IE+ + VE+ | `THUDM/GLM-4-MoE-100B-A10B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ | | `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ | | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ | From 0ea4b998abba2eca6db111b00341d07b11e18caa Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Tue, 15 Jul 2025 15:14:56 +0800 Subject: [PATCH 17/17] update for GLM MPT draft --- tests/models/registry.py | 2 + vllm/config.py | 9 +- vllm/engine/arg_utils.py | 3 +- vllm/model_executor/models/glm4_moe.py | 1 + vllm/model_executor/models/glm4_moe_mtp.py | 285 +++++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + 6 files changed, 297 insertions(+), 4 deletions(-) create mode 100644 vllm/model_executor/models/glm4_moe_mtp.py diff --git a/tests/models/registry.py b/tests/models/registry.py index 449719991cd..c53d9443a5a 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -470,6 +470,8 @@ def check_available_online( is_available_online=False, speculative_model="openbmb/MiniCPM-2B-sft-bf16", tokenizer="openbmb/MiniCPM-2B-sft-bf16"), + "Glm4MoeMTPModel": _HfExamplesInfo("THUDM/GLM-4-MoE", + speculative_model="THUDM/GLM-4-MoE"), "MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True, speculative_model="XiaomiMiMo/MiMo-7B-RL") diff --git a/vllm/config.py b/vllm/config.py index 42410006f60..0264b1931d1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2545,7 +2545,8 @@ def __post_init__(self): SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", - "mlp_speculator", "draft_model", "deepseek_mtp"] + "mlp_speculator", "draft_model", "deepseek_mtp", + "glm4_moe_mtp"] SpeculativeAcceptanceMethod = Literal["rejection_sampler", "typical_acceptance_sampler"] @@ -2805,8 +2806,10 @@ def __post_init__(self): elif (self.draft_model_config.hf_config.model_type == "mlp_speculator"): self.method = "mlp_speculator" - elif (self.draft_model_config.hf_config.model_type == - "deepseek_mtp"): + elif (self.draft_model_config.hf_config.model_type + == "deepseek_mtp" + or self.draft_model_config.hf_config.model_type + == "glm4_moe_mtp"): self.method = "deepseek_mtp" if self.num_speculative_tokens > 1: logger.warning( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 269477c4848..84ebabcfc28 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1507,7 +1507,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: is_ngram_enabled = True elif speculative_method == "medusa": is_medusa_enabled = True - elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"): + elif speculative_method in ("eagle", "eagle3", "deepseek_mtp", + "glm4_moe_mtp"): is_eagle_enabled = True else: speculative_model = self.speculative_config.get("model") diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index c8e7988f6de..6cdb02bb1b8 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -328,6 +328,7 @@ def __init__( cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", + add_qk_norm=config.add_qk_norm, ) if (config.n_routed_experts is not None diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py new file mode 100644 index 00000000000..70d90c922bd --- /dev/null +++ b/vllm/model_executor/models/glm4_moe_mtp.py @@ -0,0 +1,285 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import CacheConfig, VllmConfig +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .deepseek_v2 import get_spec_layer_idx_from_weight_name +from .glm4_moe import Glm4MoeDecoderLayer +from .interfaces import SupportsPP +from .utils import maybe_prefix + + +class SharedHead(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(hidden_states) + + +class Glm4MoeMultiTokenPredictorLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.eh_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + self.shared_head = SharedHead(config=config, quant_config=quant_config) + self.mtp_block = Glm4MoeDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_index: int = 0, + ) -> torch.Tensor: + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds[positions == 0] = 0 + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, residual = self.mtp_block(positions=positions, + hidden_states=hidden_states, + residual=None) + hidden_states = residual + hidden_states + return hidden_states + + +class Glm4MoeMultiTokenPredictor(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict({ + str(idx): + Glm4MoeMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + }) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + current_step_idx = (spec_step_idx % self.num_mtp_layers) + return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( + input_ids, + positions, + previous_hidden_states, + inputs_embeds, + current_step_idx, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> torch.Tensor: + current_step_idx = (spec_step_idx % self.num_mtp_layers) + mtp_layer = self.layers[str(self.mtp_start_layer_idx + + current_step_idx)] + logits = self.logits_processor(mtp_layer.shared_head.head, + mtp_layer.shared_head(hidden_states), + sampling_metadata) + return logits + + +class Glm4MoeMTP(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + self.model = Glm4MoeMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, + previous_hidden_states, inputs_embeds, + spec_step_idx) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> Optional[torch.Tensor]: + return self.model.compute_logits(hidden_states, sampling_metadata, + spec_step_idx) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is None: + continue + name = self._rewrite_spec_layer_name(spec_layer, name) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # According to DeepSeek-V3 Technical Report, MTP modules + # shares embedding layer. We only load the first weights. + if (spec_layer != self.model.mtp_start_layer_idx + and ".layers" not in name): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: + """ + Rewrite the weight name to match the format of the original model. + Add .mtp_block for modules in transformer layer block for spec layer + and rename shared layer weights to be top level. + """ + spec_layer_weight_names = [ + "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + ] + shared_weight_names = ["embed_tokens"] + spec_layer_weight = False + shared_weight = False + for weight_name in spec_layer_weight_names: + if weight_name in name: + spec_layer_weight = True + if weight_name in shared_weight_names: + shared_weight = True + break + if not spec_layer_weight: + # treat rest weights as weights for transformer layer block + name = name.replace(f"model.layers.{spec_layer}.", + f"model.layers.{spec_layer}.mtp_block.") + elif shared_weight: + # treat shared weights as top level weights + name = name.replace(f"model.layers.{spec_layer}.", "model.") + return name diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index af13450bd8d..6db157b491d 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -247,6 +247,7 @@ "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), + "Glm4MoeMTPForCausalLM": ("glm4_moe_mtp", "Glm4MoeMTP"), "MedusaModel": ("medusa", "Medusa"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), }