-
Notifications
You must be signed in to change notification settings - Fork 441
chunked prefill, access splitfuse op #2962
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
这个 PR 主要是为了接入新的 splitfuse
chunked prefill 算子。代码改动涉及 attention_v1.py
和 model_runner_v1.py
两个文件。在 attention_v1.py
中,_forward_v1_style
函数的注意力计算从 _npu_paged_attention_splitfuse
切换到了 npu_fused_infer_attention_score
。在 model_runner_v1.py
中,为 ChunkedPrefill
场景生成 attention mask 的逻辑被修改。
我的审查发现两个严重问题:
- 在
attention_v1.py
中,传递给新算子的actual_seq_lengths
参数值是错误的,使用了累积的 token 位置而不是序列长度,这会导致注意力计算错误。 - 在
model_runner_v1.py
中,为ChunkedPrefill
生成的 attention mask 使用了硬编码的尺寸(2048, 2048)
,这使得代码很脆弱,当序列长度超过 2048 时会导致错误。
建议修复这两个严重问题以保证代码的正确性和健壮性。
num_kv_heads=self.num_kv_heads, | ||
input_layout="TND", | ||
block_size=block_size, | ||
actual_seq_lengths=attn_metadata.query_start_loc[1:], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -797,8 +797,7 @@ def _make_attention_mask(self, seq_lens, position, | |||
attn_state) -> torch.Tensor: | |||
# Chunk Prefill situation. | |||
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla: | |||
return self.attn_mask_builder.get_splitfuse_attn_mask( | |||
seq_lens, position, self.dtype, self.device) | |||
return torch.triu(torch.ones(2048, 2048), diagonal=1).to(torch.int8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Attention mask 使用了硬编码的尺寸 (2048, 2048)
。这是一个魔法数字,使得实现不够健壮。如果批处理中任何序列的长度超过 2048,将导致不正确的掩码或越界错误。掩码的大小应该由模型配置的最大序列长度决定,以确保正确性并避免魔法数字。
return torch.triu(torch.ones(2048, 2048), diagonal=1).to(torch.int8) | |
return torch.triu(torch.ones(self.model_config.max_model_len, self.model_config.max_model_len), diagonal=1).to(torch.int8) |
Could you integrate other scenarios, such as full FlashAttention, using the FIA interface as well, and provide the performance test results? |
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?