Skip to content

Commit 8c47f5e

Browse files
committed
Refactor AscendAttentionMetadataBuilder for better extensibility and make the builder class of torchair extend from it
Signed-off-by: shen-shanshan <467638484@qq.com>
1 parent 0aba644 commit 8c47f5e

File tree

2 files changed

+111
-59
lines changed

2 files changed

+111
-59
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,18 @@ class AscendMetadata:
196196
is_only_prefill: bool = False
197197

198198

199+
@dataclass
200+
class AscendAttentionMetadataBuildInfo:
201+
num_actual_tokens: int
202+
block_table: torch.Tensor
203+
query_start_loc: torch.Tensor
204+
query_lens: torch.Tensor
205+
seq_lens: torch.Tensor
206+
slot_mapping: torch.Tensor
207+
attn_mask: torch.Tensor
208+
attn_state: AscendAttentionState
209+
210+
199211
class AscendAttentionMetadataBuilder:
200212
reorder_batch_threshold: ClassVar[int] = 1
201213

@@ -217,10 +229,61 @@ def reorder_batch(self, input_batch,
217229
scheduler_output: "SchedulerOutput") -> bool:
218230
return False
219231

232+
def _assemble_build_info(
233+
self,
234+
num_actual_tokens,
235+
block_table,
236+
query_start_loc,
237+
query_lens,
238+
seq_lens,
239+
slot_mapping,
240+
attn_mask,
241+
attn_state: "AscendAttentionState",
242+
) -> "AscendAttentionMetadataBuildInfo":
243+
if is_310p():
244+
if attn_state == AscendAttentionState.PrefillNoCache:
245+
mask_nz = nd_to_nz_2d(attn_mask)
246+
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
247+
ACL_FORMAT_FRACTAL_NZ)
248+
elif attn_state == AscendAttentionState.ChunkedPrefill:
249+
mask_nz = nd_to_nz_spec(attn_mask)
250+
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
251+
ACL_FORMAT_FRACTAL_NZ)
252+
253+
build_info = AscendAttentionMetadataBuildInfo(
254+
num_actual_tokens=num_actual_tokens,
255+
block_table=block_table,
256+
query_start_loc=query_start_loc,
257+
query_lens=query_lens,
258+
seq_lens=seq_lens,
259+
slot_mapping=slot_mapping,
260+
attn_mask=attn_mask,
261+
attn_state=attn_state)
262+
return build_info
263+
264+
def _assemble_attn_metadata(
265+
self,
266+
build_info: "AscendAttentionMetadataBuildInfo",
267+
common_attn_metadata: "AscendCommonAttentionMetadata",
268+
) -> "AscendMetadata":
269+
attn_metadata = AscendMetadata(
270+
num_actual_tokens=build_info.num_actual_tokens,
271+
block_tables=build_info.block_table,
272+
query_start_loc=build_info.query_start_loc,
273+
query_lens=build_info.query_lens,
274+
seq_lens=build_info.seq_lens,
275+
max_query_len=common_attn_metadata.max_query_len,
276+
slot_mapping=build_info.slot_mapping,
277+
attn_mask=build_info.attn_mask,
278+
attn_state=build_info.attn_state,
279+
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
280+
is_only_prefill=common_attn_metadata.is_only_prefill)
281+
return attn_metadata
282+
220283
def build(
221284
self,
222285
common_prefix_len: int,
223-
common_attn_metadata: AscendCommonAttentionMetadata,
286+
common_attn_metadata: "AscendCommonAttentionMetadata",
224287
model: nn.Module,
225288
):
226289
num_reqs = common_attn_metadata.num_reqs
@@ -244,28 +307,12 @@ def build(
244307
query_start_loc = query_start_loc_cpu.to(self.device,
245308
non_blocking=True)
246309

247-
if is_310p():
248-
if attn_state == AscendAttentionState.PrefillNoCache:
249-
mask_nz = nd_to_nz_2d(attn_mask)
250-
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
251-
ACL_FORMAT_FRACTAL_NZ)
252-
elif attn_state == AscendAttentionState.ChunkedPrefill:
253-
mask_nz = nd_to_nz_spec(attn_mask)
254-
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
255-
ACL_FORMAT_FRACTAL_NZ)
256-
257-
attn_metadata = AscendMetadata(
258-
num_actual_tokens=num_actual_tokens,
259-
block_tables=block_table,
260-
query_start_loc=query_start_loc,
261-
query_lens=query_lens,
262-
seq_lens=seq_lens,
263-
max_query_len=common_attn_metadata.max_query_len,
264-
slot_mapping=slot_mapping,
265-
attn_mask=attn_mask,
266-
attn_state=attn_state,
267-
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
268-
is_only_prefill=common_attn_metadata.is_only_prefill)
310+
build_info = self._assemble_build_info(num_actual_tokens, block_table,
311+
query_start_loc, query_lens,
312+
seq_lens, slot_mapping,
313+
attn_mask, attn_state)
314+
attn_metadata = self._assemble_attn_metadata(build_info,
315+
common_attn_metadata)
269316
return attn_metadata
270317

271318

vllm_ascend/torchair/torchair_attention.py

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,16 @@
2020

2121
import numpy as np
2222
import torch
23-
import torch.nn as nn
2423
import torch_npu
2524
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
2625
AttentionType)
2726
from vllm.attention.backends.utils import PAD_SLOT_ID
2827
from vllm.config import VllmConfig
2928
from vllm.utils import cdiv
3029

31-
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
32-
AscendAttentionMetadataBuilder,
33-
AscendAttentionState,
34-
AscendMetadata)
30+
from vllm_ascend.attention.attention_v1 import (
31+
AscendAttentionBackend, AscendAttentionMetadataBuilder,
32+
AscendAttentionMetadataBuildInfo, AscendAttentionState, AscendMetadata)
3533
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
3634
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
3735
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
@@ -171,45 +169,52 @@ def build_torchair_graph_dummy(
171169
decode=decode_metadata)
172170
return attn_metadata
173171

174-
def build(
172+
def _assemble_build_info(
175173
self,
176-
common_prefix_len: int,
177-
common_attn_metadata: AscendCommonAttentionMetadata,
178-
model: nn.Module,
179-
):
180-
num_reqs = common_attn_metadata.num_reqs
181-
num_actual_tokens = common_attn_metadata.num_actual_tokens
182-
183-
block_table = common_attn_metadata.block_table_tensor
184-
block_table[:num_reqs, :self.max_num_blocks_per_req] = (
185-
block_table[:num_reqs])
186-
187-
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
188-
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
189-
num_actual_tokens].to(
190-
self.device,
191-
non_blocking=
192-
True)
193-
attn_mask = common_attn_metadata.attn_mask
194-
195-
attn_state = common_attn_metadata.attn_state
174+
num_actual_tokens,
175+
block_table,
176+
query_start_loc,
177+
query_lens,
178+
seq_lens,
179+
slot_mapping,
180+
attn_mask,
181+
attn_state: "AscendAttentionState",
182+
) -> "AscendAttentionMetadataBuildInfo":
196183
if is_310p() and attn_state == AscendAttentionState.PrefillNoCache:
197184
mask_nz = nd_to_nz_2d(attn_mask)
198185
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29)
199186

200-
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
201-
num_reqs
202-
+ 1]
203-
query_start_loc = query_start_loc_cpu.to(self.device,
204-
non_blocking=True)
205-
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
187+
build_info = AscendAttentionMetadataBuildInfo(
188+
num_actual_tokens=num_actual_tokens,
189+
block_table=block_table,
190+
query_start_loc=query_start_loc,
191+
query_lens=query_lens,
192+
seq_lens=seq_lens,
193+
slot_mapping=slot_mapping,
194+
attn_mask=attn_mask,
195+
attn_state=attn_state)
196+
return build_info
197+
198+
def _assemble_attn_metadata(
199+
self,
200+
build_info: "AscendAttentionMetadataBuildInfo",
201+
common_attn_metadata: "AscendCommonAttentionMetadata",
202+
) -> "AscendMetadata":
203+
num_actual_tokens = build_info.num_actual_tokens
204+
block_table = build_info.block_table
205+
seq_lens = build_info.seq_lens
206+
slot_mapping = build_info.slot_mapping
207+
attn_state = build_info.attn_state
208+
209+
num_reqs = common_attn_metadata.num_reqs
206210
input_positions = common_attn_metadata.positions[:
207211
num_actual_tokens].long(
208212
)
213+
graph_pad_size = common_attn_metadata.graph_pad_size
209214

210215
decode_metadata = None
211-
graph_pad_size = common_attn_metadata.graph_pad_size
212216
use_torchair_graph = graph_pad_size > -1
217+
213218
if common_attn_metadata.attn_state in [
214219
AscendAttentionState.DecodeOnly,
215220
]:
@@ -262,12 +267,12 @@ def build(
262267
decode=decode_metadata,
263268
num_actual_tokens=num_actual_tokens,
264269
block_tables=block_table,
265-
query_start_loc=query_start_loc,
266-
query_lens=query_lens,
270+
query_start_loc=build_info.query_start_loc,
271+
query_lens=build_info.query_lens,
267272
seq_lens=seq_lens,
268273
max_query_len=common_attn_metadata.max_query_len,
269274
slot_mapping=slot_mapping,
270-
attn_mask=attn_mask,
275+
attn_mask=build_info.attn_mask,
271276
attn_state=attn_state,
272277
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
273278
return attn_metadata

0 commit comments

Comments
 (0)