11
11
from torch import nn
12
12
from transformers .activations import ACT2FN
13
13
from vllm import envs
14
- from vllm .attention import Attention , AttentionBackend , AttentionMetadata
14
+ from vllm .attention import AttentionBackend , AttentionMetadata
15
15
from vllm .compilation .decorators import support_torch_compile
16
16
from vllm .config import (CacheConfig , ModelConfig , SpeculativeConfig ,
17
17
VllmConfig , get_current_vllm_config )
18
- from vllm .distributed import (divide , get_ep_group , get_pp_group ,
18
+ from vllm .distributed import (divide , get_pp_group ,
19
19
get_tensor_model_parallel_rank ,
20
20
get_tensor_model_parallel_world_size )
21
21
from vllm .forward_context import ForwardContext , get_forward_context
27
27
# yapf: enable
28
28
from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
29
29
MergedColumnParallelLinear ,
30
- QKVParallelLinear ,
31
- ReplicatedLinear ,
32
30
RowParallelLinear )
33
31
from vllm .model_executor .layers .logits_processor import LogitsProcessor
34
32
from vllm .model_executor .layers .mamba .abstract import MambaBase
37
35
from vllm .model_executor .layers .mamba .mamba_utils import (
38
36
MambaStateDtypeCalculator , MambaStateShapeCalculator )
39
37
from vllm .model_executor .layers .quantization import QuantizationConfig
40
- from vllm .model_executor .layers .quantization .gptq import GPTQConfig
41
- from vllm .model_executor .layers .quantization .gptq_marlin import \
42
- GPTQMarlinConfig
43
- from vllm .model_executor .layers .rotary_embedding import get_rope
44
38
from vllm .model_executor .layers .vocab_parallel_embedding import (
45
39
DEFAULT_VOCAB_PADDING_SIZE , ParallelLMHead , VocabParallelEmbedding )
46
40
from vllm .model_executor .model_loader .weight_utils import (
50
44
SupportsLoRA , SupportsPP )
51
45
from vllm .model_executor .models .mamba_cache import MambaCacheParams
52
46
from vllm .model_executor .models .qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
47
+ from vllm .model_executor .models .qwen3_next import (Qwen3NextAttention ,
48
+ Qwen3NextSparseMoeBlock )
53
49
from vllm .model_executor .models .utils import (
54
50
AutoWeightsLoader , PPMissingLayer , extract_layer_index ,
55
51
is_pp_missing_parameter , make_empty_intermediate_tensors_factory ,
68
64
from vllm_ascend .ops .sigmoid_gating import fused_recurrent_gated_delta_rule
69
65
70
66
71
- class Qwen3NextSparseMoeBlock (nn .Module ):
72
-
73
- def __init__ (
74
- self ,
75
- config : Qwen3NextConfig ,
76
- quant_config : Optional [QuantizationConfig ] = None ,
77
- prefix : str = "" ,
78
- enable_eplb : bool = False ,
79
- ):
80
- super ().__init__ ()
81
- self .tp_size = get_tensor_model_parallel_world_size ()
82
-
83
- self .ep_group = get_ep_group ().device_group
84
- self .ep_rank = self .ep_group .rank ()
85
- self .ep_size = self .ep_group .size ()
86
- self .n_routed_experts = config .num_experts
87
-
88
- if self .tp_size > config .num_experts :
89
- raise ValueError (
90
- f"Tensor parallel size { self .tp_size } is greater than "
91
- f"the number of experts { config .num_experts } ." )
92
-
93
- # Load balancing settings.
94
- vllm_config = get_current_vllm_config ()
95
- eplb_config = vllm_config .parallel_config .eplb_config
96
- self .enable_eplb = enable_eplb
97
-
98
- self .n_logical_experts = self .n_routed_experts
99
- self .n_redundant_experts = eplb_config .num_redundant_experts
100
- self .n_physical_experts = (self .n_logical_experts +
101
- self .n_redundant_experts )
102
- self .n_local_physical_experts = self .n_physical_experts // self .ep_size
103
-
104
- self .physical_expert_start = (self .ep_rank *
105
- self .n_local_physical_experts )
106
- self .physical_expert_end = (self .physical_expert_start +
107
- self .n_local_physical_experts )
108
-
109
- self .experts = FusedMoE (num_experts = self .n_routed_experts ,
110
- top_k = config .num_experts_per_tok ,
111
- hidden_size = config .hidden_size ,
112
- intermediate_size = config .moe_intermediate_size ,
113
- reduce_results = False ,
114
- renormalize = config .norm_topk_prob ,
115
- quant_config = quant_config ,
116
- prefix = f"{ prefix } .experts" ,
117
- enable_eplb = self .enable_eplb ,
118
- num_redundant_experts = self .n_redundant_experts )
119
-
120
- self .gate = ReplicatedLinear (
121
- config .hidden_size ,
122
- config .num_experts ,
123
- bias = False ,
124
- quant_config = self ._maybe_ignore_quant_config (quant_config ),
125
- prefix = f"{ prefix } .gate" )
126
-
127
- if config .shared_expert_intermediate_size > 0 :
128
- self .shared_expert = Qwen3NextMLP (
129
- hidden_size = config .hidden_size ,
130
- intermediate_size = config .shared_expert_intermediate_size ,
131
- hidden_act = config .hidden_act ,
132
- quant_config = quant_config ,
133
- reduce_results = self .experts .must_reduce_shared_expert_outputs (
134
- ),
135
- )
136
- else :
137
- self .shared_expert = None
138
- self .shared_expert_gate = torch .nn .Linear (config .hidden_size ,
139
- 1 ,
140
- bias = False )
141
-
142
- def _maybe_ignore_quant_config (self , quant_config : QuantizationConfig ):
143
- # GPTQ configs do not have a list of ignored modules, however AutoGPTQ
144
- # seems to avoid gate quantization.
145
- # See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4
146
- if isinstance (quant_config , (GPTQConfig , GPTQMarlinConfig )):
147
- return None
148
- return quant_config
149
-
150
- def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
151
- # NOTE: hidden_states can have either 1D or 2D shape.
152
- orig_shape = hidden_states .shape
153
- hidden_dim = hidden_states .shape [- 1 ]
154
- hidden_states = hidden_states .view (- 1 , hidden_dim )
155
-
156
- shared_output = None
157
- if self .shared_expert is not None :
158
- shared_output = self .shared_expert (hidden_states )
159
- if self .shared_expert_gate is not None :
160
- shared_output = F .sigmoid (
161
- self .shared_expert_gate (hidden_states )) * shared_output
162
-
163
- # router_logits: (num_tokens, n_experts)
164
- router_logits , _ = self .gate (hidden_states )
165
- final_hidden_states = self .experts (hidden_states = hidden_states ,
166
- router_logits = router_logits )
167
-
168
- if shared_output is not None :
169
- final_hidden_states = final_hidden_states + shared_output
170
- if self .tp_size > 1 :
171
- final_hidden_states = self .experts .maybe_all_reduce_tensor_model_parallel ( # noqa E501
172
- final_hidden_states )
173
-
174
- return final_hidden_states .view (orig_shape )
175
-
176
-
177
67
def torch_chunk_gated_delta_rule (
178
68
query ,
179
69
key ,
@@ -473,7 +363,7 @@ def forward(
473
363
output : torch .Tensor ,
474
364
cache_params : Optional [MambaCacheParams ] = None ,
475
365
):
476
- return torch .ops .vllm .gdn_attention (
366
+ return torch .ops .vllm .npu_gdn_attention (
477
367
hidden_states ,
478
368
output ,
479
369
self .prefix ,
@@ -737,123 +627,6 @@ def _forward(
737
627
output [:num_actual_tokens ], _ = self .out_proj (core_attn_out )
738
628
739
629
740
- class Qwen3NextAttention (nn .Module ):
741
-
742
- def __init__ (
743
- self ,
744
- config : Qwen3NextConfig ,
745
- model_config : Optional [ModelConfig ] = None ,
746
- cache_config : Optional [CacheConfig ] = None ,
747
- quant_config : Optional [QuantizationConfig ] = None ,
748
- prefix : str = "" ,
749
- ) -> None :
750
- super ().__init__ ()
751
- self .config = config
752
- self .hidden_size = config .hidden_size
753
- tp_size = get_tensor_model_parallel_world_size ()
754
- self .total_num_heads = config .num_attention_heads
755
- assert self .total_num_heads % tp_size == 0
756
- self .num_heads = self .total_num_heads // tp_size
757
- self .total_num_kv_heads = config .num_key_value_heads
758
- if self .total_num_kv_heads >= tp_size :
759
- # Number of KV heads is greater than TP size, so we partition
760
- # the KV heads across multiple tensor parallel GPUs.
761
- assert self .total_num_kv_heads % tp_size == 0
762
- else :
763
- # Number of KV heads is less than TP size, so we replicate
764
- # the KV heads across multiple tensor parallel GPUs.
765
- assert tp_size % self .total_num_kv_heads == 0
766
- self .num_kv_heads = max (1 , self .total_num_kv_heads // tp_size )
767
- self .head_dim = config .head_dim or (self .hidden_size // self .num_heads )
768
- self .q_size = self .num_heads * self .head_dim
769
- self .kv_size = self .num_kv_heads * self .head_dim
770
- self .scaling = self .head_dim ** - 0.5
771
- self .dual_chunk_attention_config = getattr (
772
- config , "dual_chunk_attention_config" , None )
773
- self .attn_output_gate = getattr (config , "attn_output_gate" , True )
774
-
775
- self .qkv_proj = QKVParallelLinear (
776
- config .hidden_size ,
777
- self .head_dim ,
778
- self .total_num_heads * (1 + self .attn_output_gate ),
779
- self .total_num_kv_heads ,
780
- bias = getattr (config , "qkv_bias" , False ),
781
- quant_config = quant_config ,
782
- prefix = f"{ prefix } .qkv_proj" ,
783
- )
784
-
785
- self .o_proj = RowParallelLinear (
786
- self .total_num_heads * self .head_dim ,
787
- config .hidden_size ,
788
- bias = False ,
789
- quant_config = quant_config ,
790
- prefix = f"{ prefix } .o_proj" ,
791
- )
792
-
793
- self .rotary_emb = get_rope (
794
- head_size = self .head_dim ,
795
- rotary_dim = self .head_dim ,
796
- max_position = config .max_position_embeddings ,
797
- base = config .rope_theta ,
798
- rope_scaling = config .rope_scaling ,
799
- partial_rotary_factor = config .partial_rotary_factor ,
800
- dual_chunk_attention_config = self .dual_chunk_attention_config ,
801
- )
802
-
803
- self .attn = Attention (
804
- self .num_heads ,
805
- self .head_dim ,
806
- self .scaling ,
807
- num_kv_heads = self .num_kv_heads ,
808
- cache_config = cache_config ,
809
- quant_config = quant_config ,
810
- prefix = f"{ prefix } .attn" ,
811
- ** {
812
- "layer_idx" : extract_layer_index (prefix ),
813
- "dual_chunk_attention_config" :
814
- self .dual_chunk_attention_config ,
815
- } if self .dual_chunk_attention_config else {},
816
- )
817
-
818
- self .q_norm = Qwen3NextRMSNorm (self .head_dim , eps = config .rms_norm_eps )
819
- self .k_norm = Qwen3NextRMSNorm (self .head_dim , eps = config .rms_norm_eps )
820
-
821
- def forward (
822
- self ,
823
- positions : torch .Tensor ,
824
- output : torch .Tensor ,
825
- hidden_states : torch .Tensor ,
826
- ):
827
- qkv , _ = self .qkv_proj (hidden_states )
828
-
829
- if self .attn_output_gate :
830
- q_gate , k , v = qkv .split (
831
- [self .q_size * 2 , self .kv_size , self .kv_size ], dim = - 1 )
832
- orig_shape = q_gate .shape [:- 1 ]
833
- q_gate = q_gate .view (* orig_shape , self .num_heads , - 1 )
834
- q , gate = torch .chunk (q_gate , 2 , dim = - 1 )
835
- q = q .reshape (* orig_shape , - 1 )
836
- gate = gate .reshape (* orig_shape , - 1 )
837
- else :
838
- q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ],
839
- dim = - 1 )
840
-
841
- q = self .q_norm (q .view (- 1 , self .num_heads , self .head_dim )).view (
842
- - 1 , self .num_heads * self .head_dim )
843
- k = self .k_norm (k .view (- 1 , self .num_kv_heads , self .head_dim )).view (
844
- - 1 , self .num_kv_heads * self .head_dim )
845
-
846
- q , k = self .rotary_emb (positions , q , k )
847
-
848
- attn_output = self .attn (q , k , v )
849
-
850
- if self .attn_output_gate :
851
- gate = torch .sigmoid (gate )
852
- attn_output = attn_output * gate
853
-
854
- output [:], _ = self .o_proj (attn_output )
855
-
856
-
857
630
class Qwen3NextDecoderLayer (nn .Module ):
858
631
859
632
def __init__ (
@@ -1325,7 +1098,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
1325
1098
return self .model .get_expert_mapping ()
1326
1099
1327
1100
1328
- def gdn_attention (
1101
+ def npu_gdn_attention (
1329
1102
hidden_states : torch .Tensor ,
1330
1103
output : torch .Tensor ,
1331
1104
layer_name : str ,
@@ -1335,7 +1108,7 @@ def gdn_attention(
1335
1108
self ._forward (hidden_states = hidden_states , output = output )
1336
1109
1337
1110
1338
- def gdn_attention_fake (
1111
+ def npu_gdn_attention_fake (
1339
1112
hidden_states : torch .Tensor ,
1340
1113
output : torch .Tensor ,
1341
1114
layer_name : str ,
@@ -1344,9 +1117,9 @@ def gdn_attention_fake(
1344
1117
1345
1118
1346
1119
direct_register_custom_op (
1347
- op_name = "gdn_attention " ,
1348
- op_func = gdn_attention ,
1120
+ op_name = "npu_gdn_attention " ,
1121
+ op_func = npu_gdn_attention ,
1349
1122
mutates_args = ["output" ],
1350
- fake_impl = gdn_attention_fake ,
1123
+ fake_impl = npu_gdn_attention_fake ,
1351
1124
dispatch_key = current_platform .dispatch_key ,
1352
1125
)
0 commit comments