19
19
get_tensor_model_parallel_rank ,
20
20
get_tensor_model_parallel_world_size )
21
21
from vllm .forward_context import ForwardContext , get_forward_context
22
+ from vllm .model_executor .layers .fla .ops .fused_recurrent import \
23
+ fused_recurrent_gated_delta_rule
22
24
from vllm .model_executor .layers .fused_moe import FusedMoE
23
25
# yapf conflicts with isort for this block
24
26
# yapf: disable
44
46
SupportsLoRA , SupportsPP )
45
47
from vllm .model_executor .models .mamba_cache import MambaCacheParams
46
48
from vllm .model_executor .models .qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
47
- from vllm .model_executor .models .qwen3_next import (Qwen3NextAttention ,
48
- Qwen3NextSparseMoeBlock )
49
+ from vllm .model_executor .models .qwen3_next import fused_gdn_gating
49
50
from vllm .model_executor .models .utils import (
50
51
AutoWeightsLoader , PPMissingLayer , extract_layer_index ,
51
52
is_pp_missing_parameter , make_empty_intermediate_tensors_factory ,
60
61
61
62
from vllm_ascend .ops .casual_conv1d import (causal_conv1d_fn ,
62
63
causal_conv1d_update_npu )
63
- from vllm_ascend .ops .fla import RMSNormGated , fused_gdn_gating
64
- from vllm_ascend .ops .sigmoid_gating import fused_recurrent_gated_delta_rule
64
+ from vllm_ascend .ops .fla import RMSNormGated
65
+
66
+
67
+ class Qwen3NextSparseMoeBlock (nn .Module ):
68
+
69
+ def __init__ (
70
+ self ,
71
+ config : Qwen3NextConfig ,
72
+ quant_config : Optional [QuantizationConfig ] = None ,
73
+ prefix : str = "" ,
74
+ enable_eplb : bool = False ,
75
+ ):
76
+ super ().__init__ ()
77
+ self .tp_size = get_tensor_model_parallel_world_size ()
78
+
79
+ self .ep_group = get_ep_group ().device_group
80
+ self .ep_rank = self .ep_group .rank ()
81
+ self .ep_size = self .ep_group .size ()
82
+ self .n_routed_experts = config .num_experts
83
+
84
+ if self .tp_size > config .num_experts :
85
+ raise ValueError (
86
+ f"Tensor parallel size { self .tp_size } is greater than "
87
+ f"the number of experts { config .num_experts } ." )
88
+
89
+ # Load balancing settings.
90
+ vllm_config = get_current_vllm_config ()
91
+ eplb_config = vllm_config .parallel_config .eplb_config
92
+ self .enable_eplb = enable_eplb
93
+
94
+ self .n_logical_experts = self .n_routed_experts
95
+ self .n_redundant_experts = eplb_config .num_redundant_experts
96
+ self .n_physical_experts = (self .n_logical_experts +
97
+ self .n_redundant_experts )
98
+ self .n_local_physical_experts = self .n_physical_experts // self .ep_size
99
+
100
+ self .physical_expert_start = (self .ep_rank *
101
+ self .n_local_physical_experts )
102
+ self .physical_expert_end = (self .physical_expert_start +
103
+ self .n_local_physical_experts )
104
+
105
+ self .experts = FusedMoE (num_experts = self .n_routed_experts ,
106
+ top_k = config .num_experts_per_tok ,
107
+ hidden_size = config .hidden_size ,
108
+ intermediate_size = config .moe_intermediate_size ,
109
+ reduce_results = False ,
110
+ renormalize = config .norm_topk_prob ,
111
+ quant_config = quant_config ,
112
+ prefix = f"{ prefix } .experts" ,
113
+ enable_eplb = self .enable_eplb ,
114
+ num_redundant_experts = self .n_redundant_experts )
115
+
116
+ self .gate = ReplicatedLinear (
117
+ config .hidden_size ,
118
+ config .num_experts ,
119
+ bias = False ,
120
+ quant_config = self ._maybe_ignore_quant_config (quant_config ),
121
+ prefix = f"{ prefix } .gate" )
122
+
123
+ if config .shared_expert_intermediate_size > 0 :
124
+ self .shared_expert = Qwen3NextMLP (
125
+ hidden_size = config .hidden_size ,
126
+ intermediate_size = config .shared_expert_intermediate_size ,
127
+ hidden_act = config .hidden_act ,
128
+ quant_config = quant_config ,
129
+ reduce_results = self .experts .must_reduce_shared_expert_outputs (
130
+ ),
131
+ )
132
+ else :
133
+ self .shared_expert = None
134
+ self .shared_expert_gate = torch .nn .Linear (config .hidden_size ,
135
+ 1 ,
136
+ bias = False )
137
+
138
+ def _maybe_ignore_quant_config (self , quant_config : QuantizationConfig ):
139
+ # GPTQ configs do not have a list of ignored modules, however AutoGPTQ
140
+ # seems to avoid gate quantization.
141
+ # See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4
142
+ if isinstance (quant_config , (GPTQConfig , GPTQMarlinConfig )):
143
+ return None
144
+ return quant_config
145
+
146
+ def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
147
+ # NOTE: hidden_states can have either 1D or 2D shape.
148
+ orig_shape = hidden_states .shape
149
+ hidden_dim = hidden_states .shape [- 1 ]
150
+ hidden_states = hidden_states .view (- 1 , hidden_dim )
151
+
152
+ shared_output = None
153
+ if self .shared_expert is not None :
154
+ shared_output = self .shared_expert (hidden_states )
155
+ if self .shared_expert_gate is not None :
156
+ shared_output = F .sigmoid (
157
+ self .shared_expert_gate (hidden_states )) * shared_output
158
+
159
+ # router_logits: (num_tokens, n_experts)
160
+ router_logits , _ = self .gate (hidden_states )
161
+ final_hidden_states = self .experts (hidden_states = hidden_states ,
162
+ router_logits = router_logits )
163
+
164
+ if shared_output is not None :
165
+ final_hidden_states = final_hidden_states + shared_output
166
+ if self .tp_size > 1 :
167
+ final_hidden_states = self .experts .maybe_all_reduce_tensor_model_parallel ( # noqa E501
168
+ final_hidden_states )
169
+
170
+ return final_hidden_states .view (orig_shape )
65
171
66
172
67
173
def torch_chunk_gated_delta_rule (
@@ -363,6 +469,7 @@ def forward(
363
469
output : torch .Tensor ,
364
470
cache_params : Optional [MambaCacheParams ] = None ,
365
471
):
472
+ return torch .ops .vllm .npu_gdn_attention (
366
473
return torch .ops .vllm .npu_gdn_attention (
367
474
hidden_states ,
368
475
output ,
@@ -1098,7 +1205,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
1098
1205
return self .model .get_expert_mapping ()
1099
1206
1100
1207
1101
- def gdn_npu_attention (
1208
+ def npu_gdn_attention (
1102
1209
hidden_states : torch .Tensor ,
1103
1210
output : torch .Tensor ,
1104
1211
layer_name : str ,
@@ -1108,7 +1215,7 @@ def gdn_npu_attention(
1108
1215
self ._forward (hidden_states = hidden_states , output = output )
1109
1216
1110
1217
1111
- def gdn_npu_attention_fake (
1218
+ def npu_gdn_attention_fake (
1112
1219
hidden_states : torch .Tensor ,
1113
1220
output : torch .Tensor ,
1114
1221
layer_name : str ,
@@ -1117,9 +1224,9 @@ def gdn_npu_attention_fake(
1117
1224
1118
1225
1119
1226
direct_register_custom_op (
1120
- op_name = "gdn_attention " ,
1121
- op_func = gdn_npu_attention ,
1227
+ op_name = "npu_gdn_attention " ,
1228
+ op_func = npu_gdn_attention ,
1122
1229
mutates_args = ["output" ],
1123
- fake_impl = gdn_npu_attention_fake ,
1230
+ fake_impl = npu_gdn_attention_fake ,
1124
1231
dispatch_key = current_platform .dispatch_key ,
1125
1232
)
0 commit comments