Skip to content

Commit 572ada2

Browse files
authored
server: add usage.prompt_tokens_details.cached_tokens to json response (#849)
1 parent fb47f8f commit 572ada2

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

mlx_lm/server.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ class GenerationContext:
371371
eos_token_ids: set
372372
stop_token_sequences: List[List[int]]
373373
prompt: List[int]
374+
prompt_cache_count: int = -1
374375

375376
_should_stop: bool = False
376377

@@ -753,6 +754,7 @@ def progress_callback(info):
753754
cache, rest = self.prompt_cache.fetch_nearest_cache(
754755
current_model_key, prompt
755756
)
757+
ctx.prompt_cache_count = len(prompt) - len(rest)
756758
if cache is None:
757759
cache = make_prompt_cache(self.model_provider.model)
758760

@@ -917,6 +919,7 @@ def progress(tokens_processed, tokens_total):
917919
cache, rest = self.prompt_cache.fetch_nearest_cache(
918920
self.model_provider.model_key, prompt
919921
)
922+
ctx.prompt_cache_count = len(prompt) - len(rest)
920923
cache_key = prompt[:]
921924
if cache is None:
922925
cache = make_prompt_cache(self.model_provider.model)
@@ -1184,6 +1187,7 @@ def generate_response(
11841187
finish_reason: Union[Literal["length", "stop"], None],
11851188
prompt_token_count: Optional[int] = None,
11861189
completion_token_count: Optional[int] = None,
1190+
prompt_cache_count: Optional[int] = None,
11871191
token_logprobs: Optional[List[float]] = None,
11881192
top_tokens: Optional[List[Tuple[Dict[str, Any]]]] = None,
11891193
tokens: Optional[List[int]] = None,
@@ -1202,6 +1206,8 @@ def generate_response(
12021206
used to populate the "usage" field (not used when stream).
12031207
completion_token_count (Optional[int]): The number of tokens in the
12041208
response, used to populate the "usage" field (not used when stream).
1209+
prompt_cache_count (Optional[int]): The portion of prompt_token_count
1210+
that was found in the cache when servicing the request.
12051211
token_logprobs (Optional[List[float]]): The log probabilities per token,
12061212
in token order.
12071213
top_tokens (Optional[List[Tuple[Dict[str, Any]]]]): List of outputs from
@@ -1260,6 +1266,10 @@ def generate_response(
12601266
"completion_tokens": completion_token_count,
12611267
"total_tokens": prompt_token_count + completion_token_count,
12621268
}
1269+
if prompt_cache_count is not None and prompt_cache_count >= 0:
1270+
response["usage"]["prompt_tokens_details"] = {
1271+
"cached_tokens": prompt_cache_count,
1272+
}
12631273

12641274
choice = response["choices"][0]
12651275

@@ -1501,7 +1511,11 @@ def parse_tools(tool_calls):
15011511
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
15021512
self.wfile.flush()
15031513
if self.stream_options is not None and self.stream_options["include_usage"]:
1504-
response = self.completion_usage_response(len(ctx.prompt), len(tokens))
1514+
response = self.completion_usage_response(
1515+
len(ctx.prompt),
1516+
len(tokens),
1517+
ctx.prompt_cache_count,
1518+
)
15051519
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
15061520
self.wfile.flush()
15071521
self.wfile.write("data: [DONE]\n\n".encode())
@@ -1512,6 +1526,7 @@ def parse_tools(tool_calls):
15121526
finish_reason,
15131527
len(ctx.prompt),
15141528
len(tokens),
1529+
ctx.prompt_cache_count,
15151530
token_logprobs=token_logprobs,
15161531
top_tokens=top_tokens,
15171532
tokens=tokens,
@@ -1532,6 +1547,7 @@ def completion_usage_response(
15321547
self,
15331548
prompt_token_count: Optional[int] = None,
15341549
completion_token_count: Optional[int] = None,
1550+
prompt_cache_count: Optional[int] = None,
15351551
):
15361552
response = {
15371553
"id": self.request_id,
@@ -1546,6 +1562,10 @@ def completion_usage_response(
15461562
"total_tokens": prompt_token_count + completion_token_count,
15471563
},
15481564
}
1565+
if prompt_cache_count is not None and prompt_cache_count >= 0:
1566+
response["usage"]["prompt_tokens_details"] = {
1567+
"cached_tokens": prompt_cache_count,
1568+
}
15491569
return response
15501570

15511571
def handle_chat_completions(self) -> CompletionRequest:

0 commit comments

Comments
 (0)