Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,18 +455,25 @@
attn_metadata.seq_lens = \
attn_metadata.seq_lens.to(device=query.device)

torch_npu._npu_paged_attention_splitfuse(
num_block, block_size, head_num, head_dim = self.key_cache.shape

Check failure on line 458 in vllm_ascend/attention/attention_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "shape" [attr-defined]

Check failure on line 458 in vllm_ascend/attention/attention_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "shape" [attr-defined]

Check failure on line 458 in vllm_ascend/attention/attention_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "shape" [attr-defined]

Check failure on line 458 in vllm_ascend/attention/attention_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "shape" [attr-defined]
key = self.key_cache.view(num_block, block_size, -1)

Check failure on line 459 in vllm_ascend/attention/attention_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "view" [attr-defined]

Check failure on line 459 in vllm_ascend/attention/attention_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "view" [attr-defined]

Check failure on line 459 in vllm_ascend/attention/attention_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "view" [attr-defined]
value = self.value_cache.view(num_block, block_size, -1)

Check failure on line 460 in vllm_ascend/attention/attention_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "view" [attr-defined]

Check failure on line 460 in vllm_ascend/attention/attention_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "view" [attr-defined]

Check failure on line 460 in vllm_ascend/attention/attention_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

"None" has no attribute "view" [attr-defined]

output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
mask=attn_metadata.attn_mask,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask.to(device=query.device),
block_table=attn_metadata.block_tables,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.query_start_loc[1:],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

actual_seq_lengths 参数被传递了 attn_metadata.query_start_loc[1:],它包含的是累积的 token 位置,而不是各个序列的长度。这很可能会导致错误的注意力计算结果。你应该使用 attn_metadata.query_lens,它才包含正确的序列长度,并确保它在正确的设备上。

Suggested change
actual_seq_lengths=attn_metadata.query_start_loc[1:],
actual_seq_lengths=attn_metadata.query_lens.to(query.device),

actual_seq_lengths_kv=attn_metadata.seq_lens,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
scale=self.scale,
sparse_mode=3,
)
return output

def forward(
Expand Down
3 changes: 1 addition & 2 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,8 +797,7 @@ def _make_attention_mask(self, seq_lens, position,
attn_state) -> torch.Tensor:
# Chunk Prefill situation.
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla:
return self.attn_mask_builder.get_splitfuse_attn_mask(
seq_lens, position, self.dtype, self.device)
return torch.triu(torch.ones(2048, 2048), diagonal=1).to(torch.int8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Attention mask 使用了硬编码的尺寸 (2048, 2048)。这是一个魔法数字,使得实现不够健壮。如果批处理中任何序列的长度超过 2048,将导致不正确的掩码或越界错误。掩码的大小应该由模型配置的最大序列长度决定,以确保正确性并避免魔法数字。

Suggested change
return torch.triu(torch.ones(2048, 2048), diagonal=1).to(torch.int8)
return torch.triu(torch.ones(self.model_config.max_model_len, self.model_config.max_model_len), diagonal=1).to(torch.int8)

# Prefill without cache situation.
elif attn_state == AscendAttentionState.PrefillNoCache:
max_seq_len = max(seq_lens, default=0)
Expand Down
Loading