make cache op support non-contiguous num_blocks dim#2772
make cache op support non-contiguous num_blocks dim#2772ganyi1996ppo wants to merge 1 commit intomainfrom
Conversation
…ontiguous block dim Signed-off-by: ganyi <ygan@amd.com>
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
c75e1dd to
0b7d21d
Compare
There was a problem hiding this comment.
Pull request overview
Updates KV-cache write kernels to support a strided (non-contiguous) num_blocks dimension by using the cache tensors’ stride(0) when computing target indices.
Changes:
- Add
key_cache_stride0/value_cache_stride0(and per-token variants) to relevant kernel signatures and launches. - Update key/value cache linear index calculations to use
block_idx * stride(0)instead of assuming dense packing for dim0. - Thread stride(0) values from host (
key_cache.stride(0),value_cache.stride(0)) into the affected kernel launches.
Comments suppressed due to low confidence (1)
csrc/kernels/cache_kernels.cu:189
reshape_and_cache_kernelnow supports a non-contiguousnum_blocksdim viakey_cache_stride0/value_cache_stride0, but the kernel still hardcodes the inner-dimension layout (it assumes the remaining dims are densely packed/contiguous). Please add input validation (e.g., TORCH_CHECK on the expected strides for dims 1..end) or extend the kernel to use full per-dim strides; otherwise passing a tensor that's non-contiguous in other dims will silently write to the wrong locations.
const int64_t key_cache_stride0,
const int64_t value_cache_stride0,
const int num_heads,
const int head_size,
const int block_size,
const int x,
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| int64_t k_cache_stride0 = key_cache.stride(0); | ||
| int64_t v_cache_stride0 = value_cache.stride(0); | ||
|
|
There was a problem hiding this comment.
The PR title suggests cache ops broadly support non-contiguous num_blocks, but only reshape_and_cache / per-token-quant paths were updated to use stride(0). Other cache paths in this file (e.g., the block-quant kernels) still compute block_idx * num_heads * head_size * block_size and will remain broken for a strided num_blocks dim. Either extend the same stride(0) handling there as well, or narrow/clarify the PR scope in the title/description.
| int64_t k_cache_stride0 = key_cache.stride(0); | ||
| int64_t v_cache_stride0 = value_cache.stride(0); | ||
|
|
There was a problem hiding this comment.
There’s no test coverage exercising the new “non-contiguous num_blocks” behavior (e.g., key_cache = big_cache[layer_idx] so stride(0) differs from the dense product of inner dims). Please add a unit/integration test (see op_tests/test_kvcache.py) that constructs strided views for key_cache/value_cache and verifies reshape_and_cache{,_with_pertoken_quant} writes to the correct blocks.
| const int64_t k_cache_stride0, | ||
| const int64_t v_cache_stride0, |
There was a problem hiding this comment.
Parameter naming is inconsistent between kernels (key_cache_stride0/value_cache_stride0 vs k_cache_stride0/v_cache_stride0). Please standardize the naming across these kernels and the corresponding host variables to make the API easier to follow and reduce the chance of wiring the wrong stride into a launch.
| const int64_t k_cache_stride0, | |
| const int64_t v_cache_stride0, | |
| const int64_t key_cache_stride0, | |
| const int64_t value_cache_stride0, |
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist