@@ -455,18 +455,25 @@ def _forward_v1_style(
455
455
attn_metadata .seq_lens = \
456
456
attn_metadata .seq_lens .to (device = query .device )
457
457
458
- torch_npu ._npu_paged_attention_splitfuse (
458
+ num_block , block_size , head_num , head_dim = self .key_cache .shape
459
+ key = self .key_cache .view (num_block , block_size , - 1 )
460
+ value = self .value_cache .view (num_block , block_size , - 1 )
461
+
462
+ output , _ = torch_npu .npu_fused_infer_attention_score (
459
463
query = query ,
460
- key_cache = self . key_cache ,
461
- value_cache = self . value_cache ,
462
- mask = attn_metadata .attn_mask ,
464
+ key = key ,
465
+ value = value ,
466
+ atten_mask = attn_metadata .attn_mask . to ( device = query . device ) ,
463
467
block_table = attn_metadata .block_tables ,
464
- seq_len = attn_metadata .query_lens ,
465
- context_lens = attn_metadata .seq_lens ,
466
- num_kv_heads = self .num_kv_heads ,
468
+ input_layout = "TND" ,
469
+ block_size = block_size ,
470
+ actual_seq_lengths = attn_metadata .query_start_loc [1 :],
471
+ actual_seq_lengths_kv = attn_metadata .seq_lens ,
472
+ num_key_value_heads = self .num_kv_heads ,
467
473
num_heads = self .num_heads ,
468
- scale_value = self .scale ,
469
- out = output )
474
+ scale = self .scale ,
475
+ sparse_mode = 3 ,
476
+ )
470
477
return output
471
478
472
479
def forward (
0 commit comments