Add per-sequence trim support to BatchKVCache #873
+805
−2
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
BatchKVCachetracks per-sequenceoffsetandleft_paddingarrays, butupdate_and_fetch()andmake_mask()only use the scalar_idxto slice keys/values and compute attention masks. This means all sequences in a batch are forced to share the same effective cache length, even though the per-sequence metadata already exists.This is a problem for batched speculative decoding: each sequence in a batch may accept a different number of draft tokens, so you need to trim each sequence independently. Without per-sequence trim, you're stuck either trimming all sequences to
min(accepted)(wasting accepted tokens) or doing an extra forward pass per sequence.The fix is small (~24 lines added, 2 lines changed in
mlx_lm/models/cache.py):update_and_fetchandmake_masknow computeend = max(left_padding + offset)instead of using_idxdirectly. When all offsets are uniform (the normal case),max(left_padding + offset) == _idx, so behavior is identical. When offsets diverge after a per-sequence trim, each sequence sees its correct cache window.make_maskcomputesright_padding = end - (left_padding + offset)for sequences that are shorter thanend, and passes it tocreate_causal_mask()which already supports theright_paddingparameter (used byBatchRotatingKVCache).trim_per_sequence(n)is a new method that takes anmx.arrayof shape(B,)specifying how many tokens to trim from each sequence:This mirrors the existing
trim(n)method but operates per-sequence. The_idxscalar is updated tomax(left_padding + offset)so the allocated buffer still covers all sequences.Example usage (speculative decoding)
Backward compatibility
max(left_padding + offset) == _idxandright_paddingis all zeros (skipped). The output is identical.trim(n)method is unchanged.filter(),extend(),extract(),merge()all work as before since they reconstruct_idxfrom the buffer.Test plan
Tests are in
tests/test_batchkvcache_per_seq_trim.py(17 tests across 4 classes):trim_per_sequence(mx.array([n, n, n]))produces identical results totrim(n)All 17 tests pass locally on Apple Silicon.