Skip to content

Commit 843398c

Browse files
committed
Merge remote-tracking branch 'upstream_gitee/main' into main_eplb_0916
2 parents 65e348f + 14b39d3 commit 843398c

File tree

3 files changed

+18
-238
lines changed

3 files changed

+18
-238
lines changed

.github/workflows/pre-commit.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ name: pre-commit
22

33
on:
44
workflow_call:
5+
inputs:
6+
vllm:
7+
required: true
8+
type: string
59

610
permissions:
711
contents: read
@@ -22,6 +26,7 @@ jobs:
2226
with:
2327
repository: vllm-project/vllm
2428
path: ./vllm-empty
29+
ref: ${{ inputs.vllm }}
2530
- name: Install vllm
2631
working-directory: vllm-empty
2732
run: |

.github/workflows/vllm_ascend_test.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ concurrency:
4141
jobs:
4242
lint:
4343
uses: ./.github/workflows/pre-commit.yml
44+
with:
45+
vllm: c60e6137f0bf2034853919b3a9d705d7e06b93cf
4446

4547
changes:
4648
runs-on: ubuntu-latest
@@ -143,7 +145,7 @@ jobs:
143145
if: ${{ github.event_name == 'pull_request' && needs.lint.result == 'success' && needs.changes.outputs.e2e_tracker == 'true' && !contains(github.event.pull_request.labels.*.name, 'ready') }}
144146
uses: ./.github/workflows/_e2e_test.yaml
145147
with:
146-
vllm: v0.10.2
148+
vllm: ${{ matrix.vllm_version }}
147149
runner: linux-aarch64-a2
148150
image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.2.rc1-910b-ubuntu22.04-py3.11
149151
type: light

vllm_ascend/models/qwen3_next.py

Lines changed: 10 additions & 237 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
from torch import nn
1212
from transformers.activations import ACT2FN
1313
from vllm import envs
14-
from vllm.attention import Attention, AttentionBackend, AttentionMetadata
14+
from vllm.attention import AttentionBackend, AttentionMetadata
1515
from vllm.compilation.decorators import support_torch_compile
1616
from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
1717
VllmConfig, get_current_vllm_config)
18-
from vllm.distributed import (divide, get_ep_group, get_pp_group,
18+
from vllm.distributed import (divide, get_pp_group,
1919
get_tensor_model_parallel_rank,
2020
get_tensor_model_parallel_world_size)
2121
from vllm.forward_context import ForwardContext, get_forward_context
@@ -27,8 +27,6 @@
2727
# yapf: enable
2828
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
2929
MergedColumnParallelLinear,
30-
QKVParallelLinear,
31-
ReplicatedLinear,
3230
RowParallelLinear)
3331
from vllm.model_executor.layers.logits_processor import LogitsProcessor
3432
from vllm.model_executor.layers.mamba.abstract import MambaBase
@@ -37,10 +35,6 @@
3735
from vllm.model_executor.layers.mamba.mamba_utils import (
3836
MambaStateDtypeCalculator, MambaStateShapeCalculator)
3937
from vllm.model_executor.layers.quantization import QuantizationConfig
40-
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
41-
from vllm.model_executor.layers.quantization.gptq_marlin import \
42-
GPTQMarlinConfig
43-
from vllm.model_executor.layers.rotary_embedding import get_rope
4438
from vllm.model_executor.layers.vocab_parallel_embedding import (
4539
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
4640
from vllm.model_executor.model_loader.weight_utils import (
@@ -50,6 +44,8 @@
5044
SupportsLoRA, SupportsPP)
5145
from vllm.model_executor.models.mamba_cache import MambaCacheParams
5246
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
47+
from vllm.model_executor.models.qwen3_next import (Qwen3NextAttention,
48+
Qwen3NextSparseMoeBlock)
5349
from vllm.model_executor.models.utils import (
5450
AutoWeightsLoader, PPMissingLayer, extract_layer_index,
5551
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
@@ -68,112 +64,6 @@
6864
from vllm_ascend.ops.sigmoid_gating import fused_recurrent_gated_delta_rule
6965

7066

71-
class Qwen3NextSparseMoeBlock(nn.Module):
72-
73-
def __init__(
74-
self,
75-
config: Qwen3NextConfig,
76-
quant_config: Optional[QuantizationConfig] = None,
77-
prefix: str = "",
78-
enable_eplb: bool = False,
79-
):
80-
super().__init__()
81-
self.tp_size = get_tensor_model_parallel_world_size()
82-
83-
self.ep_group = get_ep_group().device_group
84-
self.ep_rank = self.ep_group.rank()
85-
self.ep_size = self.ep_group.size()
86-
self.n_routed_experts = config.num_experts
87-
88-
if self.tp_size > config.num_experts:
89-
raise ValueError(
90-
f"Tensor parallel size {self.tp_size} is greater than "
91-
f"the number of experts {config.num_experts}.")
92-
93-
# Load balancing settings.
94-
vllm_config = get_current_vllm_config()
95-
eplb_config = vllm_config.parallel_config.eplb_config
96-
self.enable_eplb = enable_eplb
97-
98-
self.n_logical_experts = self.n_routed_experts
99-
self.n_redundant_experts = eplb_config.num_redundant_experts
100-
self.n_physical_experts = (self.n_logical_experts +
101-
self.n_redundant_experts)
102-
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
103-
104-
self.physical_expert_start = (self.ep_rank *
105-
self.n_local_physical_experts)
106-
self.physical_expert_end = (self.physical_expert_start +
107-
self.n_local_physical_experts)
108-
109-
self.experts = FusedMoE(num_experts=self.n_routed_experts,
110-
top_k=config.num_experts_per_tok,
111-
hidden_size=config.hidden_size,
112-
intermediate_size=config.moe_intermediate_size,
113-
reduce_results=False,
114-
renormalize=config.norm_topk_prob,
115-
quant_config=quant_config,
116-
prefix=f"{prefix}.experts",
117-
enable_eplb=self.enable_eplb,
118-
num_redundant_experts=self.n_redundant_experts)
119-
120-
self.gate = ReplicatedLinear(
121-
config.hidden_size,
122-
config.num_experts,
123-
bias=False,
124-
quant_config=self._maybe_ignore_quant_config(quant_config),
125-
prefix=f"{prefix}.gate")
126-
127-
if config.shared_expert_intermediate_size > 0:
128-
self.shared_expert = Qwen3NextMLP(
129-
hidden_size=config.hidden_size,
130-
intermediate_size=config.shared_expert_intermediate_size,
131-
hidden_act=config.hidden_act,
132-
quant_config=quant_config,
133-
reduce_results=self.experts.must_reduce_shared_expert_outputs(
134-
),
135-
)
136-
else:
137-
self.shared_expert = None
138-
self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
139-
1,
140-
bias=False)
141-
142-
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
143-
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
144-
# seems to avoid gate quantization.
145-
# See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4
146-
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
147-
return None
148-
return quant_config
149-
150-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
151-
# NOTE: hidden_states can have either 1D or 2D shape.
152-
orig_shape = hidden_states.shape
153-
hidden_dim = hidden_states.shape[-1]
154-
hidden_states = hidden_states.view(-1, hidden_dim)
155-
156-
shared_output = None
157-
if self.shared_expert is not None:
158-
shared_output = self.shared_expert(hidden_states)
159-
if self.shared_expert_gate is not None:
160-
shared_output = F.sigmoid(
161-
self.shared_expert_gate(hidden_states)) * shared_output
162-
163-
# router_logits: (num_tokens, n_experts)
164-
router_logits, _ = self.gate(hidden_states)
165-
final_hidden_states = self.experts(hidden_states=hidden_states,
166-
router_logits=router_logits)
167-
168-
if shared_output is not None:
169-
final_hidden_states = final_hidden_states + shared_output
170-
if self.tp_size > 1:
171-
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
172-
final_hidden_states)
173-
174-
return final_hidden_states.view(orig_shape)
175-
176-
17767
def torch_chunk_gated_delta_rule(
17868
query,
17969
key,
@@ -473,7 +363,7 @@ def forward(
473363
output: torch.Tensor,
474364
cache_params: Optional[MambaCacheParams] = None,
475365
):
476-
return torch.ops.vllm.gdn_attention(
366+
return torch.ops.vllm.npu_gdn_attention(
477367
hidden_states,
478368
output,
479369
self.prefix,
@@ -737,123 +627,6 @@ def _forward(
737627
output[:num_actual_tokens], _ = self.out_proj(core_attn_out)
738628

739629

740-
class Qwen3NextAttention(nn.Module):
741-
742-
def __init__(
743-
self,
744-
config: Qwen3NextConfig,
745-
model_config: Optional[ModelConfig] = None,
746-
cache_config: Optional[CacheConfig] = None,
747-
quant_config: Optional[QuantizationConfig] = None,
748-
prefix: str = "",
749-
) -> None:
750-
super().__init__()
751-
self.config = config
752-
self.hidden_size = config.hidden_size
753-
tp_size = get_tensor_model_parallel_world_size()
754-
self.total_num_heads = config.num_attention_heads
755-
assert self.total_num_heads % tp_size == 0
756-
self.num_heads = self.total_num_heads // tp_size
757-
self.total_num_kv_heads = config.num_key_value_heads
758-
if self.total_num_kv_heads >= tp_size:
759-
# Number of KV heads is greater than TP size, so we partition
760-
# the KV heads across multiple tensor parallel GPUs.
761-
assert self.total_num_kv_heads % tp_size == 0
762-
else:
763-
# Number of KV heads is less than TP size, so we replicate
764-
# the KV heads across multiple tensor parallel GPUs.
765-
assert tp_size % self.total_num_kv_heads == 0
766-
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
767-
self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
768-
self.q_size = self.num_heads * self.head_dim
769-
self.kv_size = self.num_kv_heads * self.head_dim
770-
self.scaling = self.head_dim**-0.5
771-
self.dual_chunk_attention_config = getattr(
772-
config, "dual_chunk_attention_config", None)
773-
self.attn_output_gate = getattr(config, "attn_output_gate", True)
774-
775-
self.qkv_proj = QKVParallelLinear(
776-
config.hidden_size,
777-
self.head_dim,
778-
self.total_num_heads * (1 + self.attn_output_gate),
779-
self.total_num_kv_heads,
780-
bias=getattr(config, "qkv_bias", False),
781-
quant_config=quant_config,
782-
prefix=f"{prefix}.qkv_proj",
783-
)
784-
785-
self.o_proj = RowParallelLinear(
786-
self.total_num_heads * self.head_dim,
787-
config.hidden_size,
788-
bias=False,
789-
quant_config=quant_config,
790-
prefix=f"{prefix}.o_proj",
791-
)
792-
793-
self.rotary_emb = get_rope(
794-
head_size=self.head_dim,
795-
rotary_dim=self.head_dim,
796-
max_position=config.max_position_embeddings,
797-
base=config.rope_theta,
798-
rope_scaling=config.rope_scaling,
799-
partial_rotary_factor=config.partial_rotary_factor,
800-
dual_chunk_attention_config=self.dual_chunk_attention_config,
801-
)
802-
803-
self.attn = Attention(
804-
self.num_heads,
805-
self.head_dim,
806-
self.scaling,
807-
num_kv_heads=self.num_kv_heads,
808-
cache_config=cache_config,
809-
quant_config=quant_config,
810-
prefix=f"{prefix}.attn",
811-
**{
812-
"layer_idx": extract_layer_index(prefix),
813-
"dual_chunk_attention_config":
814-
self.dual_chunk_attention_config,
815-
} if self.dual_chunk_attention_config else {},
816-
)
817-
818-
self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
819-
self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
820-
821-
def forward(
822-
self,
823-
positions: torch.Tensor,
824-
output: torch.Tensor,
825-
hidden_states: torch.Tensor,
826-
):
827-
qkv, _ = self.qkv_proj(hidden_states)
828-
829-
if self.attn_output_gate:
830-
q_gate, k, v = qkv.split(
831-
[self.q_size * 2, self.kv_size, self.kv_size], dim=-1)
832-
orig_shape = q_gate.shape[:-1]
833-
q_gate = q_gate.view(*orig_shape, self.num_heads, -1)
834-
q, gate = torch.chunk(q_gate, 2, dim=-1)
835-
q = q.reshape(*orig_shape, -1)
836-
gate = gate.reshape(*orig_shape, -1)
837-
else:
838-
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
839-
dim=-1)
840-
841-
q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view(
842-
-1, self.num_heads * self.head_dim)
843-
k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view(
844-
-1, self.num_kv_heads * self.head_dim)
845-
846-
q, k = self.rotary_emb(positions, q, k)
847-
848-
attn_output = self.attn(q, k, v)
849-
850-
if self.attn_output_gate:
851-
gate = torch.sigmoid(gate)
852-
attn_output = attn_output * gate
853-
854-
output[:], _ = self.o_proj(attn_output)
855-
856-
857630
class Qwen3NextDecoderLayer(nn.Module):
858631

859632
def __init__(
@@ -1325,7 +1098,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
13251098
return self.model.get_expert_mapping()
13261099

13271100

1328-
def gdn_attention(
1101+
def npu_gdn_attention(
13291102
hidden_states: torch.Tensor,
13301103
output: torch.Tensor,
13311104
layer_name: str,
@@ -1335,7 +1108,7 @@ def gdn_attention(
13351108
self._forward(hidden_states=hidden_states, output=output)
13361109

13371110

1338-
def gdn_attention_fake(
1111+
def npu_gdn_attention_fake(
13391112
hidden_states: torch.Tensor,
13401113
output: torch.Tensor,
13411114
layer_name: str,
@@ -1344,9 +1117,9 @@ def gdn_attention_fake(
13441117

13451118

13461119
direct_register_custom_op(
1347-
op_name="gdn_attention",
1348-
op_func=gdn_attention,
1120+
op_name="npu_gdn_attention",
1121+
op_func=npu_gdn_attention,
13491122
mutates_args=["output"],
1350-
fake_impl=gdn_attention_fake,
1123+
fake_impl=npu_gdn_attention_fake,
13511124
dispatch_key=current_platform.dispatch_key,
13521125
)

0 commit comments

Comments
 (0)