From 0747d645f82d2bf01a1b076dafb91e7c4483cfb8 Mon Sep 17 00:00:00 2001 From: "vito.yy" Date: Wed, 9 Jul 2025 11:03:15 +0000 Subject: [PATCH 1/8] [Model] Add Ling implementation Signed-off-by: vito.yy --- docs/models/supported_models.md | 1 + tests/models/registry.py | 2 + vllm/model_executor/models/bailing_moe.py | 540 ++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/transformers_utils/configs/__init__.py | 1 + .../transformers_utils/configs/bailing_moe.py | 83 +++ 6 files changed, 628 insertions(+) create mode 100644 vllm/model_executor/models/bailing_moe.py create mode 100644 vllm/transformers_utils/configs/bailing_moe.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index e75d656af283..2c0f0fc00f68 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -316,6 +316,7 @@ Specified using `--task generate`. | `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ | | `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ | | `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ | | `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | | | `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | diff --git a/tests/models/registry.py b/tests/models/registry.py index 04fff03862fc..d739d066b428 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -141,6 +141,8 @@ def check_available_online( trust_remote_code=True), "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", trust_remote_code=True), + "BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5", + trust_remote_code=True), "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B", extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501 "BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m", diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py new file mode 100644 index 000000000000..a70c7300d29e --- /dev/null +++ b/vllm/model_executor/models/bailing_moe.py @@ -0,0 +1,540 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/inclusionAI/Ling/blob/master/models/modeling_bailing_moe.py +# Copyright 2023 The vLLM team. +# Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only BailingMoE model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch import nn + +from vllm.attention import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.bailing_moe import BailingMoeConfig + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +KVCache = tuple[torch.Tensor, torch.Tensor] + + +class BailingAttention(nn.Module): + + def __init__( + self, + config: BailingMoeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + self.total_num_heads = config.num_attention_heads + self.total_kv_heads = config.num_key_value_heads + tp_size = get_tensor_model_parallel_world_size() + + assert self.total_num_heads % tp_size == 0 + assert self.total_kv_heads % tp_size == 0 + assert self.total_num_heads >= self.total_kv_heads + + self.num_heads = self.total_num_heads // tp_size + self.head_dim = config.head_dim or (self.hidden_size // + self.total_num_heads) + self.q_size_per_rank = self.head_dim * self.num_heads + + self.num_kv_heads = self.total_kv_heads // tp_size + self.kv_size_per_rank = self.num_kv_heads * self.head_dim + self.scale = self.head_dim**-0.5 + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_kv_heads, + bias=(config.use_bias or config.use_qkv_bias), + quant_config=quant_config, + prefix=f"{prefix}.query_key_value", + ) + + self.dense = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn") + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=config.max_position_embeddings, + base=config.rope_theta, + is_neox_style=True, + rope_scaling=config.rope_scaling, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + ) -> torch.Tensor: + + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.split([ + self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank + ], + dim=-1) + + q, k = self.rotary_emb(position_ids, q, k) + + context_layer = self.attn( + q, + k, + v, + ) + + attn_output, _ = self.dense(context_layer) + return attn_output + + +class BailingMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: BailingMoeConfig, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: Optional[bool] = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + config.hidden_size, + [intermediate_size] * 2, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + config.hidden_size, + bias=config.use_bias, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class BailingMoE(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: BailingMoeConfig, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: Optional[bool] = True, + prefix: str = "", + ): + super().__init__() + + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_expert_prob = config.norm_topk_prob + self.hidden_size = config.hidden_size + self.quant_config = quant_config + self.num_shared_experts = config.num_shared_experts + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(self.hidden_size, + self.num_experts, + bias=False, + quant_config=None) + + self.experts = FusedMoE(num_experts=self.num_experts, + top_k=self.top_k, + hidden_size=self.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=self.norm_expert_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts") + + if self.num_shared_experts > 0: + intermediate_size = (config.moe_intermediate_size * + self.num_shared_experts) + self.shared_experts = BailingMLP( + intermediate_size=intermediate_size, + config=config, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_experts") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_size) + if self.num_shared_experts > 0: + shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + + if self.num_shared_experts > 0: + final_hidden_states = final_hidden_states + shared_output + + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_size) + + +class BailingMoeBlock(nn.Module): + + def __init__( + self, + config: BailingMoeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) + self.attention = BailingAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attention") + self.post_attention_layernorm = RMSNorm(hidden_size, + eps=config.rms_norm_eps) + self.mlp = BailingMoE(intermediate_size, + config, + quant_config, + True, + prefix=f"{prefix}.mlp") + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.attention( + hidden_states=hidden_states, + position_ids=position_ids, + ) + + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class BailingMoeModel(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + self.vocab_size = config.vocab_size + self.embed_dim = config.hidden_size + + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.word_embeddings = VocabParallelEmbedding( + self.vocab_size, self.embed_dim) + else: + self.word_embeddings = PPMissingLayer() + + self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: BailingMoeBlock( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers") + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.word_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + hidden_states, + position_ids, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class BailingMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + + packed_modules_mapping = { + "query_key_value": ["query_key_value"], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "query_key_value", + "dense", + "gate_up_proj", + "down_proj", + ] + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config + self.max_position_embeddings = config.max_position_embeddings + self.model = BailingMoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = (self.word_embeddings if config.tie_word_embeddings + else ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config)) + self.logits_processor = LogitsProcessor(config.vocab_size) + else: + self.lm_head = PPMissingLayer() + + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if (("v_head" in name) or ("inv_freq" in name) or + (self.config.tie_word_embeddings and "lm_head" in name)): + continue + if self.config.norm_head and "lm_head.weight" in name: + import torch.nn.functional as F + loaded_weight = F.normalize(loaded_weight, + dim=0, + p=2, + eps=1e-7) + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 27d476929855..867fd0bddc2f 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -41,6 +41,7 @@ "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-13b, lower case 'c' in the class name "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), + "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"), "BambaForCausalLM": ("bamba", "BambaForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 734f1e09d0fd..1fb4b7968074 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.transformers_utils.configs.bailing_moe import BailingMoeConfig from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.cohere2 import Cohere2Config from vllm.transformers_utils.configs.dbrx import DbrxConfig diff --git a/vllm/transformers_utils/configs/bailing_moe.py b/vllm/transformers_utils/configs/bailing_moe.py new file mode 100644 index 000000000000..60315dc950be --- /dev/null +++ b/vllm/transformers_utils/configs/bailing_moe.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/inclusionAI/Ling/blob/master/models/configuration_bailing_moe.py +from transformers.configuration_utils import PretrainedConfig + + +class BailingMoeConfig(PretrainedConfig): + model_type = "bailing_moe" + + def __init__( + self, + vocab_size=30592, + hidden_size=1024, + intermediate_size=None, + num_hidden_layers=24, + num_attention_heads=16, + num_key_value_heads=0, + hidden_act="silu", + use_qkv_bias=False, # bailing only + use_bias=True, # bailing only + rms_norm_eps=1e-05, + norm_head=False, # bailing only + tie_word_embeddings=False, # PretrainedConfig key, + # here change default value. + embedding_dropout=0.1, + attention_dropout=0.1, + output_dropout=0.1, + initializer_range=0.02, + max_position_embeddings=16384, + rope_theta=10000.0, + use_cache=True, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + rope_scaling=None, + pad_token_id=126081, + num_experts=16, + num_shared_experts=0, + num_experts_per_tok=2, + norm_topk_prob=True, + moe_intermediate_size=None, + first_k_dense_replace=0, + head_dim=None, + **kwargs, + ): + self.num_hidden_layers = num_hidden_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.use_qkv_bias = use_qkv_bias + self.use_bias = use_bias + self.norm_head = norm_head + self.rms_norm_eps = rms_norm_eps + self.embedding_dropout = embedding_dropout + self.attention_dropout = attention_dropout + self.output_dropout = output_dropout + self.initializer_range = initializer_range + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.use_cache = use_cache + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + self.head_dim = (head_dim if head_dim is not None else + self.hidden_size // self.num_attention_heads) + self.rope_scaling = rope_scaling + + # MoE configs + self.num_experts = num_experts + self.num_shared_experts = num_shared_experts + self.num_experts_per_tok = num_experts_per_tok + self.norm_topk_prob = norm_topk_prob + self.moe_intermediate_size = moe_intermediate_size + self.first_k_dense_replace = first_k_dense_replace + + super().__init__(pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs) From 965fc909e3d0d8d032b0fda515de586a8988aa8f Mon Sep 17 00:00:00 2001 From: "vito.yy" Date: Wed, 9 Jul 2025 11:18:28 +0000 Subject: [PATCH 2/8] Add BailingMoeConfig Signed-off-by: vito.yy --- vllm/transformers_utils/configs/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 1fb4b7968074..68aa187a13b9 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -31,6 +31,7 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ + "BailingMoeConfig", "ChatGLMConfig", "Cohere2Config", "DbrxConfig", From f0733880ae8eef67c666a931504dd82313417b95 Mon Sep 17 00:00:00 2001 From: "vito.yy" Date: Thu, 10 Jul 2025 09:19:55 +0000 Subject: [PATCH 3/8] Remove transformers_utils/configs/bailing_moe.py and move the load_weights to BailingMoeModel Signed-off-by: vito.yy --- vllm/model_executor/models/bailing_moe.py | 157 +++++++++--------- vllm/transformers_utils/configs/__init__.py | 2 - .../transformers_utils/configs/bailing_moe.py | 83 --------- 3 files changed, 82 insertions(+), 160 deletions(-) delete mode 100644 vllm/transformers_utils/configs/bailing_moe.py diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index a70c7300d29e..c7fd4bed4e03 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -28,6 +28,7 @@ import torch from torch import nn +import torch.nn.functional as F from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig @@ -54,7 +55,7 @@ from vllm.transformers_utils.configs.bailing_moe import BailingMoeConfig from .interfaces import SupportsLoRA, SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -377,6 +378,80 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if self.config.norm_head and "lm_head.weight" in name: + loaded_weight = F.normalize(loaded_weight, + dim=0, + p=2, + eps=1e-7) + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params class BailingMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): @@ -463,78 +538,10 @@ def sample( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts) - - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if (("v_head" in name) or ("inv_freq" in name) or - (self.config.tie_word_embeddings and "lm_head" in name)): - continue - if self.config.norm_head and "lm_head.weight" in name: - import torch.nn.functional as F - loaded_weight = F.normalize(loaded_weight, - dim=0, - p=2, - eps=1e-7) - - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - if "mlp.experts" in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) - break - else: - if name.endswith(".bias") and name not in params_dict: - continue - if name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 68aa187a13b9..734f1e09d0fd 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.transformers_utils.configs.bailing_moe import BailingMoeConfig from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.cohere2 import Cohere2Config from vllm.transformers_utils.configs.dbrx import DbrxConfig @@ -31,7 +30,6 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ - "BailingMoeConfig", "ChatGLMConfig", "Cohere2Config", "DbrxConfig", diff --git a/vllm/transformers_utils/configs/bailing_moe.py b/vllm/transformers_utils/configs/bailing_moe.py deleted file mode 100644 index 60315dc950be..000000000000 --- a/vllm/transformers_utils/configs/bailing_moe.py +++ /dev/null @@ -1,83 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Adapted from -# https://github.com/inclusionAI/Ling/blob/master/models/configuration_bailing_moe.py -from transformers.configuration_utils import PretrainedConfig - - -class BailingMoeConfig(PretrainedConfig): - model_type = "bailing_moe" - - def __init__( - self, - vocab_size=30592, - hidden_size=1024, - intermediate_size=None, - num_hidden_layers=24, - num_attention_heads=16, - num_key_value_heads=0, - hidden_act="silu", - use_qkv_bias=False, # bailing only - use_bias=True, # bailing only - rms_norm_eps=1e-05, - norm_head=False, # bailing only - tie_word_embeddings=False, # PretrainedConfig key, - # here change default value. - embedding_dropout=0.1, - attention_dropout=0.1, - output_dropout=0.1, - initializer_range=0.02, - max_position_embeddings=16384, - rope_theta=10000.0, - use_cache=True, - use_sliding_window=False, - sliding_window=4096, - max_window_layers=28, - rope_scaling=None, - pad_token_id=126081, - num_experts=16, - num_shared_experts=0, - num_experts_per_tok=2, - norm_topk_prob=True, - moe_intermediate_size=None, - first_k_dense_replace=0, - head_dim=None, - **kwargs, - ): - self.num_hidden_layers = num_hidden_layers - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.use_qkv_bias = use_qkv_bias - self.use_bias = use_bias - self.norm_head = norm_head - self.rms_norm_eps = rms_norm_eps - self.embedding_dropout = embedding_dropout - self.attention_dropout = attention_dropout - self.output_dropout = output_dropout - self.initializer_range = initializer_range - self.max_position_embeddings = max_position_embeddings - self.rope_theta = rope_theta - self.use_cache = use_cache - self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window - self.max_window_layers = max_window_layers - self.head_dim = (head_dim if head_dim is not None else - self.hidden_size // self.num_attention_heads) - self.rope_scaling = rope_scaling - - # MoE configs - self.num_experts = num_experts - self.num_shared_experts = num_shared_experts - self.num_experts_per_tok = num_experts_per_tok - self.norm_topk_prob = norm_topk_prob - self.moe_intermediate_size = moe_intermediate_size - self.first_k_dense_replace = first_k_dense_replace - - super().__init__(pad_token_id=pad_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs) From db65bc0c04f6193c54555b144712c04121e6deca Mon Sep 17 00:00:00 2001 From: "vito.yy" Date: Thu, 10 Jul 2025 10:03:04 +0000 Subject: [PATCH 4/8] Remove LoRA code and implement review suggestions Signed-off-by: vito.yy --- docs/models/supported_models.md | 2 +- vllm/model_executor/models/bailing_moe.py | 20 +++++--------------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 2c0f0fc00f68..d6f3bdfe331d 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -316,7 +316,7 @@ Specified using `--task generate`. | `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ | | `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | | ✅︎ | ✅︎ | | `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ | | `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | | | `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index c7fd4bed4e03..c7c40f64dd9d 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -27,8 +27,8 @@ from typing import Optional, Union import torch -from torch import nn import torch.nn.functional as F +from torch import nn from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig @@ -54,7 +54,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.bailing_moe import BailingMoeConfig -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -227,6 +227,8 @@ def __init__( quant_config=quant_config, reduce_results=False, prefix=f"{prefix}.shared_experts") + else: + self.shared_experts = None def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape @@ -454,7 +456,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class BailingMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class BailingMoeForCausalLM(nn.Module, SupportsPP): packed_modules_mapping = { "query_key_value": ["query_key_value"], @@ -464,16 +466,6 @@ class BailingMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ], } - # LoRA specific attributes - supported_lora_modules = [ - "query_key_value", - "dense", - "gate_up_proj", - "down_proj", - ] - embedding_modules = {} - embedding_padding_modules = [] - def __init__( self, *, @@ -484,10 +476,8 @@ def __init__( config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config self.config = config - self.lora_config = lora_config self.quant_config = quant_config self.max_position_embeddings = config.max_position_embeddings self.model = BailingMoeModel(vllm_config=vllm_config, From ed0f1d0709a4ad1f43d5b6dbd49d42ed10429fe3 Mon Sep 17 00:00:00 2001 From: "vito.yy" Date: Thu, 10 Jul 2025 12:06:19 +0000 Subject: [PATCH 5/8] small fix Signed-off-by: vito.yy --- vllm/model_executor/models/bailing_moe.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index c7c40f64dd9d..31effb307fc1 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -131,18 +131,14 @@ def forward( ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) - q, k, v = qkv.split([ - self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank - ], - dim=-1) + q, k, v = qkv.split( + [self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank], + dim=-1 + ) q, k = self.rotary_emb(position_ids, q, k) - context_layer = self.attn( - q, - k, - v, - ) + context_layer = self.attn(q, k, v) attn_output, _ = self.dense(context_layer) return attn_output @@ -380,7 +376,7 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -534,4 +530,3 @@ def load_weights(self, weights: Iterable[tuple[str, if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) - From e535e785ddcba8db525e2efd4ce51a168d196ff4 Mon Sep 17 00:00:00 2001 From: "vito.yy" Date: Fri, 11 Jul 2025 02:54:34 +0000 Subject: [PATCH 6/8] small fix Signed-off-by: vito.yy --- vllm/model_executor/models/bailing_moe.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index 31effb307fc1..5f3e4bea1288 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -59,8 +59,6 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -KVCache = tuple[torch.Tensor, torch.Tensor] - class BailingAttention(nn.Module): @@ -131,10 +129,10 @@ def forward( ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) - q, k, v = qkv.split( - [self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank], - dim=-1 - ) + q, k, v = qkv.split([ + self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank + ], + dim=-1) q, k = self.rotary_emb(position_ids, q, k) From 0056661838cea8f53a5ace151f31869c595d36d3 Mon Sep 17 00:00:00 2001 From: "vito.yy" Date: Mon, 14 Jul 2025 06:51:55 +0000 Subject: [PATCH 7/8] Merge from main and fix model test Signed-off-by: vito.yy --- vllm/model_executor/models/bailing_moe.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index 5f3e4bea1288..72c9e4721ee5 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -30,6 +30,8 @@ import torch.nn.functional as F from torch import nn +from transformers.configuration_utils import PretrainedConfig + from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -52,7 +54,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.configs.bailing_moe import BailingMoeConfig from .interfaces import SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, @@ -64,7 +65,7 @@ class BailingAttention(nn.Module): def __init__( self, - config: BailingMoeConfig, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -147,7 +148,7 @@ class BailingMLP(nn.Module): def __init__( self, intermediate_size: int, - config: BailingMoeConfig, + config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, reduce_results: Optional[bool] = True, prefix: str = "", @@ -182,7 +183,7 @@ class BailingMoE(nn.Module): def __init__( self, intermediate_size: int, - config: BailingMoeConfig, + config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, reduce_results: Optional[bool] = True, prefix: str = "", @@ -247,7 +248,7 @@ class BailingMoeBlock(nn.Module): def __init__( self, - config: BailingMoeConfig, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", From 6922d1fb6e1c7794998f6c4627d1c11ad6b0bffd Mon Sep 17 00:00:00 2001 From: "vito.yy" Date: Mon, 14 Jul 2025 07:13:10 +0000 Subject: [PATCH 8/8] Fix pre-commit test Signed-off-by: vito.yy --- vllm/model_executor/models/bailing_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index 72c9e4721ee5..325ba7bbad83 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -29,7 +29,6 @@ import torch import torch.nn.functional as F from torch import nn - from transformers.configuration_utils import PretrainedConfig from vllm.attention import Attention