Skip to content

make cache op support non-contiguous num_blocks dim#2772

Open
ganyi1996ppo wants to merge 1 commit intomainfrom
ganyi/cache_op_stride_0
Open

make cache op support non-contiguous num_blocks dim#2772
ganyi1996ppo wants to merge 1 commit intomainfrom
ganyi/cache_op_stride_0

Conversation

@ganyi1996ppo
Copy link
Copy Markdown
Contributor

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

@ganyi1996ppo ganyi1996ppo requested review from a team and Copilot April 17, 2026 07:42
…ontiguous block dim

Signed-off-by: ganyi <ygan@amd.com>
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2772 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Updates KV-cache write kernels to support a strided (non-contiguous) num_blocks dimension by using the cache tensors’ stride(0) when computing target indices.

Changes:

  • Add key_cache_stride0 / value_cache_stride0 (and per-token variants) to relevant kernel signatures and launches.
  • Update key/value cache linear index calculations to use block_idx * stride(0) instead of assuming dense packing for dim0.
  • Thread stride(0) values from host (key_cache.stride(0), value_cache.stride(0)) into the affected kernel launches.
Comments suppressed due to low confidence (1)

csrc/kernels/cache_kernels.cu:189

  • reshape_and_cache_kernel now supports a non-contiguous num_blocks dim via key_cache_stride0 / value_cache_stride0, but the kernel still hardcodes the inner-dimension layout (it assumes the remaining dims are densely packed/contiguous). Please add input validation (e.g., TORCH_CHECK on the expected strides for dims 1..end) or extend the kernel to use full per-dim strides; otherwise passing a tensor that's non-contiguous in other dims will silently write to the wrong locations.
                         const int64_t key_cache_stride0,
                         const int64_t value_cache_stride0,
                         const int num_heads,
                         const int head_size,
                         const int block_size,
                         const int x,

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +2656 to +2658
int64_t k_cache_stride0 = key_cache.stride(0);
int64_t v_cache_stride0 = value_cache.stride(0);

Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR title suggests cache ops broadly support non-contiguous num_blocks, but only reshape_and_cache / per-token-quant paths were updated to use stride(0). Other cache paths in this file (e.g., the block-quant kernels) still compute block_idx * num_heads * head_size * block_size and will remain broken for a strided num_blocks dim. Either extend the same stride(0) handling there as well, or narrow/clarify the PR scope in the title/description.

Copilot uses AI. Check for mistakes.
Comment on lines +2656 to +2658
int64_t k_cache_stride0 = key_cache.stride(0);
int64_t v_cache_stride0 = value_cache.stride(0);

Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There’s no test coverage exercising the new “non-contiguous num_blocks” behavior (e.g., key_cache = big_cache[layer_idx] so stride(0) differs from the dense product of inner dims). Please add a unit/integration test (see op_tests/test_kvcache.py) that constructs strided views for key_cache/value_cache and verifies reshape_and_cache{,_with_pertoken_quant} writes to the correct blocks.

Copilot uses AI. Check for mistakes.
Comment on lines +339 to +340
const int64_t k_cache_stride0,
const int64_t v_cache_stride0,
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameter naming is inconsistent between kernels (key_cache_stride0/value_cache_stride0 vs k_cache_stride0/v_cache_stride0). Please standardize the naming across these kernels and the corresponding host variables to make the API easier to follow and reduce the chance of wiring the wrong stride into a launch.

Suggested change
const int64_t k_cache_stride0,
const int64_t v_cache_stride0,
const int64_t key_cache_stride0,
const int64_t value_cache_stride0,

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants