Skip to content

Commit 97ea008

Browse files
committed
refine demo
1 parent 83ab7a5 commit 97ea008

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

tools/demo.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import numpy as np
1111
import cv2
1212

13-
import lib.transform_cv2 as T
13+
import lib.data.transform_cv2 as T
1414
from lib.models import model_factory
1515
from configs import set_cfg_from_file
1616

@@ -34,7 +34,7 @@
3434
palette = np.random.randint(0, 256, (256, 3), dtype=np.uint8)
3535

3636
# define model
37-
net = model_factory[cfg.model_type](cfg.n_cats, aux_mode='pred')
37+
net = model_factory[cfg.model_type](cfg.n_cats, aux_mode='eval')
3838
net.load_state_dict(torch.load(args.weight_path, map_location='cpu'), strict=False)
3939
net.eval()
4040
net.cuda()
@@ -53,8 +53,9 @@
5353

5454
# inference
5555
im = F.interpolate(im, size=new_size, align_corners=False, mode='bilinear')
56-
out = net(im)
56+
out = net(im)[0]
5757
out = F.interpolate(out, size=org_size, align_corners=False, mode='bilinear')
58+
out = out.argmax(dim=1)
5859

5960
# visualize
6061
out = out.squeeze().detach().cpu().numpy()

tools/demo_video.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
import torch
66
import torch.nn as nn
77
import torch.nn.functional as F
8-
from torch.multiprocessing import Process, Queue
8+
import torch.multiprocessing as mp
99
import time
1010
from PIL import Image
1111
import numpy as np
1212
import cv2
1313

14-
import lib.transform_cv2 as T
14+
import lib.data.transform_cv2 as T
1515
from lib.models import model_factory
1616
from configs import set_cfg_from_file
1717

@@ -40,7 +40,7 @@ def get_model():
4040

4141

4242
# fetch frames
43-
def get_func(inpth, in_q):
43+
def get_func(inpth, in_q, done):
4444
cap = cv2.VideoCapture(args.input)
4545
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # type is float
4646
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # type is float
@@ -59,7 +59,8 @@ def get_func(inpth, in_q):
5959
in_q.put(frame)
6060

6161
in_q.put('quit')
62-
while not in_q.empty(): continue
62+
done.wait()
63+
6364
cap.release()
6465
time.sleep(1)
6566
print('input queue done')
@@ -105,14 +106,15 @@ def infer_batch(frames):
105106

106107

107108
if __name__ == '__main__':
108-
torch.multiprocessing.set_start_method('spawn')
109+
mp.set_start_method('spawn')
109110

110-
in_q = Queue(1024)
111-
out_q = Queue(1024)
111+
in_q = mp.Queue(1024)
112+
out_q = mp.Queue(1024)
113+
done = mp.Event()
112114

113-
in_worker = Process(target=get_func,
114-
args=(args.input, in_q))
115-
out_worker = Process(target=save_func,
115+
in_worker = mp.Process(target=get_func,
116+
args=(args.input, in_q, done))
117+
out_worker = mp.Process(target=save_func,
116118
args=(args.input, args.output, out_q))
117119

118120
in_worker.start()
@@ -133,6 +135,7 @@ def infer_batch(frames):
133135
infer_batch(frames)
134136

135137
out_q.put('quit')
138+
done.set()
136139

137140
out_worker.join()
138141
in_worker.join()

0 commit comments

Comments
 (0)