27
27
import pandas as pd
28
28
from openai import AsyncOpenAI
29
29
from tqdm import tqdm
30
+ from utils import RangeSet
30
31
31
32
from paddlenlp .transformers import AutoTokenizer
32
33
33
- # 配置根 Logger
34
- logging .basicConfig (
35
- level = logging .WARNING , format = "%(asctime)s - %(levelname)s - %(message)s" , datefmt = "%Y-%m-%d %H:%M:%S"
36
- )
34
+ from transformers import logging
37
35
38
- logger = logging .getLogger ( __name__ )
39
- logger . setLevel ( logging .DEBUG )
36
+ logging .set_verbosity_info ( )
37
+ logger = logging .get_logger ( __name__ )
40
38
41
-
42
- @dataclass
43
- class RangeSet :
44
- """Manage processed line ranges with efficient storage and querying"""
45
-
46
- ranges : List [tuple ]
47
-
48
- def add (self , number : int ):
49
- """Add a number to the range set and merge adjacent ranges"""
50
- new_ranges = []
51
- added = False
52
- for start , end in sorted (self .ranges ):
53
- if number < start - 1 :
54
- if not added :
55
- new_ranges .append ((number , number ))
56
- added = True
57
- new_ranges .append ((start , end ))
58
- elif number == start - 1 :
59
- new_ranges .append ((number , end ))
60
- added = True
61
- elif number <= end :
62
- new_ranges .append ((start , end ))
63
- added = True
64
- else :
65
- new_ranges .append ((start , end ))
66
- if not added :
67
- new_ranges .append ((number , number ))
68
- self .ranges = self .merge_ranges (new_ranges )
69
-
70
- @staticmethod
71
- def merge_ranges (ranges : List [tuple ]) -> List [tuple ]:
72
- """Merge overlapping or adjacent ranges"""
73
- if not ranges :
74
- return []
75
- sorted_ranges = sorted (ranges )
76
- merged = [sorted_ranges [0 ]]
77
- for current in sorted_ranges [1 :]:
78
- last = merged [- 1 ]
79
- if current [0 ] <= last [1 ] + 1 :
80
- merged [- 1 ] = (last [0 ], max (last [1 ], current [1 ]))
81
- else :
82
- merged .append (current )
83
- return merged
84
-
85
- def contains (self , number : int ) -> bool :
86
- """Check if a number exists in any range"""
87
- for start , end in self .ranges :
88
- if start <= number <= end :
89
- return True
90
- return False
91
-
92
- def to_file_format (self ) -> str :
93
- """Serialize ranges to compact string format"""
94
- return "," .join (f"{ start } -{ end } " if start != end else str (start ) for start , end in self .ranges )
95
-
96
- @classmethod
97
- def from_file (cls , content : str ) -> "RangeSet" :
98
- """Deserialize from string format"""
99
- if not content :
100
- return cls (ranges = [])
101
- ranges = []
102
- for part in content .split ("," ):
103
- if "-" in part :
104
- start , end = map (int , part .split ("-" ))
105
- ranges .append ((start , end ))
106
- else :
107
- num = int (part )
108
- ranges .append ((num , num ))
109
- return cls (ranges = ranges )
110
-
111
- @property
112
- def processed_count (self ) -> int :
113
- """Total number of processed items"""
114
- return sum (end - start + 1 for start , end in self .ranges )
115
-
116
-
117
- # 请求api的参数类
118
39
@dataclass
119
40
class RequestPayload :
120
- """请求有效载荷"""
121
-
122
- prompt : str = "你好"
41
+ prompt : str = ""
123
42
num_responses : int = 8
124
- temperature : float = 1.0
125
- top_p : float = 1.0
126
- max_tokens : int = 20 * 1024
127
43
idx : int = 0
128
44
129
45
130
- # 响应api的参数类
131
46
@dataclass
132
47
class ResponsePayload :
133
- """响应有效载荷"""
134
-
135
48
idx : int = 0
136
49
question : str = ""
137
50
question_token_length : int = 0
@@ -142,6 +55,16 @@ class ResponsePayload:
142
55
143
56
144
57
class StatisticsManager :
58
+ """Manages statistics collection and analysis for batch inference operations.
59
+
60
+ This class provides methods to compute both per-group (dispersed) and aggregated (global)
61
+ statistics from batch inference responses, including token lengths, processing times,
62
+ and throughput metrics.
63
+ """
64
+ def __init__ (self , batch_size : int , rollout_n : int ):
65
+ self .batch_size = batch_size
66
+ self .rollout_n = rollout_n
67
+
145
68
def dispersed_stats (self , responses : List [ResponsePayload ], batch_elapsed_time : float ):
146
69
batch_group_pd = pd .DataFrame (responses )
147
70
@@ -154,10 +77,10 @@ def dispersed_stats(self, responses: List[ResponsePayload], batch_elapsed_time:
154
77
"completion_time" : batch_elapsed_time ,
155
78
"throughput_tokens_per_sec" : batch_group_pd ["token_lengths" ].apply ((lambda x : sum (x ))).sum ()
156
79
/ batch_elapsed_time ,
157
- "elapsed_times" : batch_group_pd ["elapsed_times" ].to_list (),
158
- "min_time" : batch_group_pd ["elapsed_times" ].apply (lambda x : min (x )).tolist (),
159
- "max_time" : batch_group_pd ["elapsed_times" ].apply (lambda x : max (x )).tolist (),
160
- "avg_time" : batch_group_pd ["elapsed_times" ].apply (lambda x : sum (x ) / len (x )).tolist (),
80
+ # "elapsed_times": batch_group_pd["elapsed_times"].to_list(),
81
+ # "min_time": batch_group_pd["elapsed_times"].apply(lambda x: min(x)).tolist(),
82
+ # "max_time": batch_group_pd["elapsed_times"].apply(lambda x: max(x)).tolist(),
83
+ # "avg_time": batch_group_pd["elapsed_times"].apply(lambda x: round( sum(x) / len(x), 2 )).tolist(),
161
84
}
162
85
163
86
return dispersed_stats_dict
@@ -173,11 +96,11 @@ def global_stats(self, responses: List[ResponsePayload], batch_elapsed_time: flo
173
96
global_stats_dict ["batch_index" ] = dispersed_stats_dict ["batch_index" ]
174
97
global_stats_dict ["min_response_tokens" ] = min (dispersed_stats_dict ["min_length" ])
175
98
global_stats_dict ["max_response_tokens" ] = max (dispersed_stats_dict ["max_length" ])
176
- global_stats_dict ["avg_response_tokens" ] = total_response_tokens / len ( responses )
99
+ global_stats_dict ["avg_response_tokens" ] = total_response_tokens / ( self . batch_size * self . rollout_n )
177
100
global_stats_dict ["total_response_tokens" ] = total_response_tokens
178
101
global_stats_dict ["group_max_response_tokens" ] = dispersed_stats_dict ["max_length" ]
179
- global_stats_dict ["min_time" ] = min (dispersed_stats_dict ["min_time" ])
180
- global_stats_dict ["avg_time" ] = sum (dispersed_stats_dict ["avg_time" ]) / len (responses )
102
+ # global_stats_dict["min_time"] = min(dispersed_stats_dict["min_time"])
103
+ # global_stats_dict["avg_time"] = round( sum(dispersed_stats_dict["avg_time"]) / len(responses), 2 )
181
104
global_stats_dict ["completion_time" ] = dispersed_stats_dict ["completion_time" ]
182
105
global_stats_dict ["throughput_tokens_per_sec" ] = dispersed_stats_dict ["throughput_tokens_per_sec" ]
183
106
@@ -197,21 +120,20 @@ def __init__(self, args, max_concurrency: int = 1000):
197
120
198
121
self .output_dir = Path (self .args .output_dir )
199
122
200
- # 初始化输出文件路径
201
123
self .global_stats_path = self .output_dir / "global_stats.csv"
202
124
self .dispersed_stats_path = self .output_dir / "dispersed_stats.csv"
203
125
self .rollout_details_path = self .output_dir / "rollout_details.jsonl"
204
126
self .status_file_path = self .output_dir / "status.txt"
205
127
206
- self .stats_manager = StatisticsManager ()
128
+ self .stats_manager = StatisticsManager (self . args . rollout_input_batch_size , self . args . rollout_n )
207
129
208
130
self ._load_status ()
209
131
210
132
def get_active_tasks_count (self ) -> int :
211
133
return self ._max_concurrency - self .semaphore ._value
212
134
213
135
def get_client (self ) -> AsyncOpenAI :
214
- # 返回一个AsyncOpenAI客户端实例
136
+ # Returns an AsyncOpenAI client instance
215
137
return next (self .clients )
216
138
217
139
def _save_status (self , batch_index ):
@@ -223,7 +145,6 @@ def _save_status(self, batch_index):
223
145
224
146
def _load_status (self ):
225
147
"""Load processing status from file"""
226
- """从文件中加载处理状态"""
227
148
try :
228
149
with open (self .status_file_path , "r" , encoding = "utf-8" ) as f :
229
150
content = f .read ().strip ()
@@ -236,14 +157,16 @@ def process_data(self, file_path: str) -> pd.DataFrame:
236
157
logger .info (f"Processing data from { file_path } ..." )
237
158
start_time = time .time ()
238
159
df = pd .read_parquet (file_path )
160
+ if self .args .limit_rows != - 1 :
161
+ df = df .iloc [:self .args .limit_rows ]
239
162
logger .info (f"Loaded { len (df )} samples in { time .time () - start_time :.2f} s" )
240
163
return df
241
164
242
165
def batch_process (self , dataframe : pd .DataFrame ):
243
166
batch_prompts = []
244
167
for idx , prompt in enumerate (dataframe [self .args .prompt_key ]):
245
168
batch_prompts .append (
246
- RequestPayload (prompt = prompt [0 ]["content" ], idx = idx , num_responses = self .args .rollout_output_num )
169
+ RequestPayload (prompt = prompt [0 ]["content" ], idx = idx , num_responses = self .args .rollout_n )
247
170
)
248
171
if len (batch_prompts ) == self .args .rollout_input_batch_size :
249
172
yield batch_prompts
@@ -253,21 +176,21 @@ async def call(self, request: RequestPayload) -> Tuple[str, float]:
253
176
client = self .get_client ()
254
177
try :
255
178
async with self .semaphore :
256
- logger .debug ("client is : %s" , client .base_url )
257
- logger .debug (f"当前有 { self .get_active_tasks_count ()} 个异步任务正在工作 " )
179
+ # logger.debug("client is : %s", client.base_url)
180
+ # logger.debug(f"There are currently {self.get_active_tasks_count()} asynchronous tasks working ")
258
181
start_time = time .perf_counter ()
259
182
response = await client .completions .create (
260
183
model = self .model ,
261
184
prompt = request .prompt ,
262
- temperature = request .temperature ,
263
- top_p = request .top_p ,
264
- max_tokens = request . max_tokens ,
185
+ temperature = self . args .temperature ,
186
+ top_p = self . args .top_p ,
187
+ max_tokens = self . args . max_response_length ,
265
188
n = 1 ,
266
189
stream = True ,
267
- )
268
- # 流式文字存储在chunks列表中
190
+ )
191
+ # Streaming text is stored in a list of chunks
269
192
chunks = []
270
- # 流式处理响应
193
+ # Streaming responses
271
194
async for chunk in response :
272
195
if chunk .choices and chunk .choices [0 ].text :
273
196
chunks .append (chunk .choices [0 ].text )
@@ -282,7 +205,7 @@ async def call(self, request: RequestPayload) -> Tuple[str, float]:
282
205
raise ValueError (e )
283
206
284
207
async def group_call (self , request : RequestPayload ) -> ResponsePayload :
285
- # 采用异步一次调用num_responses次 get_respose方法,并返回结果
208
+ """Performs n complete token generation rollouts for the given query."""
286
209
tasks = [self .call (request ) for _ in range (request .num_responses )]
287
210
288
211
result = ResponsePayload ()
@@ -298,7 +221,7 @@ async def group_call(self, request: RequestPayload) -> ResponsePayload:
298
221
return result
299
222
300
223
async def batch_call (self , requests : List [RequestPayload ]) -> Tuple [List [ResponsePayload ], int ]:
301
- """批量执行请求 """
224
+ """Batch execution requests """
302
225
start_time = time .perf_counter ()
303
226
batch_results = await asyncio .gather (* [self .group_call (request ) for request in requests ])
304
227
end_time = time .perf_counter ()
@@ -325,8 +248,8 @@ def execute(self):
325
248
"avg_response_tokens" ,
326
249
"total_response_tokens" ,
327
250
"group_max_response_tokens" ,
328
- "min_time" ,
329
- "avg_time" ,
251
+ # "min_time",
252
+ # "avg_time",
330
253
"completion_time" ,
331
254
"throughput_tokens_per_sec" ,
332
255
]
@@ -340,10 +263,10 @@ def execute(self):
340
263
"avg_length" ,
341
264
"completion_time" ,
342
265
"throughput_tokens_per_sec" ,
343
- "elapsed_times" ,
344
- "min_time" ,
345
- "max_time" ,
346
- "avg_time" ,
266
+ # "elapsed_times",
267
+ # "min_time",
268
+ # "max_time",
269
+ # "avg_time",
347
270
]
348
271
)
349
272
@@ -372,8 +295,8 @@ def execute(self):
372
295
round (global_stats_dict ["avg_response_tokens" ], 2 ),
373
296
global_stats_dict ["total_response_tokens" ],
374
297
global_stats_dict ["group_max_response_tokens" ],
375
- global_stats_dict ["min_time" ],
376
- global_stats_dict ["avg_time" ],
298
+ # global_stats_dict["min_time"],
299
+ # global_stats_dict["avg_time"],
377
300
round (global_stats_dict ["completion_time" ], 2 ),
378
301
round (global_stats_dict ["throughput_tokens_per_sec" ], 2 ),
379
302
]
@@ -388,10 +311,10 @@ def execute(self):
388
311
dispersed_stats_dict ["avg_length" ],
389
312
round (dispersed_stats_dict ["completion_time" ], 2 ),
390
313
round (dispersed_stats_dict ["throughput_tokens_per_sec" ], 2 ),
391
- dispersed_stats_dict ["elapsed_times" ],
392
- dispersed_stats_dict ["min_time" ],
393
- dispersed_stats_dict ["max_time" ],
394
- dispersed_stats_dict ["avg_time" ],
314
+ # dispersed_stats_dict["elapsed_times"],
315
+ # dispersed_stats_dict["min_time"],
316
+ # dispersed_stats_dict["max_time"],
317
+ # dispersed_stats_dict["avg_time"],
395
318
]
396
319
)
397
320
@@ -436,9 +359,7 @@ def tokenize(self, response: ResponsePayload) -> ResponsePayload:
436
359
437
360
438
361
def parse_args ():
439
- # 初始化 ArgumentParser
440
362
parser = argparse .ArgumentParser (description = "Process prompts with OpenAI clients." )
441
- # 添加参数
442
363
parser .add_argument ("--openai_urls" , type = str , nargs = "+" , required = True , help = "List of OpenAI service URLs" )
443
364
parser .add_argument (
444
365
"--api_keys" , type = str , nargs = "+" , default = None , help = "List of API keys (default: 'NONE' for each service)"
@@ -448,15 +369,21 @@ def parse_args():
448
369
"--tokenizer" , type = str , required = True , help = "Tokenizer name (e.g., Qwen/Qwen2.5-7B-Instruct-1M)"
449
370
)
450
371
parser .add_argument ("--rollout_input_batch_size" , type = int , default = 4 , help = "Batch size for requests" )
451
- parser .add_argument ("--rollout_output_num " , type = int , default = 8 , help = "Number of responses per request" )
372
+ parser .add_argument ("--rollout_n " , type = int , default = 8 , help = "Number of responses per request" )
452
373
parser .add_argument (
453
374
"--prompt_key" , type = str , default = "prompt" , help = "Key in the DataFrame for prompts (default: 'prompt')"
454
375
)
455
376
parser .add_argument ("--input_file" , type = str , required = True , help = "Path to the input Parquet file" )
456
377
parser .add_argument (
457
- "--output_dir" , type = str , default = "./output" , help = "Directory for output CSV files (default: './output')"
378
+ "--output_dir" , type = str , default = "./api_infer_results" , help = "Directory for output CSV files (default: './api_infer_results')"
379
+ )
380
+ parser .add_argument ("--top_p" , type = float , default = 0.9 , help = "Top-p sampling parameter for text generation" )
381
+ parser .add_argument ("--temperature" , type = float , default = 0.7 , help = "Temperature parameter for text generation" )
382
+ parser .add_argument ("--max_prompt_length" , type = int , default = 1024 * 2 , help = "Maximum prompt length (in tokens)" )
383
+ parser .add_argument (
384
+ "--max_response_length" , type = int , default = 1024 * 2 , help = "Maximum response length (in tokens)"
458
385
)
459
- # 解析参数
386
+ parser . add_argument ( "--limit_rows" , type = int , default = - 1 , help = "Maximum number of rows to read from the dataset (-1 means all)" )
460
387
return parser .parse_args ()
461
388
462
389
0 commit comments