Skip to content

Commit 8a8d615

Browse files
committed
add support for flashcomm2 in qwen3
Signed-off-by: David9857 <985700846@qq.com> fix Signed-off-by: David9857 <985700846@qq.com> remove self.pad_size Signed-off-by: David9857 <985700846@qq.com>
1 parent da2d5ac commit 8a8d615

File tree

2 files changed

+293
-7
lines changed

2 files changed

+293
-7
lines changed

vllm_ascend/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@
143143
# Batch MC2 in prefill: The number of tokens in each batch
144144
"VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE":
145145
lambda: int(os.getenv("VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE", "128")),
146+
"VLLM_ENABLE_FC":
147+
lambda: int(os.getenv("VLLM_ENABLE_FC", 0))
146148
}
147149

148150
# end-env-vars-definition

vllm_ascend/models/qwen3.py

Lines changed: 291 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,227 @@
33

44
import torch
55
from torch import nn
6+
import torch.nn.functional as F
7+
import torch.distributed as dist
68
from transformers import Qwen3Config
79
from vllm.compilation.decorators import support_torch_compile
810
from vllm.config import CacheConfig, VllmConfig
9-
from vllm.distributed import get_pp_group
11+
from vllm.attention import Attention, AttentionType
12+
from vllm.distributed import (get_pp_group,
13+
get_tensor_model_parallel_world_size,
14+
get_tensor_model_parallel_rank,
15+
tensor_model_parallel_all_gather)
1016
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1117
from vllm.model_executor.layers.quantization import QuantizationConfig
1218
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
19+
from vllm.model_executor.layers.linear import RowParallelLinear, ReplicatedLinear
1320
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
1421
from vllm.model_executor.models.qwen2 import Qwen2Model
15-
from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer
22+
from vllm.model_executor.models.qwen2 import Qwen2MLP as Qwen3MLP
23+
from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer, Qwen3Attention
1624
from vllm.model_executor.models.utils import (AutoWeightsLoader,
1725
PPMissingLayer, maybe_prefix)
1826
from vllm.model_executor.sampling_metadata import SamplingMetadata
1927
from vllm.sequence import IntermediateTensors
2028

29+
from vllm_ascend import envs
2130
from vllm_ascend.ops.layernorm import AddRMSNormQuant
2231

2332

24-
class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
33+
def pad(tensor, x):
34+
length = tensor.size(0)
35+
pad_size = (x - (length % x)) % x
36+
if pad_size > 0:
37+
return F.pad(tensor, (0, 0, 0, pad_size)), pad_size
38+
return tensor, pad_size
39+
40+
def unpad(tensor, pad_size):
41+
if pad_size > 0:
42+
return tensor[:-pad_size, :]
43+
return tensor
44+
45+
46+
class CustomQwen3MLP(Qwen3MLP):
2547

2648
def __init__(
2749
self,
28-
config: Qwen3Config,
29-
cache_config: Optional[CacheConfig] = None,
50+
hidden_size: int,
51+
intermediate_size: int,
52+
hidden_act: str,
3053
quant_config: Optional[QuantizationConfig] = None,
3154
prefix: str = "",
3255
) -> None:
33-
super().__init__(config=config,
34-
cache_config=cache_config,
56+
super().__init__(hidden_size=hidden_size,
57+
intermediate_size=intermediate_size,
58+
hidden_act=hidden_act,
3559
quant_config=quant_config,
3660
prefix=prefix)
61+
self.tp_size = get_tensor_model_parallel_world_size()
62+
self.tp_rank = get_tensor_model_parallel_rank()
63+
self.enable_fc = envs.VLLM_ENABLE_FC
64+
if self.enable_fc:
65+
# if flashcomm2 enbaled, replace Linear+AllReduce with All2All+Linear
66+
self.down_proj = ReplicatedLinear(
67+
intermediate_size,
68+
hidden_size,
69+
bias=False,
70+
quant_config=quant_config,
71+
prefix=f"{prefix}.down_proj",
72+
)
73+
else:
74+
self.down_proj = RowParallelLinear(
75+
intermediate_size,
76+
hidden_size,
77+
bias=False,
78+
quant_config=quant_config,
79+
prefix=f"{prefix}.down_proj",
80+
)
81+
82+
def forward(self, x):
83+
gate_up, _ = self.gate_up_proj(x)
84+
x = self.act_fn(gate_up)
85+
pad_size = 0
86+
if self.enable_fc:
87+
# pad input because AllGather requires token_num to be divisible by tp_size
88+
x, pad_size = pad(x, self.tp_size)
89+
output = torch.empty(x.shape, dtype=x.dtype, device=x.device)
90+
dist.all_to_all_single(output, x)
91+
x = output.reshape(self.tp_size, -1, output.size(-1)) \
92+
.transpose(0, 1) \
93+
.reshape(-1, output.size(-1)*self.tp_size)
94+
x, _ = self.down_proj(x)
95+
return x, pad_size
96+
97+
98+
class CustomQwen3Attention(Qwen3Attention):
99+
100+
def __init__(self,
101+
hidden_size: int,
102+
num_heads: int,
103+
num_kv_heads: int,
104+
max_position: int = 4096 * 32,
105+
head_dim: Optional[int] = None,
106+
rms_norm_eps: float = 1e-06,
107+
qkv_bias: bool = False,
108+
rope_theta: float = 10000,
109+
cache_config: Optional[CacheConfig] = None,
110+
quant_config: Optional[QuantizationConfig] = None,
111+
rope_scaling: Optional[tuple] = None,
112+
prefix: str = "",
113+
attn_type: str = AttentionType.DECODER) -> None:
114+
super().__init__(hidden_size=hidden_size,
115+
num_heads=num_heads,
116+
num_kv_heads=num_kv_heads,
117+
max_position=max_position,
118+
head_dim=head_dim,
119+
rms_norm_eps=rms_norm_eps,
120+
qkv_bias=qkv_bias,
121+
rope_theta=rope_theta,
122+
cache_config=cache_config,
123+
quant_config=quant_config,
124+
rope_scaling=rope_scaling,
125+
prefix=prefix,
126+
attn_type=attn_type)
127+
self.tp_size = get_tensor_model_parallel_world_size()
128+
self.tp_rank = get_tensor_model_parallel_rank()
129+
self.enable_fc = envs.VLLM_ENABLE_FC
130+
if self.enable_fc:
131+
self.o_proj = ReplicatedLinear(
132+
self.total_num_heads * self.head_dim,
133+
hidden_size,
134+
bias=False,
135+
quant_config=quant_config,
136+
prefix=f"{prefix}.o_proj",
137+
)
138+
else:
139+
self.o_proj = RowParallelLinear(
140+
self.total_num_heads * self.head_dim,
141+
hidden_size,
142+
bias=False,
143+
quant_config=quant_config,
144+
prefix=f"{prefix}.o_proj",
145+
)
146+
147+
def forward(
148+
self,
149+
positions: torch.Tensor,
150+
hidden_states: torch.Tensor,
151+
) -> torch.Tensor:
152+
qkv, _ = self.qkv_proj(hidden_states)
153+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
154+
# Add qk-norm
155+
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
156+
self.head_dim)
157+
q_by_head = self.q_norm(q_by_head)
158+
q = q_by_head.view(q.shape)
159+
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
160+
self.head_dim)
161+
k_by_head = self.k_norm(k_by_head)
162+
k = k_by_head.view(k.shape)
163+
q, k = self.rotary_emb(positions, q, k)
164+
attn_output = self.attn(q, k, v)
165+
pad_size = 0
166+
if self.enable_fc:
167+
# pad input because AllGather requires token_num to be divisible by tp_size
168+
attn_output, pad_size = pad(attn_output, self.tp_size)
169+
output = torch.empty(attn_output.shape, dtype=attn_output.dtype, device=attn_output.device)
170+
dist.all_to_all_single(output, attn_output)
171+
attn_output = output.reshape(self.tp_size, -1, output.size(-1)) \
172+
.transpose(0, 1) \
173+
.reshape(-1, output.size(-1)*self.tp_size)
174+
output, _ = self.o_proj(attn_output)
175+
return output, pad_size
176+
177+
178+
class CustomQwen3DecoderLayer(nn.Module):
179+
180+
def __init__(
181+
self,
182+
config: Qwen3Config,
183+
cache_config: Optional[CacheConfig] = None,
184+
quant_config: Optional[QuantizationConfig] = None,
185+
prefix: str = "",
186+
) -> None:
187+
super().__init__()
188+
self.hidden_size = config.hidden_size
189+
self.tp_size = get_tensor_model_parallel_world_size()
190+
self.tp_rank = get_tensor_model_parallel_rank()
191+
self.enable_fc = envs.VLLM_ENABLE_FC
192+
# Requires transformers > 4.32.0
193+
rope_theta = getattr(config, "rope_theta", 1000000)
194+
rope_scaling = getattr(config, "rope_scaling", None)
195+
196+
# By default, Qwen3 uses causal attention as it is a decoder-only model.
197+
# You can override the HF config with `is_causal=False` to enable
198+
# bidirectional attention, which is used in some embedding models
199+
# (e.g. Alibaba-NLP/gte-Qwen3-7B-instruct)
200+
if getattr(config, "is_causal", True):
201+
attn_type = AttentionType.DECODER
202+
else:
203+
attn_type = AttentionType.ENCODER_ONLY
204+
205+
self.self_attn = CustomQwen3Attention(
206+
hidden_size=self.hidden_size,
207+
num_heads=config.num_attention_heads,
208+
max_position=config.max_position_embeddings,
209+
num_kv_heads=config.num_key_value_heads,
210+
rope_theta=rope_theta,
211+
rms_norm_eps=config.rms_norm_eps,
212+
qkv_bias=getattr(config, 'attention_bias', False),
213+
head_dim=getattr(config, 'head_dim', None),
214+
cache_config=cache_config,
215+
quant_config=quant_config,
216+
rope_scaling=rope_scaling,
217+
prefix=f"{prefix}.self_attn",
218+
attn_type=attn_type,
219+
)
220+
self.mlp = CustomQwen3MLP(
221+
hidden_size=self.hidden_size,
222+
intermediate_size=config.intermediate_size,
223+
hidden_act=config.hidden_act,
224+
quant_config=quant_config,
225+
prefix=f"{prefix}.mlp",
226+
)
37227
if quant_config is None:
38228
return
39229

@@ -56,6 +246,58 @@ def __init__(
56246
layer=self.mlp.gate_up_proj,
57247
eps=config.rms_norm_eps)
58248

249+
def pre_attention_process(self, hidden_states, residual, pad_size=0):
250+
hidden_states, residual = self.input_layernorm(
251+
hidden_states, residual)
252+
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
253+
hidden_states = unpad(hidden_states, pad_size)
254+
return hidden_states, residual
255+
256+
def pre_mlp_process(self, hidden_states, residual, pad_size=0):
257+
token_num = hidden_states.size(0)
258+
if token_num != residual.size(0):
259+
if pad_size > 0:
260+
residual = F.pad(residual, (0, 0, 0, pad_size))
261+
split_size_list = [token_num] * self.tp_size
262+
residual = torch.split(residual, split_size_list)[self.tp_rank]
263+
264+
hidden_states, residual = self.post_attention_layernorm(
265+
hidden_states, residual)
266+
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
267+
hidden_states = unpad(hidden_states, pad_size)
268+
return hidden_states, residual
269+
270+
def forward(
271+
self,
272+
positions: torch.Tensor,
273+
hidden_states: torch.Tensor,
274+
residual: Optional[torch.Tensor],
275+
pad_size: int = 0
276+
) -> tuple[torch.Tensor, torch.Tensor]:
277+
# Self Attention
278+
if residual is None:
279+
residual = hidden_states
280+
hidden_states = self.input_layernorm(hidden_states)
281+
else:
282+
if self.enable_fc:
283+
hidden_states, residual = self.pre_attention_process(hidden_states, residual, pad_size)
284+
else:
285+
hidden_states, residual = self.input_layernorm(
286+
hidden_states, residual)
287+
hidden_states, pad_size = self.self_attn(
288+
positions=positions,
289+
hidden_states=hidden_states,
290+
)
291+
292+
# Fully Connected
293+
if self.enable_fc:
294+
hidden_states, residual = self.pre_mlp_process(hidden_states, residual, pad_size)
295+
else:
296+
hidden_states, residual = self.post_attention_layernorm(
297+
hidden_states, residual)
298+
hidden_states, pad_size = self.mlp(hidden_states)
299+
return hidden_states, residual, pad_size
300+
59301

60302
ALL_DECODER_LAYER_TYPES = {
61303
"attention": CustomQwen3DecoderLayer,
@@ -77,6 +319,48 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
77319
super().__init__(vllm_config=vllm_config,
78320
prefix=prefix,
79321
decoder_layer_type=CustomQwen3DecoderLayer)
322+
self.tp_size = get_tensor_model_parallel_world_size()
323+
self.tp_rank = get_tensor_model_parallel_rank()
324+
self.enable_fc = envs.VLLM_ENABLE_FC
325+
326+
def forward(
327+
self,
328+
input_ids: torch.Tensor,
329+
positions: torch.Tensor,
330+
intermediate_tensors: Optional[IntermediateTensors] = None,
331+
inputs_embeds: Optional[torch.Tensor] = None,
332+
) -> Union[torch.Tensor, IntermediateTensors]:
333+
if get_pp_group().is_first_rank:
334+
if inputs_embeds is not None:
335+
hidden_states = inputs_embeds
336+
else:
337+
hidden_states = self.get_input_embeddings(input_ids)
338+
residual = None
339+
else:
340+
assert intermediate_tensors is not None
341+
hidden_states = intermediate_tensors["hidden_states"]
342+
residual = intermediate_tensors["residual"]
343+
pad_size = 0
344+
for layer in self.layers[self.start_layer:self.end_layer]:
345+
hidden_states, residual, pad_size = layer(
346+
positions,
347+
hidden_states,
348+
residual,
349+
pad_size
350+
)
351+
if self.enable_fc:
352+
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
353+
residual = tensor_model_parallel_all_gather(residual, 0)
354+
if pad_size > 0:
355+
hidden_states = hidden_states[:-pad_size]
356+
residual = residual[:-pad_size]
357+
if not get_pp_group().is_last_rank:
358+
return IntermediateTensors({
359+
"hidden_states": hidden_states,
360+
"residual": residual
361+
})
362+
hidden_states, _ = self.norm(hidden_states, residual)
363+
return hidden_states
80364

81365

82366
class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):

0 commit comments

Comments
 (0)