1
1
import os
2
2
import torch
3
3
from collections import OrderedDict
4
+ from . import networks
4
5
5
6
6
7
class BaseModel ():
@@ -26,6 +27,22 @@ def set_input(self, input):
26
27
def forward (self ):
27
28
pass
28
29
30
+ # load and print networks; create shedulars
31
+ def setup (self , opt ):
32
+ if self .isTrain :
33
+ self .schedulers = [networks .get_scheduler (optimizer , opt ) for optimizer in self .optimizers ]
34
+
35
+ if not self .isTrain or opt .continue_train :
36
+ self .load_networks (opt .which_epoch )
37
+ self .print_networks (opt .verbose )
38
+
39
+ # make models eval mode during test time
40
+ def eval (self ):
41
+ for name in self .model_names :
42
+ if isinstance (name , str ):
43
+ net = getattr (self , 'net' + name )
44
+ net .eval ()
45
+
29
46
# used in test time, wrapping `forward` in no_grad() so we don't save
30
47
# intermediate steps for backprop
31
48
def test (self ):
@@ -77,7 +94,6 @@ def save_networks(self, which_epoch):
77
94
else :
78
95
torch .save (net .cpu ().state_dict (), save_path )
79
96
80
-
81
97
def __patch_instance_norm_state_dict (self , state_dict , module , keys , i = 0 ):
82
98
key = keys [i ]
83
99
if i + 1 == len (keys ): # at the end, pointing to a parameter/buffer
@@ -101,7 +117,7 @@ def load_networks(self, which_epoch):
101
117
# GitHub source), you can remove str() on self.device
102
118
state_dict = torch .load (save_path , map_location = str (self .device ))
103
119
# patch InstanceNorm checkpoints prior to 0.4
104
- for key in list (state_dict .keys ()): # need to copy keys here because we mutate in loop
120
+ for key in list (state_dict .keys ()): # need to copy keys here because we mutate in loop
105
121
self .__patch_instance_norm_state_dict (state_dict , net , key .split ('.' ))
106
122
net .load_state_dict (state_dict )
107
123
0 commit comments