@@ -192,6 +192,18 @@ class AscendMetadata:
192
192
is_only_prefill : bool = False
193
193
194
194
195
+ @dataclass
196
+ class AscendAttentionMetadataBuildInfo :
197
+ num_actual_tokens : int
198
+ block_table : torch .Tensor
199
+ query_start_loc : torch .Tensor
200
+ query_lens : torch .Tensor
201
+ seq_lens : torch .Tensor
202
+ slot_mapping : torch .Tensor
203
+ attn_mask : torch .Tensor
204
+ attn_state : AscendAttentionState
205
+
206
+
195
207
class AscendAttentionMetadataBuilder :
196
208
197
209
def __init__ (
@@ -209,9 +221,60 @@ def reorder_batch(self, input_batch: "InputBatch",
209
221
scheduler_output : "SchedulerOutput" ) -> bool :
210
222
return False
211
223
224
+ def _assemble_build_info (
225
+ self ,
226
+ num_actual_tokens ,
227
+ block_table ,
228
+ query_start_loc ,
229
+ query_lens ,
230
+ seq_lens ,
231
+ slot_mapping ,
232
+ attn_mask ,
233
+ attn_state : "AscendAttentionState" ,
234
+ ) -> "AscendAttentionMetadataBuildInfo" :
235
+ if is_310p ():
236
+ if attn_state == AscendAttentionState .PrefillNoCache :
237
+ mask_nz = nd_to_nz_2d (attn_mask )
238
+ attn_mask = torch_npu .npu_format_cast (mask_nz .contiguous (),
239
+ ACL_FORMAT_FRACTAL_NZ )
240
+ elif attn_state == AscendAttentionState .ChunkedPrefill :
241
+ mask_nz = nd_to_nz_spec (attn_mask )
242
+ attn_mask = torch_npu .npu_format_cast (mask_nz .contiguous (),
243
+ ACL_FORMAT_FRACTAL_NZ )
244
+
245
+ build_info = AscendAttentionMetadataBuildInfo (
246
+ num_actual_tokens = num_actual_tokens ,
247
+ block_table = block_table ,
248
+ query_start_loc = query_start_loc ,
249
+ query_lens = query_lens ,
250
+ seq_lens = seq_lens ,
251
+ slot_mapping = slot_mapping ,
252
+ attn_mask = attn_mask ,
253
+ attn_state = attn_state )
254
+ return build_info
255
+
256
+ def _assemble_attn_metadata (
257
+ self ,
258
+ build_info : "AscendAttentionMetadataBuildInfo" ,
259
+ common_attn_metadata : "AscendCommonAttentionMetadata" ,
260
+ ) -> "AscendMetadata" :
261
+ attn_metadata = AscendMetadata (
262
+ num_actual_tokens = build_info .num_actual_tokens ,
263
+ block_tables = build_info .block_table ,
264
+ query_start_loc = build_info .query_start_loc ,
265
+ query_lens = build_info .query_lens ,
266
+ seq_lens = build_info .seq_lens ,
267
+ max_query_len = common_attn_metadata .max_query_len ,
268
+ slot_mapping = build_info .slot_mapping ,
269
+ attn_mask = build_info .attn_mask ,
270
+ attn_state = build_info .attn_state ,
271
+ enable_dbo_across_dp = common_attn_metadata .enable_dbo_across_dp ,
272
+ is_only_prefill = common_attn_metadata .is_only_prefill )
273
+ return attn_metadata
274
+
212
275
def build (
213
276
self ,
214
- common_attn_metadata : AscendCommonAttentionMetadata ,
277
+ common_attn_metadata : " AscendCommonAttentionMetadata" ,
215
278
model : nn .Module ,
216
279
):
217
280
num_reqs = common_attn_metadata .num_reqs
@@ -239,28 +302,12 @@ def build(
239
302
query_start_loc = query_start_loc_cpu .to (self .device ,
240
303
non_blocking = True )
241
304
242
- if is_310p ():
243
- if attn_state == AscendAttentionState .PrefillNoCache :
244
- mask_nz = nd_to_nz_2d (attn_mask )
245
- attn_mask = torch_npu .npu_format_cast (mask_nz .contiguous (),
246
- ACL_FORMAT_FRACTAL_NZ )
247
- elif attn_state == AscendAttentionState .ChunkedPrefill :
248
- mask_nz = nd_to_nz_spec (attn_mask )
249
- attn_mask = torch_npu .npu_format_cast (mask_nz .contiguous (),
250
- ACL_FORMAT_FRACTAL_NZ )
251
-
252
- attn_metadata = AscendMetadata (
253
- num_actual_tokens = num_actual_tokens ,
254
- block_tables = block_table ,
255
- query_start_loc = query_start_loc ,
256
- query_lens = query_lens ,
257
- seq_lens = seq_lens ,
258
- max_query_len = common_attn_metadata .max_query_len ,
259
- slot_mapping = slot_mapping ,
260
- attn_mask = attn_mask ,
261
- attn_state = attn_state ,
262
- enable_dbo_across_dp = common_attn_metadata .enable_dbo_across_dp ,
263
- is_only_prefill = common_attn_metadata .is_only_prefill )
305
+ build_info = self ._assemble_build_info (num_actual_tokens , block_table ,
306
+ query_start_loc , query_lens ,
307
+ seq_lens , slot_mapping ,
308
+ attn_mask , attn_state )
309
+ attn_metadata = self ._assemble_attn_metadata (build_info ,
310
+ common_attn_metadata )
264
311
return attn_metadata
265
312
266
313
0 commit comments