Skip to content

Commit ea15107

Browse files
committed
Refactor AscendAttentionMetadataBuilder for better extensibility and make the builder class of torchair extend from it
Signed-off-by: shen-shanshan <467638484@qq.com>
1 parent 2693196 commit ea15107

File tree

2 files changed

+111
-58
lines changed

2 files changed

+111
-58
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,18 @@ class AscendMetadata:
192192
is_only_prefill: bool = False
193193

194194

195+
@dataclass
196+
class AscendAttentionMetadataBuildInfo:
197+
num_actual_tokens: int
198+
block_table: torch.Tensor
199+
query_start_loc: torch.Tensor
200+
query_lens: torch.Tensor
201+
seq_lens: torch.Tensor
202+
slot_mapping: torch.Tensor
203+
attn_mask: torch.Tensor
204+
attn_state: AscendAttentionState
205+
206+
195207
class AscendAttentionMetadataBuilder:
196208

197209
def __init__(
@@ -209,9 +221,60 @@ def reorder_batch(self, input_batch: "InputBatch",
209221
scheduler_output: "SchedulerOutput") -> bool:
210222
return False
211223

224+
def _assemble_build_info(
225+
self,
226+
num_actual_tokens,
227+
block_table,
228+
query_start_loc,
229+
query_lens,
230+
seq_lens,
231+
slot_mapping,
232+
attn_mask,
233+
attn_state: "AscendAttentionState",
234+
) -> "AscendAttentionMetadataBuildInfo":
235+
if is_310p():
236+
if attn_state == AscendAttentionState.PrefillNoCache:
237+
mask_nz = nd_to_nz_2d(attn_mask)
238+
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
239+
ACL_FORMAT_FRACTAL_NZ)
240+
elif attn_state == AscendAttentionState.ChunkedPrefill:
241+
mask_nz = nd_to_nz_spec(attn_mask)
242+
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
243+
ACL_FORMAT_FRACTAL_NZ)
244+
245+
build_info = AscendAttentionMetadataBuildInfo(
246+
num_actual_tokens=num_actual_tokens,
247+
block_table=block_table,
248+
query_start_loc=query_start_loc,
249+
query_lens=query_lens,
250+
seq_lens=seq_lens,
251+
slot_mapping=slot_mapping,
252+
attn_mask=attn_mask,
253+
attn_state=attn_state)
254+
return build_info
255+
256+
def _assemble_attn_metadata(
257+
self,
258+
build_info: "AscendAttentionMetadataBuildInfo",
259+
common_attn_metadata: "AscendCommonAttentionMetadata",
260+
) -> "AscendMetadata":
261+
attn_metadata = AscendMetadata(
262+
num_actual_tokens=build_info.num_actual_tokens,
263+
block_tables=build_info.block_table,
264+
query_start_loc=build_info.query_start_loc,
265+
query_lens=build_info.query_lens,
266+
seq_lens=build_info.seq_lens,
267+
max_query_len=common_attn_metadata.max_query_len,
268+
slot_mapping=build_info.slot_mapping,
269+
attn_mask=build_info.attn_mask,
270+
attn_state=build_info.attn_state,
271+
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
272+
is_only_prefill=common_attn_metadata.is_only_prefill)
273+
return attn_metadata
274+
212275
def build(
213276
self,
214-
common_attn_metadata: AscendCommonAttentionMetadata,
277+
common_attn_metadata: "AscendCommonAttentionMetadata",
215278
model: nn.Module,
216279
):
217280
num_reqs = common_attn_metadata.num_reqs
@@ -239,28 +302,12 @@ def build(
239302
query_start_loc = query_start_loc_cpu.to(self.device,
240303
non_blocking=True)
241304

242-
if is_310p():
243-
if attn_state == AscendAttentionState.PrefillNoCache:
244-
mask_nz = nd_to_nz_2d(attn_mask)
245-
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
246-
ACL_FORMAT_FRACTAL_NZ)
247-
elif attn_state == AscendAttentionState.ChunkedPrefill:
248-
mask_nz = nd_to_nz_spec(attn_mask)
249-
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
250-
ACL_FORMAT_FRACTAL_NZ)
251-
252-
attn_metadata = AscendMetadata(
253-
num_actual_tokens=num_actual_tokens,
254-
block_tables=block_table,
255-
query_start_loc=query_start_loc,
256-
query_lens=query_lens,
257-
seq_lens=seq_lens,
258-
max_query_len=common_attn_metadata.max_query_len,
259-
slot_mapping=slot_mapping,
260-
attn_mask=attn_mask,
261-
attn_state=attn_state,
262-
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
263-
is_only_prefill=common_attn_metadata.is_only_prefill)
305+
build_info = self._assemble_build_info(num_actual_tokens, block_table,
306+
query_start_loc, query_lens,
307+
seq_lens, slot_mapping,
308+
attn_mask, attn_state)
309+
attn_metadata = self._assemble_attn_metadata(build_info,
310+
common_attn_metadata)
264311
return attn_metadata
265312

266313

vllm_ascend/torchair/torchair_attention.py

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,16 @@
2020

2121
import numpy as np
2222
import torch
23-
import torch.nn as nn
2423
import torch_npu
2524
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
2625
AttentionType)
2726
from vllm.attention.backends.utils import PAD_SLOT_ID
2827
from vllm.config import VllmConfig
2928
from vllm.utils import cdiv
3029

31-
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
32-
AscendAttentionMetadataBuilder,
33-
AscendAttentionState,
34-
AscendMetadata)
30+
from vllm_ascend.attention.attention_v1 import (
31+
AscendAttentionBackend, AscendAttentionMetadataBuilder,
32+
AscendAttentionMetadataBuildInfo, AscendAttentionState, AscendMetadata)
3533
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
3634
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
3735
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
@@ -169,44 +167,52 @@ def build_torchair_graph_dummy(
169167
decode=decode_metadata)
170168
return attn_metadata
171169

172-
def build(
170+
def _assemble_build_info(
173171
self,
174-
common_attn_metadata: AscendCommonAttentionMetadata,
175-
model: nn.Module,
176-
):
177-
num_reqs = common_attn_metadata.num_reqs
178-
num_actual_tokens = common_attn_metadata.num_actual_tokens
179-
180-
block_table = common_attn_metadata.block_table_tensor
181-
block_table[:num_reqs, :self.max_num_blocks_per_req] = (
182-
block_table[:num_reqs])
183-
184-
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
185-
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
186-
num_actual_tokens].to(
187-
self.device,
188-
non_blocking=
189-
True)
190-
attn_mask = common_attn_metadata.attn_mask
191-
192-
attn_state = common_attn_metadata.attn_state
172+
num_actual_tokens,
173+
block_table,
174+
query_start_loc,
175+
query_lens,
176+
seq_lens,
177+
slot_mapping,
178+
attn_mask,
179+
attn_state: "AscendAttentionState",
180+
) -> "AscendAttentionMetadataBuildInfo":
193181
if is_310p() and attn_state == AscendAttentionState.PrefillNoCache:
194182
mask_nz = nd_to_nz_2d(attn_mask)
195183
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29)
196184

197-
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
198-
num_reqs
199-
+ 1]
200-
query_start_loc = query_start_loc_cpu.to(self.device,
201-
non_blocking=True)
202-
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
185+
build_info = AscendAttentionMetadataBuildInfo(
186+
num_actual_tokens=num_actual_tokens,
187+
block_table=block_table,
188+
query_start_loc=query_start_loc,
189+
query_lens=query_lens,
190+
seq_lens=seq_lens,
191+
slot_mapping=slot_mapping,
192+
attn_mask=attn_mask,
193+
attn_state=attn_state)
194+
return build_info
195+
196+
def _assemble_attn_metadata(
197+
self,
198+
build_info: "AscendAttentionMetadataBuildInfo",
199+
common_attn_metadata: "AscendCommonAttentionMetadata",
200+
) -> "AscendMetadata":
201+
num_actual_tokens = build_info.num_actual_tokens
202+
block_table = build_info.block_table
203+
seq_lens = build_info.seq_lens
204+
slot_mapping = build_info.slot_mapping
205+
attn_state = build_info.attn_state
206+
207+
num_reqs = common_attn_metadata.num_reqs
203208
input_positions = common_attn_metadata.positions[:
204209
num_actual_tokens].long(
205210
)
211+
graph_pad_size = common_attn_metadata.graph_pad_size
206212

207213
decode_metadata = None
208-
graph_pad_size = common_attn_metadata.graph_pad_size
209214
use_torchair_graph = graph_pad_size > -1
215+
210216
if common_attn_metadata.attn_state in [
211217
AscendAttentionState.DecodeOnly,
212218
]:
@@ -259,12 +265,12 @@ def build(
259265
decode=decode_metadata,
260266
num_actual_tokens=num_actual_tokens,
261267
block_tables=block_table,
262-
query_start_loc=query_start_loc,
263-
query_lens=query_lens,
268+
query_start_loc=build_info.query_start_loc,
269+
query_lens=build_info.query_lens,
264270
seq_lens=seq_lens,
265271
max_query_len=common_attn_metadata.max_query_len,
266272
slot_mapping=slot_mapping,
267-
attn_mask=attn_mask,
273+
attn_mask=build_info.attn_mask,
268274
attn_state=attn_state,
269275
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
270276
return attn_metadata

0 commit comments

Comments
 (0)