Skip to content

Commit f3b44b4

Browse files
author
unknown
committed
fc1 for glm
1 parent 4c380f3 commit f3b44b4

File tree

8 files changed

+302
-20
lines changed

8 files changed

+302
-20
lines changed

1.patch

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
diff --git a/pyproject.toml b/pyproject.toml
2+
index 1a140ce..f383d6d 100644
3+
--- a/pyproject.toml
4+
+++ b/pyproject.toml
5+
@@ -12,9 +12,9 @@ requires = [
6+
"scipy",
7+
"setuptools>=64",
8+
"setuptools-scm>=8",
9+
- "torch-npu==2.7.1.dev20250724",
10+
- "torch>=2.7.1",
11+
- "torchvision",
12+
+ #"torch-npu==2.7.1.dev20250724",
13+
+ #"torch>=2.7.1",
14+
+ #"torchvision",
15+
"wheel",
16+
"msgpack",
17+
"quart",
18+
diff --git a/requirements.txt b/requirements.txt
19+
index 7808e85..f422081 100644
20+
--- a/requirements.txt
21+
+++ b/requirements.txt
22+
@@ -10,8 +10,8 @@ pyyaml
23+
scipy
24+
setuptools>=64
25+
setuptools-scm>=8
26+
-torch>=2.7.1
27+
-torchvision
28+
+#torch>=2.7.1
29+
+#torchvision
30+
wheel
31+
32+
# requirements for disaggregated prefill
33+
@@ -22,6 +22,6 @@ quart
34+
numba
35+
36+
# Install torch_npu
37+
---pre
38+
---extra-index-url https://mirrors.huaweicloud.com/ascend/repos/pypi
39+
-torch-npu==2.7.1.dev20250724
40+
+#--pre
41+
+#--extra-index-url https://mirrors.huaweicloud.com/ascend/repos/pypi
42+
+#torch-npu==2.7.1.dev20250724
43+
diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py
44+
index f4a7887..a896ae4 100644
45+
--- a/vllm_ascend/ascend_forward_context.py
46+
+++ b/vllm_ascend/ascend_forward_context.py
47+
@@ -87,10 +87,14 @@ def set_ascend_forward_context(
48+
is_deepseek_v3_r1 = hasattr(
49+
vllm_config.model_config.hf_config, 'n_routed_experts'
50+
) and vllm_config.model_config.hf_config.n_routed_experts == 256
51+
+ is_glm4_moe = hasattr(
52+
+ vllm_config.model_config.hf_config, 'n_routed_experts'
53+
+ ) and vllm_config.model_config.hf_config.model_type == 'glm4_moe'
54+
fused_moe_state = _get_fused_moe_state(ep_size, with_prefill,
55+
is_deepseek_v3_r1)
56+
forward_context.fused_moe_state = fused_moe_state
57+
forward_context.in_profile_run = in_profile_run
58+
+ forward_context.is_glm4_moe = is_glm4_moe
59+
60+
from vllm_ascend.ops.moe.token_dispatcher import get_token_dispatcher
61+
dispatcher_name = _moe_method_to_dispatcher[moe_comm_method]
62+
diff --git a/vllm_ascend/models/glm4_moe.py b/vllm_ascend/models/glm4_moe.py
63+
index 9e1ca4b..46f3d0e 100644
64+
--- a/vllm_ascend/models/glm4_moe.py
65+
+++ b/vllm_ascend/models/glm4_moe.py
66+
@@ -76,7 +76,7 @@ class CustomGlm4MoE(Glm4MoE):
67+
final_hidden_states = (
68+
self.experts.maybe_all_reduce_tensor_model_parallel(
69+
final_hidden_states))
70+
- return final_hidden_states.view(num_tokens, hidden_dim)
71+
+ return final_hidden_states.view(-1, hidden_dim)
72+
73+
74+
class CustomGlm4MoeDecoderLayer(nn.Module):
75+
@@ -133,9 +133,9 @@ class CustomGlm4MoeDecoderLayer(nn.Module):
76+
prefix=f"{prefix}.mlp")
77+
78+
self.input_layernorm = RMSNorm(config.hidden_size,
79+
- eps=config.rms_norm_eps)
80+
+ eps=config.rms_norm_eps,prefix=f"{prefix}.input_layernorm")
81+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
82+
- eps=config.rms_norm_eps)
83+
+ eps=config.rms_norm_eps,prefix=f"{prefix}.post_attention_layernorm")
84+
self.routed_scaling_factor = config.routed_scaling_factor
85+
86+
def forward(
87+
@@ -197,7 +197,7 @@ class CustomGlm4MoeModel(Glm4MoeModel):
88+
prefix=f"{prefix}.layers")
89+
90+
if get_pp_group().is_last_rank:
91+
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
92+
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, prefix=f"{prefix}.norm")
93+
else:
94+
self.norm = PPMissingLayer()
95+
self.make_empty_intermediate_tensors = (
96+
@@ -267,4 +267,4 @@ class CustomGlm4MoeForCausalLM(Glm4MoeForCausalLM):
97+
self.num_local_physical_experts = example_moe.n_local_physical_experts
98+
self.num_routed_experts = example_moe.n_routed_experts
99+
self.num_shared_experts = example_moe.n_shared_experts
100+
- self.num_redundant_experts = example_moe.n_redundant_experts
101+
\ No newline at end of file
102+
+ self.num_redundant_experts = example_moe.n_redundant_experts
103+
diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py
104+
index 773684e..2a55c7d 100644
105+
--- a/vllm_ascend/ops/common_fused_moe.py
106+
+++ b/vllm_ascend/ops/common_fused_moe.py
107+
@@ -35,6 +35,9 @@ from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
108+
AlltoAllCommImpl, MC2CommImpl)
109+
from vllm_ascend.ops.moe.token_dispatcher import setup_token_dispatchers
110+
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
111+
+import torch.nn.functional as F
112+
+from vllm.distributed import (get_tensor_model_parallel_rank,
113+
+ get_tensor_model_parallel_world_size)
114+
115+
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
116+
117+
@@ -305,10 +308,18 @@ class AscendFusedMoE(FusedMoE):
118+
"""
119+
forward_context = get_forward_context()
120+
moe_comm_method_name = forward_context.moe_comm_method_name
121+
+ flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
122+
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
123+
+ if flashcomm_v1_enabled:
124+
+ pad_size = forward_context.pad_size
125+
+ if pad_size > 0:
126+
+ final_hidden_states = F.pad(final_hidden_states, (0, 0, 0, pad_size))
127+
+ tp_size = get_tensor_model_parallel_world_size()
128+
+ tp_rank = get_tensor_model_parallel_rank()
129+
+ final_hidden_states = torch.chunk(final_hidden_states, tp_size, dim=0)[tp_rank]
130+
return final_hidden_states
131+
else:
132+
- return tensor_model_parallel_all_reduce(final_hidden_states)
133+
+ return torch.ops.vllm.maybe_pad_and_reduce(final_hidden_states)
134+
135+
def forward_impl(self, hidden_states: torch.Tensor,
136+
router_logits: torch.Tensor):
137+
diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py
138+
index ccd031c..cc39db9 100644
139+
--- a/vllm_ascend/ops/layernorm.py
140+
+++ b/vllm_ascend/ops/layernorm.py
141+
@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union, cast
142+
143+
import torch
144+
from vllm.model_executor.layers.layernorm import RMSNorm
145+
+from vllm.forward_context import get_forward_context
146+
147+
148+
class AddRMSNormW8A8Quant(RMSNorm):
149+
@@ -54,14 +55,27 @@ class AddRMSNormW8A8Quant(RMSNorm):
150+
self.layer.aclnn_input_offset,
151+
epsilon=self.variance_epsilon)
152+
torch.ops.vllm.maybe_wait_prefetch_done(x)
153+
+ is_glm4_moe = get_forward_context().is_glm4_moe
154+
+ x = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(x, is_glm4_moe)
155+
return x, residual
156+
157+
x, residual = torch_npu.npu_rms_norm(x, self.weight,
158+
self.variance_epsilon)
159+
return x
160+
161+
-
162+
+cnt = 0
163+
class AscendRMSNorm(RMSNorm):
164+
+ def __init__(
165+
+ self,
166+
+ hidden_size: int,
167+
+ eps: float = 1e-6,
168+
+ var_hidden_size: Optional[int] = None,
169+
+ has_weight: bool = True,
170+
+ dtype: Optional[torch.dtype] = None,
171+
+ prefix: str = None,
172+
+ ) -> None:
173+
+ super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
174+
+ self.prefix = prefix
175+
176+
def forward_oot(
177+
self,
178+
@@ -69,10 +83,11 @@ class AscendRMSNorm(RMSNorm):
179+
residual: Optional[torch.Tensor] = None,
180+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
181+
import torch_npu
182+
-
183+
from vllm_ascend.utils import is_310p
184+
if residual is not None:
185+
+ global cnt
186+
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
187+
+ cnt = cnt + 1
188+
assert x.size(0) == residual.size(0)
189+
if is_310p():
190+
orig_dtype = residual.dtype
191+
@@ -84,6 +99,8 @@ class AscendRMSNorm(RMSNorm):
192+
x, _, residual = torch_npu.npu_add_rms_norm(
193+
x, residual, self.weight, self.variance_epsilon)
194+
torch.ops.vllm.maybe_wait_prefetch_done(x)
195+
+ is_glm4_moe = get_forward_context().is_glm4_moe
196+
+ x = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(x, is_glm4_moe and "post_attention_layernorm" in self.prefix)
197+
return x, residual
198+
199+
x, residual = torch_npu.npu_rms_norm(x, self.weight,
200+
diff --git a/vllm_ascend/ops/linear.py b/vllm_ascend/ops/linear.py
201+
index 8ffce39..47d9ba7 100644
202+
--- a/vllm_ascend/ops/linear.py
203+
+++ b/vllm_ascend/ops/linear.py
204+
@@ -33,7 +33,7 @@ from vllm.model_executor.layers.linear import ( # noqa
205+
from vllm.model_executor.layers.quantization.base_config import \
206+
QuantizationConfig
207+
from vllm.model_executor.utils import set_weight_attrs
208+
-
209+
+from vllm.forward_context import get_forward_context
210+
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
211+
get_otp_group)
212+
from vllm_ascend.utils import (dense_optim_enable, matmul_allreduce_enable,
213+
@@ -151,7 +151,7 @@ class AscendRowParallelLinear(RowParallelLinear):
214+
comm_group = get_tp_group()
215+
self.forward_type = "matmul_allreduce"
216+
self.hcomm_info = self.get_hcomm_info(comm_group.device_group)
217+
- elif dense_optim_enable():
218+
+ elif prefix.find("shared_experts") == -1 and dense_optim_enable():
219+
comm_group = get_tp_group()
220+
self.forward_type = "dense_optim"
221+
else:
222+
@@ -403,12 +403,13 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
223+
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
224+
comm_group = get_mlp_tp_group()
225+
self.forward_type = "mlp_tp"
226+
- elif dense_optim_enable():
227+
+ elif prefix.find("shared_experts") == -1 and dense_optim_enable():
228+
comm_group = get_tp_group()
229+
self.forward_type = "dense_optim"
230+
else:
231+
comm_group = get_tp_group()
232+
self.forward_type = "normal_tp"
233+
+ self.prefix = prefix
234+
self.comm_group = comm_group
235+
# TODO: check for disable_tp
236+
self.tp_rank = comm_group.rank_in_group
237+
@@ -469,7 +470,9 @@ class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
238+
# Matrix multiply.
239+
assert self.quant_method is not None
240+
241+
- input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
242+
+ is_glm4_moe = get_forward_context().is_glm4_moe
243+
+ if not is_glm4_moe:
244+
+ input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
245+
output_parallel = self.quant_method.apply(self, input_, bias)
246+
247+
if self.gather_output:

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ requires = [
1212
"scipy",
1313
"setuptools>=64",
1414
"setuptools-scm>=8",
15-
"torch-npu==2.7.1.dev20250724",
16-
"torch>=2.7.1",
17-
"torchvision",
15+
#"torch-npu==2.7.1.dev20250724",
16+
#"torch>=2.7.1",
17+
#"torchvision",
1818
"wheel",
1919
"msgpack",
2020
"quart",

requirements.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ pyyaml
1010
scipy
1111
setuptools>=64
1212
setuptools-scm>=8
13-
torch>=2.7.1
14-
torchvision
13+
#torch>=2.7.1
14+
#torchvision
1515
wheel
1616

1717
# requirements for disaggregated prefill
@@ -22,6 +22,6 @@ quart
2222
numba
2323

2424
# Install torch_npu
25-
--pre
26-
--extra-index-url https://mirrors.huaweicloud.com/ascend/repos/pypi
27-
torch-npu==2.7.1.dev20250724
25+
#--pre
26+
#--extra-index-url https://mirrors.huaweicloud.com/ascend/repos/pypi
27+
#torch-npu==2.7.1.dev20250724

vllm_ascend/ascend_forward_context.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,14 @@ def set_ascend_forward_context(
8787
is_deepseek_v3_r1 = hasattr(
8888
vllm_config.model_config.hf_config, 'n_routed_experts'
8989
) and vllm_config.model_config.hf_config.n_routed_experts == 256
90+
is_glm4_moe = hasattr(
91+
vllm_config.model_config.hf_config, 'n_routed_experts'
92+
) and vllm_config.model_config.hf_config.model_type == 'glm4_moe'
9093
fused_moe_state = _get_fused_moe_state(ep_size, with_prefill,
9194
is_deepseek_v3_r1)
9295
forward_context.fused_moe_state = fused_moe_state
9396
forward_context.in_profile_run = in_profile_run
97+
forward_context.is_glm4_moe = is_glm4_moe
9498

9599
from vllm_ascend.ops.moe.token_dispatcher import get_token_dispatcher
96100
dispatcher_name = _moe_method_to_dispatcher[moe_comm_method]

vllm_ascend/models/glm4_moe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
7676
final_hidden_states = (
7777
self.experts.maybe_all_reduce_tensor_model_parallel(
7878
final_hidden_states))
79-
return final_hidden_states.view(num_tokens, hidden_dim)
79+
return final_hidden_states.view(-1, hidden_dim)
8080

8181

8282
class CustomGlm4MoeDecoderLayer(nn.Module):
@@ -133,9 +133,9 @@ def __init__(
133133
prefix=f"{prefix}.mlp")
134134

135135
self.input_layernorm = RMSNorm(config.hidden_size,
136-
eps=config.rms_norm_eps)
136+
eps=config.rms_norm_eps,prefix=f"{prefix}.input_layernorm")
137137
self.post_attention_layernorm = RMSNorm(config.hidden_size,
138-
eps=config.rms_norm_eps)
138+
eps=config.rms_norm_eps,prefix=f"{prefix}.post_attention_layernorm")
139139
self.routed_scaling_factor = config.routed_scaling_factor
140140

141141
def forward(
@@ -197,7 +197,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
197197
prefix=f"{prefix}.layers")
198198

199199
if get_pp_group().is_last_rank:
200-
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
200+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, prefix=f"{prefix}.norm")
201201
else:
202202
self.norm = PPMissingLayer()
203203
self.make_empty_intermediate_tensors = (
@@ -267,4 +267,4 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
267267
self.num_local_physical_experts = example_moe.n_local_physical_experts
268268
self.num_routed_experts = example_moe.n_routed_experts
269269
self.num_shared_experts = example_moe.n_shared_experts
270-
self.num_redundant_experts = example_moe.n_redundant_experts
270+
self.num_redundant_experts = example_moe.n_redundant_experts

vllm_ascend/ops/common_fused_moe.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
AlltoAllCommImpl, MC2CommImpl)
3636
from vllm_ascend.ops.moe.token_dispatcher import setup_token_dispatchers
3737
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
38+
import torch.nn.functional as F
39+
from vllm.distributed import (get_tensor_model_parallel_rank,
40+
get_tensor_model_parallel_world_size)
3841

3942
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
4043

@@ -305,10 +308,18 @@ def maybe_all_reduce_tensor_model_parallel(
305308
"""
306309
forward_context = get_forward_context()
307310
moe_comm_method_name = forward_context.moe_comm_method_name
311+
flashcomm_v1_enabled = forward_context.flashcomm_v1_enabled
308312
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
313+
if flashcomm_v1_enabled:
314+
pad_size = forward_context.pad_size
315+
if pad_size > 0:
316+
final_hidden_states = F.pad(final_hidden_states, (0, 0, 0, pad_size))
317+
tp_size = get_tensor_model_parallel_world_size()
318+
tp_rank = get_tensor_model_parallel_rank()
319+
final_hidden_states = torch.chunk(final_hidden_states, tp_size, dim=0)[tp_rank]
309320
return final_hidden_states
310321
else:
311-
return tensor_model_parallel_all_reduce(final_hidden_states)
322+
return torch.ops.vllm.maybe_pad_and_reduce(final_hidden_states)
312323

313324
def forward_impl(self, hidden_states: torch.Tensor,
314325
router_logits: torch.Tensor):

0 commit comments

Comments
 (0)