@@ -566,13 +566,11 @@ def forward(
566
566
value_cache = value_cache .view (num_blocks , block_size ,
567
567
self .num_kv_heads , self .head_size )
568
568
slots = attn_metadata .slot_mapping
569
- torch_npu .npu_reshapecache (key = key ,
570
- value = value ,
571
- keyCache = key_cache ,
572
- valueCache = value_cache ,
573
- slotMapping = slots ,
574
- compressType = 0 ,
575
- kvCacheCfg = 0 )
569
+ torch_npu ._npu_reshape_and_cache (key = key ,
570
+ value = value ,
571
+ key_cache = key_cache ,
572
+ value_cache = value_cache ,
573
+ slot_indices = slots )
576
574
577
575
if attn_metadata .num_prefills > 0 :
578
576
@@ -581,16 +579,16 @@ def forward(
581
579
assert attn_metadata .attn_mask is not None
582
580
mask = attn_metadata .attn_mask
583
581
self .seq_lens_tensor_cpu = torch .from_numpy (np .array (attn_metadata .prefill_metadata .seq_lens ).astype (np .int32 ))
584
- torch_npu .npu_selfattention ( query = query , key = key , value = value ,
585
- mask = mask , maskType = 1 , isTriuMask = 0 ,
586
- seqLen = self . seq_lens_tensor_cpu ,
587
- scale = self . scale , qScale = 1 ,
588
- headNum = self . num_heads , kvHeadNum = self . num_kv_heads , mlaVHeadSize = 0 ,
589
- calcType = 3 , kernelType = 0 , clampType = 0 ,
590
- scaleType = 0 , quantType = 0 , cacheType = 0 ,
591
- batchRunStatusEnable = False , kvcacheCfg = 0 ,
592
- clampMin = 0 , clampMax = 0 , inputLayout = 0 ,
593
- windowSize = 0 , outDataType = 0 , out = output )
582
+ torch_npu ._npu_flash_attention (
583
+ query = query ,
584
+ key = key ,
585
+ value = value ,
586
+ mask = mask ,
587
+ seq_len = self . seq_lens_tensor_cpu ,
588
+ scale_value = self . scale ,
589
+ num_heads = self . num_heads ,
590
+ num_kv_heads = self . num_kv_heads ,
591
+ out = output )
594
592
else :
595
593
# TODO: Will support prefix cache and chunked prefill soon.
596
594
raise RuntimeError (
@@ -600,13 +598,16 @@ def forward(
600
598
assert kv_cache is not None
601
599
self .seq_lens_tensor_cpu = torch .from_numpy (np .array (attn_metadata .decode_metadata .seq_lens ).astype (np .int32 ))
602
600
block_tables = attn_metadata .decode_metadata .block_tables
603
- torch_npu .npu_pagedattention (query = query , keyCache = key_cache , valueCache = value_cache ,
604
- contextLens = self .seq_lens_tensor_cpu , maskType = 0 ,
605
- kvHeadNum = self .num_kv_heads , headNum = self .num_heads , mlaVHeadSize = 0 ,
606
- qkScale = self .scale , scaleType = 0 , blockTables = block_tables ,
607
- batchRunStatusEnable = False , hasQuantOffset = False ,
608
- calcType = 3 , quantType = 0 , compressType = 0 ,
609
- inputLayout = 0 , outDataType = 0 , attnOut = output )
601
+ torch_npu ._npu_paged_attention (
602
+ query = query ,
603
+ key_cache = key_cache ,
604
+ value_cache = value_cache ,
605
+ num_kv_heads = self .num_kv_heads ,
606
+ num_heads = self .num_heads ,
607
+ scale_value = self .scale ,
608
+ block_table = block_tables ,
609
+ context_lens = self .seq_lens_tensor_cpu ,
610
+ out = output )
610
611
611
612
return output .view (num_tokens , self .hidden_size )
612
613
@@ -743,13 +744,9 @@ def forward(
743
744
key_cache = key_cache .view (num_blocks , block_size , self .num_kv_heads ,
744
745
self .qk_rope_head_dim + self .kv_lora_rank )
745
746
slots = attn_metadata .slot_mapping
746
- torch_npu .npu_reshapecache (key = k_cache ,
747
- value = None ,
748
- keyCache = key_cache ,
749
- valueCache = None ,
750
- slotMapping = slots ,
751
- compressType = 0 ,
752
- kvCacheCfg = 1 )
747
+ torch_npu ._npu_reshape_and_cache_siso (key = k_cache ,
748
+ key_cache = key_cache ,
749
+ slot_indices = slots )
753
750
754
751
if attn_metadata .num_prefills > 0 :
755
752
attn_output = torch .empty (num_tokens ,
@@ -762,16 +759,16 @@ def forward(
762
759
assert attn_metadata .attn_mask is not None
763
760
mask = attn_metadata .attn_mask
764
761
self .seq_lens_tensor_cpu = torch .from_numpy (np .array (attn_metadata .prefill_metadata .seq_lens ).astype (np .int32 ))
765
- torch_npu .npu_selfattention ( query = query , key = key , value = value , kvcacheCfg = 0 ,
766
- mask = mask , maskType = 1 , isTriuMask = 0 ,
767
- seqLen = self . seq_lens_tensor_cpu ,
768
- scale = self . scale , qScale = 1 , scaleType = 0 ,
769
- headNum = self . num_heads , kvHeadNum = self . num_heads , mlaVHeadSize = 0 ,
770
- calcType = 3 , kernelType = 0 , clampType = 0 ,
771
- quantType = 0 , cacheType = 0 , windowSize = 0 ,
772
- clampMin = 0 , clampMax = 0 ,
773
- batchRunStatusEnable = False , inputLayout = 0 ,
774
- outDataType = 0 , out = attn_output )
762
+ torch_npu ._npu_flash_attention (
763
+ query = query ,
764
+ key = key ,
765
+ value = value ,
766
+ mask = mask ,
767
+ seq_len = self . seq_lens_tensor_cpu ,
768
+ scale_value = self . scale ,
769
+ num_heads = self . num_heads ,
770
+ num_kv_heads = self . num_heads ,
771
+ out = attn_output )
775
772
else :
776
773
# TODO: Will support prefix cache and chunked prefill soon.
777
774
raise RuntimeError (
@@ -786,15 +783,16 @@ def forward(
786
783
device = "npu" )
787
784
self .seq_lens_tensor_cpu = torch .from_numpy (np .array (attn_metadata .decode_metadata .seq_lens ).astype (np .int32 ))
788
785
block_tables = attn_metadata .decode_metadata .block_tables
789
- torch_npu .npu_pagedattention (query = query , keyCache = key_cache , valueCache = None ,
790
- contextLens = self .seq_lens_tensor_cpu ,
791
- maskType = 0 ,
792
- kvHeadNum = self .num_kv_heads , headNum = self .num_heads ,
793
- mlaVHeadSize = self .kv_lora_rank ,
794
- qkScale = self .scale , blockTables = block_tables ,
795
- batchRunStatusEnable = False , hasQuantOffset = False ,
796
- compressType = 0 , calcType = 0 , scaleType = 0 , quantType = 0 ,
797
- inputLayout = 0 , outDataType = - 1 , attnOut = attn_output )
786
+ torch_npu ._npu_paged_attention_mla (
787
+ query = query ,
788
+ key_cache = key_cache ,
789
+ num_kv_heads = self .num_kv_heads ,
790
+ num_heads = self .num_heads ,
791
+ scale_value = self .scale ,
792
+ block_table = block_tables ,
793
+ context_lens = self .seq_lens_tensor_cpu ,
794
+ mla_vheadsize = self .kv_lora_rank ,
795
+ out = attn_output )
798
796
attn_output_t = torch_npu .npu_transpose (attn_output , (1 , 0 , 2 ), require_contiguous = True )
799
797
attn_output_t = torch .bmm (attn_output_t , self .w_vc )
800
798
attn_output = torch_npu .npu_transpose (attn_output_t , (1 , 0 , 2 ), require_contiguous = True )
0 commit comments