Skip to content
212 changes: 105 additions & 107 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
Expand Down Expand Up @@ -318,6 +317,18 @@ def build(
pcp_metadata = None
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
if common_long_seq_metadata is not None:
attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens
head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens
tail_attn_nomask_seqlens = common_long_seq_metadata.tail_attn_nomask_seqlens
pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
if pcp_size > 1:
attn_mask_seqlens = torch.cumsum(attn_mask_seqlens[0],
dim=0).tolist()
head_attn_nomask_seqlens = torch.cumsum(
head_attn_nomask_seqlens[1], dim=0).tolist()
tail_attn_nomask_seqlens = torch.cumsum(
tail_attn_nomask_seqlens[1], dim=0).tolist()
pcp_metadata = AscendPCPMetadata(
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
Expand All @@ -329,12 +340,9 @@ def build(
kv_with_q_tail_nomask_idx_tensor,
kv_with_q_tail_mask_idx=common_long_seq_metadata.
kv_with_q_tail_mask_idx_tensor,
attn_mask_seqlens=common_long_seq_metadata.
attn_mask_seqlens,
head_attn_nomask_seqlens=common_long_seq_metadata.
head_attn_nomask_seqlens,
tail_attn_nomask_seqlens=common_long_seq_metadata.
tail_attn_nomask_seqlens,
attn_mask_seqlens=attn_mask_seqlens,
head_attn_nomask_seqlens=head_attn_nomask_seqlens,
tail_attn_nomask_seqlens=tail_attn_nomask_seqlens,
q_full_idx=common_long_seq_metadata.q_full_idx,
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask)
prefill_metadata = AscendMetadataForPrefill(
Expand Down Expand Up @@ -701,28 +709,6 @@ def _forward_v1_style(
out=output)
return output

def _pack_tnd_2_bsnd(self, tensor_tnd: torch.Tensor,
lengths: List[int]) -> torch.Tensor:
max_len = max(lengths)
splits = torch.split(tensor_tnd, lengths, dim=0)

padded = []
for s in splits:
pad_len = max_len - s.shape[0]
s_pad = F.pad(s, (0, 0, 0, 0, 0, pad_len))
padded.append(s_pad)

tensor_bsnd = torch.stack(padded, dim=0)
return tensor_bsnd

def _unpack_bsnd_2_tnd(self, tensor_bsnd: torch.Tensor,
lengths: List[int]) -> torch.Tensor:
slices = []
for i, length in enumerate(lengths):
slices.append(tensor_bsnd[i, :length])
tensor_tnd = torch.cat(slices, dim=0)
return tensor_tnd

def _attention_with_nomask_and_mask(self, q: torch.Tensor,
q_seqlens: List[int],
k_nomask: torch.Tensor,
Expand All @@ -732,17 +718,15 @@ def _attention_with_nomask_and_mask(self, q: torch.Tensor,
v_mask: torch.Tensor,
kv_seqlens_mask: List[int],
mask: torch.Tensor) -> torch.Tensor:
q = self._pack_tnd_2_bsnd(q, q_seqlens)

# nomask Attention
if k_nomask is not None:
attn_out_nomask, attn_lse_nomask = torch.ops.npu.npu_fused_infer_attention_score(
q,
self._pack_tnd_2_bsnd(k_nomask, kv_seqlens_nomask),
self._pack_tnd_2_bsnd(v_nomask, kv_seqlens_nomask),
k_nomask,
v_nomask,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BSND",
input_layout="TND",
atten_mask=None,
scale=self.scale,
sparse_mode=0,
Expand All @@ -751,38 +735,46 @@ def _attention_with_nomask_and_mask(self, q: torch.Tensor,
softmax_lse_flag=True,
actual_seq_lengths_kv=kv_seqlens_nomask,
actual_seq_lengths=q_seqlens)
attn_out_nomask = self._unpack_bsnd_2_tnd(attn_out_nomask,
q_seqlens)
# (B, N, Q_S, 1) -> (B, Q_S, N, 1) -> (T, N, 1)
attn_lse_nomask = self._unpack_bsnd_2_tnd(
attn_lse_nomask.permute([0, 2, 1, 3]), q_seqlens)

# mask Attention
attn_out_mask, attn_lse_mask = torch.ops.npu.npu_fused_infer_attention_score(
q,
self._pack_tnd_2_bsnd(k_mask, kv_seqlens_mask),
self._pack_tnd_2_bsnd(v_mask, kv_seqlens_mask),
k_mask,
v_mask,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BSND",
input_layout="TND",
atten_mask=mask,
scale=self.scale,
sparse_mode=0,
sparse_mode=3,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
actual_seq_lengths_kv=kv_seqlens_mask,
actual_seq_lengths=q_seqlens)
attn_out_mask = self._unpack_bsnd_2_tnd(attn_out_mask, q_seqlens)
attn_lse_mask = self._unpack_bsnd_2_tnd(
attn_lse_mask.permute([0, 2, 1, 3]), q_seqlens)

# update
output = attn_out_mask
if k_nomask is not None:
output, _ = self._update_out_and_lse(
torch.stack([attn_out_nomask, attn_out_mask], dim=0),
torch.stack([attn_lse_nomask, attn_lse_mask], dim=0))
T = attn_out_mask.shape[0]
N = attn_out_mask.shape[1]
D = attn_out_mask.shape[2]

attn_out_mask, attn_lse_mask = self._out_lse_reshape(
attn_out_mask, attn_lse_mask)
attn_out_nomask, attn_lse_nomask = self._out_lse_reshape(
attn_out_nomask, attn_lse_nomask)
attn_out_mask = attn_out_mask.to(torch.float32)
attn_out_nomask = attn_out_nomask.to(torch.float32)
attn_lse_mask = attn_lse_mask.to(torch.float32)
attn_lse_nomask = attn_lse_nomask.to(torch.float32)

attn_output = [attn_out_nomask, attn_out_mask]
attn_lse = [attn_lse_nomask, attn_lse_mask]
update_type = 0
output, _ = torch_npu.npu_attention_update(attn_lse, attn_output,
update_type)
output = output.view(T, N, D)

return output

Expand All @@ -807,29 +799,29 @@ def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor,
# 1. Attention calculation in the first half of Q in load balancing
output_head = self._attention_with_nomask_and_mask(
q=torch.index_select(query, 0, q_head_idx),
q_seqlens=attn_mask_seqlens[0].tolist(),
q_seqlens=attn_mask_seqlens,
k_nomask=torch.index_select(key, 0, kv_with_q_head_nomask_idx)
if self.pcp_rank > 0 else None,
v_nomask=torch.index_select(value, 0, kv_with_q_head_nomask_idx)
if self.pcp_rank > 0 else None,
kv_seqlens_nomask=head_attn_nomask_seqlens[1].tolist(),
kv_seqlens_nomask=head_attn_nomask_seqlens,
k_mask=torch.index_select(key, 0, kv_with_q_head_mask_idx),
v_mask=torch.index_select(value, 0, kv_with_q_head_mask_idx),
kv_seqlens_mask=attn_mask_seqlens[0].tolist(),
kv_seqlens_mask=attn_mask_seqlens,
mask=mask)

# 2. the Attention calculation in the latter half of Q in load balancing
# pcp_rank0: Q3*KV0~KV2 + Q3*KV3
# pcp_rank1: Q2*KV0~KV1 + Q2*KV2
output_tail = self._attention_with_nomask_and_mask(
q=torch.index_select(query, 0, q_tail_idx),
q_seqlens=attn_mask_seqlens[0].tolist(),
q_seqlens=attn_mask_seqlens,
k_nomask=torch.index_select(key, 0, kv_with_q_tail_nomask_idx),
v_nomask=torch.index_select(value, 0, kv_with_q_tail_nomask_idx),
kv_seqlens_nomask=tail_attn_nomask_seqlens[1].tolist(),
kv_seqlens_nomask=tail_attn_nomask_seqlens,
k_mask=torch.index_select(key, 0, kv_with_q_tail_mask_idx),
v_mask=torch.index_select(value, 0, kv_with_q_tail_mask_idx),
kv_seqlens_mask=attn_mask_seqlens[0].tolist(),
kv_seqlens_mask=attn_mask_seqlens,
mask=mask)

# 3. Combine the output of the first half and second half.
Expand All @@ -838,20 +830,36 @@ def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor,
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
return output

def _update_out_and_lse(self, out_list: torch.Tensor,
lse_list: torch.Tensor) -> torch.Tensor:
"""LSE_final = log(sum(exp(LSE_i))), O_final = sum(exp(LSE_i - LSE_final) * O_i)
Args:
out_list: shape = [N, batch_size, num_heads, head_size]
lse_list: shape = [N, batch_size, num_heads, 1]
Returns:
out_final: shape = [batch_size, num_heads, head_size]
lse_final: shape = [batch_size, num_heads, 1]
"""
lse_final = torch.logsumexp(lse_list, dim=0, keepdim=False)
out_final = torch.sum(torch.exp(lse_list - lse_final) * out_list,
dim=0)
return out_final, lse_final
def _out_lse_reshape(self, attn_out: torch.Tensor,
attn_lse: torch.Tensor) -> torch.Tensor:
attn_out = attn_out.contiguous().view(
attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2])
attn_lse = attn_lse.contiguous().view(
attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
return attn_out, attn_lse

def _npu_attention_update(
self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor:
update_type = 0

batch = attn_out_lse_list[0].shape[0]
num_heads = attn_out_lse_list[0].shape[1]
head_dim = attn_out_lse_list[0].shape[2] - 1

attn_out_split_cp = []
attn_lse_split_cp = []

for i in attn_out_lse_list:
attn_out_allgather, attn_lse_allgather = self._out_lse_reshape(
*torch.split(i, [self.head_size, 1], dim=-1))
attn_out_split_cp.append(attn_out_allgather)
attn_lse_split_cp.append(attn_lse_allgather)

attn_out, attn_lse = torch_npu.npu_attention_update(
attn_lse_split_cp, attn_out_split_cp, update_type)
attn_out = attn_out.view(batch, num_heads, head_dim)

return attn_out

def _forward_decode_pcp_dcp(self, query: torch.Tensor,
attn_metadata: AscendMetadata) -> torch.Tensor:
Expand All @@ -864,9 +872,6 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
else:
num_heads = self.num_heads

# 1. Compute out&lse by "npu_fused_infer_attention_score"
q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2])
# [b,num_heads,head_size] -> [b,1,num_heads,head_size]
k_nope = self.key_cache.view(self.key_cache.shape[0],
self.key_cache.shape[1], -1)
value = self.value_cache.view(self.key_cache.shape[0],
Expand All @@ -877,7 +882,7 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
'num_key_value_heads':
self.num_kv_heads,
'input_layout':
"BSND",
"TND",
'atten_mask':
None,
'scale':
Expand All @@ -892,9 +897,11 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
attn_metadata.block_tables,
'block_size':
self.key_cache.shape[1],
"actual_seq_lengths_kv":
attn_metadata.decode_meta.
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
'actual_seq_lengths_kv':
attn_metadata.seq_lens_list[:attn_metadata.num_decode_tokens],
'actual_seq_lengths':
attn_metadata.actual_seq_lengths_q[:attn_metadata.
num_decode_tokens]
}
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
Expand All @@ -910,16 +917,16 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
q_nope, k_nope, value, **common_kwargs)
query, k_nope, value, **common_kwargs)
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))
attn_out = torch.empty_like(q_nope)
attn_out = torch.empty_like(query)
attn_lse = torch.empty((num_tokens, num_heads, 1, 1),
dtype=torch.float,
device=q_nope.device)
device=query.device)

graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
(weak_ref_tensors(query), weak_ref_tensors(k_nope),
weak_ref_tensors(value), self.num_heads, self.num_kv_heads,
self.scale, attn_metadata.block_tables,
self.key_cache.shape[1], attn_metadata.decode_meta.
Expand All @@ -929,7 +936,7 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
self.pcp_rank, self.dcp_rank, self.dcp_size))
torch.npu.graph_task_group_begin(stream)
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
query,
k_nope,
value,
**common_kwargs,
Expand All @@ -939,14 +946,12 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
graph_params.handles[num_tokens].append(handle)
else:
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
q_nope, k_nope, value, **common_kwargs)
query, k_nope, value, **common_kwargs)

attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2],
attn_out.shape[3])
attn_lse = attn_lse.view(attn_lse.shape[0], attn_lse.shape[1], 1)
attn_out_lse_list = []
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
if self.dcp_size > 1:
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
Expand All @@ -955,35 +960,28 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
group=self.dcp_group)
# permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
attn_out_lse_split_on_seq = list(
if self.pcp_size > 1:
attn_out_lse = attn_out_lse_all2all.contiguous()
attn_out_lse_list = list(
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))

attn_out_lse_split_dcp = torch.stack(
attn_out_lse_split_on_seq,
dim=0) # [dcp, batch_size, num_heads, head_size+1]
# Update out&lse
attn_out_split_dcp, attn_lse_split_dcp = torch.split(
attn_out_lse_split_dcp, [self.head_size, 1], dim=-1)
attn_out, attn_lse = self._update_out_and_lse(
attn_out_split_dcp, attn_lse_split_dcp)
if self.pcp_size > 1:
# 2. Concat out&lse: [bs,num_heads,head_size] + [bs,num_heads,1] -> [bs,num_heads,head_size+1]
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
# 3. AllGather out&lse within CP group
# AllGather out&lse within CP group
attn_out_lse_list = [
torch.empty_like(attn_out_lse) for _ in range(self.pcp_size)
]
dist.all_gather(attn_out_lse_list,
attn_out_lse,
group=self.pcp_group)
# 4. Update out&lse
attn_out_lse_allgather = torch.stack(
attn_out_lse_list,
dim=0) # [pcp, batch_size, num_heads, head_size+1]
attn_out_allgather, attn_lse_allgather = torch.split(
attn_out_lse_allgather, [self.head_size, 1], dim=-1)
attn_out, _ = self._update_out_and_lse(attn_out_allgather,
attn_lse_allgather)
if self.dcp_size > 1 and self.pcp_size > 1:
attn_out_lse_list_pcp_dcp = []
for s in attn_out_lse_list:
attn_out_lse_list_split = list(
torch.chunk(s, self.dcp_size, dim=1))
attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
attn_out_lse_list = attn_out_lse_list_pcp_dcp
# Update out&lse
attn_out = self._npu_attention_update(attn_out_lse_list)
return attn_out

def _forward_pcp_dcp(self, query: torch.Tensor, key: torch.Tensor,
Expand Down
Loading
Loading