feat: allow callers to cancel stream generation via callback check, and ensure prompt cache consistency#710
Open
zhutao100 wants to merge 1 commit intoml-explore:mainfrom
Open
Conversation
…nd ensure prompt cache consistency.
## 1.1 Add a first-class cancellation signal
- Add `class GenerationCancelled(Exception): ...` near `GenerationResponse`.
- Add a new optional kwarg to both token generators:
- `generate_step(..., should_cancel: Callable[[], bool] | None = None, ...)`
- `speculative_generate_step(..., should_cancel: Callable[[], bool] | None = None, prompt_progress_callback: Callable[[int,int], None] | None = None, ...)`
## 1.2 Cooperative cancel in `generate_step` (non-speculative)
Implement cancellation checks exactly at safe boundaries:
- Initialize: `should_cancel = should_cancel or (lambda: False)`
- Before prefill loop starts (inside the `with mx.stream(generation_stream)`): if `should_cancel(): raise GenerationCancelled()`
- For each prefill chunk iteration:
- check `should_cancel()` **before** `_model_call(...)` for that chunk
- after `mx.eval([c.state for c in prompt_cache])` + `prompt_progress_callback(...)`, check `should_cancel()` and raise
- Before the “first-step” call (`y, logprobs = _step(input_tokens=prompt, ...)`): if `should_cancel(): raise GenerationCancelled()`
- Decode loop: check `should_cancel()` once per iteration at a consistent point (e.g., top of `while True`) and raise.
Notes:
- Keep using existing `prompt_progress_callback` calls for prefill chunks.
- Don’t introduce new “heartbeat yields”; cancellation should be via exception so callers can treat it as non-error.
## 1.3 Add equivalent hooks/cancel points to `speculative_generate_step`
This needs two parts: (a) prompt-progress parity, (b) safe cancel points.
**(a) Reuse `prompt_progress_callback`**
- Add `prompt_progress_callback` param (same signature as in `generate_step`).
- Move speculative prefill to a “prefill all-but-last-token” structure (like `generate_step`) so you can report progress per chunk:
- For each chunk processed:
- run both draft + main model on the same chunk
- `mx.eval([c.state for c in draft_cache])` and `mx.eval([c.state for c in model_cache])`
- update `prompt_processed_tokens` and call `prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)`
**(b) Cancellation checks**
- Same pattern: check at prefill chunk boundaries and **before** entering the first speculative step.
- In the speculative decode loop, only check cancellation at boundaries that preserve cache consistency (e.g., before `_draft_generate(...)` begins, and after `_rewind_cache(...)` completes), not mid-way through draft generation unless you also add rollback logic there.
## 1.4 Thread cancellation through `stream_generate`
- Add no new top-level API function; keep `stream_generate(..., **kwargs)` but:
- Stop dropping `prompt_progress_callback` for speculative (`remove kwargs.pop("prompt_progress_callback", None)`).
- Ensure `should_cancel` passes through to either `generate_step` or `speculative_generate_step`.
- Wrap the token loop so `GenerationCancelled` propagates cleanly (no “final response” emission on cancel).
Author
|
Example usage in the inference workflow: zhutao100/mlx-omni-server@264fd30 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
feat: allow callers to cancel stream generation via callback check, and ensure prompt cache consistency.
1.1 Add a first-class cancellation signal
class GenerationCancelled(Exception): ...nearGenerationResponse.generate_step(..., should_cancel: Callable[[], bool] | None = None, ...)speculative_generate_step(..., should_cancel: Callable[[], bool] | None = None, prompt_progress_callback: Callable[[int,int], None] | None = None, ...)1.2 Cooperative cancel in
generate_step(non-speculative) Implement cancellation checks exactly at safe boundaries:should_cancel = should_cancel or (lambda: False)with mx.stream(generation_stream)): ifshould_cancel(): raise GenerationCancelled()should_cancel()before_model_call(...)for that chunkmx.eval([c.state for c in prompt_cache])+prompt_progress_callback(...), checkshould_cancel()and raisey, logprobs = _step(input_tokens=prompt, ...)): ifshould_cancel(): raise GenerationCancelled()should_cancel()once per iteration at a consistent point (e.g., top ofwhile True) and raise.Notes:
prompt_progress_callbackcalls for prefill chunks.1.3 Add equivalent hooks/cancel points to
speculative_generate_stepThis needs two parts: (a) prompt-progress parity, (b) safe cancel points.(a) Reuse
prompt_progress_callbackprompt_progress_callbackparam (same signature as ingenerate_step).generate_step) so you can report progress per chunk:mx.eval([c.state for c in draft_cache])andmx.eval([c.state for c in model_cache])prompt_processed_tokensand callprompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)(b) Cancellation checks
_draft_generate(...)begins, and after_rewind_cache(...)completes), not mid-way through draft generation unless you also add rollback logic there.1.4 Thread cancellation through
stream_generatestream_generate(..., **kwargs)but:prompt_progress_callbackfor speculative (remove kwargs.pop("prompt_progress_callback", None)).should_cancelpasses through to eithergenerate_steporspeculative_generate_step.GenerationCancelledpropagates cleanly (no “final response” emission on cancel).