Skip to content

Commit e8131b9

Browse files
baymax591白超
andauthored
Upgrade to newest pta (vllm-project#205)
Upgrade to newest pta Co-authored-by: 白超 <baichao19@huawei.com>
1 parent 263a62d commit e8131b9

File tree

2 files changed

+49
-51
lines changed

2 files changed

+49
-51
lines changed

vllm_ascend/attention.py

Lines changed: 48 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -566,13 +566,11 @@ def forward(
566566
value_cache = value_cache.view(num_blocks, block_size,
567567
self.num_kv_heads, self.head_size)
568568
slots = attn_metadata.slot_mapping
569-
torch_npu.npu_reshapecache(key=key,
570-
value=value,
571-
keyCache=key_cache,
572-
valueCache=value_cache,
573-
slotMapping=slots,
574-
compressType=0,
575-
kvCacheCfg=0)
569+
torch_npu._npu_reshape_and_cache(key=key,
570+
value=value,
571+
key_cache=key_cache,
572+
value_cache=value_cache,
573+
slot_indices=slots)
576574

577575
if attn_metadata.num_prefills > 0:
578576

@@ -581,16 +579,16 @@ def forward(
581579
assert attn_metadata.attn_mask is not None
582580
mask = attn_metadata.attn_mask
583581
self.seq_lens_tensor_cpu = torch.from_numpy(np.array(attn_metadata.prefill_metadata.seq_lens).astype(np.int32))
584-
torch_npu.npu_selfattention(query=query, key=key, value=value,
585-
mask=mask, maskType=1, isTriuMask=0,
586-
seqLen=self.seq_lens_tensor_cpu,
587-
scale=self.scale, qScale=1,
588-
headNum=self.num_heads, kvHeadNum=self.num_kv_heads, mlaVHeadSize=0,
589-
calcType=3, kernelType=0, clampType=0,
590-
scaleType=0, quantType=0, cacheType=0,
591-
batchRunStatusEnable=False, kvcacheCfg=0,
592-
clampMin=0, clampMax=0, inputLayout=0,
593-
windowSize=0, outDataType=0, out=output)
582+
torch_npu._npu_flash_attention(
583+
query=query,
584+
key=key,
585+
value=value,
586+
mask=mask,
587+
seq_len=self.seq_lens_tensor_cpu,
588+
scale_value=self.scale,
589+
num_heads=self.num_heads,
590+
num_kv_heads=self.num_kv_heads,
591+
out=output)
594592
else:
595593
# TODO: Will support prefix cache and chunked prefill soon.
596594
raise RuntimeError(
@@ -600,13 +598,16 @@ def forward(
600598
assert kv_cache is not None
601599
self.seq_lens_tensor_cpu = torch.from_numpy(np.array(attn_metadata.decode_metadata.seq_lens).astype(np.int32))
602600
block_tables = attn_metadata.decode_metadata.block_tables
603-
torch_npu.npu_pagedattention(query=query, keyCache=key_cache, valueCache=value_cache,
604-
contextLens=self.seq_lens_tensor_cpu, maskType=0,
605-
kvHeadNum=self.num_kv_heads, headNum=self.num_heads, mlaVHeadSize=0,
606-
qkScale=self.scale, scaleType=0, blockTables=block_tables,
607-
batchRunStatusEnable=False, hasQuantOffset=False,
608-
calcType=3, quantType=0, compressType=0,
609-
inputLayout=0, outDataType=0, attnOut=output)
601+
torch_npu._npu_paged_attention(
602+
query=query,
603+
key_cache=key_cache,
604+
value_cache=value_cache,
605+
num_kv_heads=self.num_kv_heads,
606+
num_heads=self.num_heads,
607+
scale_value=self.scale,
608+
block_table=block_tables,
609+
context_lens=self.seq_lens_tensor_cpu,
610+
out=output)
610611

611612
return output.view(num_tokens, self.hidden_size)
612613

@@ -743,13 +744,9 @@ def forward(
743744
key_cache = key_cache.view(num_blocks, block_size, self.num_kv_heads,
744745
self.qk_rope_head_dim+self.kv_lora_rank)
745746
slots = attn_metadata.slot_mapping
746-
torch_npu.npu_reshapecache(key=k_cache,
747-
value=None,
748-
keyCache=key_cache,
749-
valueCache=None,
750-
slotMapping=slots,
751-
compressType=0,
752-
kvCacheCfg=1)
747+
torch_npu._npu_reshape_and_cache_siso(key=k_cache,
748+
key_cache=key_cache,
749+
slot_indices=slots)
753750

754751
if attn_metadata.num_prefills > 0:
755752
attn_output = torch.empty(num_tokens,
@@ -762,16 +759,16 @@ def forward(
762759
assert attn_metadata.attn_mask is not None
763760
mask = attn_metadata.attn_mask
764761
self.seq_lens_tensor_cpu = torch.from_numpy(np.array(attn_metadata.prefill_metadata.seq_lens).astype(np.int32))
765-
torch_npu.npu_selfattention(query=query, key=key, value=value, kvcacheCfg=0,
766-
mask=mask, maskType=1, isTriuMask=0,
767-
seqLen=self.seq_lens_tensor_cpu,
768-
scale=self.scale, qScale=1, scaleType=0,
769-
headNum=self.num_heads, kvHeadNum=self.num_heads, mlaVHeadSize=0,
770-
calcType=3, kernelType=0, clampType=0,
771-
quantType=0, cacheType=0, windowSize=0,
772-
clampMin=0, clampMax=0,
773-
batchRunStatusEnable=False, inputLayout=0,
774-
outDataType=0, out=attn_output)
762+
torch_npu._npu_flash_attention(
763+
query=query,
764+
key=key,
765+
value=value,
766+
mask=mask,
767+
seq_len=self.seq_lens_tensor_cpu,
768+
scale_value=self.scale,
769+
num_heads=self.num_heads,
770+
num_kv_heads=self.num_heads,
771+
out=attn_output)
775772
else:
776773
# TODO: Will support prefix cache and chunked prefill soon.
777774
raise RuntimeError(
@@ -786,15 +783,16 @@ def forward(
786783
device="npu")
787784
self.seq_lens_tensor_cpu = torch.from_numpy(np.array(attn_metadata.decode_metadata.seq_lens).astype(np.int32))
788785
block_tables = attn_metadata.decode_metadata.block_tables
789-
torch_npu.npu_pagedattention(query=query, keyCache=key_cache, valueCache=None,
790-
contextLens=self.seq_lens_tensor_cpu,
791-
maskType=0,
792-
kvHeadNum=self.num_kv_heads, headNum=self.num_heads,
793-
mlaVHeadSize=self.kv_lora_rank,
794-
qkScale=self.scale, blockTables=block_tables,
795-
batchRunStatusEnable=False, hasQuantOffset=False,
796-
compressType=0, calcType=0, scaleType=0, quantType=0,
797-
inputLayout=0, outDataType=-1, attnOut=attn_output)
786+
torch_npu._npu_paged_attention_mla(
787+
query=query,
788+
key_cache=key_cache,
789+
num_kv_heads=self.num_kv_heads,
790+
num_heads=self.num_heads,
791+
scale_value=self.scale,
792+
block_table=block_tables,
793+
context_lens=self.seq_lens_tensor_cpu,
794+
mla_vheadsize=self.kv_lora_rank,
795+
out=attn_output)
798796
attn_output_t = torch_npu.npu_transpose(attn_output, (1, 0, 2), require_contiguous=True)
799797
attn_output_t = torch.bmm(attn_output_t, self.w_vc)
800798
attn_output = torch_npu.npu_transpose(attn_output_t, (1, 0, 2), require_contiguous=True)

vllm_ascend/ops/rotary_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def rope_forward_oot(
4141
# TODO: Remove the contiguous in the future.
4242
query = query.contiguous()
4343
key = key.contiguous()
44-
torch_npu.npu_rope(
44+
torch_npu._npu_rotary_embedding(
4545
positions,
4646
query,
4747
key,

0 commit comments

Comments
 (0)