Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit ffff8ae

Browse files
T2T TeamCopybara-Service
authored andcommitted
Fixing video decoding
PiperOrigin-RevId: 200653976
1 parent 3d7cd00 commit ffff8ae

File tree

2 files changed

+38
-5
lines changed

2 files changed

+38
-5
lines changed

tensor2tensor/models/research/next_frame.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,9 @@ def logits_to_samples(logits):
139139
targets_shape = [self.hparams.batch_size,
140140
self.hparams.video_num_target_frames, 1, 1, num_channels]
141141
features["targets"] = tf.zeros(targets_shape, dtype=tf.int32)
142-
features["target_reward"] = tf.zeros(
143-
[targets_shape[0], 1, 1], dtype=tf.int32)
142+
if "target_reward" in self.hparams.problem_hparams.target_modality:
143+
features["target_reward"] = tf.zeros(
144+
[targets_shape[0], 1, 1], dtype=tf.int32)
144145
logits, _ = self(features) # pylint: disable=not-callable
145146
if isinstance(logits, dict):
146147
results = {}

tensor2tensor/utils/decoding.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,22 @@ def log_decode_results(inputs,
6868
identity_output=False,
6969
log_results=True):
7070
"""Log inference results."""
71+
72+
# TODO(lukaszkaiser) refactor this into feature_encoder
73+
is_video = "video" in problem_name
74+
if is_video:
75+
def fix_and_save_video(vid, prefix):
76+
save_path_template = os.path.join(
77+
model_dir, "%s_%s_%d_{}.png" % (problem_name, prefix, prediction_idx))
78+
# this is only required for predictions
79+
if vid.shape[-1] == 1:
80+
vid = np.squeeze(vid, axis=-1)
81+
save_video(vid, save_path_template)
82+
tf.logging.info("Saving video: {}".format(prediction_idx))
83+
fix_and_save_video(inputs, "inputs")
84+
fix_and_save_video(outputs, "outputs")
85+
fix_and_save_video(targets, "targets")
86+
7187
is_image = "image" in problem_name
7288
decoded_inputs = None
7389
if is_image and save_images:
@@ -80,7 +96,7 @@ def log_decode_results(inputs,
8096
else:
8197
decoded_inputs = inputs_vocab.decode(_save_until_eos(inputs, is_image))
8298

83-
if log_results:
99+
if log_results and not is_video:
84100
tf.logging.info("Inference results INPUT: %s" % decoded_inputs)
85101

86102
decoded_targets = None
@@ -93,8 +109,9 @@ def log_decode_results(inputs,
93109
decoded_outputs = targets_vocab.decode(_save_until_eos(outputs, is_image))
94110
if targets is not None and log_results:
95111
decoded_targets = targets_vocab.decode(_save_until_eos(targets, is_image))
96-
tf.logging.info("Inference results OUTPUT: %s" % decoded_outputs)
97-
if targets is not None and log_results:
112+
if not is_video:
113+
tf.logging.info("Inference results OUTPUT: %s" % decoded_outputs)
114+
if targets is not None and log_results and not is_video:
98115
tf.logging.info("Inference results TARGET: %s" % decoded_targets)
99116
return decoded_inputs, decoded_outputs, decoded_targets
100117

@@ -518,6 +535,21 @@ def _interactive_input_fn(hparams, decode_hp):
518535
yield features
519536

520537

538+
def save_video(video, save_path_template):
539+
"""Save frames of the videos into files."""
540+
try:
541+
from PIL import Image # pylint: disable=g-import-not-at-top
542+
except ImportError as e:
543+
tf.logging.warning(
544+
"Showing and saving an image requires PIL library to be "
545+
"installed: %s", e)
546+
raise NotImplementedError("Image display and save not implemented.")
547+
548+
for i, frame in enumerate(video):
549+
save_path = save_path_template.format(i)
550+
Image.fromarray(np.uint8(frame)).save(save_path)
551+
552+
521553
def show_and_save_image(img, save_path):
522554
try:
523555
import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top

0 commit comments

Comments
 (0)