Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm_ascend/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,8 @@ def register_model():
ModelRegistry.register_model(
"PanguProMoEForCausalLM",
"vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM")

ModelRegistry.register_model(
"Step3TextForCausalLM",
"vllm_ascend.models.step3_text:CustomStep3TextForCausalLM"
)
160 changes: 160 additions & 0 deletions vllm_ascend/models/step3_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from collections.abc import Iterable
from typing import Any, Dict, Iterable, Optional, Tuple

Check failure on line 2 in vllm_ascend/models/step3_text.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Ruff (F811)

vllm_ascend/models/step3_text.py:2:20: F811 Redefinition of unused `Iterable` from line 1
from vllm.model_executor.models.step3_text import Step3TextForCausalLM
import torch
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.models.utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.config import VllmConfig

class CustomStep3TextForCausalLM(Step3TextForCausalLM):
experts_ = []
for i in range(48):
experts_.extend([f"experts.{i}.down_proj", f"experts.{i}.gate_proj", f"experts.{i}.up_proj"])
Comment on lines +12 to +14
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The list construction for experts_ can be simplified using a more concise and Pythonic list comprehension. Also, the hardcoded number 48 is a magic number. It would be better to define it as a named constant (e.g., _STEP3_EXPERT_COUNT = 48) for improved readability and maintainability, and use that constant here.

Suggested change
experts_ = []
for i in range(48):
experts_.extend([f"experts.{i}.down_proj", f"experts.{i}.gate_proj", f"experts.{i}.up_proj"])
experts_ = [f"experts.{i}.{proj}" for i in range(48) for proj in ("down_proj", "gate_proj", "up_proj")]


packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts": experts_
}

def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__(vllm_config=vllm_config, prefix="model")

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> set[str]:
qkv_params_mapping = [
# (param_name, shard_name, relative_start_idx, relative_end_idx)
(".qkv_proj", ".q_proj", 0, self.config.share_q_dim / (self.config.share_q_dim + self.config.head_dim * 2)),
(".qkv_proj", ".k_proj", self.config.share_q_dim / (self.config.share_q_dim + self.config.head_dim * 2), (self.config.share_q_dim + self.config.head_dim) / (self.config.share_q_dim + self.config.head_dim * 2)),
(".qkv_proj", ".v_proj", (self.config.share_q_dim + self.config.head_dim) / (self.config.share_q_dim + self.config.head_dim * 2), (self.config.share_q_dim + self.config.head_dim * 2) / (self.config.share_q_dim + self.config.head_dim * 2)),
]
Comment on lines +38 to +43
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The initialization of qkv_params_mapping involves complex and repetitive calculations, which harms readability and maintainability. It's better to extract these calculations into separate variables. This makes the logic clearer and reduces the chance of errors if the dimensions change in the future. Also, the final relative end index is always 1.0, which can be used directly.

Suggested change
qkv_params_mapping = [
# (param_name, shard_name, relative_start_idx, relative_end_idx)
(".qkv_proj", ".q_proj", 0, self.config.share_q_dim / (self.config.share_q_dim + self.config.head_dim * 2)),
(".qkv_proj", ".k_proj", self.config.share_q_dim / (self.config.share_q_dim + self.config.head_dim * 2), (self.config.share_q_dim + self.config.head_dim) / (self.config.share_q_dim + self.config.head_dim * 2)),
(".qkv_proj", ".v_proj", (self.config.share_q_dim + self.config.head_dim) / (self.config.share_q_dim + self.config.head_dim * 2), (self.config.share_q_dim + self.config.head_dim * 2) / (self.config.share_q_dim + self.config.head_dim * 2)),
]
q_dim = self.config.share_q_dim
kv_dim = self.config.head_dim
total_dim = q_dim + 2 * kv_dim
q_end_rel = q_dim / total_dim
k_end_rel = (q_dim + kv_dim) / total_dim
qkv_params_mapping = [
# (param_name, shard_name, relative_start_idx, relative_end_idx)
(".qkv_proj", ".q_proj", 0, q_end_rel),
(".qkv_proj", ".k_proj", q_end_rel, k_end_rel),
(".qkv_proj", ".v_proj", k_end_rel, 1.0),
]

stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params = set()
params_need_to_load = set()

if self.vllm_config.quant_config is not None:
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.model.config.moe_num_experts)
is_fused_moe = False
else:
expert_params_mapping = [
(".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"),
(".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"),
(".moe.experts.w2_weight", ".moe.down_proj.weight", "w2")
]
is_fused_moe = True

disable_moe_stacked_params = [data[1] for data in expert_params_mapping]
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
if any(disable_moe_stacked_param in name for disable_moe_stacked_param in disable_moe_stacked_params):
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
else:
for mapping in expert_params_mapping:
if len(mapping) == 4:
param_name, weight_name, expert_id, shard_id = mapping
else:
param_name, weight_name, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
if is_fused_moe:
for expert_id in range(loaded_weight.shape[0]):
loaded_weight_expert = loaded_weight[expert_id]
weight_loader(param,
loaded_weight_expert,
name,
shard_id=shard_id,
expert_id=expert_id)
else:
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
loaded_params.add(name)
break
else:
for (param_name, weight_name, start_idx, end_idx) in qkv_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
if hasattr(param, "output_dim"):
dim = param.shape[param.output_dim]
begin_idx = int(start_idx * dim)
end_idx = int(end_idx * dim)
param_slice = param.narrow(param.output_dim,begin_idx,end_idx-begin_idx)
param_slice.copy_(loaded_weight)
else:
param.copy_(loaded_weight)
loaded_params.add(name)
break
else:
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
Comment on lines +69 to +150
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The load_weights method is very complex, using a chain of for-else blocks that spans over 80 lines. This makes the code difficult to read, debug, and maintain. It is highly recommended to refactor this logic by breaking it down into smaller, more manageable helper functions, each responsible for loading a specific type of parameter (e.g., _load_stacked_param, _load_expert_param). The main method can then orchestrate calls to these helpers using a clearer control flow like an if/elif/else structure. This will greatly improve the code's modularity and readability.

for name in params_dict:
params_need_to_load.add(name)
if params_need_to_load != loaded_params:
param_name_example = list(params_need_to_load - loaded_params)[0]
raise RuntimeError(
f"Some parameters like {param_name_example} are not in the checkpoint and will falsely use random initialization"
)
return loaded_params


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the model be adapted without adding the models file?

Loading