Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 8d8f8ba

Browse files
MechCoderafrozenator
authored andcommitted
Draw a border around each frame to differentiate between the conditioned and target frames during decoding.
PiperOrigin-RevId: 219179149
1 parent ea69336 commit 8d8f8ba

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

tensor2tensor/data_generators/video_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,33 @@ def resize_video_frames(images, size):
5252
return resized_images
5353

5454

55+
def create_border(video, color="blue", border_percent=2):
56+
"""Creates a border around each frame to differentiate input and target.
57+
58+
Args:
59+
video: 5-D NumPy array.
60+
color: string, "blue", "red" or "green".
61+
border_percent: Percentarge of the frame covered by the border.
62+
Returns:
63+
video: 5-D NumPy array.
64+
"""
65+
color_to_axis = {"blue": 2, "red": 0, "green": 1}
66+
axis = color_to_axis[color]
67+
_, _, height, width, _ = video.shape
68+
border_height = np.ceil(border_percent * height / 100.0).astype(np.int)
69+
border_width = np.ceil(border_percent * width / 100.0).astype(np.int)
70+
video[:, :, :border_height, :, axis] = 255
71+
video[:, :, -border_height:, :, axis] = 255
72+
video[:, :, :, :border_width, axis] = 255
73+
video[:, :, :, -border_width:, axis] = 255
74+
return video
75+
76+
5577
def display_video_hooks(hook_args):
5678
"""Hooks to display videos at decode time."""
5779
predictions = hook_args.predictions
5880
fps = hook_args.decode_hparams.frames_per_second
81+
border_percent = hook_args.decode_hparams.border_percent
5982

6083
all_summaries = []
6184
for decode_ind, decode in enumerate(predictions):
@@ -67,9 +90,17 @@ def display_video_hooks(hook_args):
6790
output_videos = np.asarray(output_videos, dtype=np.uint8)
6891
input_videos = np.asarray(input_videos, dtype=np.uint8)
6992

93+
input_videos = create_border(
94+
input_videos, color="blue", border_percent=border_percent)
95+
target_videos = create_border(
96+
target_videos, color="red", border_percent=border_percent)
97+
output_videos = create_border(
98+
output_videos, color="red", border_percent=border_percent)
99+
70100
# Video gif.
71101
all_input = np.concatenate((input_videos, target_videos), axis=1)
72102
all_output = np.concatenate((input_videos, output_videos), axis=1)
103+
73104
input_summ_vals, _ = common_video.py_gif_summary(
74105
"decode_%d/input" % decode_ind,
75106
all_input, max_outputs=10,

tensor2tensor/utils/decoding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def decode_hparams(overrides=""):
7171
# Used for video decoding.
7272
frames_per_second=10,
7373
skip_eos_postprocess=False,
74+
# Creates a blue/red border covering border_percent of the frame.
75+
border_percent=2,
7476
# Used for MLPerf compliance logging.
7577
mlperf_mode=False,
7678
mlperf_threshold=25.0,

0 commit comments

Comments
 (0)