Skip to content

Commit e7c0013

Browse files
taoyuxiangNicholasTao
authored andcommitted
Add graph mode for Qwen2.5 and Qwen3
Signed-off-by: taoyuxiang <t30002884@china.huawei.com>
1 parent 63944db commit e7c0013

File tree

9 files changed

+1312
-52
lines changed

9 files changed

+1312
-52
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import unittest
2+
import torch
3+
4+
class DummyNPU:
5+
@staticmethod
6+
def npu_scatter_nd_update_(tensor, indices, updates):
7+
batch = indices.shape[0]
8+
for i in range(batch):
9+
b = indices[i,0,0].item()
10+
o = indices[i,0,1].item()
11+
tensor[b, o] = updates[i]
12+
13+
import torch_npu
14+
from vllm_ascend.attention.attention_v1 import AscendAttentionBackendImpl
15+
16+
17+
class TestUpdateKVCache(unittest.TestCase):
18+
19+
def test_basic_update(self):
20+
block_num, block_size = 3, 2
21+
num_heads, head_dim = 1, 1
22+
23+
key_cache = torch.zeros(block_num, block_size, num_heads, head_dim)
24+
value_cache = torch.zeros_like(key_cache)
25+
26+
batch_size = 2
27+
key = torch.tensor([[[1.0]], [[2.0]]])
28+
value = torch.tensor([[[3.0]], [[4.0]]])
29+
30+
slot_indices = torch.tensor([1, 3])
31+
32+
AscendAttentionBackendImpl.update_kv_cache(key, value, key_cache, value_cache, slot_indices)
33+
34+
self.assertEqual(key_cache[0,1,0,0].item(), 1.0)
35+
self.assertEqual(value_cache[0,1,0,0].item(), 3.0)
36+
37+
self.assertEqual(key_cache[1,1,0,0].item(), 2.0)
38+
self.assertEqual(value_cache[1,1,0,0].item(), 4.0)
39+
40+
if __name__ == '__main__':
41+
unittest.main()

vllm_ascend/ascend_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,12 @@ def check_ascend_config(vllm_config, enforce_eager):
169169
"Torchair graph mode is still experimental and not supported for V1 without mla currently, "
170170
"it has been disabled automatically.")
171171
ascend_config.torchair_graph_config.enabled = False
172-
# torchair_graph is supported for deepseek model only currently.
172+
# torchair_graph is supported for deepseek or qwen currently.
173173
if vllm_config.model_config:
174174
model_type = vllm_config.model_config.hf_config.model_type
175-
if "deepseek" not in model_type:
175+
if "deepseek" not in model_type and "qwen" not in model_type:
176176
raise NotImplementedError(
177-
"Torchair graph mode only works with deepseek model."
177+
"Torchair graph mode only works with deepseek or qwen model."
178178
)
179179
# aclgraph case
180180
else:

vllm_ascend/attention/attention_v1.py

Lines changed: 162 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,19 @@
1919
from enum import Enum
2020
from typing import Any, Dict, List, Optional, Tuple, Type
2121

22+
import numpy as np
2223
import torch
2324
import torch_npu
2425
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
2526
AttentionLayer, AttentionType)
26-
from vllm.attention.backends.utils import CommonAttentionState
27+
from vllm.attention.backends.utils import PAD_SLOT_ID, CommonAttentionState
2728
from vllm.config import get_current_vllm_config
2829
from vllm.forward_context import ForwardContext, get_forward_context
2930
from vllm.utils import direct_register_custom_op
3031
from vllm.v1.core.sched.output import SchedulerOutput
3132
from vllm.v1.worker.gpu_input_batch import InputBatch
3233

34+
from vllm_ascend.ascend_config import get_ascend_config
3335
from vllm_ascend.attention.utils import \
3436
AscendCommonAttentionMetadata as CommonAttentionMetadata
3537
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
@@ -52,7 +54,7 @@ def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
5254
def get_metadata_cls() -> Type["AscendMetadata"]:
5355
return AscendMetadata
5456

55-
@staticmethod
57+
@AscendAttentionBackendImplstaticmethod
5658
def get_state_cls() -> Type["CommonAttentionState"]:
5759
return CommonAttentionState
5860

@@ -140,6 +142,8 @@ class AscendMetadata:
140142
num_input_tokens: int = 0 # Number of tokens including padding.
141143

142144
enable_dbo_across_dp: bool = False
145+
with_prefill_across_dp: bool = False
146+
use_torchair_graph: bool = False
143147

144148
def split_metadata_for_multistream(
145149
self,
@@ -163,6 +167,32 @@ def reorder_batch(self, input_batch: "InputBatch",
163167
scheduler_output: "SchedulerOutput") -> bool:
164168
return False
165169

170+
def _get_graph_runner_block_tables(
171+
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
172+
173+
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
174+
assert max_batch_size >= num_seqs
175+
176+
if isinstance(self.runner.graph_block_tables, np.ndarray):
177+
graph_block_tables = torch.zeros((max_batch_size, max_blocks),
178+
dtype=block_tables.dtype,
179+
device=block_tables.device)
180+
else:
181+
graph_block_tables = self.runner.graph_block_tables.to(
182+
device=block_tables.device, dtype=block_tables.dtype)
183+
184+
num_blocks = block_tables.size(1)
185+
if num_blocks <= max_blocks:
186+
graph_block_tables[:num_seqs, :
187+
num_blocks] = block_tables[:num_seqs, :
188+
num_blocks]
189+
else:
190+
graph_block_tables[:num_seqs, :
191+
max_blocks] = block_tables[:num_seqs, :
192+
max_blocks]
193+
194+
return graph_block_tables[:num_seqs, :max_blocks]
195+
166196
def build(self,
167197
num_reqs,
168198
num_actual_tokens,
@@ -178,7 +208,7 @@ def build(self,
178208
block_table[:num_reqs])
179209

180210
query_start_loc = common_attn_metadata.query_start_loc
181-
seq_lens = common_attn_metadata.seq_lens
211+
seq_lens = common_attn_metadata.seq_lens # type: ignore
182212
# TODO: Refactor these two param to common metadata in runners,
183213
# preparing for the hybrid KV groups feature
184214
query_lens = common_attn_metadata.query_lens or self.runner.query_lens
@@ -188,6 +218,41 @@ def build(self,
188218
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
189219
attn_mask = self.runner.attn_mask
190220
attn_state = self.runner.attn_state
221+
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
222+
query_start_loc = query_start_loc_cpu.to(self.runner.device,
223+
non_blocking=True)
224+
225+
graph_pad_size = kwargs.get("graph_pad_size", -1)
226+
with_prefill_across_dp = kwargs["with_prefill_across_dp"]
227+
use_torchair_graph = graph_pad_size != -1
228+
if not with_prefill_across_dp:
229+
if use_torchair_graph and self.runner.attn_state in [
230+
AscendAttentionState.DecodeOnly,
231+
AscendAttentionState.SpecDecoding
232+
]:
233+
num_seqs = len(seq_lens)
234+
if graph_pad_size != 0:
235+
pad_value = 1
236+
padded_seq_lens = seq_lens.tolist() + [pad_value
237+
] * graph_pad_size
238+
else:
239+
padded_seq_lens = seq_lens.tolist()
240+
241+
seq_lens = torch.from_numpy(
242+
np.array(padded_seq_lens).astype(np.int32))
243+
padding = torch.full((graph_pad_size, ),
244+
PAD_SLOT_ID,
245+
dtype=slot_mapping.dtype,
246+
device=slot_mapping.device)
247+
slot_mapping = torch.cat([slot_mapping, padding])
248+
block_table_padding = torch.zeros(
249+
(graph_pad_size, ) + block_table.shape[1:],
250+
dtype=block_table.dtype,
251+
device=block_table.device)
252+
block_table = torch.cat([block_table, block_table_padding],
253+
dim=0)
254+
block_table = self._get_graph_runner_block_tables(
255+
num_seqs + graph_pad_size, block_table)
191256

192257
attn_metadata = AscendMetadata(
193258
num_actual_tokens=num_actual_tokens,
@@ -200,7 +265,44 @@ def build(self,
200265
slot_mapping=slot_mapping,
201266
attn_mask=attn_mask,
202267
attn_state=attn_state,
203-
enable_dbo_across_dp=enable_dbo_across_dp)
268+
enable_dbo_across_dp=enable_dbo_across_dp,
269+
with_prefill_across_dp=with_prefill_across_dp,
270+
use_torchair_graph=use_torchair_graph)
271+
return attn_metadata
272+
273+
def build_torchair_graph_dummy(self, num_reqs: int,
274+
num_actual_tokens: int):
275+
device = self.runner.device
276+
_, max_blocks = self.runner.graph_block_tables.shape
277+
block_table = torch.zeros((num_reqs, max_blocks),
278+
dtype=torch.int32,
279+
device=device)
280+
block_table = self._get_graph_runner_block_tables(
281+
num_reqs, block_table)
282+
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
283+
slot_mapping = torch.full((num_reqs, ),
284+
PAD_SLOT_ID,
285+
dtype=torch.int32,
286+
device=device)
287+
query_start_loc = torch.full((num_reqs, ),
288+
-1,
289+
dtype=torch.int32,
290+
device=device)
291+
292+
query_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
293+
attn_mask = self.runner.attn_mask
294+
295+
attn_metadata = AscendMetadata(
296+
num_actual_tokens=num_actual_tokens,
297+
block_tables=block_table,
298+
query_start_loc=query_start_loc,
299+
query_lens=query_lens,
300+
seq_lens=seq_lens,
301+
seq_lens_list=seq_lens.tolist(),
302+
max_query_len=query_lens.max().item(),
303+
slot_mapping=slot_mapping,
304+
attn_mask=attn_mask,
305+
attn_state=AscendAttentionState.DecodeOnly)
204306
return attn_metadata
205307

206308
def build_dummy_metadata(self, num_actual_tokens, num_reqs,
@@ -248,6 +350,7 @@ def __init__(
248350
attn_type: str = AttentionType.DECODER,
249351
kv_sharing_target_layer_name: Optional[str] = None,
250352
use_irope: bool = False,
353+
prefix: Optional[str] = None,
251354
) -> None:
252355
self.num_heads = num_heads
253356
self.head_size = head_size
@@ -267,11 +370,30 @@ def __init__(
267370
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
268371
self.key_cache = None
269372
self.value_cache = None
373+
ascend_config = get_ascend_config()
374+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
270375

271376
vllm_config = get_current_vllm_config()
272377
self.full_graph = vllm_config.compilation_config.full_cuda_graph
273378
self.block_size = vllm_config.cache_config.block_size
274379

380+
@staticmethod
381+
def update_kv_cache(key: torch.Tensor, value: torch.Tensor,
382+
key_cache: torch.Tensor, value_cache: torch.Tensor,
383+
slot_indices: torch.Tensor) -> None:
384+
# calc indices by block_size
385+
block_size = key_cache.shape[1]
386+
slot_indices = slot_indices.view(-1, 1, 1).to(torch.int64)
387+
block_idx = torch.div(slot_indices, block_size, rounding_mode='floor')
388+
block_offset = slot_indices % block_size
389+
indices = torch.cat([block_idx, block_offset], dim=2)
390+
indices = indices.npu()
391+
392+
# [blocknum, blocksize, numKvHeads, headDims]
393+
# -> [blocknum, blocksize, numKvHeads * headDims]
394+
torch_npu.npu_scatter_nd_update_(key_cache, indices, key)
395+
torch_npu.npu_scatter_nd_update_(value_cache, indices, value)
396+
275397
def forward(
276398
self,
277399
layer: AttentionLayer,
@@ -320,12 +442,19 @@ def forward(
320442
if self.key_cache is None:
321443
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
322444
slots = attn_metadata.slot_mapping
323-
torch_npu._npu_reshape_and_cache(
324-
key=key[:num_actual_tokens],
325-
value=value[:num_actual_tokens],
326-
key_cache=self.key_cache,
327-
value_cache=self.value_cache,
328-
slot_indices=slots)
445+
if not attn_metadata.with_prefill_across_dp and self.torchair_graph_enabled:
446+
self.update_kv_cache(key=key,
447+
value=value,
448+
key_cache=self.key_cache,
449+
value_cache=self.value_cache,
450+
slot_indices=slots.to(torch.int64))
451+
else:
452+
torch_npu._npu_reshape_and_cache(
453+
key=key[:num_actual_tokens],
454+
value=value[:num_actual_tokens],
455+
key_cache=self.key_cache,
456+
value_cache=self.value_cache,
457+
slot_indices=slots)
329458

330459
if hasattr(layer, 'quant_method'):
331460
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
@@ -363,10 +492,28 @@ def forward(
363492
scale_value=self.scale,
364493
out=output)
365494
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
366-
graph_params = get_graph_params()
367-
368-
forward_context = get_forward_context()
369-
if not forward_context.capturing:
495+
if self.torchair_graph_enabled:
496+
# query change to BSND
497+
query = query.view(-1, 1, self.num_heads * self.head_size)
498+
# [blocknum, numKvHeads, blocksize, headDims] -> [blocknum, blocksize, numKvHeads * headDims]
499+
key_cache = self.key_cache.view( # type: ignore
500+
*self.key_cache.shape[:-2], -1) # type: ignore
501+
value_cache = self.value_cache.view( # type: ignore
502+
*self.value_cache.shape[:-2], -1) # type: ignore
503+
504+
output = torch_npu.npu_incre_flash_attention(
505+
query=query,
506+
key=key_cache,
507+
value=value_cache,
508+
num_heads=self.num_heads,
509+
num_key_value_heads=self.num_kv_heads,
510+
input_layout='BSH',
511+
scale_value=self.scale,
512+
actual_seq_lengths=attn_metadata.seq_lens_list,
513+
block_table=attn_metadata.block_tables,
514+
block_size=kv_cache[0].shape[1],
515+
)
516+
elif not get_forward_context().capturing:
370517
torch_npu._npu_paged_attention(
371518
query=query,
372519
key_cache=self.key_cache,
@@ -384,6 +531,7 @@ def forward(
384531
event = torch.npu.ExternalEvent()
385532
event.wait(stream)
386533
event.reset(stream)
534+
graph_params = get_graph_params()
387535
graph_params.events[num_tokens].append(event)
388536

389537
graph_params.attn_params[num_tokens].append((

vllm_ascend/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ def register_model():
88
from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
99
from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
1010
from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401
11+
from .qwen2 import CustomQwen2ForCausalLM # noqa: F401
1112
from .qwen2_5_vl import \
1213
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
1314
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
@@ -60,3 +61,6 @@ def register_model():
6061

6162
ModelRegistry.register_model(
6263
"Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")
64+
65+
ModelRegistry.register_model(
66+
"Qwen2ForCausalLM", "vllm_ascend.models.qwen2:CustomQwen2ForCausalLM")

0 commit comments

Comments
 (0)