Skip to content

Commit 1b1c032

Browse files
authored
Update seg_track_anything.py
1 parent 6e20c5d commit 1b1c032

File tree

1 file changed

+81
-22
lines changed

1 file changed

+81
-22
lines changed

seg_track_anything.py

Lines changed: 81 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
import gc
1010
import imageio
11+
from scipy.ndimage import binary_dilation
1112

1213
def save_prediction(pred_mask,output_dir,file_name):
1314
save_mask = Image.fromarray(pred_mask.astype(np.uint8))
@@ -21,6 +22,37 @@ def colorize_mask(pred_mask):
2122
save_mask = save_mask.convert(mode='RGB')
2223
return np.array(save_mask)
2324

25+
def draw_mask(img, mask, alpha=0.5, id_countour=False):
26+
img_mask = np.zeros_like(img)
27+
img_mask = img
28+
if id_countour:
29+
# very slow ~ 1s per image
30+
obj_ids = np.unique(mask)
31+
obj_ids = obj_ids[obj_ids!=0]
32+
33+
for id in obj_ids:
34+
# Overlay color on binary mask
35+
if id <= 255:
36+
color = _palette[id*3:id*3+3]
37+
else:
38+
color = [0,0,0]
39+
foreground = img * (1-alpha) + np.ones_like(img) * alpha * np.array(color)
40+
binary_mask = (mask == id)
41+
42+
# Compose image
43+
img_mask[binary_mask] = foreground[binary_mask]
44+
45+
countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask
46+
img_mask[countours, :] = 0
47+
else:
48+
binary_mask = (mask!=0)
49+
countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask
50+
foreground = img*(1-alpha)+colorize_mask(mask)*alpha
51+
img_mask[binary_mask] = foreground[binary_mask]
52+
img_mask[countours,:] = 0
53+
54+
return img_mask.astype(img.dtype)
55+
2456
aot_model2ckpt = {
2557
"deaotb": "./ckpt/DeAOTB_PRE_YTB_DAV.pth",
2658
"deaotl": "./ckpt/DeAOTL_PRE_YTB_DAV",
@@ -29,11 +61,11 @@ def colorize_mask(pred_mask):
2961

3062

3163
def seg_track_anything(input_video_file, aot_model, sam_gap, max_obj_num, points_per_side):
64+
3265
video_name = os.path.basename(input_video_file).split('.')[0]
3366
io_args = {
3467
'input_video': f'{input_video_file}',
3568
'output_mask_dir': f'./assets/{video_name}_masks',
36-
'save_video': True,
3769
'output_video': f'./assets/{video_name}_seg.mp4', # keep same format as input video
3870
'output_gif': f'./assets/{video_name}_seg.gif',
3971
}
@@ -50,40 +82,35 @@ def seg_track_anything(input_video_file, aot_model, sam_gap, max_obj_num, points
5082
output_dir = io_args['output_mask_dir']
5183
if not os.path.exists(output_dir):
5284
os.makedirs(output_dir)
85+
5386
# source video to segment
5487
cap = cv2.VideoCapture(io_args['input_video'])
5588
fps = cap.get(cv2.CAP_PROP_FPS)
5689
# output masks
5790
output_dir = io_args['output_mask_dir']
5891
if not os.path.exists(output_dir):
5992
os.makedirs(output_dir)
60-
if io_args['save_video']:
61-
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
62-
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
63-
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
64-
fourcc = int(cap.get(cv2.CAP_PROP_FOURCC))
65-
out = cv2.VideoWriter(io_args['output_video'], fourcc, fps, (width, height))
6693
pred_list = []
6794

95+
# start to track
6896
torch.cuda.empty_cache()
6997
gc.collect()
7098
sam_gap = segtracker_args['sam_gap']
7199
frame_idx = 0
72100
segtracker = SegTracker(segtracker_args,sam_args,aot_args)
73101
segtracker.restart_tracker()
74102

75-
76103
with torch.cuda.amp.autocast():
77104
while cap.isOpened():
78105
ret, frame = cap.read()
79106
if not ret:
80107
break
81-
108+
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
82109
if frame_idx == 0:
83110
pred_mask = segtracker.seg(frame)
84111
torch.cuda.empty_cache()
85112
gc.collect()
86-
segtracker.add_reference(frame, pred_mask, frame_idx)
113+
segtracker.add_reference(frame, pred_mask)
87114
elif (frame_idx % sam_gap) == 0:
88115
seg_mask = segtracker.seg(frame)
89116
torch.cuda.empty_cache()
@@ -94,30 +121,62 @@ def seg_track_anything(input_video_file, aot_model, sam_gap, max_obj_num, points
94121
save_prediction(new_obj_mask,output_dir,str(frame_idx)+'_new.png')
95122
pred_mask = track_mask + new_obj_mask
96123
# segtracker.restart_tracker()
97-
segtracker.add_reference(frame, pred_mask, frame_idx)
124+
segtracker.add_reference(frame, pred_mask)
98125
else:
99126
pred_mask = segtracker.track(frame,update_memory=True)
100127
torch.cuda.empty_cache()
101128
gc.collect()
102129

103130
save_prediction(pred_mask,output_dir,str(frame_idx)+'.png')
104-
masked_frame = (frame*0.3+colorize_mask(pred_mask)*0.7).astype(np.uint8)
105-
pred_list.append(masked_frame)
106-
if io_args['save_video']:
107-
out.write(masked_frame)
131+
pred_list.append(pred_mask)
108132

109-
print("processed and saved mask for frame {}, obj_num {}".format(frame_idx,segtracker.get_obj_num()),end='\r')
133+
134+
print("processed frame {}, obj_num {}".format(frame_idx,segtracker.get_obj_num()),end='\r')
110135
frame_idx += 1
136+
cap.release()
137+
print('\nfinished')
138+
139+
######################
140+
# Visualization
141+
######################
142+
143+
# draw pred mask on frame and save as a video
144+
cap = cv2.VideoCapture(io_args['input_video'])
145+
fps = cap.get(cv2.CAP_PROP_FPS)
146+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
147+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
148+
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
149+
fourcc = int(cap.get(cv2.CAP_PROP_FOURCC))
150+
out = cv2.VideoWriter(io_args['output_video'], fourcc, fps, (width, height))
151+
# for .mp4
152+
frame_idx = 0
153+
while cap.isOpened():
154+
ret, frame = cap.read()
155+
if not ret:
156+
break
157+
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
158+
pred_mask = pred_list[frame_idx]
159+
masked_frame = draw_mask(frame,pred_mask)
160+
# masked_frame = masked_pred_list[frame_idx]
161+
masked_frame = cv2.cvtColor(masked_frame,cv2.COLOR_RGB2BGR)
162+
out.write(masked_frame)
163+
print('frame {} writed'.format(frame_idx),end='\r')
164+
frame_idx += 1
165+
out.release()
111166
cap.release()
112-
if io_args['save_video']:
113-
out.release()
114-
print("\n{} saved".format(io_args['output_video']))
115-
# save a gif
116-
imageio.mimsave(io_args['output_gif'],pred_list,fps=fps)
117-
print("{} saved".format(io_args['output_gif']))
167+
print("\n{} saved".format(io_args['output_video']))
118168
print('\nfinished')
119169

170+
# save colorized masks as a gif
171+
imageio.mimsave(io_args['output_gif'],pred_list,fps=fps)
172+
print("{} saved".format(io_args['output_gif']))
173+
120174
# zip predicted mask
121175
os.system(f"zip -r ./assets/{video_name}_pred_mask.zip {io_args['output_mask_dir']}")
122176

177+
# manually release memory (after cuda out of memory)
178+
del segtracker
179+
torch.cuda.empty_cache()
180+
gc.collect()
181+
123182
return io_args["output_video"], f"./assets/{video_name}_pred_mask.zip"

0 commit comments

Comments
 (0)