2727import pandas as pd
2828from openai import AsyncOpenAI
2929from tqdm import tqdm
30+ from utils import RangeSet
3031
3132from paddlenlp .transformers import AutoTokenizer
3233
33- # 配置根 Logger
34- logging .basicConfig (
35- level = logging .WARNING , format = "%(asctime)s - %(levelname)s - %(message)s" , datefmt = "%Y-%m-%d %H:%M:%S"
36- )
34+ from transformers import logging
3735
38- logger = logging .getLogger ( __name__ )
39- logger . setLevel ( logging .DEBUG )
36+ logging .set_verbosity_info ( )
37+ logger = logging .get_logger ( __name__ )
4038
41-
42- @dataclass
43- class RangeSet :
44- """Manage processed line ranges with efficient storage and querying"""
45-
46- ranges : List [tuple ]
47-
48- def add (self , number : int ):
49- """Add a number to the range set and merge adjacent ranges"""
50- new_ranges = []
51- added = False
52- for start , end in sorted (self .ranges ):
53- if number < start - 1 :
54- if not added :
55- new_ranges .append ((number , number ))
56- added = True
57- new_ranges .append ((start , end ))
58- elif number == start - 1 :
59- new_ranges .append ((number , end ))
60- added = True
61- elif number <= end :
62- new_ranges .append ((start , end ))
63- added = True
64- else :
65- new_ranges .append ((start , end ))
66- if not added :
67- new_ranges .append ((number , number ))
68- self .ranges = self .merge_ranges (new_ranges )
69-
70- @staticmethod
71- def merge_ranges (ranges : List [tuple ]) -> List [tuple ]:
72- """Merge overlapping or adjacent ranges"""
73- if not ranges :
74- return []
75- sorted_ranges = sorted (ranges )
76- merged = [sorted_ranges [0 ]]
77- for current in sorted_ranges [1 :]:
78- last = merged [- 1 ]
79- if current [0 ] <= last [1 ] + 1 :
80- merged [- 1 ] = (last [0 ], max (last [1 ], current [1 ]))
81- else :
82- merged .append (current )
83- return merged
84-
85- def contains (self , number : int ) -> bool :
86- """Check if a number exists in any range"""
87- for start , end in self .ranges :
88- if start <= number <= end :
89- return True
90- return False
91-
92- def to_file_format (self ) -> str :
93- """Serialize ranges to compact string format"""
94- return "," .join (f"{ start } -{ end } " if start != end else str (start ) for start , end in self .ranges )
95-
96- @classmethod
97- def from_file (cls , content : str ) -> "RangeSet" :
98- """Deserialize from string format"""
99- if not content :
100- return cls (ranges = [])
101- ranges = []
102- for part in content .split ("," ):
103- if "-" in part :
104- start , end = map (int , part .split ("-" ))
105- ranges .append ((start , end ))
106- else :
107- num = int (part )
108- ranges .append ((num , num ))
109- return cls (ranges = ranges )
110-
111- @property
112- def processed_count (self ) -> int :
113- """Total number of processed items"""
114- return sum (end - start + 1 for start , end in self .ranges )
115-
116-
117- # 请求api的参数类
11839@dataclass
11940class RequestPayload :
120- """请求有效载荷"""
121-
122- prompt : str = "你好"
41+ prompt : str = ""
12342 num_responses : int = 8
124- temperature : float = 1.0
125- top_p : float = 1.0
126- max_tokens : int = 20 * 1024
12743 idx : int = 0
12844
12945
130- # 响应api的参数类
13146@dataclass
13247class ResponsePayload :
133- """响应有效载荷"""
134-
13548 idx : int = 0
13649 question : str = ""
13750 question_token_length : int = 0
@@ -142,6 +55,16 @@ class ResponsePayload:
14255
14356
14457class StatisticsManager :
58+ """Manages statistics collection and analysis for batch inference operations.
59+
60+ This class provides methods to compute both per-group (dispersed) and aggregated (global)
61+ statistics from batch inference responses, including token lengths, processing times,
62+ and throughput metrics.
63+ """
64+ def __init__ (self , batch_size : int , rollout_n : int ):
65+ self .batch_size = batch_size
66+ self .rollout_n = rollout_n
67+
14568 def dispersed_stats (self , responses : List [ResponsePayload ], batch_elapsed_time : float ):
14669 batch_group_pd = pd .DataFrame (responses )
14770
@@ -154,10 +77,10 @@ def dispersed_stats(self, responses: List[ResponsePayload], batch_elapsed_time:
15477 "completion_time" : batch_elapsed_time ,
15578 "throughput_tokens_per_sec" : batch_group_pd ["token_lengths" ].apply ((lambda x : sum (x ))).sum ()
15679 / batch_elapsed_time ,
157- "elapsed_times" : batch_group_pd ["elapsed_times" ].to_list (),
158- "min_time" : batch_group_pd ["elapsed_times" ].apply (lambda x : min (x )).tolist (),
159- "max_time" : batch_group_pd ["elapsed_times" ].apply (lambda x : max (x )).tolist (),
160- "avg_time" : batch_group_pd ["elapsed_times" ].apply (lambda x : sum (x ) / len (x )).tolist (),
80+ # "elapsed_times": batch_group_pd["elapsed_times"].to_list(),
81+ # "min_time": batch_group_pd["elapsed_times"].apply(lambda x: min(x)).tolist(),
82+ # "max_time": batch_group_pd["elapsed_times"].apply(lambda x: max(x)).tolist(),
83+ # "avg_time": batch_group_pd["elapsed_times"].apply(lambda x: round( sum(x) / len(x), 2 )).tolist(),
16184 }
16285
16386 return dispersed_stats_dict
@@ -173,11 +96,11 @@ def global_stats(self, responses: List[ResponsePayload], batch_elapsed_time: flo
17396 global_stats_dict ["batch_index" ] = dispersed_stats_dict ["batch_index" ]
17497 global_stats_dict ["min_response_tokens" ] = min (dispersed_stats_dict ["min_length" ])
17598 global_stats_dict ["max_response_tokens" ] = max (dispersed_stats_dict ["max_length" ])
176- global_stats_dict ["avg_response_tokens" ] = total_response_tokens / len ( responses )
99+ global_stats_dict ["avg_response_tokens" ] = total_response_tokens / ( self . batch_size * self . rollout_n )
177100 global_stats_dict ["total_response_tokens" ] = total_response_tokens
178101 global_stats_dict ["group_max_response_tokens" ] = dispersed_stats_dict ["max_length" ]
179- global_stats_dict ["min_time" ] = min (dispersed_stats_dict ["min_time" ])
180- global_stats_dict ["avg_time" ] = sum (dispersed_stats_dict ["avg_time" ]) / len (responses )
102+ # global_stats_dict["min_time"] = min(dispersed_stats_dict["min_time"])
103+ # global_stats_dict["avg_time"] = round( sum(dispersed_stats_dict["avg_time"]) / len(responses), 2 )
181104 global_stats_dict ["completion_time" ] = dispersed_stats_dict ["completion_time" ]
182105 global_stats_dict ["throughput_tokens_per_sec" ] = dispersed_stats_dict ["throughput_tokens_per_sec" ]
183106
@@ -197,21 +120,20 @@ def __init__(self, args, max_concurrency: int = 1000):
197120
198121 self .output_dir = Path (self .args .output_dir )
199122
200- # 初始化输出文件路径
201123 self .global_stats_path = self .output_dir / "global_stats.csv"
202124 self .dispersed_stats_path = self .output_dir / "dispersed_stats.csv"
203125 self .rollout_details_path = self .output_dir / "rollout_details.jsonl"
204126 self .status_file_path = self .output_dir / "status.txt"
205127
206- self .stats_manager = StatisticsManager ()
128+ self .stats_manager = StatisticsManager (self . args . rollout_input_batch_size , self . args . rollout_n )
207129
208130 self ._load_status ()
209131
210132 def get_active_tasks_count (self ) -> int :
211133 return self ._max_concurrency - self .semaphore ._value
212134
213135 def get_client (self ) -> AsyncOpenAI :
214- # 返回一个AsyncOpenAI客户端实例
136+ # Returns an AsyncOpenAI client instance
215137 return next (self .clients )
216138
217139 def _save_status (self , batch_index ):
@@ -223,7 +145,6 @@ def _save_status(self, batch_index):
223145
224146 def _load_status (self ):
225147 """Load processing status from file"""
226- """从文件中加载处理状态"""
227148 try :
228149 with open (self .status_file_path , "r" , encoding = "utf-8" ) as f :
229150 content = f .read ().strip ()
@@ -236,14 +157,16 @@ def process_data(self, file_path: str) -> pd.DataFrame:
236157 logger .info (f"Processing data from { file_path } ..." )
237158 start_time = time .time ()
238159 df = pd .read_parquet (file_path )
160+ if self .args .limit_rows != - 1 :
161+ df = df .iloc [:self .args .limit_rows ]
239162 logger .info (f"Loaded { len (df )} samples in { time .time () - start_time :.2f} s" )
240163 return df
241164
242165 def batch_process (self , dataframe : pd .DataFrame ):
243166 batch_prompts = []
244167 for idx , prompt in enumerate (dataframe [self .args .prompt_key ]):
245168 batch_prompts .append (
246- RequestPayload (prompt = prompt [0 ]["content" ], idx = idx , num_responses = self .args .rollout_output_num )
169+ RequestPayload (prompt = prompt [0 ]["content" ], idx = idx , num_responses = self .args .rollout_n )
247170 )
248171 if len (batch_prompts ) == self .args .rollout_input_batch_size :
249172 yield batch_prompts
@@ -253,21 +176,21 @@ async def call(self, request: RequestPayload) -> Tuple[str, float]:
253176 client = self .get_client ()
254177 try :
255178 async with self .semaphore :
256- logger .debug ("client is : %s" , client .base_url )
257- logger .debug (f"当前有 { self .get_active_tasks_count ()} 个异步任务正在工作 " )
179+ # logger.debug("client is : %s", client.base_url)
180+ # logger.debug(f"There are currently {self.get_active_tasks_count()} asynchronous tasks working ")
258181 start_time = time .perf_counter ()
259182 response = await client .completions .create (
260183 model = self .model ,
261184 prompt = request .prompt ,
262- temperature = request .temperature ,
263- top_p = request .top_p ,
264- max_tokens = request . max_tokens ,
185+ temperature = self . args .temperature ,
186+ top_p = self . args .top_p ,
187+ max_tokens = self . args . max_response_length ,
265188 n = 1 ,
266189 stream = True ,
267- )
268- # 流式文字存储在chunks列表中
190+ )
191+ # Streaming text is stored in a list of chunks
269192 chunks = []
270- # 流式处理响应
193+ # Streaming responses
271194 async for chunk in response :
272195 if chunk .choices and chunk .choices [0 ].text :
273196 chunks .append (chunk .choices [0 ].text )
@@ -282,7 +205,7 @@ async def call(self, request: RequestPayload) -> Tuple[str, float]:
282205 raise ValueError (e )
283206
284207 async def group_call (self , request : RequestPayload ) -> ResponsePayload :
285- # 采用异步一次调用num_responses次 get_respose方法,并返回结果
208+ """Performs n complete token generation rollouts for the given query."""
286209 tasks = [self .call (request ) for _ in range (request .num_responses )]
287210
288211 result = ResponsePayload ()
@@ -298,7 +221,7 @@ async def group_call(self, request: RequestPayload) -> ResponsePayload:
298221 return result
299222
300223 async def batch_call (self , requests : List [RequestPayload ]) -> Tuple [List [ResponsePayload ], int ]:
301- """批量执行请求 """
224+ """Batch execution requests """
302225 start_time = time .perf_counter ()
303226 batch_results = await asyncio .gather (* [self .group_call (request ) for request in requests ])
304227 end_time = time .perf_counter ()
@@ -325,8 +248,8 @@ def execute(self):
325248 "avg_response_tokens" ,
326249 "total_response_tokens" ,
327250 "group_max_response_tokens" ,
328- "min_time" ,
329- "avg_time" ,
251+ # "min_time",
252+ # "avg_time",
330253 "completion_time" ,
331254 "throughput_tokens_per_sec" ,
332255 ]
@@ -340,10 +263,10 @@ def execute(self):
340263 "avg_length" ,
341264 "completion_time" ,
342265 "throughput_tokens_per_sec" ,
343- "elapsed_times" ,
344- "min_time" ,
345- "max_time" ,
346- "avg_time" ,
266+ # "elapsed_times",
267+ # "min_time",
268+ # "max_time",
269+ # "avg_time",
347270 ]
348271 )
349272
@@ -372,8 +295,8 @@ def execute(self):
372295 round (global_stats_dict ["avg_response_tokens" ], 2 ),
373296 global_stats_dict ["total_response_tokens" ],
374297 global_stats_dict ["group_max_response_tokens" ],
375- global_stats_dict ["min_time" ],
376- global_stats_dict ["avg_time" ],
298+ # global_stats_dict["min_time"],
299+ # global_stats_dict["avg_time"],
377300 round (global_stats_dict ["completion_time" ], 2 ),
378301 round (global_stats_dict ["throughput_tokens_per_sec" ], 2 ),
379302 ]
@@ -388,10 +311,10 @@ def execute(self):
388311 dispersed_stats_dict ["avg_length" ],
389312 round (dispersed_stats_dict ["completion_time" ], 2 ),
390313 round (dispersed_stats_dict ["throughput_tokens_per_sec" ], 2 ),
391- dispersed_stats_dict ["elapsed_times" ],
392- dispersed_stats_dict ["min_time" ],
393- dispersed_stats_dict ["max_time" ],
394- dispersed_stats_dict ["avg_time" ],
314+ # dispersed_stats_dict["elapsed_times"],
315+ # dispersed_stats_dict["min_time"],
316+ # dispersed_stats_dict["max_time"],
317+ # dispersed_stats_dict["avg_time"],
395318 ]
396319 )
397320
@@ -436,9 +359,7 @@ def tokenize(self, response: ResponsePayload) -> ResponsePayload:
436359
437360
438361def parse_args ():
439- # 初始化 ArgumentParser
440362 parser = argparse .ArgumentParser (description = "Process prompts with OpenAI clients." )
441- # 添加参数
442363 parser .add_argument ("--openai_urls" , type = str , nargs = "+" , required = True , help = "List of OpenAI service URLs" )
443364 parser .add_argument (
444365 "--api_keys" , type = str , nargs = "+" , default = None , help = "List of API keys (default: 'NONE' for each service)"
@@ -448,15 +369,21 @@ def parse_args():
448369 "--tokenizer" , type = str , required = True , help = "Tokenizer name (e.g., Qwen/Qwen2.5-7B-Instruct-1M)"
449370 )
450371 parser .add_argument ("--rollout_input_batch_size" , type = int , default = 4 , help = "Batch size for requests" )
451- parser .add_argument ("--rollout_output_num " , type = int , default = 8 , help = "Number of responses per request" )
372+ parser .add_argument ("--rollout_n " , type = int , default = 8 , help = "Number of responses per request" )
452373 parser .add_argument (
453374 "--prompt_key" , type = str , default = "prompt" , help = "Key in the DataFrame for prompts (default: 'prompt')"
454375 )
455376 parser .add_argument ("--input_file" , type = str , required = True , help = "Path to the input Parquet file" )
456377 parser .add_argument (
457- "--output_dir" , type = str , default = "./output" , help = "Directory for output CSV files (default: './output')"
378+ "--output_dir" , type = str , default = "./api_infer_results" , help = "Directory for output CSV files (default: './api_infer_results')"
379+ )
380+ parser .add_argument ("--top_p" , type = float , default = 0.9 , help = "Top-p sampling parameter for text generation" )
381+ parser .add_argument ("--temperature" , type = float , default = 0.7 , help = "Temperature parameter for text generation" )
382+ parser .add_argument ("--max_prompt_length" , type = int , default = 1024 * 2 , help = "Maximum prompt length (in tokens)" )
383+ parser .add_argument (
384+ "--max_response_length" , type = int , default = 1024 * 2 , help = "Maximum response length (in tokens)"
458385 )
459- # 解析参数
386+ parser . add_argument ( "--limit_rows" , type = int , default = - 1 , help = "Maximum number of rows to read from the dataset (-1 means all)" )
460387 return parser .parse_args ()
461388
462389
0 commit comments