Skip to content

Commit c0632ec

Browse files
committed
requires_grad for speedup
1 parent b621837 commit c0632ec

File tree

3 files changed

+19
-7
lines changed

3 files changed

+19
-7
lines changed

models/base_model.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,15 @@ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
108108
def load_networks(self, which_epoch):
109109
for name in self.model_names:
110110
if isinstance(name, str):
111-
save_filename = '%s_net_%s.pth' % (which_epoch, name)
112-
save_path = os.path.join(self.save_dir, save_filename)
111+
load_filename = '%s_net_%s.pth' % (which_epoch, name)
112+
load_path = os.path.join(self.save_dir, load_filename)
113113
net = getattr(self, 'net' + name)
114114
if isinstance(net, torch.nn.DataParallel):
115115
net = net.module
116+
print('loading the model from %s' % load_path)
116117
# if you are using PyTorch newer than 0.4 (e.g., built from
117118
# GitHub source), you can remove str() on self.device
118-
state_dict = torch.load(save_path, map_location=str(self.device))
119+
state_dict = torch.load(load_path, map_location=str(self.device))
119120
# patch InstanceNorm checkpoints prior to 0.4
120121
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
121122
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
@@ -134,3 +135,12 @@ def print_networks(self, verbose):
134135
print(net)
135136
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
136137
print('-----------------------------------------------')
138+
139+
# set requies_grad=Fasle to avoid computation
140+
def set_requires_grad(self, nets, requires_grad=False):
141+
if not isinstance(nets, list):
142+
nets = [nets]
143+
for net in nets:
144+
if net is not None:
145+
for param in net.parameters():
146+
param.requires_grad = requires_grad

models/cycle_gan_model.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,10 @@ def backward_G(self):
113113

114114
# GAN loss D_A(G_A(A))
115115
self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
116-
117116
# GAN loss D_B(G_B(B))
118117
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
119-
120118
# Forward cycle loss
121119
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
122-
123120
# Backward cycle loss
124121
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
125122
# combined loss
@@ -130,10 +127,12 @@ def optimize_parameters(self):
130127
# forward
131128
self.forward()
132129
# G_A and G_B
130+
self.set_requires_grad([self.netD_A, self.netD_B], False)
133131
self.optimizer_G.zero_grad()
134132
self.backward_G()
135133
self.optimizer_G.step()
136134
# D_A and D_B
135+
self.set_requires_grad([self.netD_A, self.netD_B], True)
137136
self.optimizer_D.zero_grad()
138137
self.backward_D_A()
139138
self.backward_D_B()

models/pix2pix_model.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,14 @@ def backward_G(self):
8686

8787
def optimize_parameters(self):
8888
self.forward()
89-
89+
# update D
90+
self.set_requires_grad(self.netD, True)
9091
self.optimizer_D.zero_grad()
9192
self.backward_D()
9293
self.optimizer_D.step()
9394

95+
# update G
96+
self.set_requires_grad(self.netD, False)
9497
self.optimizer_G.zero_grad()
9598
self.backward_G()
9699
self.optimizer_G.step()

0 commit comments

Comments
 (0)