Skip to content

Commit 7465c21

Browse files
committed
[V0.9.1] Add support for flashcomm_v1 in Qwen2.5
Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent 8e42f71 commit 7465c21

File tree

4 files changed

+269
-2
lines changed

4 files changed

+269
-2
lines changed

vllm_ascend/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@
106106
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
107107
lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0'))
108108
),
109+
"VLLM_ENABLE_FlashComm":
110+
lambda: int(os.getenv("VLLM_ENABLE_FlashComm", '0')),
109111
# MOE_ALL2ALL_BUFFER:
110112
# 0: default, normal init.
111113
# 1: enable moe_all2all_buffer.

vllm_ascend/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,16 @@ def register_model():
1111
from .qwen2_5_vl import \
1212
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
1313
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
14+
from .qwen2 import CustomQwen2ForCausalLM # noqa: F401
1415
from .qwen3 import CustomQwen3ForCausalLM # noqa: F401
1516

1617
ModelRegistry.register_model(
1718
"DeepSeekMTPModel",
1819
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
1920

21+
ModelRegistry.register_model(
22+
"Qwen2ForCausalLM", "vllm_ascend.models.qwen2:CustomQwen2ForCausalLM")
23+
2024
ModelRegistry.register_model(
2125
"Qwen2VLForConditionalGeneration",
2226
"vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration")

vllm_ascend/models/qwen2.py

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
from collections.abc import Iterable
2+
from typing import Any, Optional, Union
3+
4+
import torch
5+
from torch import nn
6+
import torch.nn.functional as F
7+
from transformers import Qwen2Config
8+
9+
from vllm.attention import Attention, AttentionType
10+
from vllm.compilation.decorators import support_torch_compile
11+
from vllm.config import CacheConfig, VllmConfig
12+
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
13+
from vllm.model_executor.layers.activation import SiluAndMul
14+
from vllm.model_executor.layers.layernorm import RMSNorm
15+
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
16+
QKVParallelLinear,
17+
RowParallelLinear)
18+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
19+
from vllm.model_executor.layers.quantization import QuantizationConfig
20+
from vllm.model_executor.layers.rotary_embedding import get_rope
21+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
22+
from vllm.model_executor.sampling_metadata import SamplingMetadata
23+
from vllm.sequence import IntermediateTensors
24+
25+
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
26+
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, maybe_prefix)
27+
28+
from vllm.model_executor.models.qwen2 import Qwen2Model, Qwen2DecoderLayer
29+
from vllm.distributed import (
30+
get_pp_group,
31+
get_tensor_model_parallel_world_size,
32+
get_tensor_model_parallel_rank,
33+
tensor_model_parallel_all_gather,
34+
tensor_model_parallel_all_reduce,
35+
tensor_model_parallel_reduce_scatter)
36+
from vllm.forward_context import get_forward_context
37+
import vllm_ascend.envs as ascend_envs
38+
from vllm_ascend.ops.linear import MaybeScatterRowParallelLinear
39+
40+
41+
def all_gather_and_maybe_unpad(
42+
hidden_states: torch.Tensor,
43+
pad_size: int,
44+
) -> torch.Tensor:
45+
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
46+
if pad_size > 0:
47+
return hidden_states[:-pad_size, :]
48+
return hidden_states
49+
50+
def maybe_pad_and_reduce_scatter(
51+
hidden_states: torch.Tensor,
52+
pad_size: int,
53+
) -> torch.Tensor:
54+
if pad_size > 0:
55+
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_size))
56+
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, 0)
57+
return hidden_states
58+
59+
class CustomQwen2DecoderLayer(Qwen2DecoderLayer):
60+
61+
def __init__(
62+
self,
63+
config: Qwen2Config,
64+
cache_config: Optional[CacheConfig] = None,
65+
quant_config: Optional[QuantizationConfig] = None,
66+
prefix: str = "",
67+
) -> None:
68+
super().__init__(config=config,
69+
cache_config=cache_config,
70+
quant_config=quant_config,
71+
prefix=prefix)
72+
self.tp_rank = get_tensor_model_parallel_rank()
73+
self.tp_size = get_tensor_model_parallel_world_size()
74+
self.fc_level = ascend_envs.VLLM_ENABLE_FlashComm
75+
if self.fc_level:
76+
self.self_attn.reduce_results=False
77+
self.mlp.reduce_results=False
78+
79+
def forward(
80+
self,
81+
positions: torch.Tensor,
82+
hidden_states: torch.Tensor,
83+
residual: Optional[torch.Tensor],
84+
with_prefill: bool,
85+
pad_size: int,
86+
) -> tuple[torch.Tensor, torch.Tensor]:
87+
# Self Attention
88+
if residual is None:
89+
residual = hidden_states
90+
if self.fc_level == 1 and with_prefill:
91+
if pad_size > 0:
92+
residual = F.pad(residual, (0, 0, 0, pad_size))
93+
residual = torch.chunk(residual, self.tp_size, dim=0)[self.tp_rank]
94+
hidden_states = self.input_layernorm(hidden_states)
95+
else:
96+
hidden_states, residual = self.input_layernorm(
97+
hidden_states, residual)
98+
if self.fc_level == 1 and with_prefill:
99+
hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size)
100+
hidden_states = self.self_attn(
101+
positions=positions,
102+
hidden_states=hidden_states,
103+
with_prefill=with_prefill,
104+
pad_size=pad_size,
105+
)
106+
if self.fc_level == 1 and with_prefill:
107+
if pad_size > 0:
108+
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_size))
109+
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, 0)
110+
else:
111+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
112+
# Fully Connected
113+
hidden_states, residual = self.post_attention_layernorm(
114+
hidden_states, residual)
115+
if self.fc_level == 1 and with_prefill:
116+
hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size)
117+
hidden_states = self.mlp(hidden_states, with_prefill, pad_size)
118+
if self.fc_level == 1 and with_prefill:
119+
maybe_pad_and_reduce_scatter(hidden_states, pad_size)
120+
else:
121+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
122+
return hidden_states, residual
123+
124+
125+
@support_torch_compile(
126+
dynamic_arg_dims={
127+
"input_ids": 0,
128+
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
129+
# otherwise (seq_len, ).
130+
"positions": -1,
131+
"intermediate_tensors": 0,
132+
"inputs_embeds": 0,
133+
})
134+
class CustomQwen2Model(Qwen2Model):
135+
136+
def __init__(self,
137+
*,
138+
vllm_config: VllmConfig,
139+
prefix: str = "",
140+
decoder_layer_type: type[nn.Module] = CustomQwen2DecoderLayer):
141+
super().__init__(vllm_config=vllm_config,
142+
prefix=prefix,
143+
decoder_layer_type=decoder_layer_type)
144+
self.tp_size = get_tensor_model_parallel_world_size()
145+
self.fc_level = ascend_envs.VLLM_ENABLE_FlashComm
146+
def forward(
147+
self,
148+
input_ids: torch.Tensor,
149+
positions: torch.Tensor,
150+
intermediate_tensors: Optional[IntermediateTensors] = None,
151+
inputs_embeds: Optional[torch.Tensor] = None,
152+
) -> Union[torch.Tensor, IntermediateTensors]:
153+
if get_pp_group().is_first_rank:
154+
if inputs_embeds is not None:
155+
hidden_states = inputs_embeds
156+
else:
157+
hidden_states = self.get_input_embeddings(input_ids)
158+
residual = None
159+
else:
160+
assert intermediate_tensors is not None
161+
hidden_states = intermediate_tensors["hidden_states"]
162+
residual = intermediate_tensors["residual"]
163+
pad_size = 0
164+
with_prefill = get_forward_context().with_prefill
165+
if self.fc_level == 1 and with_prefill:
166+
num_tokens = hidden_states.size(0)
167+
pad_size = (self.tp_size - (num_tokens % self.tp_size)) % self.tp_size
168+
for layer in self.layers[self.start_layer:self.end_layer]:
169+
hidden_states, residual = layer(
170+
positions,
171+
hidden_states,
172+
residual,
173+
with_prefill,
174+
pad_size,
175+
)
176+
if not get_pp_group().is_last_rank:
177+
return IntermediateTensors({
178+
"hidden_states": hidden_states,
179+
"residual": residual
180+
})
181+
hidden_states, _ = self.norm(hidden_states, residual)
182+
if self.fc_level == 1 and with_prefill:
183+
hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size)
184+
return hidden_states
185+
186+
187+
class CustomQwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
188+
# add `CustomQwen2Model` to init self.model
189+
packed_modules_mapping = {
190+
"qkv_proj": [
191+
"q_proj",
192+
"k_proj",
193+
"v_proj",
194+
],
195+
"gate_up_proj": [
196+
"gate_proj",
197+
"up_proj",
198+
],
199+
}
200+
201+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
202+
super().__init__()
203+
config = vllm_config.model_config.hf_config
204+
quant_config = vllm_config.quant_config
205+
lora_config = vllm_config.lora_config
206+
207+
self.config = config
208+
self.lora_config = lora_config
209+
210+
self.quant_config = quant_config
211+
self.model = CustomQwen2Model(vllm_config=vllm_config,
212+
prefix=maybe_prefix(prefix, "model"))
213+
214+
if get_pp_group().is_last_rank:
215+
if config.tie_word_embeddings:
216+
self.lm_head = self.model.embed_tokens
217+
else:
218+
self.lm_head = ParallelLMHead(config.vocab_size,
219+
config.hidden_size,
220+
quant_config=quant_config,
221+
prefix=maybe_prefix(
222+
prefix, "lm_head"))
223+
else:
224+
self.lm_head = PPMissingLayer()
225+
226+
self.logits_processor = LogitsProcessor(config.vocab_size)
227+
228+
self.make_empty_intermediate_tensors = (
229+
self.model.make_empty_intermediate_tensors)
230+
231+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
232+
return self.model.get_input_embeddings(input_ids)
233+
234+
def forward(
235+
self,
236+
input_ids: torch.Tensor,
237+
positions: torch.Tensor,
238+
intermediate_tensors: Optional[IntermediateTensors] = None,
239+
inputs_embeds: Optional[torch.Tensor] = None,
240+
) -> Union[torch.Tensor, IntermediateTensors]:
241+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
242+
inputs_embeds)
243+
return hidden_states
244+
245+
def compute_logits(
246+
self,
247+
hidden_states: torch.Tensor,
248+
sampling_metadata: SamplingMetadata,
249+
) -> Optional[torch.Tensor]:
250+
logits = self.logits_processor(self.lm_head, hidden_states,
251+
sampling_metadata)
252+
return logits
253+
254+
def load_weights(self, weights: Iterable[tuple[str,
255+
torch.Tensor]]) -> set[str]:
256+
loader = AutoWeightsLoader(
257+
self,
258+
skip_prefixes=(["lm_head."]
259+
if self.config.tie_word_embeddings else None),
260+
)
261+
return loader.load_weights(weights)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2033,8 +2033,8 @@ def capture_model(self) -> None:
20332033
for num_tokens in reversed(self.aclgraph_batch_sizes):
20342034
for _ in range(self.vllm_config.compilation_config.
20352035
cudagraph_num_of_warmups):
2036-
self._dummy_run(num_tokens, skip_attn=skip_attn)
2037-
self._dummy_run(num_tokens, skip_attn=skip_attn)
2036+
self._dummy_run(num_tokens, skip_attn=skip_attn, with_prefill=False)
2037+
self._dummy_run(num_tokens, skip_attn=skip_attn, with_prefill=False)
20382038
else:
20392039
logger.info("Skipping NPU graph capture for eager mode.")
20402040
return

0 commit comments

Comments
 (0)