@@ -136,7 +136,7 @@ def train_autoencoder(problem_name, data_dir, output_dir, hparams, epoch):
136
136
137
137
def train_agent (problem_name , agent_model_dir ,
138
138
event_dir , world_model_dir , epoch_data_dir , hparams ,
139
- autoencoder_path = None ):
139
+ autoencoder_path = None , epoch = 0 ):
140
140
"""Train the PPO agent in the simulated environment."""
141
141
gym_problem = registry .problem (problem_name )
142
142
ppo_hparams = trainer_lib .create_hparams (hparams .ppo_params )
@@ -151,6 +151,8 @@ def train_agent(problem_name, agent_model_dir,
151
151
ppo_hparams .num_agents = hparams .ppo_num_agents
152
152
ppo_hparams .problem = gym_problem
153
153
ppo_hparams .world_model_dir = world_model_dir
154
+ if hparams .ppo_learning_rate :
155
+ ppo_hparams .learning_rate = hparams .ppo_learning_rate
154
156
# 4x for the StackAndSkipWrapper minus one to always finish for reporting.
155
157
ppo_time_limit = (ppo_hparams .epoch_length - 1 ) * 4
156
158
@@ -169,7 +171,7 @@ def train_agent(problem_name, agent_model_dir,
169
171
"autoencoder_path" : autoencoder_path ,
170
172
}):
171
173
rl_trainer_lib .train (ppo_hparams , gym_problem .env_name , event_dir ,
172
- agent_model_dir )
174
+ agent_model_dir , epoch = epoch )
173
175
174
176
175
177
def evaluate_world_model (simulated_problem_name , problem_name , hparams ,
@@ -281,19 +283,32 @@ def encode_env_frames(problem_name, ae_problem_name, autoencoder_path,
281
283
ae_training_paths = ae_problem .training_filepaths (epoch_data_dir , 10 , True )
282
284
ae_eval_paths = ae_problem .dev_filepaths (epoch_data_dir , 1 , True )
283
285
286
+ skip_train = False
287
+ skip_eval = False
288
+ for path in ae_training_paths :
289
+ if tf .gfile .Exists (path ):
290
+ skip_train = True
291
+ break
292
+ for path in ae_eval_paths :
293
+ if tf .gfile .Exists (path ):
294
+ skip_eval = True
295
+ break
296
+
284
297
# Encode train data
285
- dataset = problem .dataset (tf .estimator .ModeKeys .TRAIN , epoch_data_dir ,
286
- shuffle_files = False , output_buffer_size = 100 ,
287
- preprocess = False )
288
- encode_dataset (model , dataset , problem , ae_hparams , autoencoder_path ,
289
- ae_training_paths )
298
+ if not skip_train :
299
+ dataset = problem .dataset (tf .estimator .ModeKeys .TRAIN , epoch_data_dir ,
300
+ shuffle_files = False , output_buffer_size = 100 ,
301
+ preprocess = False )
302
+ encode_dataset (model , dataset , problem , ae_hparams , autoencoder_path ,
303
+ ae_training_paths )
290
304
291
305
# Encode eval data
292
- dataset = problem .dataset (tf .estimator .ModeKeys .EVAL , epoch_data_dir ,
293
- shuffle_files = False , output_buffer_size = 100 ,
294
- preprocess = False )
295
- encode_dataset (model , dataset , problem , ae_hparams , autoencoder_path ,
296
- ae_eval_paths )
306
+ if not skip_eval :
307
+ dataset = problem .dataset (tf .estimator .ModeKeys .EVAL , epoch_data_dir ,
308
+ shuffle_files = False , output_buffer_size = 100 ,
309
+ preprocess = False )
310
+ encode_dataset (model , dataset , problem , ae_hparams , autoencoder_path ,
311
+ ae_eval_paths )
297
312
298
313
299
314
def check_problems (problem_names ):
@@ -392,7 +407,7 @@ def training_loop(hparams, output_dir, report_fn=None, report_metric=None):
392
407
ppo_model_dir = ppo_event_dir
393
408
train_agent (world_model_problem , ppo_model_dir ,
394
409
ppo_event_dir , directories ["world_model" ], epoch_data_dir ,
395
- hparams , autoencoder_path = autoencoder_model_dir )
410
+ hparams , autoencoder_path = autoencoder_model_dir , epoch = epoch )
396
411
397
412
# Collect data from the real environment.
398
413
log ("Generating real environment data" )
@@ -465,6 +480,7 @@ def rl_modelrl_base():
465
480
# though it is not necessary.
466
481
ppo_epoch_length = 60 ,
467
482
ppo_num_agents = 16 ,
483
+ ppo_learning_rate = 0. ,
468
484
# Whether the PPO agent should be restored from the previous iteration, or
469
485
# should start fresh each time.
470
486
ppo_continue_training = True ,
@@ -483,6 +499,14 @@ def rl_modelrl_medium():
483
499
return hparams
484
500
485
501
502
+ @registry .register_hparams
503
+ def rl_modelrl_25k ():
504
+ """Small set for larger testing."""
505
+ hparams = rl_modelrl_medium ()
506
+ hparams .true_env_generator_num_steps //= 2
507
+ return hparams
508
+
509
+
486
510
@registry .register_hparams
487
511
def rl_modelrl_short ():
488
512
"""Small set for larger testing."""
@@ -583,6 +607,13 @@ def rl_modelrl_ae_base():
583
607
return hparams
584
608
585
609
610
+ @registry .register_hparams
611
+ def rl_modelrl_ae_25k ():
612
+ hparams = rl_modelrl_ae_base ()
613
+ hparams .true_env_generator_num_steps //= 4
614
+ return hparams
615
+
616
+
586
617
@registry .register_hparams
587
618
def rl_modelrl_ae_l1_base ():
588
619
"""Parameter set for autoencoders and L1 loss."""
0 commit comments