8
8
import torch
9
9
import gc
10
10
import imageio
11
+ from scipy .ndimage import binary_dilation
11
12
12
13
def save_prediction (pred_mask ,output_dir ,file_name ):
13
14
save_mask = Image .fromarray (pred_mask .astype (np .uint8 ))
@@ -21,6 +22,37 @@ def colorize_mask(pred_mask):
21
22
save_mask = save_mask .convert (mode = 'RGB' )
22
23
return np .array (save_mask )
23
24
25
+ def draw_mask (img , mask , alpha = 0.5 , id_countour = False ):
26
+ img_mask = np .zeros_like (img )
27
+ img_mask = img
28
+ if id_countour :
29
+ # very slow ~ 1s per image
30
+ obj_ids = np .unique (mask )
31
+ obj_ids = obj_ids [obj_ids != 0 ]
32
+
33
+ for id in obj_ids :
34
+ # Overlay color on binary mask
35
+ if id <= 255 :
36
+ color = _palette [id * 3 :id * 3 + 3 ]
37
+ else :
38
+ color = [0 ,0 ,0 ]
39
+ foreground = img * (1 - alpha ) + np .ones_like (img ) * alpha * np .array (color )
40
+ binary_mask = (mask == id )
41
+
42
+ # Compose image
43
+ img_mask [binary_mask ] = foreground [binary_mask ]
44
+
45
+ countours = binary_dilation (binary_mask ,iterations = 1 ) ^ binary_mask
46
+ img_mask [countours , :] = 0
47
+ else :
48
+ binary_mask = (mask != 0 )
49
+ countours = binary_dilation (binary_mask ,iterations = 1 ) ^ binary_mask
50
+ foreground = img * (1 - alpha )+ colorize_mask (mask )* alpha
51
+ img_mask [binary_mask ] = foreground [binary_mask ]
52
+ img_mask [countours ,:] = 0
53
+
54
+ return img_mask .astype (img .dtype )
55
+
24
56
aot_model2ckpt = {
25
57
"deaotb" : "./ckpt/DeAOTB_PRE_YTB_DAV.pth" ,
26
58
"deaotl" : "./ckpt/DeAOTL_PRE_YTB_DAV" ,
@@ -29,11 +61,11 @@ def colorize_mask(pred_mask):
29
61
30
62
31
63
def seg_track_anything (input_video_file , aot_model , sam_gap , max_obj_num , points_per_side ):
64
+
32
65
video_name = os .path .basename (input_video_file ).split ('.' )[0 ]
33
66
io_args = {
34
67
'input_video' : f'{ input_video_file } ' ,
35
68
'output_mask_dir' : f'./assets/{ video_name } _masks' ,
36
- 'save_video' : True ,
37
69
'output_video' : f'./assets/{ video_name } _seg.mp4' , # keep same format as input video
38
70
'output_gif' : f'./assets/{ video_name } _seg.gif' ,
39
71
}
@@ -50,40 +82,35 @@ def seg_track_anything(input_video_file, aot_model, sam_gap, max_obj_num, points
50
82
output_dir = io_args ['output_mask_dir' ]
51
83
if not os .path .exists (output_dir ):
52
84
os .makedirs (output_dir )
85
+
53
86
# source video to segment
54
87
cap = cv2 .VideoCapture (io_args ['input_video' ])
55
88
fps = cap .get (cv2 .CAP_PROP_FPS )
56
89
# output masks
57
90
output_dir = io_args ['output_mask_dir' ]
58
91
if not os .path .exists (output_dir ):
59
92
os .makedirs (output_dir )
60
- if io_args ['save_video' ]:
61
- width = int (cap .get (cv2 .CAP_PROP_FRAME_WIDTH ))
62
- height = int (cap .get (cv2 .CAP_PROP_FRAME_HEIGHT ))
63
- num_frames = int (cap .get (cv2 .CAP_PROP_FRAME_COUNT ))
64
- fourcc = int (cap .get (cv2 .CAP_PROP_FOURCC ))
65
- out = cv2 .VideoWriter (io_args ['output_video' ], fourcc , fps , (width , height ))
66
93
pred_list = []
67
94
95
+ # start to track
68
96
torch .cuda .empty_cache ()
69
97
gc .collect ()
70
98
sam_gap = segtracker_args ['sam_gap' ]
71
99
frame_idx = 0
72
100
segtracker = SegTracker (segtracker_args ,sam_args ,aot_args )
73
101
segtracker .restart_tracker ()
74
102
75
-
76
103
with torch .cuda .amp .autocast ():
77
104
while cap .isOpened ():
78
105
ret , frame = cap .read ()
79
106
if not ret :
80
107
break
81
-
108
+ frame = cv2 . cvtColor ( frame , cv2 . COLOR_BGR2RGB )
82
109
if frame_idx == 0 :
83
110
pred_mask = segtracker .seg (frame )
84
111
torch .cuda .empty_cache ()
85
112
gc .collect ()
86
- segtracker .add_reference (frame , pred_mask , frame_idx )
113
+ segtracker .add_reference (frame , pred_mask )
87
114
elif (frame_idx % sam_gap ) == 0 :
88
115
seg_mask = segtracker .seg (frame )
89
116
torch .cuda .empty_cache ()
@@ -94,30 +121,62 @@ def seg_track_anything(input_video_file, aot_model, sam_gap, max_obj_num, points
94
121
save_prediction (new_obj_mask ,output_dir ,str (frame_idx )+ '_new.png' )
95
122
pred_mask = track_mask + new_obj_mask
96
123
# segtracker.restart_tracker()
97
- segtracker .add_reference (frame , pred_mask , frame_idx )
124
+ segtracker .add_reference (frame , pred_mask )
98
125
else :
99
126
pred_mask = segtracker .track (frame ,update_memory = True )
100
127
torch .cuda .empty_cache ()
101
128
gc .collect ()
102
129
103
130
save_prediction (pred_mask ,output_dir ,str (frame_idx )+ '.png' )
104
- masked_frame = (frame * 0.3 + colorize_mask (pred_mask )* 0.7 ).astype (np .uint8 )
105
- pred_list .append (masked_frame )
106
- if io_args ['save_video' ]:
107
- out .write (masked_frame )
131
+ pred_list .append (pred_mask )
108
132
109
- print ("processed and saved mask for frame {}, obj_num {}" .format (frame_idx ,segtracker .get_obj_num ()),end = '\r ' )
133
+
134
+ print ("processed frame {}, obj_num {}" .format (frame_idx ,segtracker .get_obj_num ()),end = '\r ' )
110
135
frame_idx += 1
136
+ cap .release ()
137
+ print ('\n finished' )
138
+
139
+ ######################
140
+ # Visualization
141
+ ######################
142
+
143
+ # draw pred mask on frame and save as a video
144
+ cap = cv2 .VideoCapture (io_args ['input_video' ])
145
+ fps = cap .get (cv2 .CAP_PROP_FPS )
146
+ width = int (cap .get (cv2 .CAP_PROP_FRAME_WIDTH ))
147
+ height = int (cap .get (cv2 .CAP_PROP_FRAME_HEIGHT ))
148
+ num_frames = int (cap .get (cv2 .CAP_PROP_FRAME_COUNT ))
149
+ fourcc = int (cap .get (cv2 .CAP_PROP_FOURCC ))
150
+ out = cv2 .VideoWriter (io_args ['output_video' ], fourcc , fps , (width , height ))
151
+ # for .mp4
152
+ frame_idx = 0
153
+ while cap .isOpened ():
154
+ ret , frame = cap .read ()
155
+ if not ret :
156
+ break
157
+ frame = cv2 .cvtColor (frame ,cv2 .COLOR_BGR2RGB )
158
+ pred_mask = pred_list [frame_idx ]
159
+ masked_frame = draw_mask (frame ,pred_mask )
160
+ # masked_frame = masked_pred_list[frame_idx]
161
+ masked_frame = cv2 .cvtColor (masked_frame ,cv2 .COLOR_RGB2BGR )
162
+ out .write (masked_frame )
163
+ print ('frame {} writed' .format (frame_idx ),end = '\r ' )
164
+ frame_idx += 1
165
+ out .release ()
111
166
cap .release ()
112
- if io_args ['save_video' ]:
113
- out .release ()
114
- print ("\n {} saved" .format (io_args ['output_video' ]))
115
- # save a gif
116
- imageio .mimsave (io_args ['output_gif' ],pred_list ,fps = fps )
117
- print ("{} saved" .format (io_args ['output_gif' ]))
167
+ print ("\n {} saved" .format (io_args ['output_video' ]))
118
168
print ('\n finished' )
119
169
170
+ # save colorized masks as a gif
171
+ imageio .mimsave (io_args ['output_gif' ],pred_list ,fps = fps )
172
+ print ("{} saved" .format (io_args ['output_gif' ]))
173
+
120
174
# zip predicted mask
121
175
os .system (f"zip -r ./assets/{ video_name } _pred_mask.zip { io_args ['output_mask_dir' ]} " )
122
176
177
+ # manually release memory (after cuda out of memory)
178
+ del segtracker
179
+ torch .cuda .empty_cache ()
180
+ gc .collect ()
181
+
123
182
return io_args ["output_video" ], f"./assets/{ video_name } _pred_mask.zip"
0 commit comments