Skip to content

Commit 7e4f2fc

Browse files
committed
add support for flashcomm2 in qwen3
Signed-off-by: David9857 <985700846@qq.com>
1 parent f08283a commit 7e4f2fc

File tree

3 files changed

+194
-30
lines changed

3 files changed

+194
-30
lines changed

tests/multicard/test_offline_inference_distributed.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,21 @@ def test_models_distributed_DeepSeek_W8A8():
169169
quantization="ascend",
170170
) as vllm_model:
171171
vllm_model.generate_greedy(example_prompts, max_tokens)
172+
173+
174+
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM": "2"})
175+
def test_models_distributed_Qwen3_with_flashcomm2():
176+
example_prompts = [
177+
"Hello, my name is",
178+
]
179+
max_tokens = 5
180+
181+
with VllmRunner(
182+
snapshot_download("Qwen/Qwen3-0.6B-Base"),
183+
max_model_len=8192,
184+
enforce_eager=True,
185+
dtype="auto",
186+
tensor_parallel_size=2,
187+
quantization="ascend",
188+
) as vllm_model:
189+
vllm_model.generate_greedy(example_prompts, max_tokens)

vllm_ascend/envs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@
143143
# Batch MC2 in prefill: The number of tokens in each batch
144144
"VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE":
145145
lambda: int(os.getenv("VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE", "128")),
146+
# FlashComm optimization: Enable v1 and v2 by setting this flag to 1 or 2 respectively
147+
"VLLM_ASCEND_ENABLE_FLASHCOMM":
148+
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))
146149
}
147150

148151
# end-env-vars-definition

vllm_ascend/models/qwen3.py

Lines changed: 173 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,19 @@
22
from typing import Optional, Union
33

44
import torch
5+
import torch.distributed as dist
6+
import torch.nn.functional as F
57
from torch import nn
68
from transformers import Qwen3Config
79
from vllm.attention import AttentionType
810
from vllm.compilation.decorators import support_torch_compile
911
from vllm.config import CacheConfig, VllmConfig
10-
from vllm.distributed import get_pp_group
12+
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
13+
get_tensor_model_parallel_world_size,
14+
get_tp_group, tensor_model_parallel_all_gather)
1115
from vllm.model_executor.layers.layernorm import RMSNorm
16+
from vllm.model_executor.layers.linear import (ReplicatedLinear,
17+
RowParallelLinear)
1218
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1319
from vllm.model_executor.layers.quantization import QuantizationConfig
1420
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
@@ -20,9 +26,78 @@
2026
from vllm.model_executor.sampling_metadata import SamplingMetadata
2127
from vllm.sequence import IntermediateTensors
2228

29+
from vllm_ascend import envs
2330
from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant
2431

2532

33+
def pad(tensor, x):
34+
length = tensor.size(0)
35+
pad_size = (x - (length % x)) % x
36+
if pad_size > 0:
37+
return F.pad(tensor, (0, 0, 0, pad_size)), pad_size
38+
return tensor, pad_size
39+
40+
41+
def unpad(tensor, pad_size):
42+
if pad_size > 0:
43+
return tensor[:-pad_size, :]
44+
return tensor
45+
46+
47+
class CustomQwen3MLP(Qwen3MLP):
48+
49+
def __init__(
50+
self,
51+
hidden_size: int,
52+
intermediate_size: int,
53+
hidden_act: str,
54+
quant_config: Optional[QuantizationConfig] = None,
55+
prefix: str = "",
56+
) -> None:
57+
super().__init__(hidden_size=hidden_size,
58+
intermediate_size=intermediate_size,
59+
hidden_act=hidden_act,
60+
quant_config=quant_config,
61+
prefix=prefix)
62+
self.tp_size = get_tensor_model_parallel_world_size()
63+
self.tp_rank = get_tensor_model_parallel_rank()
64+
self.enable_fc = envs.VLLM_ASCEND_ENABLE_FLASHCOMM
65+
if self.enable_fc == 2:
66+
# if flashcomm2 enabled, replace Linear+AllReduce with All2All+Linear
67+
self.down_proj = ReplicatedLinear(
68+
intermediate_size,
69+
hidden_size,
70+
bias=False,
71+
quant_config=quant_config,
72+
prefix=f"{prefix}.down_proj",
73+
)
74+
else:
75+
self.down_proj = RowParallelLinear(
76+
intermediate_size,
77+
hidden_size,
78+
bias=False,
79+
quant_config=quant_config,
80+
prefix=f"{prefix}.down_proj",
81+
)
82+
83+
def forward(self, x):
84+
gate_up, _ = self.gate_up_proj(x)
85+
x = self.act_fn(gate_up)
86+
pad_size = 0
87+
if self.enable_fc == 2:
88+
# pad input because AllGather requires token_num to be divisible by tp_size
89+
x, pad_size = pad(x, self.tp_size)
90+
output = torch.empty(x.shape, dtype=x.dtype, device=x.device)
91+
dist.all_to_all_single(output,
92+
x,
93+
group=get_tp_group().device_group)
94+
x = output.reshape(self.tp_size, -1, output.size(-1)) \
95+
.transpose(0, 1) \
96+
.reshape(-1, output.size(-1)*self.tp_size)
97+
x, _ = self.down_proj(x)
98+
return x, pad_size
99+
100+
26101
class CustomQwen3Attention(Qwen3Attention):
27102

28103
def __init__(self,
@@ -52,13 +127,32 @@ def __init__(self,
52127
rope_scaling=rope_scaling,
53128
prefix=prefix,
54129
attn_type=attn_type)
130+
self.tp_size = get_tensor_model_parallel_world_size()
131+
self.tp_rank = get_tensor_model_parallel_rank()
132+
self.enable_fc = envs.VLLM_ASCEND_ENABLE_FLASHCOMM
133+
if self.enable_fc == 2:
134+
self.o_proj = ReplicatedLinear(
135+
self.total_num_heads * self.head_dim,
136+
hidden_size,
137+
bias=False,
138+
quant_config=quant_config,
139+
prefix=f"{prefix}.o_proj",
140+
)
141+
else:
142+
self.o_proj = RowParallelLinear(
143+
self.total_num_heads * self.head_dim,
144+
hidden_size,
145+
bias=False,
146+
quant_config=quant_config,
147+
prefix=f"{prefix}.o_proj",
148+
)
55149

56150
def forward(
57151
self,
58152
positions: torch.Tensor,
153+
hidden_states: torch.Tensor,
59154
cos: torch.Tensor,
60155
sin: torch.Tensor,
61-
hidden_states: torch.Tensor,
62156
) -> torch.Tensor:
63157
qkv, _ = self.qkv_proj(hidden_states)
64158
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
@@ -78,8 +172,21 @@ def forward(
78172
sin=sin,
79173
skip_index_select=True)
80174
attn_output = self.attn(q, k, v)
175+
pad_size = 0
176+
if self.enable_fc == 2:
177+
# pad input because AllGather requires token_num to be divisible by tp_size
178+
attn_output, pad_size = pad(attn_output, self.tp_size)
179+
output = torch.empty(attn_output.shape,
180+
dtype=attn_output.dtype,
181+
device=attn_output.device)
182+
dist.all_to_all_single(output,
183+
attn_output,
184+
group=get_tp_group().device_group)
185+
attn_output = output.reshape(self.tp_size, -1, output.size(-1)) \
186+
.transpose(0, 1) \
187+
.reshape(-1, output.size(-1)*self.tp_size)
81188
output, _ = self.o_proj(attn_output)
82-
return output
189+
return output, pad_size
83190

84191

85192
class CustomQwen3DecoderLayer(nn.Module):
@@ -93,6 +200,9 @@ def __init__(
93200
) -> None:
94201
super().__init__()
95202
self.hidden_size = config.hidden_size
203+
self.tp_size = get_tensor_model_parallel_world_size()
204+
self.tp_rank = get_tensor_model_parallel_rank()
205+
self.enable_fc = envs.VLLM_ASCEND_ENABLE_FLASHCOMM
96206
# Requires transformers > 4.32.0
97207
rope_theta = getattr(config, "rope_theta", 1000000)
98208
rope_scaling = getattr(config, "rope_scaling", None)
@@ -121,7 +231,7 @@ def __init__(
121231
prefix=f"{prefix}.self_attn",
122232
attn_type=attn_type,
123233
)
124-
self.mlp = Qwen3MLP(
234+
self.mlp = CustomQwen3MLP(
125235
hidden_size=self.hidden_size,
126236
intermediate_size=config.intermediate_size,
127237
hidden_act=config.hidden_act,
@@ -159,31 +269,56 @@ def __init__(
159269
self.post_attention_layernorm = RMSNorm(
160270
config.hidden_size, eps=config.rms_norm_eps)
161271

162-
def forward(
163-
self,
164-
positions: torch.Tensor,
165-
cos: torch.Tensor,
166-
sin: torch.Tensor,
167-
hidden_states: torch.Tensor,
168-
residual: Optional[torch.Tensor],
169-
) -> tuple[torch.Tensor, torch.Tensor]:
272+
def pre_attention_process(self, hidden_states, residual, pad_size=0):
273+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
274+
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
275+
hidden_states = unpad(hidden_states, pad_size)
276+
return hidden_states, residual
277+
278+
def pre_mlp_process(self, hidden_states, residual, pad_size=0):
279+
hidden_states, residual = self.post_attention_layernorm(
280+
hidden_states, residual)
281+
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
282+
hidden_states = unpad(hidden_states, pad_size)
283+
return hidden_states, residual
284+
285+
def forward(self,
286+
positions: torch.Tensor,
287+
hidden_states: torch.Tensor,
288+
residual: Optional[torch.Tensor],
289+
cos: torch.Tensor,
290+
sin: torch.Tensor,
291+
pad_size: int = 0) -> tuple[torch.Tensor, torch.Tensor, int]:
170292
# Self Attention
171293
if residual is None:
172294
residual = hidden_states
173295
hidden_states = self.input_layernorm(hidden_states)
296+
if self.enable_fc == 2:
297+
residual, pad_size = pad(residual, self.tp_size)
298+
chunk_size = residual.size(0) // self.tp_size
299+
residual = residual[chunk_size * self.tp_rank:chunk_size *
300+
(self.tp_rank + 1)]
301+
else:
302+
if self.enable_fc == 2:
303+
hidden_states, residual = self.pre_attention_process(
304+
hidden_states, residual, pad_size)
305+
else:
306+
hidden_states, residual = self.input_layernorm(
307+
hidden_states, residual)
308+
hidden_states, pad_size = self.self_attn(positions=positions,
309+
hidden_states=hidden_states,
310+
cos=cos,
311+
sin=sin)
312+
313+
# Fully Connected
314+
if self.enable_fc == 2:
315+
hidden_states, residual = self.pre_mlp_process(
316+
hidden_states, residual, pad_size)
174317
else:
175-
hidden_states, residual = self.input_layernorm(
318+
hidden_states, residual = self.post_attention_layernorm(
176319
hidden_states, residual)
177-
hidden_states = self.self_attn(
178-
positions=positions,
179-
cos=cos,
180-
sin=sin,
181-
hidden_states=hidden_states,
182-
)
183-
hidden_states, residual = self.post_attention_layernorm(
184-
hidden_states, residual)
185-
hidden_states = self.mlp(hidden_states)
186-
return hidden_states, residual
320+
hidden_states, pad_size = self.mlp(hidden_states)
321+
return hidden_states, residual, pad_size
187322

188323

189324
ALL_DECODER_LAYER_TYPES = {
@@ -207,6 +342,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
207342
prefix=prefix,
208343
decoder_layer_type=CustomQwen3DecoderLayer)
209344
self.cos_sin_cache = self.layers[0].self_attn.rotary_emb.cos_sin_cache
345+
self.tp_size = get_tensor_model_parallel_world_size()
346+
self.tp_rank = get_tensor_model_parallel_rank()
347+
self.enable_fc = envs.VLLM_ASCEND_ENABLE_FLASHCOMM
210348

211349
def forward(
212350
self,
@@ -235,20 +373,25 @@ def forward(
235373
cos, sin = cos.view(1, -1, 1, last_dim).contiguous(), sin.view(
236374
1, -1, 1, last_dim).contiguous()
237375

376+
pad_size = 0
238377
for layer in self.layers[self.start_layer:self.end_layer]:
239-
hidden_states, residual = layer(
240-
positions,
241-
cos,
242-
sin,
243-
hidden_states,
244-
residual,
245-
)
378+
hidden_states, residual, pad_size = layer(positions, hidden_states,
379+
residual, cos, sin,
380+
pad_size)
381+
246382
if not get_pp_group().is_last_rank:
247383
return IntermediateTensors({
248384
"hidden_states": hidden_states,
249385
"residual": residual
250386
})
251387
hidden_states, _ = self.norm(hidden_states, residual)
388+
389+
if self.enable_fc == 2:
390+
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
391+
residual = tensor_model_parallel_all_gather(residual, 0)
392+
if pad_size > 0:
393+
hidden_states = hidden_states[:-pad_size]
394+
residual = residual[:-pad_size]
252395
return hidden_states
253396

254397

0 commit comments

Comments
 (0)