|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +from collections.abc import Iterable |
| 5 | + |
| 6 | +import torch |
| 7 | +import torch.nn as nn |
| 8 | + |
| 9 | +from vllm.compilation.decorators import support_torch_compile |
| 10 | +from vllm.config import VllmConfig |
| 11 | +from vllm.distributed.parallel_state import get_pp_group |
| 12 | +from vllm.model_executor.layers.fused_moe import FusedMoE |
| 13 | +from vllm.model_executor.layers.layernorm import RMSNorm |
| 14 | +from vllm.model_executor.layers.logits_processor import LogitsProcessor |
| 15 | +from vllm.model_executor.layers.vocab_parallel_embedding import ( |
| 16 | + VocabParallelEmbedding) |
| 17 | +from vllm.model_executor.model_loader.weight_utils import ( |
| 18 | + default_weight_loader, maybe_remap_kv_scale_name) |
| 19 | +from vllm.model_executor.models.deepseek_v2 import ( |
| 20 | + DeepseekV2DecoderLayer, DeepseekV3ForCausalLM, |
| 21 | + get_spec_layer_idx_from_weight_name) |
| 22 | + |
| 23 | +from .utils import AutoWeightsLoader, maybe_prefix |
| 24 | + |
| 25 | + |
| 26 | +@support_torch_compile |
| 27 | +class DeepseekV2Model(nn.Module): |
| 28 | + |
| 29 | + def __init__( |
| 30 | + self, |
| 31 | + *, |
| 32 | + vllm_config: VllmConfig, |
| 33 | + prefix: str = "", |
| 34 | + start_layer_id: int = 0, |
| 35 | + ) -> None: |
| 36 | + super().__init__() |
| 37 | + self.config = vllm_config. \ |
| 38 | + speculative_config.draft_model_config.hf_config |
| 39 | + model_config = vllm_config.model_config |
| 40 | + cache_config = vllm_config.cache_config |
| 41 | + quant_config = vllm_config.quant_config |
| 42 | + self.vocab_size = self.config.vocab_size |
| 43 | + |
| 44 | + self.embed_tokens = VocabParallelEmbedding( |
| 45 | + self.config.vocab_size, |
| 46 | + self.config.hidden_size, |
| 47 | + quant_config=quant_config, |
| 48 | + prefix=maybe_prefix(prefix, "embed_tokens"), |
| 49 | + ) |
| 50 | + |
| 51 | + self.layers = nn.ModuleList([ |
| 52 | + DeepseekV2DecoderLayer( |
| 53 | + self.config, |
| 54 | + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), |
| 55 | + model_config=model_config, |
| 56 | + cache_config=cache_config, |
| 57 | + quant_config=quant_config, |
| 58 | + ) for i in range(self.config.num_hidden_layers) |
| 59 | + ]) |
| 60 | + |
| 61 | + self.fc = nn.Linear( |
| 62 | + self.config.model.hidden_size * 2, |
| 63 | + self.config.model.hidden_size, |
| 64 | + bias=False, |
| 65 | + ) |
| 66 | + |
| 67 | + self.enorm = RMSNorm(self.config.hidden_size, |
| 68 | + eps=self.config.rms_norm_eps) |
| 69 | + self.hnorm = RMSNorm(self.config.hidden_size, |
| 70 | + eps=self.config.rms_norm_eps) |
| 71 | + self.norm = RMSNorm(self.config.hidden_size, |
| 72 | + eps=self.config.rms_norm_eps) |
| 73 | + |
| 74 | + def forward( |
| 75 | + self, |
| 76 | + input_ids: torch.Tensor, |
| 77 | + positions: torch.Tensor, |
| 78 | + hidden_states: torch.Tensor, |
| 79 | + ) -> tuple[torch.Tensor, torch.Tensor]: |
| 80 | + input_embeds = self.embed_tokens(input_ids) |
| 81 | + |
| 82 | + inputs = torch.cat( |
| 83 | + [self.enorm(input_embeds), |
| 84 | + self.hnorm(hidden_states)], dim=-1) |
| 85 | + hidden_states = self.fc(inputs) |
| 86 | + |
| 87 | + # masking inputs at position=0 |
| 88 | + hidden_states[positions == 0] = 0 |
| 89 | + residual = None |
| 90 | + for layer in self.layers: |
| 91 | + hidden_states, residual = layer( |
| 92 | + positions, |
| 93 | + hidden_states, |
| 94 | + residual, |
| 95 | + ) |
| 96 | + hidden_states, _ = self.norm(hidden_states, residual) |
| 97 | + return hidden_states, hidden_states |
| 98 | + |
| 99 | + def load_weights(self, weights: Iterable[tuple[str, |
| 100 | + torch.Tensor]]) -> set[str]: |
| 101 | + stacked_params_mapping = [ |
| 102 | + # (param_name, shard_name, shard_id) |
| 103 | + ("gate_up_proj", "gate_proj", 0), |
| 104 | + ("gate_up_proj", "up_proj", 1), |
| 105 | + ] |
| 106 | + |
| 107 | + # Params for weights, fp8 weight scales, fp8 activation scales |
| 108 | + # (param_name, weight_name, expert_id, shard_id) |
| 109 | + expert_params_mapping = FusedMoE.make_expert_params_mapping( |
| 110 | + ckpt_gate_proj_name="gate_proj", |
| 111 | + ckpt_down_proj_name="down_proj", |
| 112 | + ckpt_up_proj_name="up_proj", |
| 113 | + num_experts=self.config.n_routed_experts) |
| 114 | + |
| 115 | + params_dict = dict(self.named_parameters()) |
| 116 | + loaded_params: set[str] = set() |
| 117 | + for name, loaded_weight in weights: |
| 118 | + if "rotary_emb.inv_freq" in name: |
| 119 | + continue |
| 120 | + |
| 121 | + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) |
| 122 | + if spec_layer is not None: |
| 123 | + continue # skip spec decode layers for main model |
| 124 | + |
| 125 | + for param_name, weight_name, shard_id in stacked_params_mapping: |
| 126 | + # Skip non-stacked layers and experts (experts handled below). |
| 127 | + if weight_name not in name: |
| 128 | + continue |
| 129 | + # We have mlp.experts[0].gate_proj in the checkpoint. |
| 130 | + # Since we handle the experts below in expert_params_mapping, |
| 131 | + # we need to skip here BEFORE we update the name, otherwise |
| 132 | + # name will be updated to mlp.experts[0].gate_up_proj, which |
| 133 | + # will then be updated below in expert_params_mapping |
| 134 | + # for mlp.experts[0].gate_gate_up_proj, which breaks load. |
| 135 | + if ("mlp.experts." in name) and name not in params_dict: |
| 136 | + continue |
| 137 | + name = name.replace(weight_name, param_name) |
| 138 | + # Skip loading extra bias for GPTQ models. |
| 139 | + if name.endswith(".bias") and name not in params_dict: |
| 140 | + continue |
| 141 | + |
| 142 | + param = params_dict[name] |
| 143 | + weight_loader = param.weight_loader |
| 144 | + weight_loader(param, loaded_weight, shard_id) |
| 145 | + break |
| 146 | + else: |
| 147 | + for mapping in expert_params_mapping: |
| 148 | + param_name, weight_name, expert_id, shard_id = mapping |
| 149 | + if weight_name not in name: |
| 150 | + continue |
| 151 | + name = name.replace(weight_name, param_name) |
| 152 | + |
| 153 | + param = params_dict[name] |
| 154 | + weight_loader = param.weight_loader |
| 155 | + weight_loader( |
| 156 | + param, |
| 157 | + loaded_weight, |
| 158 | + name, |
| 159 | + shard_id=shard_id, |
| 160 | + expert_id=expert_id, |
| 161 | + ) |
| 162 | + break |
| 163 | + else: |
| 164 | + # if PP disabled then draft will share embed with target |
| 165 | + if get_pp_group().world_size == 1 and \ |
| 166 | + "embed_tokens." in name: |
| 167 | + continue |
| 168 | + |
| 169 | + # Skip loading extra bias for GPTQ models. |
| 170 | + if name.endswith(".bias") and name not in params_dict: |
| 171 | + continue |
| 172 | + |
| 173 | + # Remapping the name of FP8 kv-scale. |
| 174 | + name = maybe_remap_kv_scale_name(name, params_dict) |
| 175 | + if name is None: |
| 176 | + continue |
| 177 | + |
| 178 | + param = params_dict[name] |
| 179 | + weight_loader = getattr(param, "weight_loader", |
| 180 | + default_weight_loader) |
| 181 | + weight_loader(param, loaded_weight) |
| 182 | + loaded_params.add(name) |
| 183 | + return loaded_params |
| 184 | + |
| 185 | + |
| 186 | +class EagleDeepseekForCausalLM(DeepseekV3ForCausalLM): |
| 187 | + |
| 188 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 189 | + nn.Module.__init__(self) |
| 190 | + self.config = vllm_config. \ |
| 191 | + speculative_config.draft_model_config.hf_config |
| 192 | + target_layer_num = vllm_config.model_config.get_num_layers( |
| 193 | + vllm_config.parallel_config) |
| 194 | + self.model = DeepseekV2Model(vllm_config=vllm_config, |
| 195 | + prefix="model", |
| 196 | + start_layer_id=target_layer_num) |
| 197 | + |
| 198 | + logit_scale = getattr(self.config, "logit_scale", 1.0) |
| 199 | + self.logits_processor = LogitsProcessor(self.config.vocab_size, |
| 200 | + scale=logit_scale) |
| 201 | + |
| 202 | + def forward( |
| 203 | + self, |
| 204 | + input_ids: torch.Tensor, |
| 205 | + positions: torch.Tensor, |
| 206 | + hidden_states: torch.Tensor, |
| 207 | + ) -> tuple[torch.Tensor, torch.Tensor]: |
| 208 | + return self.model(input_ids, positions, hidden_states) |
| 209 | + |
| 210 | + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): |
| 211 | + loader = AutoWeightsLoader( |
| 212 | + self, |
| 213 | + skip_prefixes=None, |
| 214 | + ) |
| 215 | + |
| 216 | + model_weights = {} |
| 217 | + for name, loaded_weight in weights: |
| 218 | + if "lm_head" not in name: |
| 219 | + name = "model." + name |
| 220 | + model_weights[name] = loaded_weight |
| 221 | + loader.load_weights(model_weights.items()) |
0 commit comments