@@ -108,14 +108,15 @@ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
108
108
def load_networks (self , which_epoch ):
109
109
for name in self .model_names :
110
110
if isinstance (name , str ):
111
- save_filename = '%s_net_%s.pth' % (which_epoch , name )
112
- save_path = os .path .join (self .save_dir , save_filename )
111
+ load_filename = '%s_net_%s.pth' % (which_epoch , name )
112
+ load_path = os .path .join (self .save_dir , load_filename )
113
113
net = getattr (self , 'net' + name )
114
114
if isinstance (net , torch .nn .DataParallel ):
115
115
net = net .module
116
+ print ('loading the model from %s' % load_path )
116
117
# if you are using PyTorch newer than 0.4 (e.g., built from
117
118
# GitHub source), you can remove str() on self.device
118
- state_dict = torch .load (save_path , map_location = str (self .device ))
119
+ state_dict = torch .load (load_path , map_location = str (self .device ))
119
120
# patch InstanceNorm checkpoints prior to 0.4
120
121
for key in list (state_dict .keys ()): # need to copy keys here because we mutate in loop
121
122
self .__patch_instance_norm_state_dict (state_dict , net , key .split ('.' ))
@@ -134,3 +135,12 @@ def print_networks(self, verbose):
134
135
print (net )
135
136
print ('[Network %s] Total number of parameters : %.3f M' % (name , num_params / 1e6 ))
136
137
print ('-----------------------------------------------' )
138
+
139
+ # set requies_grad=Fasle to avoid computation
140
+ def set_requires_grad (self , nets , requires_grad = False ):
141
+ if not isinstance (nets , list ):
142
+ nets = [nets ]
143
+ for net in nets :
144
+ if net is not None :
145
+ for param in net .parameters ():
146
+ param .requires_grad = requires_grad
0 commit comments