Skip to content

Commit e19ed68

Browse files
committed
[V0.9.1] Add support for flashcomm_v1 in Qwen2.5
Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent 8e42f71 commit e19ed68

File tree

8 files changed

+459
-2
lines changed

8 files changed

+459
-2
lines changed

vllm_ascend/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,10 @@ def register_model():
2727
# is upgraded to 2.7.0
2828
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa: F401
2929

30+
from .utils import vllm_version_is
31+
# Import specific patches for different versions
32+
if vllm_version_is("0.9.1"):
33+
import vllm_ascend.patch.compilation.patch_0_9_1.patch_decorator # noqa: F401
34+
3035
from .models import register_model
3136
register_model()

vllm_ascend/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@
106106
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
107107
lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0'))
108108
),
109+
"VLLM_ENABLE_FlashComm":
110+
lambda: int(os.getenv("VLLM_ENABLE_FlashComm", '0')),
109111
# MOE_ALL2ALL_BUFFER:
110112
# 0: default, normal init.
111113
# 1: enable moe_all2all_buffer.

vllm_ascend/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,16 @@ def register_model():
1111
from .qwen2_5_vl import \
1212
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
1313
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
14+
from .qwen2 import CustomQwen2ForCausalLM # noqa: F401
1415
from .qwen3 import CustomQwen3ForCausalLM # noqa: F401
1516

1617
ModelRegistry.register_model(
1718
"DeepSeekMTPModel",
1819
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
1920

21+
ModelRegistry.register_model(
22+
"Qwen2ForCausalLM", "vllm_ascend.models.qwen2:CustomQwen2ForCausalLM")
23+
2024
ModelRegistry.register_model(
2125
"Qwen2VLForConditionalGeneration",
2226
"vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration")

vllm_ascend/models/qwen2.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
from collections.abc import Iterable
2+
from typing import Any, Optional, Union
3+
4+
import torch
5+
from torch import nn
6+
import torch.nn.functional as F
7+
from transformers import Qwen2Config
8+
9+
from vllm.compilation.decorators import support_torch_compile
10+
from vllm.config import CacheConfig, VllmConfig
11+
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
12+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
13+
from vllm.model_executor.layers.quantization import QuantizationConfig
14+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
15+
from vllm.model_executor.sampling_metadata import SamplingMetadata
16+
from vllm.sequence import IntermediateTensors
17+
18+
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
19+
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, maybe_prefix)
20+
21+
from vllm.model_executor.models.qwen2 import Qwen2Model, Qwen2DecoderLayer
22+
from vllm.distributed import (
23+
get_pp_group,
24+
get_tensor_model_parallel_world_size,
25+
get_tensor_model_parallel_rank,
26+
tensor_model_parallel_all_gather,
27+
tensor_model_parallel_all_reduce,
28+
tensor_model_parallel_reduce_scatter)
29+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
30+
from vllm.forward_context import get_forward_context
31+
import vllm_ascend.envs as ascend_envs
32+
33+
34+
def all_gather_and_maybe_unpad(
35+
hidden_states: torch.Tensor,
36+
pad_size: int,
37+
) -> torch.Tensor:
38+
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
39+
if pad_size > 0:
40+
return hidden_states[:-pad_size, :]
41+
return hidden_states
42+
43+
def maybe_pad_and_reduce_scatter(
44+
hidden_states: torch.Tensor,
45+
pad_size: int,
46+
) -> torch.Tensor:
47+
if pad_size > 0:
48+
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_size))
49+
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, 0)
50+
return hidden_states
51+
52+
class CustomQwen2DecoderLayer(Qwen2DecoderLayer):
53+
54+
def __init__(
55+
self,
56+
config: Qwen2Config,
57+
cache_config: Optional[CacheConfig] = None,
58+
quant_config: Optional[QuantizationConfig] = None,
59+
prefix: str = "",
60+
) -> None:
61+
super().__init__(config=config,
62+
cache_config=cache_config,
63+
quant_config=quant_config,
64+
prefix=prefix)
65+
self.tp_rank = get_tensor_model_parallel_rank()
66+
self.tp_size = get_tensor_model_parallel_world_size()
67+
self.self_attn.o_proj.reduce_results=False
68+
self.mlp.down_proj.reduce_results=False
69+
70+
def forward(
71+
self,
72+
positions: torch.Tensor,
73+
hidden_states: torch.Tensor,
74+
residual: Optional[torch.Tensor],
75+
fc_enabled: bool,
76+
pad_size: int,
77+
) -> tuple[torch.Tensor, torch.Tensor]:
78+
# Self Attention
79+
if residual is None:
80+
residual = hidden_states
81+
if fc_enabled:
82+
if pad_size > 0:
83+
residual = F.pad(residual, (0, 0, 0, pad_size))
84+
residual = torch.chunk(residual, self.tp_size, dim=0)[self.tp_rank]
85+
hidden_states = self.input_layernorm(hidden_states)
86+
else:
87+
hidden_states, residual = self.input_layernorm(
88+
hidden_states, residual)
89+
if fc_enabled:
90+
hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size)
91+
hidden_states = self.self_attn(
92+
positions=positions,
93+
hidden_states=hidden_states,
94+
)
95+
if fc_enabled:
96+
hidden_states = maybe_pad_and_reduce_scatter(hidden_states, pad_size)
97+
else:
98+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
99+
# Fully Connected
100+
hidden_states, residual = self.post_attention_layernorm(
101+
hidden_states, residual)
102+
if fc_enabled:
103+
hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size)
104+
hidden_states = self.mlp(hidden_states)
105+
if fc_enabled:
106+
hidden_states = maybe_pad_and_reduce_scatter(hidden_states, pad_size)
107+
else:
108+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
109+
return hidden_states, residual
110+
111+
112+
@support_torch_compile(
113+
dynamic_arg_dims={
114+
"input_ids": 0,
115+
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
116+
# otherwise (seq_len, ).
117+
"positions": -1,
118+
"intermediate_tensors": 0,
119+
"inputs_embeds": 0,
120+
})
121+
class CustomQwen2Model(Qwen2Model):
122+
123+
def __init__(self,
124+
*,
125+
vllm_config: VllmConfig,
126+
prefix: str = "",
127+
decoder_layer_type: type[nn.Module] = CustomQwen2DecoderLayer):
128+
super().__init__(vllm_config=vllm_config,
129+
prefix=prefix,
130+
decoder_layer_type=decoder_layer_type)
131+
self.tp_size = get_tensor_model_parallel_world_size()
132+
133+
def forward(
134+
self,
135+
input_ids: torch.Tensor,
136+
positions: torch.Tensor,
137+
intermediate_tensors: Optional[IntermediateTensors] = None,
138+
inputs_embeds: Optional[torch.Tensor] = None,
139+
) -> Union[torch.Tensor, IntermediateTensors]:
140+
if get_pp_group().is_first_rank:
141+
if inputs_embeds is not None:
142+
hidden_states = inputs_embeds
143+
else:
144+
hidden_states = self.get_input_embeddings(input_ids)
145+
residual = None
146+
else:
147+
assert intermediate_tensors is not None
148+
hidden_states = intermediate_tensors["hidden_states"]
149+
residual = intermediate_tensors["residual"]
150+
pad_size = 0
151+
fc_enabled = False
152+
attn_metadata = get_forward_context().attn_metadata
153+
if ascend_envs.VLLM_ENABLE_FlashComm == 1 and \
154+
attn_metadata is not None and \
155+
attn_metadata.attn_state != AscendAttentionState.DecodeOnly:
156+
fc_enabled = True
157+
if fc_enabled:
158+
num_tokens = hidden_states.size(0)
159+
pad_size = (self.tp_size - (num_tokens % self.tp_size)) % self.tp_size
160+
for layer in self.layers[self.start_layer:self.end_layer]:
161+
hidden_states, residual = layer(
162+
positions,
163+
hidden_states,
164+
residual,
165+
fc_enabled,
166+
pad_size,
167+
)
168+
if not get_pp_group().is_last_rank:
169+
return IntermediateTensors({
170+
"hidden_states": hidden_states,
171+
"residual": residual
172+
})
173+
hidden_states, _ = self.norm(hidden_states, residual)
174+
if fc_enabled:
175+
hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size)
176+
return hidden_states
177+
178+
179+
class CustomQwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
180+
# add `CustomQwen2Model` to init self.model
181+
packed_modules_mapping = {
182+
"qkv_proj": [
183+
"q_proj",
184+
"k_proj",
185+
"v_proj",
186+
],
187+
"gate_up_proj": [
188+
"gate_proj",
189+
"up_proj",
190+
],
191+
}
192+
193+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
194+
super().__init__()
195+
config = vllm_config.model_config.hf_config
196+
quant_config = vllm_config.quant_config
197+
lora_config = vllm_config.lora_config
198+
199+
self.config = config
200+
self.lora_config = lora_config
201+
202+
self.quant_config = quant_config
203+
self.model = CustomQwen2Model(vllm_config=vllm_config,
204+
prefix=maybe_prefix(prefix, "model"))
205+
206+
if get_pp_group().is_last_rank:
207+
if config.tie_word_embeddings:
208+
self.lm_head = self.model.embed_tokens
209+
else:
210+
self.lm_head = ParallelLMHead(config.vocab_size,
211+
config.hidden_size,
212+
quant_config=quant_config,
213+
prefix=maybe_prefix(
214+
prefix, "lm_head"))
215+
else:
216+
self.lm_head = PPMissingLayer()
217+
218+
self.logits_processor = LogitsProcessor(config.vocab_size)
219+
220+
self.make_empty_intermediate_tensors = (
221+
self.model.make_empty_intermediate_tensors)
222+
223+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
224+
return self.model.get_input_embeddings(input_ids)
225+
226+
def forward(
227+
self,
228+
input_ids: torch.Tensor,
229+
positions: torch.Tensor,
230+
intermediate_tensors: Optional[IntermediateTensors] = None,
231+
inputs_embeds: Optional[torch.Tensor] = None,
232+
) -> Union[torch.Tensor, IntermediateTensors]:
233+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
234+
inputs_embeds)
235+
return hidden_states
236+
237+
def compute_logits(
238+
self,
239+
hidden_states: torch.Tensor,
240+
sampling_metadata: SamplingMetadata,
241+
) -> Optional[torch.Tensor]:
242+
logits = self.logits_processor(self.lm_head, hidden_states,
243+
sampling_metadata)
244+
return logits
245+
246+
def load_weights(self, weights: Iterable[tuple[str,
247+
torch.Tensor]]) -> set[str]:
248+
loader = AutoWeightsLoader(
249+
self,
250+
skip_prefixes=(["lm_head."]
251+
if self.config.tie_word_embeddings else None),
252+
)
253+
return loader.load_weights(weights)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from vllm_ascend.utils import vllm_version_is
18+
19+
# Import specific patches for different versions
20+
if vllm_version_is("0.9.1"):
21+
from vllm_ascend.patch.compilation import patch_0_9_1 # noqa: F401
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import vllm_ascend.patch.compilation.patch_0_9_1.patch_decorator # noqa

0 commit comments

Comments
 (0)