diff --git a/vllm_ascend/attention/attention_mask.py b/vllm_ascend/attention/attention_mask.py index a0e63349b1..cf92affd38 100644 --- a/vllm_ascend/attention/attention_mask.py +++ b/vllm_ascend/attention/attention_mask.py @@ -39,11 +39,22 @@ def __init__( self, max_seq_len: int, dtype: torch.dtype, + device: torch.device = None, ): + # NOTE: The device argument specifies the target NPU + # to be used for the newly added FIA operator. + # Only pass this parameter when using the new FIA operator. + attn_mask = _generate_attn_mask(max_seq_len, dtype) self._seq_len_cached = attn_mask.shape[0] self.attn_mask_cache = attn_mask + self.device = device + if torch.version.cann.startswith("8.3"): + assigned_mask_dim = 2048 + self.chunked_prefill_attn_mask = torch.triu( + torch.ones(assigned_mask_dim, assigned_mask_dim), + diagonal=1).to(torch.int8).to(device) @staticmethod def get_mask_scale_factor(dtype: torch.dtype = torch.float16): @@ -66,24 +77,28 @@ def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype, def get_splitfuse_attn_mask( self, - seq_lens: torch.Tensor, - position: torch.Tensor, - dtype: torch.dtype, - device: torch.device, + seq_lens: torch.Tensor = None, + position: torch.Tensor = None, + dtype: torch.dtype = None, + device: torch.device = None, ) -> torch.Tensor: - if dtype not in [torch.float16, torch.bfloat16]: - raise ValueError( - "splitfuse_attn_mask now only supports bf16 and fp16") - max_seq_len = max(seq_lens, default=0) - self._update_attn_cache(max_seq_len, dtype) - # FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation - # is not the same. Fix this in the future when kernel is ready. - mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(dtype) - attn_mask = torch.index_select(self.attn_mask_cache, - dim=0, - index=position)[:, :max_seq_len] - attn_mask *= mask_scale_factor - return attn_mask.contiguous().to(device, non_blocking=True) + if torch.version.cann.startswith("8.3"): + return self.chunked_prefill_attn_mask + else: + if dtype not in [torch.float16, torch.bfloat16]: + raise ValueError( + "splitfuse_attn_mask now only supports bf16 and fp16") + max_seq_len = max(seq_lens, default=0) + self._update_attn_cache(max_seq_len, dtype) + # FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation + # is not the same. Fix this in the future when kernel is ready. + mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor( + dtype) + attn_mask = torch.index_select(self.attn_mask_cache, + dim=0, + index=position)[:, :max_seq_len] + attn_mask *= mask_scale_factor + return attn_mask.contiguous().to(device, non_blocking=True) def _update_attn_cache(self, seqlen: int, dtype: torch.dtype): if seqlen > self._seq_len_cached: diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 10a2f6a416..bc7f69ce5a 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -456,18 +456,43 @@ def _forward_v1_style( attn_metadata.seq_lens = \ attn_metadata.seq_lens.to(device=query.device) - torch_npu._npu_paged_attention_splitfuse( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - mask=attn_metadata.attn_mask, - block_table=attn_metadata.block_tables, - seq_len=attn_metadata.query_lens, - context_lens=attn_metadata.seq_lens, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - out=output) + if torch.version.cann.startswith("8.3"): + # TODO:The npu_fused_infer_attention_score op is planned to + # be utilized in a wider range in upcoming versions. + num_block, block_size, _, _ = self.key_cache.shape # type: ignore + key = self.key_cache.view( # type: ignore + num_block, block_size, -1) + value = self.value_cache.view( # type: ignore + num_block, block_size, -1) + + output, _ = torch_npu.npu_fused_infer_attention_score( + query=query, + key=key, + value=value, + atten_mask=attn_metadata.attn_mask, + block_table=attn_metadata.block_tables, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=attn_metadata.query_start_loc[1:], + actual_seq_lengths_kv=attn_metadata.seq_lens, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale=self.scale, + sparse_mode=3, + ) + else: + torch_npu._npu_paged_attention_splitfuse( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + mask=attn_metadata.attn_mask, + block_table=attn_metadata.block_tables, + seq_len=attn_metadata.query_lens, + context_lens=attn_metadata.seq_lens, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + out=output) return output def forward( @@ -561,12 +586,18 @@ def forward( output) # Normal V1 situation. else: + if torch.version.cann.startswith("8.3"): + # npu_fused_infer_attention_score does not support cases + # where query.shape[0] != attn_metadata.query_start_loc[-1]. + # Thus we need unpad it here. + num_tokens = attn_metadata.query_start_loc[-1] + query = query[:num_tokens] output = self._forward_v1_style(query, attn_metadata, output) # to make in-place change to the output tensor if hasattr(layer, 'quant_method') and use_kv_cache_int8: output = output.view(num_tokens, self.num_heads, self.head_size) - ori_output[:, :, :] = output[:num_tokens, :, :] + ori_output[:num_tokens, :, :] = output[:num_tokens, :, :] return output.view(num_tokens, self.hidden_size) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 6e42da1367..659541e300 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -301,8 +301,13 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): use_mla=self.model_config.use_mla, ) - self.attn_mask_builder = AttentionMaskBuilder( - self.model_config.max_model_len, self.dtype) + if torch.version.cann.startswith("8.3"): + self.attn_mask_builder = AttentionMaskBuilder( + self.scheduler_config.max_num_batched_tokens, self.dtype, + self.device) + else: + self.attn_mask_builder = AttentionMaskBuilder( + self.model_config.max_model_len, self.dtype) # Set up speculative decoding. self.spec_attn_mask = None @@ -818,8 +823,11 @@ def _make_attention_mask(self, seq_lens, position, attn_state) -> torch.Tensor: # Chunk Prefill situation. if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla: - return self.attn_mask_builder.get_splitfuse_attn_mask( - seq_lens, position, self.dtype, self.device) + if torch.version.cann.startswith("8.3"): + return self.attn_mask_builder.get_splitfuse_attn_mask() + else: + return self.attn_mask_builder.get_splitfuse_attn_mask( + seq_lens, position, self.dtype, self.device) # Prefill without cache situation. elif attn_state == AscendAttentionState.PrefillNoCache: max_seq_len = max(seq_lens, default=0)