Skip to content

Commit 5a85a7f

Browse files
committed
[V0.9.1] Add support for flashcomm_v1 in Qwen2.5
Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent 8e42f71 commit 5a85a7f

File tree

4 files changed

+258
-2
lines changed

4 files changed

+258
-2
lines changed

vllm_ascend/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@
106106
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
107107
lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0'))
108108
),
109+
"VLLM_ENABLE_FlashComm":
110+
lambda: int(os.getenv("VLLM_ENABLE_FlashComm", '0')),
109111
# MOE_ALL2ALL_BUFFER:
110112
# 0: default, normal init.
111113
# 1: enable moe_all2all_buffer.

vllm_ascend/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,16 @@ def register_model():
1111
from .qwen2_5_vl import \
1212
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
1313
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
14+
from .qwen2 import CustomQwen2ForCausalLM # noqa: F401
1415
from .qwen3 import CustomQwen3ForCausalLM # noqa: F401
1516

1617
ModelRegistry.register_model(
1718
"DeepSeekMTPModel",
1819
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
1920

21+
ModelRegistry.register_model(
22+
"Qwen2ForCausalLM", "vllm_ascend.models.qwen2:CustomQwen2ForCausalLM")
23+
2024
ModelRegistry.register_model(
2125
"Qwen2VLForConditionalGeneration",
2226
"vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration")

vllm_ascend/models/qwen2.py

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
from collections.abc import Iterable
2+
from typing import Any, Optional, Union
3+
4+
import torch
5+
from torch import nn
6+
import torch.nn.functional as F
7+
from transformers import Qwen2Config
8+
9+
from vllm.compilation.decorators import support_torch_compile
10+
from vllm.config import CacheConfig, VllmConfig
11+
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
12+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
13+
from vllm.model_executor.layers.quantization import QuantizationConfig
14+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
15+
from vllm.model_executor.sampling_metadata import SamplingMetadata
16+
from vllm.sequence import IntermediateTensors
17+
18+
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
19+
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, maybe_prefix)
20+
21+
from vllm.model_executor.models.qwen2 import Qwen2Model, Qwen2DecoderLayer
22+
from vllm.distributed import (
23+
get_pp_group,
24+
get_tensor_model_parallel_world_size,
25+
get_tensor_model_parallel_rank,
26+
tensor_model_parallel_all_gather,
27+
tensor_model_parallel_all_reduce,
28+
tensor_model_parallel_reduce_scatter)
29+
from vllm.forward_context import get_forward_context
30+
import vllm_ascend.envs as ascend_envs
31+
32+
33+
def all_gather_and_maybe_unpad(
34+
hidden_states: torch.Tensor,
35+
pad_size: int,
36+
) -> torch.Tensor:
37+
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
38+
if pad_size > 0:
39+
return hidden_states[:-pad_size, :]
40+
return hidden_states
41+
42+
def maybe_pad_and_reduce_scatter(
43+
hidden_states: torch.Tensor,
44+
pad_size: int,
45+
) -> torch.Tensor:
46+
if pad_size > 0:
47+
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_size))
48+
hidden_states = tensor_model_parallel_reduce_scatter(hidden_states, 0)
49+
return hidden_states
50+
51+
class CustomQwen2DecoderLayer(Qwen2DecoderLayer):
52+
53+
def __init__(
54+
self,
55+
config: Qwen2Config,
56+
cache_config: Optional[CacheConfig] = None,
57+
quant_config: Optional[QuantizationConfig] = None,
58+
prefix: str = "",
59+
) -> None:
60+
super().__init__(config=config,
61+
cache_config=cache_config,
62+
quant_config=quant_config,
63+
prefix=prefix)
64+
self.tp_rank = get_tensor_model_parallel_rank()
65+
self.tp_size = get_tensor_model_parallel_world_size()
66+
self.fc_level = ascend_envs.VLLM_ENABLE_FlashComm
67+
self.self_attn.reduce_results=False
68+
self.mlp.reduce_results=False
69+
70+
def forward(
71+
self,
72+
positions: torch.Tensor,
73+
hidden_states: torch.Tensor,
74+
residual: Optional[torch.Tensor],
75+
with_prefill: bool,
76+
pad_size: int,
77+
) -> tuple[torch.Tensor, torch.Tensor]:
78+
# Self Attention
79+
if residual is None:
80+
residual = hidden_states
81+
if self.fc_level == 1 and with_prefill:
82+
if pad_size > 0:
83+
residual = F.pad(residual, (0, 0, 0, pad_size))
84+
residual = torch.chunk(residual, self.tp_size, dim=0)[self.tp_rank]
85+
hidden_states = self.input_layernorm(hidden_states)
86+
else:
87+
hidden_states, residual = self.input_layernorm(
88+
hidden_states, residual)
89+
if self.fc_level == 1 and with_prefill:
90+
hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size)
91+
hidden_states = self.self_attn(
92+
positions=positions,
93+
hidden_states=hidden_states,
94+
with_prefill=with_prefill,
95+
pad_size=pad_size,
96+
)
97+
if self.fc_level == 1 and with_prefill:
98+
hidden_states = maybe_pad_and_reduce_scatter(hidden_states, pad_size)
99+
else:
100+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
101+
# Fully Connected
102+
hidden_states, residual = self.post_attention_layernorm(
103+
hidden_states, residual)
104+
if self.fc_level == 1 and with_prefill:
105+
hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size)
106+
hidden_states = self.mlp(hidden_states, with_prefill, pad_size)
107+
if self.fc_level == 1 and with_prefill:
108+
hidden_states = maybe_pad_and_reduce_scatter(hidden_states, pad_size)
109+
else:
110+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
111+
return hidden_states, residual
112+
113+
114+
@support_torch_compile(
115+
dynamic_arg_dims={
116+
"input_ids": 0,
117+
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
118+
# otherwise (seq_len, ).
119+
"positions": -1,
120+
"intermediate_tensors": 0,
121+
"inputs_embeds": 0,
122+
})
123+
class CustomQwen2Model(Qwen2Model):
124+
125+
def __init__(self,
126+
*,
127+
vllm_config: VllmConfig,
128+
prefix: str = "",
129+
decoder_layer_type: type[nn.Module] = CustomQwen2DecoderLayer):
130+
super().__init__(vllm_config=vllm_config,
131+
prefix=prefix,
132+
decoder_layer_type=decoder_layer_type)
133+
self.tp_size = get_tensor_model_parallel_world_size()
134+
self.fc_level = ascend_envs.VLLM_ENABLE_FlashComm
135+
def forward(
136+
self,
137+
input_ids: torch.Tensor,
138+
positions: torch.Tensor,
139+
intermediate_tensors: Optional[IntermediateTensors] = None,
140+
inputs_embeds: Optional[torch.Tensor] = None,
141+
) -> Union[torch.Tensor, IntermediateTensors]:
142+
if get_pp_group().is_first_rank:
143+
if inputs_embeds is not None:
144+
hidden_states = inputs_embeds
145+
else:
146+
hidden_states = self.get_input_embeddings(input_ids)
147+
residual = None
148+
else:
149+
assert intermediate_tensors is not None
150+
hidden_states = intermediate_tensors["hidden_states"]
151+
residual = intermediate_tensors["residual"]
152+
pad_size = 0
153+
with_prefill = get_forward_context().with_prefill
154+
if self.fc_level == 1 and with_prefill:
155+
num_tokens = hidden_states.size(0)
156+
pad_size = (self.tp_size - (num_tokens % self.tp_size)) % self.tp_size
157+
for layer in self.layers[self.start_layer:self.end_layer]:
158+
hidden_states, residual = layer(
159+
positions,
160+
hidden_states,
161+
residual,
162+
with_prefill,
163+
pad_size,
164+
)
165+
if not get_pp_group().is_last_rank:
166+
return IntermediateTensors({
167+
"hidden_states": hidden_states,
168+
"residual": residual
169+
})
170+
hidden_states, _ = self.norm(hidden_states, residual)
171+
if self.fc_level == 1 and with_prefill:
172+
hidden_states = all_gather_and_maybe_unpad(hidden_states, pad_size)
173+
return hidden_states
174+
175+
176+
class CustomQwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
177+
# add `CustomQwen2Model` to init self.model
178+
packed_modules_mapping = {
179+
"qkv_proj": [
180+
"q_proj",
181+
"k_proj",
182+
"v_proj",
183+
],
184+
"gate_up_proj": [
185+
"gate_proj",
186+
"up_proj",
187+
],
188+
}
189+
190+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
191+
super().__init__()
192+
config = vllm_config.model_config.hf_config
193+
quant_config = vllm_config.quant_config
194+
lora_config = vllm_config.lora_config
195+
196+
self.config = config
197+
self.lora_config = lora_config
198+
199+
self.quant_config = quant_config
200+
self.model = CustomQwen2Model(vllm_config=vllm_config,
201+
prefix=maybe_prefix(prefix, "model"))
202+
203+
if get_pp_group().is_last_rank:
204+
if config.tie_word_embeddings:
205+
self.lm_head = self.model.embed_tokens
206+
else:
207+
self.lm_head = ParallelLMHead(config.vocab_size,
208+
config.hidden_size,
209+
quant_config=quant_config,
210+
prefix=maybe_prefix(
211+
prefix, "lm_head"))
212+
else:
213+
self.lm_head = PPMissingLayer()
214+
215+
self.logits_processor = LogitsProcessor(config.vocab_size)
216+
217+
self.make_empty_intermediate_tensors = (
218+
self.model.make_empty_intermediate_tensors)
219+
220+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
221+
return self.model.get_input_embeddings(input_ids)
222+
223+
def forward(
224+
self,
225+
input_ids: torch.Tensor,
226+
positions: torch.Tensor,
227+
intermediate_tensors: Optional[IntermediateTensors] = None,
228+
inputs_embeds: Optional[torch.Tensor] = None,
229+
) -> Union[torch.Tensor, IntermediateTensors]:
230+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
231+
inputs_embeds)
232+
return hidden_states
233+
234+
def compute_logits(
235+
self,
236+
hidden_states: torch.Tensor,
237+
sampling_metadata: SamplingMetadata,
238+
) -> Optional[torch.Tensor]:
239+
logits = self.logits_processor(self.lm_head, hidden_states,
240+
sampling_metadata)
241+
return logits
242+
243+
def load_weights(self, weights: Iterable[tuple[str,
244+
torch.Tensor]]) -> set[str]:
245+
loader = AutoWeightsLoader(
246+
self,
247+
skip_prefixes=(["lm_head."]
248+
if self.config.tie_word_embeddings else None),
249+
)
250+
return loader.load_weights(weights)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2033,8 +2033,8 @@ def capture_model(self) -> None:
20332033
for num_tokens in reversed(self.aclgraph_batch_sizes):
20342034
for _ in range(self.vllm_config.compilation_config.
20352035
cudagraph_num_of_warmups):
2036-
self._dummy_run(num_tokens, skip_attn=skip_attn)
2037-
self._dummy_run(num_tokens, skip_attn=skip_attn)
2036+
self._dummy_run(num_tokens, skip_attn=skip_attn, with_prefill=False)
2037+
self._dummy_run(num_tokens, skip_attn=skip_attn, with_prefill=False)
20382038
else:
20392039
logger.info("Skipping NPU graph capture for eager mode.")
20402040
return

0 commit comments

Comments
 (0)