Skip to content

Commit d3dc2e3

Browse files
lazarustawni
andauthored
Add logits_processors support to batch_generate (#635)
* Update generate.py * add samplers and logits processors with per example option and server support --------- Co-authored-by: Awni Hannun <awni@apple.com>
1 parent 7423bf6 commit d3dc2e3

File tree

3 files changed

+199
-52
lines changed

3 files changed

+199
-52
lines changed

mlx_lm/generate.py

Lines changed: 93 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,7 @@ def setup_arg_parser():
181181
parser.add_argument(
182182
"--kv-bits",
183183
type=int,
184-
help="Number of bits for KV cache quantization. "
185-
"Defaults to no quantization.",
184+
help="Number of bits for KV cache quantization. Defaults to no quantization.",
186185
default=None,
187186
)
188187
parser.add_argument(
@@ -548,7 +547,9 @@ def _step(model, cache, y, n_predict=1):
548547
y = y[: -(n_predict - 1)]
549548
for i in range(n_predict):
550549
prev_tokens = (
551-
mx.concat([prev_tokens, y]) if prev_tokens is not None else y
550+
mx.concatenate([prev_tokens, y])
551+
if prev_tokens is not None
552+
else y
552553
)
553554
y, logprobs = _process_and_sample(prev_tokens, logits[:, i, :])
554555
out_y.append(y)
@@ -840,6 +841,9 @@ class Batch:
840841
max_tokens: List[int]
841842
num_tokens: List[int]
842843
cache: List[Any]
844+
samplers: List[Any]
845+
logits_processors: List[Any]
846+
tokens: List[mx.array]
843847

844848
def __len__(self):
845849
return len(self.uids)
@@ -849,6 +853,9 @@ def filter(self, keep_idx: List[int]):
849853
self.logprobs = [self.logprobs[k] for k in keep_idx]
850854
self.max_tokens = [self.max_tokens[k] for k in keep_idx]
851855
self.num_tokens = [self.num_tokens[k] for k in keep_idx]
856+
self.samplers = [self.samplers[k] for k in keep_idx]
857+
self.logits_processors = [self.logits_processors[k] for k in keep_idx]
858+
self.tokens = [self.tokens[k] for k in keep_idx]
852859
keep_idx = mx.array(keep_idx, mx.int32)
853860
self.y = self.y[keep_idx]
854861
for c in self.cache:
@@ -860,6 +867,9 @@ def extend(self, other):
860867
self.logprobs.extend(other.logprobs)
861868
self.num_tokens.extend(other.num_tokens)
862869
self.max_tokens.extend(other.max_tokens)
870+
self.samplers.extend(other.samplers)
871+
self.logits_processors.extend(other.logits_processors)
872+
self.tokens.extend(other.tokens)
863873
for c, o in zip(self.cache, other.cache):
864874
c.extend(o)
865875

@@ -912,7 +922,6 @@ def _merge_caches(caches):
912922

913923

914924
class BatchGenerator:
915-
916925
@dataclass
917926
class Response:
918927
uid: int
@@ -927,6 +936,9 @@ def __init__(
927936
max_tokens: int = 128,
928937
stop_tokens: Optional[set] = None,
929938
sampler: Optional[Callable[[mx.array], mx.array]] = None,
939+
logits_processors: Optional[
940+
List[Callable[[mx.array, mx.array], mx.array]]
941+
] = None,
930942
completion_batch_size: int = 32,
931943
prefill_batch_size: int = 8,
932944
prefill_step_size: int = 2048,
@@ -939,6 +951,7 @@ def __init__(
939951
self.max_tokens = max_tokens
940952
self.stop_tokens = stop_tokens or set()
941953
self.sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
954+
self.logits_processors = logits_processors or []
942955
self.uid_count = 0
943956
self.prefill_step_size = prefill_step_size
944957
self.prefill_batch_size = prefill_batch_size
@@ -965,7 +978,12 @@ def __del__(self):
965978
self.close()
966979

967980
def insert(
968-
self, prompts, max_tokens: Union[List[int], int, None] = None, caches=None
981+
self,
982+
prompts,
983+
max_tokens: Union[List[int], int, None] = None,
984+
caches=None,
985+
samplers: list | None = None,
986+
logits_processors: list | None = None,
969987
):
970988
uids = []
971989

@@ -978,8 +996,13 @@ def insert(
978996
if caches[i] is None:
979997
caches[i] = cache.make_prompt_cache(self.model)
980998

981-
for p, m, c in zip(prompts, max_tokens, caches):
982-
self.unprocessed_prompts.append((self.uid_count, p, m, c))
999+
samplers = samplers or [None] * len(prompts)
1000+
logits_processors = logits_processors or [self.logits_processors] * len(prompts)
1001+
1002+
for p, m, c, s, lp in zip(
1003+
prompts, max_tokens, caches, samplers, logits_processors
1004+
):
1005+
self.unprocessed_prompts.append((self.uid_count, p, m, c, s, lp))
9831006
uids.append(self.uid_count)
9841007
self.uid_count += 1
9851008
# Sort in ascending order of length
@@ -1003,7 +1026,7 @@ def remove(self, uids: List[int]):
10031026
self.unprocessed_prompts.pop(i)
10041027

10051028
def _process_prompts(self, prompts):
1006-
uids, inputs, max_tokens, caches = zip(*prompts)
1029+
uids, inputs, max_tokens, caches, samplers, logits_processors = zip(*prompts)
10071030

10081031
cache_lengths = [cache.cache_length(c) for c in caches]
10091032
max_cache_length = max(cache_lengths)
@@ -1013,6 +1036,7 @@ def _process_prompts(self, prompts):
10131036

10141037
self._stats.prompt_tokens += sum(lengths)
10151038

1039+
tokens = [mx.array(inp) for inp in inputs]
10161040
processed_tokens = 0
10171041

10181042
# New prompts so
@@ -1069,17 +1093,56 @@ def _process_prompts(self, prompts):
10691093
mx.clear_cache()
10701094
inputs = last_inputs
10711095

1072-
y, logprobs = self._step(inputs, prompt_cache)
1096+
y, logprobs = self._step(
1097+
inputs, prompt_cache, samplers, logits_processors, tokens
1098+
)
10731099
mx.async_eval(y, logprobs)
1100+
10741101
return Batch(
1075-
list(uids), y, logprobs, list(max_tokens), [0] * len(uids), prompt_cache
1102+
list(uids),
1103+
y,
1104+
logprobs,
1105+
list(max_tokens),
1106+
[0] * len(uids),
1107+
prompt_cache,
1108+
list(samplers),
1109+
list(logits_processors),
1110+
tokens,
10761111
)
10771112

1078-
def _step(self, input_tokens: mx.array, prompt_cache: List[Any]):
1113+
def _step(
1114+
self,
1115+
input_tokens: mx.array,
1116+
prompt_cache: List[Any],
1117+
samplers: list | None,
1118+
logits_processors: list | None,
1119+
tokens: List[mx.array],
1120+
):
1121+
batch_size = input_tokens.shape[0]
1122+
10791123
logits = self.model(input_tokens, cache=prompt_cache)
10801124
logits = logits[:, -1, :]
1125+
1126+
if any(logits_processors):
1127+
processed_logits = []
1128+
for e in range(batch_size):
1129+
sample_logits = logits[e : e + 1]
1130+
for processor in logits_processors[e]:
1131+
sample_logits = processor(tokens[e], sample_logits)
1132+
processed_logits.append(sample_logits)
1133+
logits = mx.concatenate(processed_logits, axis=0)
1134+
10811135
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
1082-
sampled = self.sampler(logprobs)
1136+
if any(samplers):
1137+
all_samples = []
1138+
for e in range(batch_size):
1139+
sample_sampler = samplers[e] or self.sampler
1140+
sampled = sample_sampler(logprobs[e : e + 1])
1141+
all_samples.append(sampled)
1142+
sampled = mx.concatenate(all_samples, axis=0)
1143+
else:
1144+
sampled = self.sampler(logprobs)
1145+
10831146
return sampled, list(logprobs)
10841147

10851148
def stats(self):
@@ -1129,7 +1192,16 @@ def _next(self):
11291192

11301193
batch = self.active_batch
11311194
y, logprobs = batch.y, batch.logprobs
1132-
batch.y, batch.logprobs = self._step(y[:, None], batch.cache)
1195+
for i, toks in enumerate(batch.tokens):
1196+
batch.tokens[i] = mx.concatenate((toks, y[i : i + 1]))
1197+
batch.y, batch.logprobs = self._step(
1198+
y[:, None],
1199+
batch.cache,
1200+
batch.samplers,
1201+
batch.logits_processors,
1202+
batch.tokens,
1203+
)
1204+
11331205
mx.async_eval(batch.y, batch.logprobs)
11341206

11351207
y = y.tolist()
@@ -1184,6 +1256,7 @@ def batch_generate(
11841256
max_tokens: Union[int, List[int]] = 128,
11851257
verbose: bool = False,
11861258
return_prompt_caches: bool = False,
1259+
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
11871260
**kwargs,
11881261
) -> BatchResponse:
11891262
"""
@@ -1202,11 +1275,17 @@ def batch_generate(
12021275
can be per prompt if a list is provided.
12031276
return_prompt_caches (bool): Return the prompt caches in the batch
12041277
responses. Default: ``False``.
1278+
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
1279+
A list of functions that take tokens and logits and return the processed logits. Default: ``None``.
12051280
kwargs: The remaining options get passed to :obj:`BatchGenerator`.
12061281
See :obj:`BatchGenerator` for more details.
12071282
"""
12081283

1209-
gen = BatchGenerator(model, stop_tokens=tokenizer.eos_token_ids, **kwargs)
1284+
gen = BatchGenerator(
1285+
model,
1286+
stop_tokens=tokenizer.eos_token_ids,
1287+
**kwargs,
1288+
)
12101289
num_samples = len(prompts)
12111290
fin = 0
12121291
if verbose:

mlx_lm/server.py

Lines changed: 30 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,29 @@ def validate_draft_tokenizer(draft_tokenizer):
467467
return self.model, self.tokenizer
468468

469469

470+
def _make_sampler(args, tokenizer):
471+
return make_sampler(
472+
args.sampling.temperature,
473+
top_p=args.sampling.top_p,
474+
top_k=args.sampling.top_k,
475+
min_p=args.sampling.min_p,
476+
xtc_probability=args.sampling.xtc_probability,
477+
xtc_threshold=args.sampling.xtc_threshold,
478+
xtc_special_tokens=[
479+
tokenizer.eos_token_id,
480+
tokenizer.encode("\n"),
481+
],
482+
)
483+
484+
485+
def _make_logits_processors(args):
486+
return make_logits_processors(
487+
args.logits.logit_bias,
488+
args.logits.repetition_penalty,
489+
args.logits.repetition_context_size,
490+
)
491+
492+
470493
class ResponseGenerator:
471494
def __init__(self, model_provider: ModelProvider, prompt_cache: LRUPromptCache):
472495
self.model_provider = model_provider
@@ -510,12 +533,6 @@ def _is_batchable(self, args):
510533
for c in self.model_provider.cache_types:
511534
if c not in (KVCache, RotatingKVCache):
512535
return False
513-
if args.logits.logit_bias is not None:
514-
return False
515-
if args.logits.repetition_penalty != 0:
516-
return False
517-
if args.logprobs > 0:
518-
return False
519536
if args.seed is not None:
520537
return False
521538

@@ -569,7 +586,6 @@ def progress_callback(info):
569586
if (
570587
batch_generator is not None
571588
and current_model == args.model
572-
and current_sampling == args.sampling
573589
and is_batchable
574590
):
575591
prompt = self._tokenize(current_tokenizer, request)
@@ -593,7 +609,11 @@ def progress_callback(info):
593609
cache = make_prompt_cache(self.model_provider.model)
594610

595611
(uid,) = batch_generator.insert(
596-
[rest], args.max_tokens, caches=[cache]
612+
[rest],
613+
args.max_tokens,
614+
caches=[cache],
615+
samplers=[_make_sampler(args, tokenizer)],
616+
logits_processors=[_make_logits_processors(args)],
597617
)
598618
batch_results[uid] = {
599619
"ctx": ctx,
@@ -620,25 +640,12 @@ def progress_callback(info):
620640
continue
621641

622642
current_model = args.model
623-
current_sampling = args.sampling
624643
current_tokenizer = tokenizer
625644
current_model_key = self.model_provider.model_key
626645
batch_results = {}
627646
batch_generator = BatchGenerator(
628647
model,
629648
stop_tokens=tokenizer.eos_token_ids,
630-
sampler=make_sampler(
631-
args.sampling.temperature,
632-
top_p=args.sampling.top_p,
633-
top_k=args.sampling.top_k,
634-
min_p=args.sampling.min_p,
635-
xtc_probability=args.sampling.xtc_probability,
636-
xtc_threshold=args.sampling.xtc_threshold,
637-
xtc_special_tokens=[
638-
tokenizer.eos_token_id,
639-
tokenizer.encode("\n"),
640-
],
641-
),
642649
prompt_progress_callback=progress_callback,
643650
)
644651
unprocessed_requests.append((rqueue, request, args))
@@ -750,23 +757,8 @@ def progress(tokens_processed, tokens_total):
750757
mx.random.seed(args.seed)
751758

752759
# Make the sampler and logit processor
753-
sampler = make_sampler(
754-
args.sampling.temperature,
755-
top_p=args.sampling.top_p,
756-
top_k=args.sampling.top_k,
757-
min_p=args.sampling.min_p,
758-
xtc_probability=args.sampling.xtc_probability,
759-
xtc_threshold=args.sampling.xtc_threshold,
760-
xtc_special_tokens=[
761-
tokenizer.eos_token_id,
762-
tokenizer.encode("\n"),
763-
],
764-
)
765-
logits_processors = make_logits_processors(
766-
args.logits.logit_bias,
767-
args.logits.repetition_penalty,
768-
args.logits.repetition_context_size,
769-
)
760+
sampler = _make_sampler(args, tokenizer)
761+
logits_processors = _make_logits_processors(args)
770762

771763
# Load the KV cache
772764
cache, rest = self.prompt_cache.fetch_nearest_cache(

0 commit comments

Comments
 (0)