@@ -435,6 +435,41 @@ def test_forward_decode_only_swa(self, mock_fused_infer_attention_score,
435
435
mock_fused_infer_attention_score .assert_called_once ()
436
436
assert output .shape == (10 , 8 * 64 )
437
437
438
+ @patch ('torch_npu._npu_reshape_and_cache' )
439
+ @patch ('torch_npu._npu_paged_attention' )
440
+ @patch ('torch_npu.npu_fused_infer_attention_score' )
441
+ def test_forward_decode_only_swa_seq_len_mismatch (
442
+ self , mock_fused_infer_attention_score , mock_paged_attention ,
443
+ mock_npu_reshape_and_cache ):
444
+ """Test forward pass in DecodeOnly state when seq)len_mismatch"""
445
+ query = torch .randn (10 , 8 * 64 )
446
+ key = torch .randn (10 , 8 * 64 )
447
+ value = torch .randn (10 , 8 * 64 )
448
+ kv_cache = torch .empty (2 , 5 , 128 , 8 , 64 )
449
+
450
+ metadata = self .attn_metadata
451
+ metadata .attn_state = AscendAttentionState .DecodeOnly
452
+ metadata .seq_lens = torch .tensor ([10 ]) # len == 1 != query.size(0)==10
453
+ metadata .block_tables = torch .zeros (1 , 5 , dtype = torch .long )
454
+ metadata .num_actual_tokens = 10
455
+ metadata .slot_mapping = torch .zeros (10 , dtype = torch .long )
456
+
457
+ mock_fused_infer_attention_score .return_value = (torch .ones (10 , 8 ,
458
+ 64 ), 1 )
459
+
460
+ output = self .impl_swa .forward (self .layer_no_quant ,
461
+ query ,
462
+ key ,
463
+ value ,
464
+ kv_cache ,
465
+ metadata ,
466
+ trace_flag = False )
467
+
468
+ mock_paged_attention .assert_called_once ()
469
+ mock_fused_infer_attention_score .assert_not_called ()
470
+
471
+ assert output .shape == (10 , 8 * 64 )
472
+
438
473
@patch ('vllm_ascend.attention.attention_v1.is_310p' , return_value = False )
439
474
@patch ('torch_npu._npu_reshape_and_cache' )
440
475
@patch ('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill' )
0 commit comments