Skip to content

Commit aa0b8a9

Browse files
committed
simplify the code and add eval mode
1 parent 714a932 commit aa0b8a9

File tree

5 files changed

+36
-23
lines changed

5 files changed

+36
-23
lines changed

models/base_model.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import torch
33
from collections import OrderedDict
4+
from . import networks
45

56

67
class BaseModel():
@@ -26,6 +27,22 @@ def set_input(self, input):
2627
def forward(self):
2728
pass
2829

30+
# load and print networks; create shedulars
31+
def setup(self, opt):
32+
if self.isTrain:
33+
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
34+
35+
if not self.isTrain or opt.continue_train:
36+
self.load_networks(opt.which_epoch)
37+
self.print_networks(opt.verbose)
38+
39+
# make models eval mode during test time
40+
def eval(self):
41+
for name in self.model_names:
42+
if isinstance(name, str):
43+
net = getattr(self, 'net' + name)
44+
net.eval()
45+
2946
# used in test time, wrapping `forward` in no_grad() so we don't save
3047
# intermediate steps for backprop
3148
def test(self):
@@ -77,7 +94,6 @@ def save_networks(self, which_epoch):
7794
else:
7895
torch.save(net.cpu().state_dict(), save_path)
7996

80-
8197
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
8298
key = keys[i]
8399
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
@@ -101,7 +117,7 @@ def load_networks(self, which_epoch):
101117
# GitHub source), you can remove str() on self.device
102118
state_dict = torch.load(save_path, map_location=str(self.device))
103119
# patch InstanceNorm checkpoints prior to 0.4
104-
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
120+
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
105121
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
106122
net.load_state_dict(state_dict)
107123

models/cycle_gan_model.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,10 @@ def initialize(self, opt):
5858
self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()),
5959
lr=opt.lr, betas=(opt.beta1, 0.999))
6060
self.optimizers = []
61-
self.schedulers = []
6261
self.optimizers.append(self.optimizer_G)
6362
self.optimizers.append(self.optimizer_D)
64-
for optimizer in self.optimizers:
65-
self.schedulers.append(networks.get_scheduler(optimizer, opt))
6663

67-
if not self.isTrain or opt.continue_train:
68-
self.load_networks(opt.which_epoch)
69-
self.print_networks(opt.verbose)
64+
self.setup(opt)
7065

7166
def set_input(self, input):
7267
AtoB = self.opt.which_direction == 'AtoB'

models/pix2pix_model.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,15 @@ def initialize(self, opt):
3737
self.criterionL1 = torch.nn.L1Loss()
3838

3939
# initialize optimizers
40-
self.schedulers = []
4140
self.optimizers = []
4241
self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
4342
lr=opt.lr, betas=(opt.beta1, 0.999))
4443
self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
4544
lr=opt.lr, betas=(opt.beta1, 0.999))
4645
self.optimizers.append(self.optimizer_G)
4746
self.optimizers.append(self.optimizer_D)
48-
for optimizer in self.optimizers:
49-
self.schedulers.append(networks.get_scheduler(optimizer, opt))
5047

51-
if not self.isTrain or opt.continue_train:
52-
self.load_networks(opt.which_epoch)
53-
54-
self.print_networks(opt.verbose)
48+
self.setup(opt)
5549

5650
def set_input(self, input):
5751
AtoB = self.opt.which_direction == 'AtoB'

models/test_model.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from torch.autograd import Variable
21
from .base_model import BaseModel
32
from . import networks
43

@@ -23,8 +22,7 @@ def initialize(self, opt):
2322
opt.norm, not opt.no_dropout,
2423
opt.init_type,
2524
self.gpu_ids)
26-
self.load_networks(opt.which_epoch)
27-
self.print_networks(opt.verbose)
25+
self.setup(opt)
2826

2927
def set_input(self, input):
3028
# we need to use single_dataset mode

scripts/check_all.sh

+15-5
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
11
set -ex
2+
DOWNLOAD=${1}
23
echo 'apply a pretrained cyclegan model'
3-
bash pretrained_models/download_cyclegan_model.sh horse2zebra
4-
bash ./datasets/download_cyclegan_dataset.sh horse2zebra
4+
if [ ${DOWNLOAD} -eq 1 ]
5+
then
6+
bash pretrained_models/download_cyclegan_model.sh horse2zebra
7+
bash ./datasets/download_cyclegan_dataset.sh horse2zebra
8+
fi
59
python test.py --dataroot datasets/horse2zebra/testA --checkpoints_dir ./checkpoints/ --name horse2zebra_pretrained --no_dropout --model test --dataset_mode single --loadSize 256
610

711
echo 'apply a pretrained pix2pix model'
8-
bash pretrained_models/download_pix2pix_model.sh facades_label2photo
9-
bash ./datasets/download_pix2pix_dataset.sh facades
12+
if [ ${DOWNLOAD} -eq 1 ]
13+
then
14+
bash pretrained_models/download_pix2pix_model.sh facades_label2photo
15+
bash ./datasets/download_pix2pix_dataset.sh facades
16+
fi
1017
python test.py --dataroot ./datasets/facades/ --which_direction BtoA --model pix2pix --name facades_label2photo_pretrained --dataset_mode aligned --which_model_netG unet_256 --norm batch
1118

1219

1320
echo 'cyclegan train (1 epoch) and test'
14-
bash ./datasets/download_cyclegan_dataset.sh maps
21+
if [ ${DOWNLOAD} -eq 1 ]
22+
then
23+
bash ./datasets/download_cyclegan_dataset.sh maps
24+
fi
1525
python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan --no_dropout --niter 1 --niter_decay 0 --max_dataset_size 100 --save_latest_freq 100
1626
python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan --phase test --no_dropout
1727

0 commit comments

Comments
 (0)