Skip to content

Commit d26e032

Browse files
committed
add support for flashcomm2 in qwen3
Signed-off-by: David9857 <985700846@qq.com>
1 parent f08283a commit d26e032

File tree

2 files changed

+176
-10
lines changed

2 files changed

+176
-10
lines changed

vllm_ascend/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@
143143
# Batch MC2 in prefill: The number of tokens in each batch
144144
"VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE":
145145
lambda: int(os.getenv("VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE", "128")),
146+
"VLLM_ENABLE_FC":
147+
lambda: int(os.getenv("VLLM_ENABLE_FC", '0'))
146148
}
147149

148150
# end-env-vars-definition

vllm_ascend/models/qwen3.py

Lines changed: 174 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,101 @@
22
from typing import Optional, Union
33

44
import torch
5+
import torch.distributed as dist
6+
import torch.nn.functional as F
57
from torch import nn
68
from transformers import Qwen3Config
79
from vllm.attention import AttentionType
810
from vllm.compilation.decorators import support_torch_compile
911
from vllm.config import CacheConfig, VllmConfig
10-
from vllm.distributed import get_pp_group
1112
from vllm.model_executor.layers.layernorm import RMSNorm
13+
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
14+
get_tensor_model_parallel_world_size,
15+
tensor_model_parallel_all_gather)
16+
from vllm.model_executor.layers.linear import (ReplicatedLinear,
17+
RowParallelLinear)
1218
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1319
from vllm.model_executor.layers.quantization import QuantizationConfig
1420
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
1521
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
22+
from vllm.model_executor.models.qwen2 import Qwen2MLP as Qwen3MLP
1623
from vllm.model_executor.models.qwen2 import Qwen2Model
1724
from vllm.model_executor.models.qwen3 import Qwen3Attention, Qwen3MLP
1825
from vllm.model_executor.models.utils import (AutoWeightsLoader,
1926
PPMissingLayer, maybe_prefix)
2027
from vllm.model_executor.sampling_metadata import SamplingMetadata
2128
from vllm.sequence import IntermediateTensors
2229

30+
from vllm_ascend import envs
2331
from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant
2432

2533

34+
def pad(tensor, x):
35+
length = tensor.size(0)
36+
pad_size = (x - (length % x)) % x
37+
if pad_size > 0:
38+
return F.pad(tensor, (0, 0, 0, pad_size)), pad_size
39+
return tensor, pad_size
40+
41+
42+
def unpad(tensor, pad_size):
43+
if pad_size > 0:
44+
return tensor[:-pad_size, :]
45+
return tensor
46+
47+
48+
class CustomQwen3MLP(Qwen3MLP):
49+
50+
def __init__(
51+
self,
52+
hidden_size: int,
53+
intermediate_size: int,
54+
hidden_act: str,
55+
quant_config: Optional[QuantizationConfig] = None,
56+
prefix: str = "",
57+
) -> None:
58+
super().__init__(hidden_size=hidden_size,
59+
intermediate_size=intermediate_size,
60+
hidden_act=hidden_act,
61+
quant_config=quant_config,
62+
prefix=prefix)
63+
self.tp_size = get_tensor_model_parallel_world_size()
64+
self.tp_rank = get_tensor_model_parallel_rank()
65+
self.enable_fc = envs.VLLM_ENABLE_FC
66+
if self.enable_fc:
67+
# if flashcomm2 enabled, replace Linear+AllReduce with All2All+Linear
68+
self.down_proj = ReplicatedLinear(
69+
intermediate_size,
70+
hidden_size,
71+
bias=False,
72+
quant_config=quant_config,
73+
prefix=f"{prefix}.down_proj",
74+
)
75+
else:
76+
self.down_proj = RowParallelLinear(
77+
intermediate_size,
78+
hidden_size,
79+
bias=False,
80+
quant_config=quant_config,
81+
prefix=f"{prefix}.down_proj",
82+
)
83+
84+
def forward(self, x):
85+
gate_up, _ = self.gate_up_proj(x)
86+
x = self.act_fn(gate_up)
87+
pad_size = 0
88+
if self.enable_fc:
89+
# pad input because AllGather requires token_num to be divisible by tp_size
90+
x, pad_size = pad(x, self.tp_size)
91+
output = torch.empty(x.shape, dtype=x.dtype, device=x.device)
92+
dist.all_to_all_single(output, x)
93+
x = output.reshape(self.tp_size, -1, output.size(-1)) \
94+
.transpose(0, 1) \
95+
.reshape(-1, output.size(-1)*self.tp_size)
96+
x, _ = self.down_proj(x)
97+
return x, pad_size
98+
99+
26100
class CustomQwen3Attention(Qwen3Attention):
27101

28102
def __init__(self,
@@ -52,6 +126,25 @@ def __init__(self,
52126
rope_scaling=rope_scaling,
53127
prefix=prefix,
54128
attn_type=attn_type)
129+
self.tp_size = get_tensor_model_parallel_world_size()
130+
self.tp_rank = get_tensor_model_parallel_rank()
131+
self.enable_fc = envs.VLLM_ENABLE_FC
132+
if self.enable_fc:
133+
self.o_proj = ReplicatedLinear(
134+
self.total_num_heads * self.head_dim,
135+
hidden_size,
136+
bias=False,
137+
quant_config=quant_config,
138+
prefix=f"{prefix}.o_proj",
139+
)
140+
else:
141+
self.o_proj = RowParallelLinear(
142+
self.total_num_heads * self.head_dim,
143+
hidden_size,
144+
bias=False,
145+
quant_config=quant_config,
146+
prefix=f"{prefix}.o_proj",
147+
)
55148

56149
def forward(
57150
self,
@@ -78,8 +171,19 @@ def forward(
78171
sin=sin,
79172
skip_index_select=True)
80173
attn_output = self.attn(q, k, v)
174+
pad_size = 0
175+
if self.enable_fc:
176+
# pad input because AllGather requires token_num to be divisible by tp_size
177+
attn_output, pad_size = pad(attn_output, self.tp_size)
178+
output = torch.empty(attn_output.shape,
179+
dtype=attn_output.dtype,
180+
device=attn_output.device)
181+
dist.all_to_all_single(output, attn_output)
182+
attn_output = output.reshape(self.tp_size, -1, output.size(-1)) \
183+
.transpose(0, 1) \
184+
.reshape(-1, output.size(-1)*self.tp_size)
81185
output, _ = self.o_proj(attn_output)
82-
return output
186+
return output, pad_size
83187

84188

85189
class CustomQwen3DecoderLayer(nn.Module):
@@ -93,6 +197,9 @@ def __init__(
93197
) -> None:
94198
super().__init__()
95199
self.hidden_size = config.hidden_size
200+
self.tp_size = get_tensor_model_parallel_world_size()
201+
self.tp_rank = get_tensor_model_parallel_rank()
202+
self.enable_fc = envs.VLLM_ENABLE_FC
96203
# Requires transformers > 4.32.0
97204
rope_theta = getattr(config, "rope_theta", 1000000)
98205
rope_scaling = getattr(config, "rope_scaling", None)
@@ -121,7 +228,7 @@ def __init__(
121228
prefix=f"{prefix}.self_attn",
122229
attn_type=attn_type,
123230
)
124-
self.mlp = Qwen3MLP(
231+
self.mlp = CustomQwen3MLP(
125232
hidden_size=self.hidden_size,
126233
intermediate_size=config.intermediate_size,
127234
hidden_act=config.hidden_act,
@@ -185,6 +292,57 @@ def forward(
185292
hidden_states = self.mlp(hidden_states)
186293
return hidden_states, residual
187294

295+
def pre_attention_process(self, hidden_states, residual, pad_size=0):
296+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
297+
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
298+
hidden_states = unpad(hidden_states, pad_size)
299+
return hidden_states, residual
300+
301+
def pre_mlp_process(self, hidden_states, residual, pad_size=0):
302+
token_num = hidden_states.size(0)
303+
if token_num != residual.size(0):
304+
if pad_size > 0:
305+
residual = F.pad(residual, (0, 0, 0, pad_size))
306+
split_size_list = [token_num] * self.tp_size
307+
residual = torch.split(residual, split_size_list)[self.tp_rank]
308+
309+
hidden_states, residual = self.post_attention_layernorm(
310+
hidden_states, residual)
311+
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
312+
hidden_states = unpad(hidden_states, pad_size)
313+
return hidden_states, residual
314+
315+
def forward(self,
316+
positions: torch.Tensor,
317+
hidden_states: torch.Tensor,
318+
residual: Optional[torch.Tensor],
319+
pad_size: int = 0) -> tuple[torch.Tensor, torch.Tensor, int]:
320+
# Self Attention
321+
if residual is None:
322+
residual = hidden_states
323+
hidden_states = self.input_layernorm(hidden_states)
324+
else:
325+
if self.enable_fc:
326+
hidden_states, residual = self.pre_attention_process(
327+
hidden_states, residual, pad_size)
328+
else:
329+
hidden_states, residual = self.input_layernorm(
330+
hidden_states, residual)
331+
hidden_states, pad_size = self.self_attn(
332+
positions=positions,
333+
hidden_states=hidden_states,
334+
)
335+
336+
# Fully Connected
337+
if self.enable_fc:
338+
hidden_states, residual = self.pre_mlp_process(
339+
hidden_states, residual, pad_size)
340+
else:
341+
hidden_states, residual = self.post_attention_layernorm(
342+
hidden_states, residual)
343+
hidden_states, pad_size = self.mlp(hidden_states)
344+
return hidden_states, residual, pad_size
345+
188346

189347
ALL_DECODER_LAYER_TYPES = {
190348
"attention": CustomQwen3DecoderLayer,
@@ -207,6 +365,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
207365
prefix=prefix,
208366
decoder_layer_type=CustomQwen3DecoderLayer)
209367
self.cos_sin_cache = self.layers[0].self_attn.rotary_emb.cos_sin_cache
368+
self.tp_size = get_tensor_model_parallel_world_size()
369+
self.tp_rank = get_tensor_model_parallel_rank()
370+
self.enable_fc = envs.VLLM_ENABLE_FC
210371

211372
def forward(
212373
self,
@@ -235,14 +396,17 @@ def forward(
235396
cos, sin = cos.view(1, -1, 1, last_dim).contiguous(), sin.view(
236397
1, -1, 1, last_dim).contiguous()
237398

399+
pad_size = 0
238400
for layer in self.layers[self.start_layer:self.end_layer]:
239-
hidden_states, residual = layer(
240-
positions,
241-
cos,
242-
sin,
243-
hidden_states,
244-
residual,
245-
)
401+
hidden_states, residual, pad_size = layer(positions, hidden_states,
402+
residual, pad_size)
403+
if self.enable_fc:
404+
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
405+
residual = tensor_model_parallel_all_gather(residual, 0)
406+
if pad_size > 0:
407+
hidden_states = hidden_states[:-pad_size]
408+
residual = residual[:-pad_size]
409+
246410
if not get_pp_group().is_last_rank:
247411
return IntermediateTensors({
248412
"hidden_states": hidden_states,

0 commit comments

Comments
 (0)