Skip to content

Commit 32e9136

Browse files
committed
fix some bug
1 parent ad585eb commit 32e9136

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

llm/benchmark/rl/paddle_infer.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@
2222
from typing import List
2323
from utils import RangeSet
2424

25+
@contextmanager
26+
def switch_level_context(level="INFO"):
27+
"""临时切换日志级别的上下文管理器"""
28+
import logging
29+
original_level = logging.root.level
30+
logging.root.setLevel(level)
31+
try:
32+
yield
33+
finally:
34+
logging.root.setLevel(original_level)
35+
2536
import paddle
2637
import pandas as pd
2738
from tqdm import tqdm
@@ -37,12 +48,21 @@
3748
from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer
3849
from paddlenlp.utils.log import logger
3950

51+
@contextmanager
52+
def switch_level_context(level="ERROR"):
53+
original_level = logger.logLevel
54+
logger.set_level(level)
55+
56+
try:
57+
yield
58+
finally:
59+
logger.set_level(original_level)
60+
4061
def chunk(all_input_ids, size):
4162
if size <= 0:
4263
raise ValueError("Size must be greater than 0")
4364
return [all_input_ids[i : i + size] for i in range(0, len(all_input_ids), size)]
4465

45-
4666
@dataclass
4767
class DumpyTrainingArguments(TrainingArguments):
4868
actor_model_name_or_path: str = field(default="Qwen/Qwen2.5-7B-Instruct-1M", metadata={"help": "预训练模型名称或路径"})
@@ -137,8 +157,6 @@ def run_inference(self, input_ids, batch_index=0):
137157
if self.args.world_size > 1:
138158
paddle.distributed.barrier()
139159
end_time = time.time()
140-
print(input_ids.shape)
141-
print(output_ids.shape)
142160
if self.args.should_log:
143161
statistics = self.postprocess_data(input_ids, output_ids, batch_index=batch_index)
144162
statistics["total_time"] = end_time - start_time

0 commit comments

Comments
 (0)