-
Notifications
You must be signed in to change notification settings - Fork 28.8k
Force the model to write some tokens mid-generation? #37771
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Asked ChatGPT about it: Below are some practical “entry points” in 🤗 Transformers that let you overwrite the next-token choice, even if the token you want (e.g. "failed") has lower probability than the model’s top guess. Understand where the decision is made Key locations in the source tree (v4.39.*): src/transformers/generation/utils.py – generation loops Three out-of-the-box hooks you can already use Works when you know in advance that at position N the token must be X. Callable invoked every step; you receive (batch_id, input_ids) and return the set of IDs that are legal next tokens. C. stop-generate-then-append Let generation stop after the code block, run pytest, then feed the original prompt + "Result: failed\n" back into generate() with past_key_values to keep it fast. Example: force a single token at a particular time step from transformers import LogitsProcessor class ForceTokenAtPosition(LogitsProcessor):
Usage: processor = ForceTokenAtPosition( Dynamic forcing (decide after pytest) class DynamicForceNext(LogitsProcessor):
In your outer loop run dyn_proc.set(failed_id) immediately after executing the tests. Where to drop the code in the repo Then unit-test: tests/generation/test_force_token.py def test_force_token(): Performance notes TL;DR Doesn't make much sense to me, but someone whos familiar may know what to do |
It seems like you're asking for structured generation here - cc @gante what do we recommend for people these days? Just using Outlines? |
Oh, so we don't need a custom Logits Processor? |
Feature request
Here’s an example:
User: Hello make a python function for something
Assistant: Here’s an function for that:
def function():
pass
← This is a line we tuned the model to generate
import pytest
assert foo == bar
← Execute the tests right after this token was predicted
Result: tests succeeded ← THIS is the forced tokens, we also tuned the model to generate this
Ok, looks like the function is working…
EDIT:
The LLM is trained to respond with the same block given above, however since LLMs are bad at detecting when they have done a mistake they will lean towards saying succeeded for everything.
However after the inference pass for the token “succeeded” there will be a probablity distribution e.g.
succeeded 0.5
failed 0.3
etc.
So I want to “force” the model to pick failed (or succeeded) even though it is a less likely token. Seems like something very simple, but there is no support.
Motivation
In case you didn't realize the point already, doing this could be opensource LLMs becoming significantly better for agentic workflows. unlike stopping generation, calling tools, and otherwise creating delays, this works right between inference passes. Agentic workflows for proprietary LLMs can add up costs FAST.
Your contribution
I'm not used to this codebase. It seems very complex. but the feature is very simple. Maybe is someone could give me pointers
The text was updated successfully, but these errors were encountered: