Skip to content

Commit eef2d73

Browse files
committed
flashcomm2_qwen3_mode_code
Signed-off-by: iKunHvv <puppybala@outlook.com>
1 parent e75b568 commit eef2d73

File tree

5 files changed

+263
-18
lines changed

5 files changed

+263
-18
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import unittest
2+
import torch
3+
from unittest.mock import patch
4+
5+
from vllm_ascend.ops.sequence_parallel import init_metadata_for_flashcomm2
6+
7+
8+
class TestInitMetadataForFlashcomm2(unittest.TestCase):
9+
10+
def setUp(self):
11+
patcher = patch("vllm_ascend.ops.sequence_parallel.MetadataForPadding")
12+
self.MockMetadata = patcher.start()
13+
self.addCleanup(patcher.stop)
14+
15+
def _run_case(self, tp_size, input_len, expected_padding_flag, expected_pad_size):
16+
with patch(
17+
"vllm_ascend.ops.sequence_parallel.get_tensor_model_parallel_world_size",
18+
return_value=tp_size,
19+
):
20+
input_ids = torch.arange(input_len)
21+
22+
result = init_metadata_for_flashcomm2(input_ids)
23+
24+
# 验证 MetadataForPadding 调用参数
25+
self.MockMetadata.assert_called_once_with(
26+
lengths_sum_unpadding=input_len,
27+
lengths_sum_padding=((input_len + tp_size - 1) // tp_size) * tp_size,
28+
padding_flag=expected_padding_flag,
29+
pad_size=expected_pad_size,
30+
not_dummy_and_is_prefill=False,
31+
)
32+
33+
# 验证返回值
34+
self.assertEqual(result, self.MockMetadata.return_value)
35+
36+
def test_no_padding(self):
37+
self._run_case(tp_size=4, input_len=8, expected_padding_flag=False, expected_pad_size=0)
38+
39+
def test_with_padding(self):
40+
self._run_case(tp_size=4, input_len=10, expected_padding_flag=True, expected_pad_size=2)
41+
42+
def test_with_padding_non_multiple(self):
43+
self._run_case(tp_size=3, input_len=7, expected_padding_flag=True, expected_pad_size=2)
44+
45+
def test_exact_multiple(self):
46+
self._run_case(tp_size=5, input_len=5, expected_padding_flag=False, expected_pad_size=0)
47+
48+
def test_empty_input(self):
49+
self._run_case(tp_size=4, input_len=0, expected_padding_flag=False, expected_pad_size=0)

vllm_ascend/envs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,9 @@
133133
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
134134
# Whether to enable FlashComm optimization when tensor parallel is enabled.
135135
# This feature will get better performance when concurrency is large.
136+
# FlashComm optimization: Enable v1 and v2 by setting this flag to 1 or 2 respectively
136137
"VLLM_ASCEND_ENABLE_FLASHCOMM":
137-
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))),
138+
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0')),
138139
# Whether to enable dense model and general optimizations for better performance.
139140
# Since we modified the base parent class `linear`, this optimization is also applicable to other model types.
140141
# However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models.

vllm_ascend/models/qwen3_moe.py

Lines changed: 180 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@
1717
# Adapted from vllm/model_executor/models/qwen3_moe.py
1818
# This file is a part of the vllm-ascend project.
1919

20-
from typing import Optional, Union
20+
from typing import Any, Optional, Union
2121

2222
import torch
23+
import torch_npu
2324
from torch import nn
2425
from transformers import PretrainedConfig
26+
import torch.distributed as dist
27+
import vllm_ascend.envs as envs_ascend
2528
from vllm.compilation.decorators import support_torch_compile
2629
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
2730
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@@ -30,7 +33,8 @@
3033
from vllm.forward_context import get_forward_context
3134
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
3235
from vllm.model_executor.layers.layernorm import RMSNorm
33-
from vllm.model_executor.layers.linear import ReplicatedLinear
36+
from vllm.model_executor.layers.linear import (ReplicatedLinear,
37+
RowParallelLinear)
3438
from vllm.model_executor.layers.logits_processor import LogitsProcessor
3539
from vllm.model_executor.layers.quantization import QuantizationConfig
3640
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -47,9 +51,11 @@
4751
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
4852
from vllm.sequence import IntermediateTensors
4953

54+
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
5055
from vllm_ascend.ops.fused_moe import AscendFusedMoE
5156
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
5257
init_metadata_for_sp)
58+
init_metadata_for_sp, init_metadata_for_flashcomm2)
5359

5460

5561
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
@@ -125,6 +131,153 @@ def forward(
125131
return hidden_states
126132

127133

134+
class CustomQwen3MoeMLP(Qwen3MoeMLP):
135+
136+
def __init__(
137+
self,
138+
hidden_size: int,
139+
intermediate_size: int,
140+
hidden_act: str,
141+
quant_config: Optional[QuantizationConfig] = None,
142+
reduce_results: bool = True,
143+
prefix: str = "",
144+
) -> None:
145+
super().__init__(hidden_size=hidden_size,
146+
intermediate_size=intermediate_size,
147+
hidden_act=hidden_act,
148+
quant_config=quant_config,
149+
reduce_results=reduce_results,
150+
prefix=prefix)
151+
self.tp_size = get_tensor_model_parallel_world_size()
152+
self.enable_flashcomm2 = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM == 2
153+
if self.enable_flashcomm2:
154+
# if flashcomm2 enabled, replace Linear+AllReduce with All2All+Linear
155+
self.down_proj = ReplicatedLinear(
156+
intermediate_size,
157+
hidden_size,
158+
bias=False,
159+
quant_config=quant_config,
160+
prefix=f"{prefix}.down_proj",
161+
)
162+
else:
163+
self.down_proj = RowParallelLinear(
164+
intermediate_size,
165+
hidden_size,
166+
bias=False,
167+
quant_config=quant_config,
168+
prefix=f"{prefix}.down_proj",
169+
)
170+
171+
def forward(self, x, _metadata_for_padding=None):
172+
#if flashcomm2 enabled, the input of MLP is DP
173+
#so we need allgather hidden_states and then use TP in gate_up and use DP(by all2all) in down_proj
174+
if self.enable_flashcomm2:
175+
x = tensor_model_parallel_all_gather(x, 0)
176+
gate_up, _ = self.gate_up_proj(x)
177+
x = self.act_fn(gate_up)
178+
if self.enable_flashcomm2:
179+
#Do not need pad input, because the input of mlp is the output of the attn, which is padded
180+
output = torch.empty(x.shape, dtype=x.dtype, device=x.device)
181+
dist.all_to_all_single(output,
182+
x,
183+
group=get_tp_group().device_group)
184+
x = output.reshape(self.tp_size, -1, output.size(-1)) \
185+
.transpose(0, 1) \
186+
.reshape(-1, output.size(-1)*self.tp_size)
187+
x, _ = self.down_proj(x)
188+
return x
189+
190+
191+
class CustomQwen3MoeAttention(Qwen3MoeAttention):
192+
193+
def __init__(
194+
self,
195+
hidden_size: int,
196+
num_heads: int,
197+
num_kv_heads: int,
198+
rope_theta: float = 10000,
199+
rope_scaling: Optional[dict[str, Any]] = None,
200+
max_position_embeddings: int = 8192,
201+
head_dim: Optional[int] = None,
202+
rms_norm_eps: float = 1e-06,
203+
qkv_bias: bool = False,
204+
cache_config: Optional[CacheConfig] = None,
205+
quant_config: Optional[QuantizationConfig] = None,
206+
prefix: str = "",
207+
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
208+
) -> None:
209+
super().__init__(hidden_size=hidden_size,
210+
num_heads=num_heads,
211+
num_kv_heads=num_kv_heads,
212+
rope_theta=rope_theta,
213+
rope_scaling=rope_scaling,
214+
max_position_embeddings=max_position_embeddings,
215+
head_dim=head_dim,
216+
rms_norm_eps=rms_norm_eps,
217+
qkv_bias=qkv_bias,
218+
cache_config=cache_config,
219+
quant_config=quant_config,
220+
prefix=prefix,
221+
dual_chunk_attention_config=dual_chunk_attention_config)
222+
self.tp_size = get_tensor_model_parallel_world_size()
223+
self.enable_flashcomm2 = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM == 2
224+
if self.enable_flashcomm2:
225+
self.o_proj = ReplicatedLinear(
226+
self.total_num_heads * self.head_dim,
227+
hidden_size,
228+
bias=False,
229+
quant_config=quant_config,
230+
prefix=f"{prefix}.o_proj",
231+
)
232+
else:
233+
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
234+
hidden_size,
235+
bias=False,
236+
quant_config=quant_config,
237+
prefix=f"{prefix}.o_proj")
238+
239+
def attn_output_all_to_all(self,
240+
attn_output: torch.Tensor,
241+
_metadata_for_padding: Optional[MetadataForPadding] = None) -> torch.Tensor:
242+
assert _metadata_for_padding is not None, "Metadata for padding is required for FlashComm2."
243+
# pad input because AllGather requires token_num to be divisible by tp_size
244+
attn_output = _metadata_for_padding.padding_full(attn_output)
245+
output = torch.empty(attn_output.shape,
246+
dtype=attn_output.dtype,
247+
device=attn_output.device)
248+
dist.all_to_all_single(output,
249+
attn_output,
250+
group=get_tp_group().device_group)
251+
attn_output = output.reshape(self.tp_size, -1, output.size(-1)) \
252+
.transpose(0, 1) \
253+
.reshape(-1, output.size(-1)*self.tp_size)
254+
return attn_output
255+
256+
def forward(
257+
self,
258+
positions: torch.Tensor,
259+
hidden_states: torch.Tensor,
260+
_metadata_for_padding: Optional[MetadataForPadding] = None) -> torch.Tensor:
261+
qkv, _ = self.qkv_proj(hidden_states)
262+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
263+
# Add qk-norm
264+
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
265+
self.head_dim)
266+
q_by_head = self.q_norm(q_by_head)
267+
q = q_by_head.view(q.shape)
268+
269+
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
270+
self.head_dim)
271+
k_by_head = self.k_norm(k_by_head)
272+
k = k_by_head.view(k.shape)
273+
q, k = self.rotary_emb(positions, q, k)
274+
attn_output = self.attn(q, k, v)
275+
if self.enable_flashcomm2:
276+
attn_output = self.attn_output_all_to_all(attn_output, _metadata_for_padding)
277+
output, _ = self.o_proj(attn_output)
278+
return output
279+
280+
128281
class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
129282

130283
def __init__(
@@ -142,7 +295,7 @@ def __init__(
142295
rope_scaling = getattr(config, "rope_scaling", None)
143296
max_position_embeddings = getattr(config, "max_position_embeddings",
144297
8192)
145-
self.self_attn = Qwen3MoeAttention(
298+
self.self_attn = CustomQwen3MoeAttention(
146299
hidden_size=self.hidden_size,
147300
num_heads=config.num_attention_heads,
148301
num_kv_heads=config.num_key_value_heads,
@@ -178,7 +331,7 @@ def __init__(
178331
quant_config=quant_config,
179332
prefix=f"{prefix}.mlp")
180333
else:
181-
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
334+
self.mlp = CustomQwen3MoeMLP(hidden_size=config.hidden_size,
182335
intermediate_size=config.intermediate_size,
183336
hidden_act=config.hidden_act,
184337
quant_config=quant_config,
@@ -191,6 +344,7 @@ def __init__(
191344
self.enable_sequence_parallelism = (
192345
vllm_config.compilation_config.pass_config.
193346
enable_sequence_parallelism if vllm_config is not None else False)
347+
self.enable_flashcomm2 = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM == 2
194348

195349
def forward(
196350
self,
@@ -201,34 +355,37 @@ def forward(
201355
) -> torch.Tensor:
202356

203357
# To prevent precision issues during the decoder phase when only prefilling enables SP
204-
if not self.enable_sequence_parallelism:
205-
self.self_attn.o_proj.reduce_results = True
206-
else:
207-
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True
358+
if not self.enable_flashcomm2:
359+
if not self.enable_sequence_parallelism:
360+
self.self_attn.o_proj.reduce_results = True
361+
else:
362+
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True
208363

209364
# Self Attention
210365
if residual is None:
211366
residual = hidden_states
212-
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
367+
if _metadata_for_padding and (_metadata_for_padding.not_dummy_and_is_prefill or self.enable_flashcomm2):
213368
residual = _metadata_for_padding.padding_slice(residual)
214369

215370
hidden_states = self.input_layernorm(hidden_states)
216371
else:
217372
hidden_states, residual = self.input_layernorm(
218373
hidden_states, residual)
219374

220-
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
375+
if _metadata_for_padding and (_metadata_for_padding.not_dummy_and_is_prefill or self.enable_flashcomm2):
221376
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
222377
hidden_states)
223378

224379
hidden_states = self.self_attn(
225380
positions=positions,
226381
hidden_states=hidden_states,
382+
_metadata_for_padding=_metadata_for_padding
227383
)
228384

229-
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
230-
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(
231-
hidden_states)
385+
if not self.enable_flashcomm2:
386+
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
387+
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(
388+
hidden_states)
232389

233390
# Fully Connected
234391
hidden_states, residual = self.post_attention_layernorm(
@@ -276,6 +433,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
276433
self.make_empty_intermediate_tensors = (
277434
make_empty_intermediate_tensors_factory(
278435
["hidden_states", "residual"], config.hidden_size))
436+
self.enable_flashcomm2 = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM == 2
279437

280438
def forward(
281439
self,
@@ -310,7 +468,7 @@ def forward(
310468

311469
hidden_states, _ = self.norm(hidden_states, residual)
312470

313-
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
471+
if _metadata_for_padding and (_metadata_for_padding.not_dummy_and_is_prefill or self.enable_flashcomm2):
314472
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
315473
hidden_states)
316474

@@ -354,6 +512,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
354512
self.model.make_empty_intermediate_tensors)
355513

356514
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
515+
self.enable_flashcomm2 = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM == 2
357516
# Set MoE hyperparameters
358517
self.expert_weights: list[torch.Tensor] = []
359518

@@ -382,8 +541,13 @@ def forward(
382541
intermediate_tensors: Optional[IntermediateTensors] = None,
383542
inputs_embeds: Optional[torch.Tensor] = None,
384543
) -> Union[torch.Tensor, IntermediateTensors]:
385-
_metadata_for_padding = init_metadata_for_sp(
386-
input_ids, self.enable_sequence_parallelism)
544+
if self.enable_flashcomm2:
545+
if self.enable_sequence_parallelism:
546+
raise ValueError(f"Sequence parallelism and FlashComm2 cannot be enabled simultaneously.")
547+
_metadata_for_padding = init_metadata_for_flashcomm2(input_ids)
548+
else:
549+
_metadata_for_padding = init_metadata_for_sp(
550+
input_ids, self.enable_sequence_parallelism)
387551
hidden_states = self.model(input_ids, positions, intermediate_tensors,
388552
inputs_embeds, _metadata_for_padding)
389553
return hidden_states

vllm_ascend/ops/fused_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222
import torch.distributed as dist
2323
import torch_npu
24+
import vllm_ascend.envs as envs_ascend
2425
from torch import nn
2526
from vllm.config import get_current_vllm_config
2627
from vllm.distributed import (get_tensor_model_parallel_rank,
@@ -366,6 +367,7 @@ def __init__(
366367
num_experts=self.global_num_experts,
367368
num_global_redundant_experts=self.global_redundant_expert_num,
368369
num_local_experts=self.local_num_experts)
370+
self.enable_flashcomm2 = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM == 2
369371

370372
def naive_multicast(self, x: torch.Tensor,
371373
cu_tokens_across_dp_cpu: torch.Tensor):
@@ -415,7 +417,7 @@ def forward(self,
415417

416418
enable_sp = _metadata_for_padding is not None and _metadata_for_padding.not_dummy_and_is_prefill
417419
tp_size = get_tensor_model_parallel_world_size()
418-
if enable_sp:
420+
if enable_sp or self.enable_flashcomm2:
419421
tp_rank = get_tensor_model_parallel_rank()
420422
mc2_mask_sp = _metadata_for_padding.mc2_mask if _metadata_for_padding is not None else forward_context.mc2_mask
421423
chunk_mc2_mask = torch.tensor_split(mc2_mask_sp, tp_size, dim=0)

0 commit comments

Comments
 (0)