Skip to content

Commit 7cde16d

Browse files
committed
add reinforce learning framework with scripts file
1 parent 32e9136 commit 7cde16d

File tree

6 files changed

+177
-192
lines changed

6 files changed

+177
-192
lines changed

llm/benchmark/rl/api_serve.py

Lines changed: 58 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -27,111 +27,24 @@
2727
import pandas as pd
2828
from openai import AsyncOpenAI
2929
from tqdm import tqdm
30+
from utils import RangeSet
3031

3132
from paddlenlp.transformers import AutoTokenizer
3233

33-
# 配置根 Logger
34-
logging.basicConfig(
35-
level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
36-
)
34+
from transformers import logging
3735

38-
logger = logging.getLogger(__name__)
39-
logger.setLevel(logging.DEBUG)
36+
logging.set_verbosity_info()
37+
logger = logging.get_logger(__name__)
4038

41-
42-
@dataclass
43-
class RangeSet:
44-
"""Manage processed line ranges with efficient storage and querying"""
45-
46-
ranges: List[tuple]
47-
48-
def add(self, number: int):
49-
"""Add a number to the range set and merge adjacent ranges"""
50-
new_ranges = []
51-
added = False
52-
for start, end in sorted(self.ranges):
53-
if number < start - 1:
54-
if not added:
55-
new_ranges.append((number, number))
56-
added = True
57-
new_ranges.append((start, end))
58-
elif number == start - 1:
59-
new_ranges.append((number, end))
60-
added = True
61-
elif number <= end:
62-
new_ranges.append((start, end))
63-
added = True
64-
else:
65-
new_ranges.append((start, end))
66-
if not added:
67-
new_ranges.append((number, number))
68-
self.ranges = self.merge_ranges(new_ranges)
69-
70-
@staticmethod
71-
def merge_ranges(ranges: List[tuple]) -> List[tuple]:
72-
"""Merge overlapping or adjacent ranges"""
73-
if not ranges:
74-
return []
75-
sorted_ranges = sorted(ranges)
76-
merged = [sorted_ranges[0]]
77-
for current in sorted_ranges[1:]:
78-
last = merged[-1]
79-
if current[0] <= last[1] + 1:
80-
merged[-1] = (last[0], max(last[1], current[1]))
81-
else:
82-
merged.append(current)
83-
return merged
84-
85-
def contains(self, number: int) -> bool:
86-
"""Check if a number exists in any range"""
87-
for start, end in self.ranges:
88-
if start <= number <= end:
89-
return True
90-
return False
91-
92-
def to_file_format(self) -> str:
93-
"""Serialize ranges to compact string format"""
94-
return ",".join(f"{start}-{end}" if start != end else str(start) for start, end in self.ranges)
95-
96-
@classmethod
97-
def from_file(cls, content: str) -> "RangeSet":
98-
"""Deserialize from string format"""
99-
if not content:
100-
return cls(ranges=[])
101-
ranges = []
102-
for part in content.split(","):
103-
if "-" in part:
104-
start, end = map(int, part.split("-"))
105-
ranges.append((start, end))
106-
else:
107-
num = int(part)
108-
ranges.append((num, num))
109-
return cls(ranges=ranges)
110-
111-
@property
112-
def processed_count(self) -> int:
113-
"""Total number of processed items"""
114-
return sum(end - start + 1 for start, end in self.ranges)
115-
116-
117-
# 请求api的参数类
11839
@dataclass
11940
class RequestPayload:
120-
"""请求有效载荷"""
121-
122-
prompt: str = "你好"
41+
prompt: str = ""
12342
num_responses: int = 8
124-
temperature: float = 1.0
125-
top_p: float = 1.0
126-
max_tokens: int = 20 * 1024
12743
idx: int = 0
12844

12945

130-
# 响应api的参数类
13146
@dataclass
13247
class ResponsePayload:
133-
"""响应有效载荷"""
134-
13548
idx: int = 0
13649
question: str = ""
13750
question_token_length: int = 0
@@ -142,6 +55,16 @@ class ResponsePayload:
14255

14356

14457
class StatisticsManager:
58+
"""Manages statistics collection and analysis for batch inference operations.
59+
60+
This class provides methods to compute both per-group (dispersed) and aggregated (global)
61+
statistics from batch inference responses, including token lengths, processing times,
62+
and throughput metrics.
63+
"""
64+
def __init__(self, batch_size: int, rollout_n: int):
65+
self.batch_size = batch_size
66+
self.rollout_n = rollout_n
67+
14568
def dispersed_stats(self, responses: List[ResponsePayload], batch_elapsed_time: float):
14669
batch_group_pd = pd.DataFrame(responses)
14770

@@ -154,10 +77,10 @@ def dispersed_stats(self, responses: List[ResponsePayload], batch_elapsed_time:
15477
"completion_time": batch_elapsed_time,
15578
"throughput_tokens_per_sec": batch_group_pd["token_lengths"].apply((lambda x: sum(x))).sum()
15679
/ batch_elapsed_time,
157-
"elapsed_times": batch_group_pd["elapsed_times"].to_list(),
158-
"min_time": batch_group_pd["elapsed_times"].apply(lambda x: min(x)).tolist(),
159-
"max_time": batch_group_pd["elapsed_times"].apply(lambda x: max(x)).tolist(),
160-
"avg_time": batch_group_pd["elapsed_times"].apply(lambda x: sum(x) / len(x)).tolist(),
80+
# "elapsed_times": batch_group_pd["elapsed_times"].to_list(),
81+
# "min_time": batch_group_pd["elapsed_times"].apply(lambda x: min(x)).tolist(),
82+
# "max_time": batch_group_pd["elapsed_times"].apply(lambda x: max(x)).tolist(),
83+
# "avg_time": batch_group_pd["elapsed_times"].apply(lambda x: round(sum(x) / len(x), 2)).tolist(),
16184
}
16285

16386
return dispersed_stats_dict
@@ -173,11 +96,11 @@ def global_stats(self, responses: List[ResponsePayload], batch_elapsed_time: flo
17396
global_stats_dict["batch_index"] = dispersed_stats_dict["batch_index"]
17497
global_stats_dict["min_response_tokens"] = min(dispersed_stats_dict["min_length"])
17598
global_stats_dict["max_response_tokens"] = max(dispersed_stats_dict["max_length"])
176-
global_stats_dict["avg_response_tokens"] = total_response_tokens / len(responses)
99+
global_stats_dict["avg_response_tokens"] = total_response_tokens / (self.batch_size * self.rollout_n)
177100
global_stats_dict["total_response_tokens"] = total_response_tokens
178101
global_stats_dict["group_max_response_tokens"] = dispersed_stats_dict["max_length"]
179-
global_stats_dict["min_time"] = min(dispersed_stats_dict["min_time"])
180-
global_stats_dict["avg_time"] = sum(dispersed_stats_dict["avg_time"]) / len(responses)
102+
# global_stats_dict["min_time"] = min(dispersed_stats_dict["min_time"])
103+
# global_stats_dict["avg_time"] = round(sum(dispersed_stats_dict["avg_time"]) / len(responses), 2)
181104
global_stats_dict["completion_time"] = dispersed_stats_dict["completion_time"]
182105
global_stats_dict["throughput_tokens_per_sec"] = dispersed_stats_dict["throughput_tokens_per_sec"]
183106

@@ -197,21 +120,20 @@ def __init__(self, args, max_concurrency: int = 1000):
197120

198121
self.output_dir = Path(self.args.output_dir)
199122

200-
# 初始化输出文件路径
201123
self.global_stats_path = self.output_dir / "global_stats.csv"
202124
self.dispersed_stats_path = self.output_dir / "dispersed_stats.csv"
203125
self.rollout_details_path = self.output_dir / "rollout_details.jsonl"
204126
self.status_file_path = self.output_dir / "status.txt"
205127

206-
self.stats_manager = StatisticsManager()
128+
self.stats_manager = StatisticsManager(self.args.rollout_input_batch_size, self.args.rollout_n)
207129

208130
self._load_status()
209131

210132
def get_active_tasks_count(self) -> int:
211133
return self._max_concurrency - self.semaphore._value
212134

213135
def get_client(self) -> AsyncOpenAI:
214-
# 返回一个AsyncOpenAI客户端实例
136+
# Returns an AsyncOpenAI client instance
215137
return next(self.clients)
216138

217139
def _save_status(self, batch_index):
@@ -223,7 +145,6 @@ def _save_status(self, batch_index):
223145

224146
def _load_status(self):
225147
"""Load processing status from file"""
226-
"""从文件中加载处理状态"""
227148
try:
228149
with open(self.status_file_path, "r", encoding="utf-8") as f:
229150
content = f.read().strip()
@@ -236,14 +157,16 @@ def process_data(self, file_path: str) -> pd.DataFrame:
236157
logger.info(f"Processing data from {file_path}...")
237158
start_time = time.time()
238159
df = pd.read_parquet(file_path)
160+
if self.args.limit_rows != -1:
161+
df = df.iloc[:self.args.limit_rows]
239162
logger.info(f"Loaded {len(df)} samples in {time.time() - start_time:.2f}s")
240163
return df
241164

242165
def batch_process(self, dataframe: pd.DataFrame):
243166
batch_prompts = []
244167
for idx, prompt in enumerate(dataframe[self.args.prompt_key]):
245168
batch_prompts.append(
246-
RequestPayload(prompt=prompt[0]["content"], idx=idx, num_responses=self.args.rollout_output_num)
169+
RequestPayload(prompt=prompt[0]["content"], idx=idx, num_responses=self.args.rollout_n)
247170
)
248171
if len(batch_prompts) == self.args.rollout_input_batch_size:
249172
yield batch_prompts
@@ -253,21 +176,21 @@ async def call(self, request: RequestPayload) -> Tuple[str, float]:
253176
client = self.get_client()
254177
try:
255178
async with self.semaphore:
256-
logger.debug("client is : %s", client.base_url)
257-
logger.debug(f"当前有 {self.get_active_tasks_count()} 个异步任务正在工作")
179+
# logger.debug("client is : %s", client.base_url)
180+
# logger.debug(f"There are currently {self.get_active_tasks_count()} asynchronous tasks working")
258181
start_time = time.perf_counter()
259182
response = await client.completions.create(
260183
model=self.model,
261184
prompt=request.prompt,
262-
temperature=request.temperature,
263-
top_p=request.top_p,
264-
max_tokens=request.max_tokens,
185+
temperature=self.args.temperature,
186+
top_p=self.args.top_p,
187+
max_tokens=self.args.max_response_length,
265188
n=1,
266189
stream=True,
267-
)
268-
# 流式文字存储在chunks列表中
190+
)
191+
# Streaming text is stored in a list of chunks
269192
chunks = []
270-
# 流式处理响应
193+
# Streaming responses
271194
async for chunk in response:
272195
if chunk.choices and chunk.choices[0].text:
273196
chunks.append(chunk.choices[0].text)
@@ -282,7 +205,7 @@ async def call(self, request: RequestPayload) -> Tuple[str, float]:
282205
raise ValueError(e)
283206

284207
async def group_call(self, request: RequestPayload) -> ResponsePayload:
285-
# 采用异步一次调用num_responses次 get_respose方法,并返回结果
208+
"""Performs n complete token generation rollouts for the given query."""
286209
tasks = [self.call(request) for _ in range(request.num_responses)]
287210

288211
result = ResponsePayload()
@@ -298,7 +221,7 @@ async def group_call(self, request: RequestPayload) -> ResponsePayload:
298221
return result
299222

300223
async def batch_call(self, requests: List[RequestPayload]) -> Tuple[List[ResponsePayload], int]:
301-
"""批量执行请求"""
224+
"""Batch execution requests"""
302225
start_time = time.perf_counter()
303226
batch_results = await asyncio.gather(*[self.group_call(request) for request in requests])
304227
end_time = time.perf_counter()
@@ -325,8 +248,8 @@ def execute(self):
325248
"avg_response_tokens",
326249
"total_response_tokens",
327250
"group_max_response_tokens",
328-
"min_time",
329-
"avg_time",
251+
# "min_time",
252+
# "avg_time",
330253
"completion_time",
331254
"throughput_tokens_per_sec",
332255
]
@@ -340,10 +263,10 @@ def execute(self):
340263
"avg_length",
341264
"completion_time",
342265
"throughput_tokens_per_sec",
343-
"elapsed_times",
344-
"min_time",
345-
"max_time",
346-
"avg_time",
266+
# "elapsed_times",
267+
# "min_time",
268+
# "max_time",
269+
# "avg_time",
347270
]
348271
)
349272

@@ -372,8 +295,8 @@ def execute(self):
372295
round(global_stats_dict["avg_response_tokens"], 2),
373296
global_stats_dict["total_response_tokens"],
374297
global_stats_dict["group_max_response_tokens"],
375-
global_stats_dict["min_time"],
376-
global_stats_dict["avg_time"],
298+
# global_stats_dict["min_time"],
299+
# global_stats_dict["avg_time"],
377300
round(global_stats_dict["completion_time"], 2),
378301
round(global_stats_dict["throughput_tokens_per_sec"], 2),
379302
]
@@ -388,10 +311,10 @@ def execute(self):
388311
dispersed_stats_dict["avg_length"],
389312
round(dispersed_stats_dict["completion_time"], 2),
390313
round(dispersed_stats_dict["throughput_tokens_per_sec"], 2),
391-
dispersed_stats_dict["elapsed_times"],
392-
dispersed_stats_dict["min_time"],
393-
dispersed_stats_dict["max_time"],
394-
dispersed_stats_dict["avg_time"],
314+
# dispersed_stats_dict["elapsed_times"],
315+
# dispersed_stats_dict["min_time"],
316+
# dispersed_stats_dict["max_time"],
317+
# dispersed_stats_dict["avg_time"],
395318
]
396319
)
397320

@@ -436,9 +359,7 @@ def tokenize(self, response: ResponsePayload) -> ResponsePayload:
436359

437360

438361
def parse_args():
439-
# 初始化 ArgumentParser
440362
parser = argparse.ArgumentParser(description="Process prompts with OpenAI clients.")
441-
# 添加参数
442363
parser.add_argument("--openai_urls", type=str, nargs="+", required=True, help="List of OpenAI service URLs")
443364
parser.add_argument(
444365
"--api_keys", type=str, nargs="+", default=None, help="List of API keys (default: 'NONE' for each service)"
@@ -448,15 +369,21 @@ def parse_args():
448369
"--tokenizer", type=str, required=True, help="Tokenizer name (e.g., Qwen/Qwen2.5-7B-Instruct-1M)"
449370
)
450371
parser.add_argument("--rollout_input_batch_size", type=int, default=4, help="Batch size for requests")
451-
parser.add_argument("--rollout_output_num", type=int, default=8, help="Number of responses per request")
372+
parser.add_argument("--rollout_n", type=int, default=8, help="Number of responses per request")
452373
parser.add_argument(
453374
"--prompt_key", type=str, default="prompt", help="Key in the DataFrame for prompts (default: 'prompt')"
454375
)
455376
parser.add_argument("--input_file", type=str, required=True, help="Path to the input Parquet file")
456377
parser.add_argument(
457-
"--output_dir", type=str, default="./output", help="Directory for output CSV files (default: './output')"
378+
"--output_dir", type=str, default="./api_infer_results", help="Directory for output CSV files (default: './api_infer_results')"
379+
)
380+
parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling parameter for text generation")
381+
parser.add_argument("--temperature", type=float, default=0.7, help="Temperature parameter for text generation")
382+
parser.add_argument("--max_prompt_length", type=int, default=1024 * 2, help="Maximum prompt length (in tokens)")
383+
parser.add_argument(
384+
"--max_response_length", type=int, default=1024 * 2, help="Maximum response length (in tokens)"
458385
)
459-
# 解析参数
386+
parser.add_argument("--limit_rows", type=int, default=-1, help="Maximum number of rows to read from the dataset (-1 means all)")
460387
return parser.parse_args()
461388

462389

0 commit comments

Comments
 (0)