Skip to content

Commit f1ecbef

Browse files
authored
Dev (#246)
* refine demo.py * refine demo
1 parent 09d88c8 commit f1ecbef

File tree

2 files changed

+28
-13
lines changed

2 files changed

+28
-13
lines changed

tools/demo.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
import sys
33
sys.path.insert(0, '.')
44
import argparse
5+
import math
56
import torch
67
import torch.nn as nn
8+
import torch.nn.functional as F
79
from PIL import Image
810
import numpy as np
911
import cv2
1012

11-
import lib.transform_cv2 as T
13+
import lib.data.transform_cv2 as T
1214
from lib.models import model_factory
1315
from configs import set_cfg_from_file
1416

@@ -32,7 +34,7 @@
3234
palette = np.random.randint(0, 256, (256, 3), dtype=np.uint8)
3335

3436
# define model
35-
net = model_factory[cfg.model_type](cfg.n_cats, aux_mode='pred')
37+
net = model_factory[cfg.model_type](cfg.n_cats, aux_mode='eval')
3638
net.load_state_dict(torch.load(args.weight_path, map_location='cpu'), strict=False)
3739
net.eval()
3840
net.cuda()
@@ -45,7 +47,17 @@
4547
im = cv2.imread(args.img_path)[:, :, ::-1]
4648
im = to_tensor(dict(im=im, lb=None))['im'].unsqueeze(0).cuda()
4749

50+
# shape divisor
51+
org_size = im.size()[2:]
52+
new_size = [math.ceil(el / 32) * 32 for el in im.size()[2:]]
53+
4854
# inference
49-
out = net(im).squeeze().detach().cpu().numpy()
55+
im = F.interpolate(im, size=new_size, align_corners=False, mode='bilinear')
56+
out = net(im)[0]
57+
out = F.interpolate(out, size=org_size, align_corners=False, mode='bilinear')
58+
out = out.argmax(dim=1)
59+
60+
# visualize
61+
out = out.squeeze().detach().cpu().numpy()
5062
pred = palette[out]
5163
cv2.imwrite('./res.jpg', pred)

tools/demo_video.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
import torch
66
import torch.nn as nn
77
import torch.nn.functional as F
8-
from torch.multiprocessing import Process, Queue
8+
import torch.multiprocessing as mp
99
import time
1010
from PIL import Image
1111
import numpy as np
1212
import cv2
1313

14-
import lib.transform_cv2 as T
14+
import lib.data.transform_cv2 as T
1515
from lib.models import model_factory
1616
from configs import set_cfg_from_file
1717

@@ -40,7 +40,7 @@ def get_model():
4040

4141

4242
# fetch frames
43-
def get_func(inpth, in_q):
43+
def get_func(inpth, in_q, done):
4444
cap = cv2.VideoCapture(args.input)
4545
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # type is float
4646
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # type is float
@@ -59,7 +59,8 @@ def get_func(inpth, in_q):
5959
in_q.put(frame)
6060

6161
in_q.put('quit')
62-
while not in_q.empty(): continue
62+
done.wait()
63+
6364
cap.release()
6465
time.sleep(1)
6566
print('input queue done')
@@ -105,14 +106,15 @@ def infer_batch(frames):
105106

106107

107108
if __name__ == '__main__':
108-
torch.multiprocessing.set_start_method('spawn')
109+
mp.set_start_method('spawn')
109110

110-
in_q = Queue(1024)
111-
out_q = Queue(1024)
111+
in_q = mp.Queue(1024)
112+
out_q = mp.Queue(1024)
113+
done = mp.Event()
112114

113-
in_worker = Process(target=get_func,
114-
args=(args.input, in_q))
115-
out_worker = Process(target=save_func,
115+
in_worker = mp.Process(target=get_func,
116+
args=(args.input, in_q, done))
117+
out_worker = mp.Process(target=save_func,
116118
args=(args.input, args.output, out_q))
117119

118120
in_worker.start()
@@ -133,6 +135,7 @@ def infer_batch(frames):
133135
infer_batch(frames)
134136

135137
out_q.put('quit')
138+
done.set()
136139

137140
out_worker.join()
138141
in_worker.join()

0 commit comments

Comments
 (0)