Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 36e1446

Browse files
blazejosinskiCopybara-Service
authored andcommitted
Introducing StackWrapper.
PiperOrigin-RevId: 209098352
1 parent 3b12e6c commit 36e1446

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

tensor2tensor/data_generators/gym_problems.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848

4949
def standard_atari_env_spec(env):
5050
"""Parameters of environment specification."""
51-
standard_wrappers = [[tf_atari_wrappers.StackAndSkipWrapper, {"skip": 4}]]
51+
standard_wrappers = [[tf_atari_wrappers.StackWrapper, {"history": 4}]]
5252
env_lambda = None
5353
if isinstance(env, str):
5454
env_lambda = lambda: gym.make(env)

tensor2tensor/rl/envs/tf_atari_wrappers.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,47 @@ def _reset_non_empty(self, indices):
151151
return tf.gather(self.observ, indices)
152152

153153

154+
class StackWrapper(WrapperBase):
155+
""" A wrapper which stacks previously seen frames. """
156+
157+
def __init__(self, batch_env, history=4):
158+
super(StackWrapper, self).__init__(batch_env)
159+
self.history = history
160+
self.old_shape = batch_env.observ.shape.as_list()
161+
observs_shape = self.old_shape[:-1] + [self.old_shape[-1] * self.history]
162+
observ_dtype = tf.float32
163+
self._observ = tf.Variable(tf.zeros(observs_shape, observ_dtype),
164+
trainable=False)
165+
166+
def simulate(self, action):
167+
reward, done = self._batch_env.simulate(action)
168+
with tf.control_dependencies([reward, done]):
169+
new_observ = self._batch_env.observ + 0
170+
old_observ = tf.gather(
171+
self._observ.read_value(),
172+
range(self.old_shape[-1], self.old_shape[-1] * self.history),
173+
axis=-1)
174+
with tf.control_dependencies([new_observ, old_observ]):
175+
with tf.control_dependencies([self._observ.assign(
176+
tf.concat([old_observ, new_observ], axis=-1))]):
177+
return tf.identity(reward), tf.identity(done)
178+
179+
def _reset_non_empty(self, indices):
180+
# pylint: disable=protected-access
181+
new_values = self._batch_env._reset_non_empty(indices)
182+
# pylint: enable=protected-access
183+
inx = tf.concat(
184+
[
185+
tf.ones(tf.size(tf.shape(new_values)), dtype=tf.int32)[:-1],
186+
[self.history]
187+
],
188+
axis=0)
189+
assign_op = tf.scatter_update(self._observ, indices, tf.tile(
190+
new_values, inx))
191+
with tf.control_dependencies([assign_op]):
192+
return tf.gather(self.observ, indices)
193+
194+
154195
class AutoencoderWrapper(WrapperBase):
155196
""" Transforms the observations taking the bottleneck
156197
state of an autoencoder"""

0 commit comments

Comments
 (0)