Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 8070eb4

Browse files
MechCoderCopybara-Service
authored andcommitted
Minor, display ground truth videos only once during decode.
PiperOrigin-RevId: 221561341
1 parent 9c6402b commit 8070eb4

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

tensor2tensor/data_generators/video_utils.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,17 @@ def create_border(video, color="blue", border_percent=2):
7575

7676

7777
def convert_videos_to_summaries(input_videos, output_videos, target_videos,
78-
tag, decode_hparams):
78+
tag, decode_hparams,
79+
display_ground_truth=False):
7980
"""Converts input, output and target videos into video summaries.
8081
8182
Args:
8283
input_videos: 5-D NumPy array, (NTHWC) conditioning frames.
83-
output_videos: 5-D NumPy array, (NTHWC) ground truth.
84+
output_videos: 5-D NumPy array, (NTHWC) model predictions.
8485
target_videos: 5-D NumPy array, (NTHWC) target frames.
8586
tag: tf summary tag.
8687
decode_hparams: tf.contrib.training.HParams.
88+
display_ground_truth: Whether or not to display ground truth videos.
8789
Returns:
8890
summaries: a list of tf frame-by-frame and video summaries.
8991
"""
@@ -98,18 +100,20 @@ def convert_videos_to_summaries(input_videos, output_videos, target_videos,
98100
output_videos = create_border(
99101
output_videos, color="red", border_percent=border_percent)
100102

101-
# Video gif.
102103
all_input = np.concatenate((input_videos, target_videos), axis=1)
103104
all_output = np.concatenate((input_videos, output_videos), axis=1)
104-
input_summ_vals, _ = common_video.py_gif_summary(
105-
"%s/input" % tag, all_input, max_outputs=max_outputs, fps=fps,
106-
return_summary_value=True)
107105
output_summ_vals, _ = common_video.py_gif_summary(
108106
"%s/output" % tag, all_output, max_outputs=max_outputs, fps=fps,
109107
return_summary_value=True)
110-
all_summaries.extend(input_summ_vals)
111108
all_summaries.extend(output_summ_vals)
112109

110+
# Optionally display ground truth.
111+
if display_ground_truth:
112+
input_summ_vals, _ = common_video.py_gif_summary(
113+
"%s/input" % tag, all_input, max_outputs=max_outputs, fps=fps,
114+
return_summary_value=True)
115+
all_summaries.extend(input_summ_vals)
116+
113117
# Frame-by-frame summaries
114118
iterable = zip(all_input[:max_outputs], all_output[:max_outputs])
115119
for ind, (input_video, output_video) in enumerate(iterable):
@@ -164,7 +168,8 @@ def display_video_hooks(hook_args):
164168
input_videos = np.asarray(input_videos, dtype=np.uint8)
165169
summaries = convert_videos_to_summaries(
166170
input_videos, output_videos, target_videos,
167-
tag="decode_%d" % decode_ind, decode_hparams=hook_args.decode_hparams)
171+
tag="decode_%d" % decode_ind, decode_hparams=hook_args.decode_hparams,
172+
display_ground_truth=decode_ind == 0)
168173
all_summaries.extend(summaries)
169174
return all_summaries
170175

tensor2tensor/data_generators/video_utils_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,11 @@ def testConvertPredictionsToVideoSummaries(self):
7373
hparams=decode_hparams, decode_hparams=decode_hparams,
7474
predictions=predictions)
7575
summaries = video_utils.display_video_hooks(decode_hooks)
76-
# for {random, psnr_max, psnr_min, ssim_max, ssim_min}
76+
# for {psnr_max, psnr_min, ssim_max, ssim_min}
77+
# 10 output vids + 10 frame-by-frame.
78+
# for {random}
7779
# 10 input vids + 10 output vids + 10 frame-by-frame.
78-
self.assertEqual(len(summaries), 150)
80+
self.assertEqual(len(summaries), 110)
7981
for summary in summaries:
8082
self.assertTrue(isinstance(summary, tf.Summary.Value))
8183

0 commit comments

Comments
 (0)