22
22
from collections import deque
23
23
24
24
import functools
25
- import os
26
25
# Dependency imports
27
26
import gym
28
27
29
- from tensor2tensor .data_generators import generator_utils
30
28
from tensor2tensor .data_generators import problem
31
29
from tensor2tensor .data_generators import video_utils
32
30
35
33
from tensor2tensor .rl .envs import tf_atari_wrappers as atari
36
34
from tensor2tensor .rl .envs .utils import batch_env_factory
37
35
36
+ from tensor2tensor .utils import metrics
38
37
from tensor2tensor .utils import registry
39
38
40
39
import tensorflow as tf
@@ -63,6 +62,12 @@ def num_target_frames(self):
63
62
"""Number of frames to batch on one target."""
64
63
return 1
65
64
65
+ def eval_metrics (self ):
66
+ eval_metrics = [
67
+ metrics .Metrics .ACC , metrics .Metrics .ACC_PER_SEQ ,
68
+ metrics .Metrics .NEG_LOG_PERPLEXITY ]
69
+ return eval_metrics
70
+
66
71
@property
67
72
def extra_reading_spec (self ):
68
73
"""Additional data fields to store on disk and their decoders."""
@@ -116,7 +121,8 @@ def hparams(self, defaults, unused_model_hparams):
116
121
p .input_modality = {"inputs" : ("video" , 256 ),
117
122
"input_reward" : ("symbol" , self .num_rewards ),
118
123
"input_action" : ("symbol" , self .num_actions )}
119
- p .target_modality = ("video" , 256 )
124
+ p .target_modality = {"targets" : ("video" , 256 ),
125
+ "target_reward" : ("symbol" , self .num_rewards )}
120
126
p .input_space_id = problem .SpaceID .IMAGE
121
127
p .target_space_id = problem .SpaceID .IMAGE
122
128
@@ -174,34 +180,27 @@ def num_steps(self):
174
180
return 50000
175
181
176
182
177
- def moviepy_editor ():
178
- """Access to moviepy that fails gracefully without a moviepy install."""
179
- try :
180
- from moviepy import editor # pylint: disable=g-import-not-at-top
181
- except ImportError :
182
- raise ImportError ("pip install moviepy to record videos" )
183
- return editor
184
-
185
-
186
183
@registry .register_problem
187
- class GymDiscreteProblemWithAgent (problem . Problem ):
188
- """Gym environment with discrete actions and rewards."""
184
+ class GymDiscreteProblemWithAgent (GymPongRandom5k ):
185
+ """Gym environment with discrete actions and rewards and an agent ."""
189
186
190
187
def __init__ (self , * args , ** kwargs ):
191
188
super (GymDiscreteProblemWithAgent , self ).__init__ (* args , ** kwargs )
192
- self .num_channels = 3
189
+ self ._env = None
193
190
self .history_size = 2
194
191
195
192
# defaults
196
- self .environment_spec = lambda : gym .make ("PongNoFrameskip -v4" )
193
+ self .environment_spec = lambda : gym .make ("PongDeterministic -v4" )
197
194
self .in_graph_wrappers = [(atari .MaxAndSkipWrapper , {"skip" : 4 })]
198
195
self .collect_hparams = rl .atari_base ()
199
- self .num_steps = 1000
200
- self .movies = False
201
- self .movies_fps = 24
196
+ self .settable_num_steps = 1000
202
197
self .simulated_environment = None
203
198
self .warm_up = 70
204
199
200
+ @property
201
+ def num_steps (self ):
202
+ return self .settable_num_steps
203
+
205
204
def _setup (self ):
206
205
in_graph_wrappers = [(atari .ShiftRewardWrapper , {"add_value" : 2 }),
207
206
(atari .MemoryWrapper , {})] + self .in_graph_wrappers
@@ -234,85 +233,23 @@ def _setup(self):
234
233
self .data_get_op = atari .MemoryWrapper .singleton .speculum .dequeue ()
235
234
self .history_buffer = deque (maxlen = self .history_size + 1 )
236
235
237
- def example_reading_spec (self , label_repr = None ):
238
- data_fields = {
239
- "targets_encoded" : tf .FixedLenFeature ((), tf .string ),
240
- "image/format" : tf .FixedLenFeature ((), tf .string ),
241
- "action" : tf .FixedLenFeature ([1 ], tf .int64 ),
242
- "reward" : tf .FixedLenFeature ([1 ], tf .int64 ),
243
- # "done": tf.FixedLenFeature([1], tf.int64)
244
- }
245
-
246
- for x in range (self .history_size ):
247
- data_fields ["inputs_encoded_{}" .format (x )] = tf .FixedLenFeature (
248
- (), tf .string )
249
-
250
- data_items_to_decoders = {
251
- "targets" : tf .contrib .slim .tfexample_decoder .Image (
252
- image_key = "targets_encoded" ,
253
- format_key = "image/format" ,
254
- shape = [210 , 160 , 3 ],
255
- channels = 3 ),
256
- # Just do a pass through.
257
- "action" : tf .contrib .slim .tfexample_decoder .Tensor (tensor_key = "action" ),
258
- "reward" : tf .contrib .slim .tfexample_decoder .Tensor (tensor_key = "reward" ),
259
- }
260
-
261
- for x in range (self .history_size ):
262
- key = "inputs_{}" .format (x )
263
- data_items_to_decoders [key ] = tf .contrib .slim .tfexample_decoder .Image (
264
- image_key = "inputs_encoded_{}" .format (x ),
265
- format_key = "image/format" ,
266
- shape = [210 , 160 , 3 ],
267
- channels = 3 )
268
-
269
- return data_fields , data_items_to_decoders
270
-
271
- @property
272
- def num_actions (self ):
273
- return 4
274
-
275
- @property
276
- def num_rewards (self ):
277
- return 2
278
-
279
- @property
280
- def num_shards (self ):
281
- return 10
282
-
283
- @property
284
- def num_dev_shards (self ):
285
- return 1
286
-
287
- def get_action (self , observation = None ):
288
- return self .env .action_space .sample ()
289
-
290
- def hparams (self , defaults , unused_model_hparams ):
291
- p = defaults
292
- # The hard coded +1 after "symbol" refers to the fact
293
- # that 0 is a special symbol meaning padding
294
- # when symbols are e.g. 0, 1, 2, 3 we
295
- # shift them to 0, 1, 2, 3, 4.
296
- p .input_modality = {"action" : ("symbol:identity" , self .num_actions )}
297
-
298
- for x in range (self .history_size ):
299
- p .input_modality ["inputs_{}" .format (x )] = ("image" , 256 )
300
-
301
- p .target_modality = {"targets" : ("image" , 256 ),
302
- "reward" : ("symbol" , self .num_rewards + 1 )}
303
-
304
- p .input_space_id = problem .SpaceID .IMAGE
305
- p .target_space_id = problem .SpaceID .IMAGE
306
-
307
236
def restore_networks (self , sess ):
308
237
model_saver = tf .train .Saver (
309
238
tf .global_variables (".*network_parameters.*" ))
310
239
if FLAGS .agent_policy_path :
311
240
model_saver .restore (sess , FLAGS .agent_policy_path )
312
241
313
- def generator (self , data_dir , tmp_dir ):
242
+ def generate_encoded_samples (self , data_dir , tmp_dir , unused_dataset_split ):
314
243
self ._setup ()
315
- clip_files = []
244
+
245
+ # When no agent_policy_path is set, just generate random samples.
246
+ if not FLAGS .agent_policy_path :
247
+ for sample in super (GymDiscreteProblemWithAgent ,
248
+ self ).generate_encoded_samples (
249
+ data_dir , tmp_dir , unused_dataset_split ):
250
+ yield sample
251
+ return
252
+
316
253
with tf .Session () as sess :
317
254
sess .run (tf .global_variables_initializer ())
318
255
self .restore_networks (sess )
@@ -324,61 +261,33 @@ def generator(self, data_dir, tmp_dir):
324
261
observ , reward , action , _ = sess .run (self .data_get_op )
325
262
self .history_buffer .append (observ )
326
263
327
- if self .movies and pieces_generated > self .warm_up :
328
- file_name = os .path .join (tmp_dir ,
329
- "output_{}.png" .format (pieces_generated ))
330
- clip_files .append (file_name )
331
- with open (file_name , "wb" ) as f :
332
- f .write (observ )
333
-
334
- if len (self .history_buffer ) == self .history_size + 1 :
264
+ if len (self .history_buffer ) == self .history_size + 1 :
335
265
pieces_generated += 1
336
- ret_dict = {
337
- "targets_encoded" : [observ ],
338
- "image/format" : ["png" ],
339
- "action" : [int (action )],
340
- # "done": [bool(done)],
341
- "reward" : [int (reward )],
342
- }
343
- for i , v in enumerate (list (self .history_buffer )[:- 1 ]):
344
- ret_dict ["inputs_encoded_{}" .format (i )] = [v ]
266
+ ret_dict = {"image/encoded" : [observ ],
267
+ "image/format" : ["png" ],
268
+ "image/height" : [self .frame_height ],
269
+ "image/width" : [self .frame_width ],
270
+ "action" : [int (action )],
271
+ "done" : [int (False )],
272
+ "reward" : [int (reward ) - self .min_reward ]}
345
273
if pieces_generated > self .warm_up :
346
274
yield ret_dict
347
275
else :
348
276
sess .run (self .collect_trigger_op )
349
277
350
- if self .movies :
351
- clip = moviepy_editor ().ImageSequenceClip (clip_files , fps = self .movies_fps )
352
- clip_path = os .path .join (data_dir , "output_{}.mp4" .format (self .name ))
353
- clip .write_videofile (clip_path , fps = self .movies_fps , codec = "mpeg4" )
354
-
355
- def generate_data (self , data_dir , tmp_dir , task_id = - 1 ):
356
- train_paths = self .training_filepaths (
357
- data_dir , self .num_shards , shuffled = False )
358
- dev_paths = self .dev_filepaths (
359
- data_dir , self .num_dev_shards , shuffled = False )
360
- all_paths = train_paths + dev_paths
361
- generator_utils .generate_files (
362
- self .generator (data_dir , tmp_dir ), all_paths )
363
- generator_utils .shuffle_dataset (all_paths )
364
-
365
278
366
279
@registry .register_problem
367
280
class GymSimulatedDiscreteProblemWithAgent (GymDiscreteProblemWithAgent ):
368
281
"""Simulated gym environment with discrete actions and rewards."""
369
282
370
283
def __init__ (self , * args , ** kwargs ):
371
284
super (GymSimulatedDiscreteProblemWithAgent , self ).__init__ (* args , ** kwargs )
372
- # TODO(lukaszkaiser): pull it outside
373
- self .in_graph_wrappers = [(atari .TimeLimitWrapper , {"timelimit" : 150 }),
374
- (atari .MaxAndSkipWrapper , {"skip" : 4 })]
375
285
self .simulated_environment = True
376
- self .movies_fps = 2
286
+ self .debug_dump_frames_path = "/tmp/t2t_debug_dump_frames"
377
287
378
288
def restore_networks (self , sess ):
379
289
super (GymSimulatedDiscreteProblemWithAgent , self ).restore_networks (sess )
380
-
381
- # TODO(lukaszkaiser): adjust regexp for different models
290
+ # TODO(blazej): adjust regexp for different models.
382
291
env_model_loader = tf .train .Saver (tf .global_variables (".*basic_conv_gen.*" ))
383
292
sess = tf .get_default_session ()
384
293
0 commit comments