Skip to content

Commit 16a59c4

Browse files
taoyuxiangNicholasTao
authored andcommitted
Add graph mode for Qwen2.5 and Qwen3
1 parent 63944db commit 16a59c4

File tree

8 files changed

+1256
-51
lines changed

8 files changed

+1256
-51
lines changed

vllm_ascend/ascend_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,12 @@ def check_ascend_config(vllm_config, enforce_eager):
169169
"Torchair graph mode is still experimental and not supported for V1 without mla currently, "
170170
"it has been disabled automatically.")
171171
ascend_config.torchair_graph_config.enabled = False
172-
# torchair_graph is supported for deepseek model only currently.
172+
# torchair_graph is supported for deepseek or qwen currently.
173173
if vllm_config.model_config:
174174
model_type = vllm_config.model_config.hf_config.model_type
175-
if "deepseek" not in model_type:
175+
if "deepseek" not in model_type and "qwen" not in model_type:
176176
raise NotImplementedError(
177-
"Torchair graph mode only works with deepseek model."
177+
"Torchair graph mode only works with deepseek or qwen model."
178178
)
179179
# aclgraph case
180180
else:

vllm_ascend/attention/attention_v1.py

Lines changed: 160 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919
from enum import Enum
2020
from typing import Any, Dict, List, Optional, Tuple, Type
2121

22+
import numpy as np
2223
import torch
2324
import torch_npu
2425
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
2526
AttentionLayer, AttentionType)
26-
from vllm.attention.backends.utils import CommonAttentionState
27+
from vllm.attention.backends.utils import CommonAttentionState, PAD_SLOT_ID
2728
from vllm.config import get_current_vllm_config
2829
from vllm.forward_context import ForwardContext, get_forward_context
2930
from vllm.utils import direct_register_custom_op
@@ -35,6 +36,7 @@
3536
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
3637
from vllm_ascend.ops.attention import vanilla_chunked_prefill
3738
from vllm_ascend.utils import get_graph_params
39+
from vllm_ascend.ascend_config import get_ascend_config
3840

3941

4042
class AscendAttentionBackend(AttentionBackend):
@@ -140,6 +142,8 @@ class AscendMetadata:
140142
num_input_tokens: int = 0 # Number of tokens including padding.
141143

142144
enable_dbo_across_dp: bool = False
145+
with_prefill_across_dp: bool = False
146+
use_torchair_graph: bool = False
143147

144148
def split_metadata_for_multistream(
145149
self,
@@ -153,7 +157,6 @@ def split_metadata_for_multistream(
153157
_metadata_cls=AscendMetadata,
154158
)
155159

156-
157160
class AscendAttentionMetadataBuilder:
158161

159162
def __init__(self, runner):
@@ -163,6 +166,32 @@ def reorder_batch(self, input_batch: "InputBatch",
163166
scheduler_output: "SchedulerOutput") -> bool:
164167
return False
165168

169+
def _get_graph_runner_block_tables(
170+
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
171+
172+
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
173+
assert max_batch_size >= num_seqs
174+
175+
if isinstance(self.runner.graph_block_tables, np.ndarray):
176+
graph_block_tables = torch.zeros((max_batch_size, max_blocks),
177+
dtype=block_tables.dtype,
178+
device=block_tables.device)
179+
else:
180+
graph_block_tables = self.runner.graph_block_tables.to(
181+
device=block_tables.device, dtype=block_tables.dtype)
182+
183+
num_blocks = block_tables.size(1)
184+
if num_blocks <= max_blocks:
185+
graph_block_tables[:num_seqs, :
186+
num_blocks] = block_tables[:num_seqs, :
187+
num_blocks]
188+
else:
189+
graph_block_tables[:num_seqs, :
190+
max_blocks] = block_tables[:num_seqs, :
191+
max_blocks]
192+
193+
return graph_block_tables[:num_seqs, :max_blocks]
194+
166195
def build(self,
167196
num_reqs,
168197
num_actual_tokens,
@@ -188,6 +217,41 @@ def build(self,
188217
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
189218
attn_mask = self.runner.attn_mask
190219
attn_state = self.runner.attn_state
220+
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
221+
query_start_loc = query_start_loc_cpu.to(self.runner.device,
222+
non_blocking=True)
223+
224+
graph_pad_size = kwargs.get("graph_pad_size", -1)
225+
with_prefill_across_dp = kwargs["with_prefill_across_dp"]
226+
use_torchair_graph = graph_pad_size != -1
227+
if not with_prefill_across_dp:
228+
if use_torchair_graph and self.runner.attn_state in [
229+
AscendAttentionState.DecodeOnly,
230+
AscendAttentionState.SpecDecoding
231+
]:
232+
num_seqs = len(seq_lens)
233+
if graph_pad_size != 0:
234+
pad_value = 1
235+
padded_seq_lens = seq_lens.tolist() + [pad_value
236+
] * graph_pad_size
237+
else:
238+
padded_seq_lens = seq_lens.tolist()
239+
240+
seq_lens = torch.from_numpy(
241+
np.array(padded_seq_lens).astype(np.int32))
242+
padding = torch.full((graph_pad_size, ),
243+
PAD_SLOT_ID,
244+
dtype=slot_mapping.dtype,
245+
device=slot_mapping.device)
246+
slot_mapping = torch.cat([slot_mapping, padding])
247+
block_table_padding = torch.zeros(
248+
(graph_pad_size, ) + block_table.shape[1:],
249+
dtype=block_table.dtype,
250+
device=block_table.device)
251+
block_table = torch.cat([block_table, block_table_padding],
252+
dim=0)
253+
block_table = self._get_graph_runner_block_tables(
254+
num_seqs + graph_pad_size, block_table)
191255

192256
attn_metadata = AscendMetadata(
193257
num_actual_tokens=num_actual_tokens,
@@ -200,7 +264,44 @@ def build(self,
200264
slot_mapping=slot_mapping,
201265
attn_mask=attn_mask,
202266
attn_state=attn_state,
203-
enable_dbo_across_dp=enable_dbo_across_dp)
267+
enable_dbo_across_dp=enable_dbo_across_dp,
268+
with_prefill_across_dp=with_prefill_across_dp,
269+
use_torchair_graph=use_torchair_graph
270+
)
271+
return attn_metadata
272+
273+
def build_torchair_graph_dummy(self, num_reqs: int, num_actual_tokens: int):
274+
device = self.runner.device
275+
_, max_blocks = self.runner.graph_block_tables.shape
276+
block_table = torch.zeros((num_reqs, max_blocks),
277+
dtype=torch.int32,
278+
device=device)
279+
block_table = self._get_graph_runner_block_tables(
280+
num_reqs, block_table)
281+
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
282+
slot_mapping = torch.full((num_reqs, ),
283+
PAD_SLOT_ID,
284+
dtype=torch.int32,
285+
device=device)
286+
query_start_loc = torch.full((num_reqs, ),
287+
-1,
288+
dtype=torch.int32,
289+
device=device)
290+
291+
query_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
292+
attn_mask = self.runner.attn_mask
293+
294+
attn_metadata = AscendMetadata(
295+
num_actual_tokens=num_actual_tokens,
296+
block_tables=block_table,
297+
query_start_loc=query_start_loc,
298+
query_lens=query_lens,
299+
seq_lens=seq_lens,
300+
seq_lens_list=seq_lens.tolist(),
301+
max_query_len=query_lens.max().item(),
302+
slot_mapping=slot_mapping,
303+
attn_mask=attn_mask,
304+
attn_state=AscendAttentionState.DecodeOnly)
204305
return attn_metadata
205306

206307
def build_dummy_metadata(self, num_actual_tokens, num_reqs,
@@ -248,6 +349,7 @@ def __init__(
248349
attn_type: str = AttentionType.DECODER,
249350
kv_sharing_target_layer_name: Optional[str] = None,
250351
use_irope: bool = False,
352+
prefix: Optional[str] = None,
251353
) -> None:
252354
self.num_heads = num_heads
253355
self.head_size = head_size
@@ -267,11 +369,34 @@ def __init__(
267369
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
268370
self.key_cache = None
269371
self.value_cache = None
372+
ascend_config = get_ascend_config()
373+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
270374

271375
vllm_config = get_current_vllm_config()
272376
self.full_graph = vllm_config.compilation_config.full_cuda_graph
273377
self.block_size = vllm_config.cache_config.block_size
274378

379+
def update_kv_cache(
380+
self,
381+
key: torch.Tensor,
382+
value: torch.Tensor,
383+
key_cache: torch.Tensor,
384+
value_cache: torch.Tensor,
385+
slot_indices: torch.Tensor
386+
) -> None:
387+
# calc indices by block_size
388+
block_size = key_cache.shape[1]
389+
slot_indices = slot_indices.view(-1,1,1).to(torch.int64)
390+
block_idx = torch.div(slot_indices, block_size, rounding_mode='floor')
391+
block_offset = slot_indices % block_size
392+
indices = torch.cat([block_idx, block_offset], dim=2)
393+
indices = indices.npu()
394+
395+
# [blocknum, blocksize, numKvHeads, headDims]
396+
# -> [blocknum, blocksize, numKvHeads * headDims]
397+
torch_npu.npu_scatter_nd_update_(key_cache, indices, key)
398+
torch_npu.npu_scatter_nd_update_(value_cache, indices, value)
399+
275400
def forward(
276401
self,
277402
layer: AttentionLayer,
@@ -320,12 +445,18 @@ def forward(
320445
if self.key_cache is None:
321446
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
322447
slots = attn_metadata.slot_mapping
323-
torch_npu._npu_reshape_and_cache(
324-
key=key[:num_actual_tokens],
325-
value=value[:num_actual_tokens],
326-
key_cache=self.key_cache,
327-
value_cache=self.value_cache,
328-
slot_indices=slots)
448+
if not attn_metadata.with_prefill_across_dp and self.torchair_graph_enabled:
449+
self.update_kv_cache(key=key,
450+
value=value,
451+
key_cache=self.key_cache,
452+
value_cache=self.value_cache,
453+
slot_indices=slots.to(torch.int64))
454+
else:
455+
torch_npu._npu_reshape_and_cache(key=key[:num_actual_tokens],
456+
value=value[:num_actual_tokens],
457+
key_cache=self.key_cache,
458+
value_cache=self.value_cache,
459+
slot_indices=slots)
329460

330461
if hasattr(layer, 'quant_method'):
331462
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
@@ -363,10 +494,25 @@ def forward(
363494
scale_value=self.scale,
364495
out=output)
365496
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
366-
graph_params = get_graph_params()
367-
368-
forward_context = get_forward_context()
369-
if not forward_context.capturing:
497+
if self.torchair_graph_enabled:
498+
# query change to BSND
499+
query = query.view(-1, 1, self.num_heads * self.head_size)
500+
# [blocknum, numKvHeads, blocksize, headDims] -> [blocknum, blocksize, numKvHeads * headDims]
501+
key_cache = self.key_cache.view(*self.key_cache.shape[:-2], -1)
502+
value_cache = self.value_cache.view(*self.value_cache.shape[:-2], -1)
503+
504+
output = torch_npu.npu_incre_flash_attention(
505+
query=query,
506+
key=key_cache,
507+
value=value_cache,
508+
num_heads=self.num_heads,
509+
num_key_value_heads=self.num_kv_heads,
510+
input_layout='BSH',
511+
scale_value=self.scale,
512+
actual_seq_lengths=attn_metadata.seq_lens_list,
513+
block_table=attn_metadata.block_tables,
514+
block_size=kv_cache[0].shape[1],)
515+
elif not get_forward_context().capturing:
370516
torch_npu._npu_paged_attention(
371517
query=query,
372518
key_cache=self.key_cache,
@@ -384,6 +530,7 @@ def forward(
384530
event = torch.npu.ExternalEvent()
385531
event.wait(stream)
386532
event.reset(stream)
533+
graph_params = get_graph_params()
387534
graph_params.events[num_tokens].append(event)
388535

389536
graph_params.attn_params[num_tokens].append((

vllm_ascend/models/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ def register_model():
88
from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
99
from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
1010
from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401
11+
from .qwen2 import CustomQwen2ForCausalLM # noqa: F401
1112
from .qwen2_5_vl import \
1213
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
1314
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
@@ -60,3 +61,7 @@ def register_model():
6061

6162
ModelRegistry.register_model(
6263
"Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")
64+
65+
ModelRegistry.register_model(
66+
"Qwen2ForCausalLM",
67+
"vllm_ascend.models.qwen2:CustomQwen2ForCausalLM")

0 commit comments

Comments
 (0)