Skip to content

[V0.9.1] add support for flashcomm2 in qwen3 #1726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions tests/multicard/test_offline_inference_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,21 @@ def test_models_distributed_DeepSeek_W8A8():
quantization="ascend",
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)


@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM": "2"})
def test_models_distributed_Qwen3_with_flashcomm2():
example_prompts = [
"Hello, my name is",
]
max_tokens = 5

with VllmRunner(
snapshot_download("Qwen/Qwen3-0.6B-Base"),
max_model_len=8192,
enforce_eager=True,
dtype="auto",
tensor_parallel_size=2,
quantization="ascend",
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
3 changes: 3 additions & 0 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@
# Batch MC2 in prefill: The number of tokens in each batch
"VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE":
lambda: int(os.getenv("VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE", "128")),
# FlashComm optimization: Enable v1 and v2 by setting this flag to 1 or 2 respectively
"VLLM_ASCEND_ENABLE_FLASHCOMM":
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))
}

# end-env-vars-definition
Expand Down
203 changes: 173 additions & 30 deletions vllm_ascend/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@
from typing import Optional, Union

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from transformers import Qwen3Config
from vllm.attention import AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group, tensor_model_parallel_all_gather)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
Expand All @@ -20,9 +26,78 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from vllm_ascend import envs
from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant


def pad(tensor, x):
length = tensor.size(0)
pad_size = (x - (length % x)) % x
if pad_size > 0:
return F.pad(tensor, (0, 0, 0, pad_size)), pad_size
return tensor, pad_size


def unpad(tensor, pad_size):
if pad_size > 0:
return tensor[:-pad_size, :]
return tensor


class CustomQwen3MLP(Qwen3MLP):

def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(hidden_size=hidden_size,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
quant_config=quant_config,
prefix=prefix)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.enable_fc = envs.VLLM_ASCEND_ENABLE_FLASHCOMM
if self.enable_fc == 2:
# if flashcomm2 enabled, replace Linear+AllReduce with All2All+Linear
self.down_proj = ReplicatedLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
else:
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)

def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
pad_size = 0
if self.enable_fc == 2:
# pad input because AllGather requires token_num to be divisible by tp_size
x, pad_size = pad(x, self.tp_size)
output = torch.empty(x.shape, dtype=x.dtype, device=x.device)
dist.all_to_all_single(output,
x,
group=get_tp_group().device_group)
x = output.reshape(self.tp_size, -1, output.size(-1)) \
.transpose(0, 1) \
.reshape(-1, output.size(-1)*self.tp_size)
x, _ = self.down_proj(x)
return x, pad_size


class CustomQwen3Attention(Qwen3Attention):

def __init__(self,
Expand Down Expand Up @@ -52,13 +127,32 @@ def __init__(self,
rope_scaling=rope_scaling,
prefix=prefix,
attn_type=attn_type)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.enable_fc = envs.VLLM_ASCEND_ENABLE_FLASHCOMM
if self.enable_fc == 2:
self.o_proj = ReplicatedLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
else:
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
Expand All @@ -78,8 +172,21 @@ def forward(
sin=sin,
skip_index_select=True)
attn_output = self.attn(q, k, v)
pad_size = 0
if self.enable_fc == 2:
# pad input because AllGather requires token_num to be divisible by tp_size
attn_output, pad_size = pad(attn_output, self.tp_size)
output = torch.empty(attn_output.shape,
dtype=attn_output.dtype,
device=attn_output.device)
dist.all_to_all_single(output,
attn_output,
group=get_tp_group().device_group)
attn_output = output.reshape(self.tp_size, -1, output.size(-1)) \
.transpose(0, 1) \
.reshape(-1, output.size(-1)*self.tp_size)
output, _ = self.o_proj(attn_output)
return output
return output, pad_size


class CustomQwen3DecoderLayer(nn.Module):
Expand All @@ -93,6 +200,9 @@ def __init__(
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.enable_fc = envs.VLLM_ASCEND_ENABLE_FLASHCOMM
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
Expand Down Expand Up @@ -121,7 +231,7 @@ def __init__(
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
)
self.mlp = Qwen3MLP(
self.mlp = CustomQwen3MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
Expand Down Expand Up @@ -159,31 +269,56 @@ def __init__(
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
positions: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
def pre_attention_process(self, hidden_states, residual, pad_size=0):
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
hidden_states = unpad(hidden_states, pad_size)
return hidden_states, residual

def pre_mlp_process(self, hidden_states, residual, pad_size=0):
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
hidden_states = unpad(hidden_states, pad_size)
return hidden_states, residual

def forward(self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
cos: torch.Tensor,
sin: torch.Tensor,
pad_size: int = 0) -> tuple[torch.Tensor, torch.Tensor, int]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if self.enable_fc == 2:
residual, pad_size = pad(residual, self.tp_size)
chunk_size = residual.size(0) // self.tp_size
residual = residual[chunk_size * self.tp_rank:chunk_size *
(self.tp_rank + 1)]
else:
if self.enable_fc == 2:
hidden_states, residual = self.pre_attention_process(
hidden_states, residual, pad_size)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states, pad_size = self.self_attn(positions=positions,
hidden_states=hidden_states,
cos=cos,
sin=sin)

# Fully Connected
if self.enable_fc == 2:
hidden_states, residual = self.pre_mlp_process(
hidden_states, residual, pad_size)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
cos=cos,
sin=sin,
hidden_states=hidden_states,
)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
hidden_states, pad_size = self.mlp(hidden_states)
return hidden_states, residual, pad_size


ALL_DECODER_LAYER_TYPES = {
Expand All @@ -207,6 +342,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
prefix=prefix,
decoder_layer_type=CustomQwen3DecoderLayer)
self.cos_sin_cache = self.layers[0].self_attn.rotary_emb.cos_sin_cache
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.enable_fc = envs.VLLM_ASCEND_ENABLE_FLASHCOMM

def forward(
self,
Expand Down Expand Up @@ -235,20 +373,25 @@ def forward(
cos, sin = cos.view(1, -1, 1, last_dim).contiguous(), sin.view(
1, -1, 1, last_dim).contiguous()

pad_size = 0
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(
positions,
cos,
sin,
hidden_states,
residual,
)
hidden_states, residual, pad_size = layer(positions, hidden_states,
residual, cos, sin,
pad_size)

if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)

if self.enable_fc == 2:
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
residual = tensor_model_parallel_all_gather(residual, 0)
if pad_size > 0:
hidden_states = hidden_states[:-pad_size]
residual = residual[:-pad_size]
return hidden_states


Expand Down