Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 3 additions & 17 deletions vllm_ascend/attention/attention_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ def __init__(
self,
max_seq_len: int,
dtype: torch.dtype,
device: torch.device,
):
attn_mask = _generate_attn_mask(max_seq_len, dtype)

self._seq_len_cached = attn_mask.shape[0]
self.attn_mask_cache = attn_mask
self.chunked_prefill_attn_mask = torch.triu(torch.ones(2048, 2048), diagonal=1).to(torch.int8).to(device)

@staticmethod
def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
Expand All @@ -66,24 +68,8 @@ def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,

def get_splitfuse_attn_mask(
self,
seq_lens: torch.Tensor,
position: torch.Tensor,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
if dtype not in [torch.float16, torch.bfloat16]:
raise ValueError(
"splitfuse_attn_mask now only supports bf16 and fp16")
max_seq_len = max(seq_lens, default=0)
self._update_attn_cache(max_seq_len, dtype)
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
# is not the same. Fix this in the future when kernel is ready.
mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(dtype)
attn_mask = torch.index_select(self.attn_mask_cache,
dim=0,
index=position)[:, :max_seq_len]
attn_mask *= mask_scale_factor
return attn_mask.contiguous().to(device, non_blocking=True)
return self.chunked_prefill_attn_mask

def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
if seqlen > self._seq_len_cached:
Expand Down
25 changes: 16 additions & 9 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,18 +455,25 @@
attn_metadata.seq_lens = \
attn_metadata.seq_lens.to(device=query.device)

torch_npu._npu_paged_attention_splitfuse(
num_block, block_size, head_num, head_dim = self.key_cache.shape

Check failure on line 458 in vllm_ascend/attention/attention_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "shape" [attr-defined]

Check failure on line 458 in vllm_ascend/attention/attention_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "shape" [attr-defined]
key = self.key_cache.view(num_block, block_size, -1)

Check failure on line 459 in vllm_ascend/attention/attention_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "view" [attr-defined]

Check failure on line 459 in vllm_ascend/attention/attention_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "view" [attr-defined]
value = self.value_cache.view(num_block, block_size, -1)

Check failure on line 460 in vllm_ascend/attention/attention_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "view" [attr-defined]

Check failure on line 460 in vllm_ascend/attention/attention_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "view" [attr-defined]

output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
mask=attn_metadata.attn_mask,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.query_start_loc[1:],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

actual_seq_lengths 参数被传递了 attn_metadata.query_start_loc[1:],它包含的是累积的 token 位置,而不是各个序列的长度。这很可能会导致错误的注意力计算结果。你应该使用 attn_metadata.query_lens,它才包含正确的序列长度,并确保它在正确的设备上。

Suggested change
actual_seq_lengths=attn_metadata.query_start_loc[1:],
actual_seq_lengths=attn_metadata.query_lens.to(query.device),

actual_seq_lengths_kv=attn_metadata.seq_lens,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
scale=self.scale,
sparse_mode=3,
)
return output

def forward(
Expand Down
5 changes: 2 additions & 3 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
)

self.attn_mask_builder = AttentionMaskBuilder(
self.model_config.max_model_len, self.dtype)
self.model_config.max_model_len, self.dtype, self.device)

# Set up speculative decoding.
self.spec_attn_mask = None
Expand Down Expand Up @@ -797,8 +797,7 @@ def _make_attention_mask(self, seq_lens, position,
attn_state) -> torch.Tensor:
# Chunk Prefill situation.
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla:
return self.attn_mask_builder.get_splitfuse_attn_mask(
seq_lens, position, self.dtype, self.device)
return self.attn_mask_builder.get_splitfuse_attn_mask()
# Prefill without cache situation.
elif attn_state == AscendAttentionState.PrefillNoCache:
max_seq_len = max(seq_lens, default=0)
Expand Down
Loading