@@ -75,15 +75,17 @@ def create_border(video, color="blue", border_percent=2):
75
75
76
76
77
77
def convert_videos_to_summaries (input_videos , output_videos , target_videos ,
78
- tag , decode_hparams ):
78
+ tag , decode_hparams ,
79
+ display_ground_truth = False ):
79
80
"""Converts input, output and target videos into video summaries.
80
81
81
82
Args:
82
83
input_videos: 5-D NumPy array, (NTHWC) conditioning frames.
83
- output_videos: 5-D NumPy array, (NTHWC) ground truth .
84
+ output_videos: 5-D NumPy array, (NTHWC) model predictions .
84
85
target_videos: 5-D NumPy array, (NTHWC) target frames.
85
86
tag: tf summary tag.
86
87
decode_hparams: tf.contrib.training.HParams.
88
+ display_ground_truth: Whether or not to display ground truth videos.
87
89
Returns:
88
90
summaries: a list of tf frame-by-frame and video summaries.
89
91
"""
@@ -98,18 +100,20 @@ def convert_videos_to_summaries(input_videos, output_videos, target_videos,
98
100
output_videos = create_border (
99
101
output_videos , color = "red" , border_percent = border_percent )
100
102
101
- # Video gif.
102
103
all_input = np .concatenate ((input_videos , target_videos ), axis = 1 )
103
104
all_output = np .concatenate ((input_videos , output_videos ), axis = 1 )
104
- input_summ_vals , _ = common_video .py_gif_summary (
105
- "%s/input" % tag , all_input , max_outputs = max_outputs , fps = fps ,
106
- return_summary_value = True )
107
105
output_summ_vals , _ = common_video .py_gif_summary (
108
106
"%s/output" % tag , all_output , max_outputs = max_outputs , fps = fps ,
109
107
return_summary_value = True )
110
- all_summaries .extend (input_summ_vals )
111
108
all_summaries .extend (output_summ_vals )
112
109
110
+ # Optionally display ground truth.
111
+ if display_ground_truth :
112
+ input_summ_vals , _ = common_video .py_gif_summary (
113
+ "%s/input" % tag , all_input , max_outputs = max_outputs , fps = fps ,
114
+ return_summary_value = True )
115
+ all_summaries .extend (input_summ_vals )
116
+
113
117
# Frame-by-frame summaries
114
118
iterable = zip (all_input [:max_outputs ], all_output [:max_outputs ])
115
119
for ind , (input_video , output_video ) in enumerate (iterable ):
@@ -164,7 +168,8 @@ def display_video_hooks(hook_args):
164
168
input_videos = np .asarray (input_videos , dtype = np .uint8 )
165
169
summaries = convert_videos_to_summaries (
166
170
input_videos , output_videos , target_videos ,
167
- tag = "decode_%d" % decode_ind , decode_hparams = hook_args .decode_hparams )
171
+ tag = "decode_%d" % decode_ind , decode_hparams = hook_args .decode_hparams ,
172
+ display_ground_truth = decode_ind == 0 )
168
173
all_summaries .extend (summaries )
169
174
return all_summaries
170
175
0 commit comments