Skip to content

Commit ad585eb

Browse files
committed
add elapsed time metrics in api serve framework
1 parent 45758bf commit ad585eb

File tree

4 files changed

+114
-181
lines changed

4 files changed

+114
-181
lines changed

llm/benchmark/rl/api_serve.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -142,20 +142,6 @@ class ResponsePayload:
142142

143143

144144
class StatisticsManager:
145-
def __init__(self, responses_num: int):
146-
self.responses_num = responses_num
147-
self.batch_index = 0
148-
149-
def res_stats(self, response: List[ResponsePayload]):
150-
batch_group_pd = pd.DataFrame(response)
151-
res_batch_pd = batch_group_pd[["idx", "question", "responses"]]
152-
responses_batch_pd = pd.DataFrame(
153-
res_batch_pd["responses"].to_list(), columns=[f"response_{i+1}" for i in range(self.responses_num)]
154-
)
155-
res_batch_pd = pd.concat([res_batch_pd[["idx", "question"]], responses_batch_pd], axis=1)
156-
157-
res_batch_pd.to_json(self.res_path, orient="records", lines=True, force_ascii=False, mode="a")
158-
159145
def dispersed_stats(self, responses: List[ResponsePayload], batch_elapsed_time: float):
160146
batch_group_pd = pd.DataFrame(responses)
161147

@@ -168,6 +154,10 @@ def dispersed_stats(self, responses: List[ResponsePayload], batch_elapsed_time:
168154
"completion_time": batch_elapsed_time,
169155
"throughput_tokens_per_sec": batch_group_pd["token_lengths"].apply((lambda x: sum(x))).sum()
170156
/ batch_elapsed_time,
157+
"elapsed_times": batch_group_pd["elapsed_times"].to_list(),
158+
"min_time": batch_group_pd["elapsed_times"].apply(lambda x: min(x)).tolist(),
159+
"max_time": batch_group_pd["elapsed_times"].apply(lambda x: max(x)).tolist(),
160+
"avg_time": batch_group_pd["elapsed_times"].apply(lambda x: sum(x) / len(x)).tolist(),
171161
}
172162

173163
return dispersed_stats_dict
@@ -186,6 +176,8 @@ def global_stats(self, responses: List[ResponsePayload], batch_elapsed_time: flo
186176
global_stats_dict["avg_response_tokens"] = total_response_tokens / len(responses)
187177
global_stats_dict["total_response_tokens"] = total_response_tokens
188178
global_stats_dict["group_max_response_tokens"] = dispersed_stats_dict["max_length"]
179+
global_stats_dict["min_time"] = min(dispersed_stats_dict["min_time"])
180+
global_stats_dict["avg_time"] = sum(dispersed_stats_dict["avg_time"]) / len(responses)
189181
global_stats_dict["completion_time"] = dispersed_stats_dict["completion_time"]
190182
global_stats_dict["throughput_tokens_per_sec"] = dispersed_stats_dict["throughput_tokens_per_sec"]
191183

@@ -211,7 +203,7 @@ def __init__(self, args, max_concurrency: int = 1000):
211203
self.rollout_details_path = self.output_dir / "rollout_details.jsonl"
212204
self.status_file_path = self.output_dir / "status.txt"
213205

214-
self.stats_manager = StatisticsManager(args.rollout_output_num)
206+
self.stats_manager = StatisticsManager()
215207

216208
self._load_status()
217209

@@ -282,8 +274,8 @@ async def call(self, request: RequestPayload) -> Tuple[str, float]:
282274
text = "".join(chunks)
283275
end_time = time.perf_counter()
284276
elapsed_time = end_time - start_time
285-
logger.debug("Streaming response took %.4f seconds", elapsed_time)
286-
return text, elapsed_time
277+
logger.debug("Streaming response took %.2f seconds", elapsed_time)
278+
return text, round(elapsed_time, 2)
287279

288280
except Exception as e:
289281
logger.error("Error while streaming: %s", e)
@@ -333,6 +325,8 @@ def execute(self):
333325
"avg_response_tokens",
334326
"total_response_tokens",
335327
"group_max_response_tokens",
328+
"min_time",
329+
"avg_time",
336330
"completion_time",
337331
"throughput_tokens_per_sec",
338332
]
@@ -346,6 +340,10 @@ def execute(self):
346340
"avg_length",
347341
"completion_time",
348342
"throughput_tokens_per_sec",
343+
"elapsed_times",
344+
"min_time",
345+
"max_time",
346+
"avg_time",
349347
]
350348
)
351349

@@ -374,6 +372,8 @@ def execute(self):
374372
round(global_stats_dict["avg_response_tokens"], 2),
375373
global_stats_dict["total_response_tokens"],
376374
global_stats_dict["group_max_response_tokens"],
375+
global_stats_dict["min_time"],
376+
global_stats_dict["avg_time"],
377377
round(global_stats_dict["completion_time"], 2),
378378
round(global_stats_dict["throughput_tokens_per_sec"], 2),
379379
]
@@ -388,6 +388,10 @@ def execute(self):
388388
dispersed_stats_dict["avg_length"],
389389
round(dispersed_stats_dict["completion_time"], 2),
390390
round(dispersed_stats_dict["throughput_tokens_per_sec"], 2),
391+
dispersed_stats_dict["elapsed_times"],
392+
dispersed_stats_dict["min_time"],
393+
dispersed_stats_dict["max_time"],
394+
dispersed_stats_dict["avg_time"],
391395
]
392396
)
393397

llm/benchmark/rl/paddle_infer.py

Lines changed: 1 addition & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from dataclasses import dataclass, field
2121
from pathlib import Path
2222
from typing import List
23+
from utils import RangeSet
2324

2425
import paddle
2526
import pandas as pd
@@ -36,93 +37,6 @@
3637
from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer
3738
from paddlenlp.utils.log import logger
3839

39-
40-
@dataclass
41-
class RangeSet:
42-
"""Manage processed line ranges with efficient storage and querying"""
43-
44-
ranges: List[tuple]
45-
46-
def add(self, number: int):
47-
"""Add a number to the range set and merge adjacent ranges"""
48-
new_ranges = []
49-
added = False
50-
for start, end in sorted(self.ranges):
51-
if number < start - 1:
52-
if not added:
53-
new_ranges.append((number, number))
54-
added = True
55-
new_ranges.append((start, end))
56-
elif number == start - 1:
57-
new_ranges.append((number, end))
58-
added = True
59-
elif number <= end:
60-
new_ranges.append((start, end))
61-
added = True
62-
else:
63-
new_ranges.append((start, end))
64-
if not added:
65-
new_ranges.append((number, number))
66-
self.ranges = self.merge_ranges(new_ranges)
67-
68-
@staticmethod
69-
def merge_ranges(ranges: List[tuple]) -> List[tuple]:
70-
"""Merge overlapping or adjacent ranges"""
71-
if not ranges:
72-
return []
73-
sorted_ranges = sorted(ranges)
74-
merged = [sorted_ranges[0]]
75-
for current in sorted_ranges[1:]:
76-
last = merged[-1]
77-
if current[0] <= last[1] + 1:
78-
merged[-1] = (last[0], max(last[1], current[1]))
79-
else:
80-
merged.append(current)
81-
return merged
82-
83-
def contains(self, number: int) -> bool:
84-
"""Check if a number exists in any range"""
85-
for start, end in self.ranges:
86-
if start <= number <= end:
87-
return True
88-
return False
89-
90-
def to_file_format(self) -> str:
91-
"""Serialize ranges to compact string format"""
92-
return ",".join(f"{start}-{end}" if start != end else str(start) for start, end in self.ranges)
93-
94-
@classmethod
95-
def from_file(cls, content: str) -> "RangeSet":
96-
"""Deserialize from string format"""
97-
if not content:
98-
return cls(ranges=[])
99-
ranges = []
100-
for part in content.split(","):
101-
if "-" in part:
102-
start, end = map(int, part.split("-"))
103-
ranges.append((start, end))
104-
else:
105-
num = int(part)
106-
ranges.append((num, num))
107-
return cls(ranges=ranges)
108-
109-
@property
110-
def processed_count(self) -> int:
111-
"""Total number of processed items"""
112-
return sum(end - start + 1 for start, end in self.ranges)
113-
114-
115-
@contextmanager
116-
def switch_level_context(level="ERROR"):
117-
original_level = logger.logLevel
118-
logger.set_level(level)
119-
120-
try:
121-
yield
122-
finally:
123-
logger.set_level(original_level)
124-
125-
12640
def chunk(all_input_ids, size):
12741
if size <= 0:
12842
raise ValueError("Size must be greater than 0")

llm/benchmark/rl/torch_infer.py

Lines changed: 1 addition & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,12 @@
1818
import math
1919
import time
2020
from contextlib import contextmanager
21-
from dataclasses import dataclass
2221
from pathlib import Path
23-
from typing import List
2422

2523
import pandas as pd
2624
import tqdm
2725
from transformers import AutoTokenizer
26+
from utils import RangeSet
2827
from vllm import LLM, SamplingParams
2928

3029
from paddlenlp.utils.log import logger
@@ -70,81 +69,6 @@ def switch_level_context(level="ERROR"):
7069
logger.set_level(original_level)
7170

7271

73-
@dataclass
74-
class RangeSet:
75-
"""Manage processed line ranges with efficient storage and querying"""
76-
77-
ranges: List[tuple]
78-
79-
def add(self, number: int):
80-
"""Add a number to the range set and merge adjacent ranges"""
81-
new_ranges = []
82-
added = False
83-
for start, end in sorted(self.ranges):
84-
if number < start - 1:
85-
if not added:
86-
new_ranges.append((number, number))
87-
added = True
88-
new_ranges.append((start, end))
89-
elif number == start - 1:
90-
new_ranges.append((number, end))
91-
added = True
92-
elif number <= end:
93-
new_ranges.append((start, end))
94-
added = True
95-
else:
96-
new_ranges.append((start, end))
97-
if not added:
98-
new_ranges.append((number, number))
99-
self.ranges = self.merge_ranges(new_ranges)
100-
101-
@staticmethod
102-
def merge_ranges(ranges: List[tuple]) -> List[tuple]:
103-
"""Merge overlapping or adjacent ranges"""
104-
if not ranges:
105-
return []
106-
sorted_ranges = sorted(ranges)
107-
merged = [sorted_ranges[0]]
108-
for current in sorted_ranges[1:]:
109-
last = merged[-1]
110-
if current[0] <= last[1] + 1:
111-
merged[-1] = (last[0], max(last[1], current[1]))
112-
else:
113-
merged.append(current)
114-
return merged
115-
116-
def contains(self, number: int) -> bool:
117-
"""Check if a number exists in any range"""
118-
for start, end in self.ranges:
119-
if start <= number <= end:
120-
return True
121-
return False
122-
123-
def to_file_format(self) -> str:
124-
"""Serialize ranges to compact string format"""
125-
return ",".join(f"{start}-{end}" if start != end else str(start) for start, end in self.ranges)
126-
127-
@classmethod
128-
def from_file(cls, content: str) -> "RangeSet":
129-
"""Deserialize from string format"""
130-
if not content:
131-
return cls(ranges=[])
132-
ranges = []
133-
for part in content.split(","):
134-
if "-" in part:
135-
start, end = map(int, part.split("-"))
136-
ranges.append((start, end))
137-
else:
138-
num = int(part)
139-
ranges.append((num, num))
140-
return cls(ranges=ranges)
141-
142-
@property
143-
def processed_count(self) -> int:
144-
"""Total number of processed items"""
145-
return sum(end - start + 1 for start, end in self.ranges)
146-
147-
14872
class DumpyInferenceTask:
14973
def __init__(self, args):
15074
self.args = args

llm/benchmark/rl/utils.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass
16+
from typing import List
17+
18+
19+
@dataclass
20+
class RangeSet:
21+
"""Manage processed line ranges with efficient storage and querying"""
22+
23+
ranges: List[tuple]
24+
25+
def add(self, number: int):
26+
"""Add a number to the range set and merge adjacent ranges"""
27+
new_ranges = []
28+
added = False
29+
for start, end in sorted(self.ranges):
30+
if number < start - 1:
31+
if not added:
32+
new_ranges.append((number, number))
33+
added = True
34+
new_ranges.append((start, end))
35+
elif number == start - 1:
36+
new_ranges.append((number, end))
37+
added = True
38+
elif number <= end:
39+
new_ranges.append((start, end))
40+
added = True
41+
else:
42+
new_ranges.append((start, end))
43+
if not added:
44+
new_ranges.append((number, number))
45+
self.ranges = self.merge_ranges(new_ranges)
46+
47+
@staticmethod
48+
def merge_ranges(ranges: List[tuple]) -> List[tuple]:
49+
"""Merge overlapping or adjacent ranges"""
50+
if not ranges:
51+
return []
52+
sorted_ranges = sorted(ranges)
53+
merged = [sorted_ranges[0]]
54+
for current in sorted_ranges[1:]:
55+
last = merged[-1]
56+
if current[0] <= last[1] + 1:
57+
merged[-1] = (last[0], max(last[1], current[1]))
58+
else:
59+
merged.append(current)
60+
return merged
61+
62+
def contains(self, number: int) -> bool:
63+
"""Check if a number exists in any range"""
64+
for start, end in self.ranges:
65+
if start <= number <= end:
66+
return True
67+
return False
68+
69+
def to_file_format(self) -> str:
70+
"""Serialize ranges to compact string format"""
71+
return ",".join(f"{start}-{end}" if start != end else str(start) for start, end in self.ranges)
72+
73+
@classmethod
74+
def from_file(cls, content: str) -> "RangeSet":
75+
"""Deserialize from string format"""
76+
if not content:
77+
return cls(ranges=[])
78+
ranges = []
79+
for part in content.split(","):
80+
if "-" in part:
81+
start, end = map(int, part.split("-"))
82+
ranges.append((start, end))
83+
else:
84+
num = int(part)
85+
ranges.append((num, num))
86+
return cls(ranges=ranges)
87+
88+
@property
89+
def processed_count(self) -> int:
90+
"""Total number of processed items"""
91+
return sum(end - start + 1 for start, end in self.ranges)

0 commit comments

Comments
 (0)