|
| 1 | +from typing import Dict, Iterable, Tuple |
| 2 | +from vllm.model_executor.models.step3_text import Step3TextForCausalLM |
| 3 | +import torch |
| 4 | +from vllm.model_executor.layers.fused_moe import FusedMoE |
| 5 | +from vllm.model_executor.models.utils import is_pp_missing_parameter |
| 6 | +from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
| 7 | +from vllm.config import VllmConfig |
| 8 | + |
| 9 | +class CustomStep3TextForCausalLM(Step3TextForCausalLM): |
| 10 | + experts_ = [f"experts.{i}.{proj}" for i in range(48) for proj in ("down_proj", "gate_proj", "up_proj")] |
| 11 | + |
| 12 | + packed_modules_mapping = { |
| 13 | + "qkv_proj": [ |
| 14 | + "q_proj", |
| 15 | + "k_proj", |
| 16 | + "v_proj", |
| 17 | + ], |
| 18 | + "gate_up_proj":[ |
| 19 | + "gate_proj", |
| 20 | + "up_proj", |
| 21 | + ], |
| 22 | + "experts": experts_ |
| 23 | + } |
| 24 | + |
| 25 | + def __init__( |
| 26 | + self, |
| 27 | + *, |
| 28 | + vllm_config: VllmConfig, |
| 29 | + prefix: str = "", |
| 30 | + ): |
| 31 | + super().__init__(vllm_config=vllm_config, prefix="model") |
| 32 | + |
| 33 | + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: |
| 34 | + qkv_params_mapping = [ |
| 35 | + # (param_name, shard_name, relative_start_idx, relative_end_idx) |
| 36 | + (".qkv_proj", ".q_proj", 0, self.config.share_q_dim / |
| 37 | + (self.config.share_q_dim + self.config.head_dim * 2)), |
| 38 | + (".qkv_proj", ".k_proj", self.config.share_q_dim / |
| 39 | + (self.config.share_q_dim + self.config.head_dim * 2), |
| 40 | + (self.config.share_q_dim + self.config.head_dim) / |
| 41 | + (self.config.share_q_dim + self.config.head_dim * 2)), |
| 42 | + (".qkv_proj", ".v_proj", |
| 43 | + (self.config.share_q_dim + self.config.head_dim) / |
| 44 | + (self.config.share_q_dim + self.config.head_dim * 2), |
| 45 | + (self.config.share_q_dim + self.config.head_dim * 2) / |
| 46 | + (self.config.share_q_dim + self.config.head_dim * 2)), |
| 47 | + ] |
| 48 | + stacked_params_mapping = [ |
| 49 | + # (param_name, shard_name, shard_id) |
| 50 | + (".gate_up_proj", ".gate_proj", 0), |
| 51 | + (".gate_up_proj", ".up_proj", 1), |
| 52 | + ] |
| 53 | + params_dict = dict(self.named_parameters()) |
| 54 | + loaded_params: set[str] = set() |
| 55 | + |
| 56 | + if self.vllm_config.quant_config is not None: |
| 57 | + expert_params_mapping = FusedMoE.make_expert_params_mapping( |
| 58 | + ckpt_gate_proj_name="gate_proj", |
| 59 | + ckpt_down_proj_name="down_proj", |
| 60 | + ckpt_up_proj_name="up_proj", |
| 61 | + num_experts=self.model.config.moe_num_experts) |
| 62 | + is_fused_moe = False |
| 63 | + else: |
| 64 | + expert_params_mapping = [ |
| 65 | + (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"), |
| 66 | + (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"), |
| 67 | + (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2") |
| 68 | + ] |
| 69 | + is_fused_moe = True |
| 70 | + |
| 71 | + disable_moe_stacked_params = [ |
| 72 | + data[1] for data in expert_params_mapping |
| 73 | + ] |
| 74 | + |
| 75 | + for name, loaded_weight in weights: |
| 76 | + for (param_name, weight_name, shard_id) in stacked_params_mapping: |
| 77 | + if weight_name not in name: |
| 78 | + continue |
| 79 | + if any(disable_moe_stacked_param in name |
| 80 | + for disable_moe_stacked_param in |
| 81 | + disable_moe_stacked_params): |
| 82 | + continue |
| 83 | + name = name.replace(weight_name, param_name) |
| 84 | + if is_pp_missing_parameter(name, self): |
| 85 | + continue |
| 86 | + param = params_dict[name] |
| 87 | + weight_loader = param.weight_loader |
| 88 | + weight_loader(param, loaded_weight, shard_id) |
| 89 | + loaded_params.add(name) |
| 90 | + break |
| 91 | + else: |
| 92 | + for mapping in expert_params_mapping: |
| 93 | + if len(mapping) == 4: |
| 94 | + param_name, weight_name, expert_id, shard_id = mapping |
| 95 | + else: |
| 96 | + param_name, weight_name, shard_id = mapping |
| 97 | + if weight_name not in name: |
| 98 | + continue |
| 99 | + name = name.replace(weight_name, param_name) |
| 100 | + # Skip layers on other devices. |
| 101 | + if is_pp_missing_parameter(name, self): |
| 102 | + continue |
| 103 | + # Skip loading extra bias for GPTQ models. |
| 104 | + if ((name.endswith(".bias") or name.endswith("_bias")) |
| 105 | + and name not in params_dict): |
| 106 | + continue |
| 107 | + param = params_dict[name] |
| 108 | + weight_loader = param.weight_loader |
| 109 | + if is_fused_moe: |
| 110 | + for expert_id in range(loaded_weight.shape[0]): |
| 111 | + loaded_weight_expert = loaded_weight[expert_id] |
| 112 | + weight_loader(param, |
| 113 | + loaded_weight_expert, |
| 114 | + name, |
| 115 | + shard_id=shard_id, |
| 116 | + expert_id=expert_id) |
| 117 | + else: |
| 118 | + weight_loader(param, |
| 119 | + loaded_weight, |
| 120 | + name, |
| 121 | + shard_id=shard_id, |
| 122 | + expert_id=expert_id) |
| 123 | + loaded_params.add(name) |
| 124 | + break |
| 125 | + else: |
| 126 | + for (param_name, weight_name, start_idx, |
| 127 | + end_idx) in qkv_params_mapping: |
| 128 | + if weight_name not in name: |
| 129 | + continue |
| 130 | + name = name.replace(weight_name, param_name) |
| 131 | + if is_pp_missing_parameter(name, self): |
| 132 | + continue |
| 133 | + param = params_dict[name] |
| 134 | + if hasattr(param, "output_dim"): |
| 135 | + dim = param.shape[param.output_dim] |
| 136 | + begin_idx = int(start_idx * dim) |
| 137 | + end_idx = int(end_idx * dim) |
| 138 | + param_slice = param.narrow(param.output_dim, begin_idx, |
| 139 | + end_idx-begin_idx) |
| 140 | + param_slice.copy_(loaded_weight) |
| 141 | + else: |
| 142 | + param.copy_(loaded_weight) |
| 143 | + loaded_params.add(name) |
| 144 | + break |
| 145 | + else: |
| 146 | + if is_pp_missing_parameter(name, self): |
| 147 | + continue |
| 148 | + param = params_dict[name] |
| 149 | + weight_loader = getattr(param, "weight_loader", |
| 150 | + default_weight_loader) |
| 151 | + weight_loader(param, loaded_weight) |
| 152 | + loaded_params.add(name) |
| 153 | + return loaded_params |
0 commit comments