Skip to content

Conversation

@0xDaizz
Copy link

@0xDaizz 0xDaizz commented Feb 11, 2026

BatchKVCache tracks per-sequence offset and left_padding arrays, but update_and_fetch() and make_mask() only use the scalar _idx to slice keys/values and compute attention masks. This means all sequences in a batch are forced to share the same effective cache length, even though the per-sequence metadata already exists.

This is a problem for batched speculative decoding: each sequence in a batch may accept a different number of draft tokens, so you need to trim each sequence independently. Without per-sequence trim, you're stuck either trimming all sequences to min(accepted) (wasting accepted tokens) or doing an extra forward pass per sequence.

The fix is small (~24 lines added, 2 lines changed in mlx_lm/models/cache.py):

update_and_fetch and make_mask now compute end = max(left_padding + offset) instead of using _idx directly. When all offsets are uniform (the normal case), max(left_padding + offset) == _idx, so behavior is identical. When offsets diverge after a per-sequence trim, each sequence sees its correct cache window.

make_mask computes right_padding = end - (left_padding + offset) for sequences that are shorter than end, and passes it to create_causal_mask() which already supports the right_padding parameter (used by BatchRotatingKVCache).

trim_per_sequence(n) is a new method that takes an mx.array of shape (B,) specifying how many tokens to trim from each sequence:

def trim_per_sequence(self, n: mx.array):
    n = mx.minimum(n, self.left_padding + self.offset)
    self.offset -= n
    self._idx = int(mx.max(self.left_padding + self.offset).item())

This mirrors the existing trim(n) method but operates per-sequence. The _idx scalar is updated to max(left_padding + offset) so the allocated buffer still covers all sequences.

Example usage (speculative decoding)

# After verifying draft tokens, each sequence accepted a different count
# Sequence 0 accepted 3 drafts, sequence 1 accepted 1, sequence 2 accepted 5
# Need to trim the speculative tokens that were rejected:
rejected = mx.array([2, 4, 0])  # = draft_len - accepted
for layer_cache in cache:
    layer_cache.trim_per_sequence(rejected)

Backward compatibility

  • When all sequences have uniform offsets (the standard path), max(left_padding + offset) == _idx and right_padding is all zeros (skipped). The output is identical.
  • The existing trim(n) method is unchanged.
  • filter(), extend(), extract(), merge() all work as before since they reconstruct _idx from the buffer.

Test plan

Tests are in tests/test_batchkvcache_per_seq_trim.py (17 tests across 4 classes):

  • Variable trim: trim different amounts per sequence and verify each sequence's keys/values and mask are correct
  • Multi-step trim: trim, generate more tokens, trim again -- verifies offset tracking across multiple trims
  • Uniform trim equivalence: trim_per_sequence(mx.array([n, n, n])) produces identical results to trim(n)
  • Edge cases: trim zero, trim to empty, trim more than available (clamped), single sequence batch, batch size 1
  • Backward compatibility: standard prefill + decode without any per-sequence trim produces identical results to the unmodified code
  • Real model tests: end-to-end with a quantized model to verify correct generation after per-sequence trim (skipped if model not available locally)

All 17 tests pass locally on Apple Silicon.

- update_and_fetch(): use max(left_padding + offset) for return slice
- make_mask(): compute right_padding for divergent per-sequence offsets
- New trim_per_sequence(n) method for variable per-sequence trimming

Enables batched speculative decoding where each sequence accepts a
different number of draft tokens. Fully backward compatible — when
offsets are uniform (normal case), behavior is identical.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant