|
22 | 22 | from typing import List
|
23 | 23 | from utils import RangeSet
|
24 | 24 |
|
| 25 | +@contextmanager |
| 26 | +def switch_level_context(level="INFO"): |
| 27 | + """临时切换日志级别的上下文管理器""" |
| 28 | + import logging |
| 29 | + original_level = logging.root.level |
| 30 | + logging.root.setLevel(level) |
| 31 | + try: |
| 32 | + yield |
| 33 | + finally: |
| 34 | + logging.root.setLevel(original_level) |
| 35 | + |
25 | 36 | import paddle
|
26 | 37 | import pandas as pd
|
27 | 38 | from tqdm import tqdm
|
|
37 | 48 | from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer
|
38 | 49 | from paddlenlp.utils.log import logger
|
39 | 50 |
|
| 51 | +@contextmanager |
| 52 | +def switch_level_context(level="ERROR"): |
| 53 | + original_level = logger.logLevel |
| 54 | + logger.set_level(level) |
| 55 | + |
| 56 | + try: |
| 57 | + yield |
| 58 | + finally: |
| 59 | + logger.set_level(original_level) |
| 60 | + |
40 | 61 | def chunk(all_input_ids, size):
|
41 | 62 | if size <= 0:
|
42 | 63 | raise ValueError("Size must be greater than 0")
|
43 | 64 | return [all_input_ids[i : i + size] for i in range(0, len(all_input_ids), size)]
|
44 | 65 |
|
45 |
| - |
46 | 66 | @dataclass
|
47 | 67 | class DumpyTrainingArguments(TrainingArguments):
|
48 | 68 | actor_model_name_or_path: str = field(default="Qwen/Qwen2.5-7B-Instruct-1M", metadata={"help": "预训练模型名称或路径"})
|
@@ -137,8 +157,6 @@ def run_inference(self, input_ids, batch_index=0):
|
137 | 157 | if self.args.world_size > 1:
|
138 | 158 | paddle.distributed.barrier()
|
139 | 159 | end_time = time.time()
|
140 |
| - print(input_ids.shape) |
141 |
| - print(output_ids.shape) |
142 | 160 | if self.args.should_log:
|
143 | 161 | statistics = self.postprocess_data(input_ids, output_ids, batch_index=batch_index)
|
144 | 162 | statistics["total_time"] = end_time - start_time
|
|
0 commit comments