Skip to content

Commit aad5cbe

Browse files
authored
[ArgParser] Support pass python file as config. (#10489)
* fix import. * add python script. * support pass python file as config.
1 parent 370702a commit aad5cbe

File tree

4 files changed

+129
-0
lines changed

4 files changed

+129
-0
lines changed

llm/config/llama/pretrain_argument.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# models
16+
model_name_or_path = "meta-llama/Meta-Llama-3-8B"
17+
tokenizer_name_or_path = "meta-llama/Meta-Llama-3-8B"
18+
19+
# data
20+
checkpoint_dirs = {
21+
"input_dir": "./data",
22+
"output_dir": "./checkpoints/pretrain_ckpts",
23+
"unified_checkpoint": True,
24+
"save_total_limit": 2,
25+
}
26+
27+
training_contronl = {
28+
"do_train": True,
29+
"do_eval": True,
30+
"do_predict": True,
31+
"disable_tqdm": True,
32+
"recompute": False,
33+
"distributed_dataloader": 1,
34+
"recompute_granularity": "full",
35+
}
36+
37+
38+
training_args = {
39+
"per_device_train_batch_size": 1,
40+
"gradient_accumulation_steps": 16,
41+
"per_device_eval_batch_size": 2,
42+
"tensor_parallel_degree": 2,
43+
"pipeline_parallel_degree": 1,
44+
"sharding": "stage2",
45+
"virtual_pp_degree": 1,
46+
"sequence_parallel": 0,
47+
"max_seq_length": 4096,
48+
"learning_rate": 3e-05,
49+
"min_learning_rate": 3e-06,
50+
"warmup_steps": 30,
51+
"logging_steps": 1,
52+
"max_steps": 10000,
53+
"save_steps": 5000,
54+
"eval_steps": 1000,
55+
"weight_decay": 0.01,
56+
"warmup_ratio": 0.01,
57+
"max_grad_norm": 1.0,
58+
"dataloader_num_workers": 1,
59+
"continue_training": 0,
60+
}
61+
accelerate = {
62+
"use_flash_attention": True,
63+
"use_fused_rms_norm": True,
64+
"use_fused_rope": True,
65+
"bf16": True,
66+
"fp16_opt_level": "O2",
67+
}

llm/run_finetune.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ def main():
112112
gen_args, model_args, reft_args, data_args, training_args = parser.parse_json_file_and_cmd_lines()
113113
elif len(sys.argv) >= 2 and sys.argv[1].endswith(".yaml"):
114114
gen_args, model_args, reft_args, data_args, training_args = parser.parse_yaml_file_and_cmd_lines()
115+
elif len(sys.argv) >= 2 and sys.argv[1].endswith(".py"):
116+
gen_args, model_args, reft_args, data_args, training_args = parser.parse_python_file_and_cmd_lines()
115117
else:
116118
gen_args, model_args, reft_args, data_args, training_args = parser.parse_args_into_dataclasses()
117119

llm/run_pretrain.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,8 @@ def main():
354354
model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines()
355355
elif len(sys.argv) >= 2 and sys.argv[1].endswith(".yaml"):
356356
model_args, data_args, training_args = parser.parse_yaml_file_and_cmd_lines()
357+
elif len(sys.argv) >= 2 and sys.argv[1].endswith(".py"):
358+
model_args, data_args, training_args = parser.parse_python_file_and_cmd_lines()
357359
else:
358360
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
359361

paddlenlp/trainer/argparser.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,64 @@ def parse_yaml_file_and_cmd_lines(self, return_remaining_strings=False) -> Tuple
342342
args = yaml_args + sys.argv[2:]
343343
return self.common_parse(args, return_remaining_strings)
344344

345+
def read_python(self, python_file: str) -> list:
346+
347+
python_file = Path(python_file)
348+
349+
def get_variables_exec(file_path):
350+
def flatten(config):
351+
ret = {}
352+
for k, v in config.items():
353+
if type(v) is dict:
354+
sub = flatten(v)
355+
for sk, sv in sub.items():
356+
ret[sk] = sv
357+
else:
358+
ret[k] = v
359+
return ret
360+
361+
with open(file_path, "r", encoding="utf-8") as f:
362+
code = compile(f.read(), file_path, "exec")
363+
globals_dict = {}
364+
exec(code, globals_dict)
365+
ret_dict = {k: globals_dict[k] for k in globals_dict if not k.startswith("__")}
366+
return flatten(ret_dict)
367+
368+
if python_file.exists():
369+
data = get_variables_exec(python_file)
370+
371+
python_args = []
372+
for key, value in data.items():
373+
if isinstance(value, list):
374+
python_args.extend([f"--{key}", *[str(v) for v in value]])
375+
elif isinstance(value, dict):
376+
python_args.extend([f"--{key}", json.dumps(value)])
377+
else:
378+
python_args.extend([f"--{key}", str(value)])
379+
return python_args
380+
else:
381+
raise FileNotFoundError(f"The argument file {python_file} does not exist.")
382+
383+
def parse_python_file_and_cmd_lines(self, return_remaining_strings=False) -> Tuple[DataClass, ...]:
384+
"""
385+
Extend the functionality of `parse_python_file` to handle command line arguments in addition to loading a python
386+
file.
387+
388+
When there is a conflict between the command line arguments and the YAML file configuration,
389+
the command line arguments will take precedence.
390+
391+
Returns:
392+
Tuple consisting of:
393+
394+
- the dataclass instances in the same order as they were passed to the initializer.abspath
395+
"""
396+
if not sys.argv[1].endswith(".py"):
397+
raise ValueError(f"The first argument should be a PYTHON file, but it is {sys.argv[1]}")
398+
python_args = self.read_python(sys.argv[1])
399+
# In case of conflict, command line arguments take precedence
400+
args = python_args + sys.argv[2:]
401+
return self.common_parse(args, return_remaining_strings)
402+
345403
def parse_dict(self, args: dict) -> Tuple[DataClass, ...]:
346404
"""
347405
Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass

0 commit comments

Comments
 (0)