Skip to content

Commit a148041

Browse files
rjg-lyhweijinqian_v1
authored andcommitted
[V0.9.1] Add support for flashcomm_v1 in Qwen2.5 (vllm-project#1745)
### What this PR does / why we need it? Add support for flashcomm_v1 in Qwen2.5. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? **①Functional Testing**: CI passed with existing test and add new test in tests\multicard\test_offline_inference_distributed.py. **②Accuracy Testing**: Using offline_inference: Evaluate the accuracy difference in model outputs between enabling and disabling the FlashComm v1 feature using offline inference. As shown in the figure below: - disabling <img width="1543" height="358" alt="image" src="https://github.com/user-attachments/assets/f7fab4e3-c3d1-412a-958e-11e2b9ec8f58" /> - enabling <img width="1541" height="531" alt="image" src="https://github.com/user-attachments/assets/11a2c5bf-22f0-4a63-b76d-c7b7575397be" /> **③Performance Stress Testing**: Here's the comparison of TTFT time, based on QwQ-32B-BF16, input_len=16K~32K, output_len=8K, and max_concurrency=16: - disabling Mean TTFT (ms): 1419.58 Median TTFT (ms): 1073.32 P99 TTFT (ms): 9549.34 - enabling Mean TTFT (ms): 1322.36 Median TTFT (ms): 1006.09 P99 TTFT (ms): 8268.28 --------- Signed-off-by: rjg-lyh <1318825571@qq.com> Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
1 parent 3b0cdb9 commit a148041

File tree

7 files changed

+441
-2
lines changed

7 files changed

+441
-2
lines changed

.github/workflows/vllm_ascend_test.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ jobs:
205205
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo
206206
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_w8a8_ep_dbo
207207
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo
208+
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ_with_flashcomm_v1
209+
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_with_flashcomm_v2
208210
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py --ignore=tests/multicard/test_w4a8_deepseek.py
209211
fi
210212

tests/multicard/test_offline_inference_distributed.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,26 @@ def test_models_distributed_DeepSeek_W8A8():
195195
vllm_model.generate_greedy(example_prompts, max_tokens)
196196

197197

198+
@pytest.mark.parametrize("enforce_eager", [True, False])
199+
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM": "1"})
200+
def test_models_distributed_QwQ_with_flashcomm_v1(enforce_eager: bool):
201+
example_prompts = [
202+
"Hello, my name is",
203+
]
204+
max_tokens = 5
205+
206+
with VllmRunner(
207+
snapshot_download("Qwen/QwQ-32B"),
208+
max_model_len=8192,
209+
enforce_eager=enforce_eager,
210+
dtype="auto",
211+
tensor_parallel_size=4,
212+
) as vllm_model:
213+
vllm_model.generate_greedy(example_prompts, max_tokens)
214+
215+
198216
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM": "2"})
199-
def test_models_distributed_Qwen3_with_flashcomm2():
217+
def test_models_distributed_Qwen3_with_flashcomm_v2():
200218
example_prompts = [
201219
"Hello, my name is",
202220
]
@@ -208,6 +226,5 @@ def test_models_distributed_Qwen3_with_flashcomm2():
208226
enforce_eager=True,
209227
dtype="auto",
210228
tensor_parallel_size=2,
211-
quantization="ascend",
212229
) as vllm_model:
213230
vllm_model.generate_greedy(example_prompts, max_tokens)

vllm_ascend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,6 @@ def register_model():
2828
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa: F401
2929

3030
from .models import register_model
31+
32+
import vllm_ascend.patch.platform.patch_0_9_1.patch_decorator # isort: skip # noqa: F401
3133
register_model()

vllm_ascend/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ def register_model():
88
from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
99
from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
1010
from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401
11+
from .qwen2 import CustomQwen2ForCausalLM # noqa: F401
1112
from .qwen2_5_vl import \
1213
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
1314
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
@@ -17,6 +18,9 @@ def register_model():
1718
"DeepSeekMTPModel",
1819
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
1920

21+
ModelRegistry.register_model(
22+
"Qwen2ForCausalLM", "vllm_ascend.models.qwen2:CustomQwen2ForCausalLM")
23+
2024
ModelRegistry.register_model(
2125
"Qwen2VLForConditionalGeneration",
2226
"vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration")

vllm_ascend/models/qwen2.py

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
from collections.abc import Iterable
2+
from typing import Optional, Union
3+
4+
import torch
5+
import torch.nn.functional as F
6+
from torch import nn
7+
from transformers import Qwen2Config
8+
from vllm.compilation.decorators import support_torch_compile
9+
from vllm.config import CacheConfig, VllmConfig
10+
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
11+
get_tensor_model_parallel_world_size,
12+
tensor_model_parallel_all_gather,
13+
tensor_model_parallel_all_reduce,
14+
tensor_model_parallel_reduce_scatter)
15+
from vllm.forward_context import get_forward_context
16+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
17+
from vllm.model_executor.layers.quantization import QuantizationConfig
18+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
19+
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
20+
from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2Model
21+
from vllm.model_executor.models.utils import (AutoWeightsLoader,
22+
PPMissingLayer, maybe_prefix)
23+
from vllm.model_executor.sampling_metadata import SamplingMetadata
24+
from vllm.sequence import IntermediateTensors
25+
26+
import vllm_ascend.envs as ascend_envs
27+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
28+
29+
30+
def all_gather_and_maybe_unpad(
31+
hidden_states: torch.Tensor,
32+
pad_size: int,
33+
) -> torch.Tensor:
34+
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
35+
if pad_size > 0:
36+
return hidden_states[:-pad_size, :]
37+
return hidden_states
38+
39+
40+
def maybe_pad_and_reduce_scatter(
41+
hidden_states: torch.Tensor,
42+
pad_size: int,
43+
) -> torch.Tensor:
44+
if pad_size > 0:
45+
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_size))
46+
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, 0)
47+
return hidden_states
48+
49+
50+
class CustomQwen2DecoderLayer(Qwen2DecoderLayer):
51+
52+
def __init__(
53+
self,
54+
config: Qwen2Config,
55+
cache_config: Optional[CacheConfig] = None,
56+
quant_config: Optional[QuantizationConfig] = None,
57+
prefix: str = "",
58+
) -> None:
59+
super().__init__(config=config,
60+
cache_config=cache_config,
61+
quant_config=quant_config,
62+
prefix=prefix)
63+
self.tp_rank = get_tensor_model_parallel_rank()
64+
self.tp_size = get_tensor_model_parallel_world_size()
65+
self.self_attn.o_proj.reduce_results = False
66+
self.mlp.down_proj.reduce_results = False
67+
68+
def forward(
69+
self,
70+
positions: torch.Tensor,
71+
hidden_states: torch.Tensor,
72+
residual: Optional[torch.Tensor],
73+
flashcomm_v1_enabled: bool,
74+
pad_size: int,
75+
) -> tuple[torch.Tensor, torch.Tensor]:
76+
# Self Attention
77+
if residual is None:
78+
residual = hidden_states
79+
if flashcomm_v1_enabled:
80+
if pad_size > 0:
81+
residual = F.pad(residual, (0, 0, 0, pad_size))
82+
residual = torch.chunk(residual, self.tp_size,
83+
dim=0)[self.tp_rank]
84+
hidden_states = self.input_layernorm(hidden_states)
85+
else:
86+
hidden_states, residual = self.input_layernorm(
87+
hidden_states, residual)
88+
if flashcomm_v1_enabled:
89+
hidden_states = all_gather_and_maybe_unpad(
90+
hidden_states, pad_size)
91+
hidden_states = self.self_attn(
92+
positions=positions,
93+
hidden_states=hidden_states,
94+
)
95+
if flashcomm_v1_enabled:
96+
hidden_states = maybe_pad_and_reduce_scatter(
97+
hidden_states, pad_size)
98+
else:
99+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
100+
# Fully Connected
101+
hidden_states, residual = self.post_attention_layernorm(
102+
hidden_states, residual)
103+
if flashcomm_v1_enabled:
104+
hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size)
105+
hidden_states = self.mlp(hidden_states)
106+
if flashcomm_v1_enabled:
107+
hidden_states = maybe_pad_and_reduce_scatter(
108+
hidden_states, pad_size)
109+
else:
110+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
111+
return hidden_states, residual
112+
113+
114+
@support_torch_compile(
115+
dynamic_arg_dims={
116+
"input_ids": 0,
117+
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
118+
# otherwise (seq_len, ).
119+
"positions": -1,
120+
"intermediate_tensors": 0,
121+
"inputs_embeds": 0,
122+
})
123+
class CustomQwen2Model(Qwen2Model):
124+
125+
def __init__(
126+
self,
127+
*,
128+
vllm_config: VllmConfig,
129+
prefix: str = "",
130+
decoder_layer_type: type[nn.Module] = CustomQwen2DecoderLayer):
131+
super().__init__(vllm_config=vllm_config,
132+
prefix=prefix,
133+
decoder_layer_type=decoder_layer_type)
134+
self.tp_size = get_tensor_model_parallel_world_size()
135+
136+
def forward(
137+
self,
138+
input_ids: torch.Tensor,
139+
positions: torch.Tensor,
140+
intermediate_tensors: Optional[IntermediateTensors] = None,
141+
inputs_embeds: Optional[torch.Tensor] = None,
142+
) -> Union[torch.Tensor, IntermediateTensors]:
143+
if get_pp_group().is_first_rank:
144+
if inputs_embeds is not None:
145+
hidden_states = inputs_embeds
146+
else:
147+
hidden_states = self.get_input_embeddings(input_ids)
148+
residual = None
149+
else:
150+
assert intermediate_tensors is not None
151+
hidden_states = intermediate_tensors["hidden_states"]
152+
residual = intermediate_tensors["residual"]
153+
pad_size = 0
154+
flashcomm_v1_enabled = False
155+
attn_metadata = get_forward_context().attn_metadata
156+
if ascend_envs.VLLM_ASCEND_ENABLE_FLASHCOMM == 1 and \
157+
attn_metadata is not None and \
158+
attn_metadata.attn_state != AscendAttentionState.DecodeOnly:
159+
flashcomm_v1_enabled = True
160+
if flashcomm_v1_enabled:
161+
num_tokens = hidden_states.size(0)
162+
pad_size = (self.tp_size -
163+
(num_tokens % self.tp_size)) % self.tp_size
164+
for layer in self.layers[self.start_layer:self.end_layer]:
165+
hidden_states, residual = layer(
166+
positions,
167+
hidden_states,
168+
residual,
169+
flashcomm_v1_enabled,
170+
pad_size,
171+
)
172+
if not get_pp_group().is_last_rank:
173+
return IntermediateTensors({
174+
"hidden_states": hidden_states,
175+
"residual": residual
176+
})
177+
hidden_states, _ = self.norm(hidden_states, residual)
178+
if flashcomm_v1_enabled:
179+
hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size)
180+
return hidden_states
181+
182+
183+
class CustomQwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
184+
# add `CustomQwen2Model` to init self.model
185+
packed_modules_mapping = {
186+
"qkv_proj": [
187+
"q_proj",
188+
"k_proj",
189+
"v_proj",
190+
],
191+
"gate_up_proj": [
192+
"gate_proj",
193+
"up_proj",
194+
],
195+
}
196+
197+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
198+
super().__init__()
199+
config = vllm_config.model_config.hf_config
200+
quant_config = vllm_config.quant_config
201+
lora_config = vllm_config.lora_config
202+
203+
self.config = config
204+
self.lora_config = lora_config
205+
206+
self.quant_config = quant_config
207+
self.model = CustomQwen2Model(vllm_config=vllm_config,
208+
prefix=maybe_prefix(prefix, "model"))
209+
210+
if get_pp_group().is_last_rank:
211+
if config.tie_word_embeddings:
212+
self.lm_head = self.model.embed_tokens
213+
else:
214+
self.lm_head = ParallelLMHead(config.vocab_size,
215+
config.hidden_size,
216+
quant_config=quant_config,
217+
prefix=maybe_prefix(
218+
prefix, "lm_head"))
219+
else:
220+
self.lm_head = PPMissingLayer()
221+
222+
self.logits_processor = LogitsProcessor(config.vocab_size)
223+
224+
self.make_empty_intermediate_tensors = (
225+
self.model.make_empty_intermediate_tensors)
226+
227+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
228+
return self.model.get_input_embeddings(input_ids)
229+
230+
def forward(
231+
self,
232+
input_ids: torch.Tensor,
233+
positions: torch.Tensor,
234+
intermediate_tensors: Optional[IntermediateTensors] = None,
235+
inputs_embeds: Optional[torch.Tensor] = None,
236+
) -> Union[torch.Tensor, IntermediateTensors]:
237+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
238+
inputs_embeds)
239+
return hidden_states
240+
241+
def compute_logits(
242+
self,
243+
hidden_states: torch.Tensor,
244+
sampling_metadata: SamplingMetadata,
245+
) -> Optional[torch.Tensor]:
246+
logits = self.logits_processor(self.lm_head, hidden_states,
247+
sampling_metadata)
248+
return logits
249+
250+
def load_weights(self, weights: Iterable[tuple[str,
251+
torch.Tensor]]) -> set[str]:
252+
loader = AutoWeightsLoader(
253+
self,
254+
skip_prefixes=(["lm_head."]
255+
if self.config.tie_word_embeddings else None),
256+
)
257+
return loader.load_weights(weights)

vllm_ascend/patch/platform/patch_0_9_1/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,8 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
18+
# patch_utils should be the first import, because it will be used by other
19+
# patch files.
20+
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
21+
import vllm_ascend.patch.platform.patch_0_9_1.patch_decorator # noqa

0 commit comments

Comments
 (0)