@@ -196,6 +196,18 @@ class AscendMetadata:
196
196
is_only_prefill : bool = False
197
197
198
198
199
+ @dataclass
200
+ class AscendAttentionMetadataBuildInfo :
201
+ num_actual_tokens : int
202
+ block_table : torch .Tensor
203
+ query_start_loc : torch .Tensor
204
+ query_lens : torch .Tensor
205
+ seq_lens : torch .Tensor
206
+ slot_mapping : torch .Tensor
207
+ attn_mask : torch .Tensor
208
+ attn_state : AscendAttentionState
209
+
210
+
199
211
class AscendAttentionMetadataBuilder :
200
212
reorder_batch_threshold : ClassVar [int ] = 1
201
213
@@ -217,10 +229,61 @@ def reorder_batch(self, input_batch,
217
229
scheduler_output : "SchedulerOutput" ) -> bool :
218
230
return False
219
231
232
+ def _assemble_build_info (
233
+ self ,
234
+ num_actual_tokens ,
235
+ block_table ,
236
+ query_start_loc ,
237
+ query_lens ,
238
+ seq_lens ,
239
+ slot_mapping ,
240
+ attn_mask ,
241
+ attn_state : "AscendAttentionState" ,
242
+ ) -> "AscendAttentionMetadataBuildInfo" :
243
+ if is_310p ():
244
+ if attn_state == AscendAttentionState .PrefillNoCache :
245
+ mask_nz = nd_to_nz_2d (attn_mask )
246
+ attn_mask = torch_npu .npu_format_cast (mask_nz .contiguous (),
247
+ ACL_FORMAT_FRACTAL_NZ )
248
+ elif attn_state == AscendAttentionState .ChunkedPrefill :
249
+ mask_nz = nd_to_nz_spec (attn_mask )
250
+ attn_mask = torch_npu .npu_format_cast (mask_nz .contiguous (),
251
+ ACL_FORMAT_FRACTAL_NZ )
252
+
253
+ build_info = AscendAttentionMetadataBuildInfo (
254
+ num_actual_tokens = num_actual_tokens ,
255
+ block_table = block_table ,
256
+ query_start_loc = query_start_loc ,
257
+ query_lens = query_lens ,
258
+ seq_lens = seq_lens ,
259
+ slot_mapping = slot_mapping ,
260
+ attn_mask = attn_mask ,
261
+ attn_state = attn_state )
262
+ return build_info
263
+
264
+ def _assemble_attn_metadata (
265
+ self ,
266
+ build_info : "AscendAttentionMetadataBuildInfo" ,
267
+ common_attn_metadata : "AscendCommonAttentionMetadata" ,
268
+ ) -> "AscendMetadata" :
269
+ attn_metadata = AscendMetadata (
270
+ num_actual_tokens = build_info .num_actual_tokens ,
271
+ block_tables = build_info .block_table ,
272
+ query_start_loc = build_info .query_start_loc ,
273
+ query_lens = build_info .query_lens ,
274
+ seq_lens = build_info .seq_lens ,
275
+ max_query_len = common_attn_metadata .max_query_len ,
276
+ slot_mapping = build_info .slot_mapping ,
277
+ attn_mask = build_info .attn_mask ,
278
+ attn_state = build_info .attn_state ,
279
+ enable_dbo_across_dp = common_attn_metadata .enable_dbo_across_dp ,
280
+ is_only_prefill = common_attn_metadata .is_only_prefill )
281
+ return attn_metadata
282
+
220
283
def build (
221
284
self ,
222
285
common_prefix_len : int ,
223
- common_attn_metadata : AscendCommonAttentionMetadata ,
286
+ common_attn_metadata : " AscendCommonAttentionMetadata" ,
224
287
model : nn .Module ,
225
288
):
226
289
num_reqs = common_attn_metadata .num_reqs
@@ -244,28 +307,12 @@ def build(
244
307
query_start_loc = query_start_loc_cpu .to (self .device ,
245
308
non_blocking = True )
246
309
247
- if is_310p ():
248
- if attn_state == AscendAttentionState .PrefillNoCache :
249
- mask_nz = nd_to_nz_2d (attn_mask )
250
- attn_mask = torch_npu .npu_format_cast (mask_nz .contiguous (),
251
- ACL_FORMAT_FRACTAL_NZ )
252
- elif attn_state == AscendAttentionState .ChunkedPrefill :
253
- mask_nz = nd_to_nz_spec (attn_mask )
254
- attn_mask = torch_npu .npu_format_cast (mask_nz .contiguous (),
255
- ACL_FORMAT_FRACTAL_NZ )
256
-
257
- attn_metadata = AscendMetadata (
258
- num_actual_tokens = num_actual_tokens ,
259
- block_tables = block_table ,
260
- query_start_loc = query_start_loc ,
261
- query_lens = query_lens ,
262
- seq_lens = seq_lens ,
263
- max_query_len = common_attn_metadata .max_query_len ,
264
- slot_mapping = slot_mapping ,
265
- attn_mask = attn_mask ,
266
- attn_state = attn_state ,
267
- enable_dbo_across_dp = common_attn_metadata .enable_dbo_across_dp ,
268
- is_only_prefill = common_attn_metadata .is_only_prefill )
310
+ build_info = self ._assemble_build_info (num_actual_tokens , block_table ,
311
+ query_start_loc , query_lens ,
312
+ seq_lens , slot_mapping ,
313
+ attn_mask , attn_state )
314
+ attn_metadata = self ._assemble_attn_metadata (build_info ,
315
+ common_attn_metadata )
269
316
return attn_metadata
270
317
271
318
0 commit comments