Skip to content

Commit 716d6f6

Browse files
committed
add support for flashcomm2 in qwen3
Signed-off-by: David9857 <985700846@qq.com>
1 parent f08283a commit 716d6f6

File tree

2 files changed

+172
-30
lines changed

2 files changed

+172
-30
lines changed

vllm_ascend/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@
143143
# Batch MC2 in prefill: The number of tokens in each batch
144144
"VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE":
145145
lambda: int(os.getenv("VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE", "128")),
146+
"VLLM_ENABLE_FC":
147+
lambda: int(os.getenv("VLLM_ENABLE_FC", '0'))
146148
}
147149

148150
# end-env-vars-definition

vllm_ascend/models/qwen3.py

Lines changed: 170 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,19 @@
22
from typing import Optional, Union
33

44
import torch
5+
import torch.distributed as dist
6+
import torch.nn.functional as F
57
from torch import nn
68
from transformers import Qwen3Config
79
from vllm.attention import AttentionType
810
from vllm.compilation.decorators import support_torch_compile
911
from vllm.config import CacheConfig, VllmConfig
10-
from vllm.distributed import get_pp_group
12+
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
13+
get_tensor_model_parallel_world_size,
14+
tensor_model_parallel_all_gather)
1115
from vllm.model_executor.layers.layernorm import RMSNorm
16+
from vllm.model_executor.layers.linear import (ReplicatedLinear,
17+
RowParallelLinear)
1218
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1319
from vllm.model_executor.layers.quantization import QuantizationConfig
1420
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
@@ -20,9 +26,76 @@
2026
from vllm.model_executor.sampling_metadata import SamplingMetadata
2127
from vllm.sequence import IntermediateTensors
2228

29+
from vllm_ascend import envs
2330
from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant
2431

2532

33+
def pad(tensor, x):
34+
length = tensor.size(0)
35+
pad_size = (x - (length % x)) % x
36+
if pad_size > 0:
37+
return F.pad(tensor, (0, 0, 0, pad_size)), pad_size
38+
return tensor, pad_size
39+
40+
41+
def unpad(tensor, pad_size):
42+
if pad_size > 0:
43+
return tensor[:-pad_size, :]
44+
return tensor
45+
46+
47+
class CustomQwen3MLP(Qwen3MLP):
48+
49+
def __init__(
50+
self,
51+
hidden_size: int,
52+
intermediate_size: int,
53+
hidden_act: str,
54+
quant_config: Optional[QuantizationConfig] = None,
55+
prefix: str = "",
56+
) -> None:
57+
super().__init__(hidden_size=hidden_size,
58+
intermediate_size=intermediate_size,
59+
hidden_act=hidden_act,
60+
quant_config=quant_config,
61+
prefix=prefix)
62+
self.tp_size = get_tensor_model_parallel_world_size()
63+
self.tp_rank = get_tensor_model_parallel_rank()
64+
self.enable_fc = envs.VLLM_ENABLE_FC
65+
if self.enable_fc:
66+
# if flashcomm2 enabled, replace Linear+AllReduce with All2All+Linear
67+
self.down_proj = ReplicatedLinear(
68+
intermediate_size,
69+
hidden_size,
70+
bias=False,
71+
quant_config=quant_config,
72+
prefix=f"{prefix}.down_proj",
73+
)
74+
else:
75+
self.down_proj = RowParallelLinear(
76+
intermediate_size,
77+
hidden_size,
78+
bias=False,
79+
quant_config=quant_config,
80+
prefix=f"{prefix}.down_proj",
81+
)
82+
83+
def forward(self, x):
84+
gate_up, _ = self.gate_up_proj(x)
85+
x = self.act_fn(gate_up)
86+
pad_size = 0
87+
if self.enable_fc:
88+
# pad input because AllGather requires token_num to be divisible by tp_size
89+
x, pad_size = pad(x, self.tp_size)
90+
output = torch.empty(x.shape, dtype=x.dtype, device=x.device)
91+
dist.all_to_all_single(output, x)
92+
x = output.reshape(self.tp_size, -1, output.size(-1)) \
93+
.transpose(0, 1) \
94+
.reshape(-1, output.size(-1)*self.tp_size)
95+
x, _ = self.down_proj(x)
96+
return x, pad_size
97+
98+
2699
class CustomQwen3Attention(Qwen3Attention):
27100

28101
def __init__(self,
@@ -52,13 +125,32 @@ def __init__(self,
52125
rope_scaling=rope_scaling,
53126
prefix=prefix,
54127
attn_type=attn_type)
128+
self.tp_size = get_tensor_model_parallel_world_size()
129+
self.tp_rank = get_tensor_model_parallel_rank()
130+
self.enable_fc = envs.VLLM_ENABLE_FC
131+
if self.enable_fc:
132+
self.o_proj = ReplicatedLinear(
133+
self.total_num_heads * self.head_dim,
134+
hidden_size,
135+
bias=False,
136+
quant_config=quant_config,
137+
prefix=f"{prefix}.o_proj",
138+
)
139+
else:
140+
self.o_proj = RowParallelLinear(
141+
self.total_num_heads * self.head_dim,
142+
hidden_size,
143+
bias=False,
144+
quant_config=quant_config,
145+
prefix=f"{prefix}.o_proj",
146+
)
55147

56148
def forward(
57149
self,
58150
positions: torch.Tensor,
151+
hidden_states: torch.Tensor,
59152
cos: torch.Tensor,
60153
sin: torch.Tensor,
61-
hidden_states: torch.Tensor,
62154
) -> torch.Tensor:
63155
qkv, _ = self.qkv_proj(hidden_states)
64156
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
@@ -78,8 +170,19 @@ def forward(
78170
sin=sin,
79171
skip_index_select=True)
80172
attn_output = self.attn(q, k, v)
173+
pad_size = 0
174+
if self.enable_fc:
175+
# pad input because AllGather requires token_num to be divisible by tp_size
176+
attn_output, pad_size = pad(attn_output, self.tp_size)
177+
output = torch.empty(attn_output.shape,
178+
dtype=attn_output.dtype,
179+
device=attn_output.device)
180+
dist.all_to_all_single(output, attn_output)
181+
attn_output = output.reshape(self.tp_size, -1, output.size(-1)) \
182+
.transpose(0, 1) \
183+
.reshape(-1, output.size(-1)*self.tp_size)
81184
output, _ = self.o_proj(attn_output)
82-
return output
185+
return output, pad_size
83186

84187

85188
class CustomQwen3DecoderLayer(nn.Module):
@@ -93,6 +196,9 @@ def __init__(
93196
) -> None:
94197
super().__init__()
95198
self.hidden_size = config.hidden_size
199+
self.tp_size = get_tensor_model_parallel_world_size()
200+
self.tp_rank = get_tensor_model_parallel_rank()
201+
self.enable_fc = envs.VLLM_ENABLE_FC
96202
# Requires transformers > 4.32.0
97203
rope_theta = getattr(config, "rope_theta", 1000000)
98204
rope_scaling = getattr(config, "rope_scaling", None)
@@ -121,7 +227,7 @@ def __init__(
121227
prefix=f"{prefix}.self_attn",
122228
attn_type=attn_type,
123229
)
124-
self.mlp = Qwen3MLP(
230+
self.mlp = CustomQwen3MLP(
125231
hidden_size=self.hidden_size,
126232
intermediate_size=config.intermediate_size,
127233
hidden_act=config.hidden_act,
@@ -159,31 +265,58 @@ def __init__(
159265
self.post_attention_layernorm = RMSNorm(
160266
config.hidden_size, eps=config.rms_norm_eps)
161267

162-
def forward(
163-
self,
164-
positions: torch.Tensor,
165-
cos: torch.Tensor,
166-
sin: torch.Tensor,
167-
hidden_states: torch.Tensor,
168-
residual: Optional[torch.Tensor],
169-
) -> tuple[torch.Tensor, torch.Tensor]:
268+
def pre_attention_process(self, hidden_states, residual, pad_size=0):
269+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
270+
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
271+
hidden_states = unpad(hidden_states, pad_size)
272+
return hidden_states, residual
273+
274+
def pre_mlp_process(self, hidden_states, residual, pad_size=0):
275+
token_num = hidden_states.size(0)
276+
if token_num != residual.size(0):
277+
if pad_size > 0:
278+
residual = F.pad(residual, (0, 0, 0, pad_size))
279+
split_size_list = [token_num] * self.tp_size
280+
residual = torch.split(residual, split_size_list)[self.tp_rank]
281+
282+
hidden_states, residual = self.post_attention_layernorm(
283+
hidden_states, residual)
284+
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
285+
hidden_states = unpad(hidden_states, pad_size)
286+
return hidden_states, residual
287+
288+
def forward(self,
289+
positions: torch.Tensor,
290+
hidden_states: torch.Tensor,
291+
residual: Optional[torch.Tensor],
292+
cos: torch.Tensor,
293+
sin: torch.Tensor,
294+
pad_size: int = 0) -> tuple[torch.Tensor, torch.Tensor, int]:
170295
# Self Attention
171296
if residual is None:
172297
residual = hidden_states
173298
hidden_states = self.input_layernorm(hidden_states)
174299
else:
175-
hidden_states, residual = self.input_layernorm(
300+
if self.enable_fc:
301+
hidden_states, residual = self.pre_attention_process(
302+
hidden_states, residual, pad_size)
303+
else:
304+
hidden_states, residual = self.input_layernorm(
305+
hidden_states, residual)
306+
hidden_states, pad_size = self.self_attn(positions=positions,
307+
hidden_states=hidden_states,
308+
cos=cos,
309+
sin=sin)
310+
311+
# Fully Connected
312+
if self.enable_fc:
313+
hidden_states, residual = self.pre_mlp_process(
314+
hidden_states, residual, pad_size)
315+
else:
316+
hidden_states, residual = self.post_attention_layernorm(
176317
hidden_states, residual)
177-
hidden_states = self.self_attn(
178-
positions=positions,
179-
cos=cos,
180-
sin=sin,
181-
hidden_states=hidden_states,
182-
)
183-
hidden_states, residual = self.post_attention_layernorm(
184-
hidden_states, residual)
185-
hidden_states = self.mlp(hidden_states)
186-
return hidden_states, residual
318+
hidden_states, pad_size = self.mlp(hidden_states)
319+
return hidden_states, residual, pad_size
187320

188321

189322
ALL_DECODER_LAYER_TYPES = {
@@ -207,6 +340,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
207340
prefix=prefix,
208341
decoder_layer_type=CustomQwen3DecoderLayer)
209342
self.cos_sin_cache = self.layers[0].self_attn.rotary_emb.cos_sin_cache
343+
self.tp_size = get_tensor_model_parallel_world_size()
344+
self.tp_rank = get_tensor_model_parallel_rank()
345+
self.enable_fc = envs.VLLM_ENABLE_FC
210346

211347
def forward(
212348
self,
@@ -235,14 +371,18 @@ def forward(
235371
cos, sin = cos.view(1, -1, 1, last_dim).contiguous(), sin.view(
236372
1, -1, 1, last_dim).contiguous()
237373

374+
pad_size = 0
238375
for layer in self.layers[self.start_layer:self.end_layer]:
239-
hidden_states, residual = layer(
240-
positions,
241-
cos,
242-
sin,
243-
hidden_states,
244-
residual,
245-
)
376+
hidden_states, residual, pad_size = layer(positions, hidden_states,
377+
residual, cos, sin,
378+
pad_size)
379+
if self.enable_fc:
380+
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
381+
residual = tensor_model_parallel_all_gather(residual, 0)
382+
if pad_size > 0:
383+
hidden_states = hidden_states[:-pad_size]
384+
residual = residual[:-pad_size]
385+
246386
if not get_pp_group().is_last_rank:
247387
return IntermediateTensors({
248388
"hidden_states": hidden_states,

0 commit comments

Comments
 (0)