-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
40 lines (34 loc) · 1.33 KB
/
train.py
File metadata and controls
40 lines (34 loc) · 1.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# train.py
import argparse
import subprocess
import sys
import os
def main():
parser = argparse.ArgumentParser(description="Master training script")
parser.add_argument("train_strategy", type=str, help="Trainer to run (e.g. logreg, llm)")
args, remaining_args = parser.parse_known_args()
train_strategy = args.train_strategy
if train_strategy=="ml_bow":
trainer_script = f"trainers/ml_bagofwords.py"
elif train_strategy=="ml_wv":
trainer_script = f"trainers/ml_wordvectors.py"
elif train_strategy=="fcn_bow":
trainer_script = f"trainers/fcn_bagofwords.py"
elif train_strategy=="fcn_wv":
trainer_script = f"trainers/fcn_wordvectors.py"
elif train_strategy=="lstm_wv":
trainer_script = f"trainers/lstm_wordvectors.py"
elif train_strategy=="transformer_wv":
trainer_script = f"trainers/transformer_wordvectors.py"
elif train_strategy=="llm":
trainer_script = f"trainers/llm.py"
else:
raise ValueError("Trainer not found")
if not os.path.exists(trainer_script):
print(f"Trainer script '{trainer_script}' not found.")
sys.exit(1)
# Call the trainer script and pass all remaining CLI args to it
cmd = [sys.executable, trainer_script] + remaining_args
subprocess.run(cmd)
if __name__ == "__main__":
main()