Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 7836aa7

Browse files
Lukasz KaiserCopybara-Service
authored andcommitted
Check if ffmpeg is installed and guard against it in RL env.
PiperOrigin-RevId: 228910916
1 parent 77f4437 commit 7836aa7

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

tensor2tensor/layers/common_video.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,8 +357,8 @@ def _encode_gif(images, fps):
357357
"""Encodes numpy images into gif string.
358358
359359
Args:
360-
images: A 5-D `uint8` `np.array` (or a list of 4-D images) of shape
361-
`[batch_size, time, height, width, channels]` where `channels` is 1 or 3.
360+
images: A 4-D `uint8` `np.array` (or a list of 3-D images) of shape
361+
`[time, height, width, channels]` where `channels` is 1 or 3.
362362
fps: frames per second of the animation
363363
364364
Returns:
@@ -372,6 +372,16 @@ def _encode_gif(images, fps):
372372
return writer.finish()
373373

374374

375+
def ffmpeg_works():
376+
"""Tries to encode images with ffmpeg to check if it works."""
377+
images = np.zeros((2, 32, 32, 3), dtype=np.uint8)
378+
try:
379+
_encode_gif(images, 2)
380+
return True
381+
except (IOError, OSError):
382+
return False
383+
384+
375385
def py_gif_summary(tag, images, max_outputs, fps, return_summary_value=False):
376386
"""Outputs a `Summary` protocol buffer with gif animations.
377387
@@ -697,7 +707,7 @@ def __init__(self, fps, output_path=None, file_format="gif"):
697707
def __init_ffmpeg(self, image_shape):
698708
"""Initializes ffmpeg to write frames."""
699709
import itertools # pylint: disable=g-import-not-at-top
700-
from subprocess import Popen, PIPE # pylint: disable=g-import-not-at-top,g-multiple-import
710+
from subprocess import Popen, PIPE # pylint: disable=g-import-not-at-top,g-multiple-import,g-importing-member
701711
ffmpeg = "ffmpeg"
702712
height, width, channels = image_shape
703713
self.cmd = [

tensor2tensor/rl/envs/simulated_batch_env.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(
115115
"""Batch of environments inside the TensorFlow graph."""
116116
super(SimulatedBatchEnv, self).__init__(observation_space, action_space)
117117

118+
self._ffmpeg_works = common_video.ffmpeg_works()
118119
self.batch_size = batch_size
119120
self._min_reward = reward_range[0]
120121
self._num_frames = frame_stack_size
@@ -267,6 +268,8 @@ def history_observations(self):
267268
return self.history_buffer.get_all_elements()
268269

269270
def _video_dump_frame(self, obs, rews):
271+
if not self._ffmpeg_works:
272+
return
270273
if self._video_writer is None:
271274
self._video_counter += 1
272275
self._video_writer = common_video.WholeVideoWriter(
@@ -280,6 +283,8 @@ def _video_dump_frame(self, obs, rews):
280283
self._video_writer.write(np.concatenate([np.asarray(img), obs[0]], axis=0))
281284

282285
def _video_dump_frames(self, obs):
286+
if not self._ffmpeg_works:
287+
return
283288
zeros = np.zeros(obs.shape[0])
284289
for i in range(obs.shape[1]):
285290
self._video_dump_frame(obs[:, i, :], zeros)

0 commit comments

Comments
 (0)