Skip to content

feat: allow callers to cancel stream generation via callback check, and ensure prompt cache consistency#710

Open
zhutao100 wants to merge 1 commit intoml-explore:mainfrom
zhutao100:feat-cancel-stream-generate
Open

feat: allow callers to cancel stream generation via callback check, and ensure prompt cache consistency#710
zhutao100 wants to merge 1 commit intoml-explore:mainfrom
zhutao100:feat-cancel-stream-generate

Conversation

@zhutao100
Copy link

feat: allow callers to cancel stream generation via callback check, and ensure prompt cache consistency.

1.1 Add a first-class cancellation signal

  • Add class GenerationCancelled(Exception): ... near GenerationResponse.
  • Add a new optional kwarg to both token generators:
    • generate_step(..., should_cancel: Callable[[], bool] | None = None, ...)
    • speculative_generate_step(..., should_cancel: Callable[[], bool] | None = None, prompt_progress_callback: Callable[[int,int], None] | None = None, ...)

1.2 Cooperative cancel in generate_step (non-speculative) Implement cancellation checks exactly at safe boundaries:

  • Initialize: should_cancel = should_cancel or (lambda: False)
  • Before prefill loop starts (inside the with mx.stream(generation_stream)): if should_cancel(): raise GenerationCancelled()
  • For each prefill chunk iteration:
    • check should_cancel() before _model_call(...) for that chunk
    • after mx.eval([c.state for c in prompt_cache]) + prompt_progress_callback(...), check should_cancel() and raise
  • Before the “first-step” call (y, logprobs = _step(input_tokens=prompt, ...)): if should_cancel(): raise GenerationCancelled()
  • Decode loop: check should_cancel() once per iteration at a consistent point (e.g., top of while True) and raise.

Notes:

  • Keep using existing prompt_progress_callback calls for prefill chunks.
  • Don’t introduce new “heartbeat yields”; cancellation should be via exception so callers can treat it as non-error.

1.3 Add equivalent hooks/cancel points to speculative_generate_step This needs two parts: (a) prompt-progress parity, (b) safe cancel points.

(a) Reuse prompt_progress_callback

  • Add prompt_progress_callback param (same signature as in generate_step).
  • Move speculative prefill to a “prefill all-but-last-token” structure (like generate_step) so you can report progress per chunk:
    • For each chunk processed:
      • run both draft + main model on the same chunk
      • mx.eval([c.state for c in draft_cache]) and mx.eval([c.state for c in model_cache])
      • update prompt_processed_tokens and call prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)

(b) Cancellation checks

  • Same pattern: check at prefill chunk boundaries and before entering the first speculative step.
  • In the speculative decode loop, only check cancellation at boundaries that preserve cache consistency (e.g., before _draft_generate(...) begins, and after _rewind_cache(...) completes), not mid-way through draft generation unless you also add rollback logic there.

1.4 Thread cancellation through stream_generate

  • Add no new top-level API function; keep stream_generate(..., **kwargs) but:
    • Stop dropping prompt_progress_callback for speculative (remove kwargs.pop("prompt_progress_callback", None)).
    • Ensure should_cancel passes through to either generate_step or speculative_generate_step.
  • Wrap the token loop so GenerationCancelled propagates cleanly (no “final response” emission on cancel).

…nd ensure prompt cache consistency.

## 1.1 Add a first-class cancellation signal
- Add `class GenerationCancelled(Exception): ...` near `GenerationResponse`.
- Add a new optional kwarg to both token generators:
  - `generate_step(..., should_cancel: Callable[[], bool] | None = None, ...)`
  - `speculative_generate_step(..., should_cancel: Callable[[], bool] | None = None, prompt_progress_callback: Callable[[int,int], None] | None = None, ...)`

## 1.2 Cooperative cancel in `generate_step` (non-speculative)
Implement cancellation checks exactly at safe boundaries:
- Initialize: `should_cancel = should_cancel or (lambda: False)`
- Before prefill loop starts (inside the `with mx.stream(generation_stream)`): if `should_cancel(): raise GenerationCancelled()`
- For each prefill chunk iteration:
  - check `should_cancel()` **before** `_model_call(...)` for that chunk
  - after `mx.eval([c.state for c in prompt_cache])` + `prompt_progress_callback(...)`, check `should_cancel()` and raise
- Before the “first-step” call (`y, logprobs = _step(input_tokens=prompt, ...)`): if `should_cancel(): raise GenerationCancelled()`
- Decode loop: check `should_cancel()` once per iteration at a consistent point (e.g., top of `while True`) and raise.

Notes:
- Keep using existing `prompt_progress_callback` calls for prefill chunks.
- Don’t introduce new “heartbeat yields”; cancellation should be via exception so callers can treat it as non-error.

## 1.3 Add equivalent hooks/cancel points to `speculative_generate_step`
This needs two parts: (a) prompt-progress parity, (b) safe cancel points.

**(a) Reuse `prompt_progress_callback`**
- Add `prompt_progress_callback` param (same signature as in `generate_step`).
- Move speculative prefill to a “prefill all-but-last-token” structure (like `generate_step`) so you can report progress per chunk:
  - For each chunk processed:
    - run both draft + main model on the same chunk
    - `mx.eval([c.state for c in draft_cache])` and `mx.eval([c.state for c in model_cache])`
    - update `prompt_processed_tokens` and call `prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)`

**(b) Cancellation checks**
- Same pattern: check at prefill chunk boundaries and **before** entering the first speculative step.
- In the speculative decode loop, only check cancellation at boundaries that preserve cache consistency (e.g., before `_draft_generate(...)` begins, and after `_rewind_cache(...)` completes), not mid-way through draft generation unless you also add rollback logic there.

## 1.4 Thread cancellation through `stream_generate`
- Add no new top-level API function; keep `stream_generate(..., **kwargs)` but:
  - Stop dropping `prompt_progress_callback` for speculative (`remove kwargs.pop("prompt_progress_callback", None)`).
  - Ensure `should_cancel` passes through to either `generate_step` or `speculative_generate_step`.
- Wrap the token loop so `GenerationCancelled` propagates cleanly (no “final response” emission on cancel).
@zhutao100
Copy link
Author

Example usage in the inference workflow: zhutao100/mlx-omni-server@264fd30

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant