diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 62bca3098d..a6bbf279ba 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -23,7 +23,6 @@ import torch import torch.distributed as dist import torch.nn as nn -import torch.nn.functional as F import torch_npu from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) @@ -318,6 +317,18 @@ def build( pcp_metadata = None common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata if common_long_seq_metadata is not None: + attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens + head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens + tail_attn_nomask_seqlens = common_long_seq_metadata.tail_attn_nomask_seqlens + pcp_size = get_prefill_context_model_parallel_world_size( + ) if prefill_context_parallel_enable() else 1 + if pcp_size > 1: + attn_mask_seqlens = torch.cumsum(attn_mask_seqlens[0], + dim=0).tolist() + head_attn_nomask_seqlens = torch.cumsum( + head_attn_nomask_seqlens[1], dim=0).tolist() + tail_attn_nomask_seqlens = torch.cumsum( + tail_attn_nomask_seqlens[1], dim=0).tolist() pcp_metadata = AscendPCPMetadata( q_head_idx=common_long_seq_metadata.q_head_idx_tensor, q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor, @@ -329,12 +340,9 @@ def build( kv_with_q_tail_nomask_idx_tensor, kv_with_q_tail_mask_idx=common_long_seq_metadata. kv_with_q_tail_mask_idx_tensor, - attn_mask_seqlens=common_long_seq_metadata. - attn_mask_seqlens, - head_attn_nomask_seqlens=common_long_seq_metadata. - head_attn_nomask_seqlens, - tail_attn_nomask_seqlens=common_long_seq_metadata. - tail_attn_nomask_seqlens, + attn_mask_seqlens=attn_mask_seqlens, + head_attn_nomask_seqlens=head_attn_nomask_seqlens, + tail_attn_nomask_seqlens=tail_attn_nomask_seqlens, q_full_idx=common_long_seq_metadata.q_full_idx, pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask) prefill_metadata = AscendMetadataForPrefill( @@ -701,28 +709,6 @@ def _forward_v1_style( out=output) return output - def _pack_tnd_2_bsnd(self, tensor_tnd: torch.Tensor, - lengths: List[int]) -> torch.Tensor: - max_len = max(lengths) - splits = torch.split(tensor_tnd, lengths, dim=0) - - padded = [] - for s in splits: - pad_len = max_len - s.shape[0] - s_pad = F.pad(s, (0, 0, 0, 0, 0, pad_len)) - padded.append(s_pad) - - tensor_bsnd = torch.stack(padded, dim=0) - return tensor_bsnd - - def _unpack_bsnd_2_tnd(self, tensor_bsnd: torch.Tensor, - lengths: List[int]) -> torch.Tensor: - slices = [] - for i, length in enumerate(lengths): - slices.append(tensor_bsnd[i, :length]) - tensor_tnd = torch.cat(slices, dim=0) - return tensor_tnd - def _attention_with_nomask_and_mask(self, q: torch.Tensor, q_seqlens: List[int], k_nomask: torch.Tensor, @@ -732,17 +718,15 @@ def _attention_with_nomask_and_mask(self, q: torch.Tensor, v_mask: torch.Tensor, kv_seqlens_mask: List[int], mask: torch.Tensor) -> torch.Tensor: - q = self._pack_tnd_2_bsnd(q, q_seqlens) - # nomask Attention if k_nomask is not None: attn_out_nomask, attn_lse_nomask = torch.ops.npu.npu_fused_infer_attention_score( q, - self._pack_tnd_2_bsnd(k_nomask, kv_seqlens_nomask), - self._pack_tnd_2_bsnd(v_nomask, kv_seqlens_nomask), + k_nomask, + v_nomask, num_heads=self.num_heads, num_key_value_heads=self.num_kv_heads, - input_layout="BSND", + input_layout="TND", atten_mask=None, scale=self.scale, sparse_mode=0, @@ -751,38 +735,46 @@ def _attention_with_nomask_and_mask(self, q: torch.Tensor, softmax_lse_flag=True, actual_seq_lengths_kv=kv_seqlens_nomask, actual_seq_lengths=q_seqlens) - attn_out_nomask = self._unpack_bsnd_2_tnd(attn_out_nomask, - q_seqlens) - # (B, N, Q_S, 1) -> (B, Q_S, N, 1) -> (T, N, 1) - attn_lse_nomask = self._unpack_bsnd_2_tnd( - attn_lse_nomask.permute([0, 2, 1, 3]), q_seqlens) # mask Attention attn_out_mask, attn_lse_mask = torch.ops.npu.npu_fused_infer_attention_score( q, - self._pack_tnd_2_bsnd(k_mask, kv_seqlens_mask), - self._pack_tnd_2_bsnd(v_mask, kv_seqlens_mask), + k_mask, + v_mask, num_heads=self.num_heads, num_key_value_heads=self.num_kv_heads, - input_layout="BSND", + input_layout="TND", atten_mask=mask, scale=self.scale, - sparse_mode=0, + sparse_mode=3, antiquant_mode=0, antiquant_scale=None, softmax_lse_flag=True, actual_seq_lengths_kv=kv_seqlens_mask, actual_seq_lengths=q_seqlens) - attn_out_mask = self._unpack_bsnd_2_tnd(attn_out_mask, q_seqlens) - attn_lse_mask = self._unpack_bsnd_2_tnd( - attn_lse_mask.permute([0, 2, 1, 3]), q_seqlens) # update output = attn_out_mask if k_nomask is not None: - output, _ = self._update_out_and_lse( - torch.stack([attn_out_nomask, attn_out_mask], dim=0), - torch.stack([attn_lse_nomask, attn_lse_mask], dim=0)) + T = attn_out_mask.shape[0] + N = attn_out_mask.shape[1] + D = attn_out_mask.shape[2] + + attn_out_mask, attn_lse_mask = self._out_lse_reshape( + attn_out_mask, attn_lse_mask) + attn_out_nomask, attn_lse_nomask = self._out_lse_reshape( + attn_out_nomask, attn_lse_nomask) + attn_out_mask = attn_out_mask.to(torch.float32) + attn_out_nomask = attn_out_nomask.to(torch.float32) + attn_lse_mask = attn_lse_mask.to(torch.float32) + attn_lse_nomask = attn_lse_nomask.to(torch.float32) + + attn_output = [attn_out_nomask, attn_out_mask] + attn_lse = [attn_lse_nomask, attn_lse_mask] + update_type = 0 + output, _ = torch_npu.npu_attention_update(attn_lse, attn_output, + update_type) + output = output.view(T, N, D) return output @@ -807,15 +799,15 @@ def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor, # 1. Attention calculation in the first half of Q in load balancing output_head = self._attention_with_nomask_and_mask( q=torch.index_select(query, 0, q_head_idx), - q_seqlens=attn_mask_seqlens[0].tolist(), + q_seqlens=attn_mask_seqlens, k_nomask=torch.index_select(key, 0, kv_with_q_head_nomask_idx) if self.pcp_rank > 0 else None, v_nomask=torch.index_select(value, 0, kv_with_q_head_nomask_idx) if self.pcp_rank > 0 else None, - kv_seqlens_nomask=head_attn_nomask_seqlens[1].tolist(), + kv_seqlens_nomask=head_attn_nomask_seqlens, k_mask=torch.index_select(key, 0, kv_with_q_head_mask_idx), v_mask=torch.index_select(value, 0, kv_with_q_head_mask_idx), - kv_seqlens_mask=attn_mask_seqlens[0].tolist(), + kv_seqlens_mask=attn_mask_seqlens, mask=mask) # 2. the Attention calculation in the latter half of Q in load balancing @@ -823,13 +815,13 @@ def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor, # pcp_rank1: Q2*KV0~KV1 + Q2*KV2 output_tail = self._attention_with_nomask_and_mask( q=torch.index_select(query, 0, q_tail_idx), - q_seqlens=attn_mask_seqlens[0].tolist(), + q_seqlens=attn_mask_seqlens, k_nomask=torch.index_select(key, 0, kv_with_q_tail_nomask_idx), v_nomask=torch.index_select(value, 0, kv_with_q_tail_nomask_idx), - kv_seqlens_nomask=tail_attn_nomask_seqlens[1].tolist(), + kv_seqlens_nomask=tail_attn_nomask_seqlens, k_mask=torch.index_select(key, 0, kv_with_q_tail_mask_idx), v_mask=torch.index_select(value, 0, kv_with_q_tail_mask_idx), - kv_seqlens_mask=attn_mask_seqlens[0].tolist(), + kv_seqlens_mask=attn_mask_seqlens, mask=mask) # 3. Combine the output of the first half and second half. @@ -838,20 +830,36 @@ def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor, torch.cat([output_head, output_tail], dim=0), 0, q_full_idx) return output - def _update_out_and_lse(self, out_list: torch.Tensor, - lse_list: torch.Tensor) -> torch.Tensor: - """LSE_final = log(sum(exp(LSE_i))), O_final = sum(exp(LSE_i - LSE_final) * O_i) - Args: - out_list: shape = [N, batch_size, num_heads, head_size] - lse_list: shape = [N, batch_size, num_heads, 1] - Returns: - out_final: shape = [batch_size, num_heads, head_size] - lse_final: shape = [batch_size, num_heads, 1] - """ - lse_final = torch.logsumexp(lse_list, dim=0, keepdim=False) - out_final = torch.sum(torch.exp(lse_list - lse_final) * out_list, - dim=0) - return out_final, lse_final + def _out_lse_reshape(self, attn_out: torch.Tensor, + attn_lse: torch.Tensor) -> torch.Tensor: + attn_out = attn_out.contiguous().view( + attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2]) + attn_lse = attn_lse.contiguous().view( + attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2]) + return attn_out, attn_lse + + def _npu_attention_update( + self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor: + update_type = 0 + + batch = attn_out_lse_list[0].shape[0] + num_heads = attn_out_lse_list[0].shape[1] + head_dim = attn_out_lse_list[0].shape[2] - 1 + + attn_out_split_cp = [] + attn_lse_split_cp = [] + + for i in attn_out_lse_list: + attn_out_allgather, attn_lse_allgather = self._out_lse_reshape( + *torch.split(i, [self.head_size, 1], dim=-1)) + attn_out_split_cp.append(attn_out_allgather) + attn_lse_split_cp.append(attn_lse_allgather) + + attn_out, attn_lse = torch_npu.npu_attention_update( + attn_lse_split_cp, attn_out_split_cp, update_type) + attn_out = attn_out.view(batch, num_heads, head_dim) + + return attn_out def _forward_decode_pcp_dcp(self, query: torch.Tensor, attn_metadata: AscendMetadata) -> torch.Tensor: @@ -864,9 +872,6 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor, else: num_heads = self.num_heads - # 1. Compute out&lse by "npu_fused_infer_attention_score" - q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2]) - # [b,num_heads,head_size] -> [b,1,num_heads,head_size] k_nope = self.key_cache.view(self.key_cache.shape[0], self.key_cache.shape[1], -1) value = self.value_cache.view(self.key_cache.shape[0], @@ -877,7 +882,7 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor, 'num_key_value_heads': self.num_kv_heads, 'input_layout': - "BSND", + "TND", 'atten_mask': None, 'scale': @@ -892,9 +897,11 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor, attn_metadata.block_tables, 'block_size': self.key_cache.shape[1], - "actual_seq_lengths_kv": - attn_metadata.decode_meta. - num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank], + 'actual_seq_lengths_kv': + attn_metadata.seq_lens_list[:attn_metadata.num_decode_tokens], + 'actual_seq_lengths': + attn_metadata.actual_seq_lengths_q[:attn_metadata. + num_decode_tokens] } graph_params = get_graph_params() forward_context: ForwardContext = get_forward_context() @@ -910,16 +917,16 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor, workspace = graph_params.workspaces.get(num_tokens) if workspace is None: workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( - q_nope, k_nope, value, **common_kwargs) + query, k_nope, value, **common_kwargs) update_graph_params_workspaces(num_tokens, weak_ref_tensors(workspace)) - attn_out = torch.empty_like(q_nope) + attn_out = torch.empty_like(query) attn_lse = torch.empty((num_tokens, num_heads, 1, 1), dtype=torch.float, - device=q_nope.device) + device=query.device) graph_params.attn_params[num_tokens].append( - (weak_ref_tensors(q_nope), weak_ref_tensors(k_nope), + (weak_ref_tensors(query), weak_ref_tensors(k_nope), weak_ref_tensors(value), self.num_heads, self.num_kv_heads, self.scale, attn_metadata.block_tables, self.key_cache.shape[1], attn_metadata.decode_meta. @@ -929,7 +936,7 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor, self.pcp_rank, self.dcp_rank, self.dcp_size)) torch.npu.graph_task_group_begin(stream) torch_npu.npu_fused_infer_attention_score.out( - q_nope, + query, k_nope, value, **common_kwargs, @@ -939,14 +946,12 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor, graph_params.handles[num_tokens].append(handle) else: attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score( - q_nope, k_nope, value, **common_kwargs) + query, k_nope, value, **common_kwargs) - attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2], - attn_out.shape[3]) - attn_lse = attn_lse.view(attn_lse.shape[0], attn_lse.shape[1], 1) + attn_out_lse_list = [] + # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] + attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1) if self.dcp_size > 1: - # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] - attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1) # permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs] attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous() attn_out_lse_all2all = torch.empty_like(attn_out_lse) @@ -955,35 +960,28 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor, group=self.dcp_group) # permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1] attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1]) - attn_out_lse_split_on_seq = list( + if self.pcp_size > 1: + attn_out_lse = attn_out_lse_all2all.contiguous() + attn_out_lse_list = list( torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1)) - attn_out_lse_split_dcp = torch.stack( - attn_out_lse_split_on_seq, - dim=0) # [dcp, batch_size, num_heads, head_size+1] - # Update out&lse - attn_out_split_dcp, attn_lse_split_dcp = torch.split( - attn_out_lse_split_dcp, [self.head_size, 1], dim=-1) - attn_out, attn_lse = self._update_out_and_lse( - attn_out_split_dcp, attn_lse_split_dcp) if self.pcp_size > 1: - # 2. Concat out&lse: [bs,num_heads,head_size] + [bs,num_heads,1] -> [bs,num_heads,head_size+1] - attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1) - # 3. AllGather out&lse within CP group + # AllGather out&lse within CP group attn_out_lse_list = [ torch.empty_like(attn_out_lse) for _ in range(self.pcp_size) ] dist.all_gather(attn_out_lse_list, attn_out_lse, group=self.pcp_group) - # 4. Update out&lse - attn_out_lse_allgather = torch.stack( - attn_out_lse_list, - dim=0) # [pcp, batch_size, num_heads, head_size+1] - attn_out_allgather, attn_lse_allgather = torch.split( - attn_out_lse_allgather, [self.head_size, 1], dim=-1) - attn_out, _ = self._update_out_and_lse(attn_out_allgather, - attn_lse_allgather) + if self.dcp_size > 1 and self.pcp_size > 1: + attn_out_lse_list_pcp_dcp = [] + for s in attn_out_lse_list: + attn_out_lse_list_split = list( + torch.chunk(s, self.dcp_size, dim=1)) + attn_out_lse_list_pcp_dcp += attn_out_lse_list_split + attn_out_lse_list = attn_out_lse_list_pcp_dcp + # Update out&lse + attn_out = self._npu_attention_update(attn_out_lse_list) return attn_out def _forward_pcp_dcp(self, query: torch.Tensor, key: torch.Tensor, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f30a9a39b4..889f5d874d 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1365,13 +1365,13 @@ def _prepare_inputs( self.input_batch.block_table.compute_slot_mapping( req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping( + total_num_scheduled_tokens) tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp( tokens) num_scheduled_tokens = np.array(tokens, dtype=np.int32) # update total_num_scheduled_tokens total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs]) - self.input_batch.block_table.commit_slot_mapping( - total_num_scheduled_tokens) total_num_pcp_pads = sum(self.num_pcp_pads) max_num_scheduled_tokens = max(tokens) @@ -4118,7 +4118,6 @@ def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens): num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs] >= self.input_batch.num_prompt_tokens[:num_reqs]) num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size - num_prefills = num_reqs - num_decodes long_seq_metadata = None if self.pcp_size * self.dcp_size > 1: long_seq_metadata = AscendPrefillContextParallelMetadata( @@ -4226,9 +4225,8 @@ def _list_to_tensor(lst, device, dtype=torch.int32): device=self.device, dtype=self.dtype), 1) else: - max_seq_len = max(seq_lens, default=0) pcp_prefill_mask = torch.triu( - torch.full((num_prefills, max_seq_len, max_seq_len), + torch.full((2048, 2048), True, device=self.device, dtype=torch.bool), 1)