@@ -195,7 +195,8 @@ def test_ascend_mla_metadata_builder_default(self):
195
195
ascend_config .torchair_graph_config .enabled = True
196
196
with patch ("vllm_ascend.torchair.torchair_mla.get_ascend_config" ,
197
197
return_value = ascend_config ):
198
- builder = AscendMLATorchairMetadataBuilder (mock_vllm_config ,
198
+ builder = AscendMLATorchairMetadataBuilder (None , None ,
199
+ mock_vllm_config ,
199
200
mock_device )
200
201
201
202
self .assertEqual (builder .block_size ,
@@ -216,7 +217,8 @@ def test_reorder_batch_with_torchair_graph(self, ascend_config):
216
217
ascend_config .torchair_graph_config = MagicMock ()
217
218
ascend_config .torchair_graph_config .enabled = True
218
219
219
- builder = AscendMLATorchairMetadataBuilder (mock_vllm_config ,
220
+ builder = AscendMLATorchairMetadataBuilder (None , None ,
221
+ mock_vllm_config ,
220
222
mock_device )
221
223
222
224
input_batch = MagicMock ()
@@ -252,7 +254,8 @@ def test_reorder_batch_without_torchair_graph(self):
252
254
253
255
with patch ("vllm_ascend.torchair.torchair_mla.get_ascend_config" ,
254
256
return_value = ascend_config ):
255
- builder = AscendMLATorchairMetadataBuilder (mock_vllm_config ,
257
+ builder = AscendMLATorchairMetadataBuilder (None , None ,
258
+ mock_vllm_config ,
256
259
mock_device )
257
260
258
261
input_batch = MagicMock ()
@@ -285,7 +288,8 @@ def test_get_graph_runner_block_tables_normal(self, mock_ascend_config):
285
288
mock_vllm_config .scheduler_config .chunked_prefill_enabled = False
286
289
mock_device = 'cpu'
287
290
288
- builder = AscendMLATorchairMetadataBuilder (mock_vllm_config ,
291
+ builder = AscendMLATorchairMetadataBuilder (None , None ,
292
+ mock_vllm_config ,
289
293
mock_device )
290
294
block_tables = torch .randint (0 , 100 , (3 , 10 ), dtype = torch .int32 )
291
295
@@ -305,7 +309,8 @@ def test_get_graph_runner_block_tables_truncated(self, mock_ascend_config):
305
309
mock_vllm_config .scheduler_config .chunked_prefill_enabled = False
306
310
mock_device = 'cpu'
307
311
308
- builder = AscendMLATorchairMetadataBuilder (mock_vllm_config ,
312
+ builder = AscendMLATorchairMetadataBuilder (None , None ,
313
+ mock_vllm_config ,
309
314
mock_device )
310
315
block_tables = torch .randint (0 , 100 , (3 , 10 ), dtype = torch .int32 )
311
316
@@ -326,7 +331,8 @@ def test_get_graph_runner_block_tables_from_numpy(self,
326
331
mock_vllm_config .scheduler_config .chunked_prefill_enabled = False
327
332
mock_device = 'cpu'
328
333
329
- builder = AscendMLATorchairMetadataBuilder (mock_vllm_config ,
334
+ builder = AscendMLATorchairMetadataBuilder (None , None ,
335
+ mock_vllm_config ,
330
336
mock_device )
331
337
332
338
block_tables = torch .randint (0 , 100 , (3 , 10 ), dtype = torch .int32 )
@@ -352,6 +358,8 @@ def test_build_dummy(self, mock_ascend_config):
352
358
mock_device = 'cpu'
353
359
354
360
builder = AscendMLATorchairMetadataBuilder (
361
+ None ,
362
+ None ,
355
363
mock_vllm_config ,
356
364
mock_device ,
357
365
metadata_cls = AscendMLATorchairMetadata )
@@ -417,6 +425,8 @@ def test_build_decode(self, mock_ascend_config):
417
425
model .model = MagicMock (spec = nn .Module )
418
426
419
427
builder = AscendMLATorchairMetadataBuilder (
428
+ None ,
429
+ None ,
420
430
mock_vllm_config ,
421
431
mock_device ,
422
432
metadata_cls = AscendMLATorchairMetadata )
@@ -442,9 +452,11 @@ def test_build_decode(self, mock_ascend_config):
442
452
positions = torch .tensor ([1 , 1 ]),
443
453
attn_mask = torch .ones ((15 , 15 )),
444
454
spec_attn_mask = None ,
445
- attn_state = AscendAttentionState .ChunkedPrefill )
455
+ attn_state = AscendAttentionState .ChunkedPrefill ,
456
+ num_computed_tokens_cpu = None ,
457
+ seq_lens = None )
446
458
447
- metadata = builder .build (common_attn_metadata , model )
459
+ metadata = builder .build (1 , common_attn_metadata , model )
448
460
449
461
self .assertIsInstance (metadata , AscendMLATorchairMetadata )
450
462
self .assertEqual (metadata .num_input_tokens , 0 )
0 commit comments