5
5
import yaml
6
6
from data import setup_data
7
7
from ignite .engine import Events
8
+ from ignite .handlers import PiecewiseLinear
8
9
from ignite .metrics import Accuracy , Loss
9
10
from ignite .utils import manual_seed
10
11
from models import setup_model
@@ -29,6 +30,15 @@ def run(local_rank: int, config: Any):
29
30
model = idist .auto_model (setup_model (config .model ))
30
31
optimizer = idist .auto_optim (optim .Adam (model .parameters (), lr = config .lr ))
31
32
loss_fn = nn .CrossEntropyLoss ().to (device = device )
33
+ milestones_values = [
34
+ (0 , 0.0 ),
35
+ (
36
+ len (dataloader_train ),
37
+ config .lr ,
38
+ ),
39
+ (config .max_epochs * len (dataloader_train ), 0.0 ),
40
+ ]
41
+ lr_scheduler = PiecewiseLinear (optimizer , "lr" , milestones_values = milestones_values )
32
42
33
43
# trainer and evaluator
34
44
trainer = setup_trainer (
@@ -53,10 +63,17 @@ def run(local_rank: int, config: Any):
53
63
(config .output_dir / "config-lock.yaml" ).write_text (yaml .dump (config ))
54
64
trainer .logger = evaluator .logger = logger
55
65
66
+ trainer .add_event_handler (Events .ITERATION_COMPLETED , lr_scheduler )
67
+
56
68
# setup ignite handlers
57
69
#::: if (it.save_training || it.save_evaluation) { :::#
58
70
#::: if (it.save_training) { :::#
59
- to_save_train = {"model" : model , "optimizer" : optimizer , "trainer" : trainer }
71
+ to_save_train = {
72
+ "model" : model ,
73
+ "optimizer" : optimizer ,
74
+ "trainer" : trainer ,
75
+ "lr_scheduler" : lr_scheduler ,
76
+ }
60
77
#::: } else { :::#
61
78
to_save_train = None
62
79
#::: } :::#
0 commit comments