Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 15bd9e3

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Style corrections to recent RL code.
PiperOrigin-RevId: 193623810
1 parent 9c751a2 commit 15bd9e3

File tree

9 files changed

+182
-200
lines changed

9 files changed

+182
-200
lines changed

tensor2tensor/data_generators/gym.py

Lines changed: 38 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,9 @@
2222
from collections import deque
2323

2424
import functools
25-
import os
2625
# Dependency imports
2726
import gym
2827

29-
from tensor2tensor.data_generators import generator_utils
3028
from tensor2tensor.data_generators import problem
3129
from tensor2tensor.data_generators import video_utils
3230

@@ -35,6 +33,7 @@
3533
from tensor2tensor.rl.envs import tf_atari_wrappers as atari
3634
from tensor2tensor.rl.envs.utils import batch_env_factory
3735

36+
from tensor2tensor.utils import metrics
3837
from tensor2tensor.utils import registry
3938

4039
import tensorflow as tf
@@ -63,6 +62,12 @@ def num_target_frames(self):
6362
"""Number of frames to batch on one target."""
6463
return 1
6564

65+
def eval_metrics(self):
66+
eval_metrics = [
67+
metrics.Metrics.ACC, metrics.Metrics.ACC_PER_SEQ,
68+
metrics.Metrics.NEG_LOG_PERPLEXITY]
69+
return eval_metrics
70+
6671
@property
6772
def extra_reading_spec(self):
6873
"""Additional data fields to store on disk and their decoders."""
@@ -116,7 +121,8 @@ def hparams(self, defaults, unused_model_hparams):
116121
p.input_modality = {"inputs": ("video", 256),
117122
"input_reward": ("symbol", self.num_rewards),
118123
"input_action": ("symbol", self.num_actions)}
119-
p.target_modality = ("video", 256)
124+
p.target_modality = {"targets": ("video", 256),
125+
"target_reward": ("symbol", self.num_rewards)}
120126
p.input_space_id = problem.SpaceID.IMAGE
121127
p.target_space_id = problem.SpaceID.IMAGE
122128

@@ -174,34 +180,27 @@ def num_steps(self):
174180
return 50000
175181

176182

177-
def moviepy_editor():
178-
"""Access to moviepy that fails gracefully without a moviepy install."""
179-
try:
180-
from moviepy import editor # pylint: disable=g-import-not-at-top
181-
except ImportError:
182-
raise ImportError("pip install moviepy to record videos")
183-
return editor
184-
185-
186183
@registry.register_problem
187-
class GymDiscreteProblemWithAgent(problem.Problem):
188-
"""Gym environment with discrete actions and rewards."""
184+
class GymDiscreteProblemWithAgent(GymPongRandom5k):
185+
"""Gym environment with discrete actions and rewards and an agent."""
189186

190187
def __init__(self, *args, **kwargs):
191188
super(GymDiscreteProblemWithAgent, self).__init__(*args, **kwargs)
192-
self.num_channels = 3
189+
self._env = None
193190
self.history_size = 2
194191

195192
# defaults
196-
self.environment_spec = lambda: gym.make("PongNoFrameskip-v4")
193+
self.environment_spec = lambda: gym.make("PongDeterministic-v4")
197194
self.in_graph_wrappers = [(atari.MaxAndSkipWrapper, {"skip": 4})]
198195
self.collect_hparams = rl.atari_base()
199-
self.num_steps = 1000
200-
self.movies = False
201-
self.movies_fps = 24
196+
self.settable_num_steps = 1000
202197
self.simulated_environment = None
203198
self.warm_up = 70
204199

200+
@property
201+
def num_steps(self):
202+
return self.settable_num_steps
203+
205204
def _setup(self):
206205
in_graph_wrappers = [(atari.ShiftRewardWrapper, {"add_value": 2}),
207206
(atari.MemoryWrapper, {})] + self.in_graph_wrappers
@@ -234,85 +233,23 @@ def _setup(self):
234233
self.data_get_op = atari.MemoryWrapper.singleton.speculum.dequeue()
235234
self.history_buffer = deque(maxlen=self.history_size+1)
236235

237-
def example_reading_spec(self, label_repr=None):
238-
data_fields = {
239-
"targets_encoded": tf.FixedLenFeature((), tf.string),
240-
"image/format": tf.FixedLenFeature((), tf.string),
241-
"action": tf.FixedLenFeature([1], tf.int64),
242-
"reward": tf.FixedLenFeature([1], tf.int64),
243-
# "done": tf.FixedLenFeature([1], tf.int64)
244-
}
245-
246-
for x in range(self.history_size):
247-
data_fields["inputs_encoded_{}".format(x)] = tf.FixedLenFeature(
248-
(), tf.string)
249-
250-
data_items_to_decoders = {
251-
"targets": tf.contrib.slim.tfexample_decoder.Image(
252-
image_key="targets_encoded",
253-
format_key="image/format",
254-
shape=[210, 160, 3],
255-
channels=3),
256-
# Just do a pass through.
257-
"action": tf.contrib.slim.tfexample_decoder.Tensor(tensor_key="action"),
258-
"reward": tf.contrib.slim.tfexample_decoder.Tensor(tensor_key="reward"),
259-
}
260-
261-
for x in range(self.history_size):
262-
key = "inputs_{}".format(x)
263-
data_items_to_decoders[key] = tf.contrib.slim.tfexample_decoder.Image(
264-
image_key="inputs_encoded_{}".format(x),
265-
format_key="image/format",
266-
shape=[210, 160, 3],
267-
channels=3)
268-
269-
return data_fields, data_items_to_decoders
270-
271-
@property
272-
def num_actions(self):
273-
return 4
274-
275-
@property
276-
def num_rewards(self):
277-
return 2
278-
279-
@property
280-
def num_shards(self):
281-
return 10
282-
283-
@property
284-
def num_dev_shards(self):
285-
return 1
286-
287-
def get_action(self, observation=None):
288-
return self.env.action_space.sample()
289-
290-
def hparams(self, defaults, unused_model_hparams):
291-
p = defaults
292-
# The hard coded +1 after "symbol" refers to the fact
293-
# that 0 is a special symbol meaning padding
294-
# when symbols are e.g. 0, 1, 2, 3 we
295-
# shift them to 0, 1, 2, 3, 4.
296-
p.input_modality = {"action": ("symbol:identity", self.num_actions)}
297-
298-
for x in range(self.history_size):
299-
p.input_modality["inputs_{}".format(x)] = ("image", 256)
300-
301-
p.target_modality = {"targets": ("image", 256),
302-
"reward": ("symbol", self.num_rewards + 1)}
303-
304-
p.input_space_id = problem.SpaceID.IMAGE
305-
p.target_space_id = problem.SpaceID.IMAGE
306-
307236
def restore_networks(self, sess):
308237
model_saver = tf.train.Saver(
309238
tf.global_variables(".*network_parameters.*"))
310239
if FLAGS.agent_policy_path:
311240
model_saver.restore(sess, FLAGS.agent_policy_path)
312241

313-
def generator(self, data_dir, tmp_dir):
242+
def generate_encoded_samples(self, data_dir, tmp_dir, unused_dataset_split):
314243
self._setup()
315-
clip_files = []
244+
245+
# When no agent_policy_path is set, just generate random samples.
246+
if not FLAGS.agent_policy_path:
247+
for sample in super(GymDiscreteProblemWithAgent,
248+
self).generate_encoded_samples(
249+
data_dir, tmp_dir, unused_dataset_split):
250+
yield sample
251+
return
252+
316253
with tf.Session() as sess:
317254
sess.run(tf.global_variables_initializer())
318255
self.restore_networks(sess)
@@ -324,61 +261,33 @@ def generator(self, data_dir, tmp_dir):
324261
observ, reward, action, _ = sess.run(self.data_get_op)
325262
self.history_buffer.append(observ)
326263

327-
if self.movies and pieces_generated > self.warm_up:
328-
file_name = os.path.join(tmp_dir,
329-
"output_{}.png".format(pieces_generated))
330-
clip_files.append(file_name)
331-
with open(file_name, "wb") as f:
332-
f.write(observ)
333-
334-
if len(self.history_buffer) == self.history_size+1:
264+
if len(self.history_buffer) == self.history_size + 1:
335265
pieces_generated += 1
336-
ret_dict = {
337-
"targets_encoded": [observ],
338-
"image/format": ["png"],
339-
"action": [int(action)],
340-
# "done": [bool(done)],
341-
"reward": [int(reward)],
342-
}
343-
for i, v in enumerate(list(self.history_buffer)[:-1]):
344-
ret_dict["inputs_encoded_{}".format(i)] = [v]
266+
ret_dict = {"image/encoded": [observ],
267+
"image/format": ["png"],
268+
"image/height": [self.frame_height],
269+
"image/width": [self.frame_width],
270+
"action": [int(action)],
271+
"done": [int(False)],
272+
"reward": [int(reward) - self.min_reward]}
345273
if pieces_generated > self.warm_up:
346274
yield ret_dict
347275
else:
348276
sess.run(self.collect_trigger_op)
349277

350-
if self.movies:
351-
clip = moviepy_editor().ImageSequenceClip(clip_files, fps=self.movies_fps)
352-
clip_path = os.path.join(data_dir, "output_{}.mp4".format(self.name))
353-
clip.write_videofile(clip_path, fps=self.movies_fps, codec="mpeg4")
354-
355-
def generate_data(self, data_dir, tmp_dir, task_id=-1):
356-
train_paths = self.training_filepaths(
357-
data_dir, self.num_shards, shuffled=False)
358-
dev_paths = self.dev_filepaths(
359-
data_dir, self.num_dev_shards, shuffled=False)
360-
all_paths = train_paths + dev_paths
361-
generator_utils.generate_files(
362-
self.generator(data_dir, tmp_dir), all_paths)
363-
generator_utils.shuffle_dataset(all_paths)
364-
365278

366279
@registry.register_problem
367280
class GymSimulatedDiscreteProblemWithAgent(GymDiscreteProblemWithAgent):
368281
"""Simulated gym environment with discrete actions and rewards."""
369282

370283
def __init__(self, *args, **kwargs):
371284
super(GymSimulatedDiscreteProblemWithAgent, self).__init__(*args, **kwargs)
372-
# TODO(lukaszkaiser): pull it outside
373-
self.in_graph_wrappers = [(atari.TimeLimitWrapper, {"timelimit": 150}),
374-
(atari.MaxAndSkipWrapper, {"skip": 4})]
375285
self.simulated_environment = True
376-
self.movies_fps = 2
286+
self.debug_dump_frames_path = "/tmp/t2t_debug_dump_frames"
377287

378288
def restore_networks(self, sess):
379289
super(GymSimulatedDiscreteProblemWithAgent, self).restore_networks(sess)
380-
381-
# TODO(lukaszkaiser): adjust regexp for different models
290+
# TODO(blazej): adjust regexp for different models.
382291
env_model_loader = tf.train.Saver(tf.global_variables(".*basic_conv_gen.*"))
383292
sess = tf.get_default_session()
384293

tensor2tensor/data_generators/video_utils.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22+
import os
23+
2224
# Dependency imports
2325

26+
import six
27+
2428
from tensor2tensor.data_generators import generator_utils
2529
from tensor2tensor.data_generators import image_utils
2630
from tensor2tensor.data_generators import problem
@@ -43,6 +47,12 @@ def resize_video_frames(images, size):
4347
class VideoProblem(problem.Problem):
4448
"""Base class for problems with videos."""
4549

50+
def __init__(self, *args, **kwargs):
51+
super(VideoProblem, self).__init__(*args, **kwargs)
52+
# Path to a directory to dump generated frames as png for debugging.
53+
# If empty, no debug frames will be generated.
54+
self.debug_dump_frames_path = ""
55+
4656
@property
4757
def num_channels(self):
4858
"""Number of color channels in each frame."""
@@ -157,7 +167,7 @@ def features_from_batch(batched_prefeatures):
157167
Features dictionary with joint features per-frame.
158168
"""
159169
features = {}
160-
for k, v in batched_prefeatures.iteritems():
170+
for k, v in six.iteritems(batched_prefeatures):
161171
if k == "frame": # We rename past frames to inputs and targets.
162172
s1, s2 = split_on_batch(v)
163173
# Reshape just to make sure shapes are right and set.
@@ -242,13 +252,27 @@ def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
242252
if width != self.frame_width:
243253
raise ValueError("Generated frame has width %d while the class "
244254
"assumes width %d." % (width, self.frame_width))
245-
encoded_frame = image_utils.encode_images_as_png([unencoded_frame]).next()
255+
encoded_frame = six.next(
256+
image_utils.encode_images_as_png([unencoded_frame]))
246257
features["image/encoded"] = [encoded_frame]
247258
features["image/format"] = ["png"]
248259
features["image/height"] = [height]
249260
features["image/width"] = [width]
250261
yield features
251262

263+
def generate_encoded_samples_debug(self, data_dir, tmp_dir, dataset_split):
264+
"""Generate samples of the encoded frames and dump for debug if needed."""
265+
counter = 0
266+
for sample in self.generate_encoded_samples(
267+
data_dir, tmp_dir, dataset_split):
268+
if self.debug_dump_frames_path:
269+
path = os.path.join(self.debug_dump_frames_path,
270+
"frame_%d.png" % counter)
271+
with tf.gfile.Open(path, "wb") as f:
272+
f.write(sample["image/encoded"][0])
273+
counter += 1
274+
yield sample
275+
252276
def generate_data(self, data_dir, tmp_dir, task_id=-1):
253277
"""The function generating the data."""
254278
filepath_fns = {
@@ -268,10 +292,11 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
268292
if self.is_generate_per_split:
269293
for split, paths in split_paths:
270294
generator_utils.generate_files(
271-
self.generate_encoded_samples(data_dir, tmp_dir, split), paths)
295+
self.generate_encoded_samples_debug(
296+
data_dir, tmp_dir, split), paths)
272297
else:
273298
generator_utils.generate_files(
274-
self.generate_encoded_samples(
299+
self.generate_encoded_samples_debug(
275300
data_dir, tmp_dir, problem.DatasetSplit.TRAIN), all_paths)
276301

277302

tensor2tensor/layers/modalities.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,8 @@ def bottom(self, inputs):
464464
inputs_shape = common_layers.shape_list(inputs)
465465
if len(inputs_shape) != 5:
466466
raise ValueError("Assuming videos given as tensors in the format "
467-
"[batch, time, height, width, channels].")
467+
"[batch, time, height, width, channels] but got one "
468+
"of shape: %s" % str(inputs_shape))
468469
if not context.in_eager_mode():
469470
tf.summary.image("inputs", tf.cast(inputs[:, -1, :, :, :], tf.uint8),
470471
max_outputs=1)
@@ -484,7 +485,8 @@ def targets_bottom(self, inputs):
484485
inputs_shape = common_layers.shape_list(inputs)
485486
if len(inputs_shape) != 5:
486487
raise ValueError("Assuming videos given as tensors in the format "
487-
"[batch, time, height, width, channels].")
488+
"[batch, time, height, width, channels] but got one "
489+
"of shape: %s" % str(inputs_shape))
488490
if not context.in_eager_mode():
489491
tf.summary.image(
490492
"targets_bottom", tf.cast(inputs[:, -1, :, :, :], tf.uint8),

0 commit comments

Comments
 (0)