Skip to content

Conversation

MengqingCao
Copy link
Collaborator

@MengqingCao MengqingCao commented Sep 18, 2025

What this PR does / why we need it?

This pr fixes a few issues on prefill disaggregation:

  1. Fix prefill disaggregation kvcache addr alignment issue, llmdatadist needs the addr of tensors to be aligned with 2M
  2. Fix prefill disaggregation kvcache shape error, llmdatadist requires k/v tensors with shape [num_blocks, ...], however the implentment before this pr is [2, num_blocks, ...], which will break prefill disaggregation
  3. Use hybrid kv cache only when running qwen3_next to fix accuracy issue on prefill disaggregation.

Does this PR introduce any user-facing change?

N/A

How was this patch tested?

Tested locally by @liziyu179

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces fixes for KV cache address alignment and restricts the hybrid KV cache usage to the qwen3_next model. The changes are mostly in NPUModelRunner.

My review focuses on improving code quality and maintainability. I've identified a few areas with significant code duplication that should be refactored into helper methods to reduce redundancy. Specifically, the logic for tensor allocation with alignment and the calculation/validation of num_blocks are repeated. Additionally, a magic number for memory alignment has been introduced, which should be defined as a named constant for better readability. Addressing these points will make the code cleaner and easier to maintain.

kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
kv_cache_raw_tensors = {}
# llmdatadist need the addr of cache tensor be aligned with 2M
alignment = 2 * 1024 * 1024
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The alignment value 2 * 1024 * 1024 is a magic number. It should be defined as a constant with a descriptive name (e.g., _LLMDATADIST_ALIGNMENT_BYTES) at a suitable scope (like module level) to improve readability and make it easier to manage if this value needs to be changed in the future.

Comment on lines 2594 to 2615
elif "self_attn" in layer_name:
tensor = torch.zeros(kv_cache_tensor.size,
dtype=torch.int8,
device=self.device)
kv_cache_raw_tensors[layer_name] = tensor
if self.vllm_config.kv_transfer_config:
k_tensor = torch.zeros(kv_cache_tensor.size // 2,
dtype=torch.int8,
device=self.device)
v_tensor = torch.zeros(kv_cache_tensor.size // 2,
dtype=torch.int8,
device=self.device)
else:
cache_size = kv_cache_tensor.size // 2
cache_size_aligned = kv_cache_tensor.size // 2 + alignment
k_tensor = torch.zeros(cache_size_aligned,
dtype=torch.int8,
device=self.device)
v_tensor = torch.zeros(cache_size_aligned,
dtype=torch.int8,
device=self.device)
k_tensor = self._align_memory(k_tensor,
alignment)[:cache_size]
v_tensor = self._align_memory(v_tensor,
alignment)[:cache_size]
kv_cache_raw_tensors[layer_name] = (k_tensor, v_tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for allocating tensors with optional alignment is duplicated between the if "linear_attn" in layer_name_inner case (lines 2582-2593) and this elif "self_attn" in layer_name case. This duplication makes the code harder to maintain. Consider extracting this logic into a private helper method, for instance _create_cache_tensor(size, alignment, align_memory: bool), which would handle the conditional allocation and alignment. This would simplify both branches of the conditional.

Comment on lines 2681 to 2694
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel(
) % kv_cache_spec.page_size_bytes == 0
num_blocks = raw_tensor.numel(
) // kv_cache_spec.page_size_bytes

# `num_blocks` is the number of blocks the model runner can use.
# `kv_cache_config.num_blocks` is the number of blocks that
# KVCacheManager may allocate.
# Since different GPUs may have different number of layers and
# different memory capacities, `num_blocks` can be different on
# different GPUs, and `kv_cache_config.num_blocks` is set to
# the min of all `num_blocks`. Verify it here.
assert num_blocks >= kv_cache_config.num_blocks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code, which calculates num_blocks and asserts it against the configuration, is nearly identical to the logic for FullAttentionSpec on lines 2637-2651. This duplication increases maintenance overhead. This logic could be extracted into a private helper method that takes the raw tensor(s) and the spec as input and returns the calculated num_blocks after performing the necessary assertions.

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
@wangxiyuan wangxiyuan added ready read for review ready-for-test start test by label for PR labels Sep 18, 2025
@wangxiyuan wangxiyuan merged commit 367edff into vllm-project:main Sep 18, 2025
49 of 50 checks passed
@MengqingCao MengqingCao deleted the kv_fix branch September 19, 2025 06:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module:core ready read for review ready-for-test start test by label for PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants