Skip to content

Commit 6ebc3c3

Browse files
author
root
committed
Support W8A8_dynamic on Step3 Model
1 parent 05a700d commit 6ebc3c3

File tree

2 files changed

+156
-0
lines changed

2 files changed

+156
-0
lines changed

vllm_ascend/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,6 @@ def register_model():
4444
ModelRegistry.register_model(
4545
"Qwen3NextForCausalLM",
4646
"vllm_ascend.models.qwen3_next:Qwen3NextForCausalLM")
47+
ModelRegistry.register_model(
48+
"Step3TextForCausalLM",
49+
"vllm_ascend.models.step3_text:CustomStep3TextForCausalLM")

vllm_ascend/models/step3_text.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
from typing import Dict, Iterable, Tuple
2+
from vllm.model_executor.models.step3_text import Step3TextForCausalLM
3+
import torch
4+
from vllm.model_executor.layers.fused_moe import FusedMoE
5+
from vllm.model_executor.models.utils import is_pp_missing_parameter
6+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
7+
from vllm.config import VllmConfig
8+
9+
class CustomStep3TextForCausalLM(Step3TextForCausalLM):
10+
experts_ = [f"experts.{i}.{proj}" for i in range(48) for proj in ("down_proj", "gate_proj", "up_proj")]
11+
12+
packed_modules_mapping = {
13+
"qkv_proj": [
14+
"q_proj",
15+
"k_proj",
16+
"v_proj",
17+
],
18+
"gate_up_proj":[
19+
"gate_proj",
20+
"up_proj",
21+
],
22+
"experts": experts_
23+
}
24+
25+
def __init__(
26+
self,
27+
*,
28+
vllm_config: VllmConfig,
29+
prefix: str = "",
30+
):
31+
super().__init__(vllm_config=vllm_config, prefix="model")
32+
33+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
34+
qkv_params_mapping = [
35+
# (param_name, shard_name, relative_start_idx, relative_end_idx)
36+
(".qkv_proj", ".q_proj", 0, self.config.share_q_dim /
37+
(self.config.share_q_dim + self.config.head_dim * 2)),
38+
(".qkv_proj", ".k_proj", self.config.share_q_dim /
39+
(self.config.share_q_dim + self.config.head_dim * 2),
40+
(self.config.share_q_dim + self.config.head_dim) /
41+
(self.config.share_q_dim + self.config.head_dim * 2)),
42+
(".qkv_proj", ".v_proj",
43+
(self.config.share_q_dim + self.config.head_dim) /
44+
(self.config.share_q_dim + self.config.head_dim * 2),
45+
(self.config.share_q_dim + self.config.head_dim * 2) /
46+
(self.config.share_q_dim + self.config.head_dim * 2)),
47+
]
48+
stacked_params_mapping = [
49+
# (param_name, shard_name, shard_id)
50+
(".gate_up_proj", ".gate_proj", 0),
51+
(".gate_up_proj", ".up_proj", 1),
52+
]
53+
params_dict = dict(self.named_parameters())
54+
loaded_params: set[str] = set()
55+
56+
if self.vllm_config.quant_config is not None:
57+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
58+
ckpt_gate_proj_name="gate_proj",
59+
ckpt_down_proj_name="down_proj",
60+
ckpt_up_proj_name="up_proj",
61+
num_experts=self.model.config.moe_num_experts)
62+
is_fused_moe = False
63+
else:
64+
expert_params_mapping = [
65+
(".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"),
66+
(".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"),
67+
(".moe.experts.w2_weight", ".moe.down_proj.weight", "w2")
68+
]
69+
is_fused_moe = True
70+
71+
disable_moe_stacked_params = [
72+
data[1] for data in expert_params_mapping
73+
]
74+
75+
for name, loaded_weight in weights:
76+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
77+
if weight_name not in name:
78+
continue
79+
if any(disable_moe_stacked_param in name
80+
for disable_moe_stacked_param in
81+
disable_moe_stacked_params):
82+
continue
83+
name = name.replace(weight_name, param_name)
84+
if is_pp_missing_parameter(name, self):
85+
continue
86+
param = params_dict[name]
87+
weight_loader = param.weight_loader
88+
weight_loader(param, loaded_weight, shard_id)
89+
loaded_params.add(name)
90+
break
91+
else:
92+
for mapping in expert_params_mapping:
93+
if len(mapping) == 4:
94+
param_name, weight_name, expert_id, shard_id = mapping
95+
else:
96+
param_name, weight_name, shard_id = mapping
97+
if weight_name not in name:
98+
continue
99+
name = name.replace(weight_name, param_name)
100+
# Skip layers on other devices.
101+
if is_pp_missing_parameter(name, self):
102+
continue
103+
# Skip loading extra bias for GPTQ models.
104+
if ((name.endswith(".bias") or name.endswith("_bias"))
105+
and name not in params_dict):
106+
continue
107+
param = params_dict[name]
108+
weight_loader = param.weight_loader
109+
if is_fused_moe:
110+
for expert_id in range(loaded_weight.shape[0]):
111+
loaded_weight_expert = loaded_weight[expert_id]
112+
weight_loader(param,
113+
loaded_weight_expert,
114+
name,
115+
shard_id=shard_id,
116+
expert_id=expert_id)
117+
else:
118+
weight_loader(param,
119+
loaded_weight,
120+
name,
121+
shard_id=shard_id,
122+
expert_id=expert_id)
123+
loaded_params.add(name)
124+
break
125+
else:
126+
for (param_name, weight_name, start_idx,
127+
end_idx) in qkv_params_mapping:
128+
if weight_name not in name:
129+
continue
130+
name = name.replace(weight_name, param_name)
131+
if is_pp_missing_parameter(name, self):
132+
continue
133+
param = params_dict[name]
134+
if hasattr(param, "output_dim"):
135+
dim = param.shape[param.output_dim]
136+
begin_idx = int(start_idx * dim)
137+
end_idx = int(end_idx * dim)
138+
param_slice = param.narrow(param.output_dim, begin_idx,
139+
end_idx-begin_idx)
140+
param_slice.copy_(loaded_weight)
141+
else:
142+
param.copy_(loaded_weight)
143+
loaded_params.add(name)
144+
break
145+
else:
146+
if is_pp_missing_parameter(name, self):
147+
continue
148+
param = params_dict[name]
149+
weight_loader = getattr(param, "weight_loader",
150+
default_weight_loader)
151+
weight_loader(param, loaded_weight)
152+
loaded_params.add(name)
153+
return loaded_params

0 commit comments

Comments
 (0)