@@ -151,6 +151,47 @@ def _reset_non_empty(self, indices):
151
151
return tf .gather (self .observ , indices )
152
152
153
153
154
+ class StackWrapper (WrapperBase ):
155
+ """ A wrapper which stacks previously seen frames. """
156
+
157
+ def __init__ (self , batch_env , history = 4 ):
158
+ super (StackWrapper , self ).__init__ (batch_env )
159
+ self .history = history
160
+ self .old_shape = batch_env .observ .shape .as_list ()
161
+ observs_shape = self .old_shape [:- 1 ] + [self .old_shape [- 1 ] * self .history ]
162
+ observ_dtype = tf .float32
163
+ self ._observ = tf .Variable (tf .zeros (observs_shape , observ_dtype ),
164
+ trainable = False )
165
+
166
+ def simulate (self , action ):
167
+ reward , done = self ._batch_env .simulate (action )
168
+ with tf .control_dependencies ([reward , done ]):
169
+ new_observ = self ._batch_env .observ + 0
170
+ old_observ = tf .gather (
171
+ self ._observ .read_value (),
172
+ range (self .old_shape [- 1 ], self .old_shape [- 1 ] * self .history ),
173
+ axis = - 1 )
174
+ with tf .control_dependencies ([new_observ , old_observ ]):
175
+ with tf .control_dependencies ([self ._observ .assign (
176
+ tf .concat ([old_observ , new_observ ], axis = - 1 ))]):
177
+ return tf .identity (reward ), tf .identity (done )
178
+
179
+ def _reset_non_empty (self , indices ):
180
+ # pylint: disable=protected-access
181
+ new_values = self ._batch_env ._reset_non_empty (indices )
182
+ # pylint: enable=protected-access
183
+ inx = tf .concat (
184
+ [
185
+ tf .ones (tf .size (tf .shape (new_values )), dtype = tf .int32 )[:- 1 ],
186
+ [self .history ]
187
+ ],
188
+ axis = 0 )
189
+ assign_op = tf .scatter_update (self ._observ , indices , tf .tile (
190
+ new_values , inx ))
191
+ with tf .control_dependencies ([assign_op ]):
192
+ return tf .gather (self .observ , indices )
193
+
194
+
154
195
class AutoencoderWrapper (WrapperBase ):
155
196
""" Transforms the observations taking the bottleneck
156
197
state of an autoencoder"""
0 commit comments