Skip to content

Commit c985f62

Browse files
author
taoyuxiang
committed
Add graph mode for Qwen2.5 and Qwen3
1 parent 63944db commit c985f62

File tree

7 files changed

+725
-50
lines changed

7 files changed

+725
-50
lines changed

vllm_ascend/ascend_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,12 @@ def check_ascend_config(vllm_config, enforce_eager):
169169
"Torchair graph mode is still experimental and not supported for V1 without mla currently, "
170170
"it has been disabled automatically.")
171171
ascend_config.torchair_graph_config.enabled = False
172-
# torchair_graph is supported for deepseek model only currently.
172+
# torchair_graph is supported for deepseek or qwen currently.
173173
if vllm_config.model_config:
174174
model_type = vllm_config.model_config.hf_config.model_type
175-
if "deepseek" not in model_type:
175+
if "deepseek" not in model_type and "qwen" not in model_type:
176176
raise NotImplementedError(
177-
"Torchair graph mode only works with deepseek model."
177+
"Torchair graph mode only works with deepseek or qwen model."
178178
)
179179
# aclgraph case
180180
else:

vllm_ascend/attention/attention_v1.py

Lines changed: 160 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919
from enum import Enum
2020
from typing import Any, Dict, List, Optional, Tuple, Type
2121

22+
import numpy as np
2223
import torch
2324
import torch_npu
2425
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
2526
AttentionLayer, AttentionType)
26-
from vllm.attention.backends.utils import CommonAttentionState
27+
from vllm.attention.backends.utils import CommonAttentionState, PAD_SLOT_ID
2728
from vllm.config import get_current_vllm_config
2829
from vllm.forward_context import ForwardContext, get_forward_context
2930
from vllm.utils import direct_register_custom_op
@@ -32,9 +33,9 @@
3233

3334
from vllm_ascend.attention.utils import \
3435
AscendCommonAttentionMetadata as CommonAttentionMetadata
35-
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
3636
from vllm_ascend.ops.attention import vanilla_chunked_prefill
3737
from vllm_ascend.utils import get_graph_params
38+
from vllm_ascend.ascend_config import get_ascend_config
3839

3940

4041
class AscendAttentionBackend(AttentionBackend):
@@ -140,7 +141,8 @@ class AscendMetadata:
140141
num_input_tokens: int = 0 # Number of tokens including padding.
141142

142143
enable_dbo_across_dp: bool = False
143-
144+
with_prefill_across_dp: bool = False
145+
use_torchair_graph: bool = False
144146
def split_metadata_for_multistream(
145147
self,
146148
ms_split_config: MSAttentionMetadataSplitConfig,
@@ -153,7 +155,6 @@ def split_metadata_for_multistream(
153155
_metadata_cls=AscendMetadata,
154156
)
155157

156-
157158
class AscendAttentionMetadataBuilder:
158159

159160
def __init__(self, runner):
@@ -163,6 +164,32 @@ def reorder_batch(self, input_batch: "InputBatch",
163164
scheduler_output: "SchedulerOutput") -> bool:
164165
return False
165166

167+
def _get_graph_runner_block_tables(
168+
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
169+
170+
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
171+
assert max_batch_size >= num_seqs
172+
173+
if isinstance(self.runner.graph_block_tables, np.ndarray):
174+
graph_block_tables = torch.zeros((max_batch_size, max_blocks),
175+
dtype=block_tables.dtype,
176+
device=block_tables.device)
177+
else:
178+
graph_block_tables = self.runner.graph_block_tables.to(
179+
device=block_tables.device, dtype=block_tables.dtype)
180+
181+
num_blocks = block_tables.size(1)
182+
if num_blocks <= max_blocks:
183+
graph_block_tables[:num_seqs, :
184+
num_blocks] = block_tables[:num_seqs, :
185+
num_blocks]
186+
else:
187+
graph_block_tables[:num_seqs, :
188+
max_blocks] = block_tables[:num_seqs, :
189+
max_blocks]
190+
191+
return graph_block_tables[:num_seqs, :max_blocks]
192+
166193
def build(self,
167194
num_reqs,
168195
num_actual_tokens,
@@ -188,6 +215,41 @@ def build(self,
188215
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
189216
attn_mask = self.runner.attn_mask
190217
attn_state = self.runner.attn_state
218+
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
219+
query_start_loc = query_start_loc_cpu.to(self.runner.device,
220+
non_blocking=True)
221+
222+
graph_pad_size = kwargs["graph_pad_size"]
223+
with_prefill_across_dp = kwargs["with_prefill_across_dp"]
224+
use_torchair_graph = graph_pad_size != -1
225+
if not with_prefill_across_dp:
226+
if use_torchair_graph and self.runner.attn_state in [
227+
AscendAttentionState.DecodeOnly,
228+
AscendAttentionState.SpecDecoding
229+
]:
230+
num_seqs = len(seq_lens)
231+
if graph_pad_size != 0:
232+
pad_value = 1
233+
padded_seq_lens = seq_lens.tolist() + [pad_value
234+
] * graph_pad_size
235+
else:
236+
padded_seq_lens = seq_lens.tolist()
237+
238+
seq_lens = torch.from_numpy(
239+
np.array(padded_seq_lens).astype(np.int32))
240+
padding = torch.full((graph_pad_size, ),
241+
PAD_SLOT_ID,
242+
dtype=slot_mapping.dtype,
243+
device=slot_mapping.device)
244+
slot_mapping = torch.cat([slot_mapping, padding])
245+
block_table_padding = torch.zeros(
246+
(graph_pad_size, ) + block_table.shape[1:],
247+
dtype=block_table.dtype,
248+
device=block_table.device)
249+
block_table = torch.cat([block_table, block_table_padding],
250+
dim=0)
251+
block_table = self._get_graph_runner_block_tables(
252+
num_seqs + graph_pad_size, block_table)
191253

192254
attn_metadata = AscendMetadata(
193255
num_actual_tokens=num_actual_tokens,
@@ -200,7 +262,44 @@ def build(self,
200262
slot_mapping=slot_mapping,
201263
attn_mask=attn_mask,
202264
attn_state=attn_state,
203-
enable_dbo_across_dp=enable_dbo_across_dp)
265+
enable_dbo_across_dp=enable_dbo_across_dp,
266+
with_prefill_across_dp=with_prefill_across_dp,
267+
use_torchair_graph=use_torchair_graph
268+
)
269+
return attn_metadata
270+
271+
def build_torchair_graph_dummy(self, num_reqs: int, num_actual_tokens: int):
272+
device = self.runner.device
273+
_, max_blocks = self.runner.graph_block_tables.shape
274+
block_table = torch.zeros((num_reqs, max_blocks),
275+
dtype=torch.int32,
276+
device=device)
277+
block_table = self._get_graph_runner_block_tables(
278+
num_reqs, block_table)
279+
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
280+
slot_mapping = torch.full((num_reqs, ),
281+
PAD_SLOT_ID,
282+
dtype=torch.int32,
283+
device=device)
284+
query_start_loc = torch.full((num_reqs, ),
285+
-1,
286+
dtype=torch.int32,
287+
device=device)
288+
289+
query_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
290+
attn_mask = self.runner.attn_mask
291+
292+
attn_metadata = AscendMetadata(
293+
num_actual_tokens=num_actual_tokens,
294+
block_tables=block_table,
295+
query_start_loc=query_start_loc,
296+
query_lens=query_lens,
297+
seq_lens=seq_lens,
298+
seq_lens_list=seq_lens.tolist(),
299+
max_query_len=query_lens.max().item(),
300+
slot_mapping=slot_mapping,
301+
attn_mask=attn_mask,
302+
attn_state=AscendAttentionState.DecodeOnly)
204303
return attn_metadata
205304

206305
def build_dummy_metadata(self, num_actual_tokens, num_reqs,
@@ -248,6 +347,7 @@ def __init__(
248347
attn_type: str = AttentionType.DECODER,
249348
kv_sharing_target_layer_name: Optional[str] = None,
250349
use_irope: bool = False,
350+
prefix: Optional[str] = None,
251351
) -> None:
252352
self.num_heads = num_heads
253353
self.head_size = head_size
@@ -267,11 +367,34 @@ def __init__(
267367
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
268368
self.key_cache = None
269369
self.value_cache = None
370+
ascend_config = get_ascend_config()
371+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
270372

271373
vllm_config = get_current_vllm_config()
272374
self.full_graph = vllm_config.compilation_config.full_cuda_graph
273375
self.block_size = vllm_config.cache_config.block_size
274376

377+
def update_kv_cache(
378+
self,
379+
key: torch.Tensor,
380+
value: torch.Tensor,
381+
key_cache: torch.Tensor,
382+
value_cache: torch.Tensor,
383+
slot_indices: torch.Tensor
384+
) -> None:
385+
# calc indices by block_size
386+
block_size = key_cache.shape[1]
387+
slot_indices = slot_indices.view(-1,1,1).to(torch.int64)
388+
block_idx = torch.div(slot_indices, block_size, rounding_mode='floor')
389+
block_offset = slot_indices % block_size
390+
indices = torch.cat([block_idx, block_offset], dim=2)
391+
indices = indices.npu()
392+
393+
# [blocknum, blocksize, numKvHeads, headDims]
394+
# -> [blocknum, blocksize, numKvHeads * headDims]
395+
torch_npu.npu_scatter_nd_update_(key_cache, indices, key)
396+
torch_npu.npu_scatter_nd_update_(value_cache, indices, value)
397+
275398
def forward(
276399
self,
277400
layer: AttentionLayer,
@@ -320,12 +443,18 @@ def forward(
320443
if self.key_cache is None:
321444
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
322445
slots = attn_metadata.slot_mapping
323-
torch_npu._npu_reshape_and_cache(
324-
key=key[:num_actual_tokens],
325-
value=value[:num_actual_tokens],
326-
key_cache=self.key_cache,
327-
value_cache=self.value_cache,
328-
slot_indices=slots)
446+
if not attn_metadata.with_prefill_across_dp and self.torchair_graph_enabled:
447+
self.update_kv_cache(key=key,
448+
value=value,
449+
key_cache=self.key_cache,
450+
value_cache=self.value_cache,
451+
slot_indices=slots.to(torch.int64))
452+
else:
453+
torch_npu._npu_reshape_and_cache(key=key[:num_actual_tokens],
454+
value=value[:num_actual_tokens],
455+
key_cache=self.key_cache,
456+
value_cache=self.value_cache,
457+
slot_indices=slots)
329458

330459
if hasattr(layer, 'quant_method'):
331460
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
@@ -363,10 +492,25 @@ def forward(
363492
scale_value=self.scale,
364493
out=output)
365494
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
366-
graph_params = get_graph_params()
367-
368-
forward_context = get_forward_context()
369-
if not forward_context.capturing:
495+
if self.torchair_graph_enabled:
496+
# query change to BSND
497+
query = query.view(-1, 1, self.num_heads * self.head_size)
498+
# [blocknum, numKvHeads, blocksize, headDims] -> [blocknum, blocksize, numKvHeads * headDims]
499+
key_cache = self.key_cache.view(*self.key_cache.shape[:-2], -1)
500+
value_cache = self.value_cache.view(*self.value_cache.shape[:-2], -1)
501+
502+
output = torch_npu.npu_incre_flash_attention(
503+
query=query,
504+
key=key_cache,
505+
value=value_cache,
506+
num_heads=self.num_heads,
507+
num_key_value_heads=self.num_kv_heads,
508+
input_layout='BSH',
509+
scale_value=self.scale,
510+
actual_seq_lengths=attn_metadata.seq_lens_list,
511+
block_table=attn_metadata.block_tables,
512+
block_size=kv_cache[0].shape[1],)
513+
elif not get_forward_context().capturing:
370514
torch_npu._npu_paged_attention(
371515
query=query,
372516
key_cache=self.key_cache,
@@ -384,6 +528,7 @@ def forward(
384528
event = torch.npu.ExternalEvent()
385529
event.wait(stream)
386530
event.reset(stream)
531+
graph_params = get_graph_params()
387532
graph_params.events[num_tokens].append(event)
388533

389534
graph_params.attn_params[num_tokens].append((

vllm_ascend/models/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def register_model():
1111
from .qwen2_5_vl import \
1212
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
1313
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
14+
from .qwen2 import CostomQwen2ForCausalLM
1415
from .qwen3 import CustomQwen3ForCausalLM # noqa: F401
1516

1617
ModelRegistry.register_model(
@@ -60,3 +61,7 @@ def register_model():
6061

6162
ModelRegistry.register_model(
6263
"Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")
64+
65+
ModelRegistry.register_model(
66+
"Qwen2ForCausalLM",
67+
"vllm_ascend.models.qwen2:CostomQwen2ForCausalLM")

0 commit comments

Comments
 (0)