Skip to content

Commit a0e44e3

Browse files
taoyuxiangNicholasTao
authored andcommitted
Add graph mode for Qwen2.5 and Qwen3
Signed-off-by: taoyuxiang <t30002884@china.huawei.com>
1 parent 63944db commit a0e44e3

File tree

8 files changed

+1267
-50
lines changed

8 files changed

+1267
-50
lines changed

vllm_ascend/ascend_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,12 @@ def check_ascend_config(vllm_config, enforce_eager):
169169
"Torchair graph mode is still experimental and not supported for V1 without mla currently, "
170170
"it has been disabled automatically.")
171171
ascend_config.torchair_graph_config.enabled = False
172-
# torchair_graph is supported for deepseek model only currently.
172+
# torchair_graph is supported for deepseek or qwen currently.
173173
if vllm_config.model_config:
174174
model_type = vllm_config.model_config.hf_config.model_type
175-
if "deepseek" not in model_type:
175+
if "deepseek" not in model_type and "qwen" not in model_type:
176176
raise NotImplementedError(
177-
"Torchair graph mode only works with deepseek model."
177+
"Torchair graph mode only works with deepseek or qwen model."
178178
)
179179
# aclgraph case
180180
else:

vllm_ascend/attention/attention_v1.py

Lines changed: 159 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,19 @@
1919
from enum import Enum
2020
from typing import Any, Dict, List, Optional, Tuple, Type
2121

22+
import numpy as np
2223
import torch
2324
import torch_npu
2425
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
2526
AttentionLayer, AttentionType)
26-
from vllm.attention.backends.utils import CommonAttentionState
27+
from vllm.attention.backends.utils import PAD_SLOT_ID, CommonAttentionState
2728
from vllm.config import get_current_vllm_config
2829
from vllm.forward_context import ForwardContext, get_forward_context
2930
from vllm.utils import direct_register_custom_op
3031
from vllm.v1.core.sched.output import SchedulerOutput
3132
from vllm.v1.worker.gpu_input_batch import InputBatch
3233

34+
from vllm_ascend.ascend_config import get_ascend_config
3335
from vllm_ascend.attention.utils import \
3436
AscendCommonAttentionMetadata as CommonAttentionMetadata
3537
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
@@ -140,6 +142,8 @@ class AscendMetadata:
140142
num_input_tokens: int = 0 # Number of tokens including padding.
141143

142144
enable_dbo_across_dp: bool = False
145+
with_prefill_across_dp: bool = False
146+
use_torchair_graph: bool = False
143147

144148
def split_metadata_for_multistream(
145149
self,
@@ -163,6 +167,32 @@ def reorder_batch(self, input_batch: "InputBatch",
163167
scheduler_output: "SchedulerOutput") -> bool:
164168
return False
165169

170+
def _get_graph_runner_block_tables(
171+
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
172+
173+
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
174+
assert max_batch_size >= num_seqs
175+
176+
if isinstance(self.runner.graph_block_tables, np.ndarray):
177+
graph_block_tables = torch.zeros((max_batch_size, max_blocks),
178+
dtype=block_tables.dtype,
179+
device=block_tables.device)
180+
else:
181+
graph_block_tables = self.runner.graph_block_tables.to(
182+
device=block_tables.device, dtype=block_tables.dtype)
183+
184+
num_blocks = block_tables.size(1)
185+
if num_blocks <= max_blocks:
186+
graph_block_tables[:num_seqs, :
187+
num_blocks] = block_tables[:num_seqs, :
188+
num_blocks]
189+
else:
190+
graph_block_tables[:num_seqs, :
191+
max_blocks] = block_tables[:num_seqs, :
192+
max_blocks]
193+
194+
return graph_block_tables[:num_seqs, :max_blocks]
195+
166196
def build(self,
167197
num_reqs,
168198
num_actual_tokens,
@@ -188,6 +218,41 @@ def build(self,
188218
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
189219
attn_mask = self.runner.attn_mask
190220
attn_state = self.runner.attn_state
221+
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
222+
query_start_loc = query_start_loc_cpu.to(self.runner.device,
223+
non_blocking=True)
224+
225+
graph_pad_size = kwargs.get("graph_pad_size", -1)
226+
with_prefill_across_dp = kwargs["with_prefill_across_dp"]
227+
use_torchair_graph = graph_pad_size != -1
228+
if not with_prefill_across_dp:
229+
if use_torchair_graph and self.runner.attn_state in [
230+
AscendAttentionState.DecodeOnly,
231+
AscendAttentionState.SpecDecoding
232+
]:
233+
num_seqs = len(seq_lens)
234+
if graph_pad_size != 0:
235+
pad_value = 1
236+
padded_seq_lens = seq_lens.tolist() + [pad_value
237+
] * graph_pad_size
238+
else:
239+
padded_seq_lens = seq_lens.tolist()
240+
241+
seq_lens = torch.from_numpy(
242+
np.array(padded_seq_lens).astype(np.int32))
243+
padding = torch.full((graph_pad_size, ),
244+
PAD_SLOT_ID,
245+
dtype=slot_mapping.dtype,
246+
device=slot_mapping.device)
247+
slot_mapping = torch.cat([slot_mapping, padding])
248+
block_table_padding = torch.zeros(
249+
(graph_pad_size, ) + block_table.shape[1:],
250+
dtype=block_table.dtype,
251+
device=block_table.device)
252+
block_table = torch.cat([block_table, block_table_padding],
253+
dim=0)
254+
block_table = self._get_graph_runner_block_tables(
255+
num_seqs + graph_pad_size, block_table)
191256

192257
attn_metadata = AscendMetadata(
193258
num_actual_tokens=num_actual_tokens,
@@ -200,7 +265,44 @@ def build(self,
200265
slot_mapping=slot_mapping,
201266
attn_mask=attn_mask,
202267
attn_state=attn_state,
203-
enable_dbo_across_dp=enable_dbo_across_dp)
268+
enable_dbo_across_dp=enable_dbo_across_dp,
269+
with_prefill_across_dp=with_prefill_across_dp,
270+
use_torchair_graph=use_torchair_graph)
271+
return attn_metadata
272+
273+
def build_torchair_graph_dummy(self, num_reqs: int,
274+
num_actual_tokens: int):
275+
device = self.runner.device
276+
_, max_blocks = self.runner.graph_block_tables.shape
277+
block_table = torch.zeros((num_reqs, max_blocks),
278+
dtype=torch.int32,
279+
device=device)
280+
block_table = self._get_graph_runner_block_tables(
281+
num_reqs, block_table)
282+
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
283+
slot_mapping = torch.full((num_reqs, ),
284+
PAD_SLOT_ID,
285+
dtype=torch.int32,
286+
device=device)
287+
query_start_loc = torch.full((num_reqs, ),
288+
-1,
289+
dtype=torch.int32,
290+
device=device)
291+
292+
query_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
293+
attn_mask = self.runner.attn_mask
294+
295+
attn_metadata = AscendMetadata(
296+
num_actual_tokens=num_actual_tokens,
297+
block_tables=block_table,
298+
query_start_loc=query_start_loc,
299+
query_lens=query_lens,
300+
seq_lens=seq_lens,
301+
seq_lens_list=seq_lens.tolist(),
302+
max_query_len=query_lens.max().item(),
303+
slot_mapping=slot_mapping,
304+
attn_mask=attn_mask,
305+
attn_state=AscendAttentionState.DecodeOnly)
204306
return attn_metadata
205307

206308
def build_dummy_metadata(self, num_actual_tokens, num_reqs,
@@ -248,6 +350,7 @@ def __init__(
248350
attn_type: str = AttentionType.DECODER,
249351
kv_sharing_target_layer_name: Optional[str] = None,
250352
use_irope: bool = False,
353+
prefix: Optional[str] = None,
251354
) -> None:
252355
self.num_heads = num_heads
253356
self.head_size = head_size
@@ -267,11 +370,29 @@ def __init__(
267370
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
268371
self.key_cache = None
269372
self.value_cache = None
373+
ascend_config = get_ascend_config()
374+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
270375

271376
vllm_config = get_current_vllm_config()
272377
self.full_graph = vllm_config.compilation_config.full_cuda_graph
273378
self.block_size = vllm_config.cache_config.block_size
274379

380+
def update_kv_cache(self, key: torch.Tensor, value: torch.Tensor,
381+
key_cache: torch.Tensor, value_cache: torch.Tensor,
382+
slot_indices: torch.Tensor) -> None:
383+
# calc indices by block_size
384+
block_size = key_cache.shape[1]
385+
slot_indices = slot_indices.view(-1, 1, 1).to(torch.int64)
386+
block_idx = torch.div(slot_indices, block_size, rounding_mode='floor')
387+
block_offset = slot_indices % block_size
388+
indices = torch.cat([block_idx, block_offset], dim=2)
389+
indices = indices.npu()
390+
391+
# [blocknum, blocksize, numKvHeads, headDims]
392+
# -> [blocknum, blocksize, numKvHeads * headDims]
393+
torch_npu.npu_scatter_nd_update_(key_cache, indices, key)
394+
torch_npu.npu_scatter_nd_update_(value_cache, indices, value)
395+
275396
def forward(
276397
self,
277398
layer: AttentionLayer,
@@ -320,12 +441,19 @@ def forward(
320441
if self.key_cache is None:
321442
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
322443
slots = attn_metadata.slot_mapping
323-
torch_npu._npu_reshape_and_cache(
324-
key=key[:num_actual_tokens],
325-
value=value[:num_actual_tokens],
326-
key_cache=self.key_cache,
327-
value_cache=self.value_cache,
328-
slot_indices=slots)
444+
if not attn_metadata.with_prefill_across_dp and self.torchair_graph_enabled:
445+
self.update_kv_cache(key=key,
446+
value=value,
447+
key_cache=self.key_cache,
448+
value_cache=self.value_cache,
449+
slot_indices=slots.to(torch.int64))
450+
else:
451+
torch_npu._npu_reshape_and_cache(
452+
key=key[:num_actual_tokens],
453+
value=value[:num_actual_tokens],
454+
key_cache=self.key_cache,
455+
value_cache=self.value_cache,
456+
slot_indices=slots)
329457

330458
if hasattr(layer, 'quant_method'):
331459
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
@@ -363,10 +491,28 @@ def forward(
363491
scale_value=self.scale,
364492
out=output)
365493
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
366-
graph_params = get_graph_params()
367-
368-
forward_context = get_forward_context()
369-
if not forward_context.capturing:
494+
if self.torchair_graph_enabled:
495+
# query change to BSND
496+
query = query.view(-1, 1, self.num_heads * self.head_size)
497+
# [blocknum, numKvHeads, blocksize, headDims] -> [blocknum, blocksize, numKvHeads * headDims]
498+
key_cache = self.key_cache.view(*self.key_cache.shape[:-2],
499+
-1)
500+
value_cache = self.value_cache.view(
501+
*self.value_cache.shape[:-2], -1)
502+
503+
output = torch_npu.npu_incre_flash_attention(
504+
query=query,
505+
key=key_cache,
506+
value=value_cache,
507+
num_heads=self.num_heads,
508+
num_key_value_heads=self.num_kv_heads,
509+
input_layout='BSH',
510+
scale_value=self.scale,
511+
actual_seq_lengths=attn_metadata.seq_lens_list,
512+
block_table=attn_metadata.block_tables,
513+
block_size=kv_cache[0].shape[1],
514+
)
515+
elif not get_forward_context().capturing:
370516
torch_npu._npu_paged_attention(
371517
query=query,
372518
key_cache=self.key_cache,
@@ -384,6 +530,7 @@ def forward(
384530
event = torch.npu.ExternalEvent()
385531
event.wait(stream)
386532
event.reset(stream)
533+
graph_params = get_graph_params()
387534
graph_params.events[num_tokens].append(event)
388535

389536
graph_params.attn_params[num_tokens].append((

vllm_ascend/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ def register_model():
88
from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
99
from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
1010
from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401
11+
from .qwen2 import CustomQwen2ForCausalLM # noqa: F401
1112
from .qwen2_5_vl import \
1213
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
1314
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
@@ -60,3 +61,6 @@ def register_model():
6061

6162
ModelRegistry.register_model(
6263
"Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")
64+
65+
ModelRegistry.register_model(
66+
"Qwen2ForCausalLM", "vllm_ascend.models.qwen2:CustomQwen2ForCausalLM")

0 commit comments

Comments
 (0)