@@ -181,8 +181,7 @@ def setup_arg_parser():
181181 parser .add_argument (
182182 "--kv-bits" ,
183183 type = int ,
184- help = "Number of bits for KV cache quantization. "
185- "Defaults to no quantization." ,
184+ help = "Number of bits for KV cache quantization. Defaults to no quantization." ,
186185 default = None ,
187186 )
188187 parser .add_argument (
@@ -548,7 +547,9 @@ def _step(model, cache, y, n_predict=1):
548547 y = y [: - (n_predict - 1 )]
549548 for i in range (n_predict ):
550549 prev_tokens = (
551- mx .concat ([prev_tokens , y ]) if prev_tokens is not None else y
550+ mx .concatenate ([prev_tokens , y ])
551+ if prev_tokens is not None
552+ else y
552553 )
553554 y , logprobs = _process_and_sample (prev_tokens , logits [:, i , :])
554555 out_y .append (y )
@@ -840,6 +841,9 @@ class Batch:
840841 max_tokens : List [int ]
841842 num_tokens : List [int ]
842843 cache : List [Any ]
844+ samplers : List [Any ]
845+ logits_processors : List [Any ]
846+ tokens : List [mx .array ]
843847
844848 def __len__ (self ):
845849 return len (self .uids )
@@ -849,6 +853,9 @@ def filter(self, keep_idx: List[int]):
849853 self .logprobs = [self .logprobs [k ] for k in keep_idx ]
850854 self .max_tokens = [self .max_tokens [k ] for k in keep_idx ]
851855 self .num_tokens = [self .num_tokens [k ] for k in keep_idx ]
856+ self .samplers = [self .samplers [k ] for k in keep_idx ]
857+ self .logits_processors = [self .logits_processors [k ] for k in keep_idx ]
858+ self .tokens = [self .tokens [k ] for k in keep_idx ]
852859 keep_idx = mx .array (keep_idx , mx .int32 )
853860 self .y = self .y [keep_idx ]
854861 for c in self .cache :
@@ -860,6 +867,9 @@ def extend(self, other):
860867 self .logprobs .extend (other .logprobs )
861868 self .num_tokens .extend (other .num_tokens )
862869 self .max_tokens .extend (other .max_tokens )
870+ self .samplers .extend (other .samplers )
871+ self .logits_processors .extend (other .logits_processors )
872+ self .tokens .extend (other .tokens )
863873 for c , o in zip (self .cache , other .cache ):
864874 c .extend (o )
865875
@@ -912,7 +922,6 @@ def _merge_caches(caches):
912922
913923
914924class BatchGenerator :
915-
916925 @dataclass
917926 class Response :
918927 uid : int
@@ -927,6 +936,9 @@ def __init__(
927936 max_tokens : int = 128 ,
928937 stop_tokens : Optional [set ] = None ,
929938 sampler : Optional [Callable [[mx .array ], mx .array ]] = None ,
939+ logits_processors : Optional [
940+ List [Callable [[mx .array , mx .array ], mx .array ]]
941+ ] = None ,
930942 completion_batch_size : int = 32 ,
931943 prefill_batch_size : int = 8 ,
932944 prefill_step_size : int = 2048 ,
@@ -939,6 +951,7 @@ def __init__(
939951 self .max_tokens = max_tokens
940952 self .stop_tokens = stop_tokens or set ()
941953 self .sampler = sampler or (lambda x : mx .argmax (x , axis = - 1 ))
954+ self .logits_processors = logits_processors or []
942955 self .uid_count = 0
943956 self .prefill_step_size = prefill_step_size
944957 self .prefill_batch_size = prefill_batch_size
@@ -965,7 +978,12 @@ def __del__(self):
965978 self .close ()
966979
967980 def insert (
968- self , prompts , max_tokens : Union [List [int ], int , None ] = None , caches = None
981+ self ,
982+ prompts ,
983+ max_tokens : Union [List [int ], int , None ] = None ,
984+ caches = None ,
985+ samplers : list | None = None ,
986+ logits_processors : list | None = None ,
969987 ):
970988 uids = []
971989
@@ -978,8 +996,13 @@ def insert(
978996 if caches [i ] is None :
979997 caches [i ] = cache .make_prompt_cache (self .model )
980998
981- for p , m , c in zip (prompts , max_tokens , caches ):
982- self .unprocessed_prompts .append ((self .uid_count , p , m , c ))
999+ samplers = samplers or [None ] * len (prompts )
1000+ logits_processors = logits_processors or [self .logits_processors ] * len (prompts )
1001+
1002+ for p , m , c , s , lp in zip (
1003+ prompts , max_tokens , caches , samplers , logits_processors
1004+ ):
1005+ self .unprocessed_prompts .append ((self .uid_count , p , m , c , s , lp ))
9831006 uids .append (self .uid_count )
9841007 self .uid_count += 1
9851008 # Sort in ascending order of length
@@ -1003,7 +1026,7 @@ def remove(self, uids: List[int]):
10031026 self .unprocessed_prompts .pop (i )
10041027
10051028 def _process_prompts (self , prompts ):
1006- uids , inputs , max_tokens , caches = zip (* prompts )
1029+ uids , inputs , max_tokens , caches , samplers , logits_processors = zip (* prompts )
10071030
10081031 cache_lengths = [cache .cache_length (c ) for c in caches ]
10091032 max_cache_length = max (cache_lengths )
@@ -1013,6 +1036,7 @@ def _process_prompts(self, prompts):
10131036
10141037 self ._stats .prompt_tokens += sum (lengths )
10151038
1039+ tokens = [mx .array (inp ) for inp in inputs ]
10161040 processed_tokens = 0
10171041
10181042 # New prompts so
@@ -1069,17 +1093,56 @@ def _process_prompts(self, prompts):
10691093 mx .clear_cache ()
10701094 inputs = last_inputs
10711095
1072- y , logprobs = self ._step (inputs , prompt_cache )
1096+ y , logprobs = self ._step (
1097+ inputs , prompt_cache , samplers , logits_processors , tokens
1098+ )
10731099 mx .async_eval (y , logprobs )
1100+
10741101 return Batch (
1075- list (uids ), y , logprobs , list (max_tokens ), [0 ] * len (uids ), prompt_cache
1102+ list (uids ),
1103+ y ,
1104+ logprobs ,
1105+ list (max_tokens ),
1106+ [0 ] * len (uids ),
1107+ prompt_cache ,
1108+ list (samplers ),
1109+ list (logits_processors ),
1110+ tokens ,
10761111 )
10771112
1078- def _step (self , input_tokens : mx .array , prompt_cache : List [Any ]):
1113+ def _step (
1114+ self ,
1115+ input_tokens : mx .array ,
1116+ prompt_cache : List [Any ],
1117+ samplers : list | None ,
1118+ logits_processors : list | None ,
1119+ tokens : List [mx .array ],
1120+ ):
1121+ batch_size = input_tokens .shape [0 ]
1122+
10791123 logits = self .model (input_tokens , cache = prompt_cache )
10801124 logits = logits [:, - 1 , :]
1125+
1126+ if any (logits_processors ):
1127+ processed_logits = []
1128+ for e in range (batch_size ):
1129+ sample_logits = logits [e : e + 1 ]
1130+ for processor in logits_processors [e ]:
1131+ sample_logits = processor (tokens [e ], sample_logits )
1132+ processed_logits .append (sample_logits )
1133+ logits = mx .concatenate (processed_logits , axis = 0 )
1134+
10811135 logprobs = logits - mx .logsumexp (logits , axis = - 1 , keepdims = True )
1082- sampled = self .sampler (logprobs )
1136+ if any (samplers ):
1137+ all_samples = []
1138+ for e in range (batch_size ):
1139+ sample_sampler = samplers [e ] or self .sampler
1140+ sampled = sample_sampler (logprobs [e : e + 1 ])
1141+ all_samples .append (sampled )
1142+ sampled = mx .concatenate (all_samples , axis = 0 )
1143+ else :
1144+ sampled = self .sampler (logprobs )
1145+
10831146 return sampled , list (logprobs )
10841147
10851148 def stats (self ):
@@ -1129,7 +1192,16 @@ def _next(self):
11291192
11301193 batch = self .active_batch
11311194 y , logprobs = batch .y , batch .logprobs
1132- batch .y , batch .logprobs = self ._step (y [:, None ], batch .cache )
1195+ for i , toks in enumerate (batch .tokens ):
1196+ batch .tokens [i ] = mx .concatenate ((toks , y [i : i + 1 ]))
1197+ batch .y , batch .logprobs = self ._step (
1198+ y [:, None ],
1199+ batch .cache ,
1200+ batch .samplers ,
1201+ batch .logits_processors ,
1202+ batch .tokens ,
1203+ )
1204+
11331205 mx .async_eval (batch .y , batch .logprobs )
11341206
11351207 y = y .tolist ()
@@ -1184,6 +1256,7 @@ def batch_generate(
11841256 max_tokens : Union [int , List [int ]] = 128 ,
11851257 verbose : bool = False ,
11861258 return_prompt_caches : bool = False ,
1259+ logits_processors : Optional [List [Callable [[mx .array , mx .array ], mx .array ]]] = None ,
11871260 ** kwargs ,
11881261) -> BatchResponse :
11891262 """
@@ -1202,11 +1275,17 @@ def batch_generate(
12021275 can be per prompt if a list is provided.
12031276 return_prompt_caches (bool): Return the prompt caches in the batch
12041277 responses. Default: ``False``.
1278+ logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
1279+ A list of functions that take tokens and logits and return the processed logits. Default: ``None``.
12051280 kwargs: The remaining options get passed to :obj:`BatchGenerator`.
12061281 See :obj:`BatchGenerator` for more details.
12071282 """
12081283
1209- gen = BatchGenerator (model , stop_tokens = tokenizer .eos_token_ids , ** kwargs )
1284+ gen = BatchGenerator (
1285+ model ,
1286+ stop_tokens = tokenizer .eos_token_ids ,
1287+ ** kwargs ,
1288+ )
12101289 num_samples = len (prompts )
12111290 fin = 0
12121291 if verbose :
0 commit comments