-
Notifications
You must be signed in to change notification settings - Fork 468
[Feature] Add CustomStep3TextForCausalLM model #2382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,160 @@ | ||||||||||||||||||||||||||||||||||||
from collections.abc import Iterable | ||||||||||||||||||||||||||||||||||||
from typing import Any, Dict, Iterable, Optional, Tuple | ||||||||||||||||||||||||||||||||||||
from vllm.model_executor.models.step3_text import Step3TextForCausalLM | ||||||||||||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||||||||||||
from vllm.model_executor.layers.fused_moe import FusedMoE | ||||||||||||||||||||||||||||||||||||
from vllm.model_executor.models.utils import (PPMissingLayer, is_pp_missing_parameter, | ||||||||||||||||||||||||||||||||||||
make_empty_intermediate_tensors_factory, make_layers) | ||||||||||||||||||||||||||||||||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||||||||||||||||||||||||||||||||||||
from vllm.config import VllmConfig | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
class CustomStep3TextForCausalLM(Step3TextForCausalLM): | ||||||||||||||||||||||||||||||||||||
experts_ = [] | ||||||||||||||||||||||||||||||||||||
for i in range(48): | ||||||||||||||||||||||||||||||||||||
experts_.extend([f"experts.{i}.down_proj", f"experts.{i}.gate_proj", f"experts.{i}.up_proj"]) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
packed_modules_mapping = { | ||||||||||||||||||||||||||||||||||||
"qkv_proj": [ | ||||||||||||||||||||||||||||||||||||
"q_proj", | ||||||||||||||||||||||||||||||||||||
"k_proj", | ||||||||||||||||||||||||||||||||||||
"v_proj", | ||||||||||||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||||||||||||
"gate_up_proj": [ | ||||||||||||||||||||||||||||||||||||
"gate_proj", | ||||||||||||||||||||||||||||||||||||
"up_proj", | ||||||||||||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||||||||||||
"experts": experts_ | ||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||
*, | ||||||||||||||||||||||||||||||||||||
vllm_config: VllmConfig, | ||||||||||||||||||||||||||||||||||||
prefix: str = "", | ||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||
super().__init__(vllm_config=vllm_config, prefix="model") | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> set[str]: | ||||||||||||||||||||||||||||||||||||
qkv_params_mapping = [ | ||||||||||||||||||||||||||||||||||||
# (param_name, shard_name, relative_start_idx, relative_end_idx) | ||||||||||||||||||||||||||||||||||||
(".qkv_proj", ".q_proj", 0, self.config.share_q_dim / (self.config.share_q_dim + self.config.head_dim * 2)), | ||||||||||||||||||||||||||||||||||||
(".qkv_proj", ".k_proj", self.config.share_q_dim / (self.config.share_q_dim + self.config.head_dim * 2), (self.config.share_q_dim + self.config.head_dim) / (self.config.share_q_dim + self.config.head_dim * 2)), | ||||||||||||||||||||||||||||||||||||
(".qkv_proj", ".v_proj", (self.config.share_q_dim + self.config.head_dim) / (self.config.share_q_dim + self.config.head_dim * 2), (self.config.share_q_dim + self.config.head_dim * 2) / (self.config.share_q_dim + self.config.head_dim * 2)), | ||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||
Comment on lines
+38
to
+43
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The initialization of
Suggested change
|
||||||||||||||||||||||||||||||||||||
stacked_params_mapping = [ | ||||||||||||||||||||||||||||||||||||
# (param_name, shard_name, shard_id) | ||||||||||||||||||||||||||||||||||||
(".gate_up_proj", ".gate_proj", 0), | ||||||||||||||||||||||||||||||||||||
(".gate_up_proj", ".up_proj", 1), | ||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||
params_dict = dict(self.named_parameters()) | ||||||||||||||||||||||||||||||||||||
loaded_params = set() | ||||||||||||||||||||||||||||||||||||
params_need_to_load = set() | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
if self.vllm_config.quant_config is not None: | ||||||||||||||||||||||||||||||||||||
expert_params_mapping = FusedMoE.make_expert_params_mapping( | ||||||||||||||||||||||||||||||||||||
ckpt_gate_proj_name="gate_proj", | ||||||||||||||||||||||||||||||||||||
ckpt_down_proj_name="down_proj", | ||||||||||||||||||||||||||||||||||||
ckpt_up_proj_name="up_proj", | ||||||||||||||||||||||||||||||||||||
num_experts=self.model.config.moe_num_experts) | ||||||||||||||||||||||||||||||||||||
is_fused_moe = False | ||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||
expert_params_mapping = [ | ||||||||||||||||||||||||||||||||||||
(".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"), | ||||||||||||||||||||||||||||||||||||
(".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"), | ||||||||||||||||||||||||||||||||||||
(".moe.experts.w2_weight", ".moe.down_proj.weight", "w2") | ||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||
is_fused_moe = True | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
disable_moe_stacked_params = [data[1] for data in expert_params_mapping] | ||||||||||||||||||||||||||||||||||||
for name, loaded_weight in weights: | ||||||||||||||||||||||||||||||||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping: | ||||||||||||||||||||||||||||||||||||
if weight_name not in name: | ||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||
if any(disable_moe_stacked_param in name for disable_moe_stacked_param in disable_moe_stacked_params): | ||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||
name = name.replace(weight_name, param_name) | ||||||||||||||||||||||||||||||||||||
if is_pp_missing_parameter(name, self): | ||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||
if name not in params_dict: | ||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||
param = params_dict[name] | ||||||||||||||||||||||||||||||||||||
weight_loader = param.weight_loader | ||||||||||||||||||||||||||||||||||||
weight_loader(param, loaded_weight, shard_id) | ||||||||||||||||||||||||||||||||||||
loaded_params.add(name) | ||||||||||||||||||||||||||||||||||||
break | ||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||
for mapping in expert_params_mapping: | ||||||||||||||||||||||||||||||||||||
if len(mapping) == 4: | ||||||||||||||||||||||||||||||||||||
param_name, weight_name, expert_id, shard_id = mapping | ||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||
param_name, weight_name, shard_id = mapping | ||||||||||||||||||||||||||||||||||||
if weight_name not in name: | ||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||
name = name.replace(weight_name, param_name) | ||||||||||||||||||||||||||||||||||||
# Skip layers on other devices. | ||||||||||||||||||||||||||||||||||||
if is_pp_missing_parameter(name, self): | ||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||
# Skip loading extra bias for GPTQ models. | ||||||||||||||||||||||||||||||||||||
if ((name.endswith(".bias") or name.endswith("_bias")) | ||||||||||||||||||||||||||||||||||||
and name not in params_dict): | ||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||
if name not in params_dict: | ||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||
param = params_dict[name] | ||||||||||||||||||||||||||||||||||||
weight_loader = param.weight_loader | ||||||||||||||||||||||||||||||||||||
if is_fused_moe: | ||||||||||||||||||||||||||||||||||||
for expert_id in range(loaded_weight.shape[0]): | ||||||||||||||||||||||||||||||||||||
loaded_weight_expert = loaded_weight[expert_id] | ||||||||||||||||||||||||||||||||||||
weight_loader(param, | ||||||||||||||||||||||||||||||||||||
loaded_weight_expert, | ||||||||||||||||||||||||||||||||||||
name, | ||||||||||||||||||||||||||||||||||||
shard_id=shard_id, | ||||||||||||||||||||||||||||||||||||
expert_id=expert_id) | ||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||
weight_loader(param, | ||||||||||||||||||||||||||||||||||||
loaded_weight, | ||||||||||||||||||||||||||||||||||||
name, | ||||||||||||||||||||||||||||||||||||
shard_id=shard_id, | ||||||||||||||||||||||||||||||||||||
expert_id=expert_id) | ||||||||||||||||||||||||||||||||||||
loaded_params.add(name) | ||||||||||||||||||||||||||||||||||||
break | ||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||
for (param_name, weight_name, start_idx, end_idx) in qkv_params_mapping: | ||||||||||||||||||||||||||||||||||||
if weight_name not in name: | ||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||
name = name.replace(weight_name, param_name) | ||||||||||||||||||||||||||||||||||||
if is_pp_missing_parameter(name, self): | ||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||
if name not in params_dict: | ||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||
param = params_dict[name] | ||||||||||||||||||||||||||||||||||||
if hasattr(param, "output_dim"): | ||||||||||||||||||||||||||||||||||||
dim = param.shape[param.output_dim] | ||||||||||||||||||||||||||||||||||||
begin_idx = int(start_idx * dim) | ||||||||||||||||||||||||||||||||||||
end_idx = int(end_idx * dim) | ||||||||||||||||||||||||||||||||||||
param_slice = param.narrow(param.output_dim,begin_idx,end_idx-begin_idx) | ||||||||||||||||||||||||||||||||||||
param_slice.copy_(loaded_weight) | ||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||
param.copy_(loaded_weight) | ||||||||||||||||||||||||||||||||||||
loaded_params.add(name) | ||||||||||||||||||||||||||||||||||||
break | ||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||
if is_pp_missing_parameter(name, self): | ||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||
if name not in params_dict: | ||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||
param = params_dict[name] | ||||||||||||||||||||||||||||||||||||
weight_loader = getattr(param, "weight_loader", | ||||||||||||||||||||||||||||||||||||
default_weight_loader) | ||||||||||||||||||||||||||||||||||||
weight_loader(param, loaded_weight) | ||||||||||||||||||||||||||||||||||||
loaded_params.add(name) | ||||||||||||||||||||||||||||||||||||
Comment on lines
+69
to
+150
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||||||||||||||||||||||||||||||||||||
for name in params_dict: | ||||||||||||||||||||||||||||||||||||
params_need_to_load.add(name) | ||||||||||||||||||||||||||||||||||||
if params_need_to_load != loaded_params: | ||||||||||||||||||||||||||||||||||||
param_name_example = list(params_need_to_load - loaded_params)[0] | ||||||||||||||||||||||||||||||||||||
raise RuntimeError( | ||||||||||||||||||||||||||||||||||||
f"Some parameters like {param_name_example} are not in the checkpoint and will falsely use random initialization" | ||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||
return loaded_params | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can the model be adapted without adding the models file? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The list construction for
experts_
can be simplified using a more concise and Pythonic list comprehension. Also, the hardcoded number48
is a magic number. It would be better to define it as a named constant (e.g.,_STEP3_EXPERT_COUNT = 48
) for improved readability and maintainability, and use that constant here.