Skip to content

Commit 123b9be

Browse files
committed
[2/N][Refactor][Qwen3-Next] remove redundant methods in Qwen3NextGatedDeltaNet
Signed-off-by: Icey <1790571317@qq.com>
1 parent fab27aa commit 123b9be

File tree

4 files changed

+120
-286
lines changed

4 files changed

+120
-286
lines changed

vllm_ascend/models/qwen3_next.py

Lines changed: 116 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
get_tensor_model_parallel_rank,
2020
get_tensor_model_parallel_world_size)
2121
from vllm.forward_context import ForwardContext, get_forward_context
22+
from vllm.model_executor.layers.fla.ops.fused_recurrent import \
23+
fused_recurrent_gated_delta_rule
2224
from vllm.model_executor.layers.fused_moe import FusedMoE
2325
# yapf conflicts with isort for this block
2426
# yapf: disable
@@ -44,8 +46,7 @@
4446
SupportsLoRA, SupportsPP)
4547
from vllm.model_executor.models.mamba_cache import MambaCacheParams
4648
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
47-
from vllm.model_executor.models.qwen3_next import (Qwen3NextAttention,
48-
Qwen3NextSparseMoeBlock)
49+
from vllm.model_executor.models.qwen3_next import fused_gdn_gating
4950
from vllm.model_executor.models.utils import (
5051
AutoWeightsLoader, PPMissingLayer, extract_layer_index,
5152
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
@@ -60,8 +61,113 @@
6061

6162
from vllm_ascend.ops.casual_conv1d import (causal_conv1d_fn,
6263
causal_conv1d_update_npu)
63-
from vllm_ascend.ops.fla import RMSNormGated, fused_gdn_gating
64-
from vllm_ascend.ops.sigmoid_gating import fused_recurrent_gated_delta_rule
64+
from vllm_ascend.ops.fla import RMSNormGated
65+
66+
67+
class Qwen3NextSparseMoeBlock(nn.Module):
68+
69+
def __init__(
70+
self,
71+
config: Qwen3NextConfig,
72+
quant_config: Optional[QuantizationConfig] = None,
73+
prefix: str = "",
74+
enable_eplb: bool = False,
75+
):
76+
super().__init__()
77+
self.tp_size = get_tensor_model_parallel_world_size()
78+
79+
self.ep_group = get_ep_group().device_group
80+
self.ep_rank = self.ep_group.rank()
81+
self.ep_size = self.ep_group.size()
82+
self.n_routed_experts = config.num_experts
83+
84+
if self.tp_size > config.num_experts:
85+
raise ValueError(
86+
f"Tensor parallel size {self.tp_size} is greater than "
87+
f"the number of experts {config.num_experts}.")
88+
89+
# Load balancing settings.
90+
vllm_config = get_current_vllm_config()
91+
eplb_config = vllm_config.parallel_config.eplb_config
92+
self.enable_eplb = enable_eplb
93+
94+
self.n_logical_experts = self.n_routed_experts
95+
self.n_redundant_experts = eplb_config.num_redundant_experts
96+
self.n_physical_experts = (self.n_logical_experts +
97+
self.n_redundant_experts)
98+
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
99+
100+
self.physical_expert_start = (self.ep_rank *
101+
self.n_local_physical_experts)
102+
self.physical_expert_end = (self.physical_expert_start +
103+
self.n_local_physical_experts)
104+
105+
self.experts = FusedMoE(num_experts=self.n_routed_experts,
106+
top_k=config.num_experts_per_tok,
107+
hidden_size=config.hidden_size,
108+
intermediate_size=config.moe_intermediate_size,
109+
reduce_results=False,
110+
renormalize=config.norm_topk_prob,
111+
quant_config=quant_config,
112+
prefix=f"{prefix}.experts",
113+
enable_eplb=self.enable_eplb,
114+
num_redundant_experts=self.n_redundant_experts)
115+
116+
self.gate = ReplicatedLinear(
117+
config.hidden_size,
118+
config.num_experts,
119+
bias=False,
120+
quant_config=self._maybe_ignore_quant_config(quant_config),
121+
prefix=f"{prefix}.gate")
122+
123+
if config.shared_expert_intermediate_size > 0:
124+
self.shared_expert = Qwen3NextMLP(
125+
hidden_size=config.hidden_size,
126+
intermediate_size=config.shared_expert_intermediate_size,
127+
hidden_act=config.hidden_act,
128+
quant_config=quant_config,
129+
reduce_results=self.experts.must_reduce_shared_expert_outputs(
130+
),
131+
)
132+
else:
133+
self.shared_expert = None
134+
self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
135+
1,
136+
bias=False)
137+
138+
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
139+
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
140+
# seems to avoid gate quantization.
141+
# See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4
142+
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
143+
return None
144+
return quant_config
145+
146+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
147+
# NOTE: hidden_states can have either 1D or 2D shape.
148+
orig_shape = hidden_states.shape
149+
hidden_dim = hidden_states.shape[-1]
150+
hidden_states = hidden_states.view(-1, hidden_dim)
151+
152+
shared_output = None
153+
if self.shared_expert is not None:
154+
shared_output = self.shared_expert(hidden_states)
155+
if self.shared_expert_gate is not None:
156+
shared_output = F.sigmoid(
157+
self.shared_expert_gate(hidden_states)) * shared_output
158+
159+
# router_logits: (num_tokens, n_experts)
160+
router_logits, _ = self.gate(hidden_states)
161+
final_hidden_states = self.experts(hidden_states=hidden_states,
162+
router_logits=router_logits)
163+
164+
if shared_output is not None:
165+
final_hidden_states = final_hidden_states + shared_output
166+
if self.tp_size > 1:
167+
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
168+
final_hidden_states)
169+
170+
return final_hidden_states.view(orig_shape)
65171

66172

67173
def torch_chunk_gated_delta_rule(
@@ -363,6 +469,7 @@ def forward(
363469
output: torch.Tensor,
364470
cache_params: Optional[MambaCacheParams] = None,
365471
):
472+
return torch.ops.vllm.npu_gdn_attention(
366473
return torch.ops.vllm.npu_gdn_attention(
367474
hidden_states,
368475
output,
@@ -1098,7 +1205,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
10981205
return self.model.get_expert_mapping()
10991206

11001207

1101-
def gdn_npu_attention(
1208+
def npu_gdn_attention(
11021209
hidden_states: torch.Tensor,
11031210
output: torch.Tensor,
11041211
layer_name: str,
@@ -1108,7 +1215,7 @@ def gdn_npu_attention(
11081215
self._forward(hidden_states=hidden_states, output=output)
11091216

11101217

1111-
def gdn_npu_attention_fake(
1218+
def npu_gdn_attention_fake(
11121219
hidden_states: torch.Tensor,
11131220
output: torch.Tensor,
11141221
layer_name: str,
@@ -1117,9 +1224,9 @@ def gdn_npu_attention_fake(
11171224

11181225

11191226
direct_register_custom_op(
1120-
op_name="gdn_attention",
1121-
op_func=gdn_npu_attention,
1227+
op_name="npu_gdn_attention",
1228+
op_func=npu_gdn_attention,
11221229
mutates_args=["output"],
1123-
fake_impl=gdn_npu_attention_fake,
1230+
fake_impl=npu_gdn_attention_fake,
11241231
dispatch_key=current_platform.dispatch_key,
11251232
)

vllm_ascend/ops/fla.py

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -328,54 +328,3 @@ def forward(self, x, z=None):
328328
group_size=self.group_size,
329329
norm_before_gate=self.norm_before_gate,
330330
)
331-
332-
333-
@triton.jit
334-
def fused_gdn_gating_kernel(
335-
g,
336-
A_log,
337-
a,
338-
dt_bias,
339-
seq_len,
340-
NUM_HEADS: tl.constexpr,
341-
beta: tl.constexpr,
342-
threshold: tl.constexpr,
343-
BLK_HEADS: tl.constexpr,
344-
):
345-
i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)
346-
head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS)
347-
off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off
348-
mask = head_off < NUM_HEADS
349-
blk_A_log = tl.load(A_log + head_off, mask=mask)
350-
blk_a = tl.load(a + off, mask=mask)
351-
blk_bias = tl.load(dt_bias + head_off, mask=mask)
352-
# If the model is loaded in fp16, without the .float() here, A might be -inf
353-
x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
354-
softplus_x = tl.where(beta * x <= threshold,
355-
(1 / beta) * tl.log(1 + tl.exp(beta * x)), x)
356-
blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
357-
tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
358-
359-
360-
def fused_gdn_gating(
361-
A_log: torch.Tensor,
362-
a: torch.Tensor,
363-
dt_bias: torch.Tensor,
364-
beta: float = 1.0,
365-
threshold: float = 20.0,
366-
) -> torch.Tensor:
367-
batch, num_heads = a.shape
368-
seq_len = 1
369-
grid = (batch, seq_len, triton.cdiv(num_heads, 8))
370-
g = torch.empty_like(a, dtype=torch.float32)
371-
fused_gdn_gating_kernel[grid](g,
372-
A_log,
373-
a,
374-
dt_bias,
375-
seq_len,
376-
num_heads,
377-
beta,
378-
threshold,
379-
8,
380-
num_warps=1)
381-
return g

0 commit comments

Comments
 (0)