Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 2b2b46d

Browse files
author
Ryan Sepassi
committed
Unset random seed by default, RL fixes
PiperOrigin-RevId: 197170249
1 parent 8fd79f4 commit 2b2b46d

File tree

3 files changed

+54
-16
lines changed

3 files changed

+54
-16
lines changed

tensor2tensor/bin/t2t_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
"The imported files should contain registrations, "
4646
"e.g. @registry.register_model calls, that will then be "
4747
"available to the t2t-trainer.")
48-
flags.DEFINE_integer("random_seed", 1234, "Random seed.")
48+
flags.DEFINE_integer("random_seed", None, "Random seed.")
4949
flags.DEFINE_integer("tpu_num_shards", 8, "Number of tpu shards.")
5050
flags.DEFINE_integer("iterations_per_loop", 100,
5151
"Number of iterations in a TPU training loop.")

tensor2tensor/rl/model_rl_experiment.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def train_autoencoder(problem_name, data_dir, output_dir, hparams, epoch):
136136

137137
def train_agent(problem_name, agent_model_dir,
138138
event_dir, world_model_dir, epoch_data_dir, hparams,
139-
autoencoder_path=None):
139+
autoencoder_path=None, epoch=0):
140140
"""Train the PPO agent in the simulated environment."""
141141
gym_problem = registry.problem(problem_name)
142142
ppo_hparams = trainer_lib.create_hparams(hparams.ppo_params)
@@ -151,6 +151,8 @@ def train_agent(problem_name, agent_model_dir,
151151
ppo_hparams.num_agents = hparams.ppo_num_agents
152152
ppo_hparams.problem = gym_problem
153153
ppo_hparams.world_model_dir = world_model_dir
154+
if hparams.ppo_learning_rate:
155+
ppo_hparams.learning_rate = hparams.ppo_learning_rate
154156
# 4x for the StackAndSkipWrapper minus one to always finish for reporting.
155157
ppo_time_limit = (ppo_hparams.epoch_length - 1) * 4
156158

@@ -169,7 +171,7 @@ def train_agent(problem_name, agent_model_dir,
169171
"autoencoder_path": autoencoder_path,
170172
}):
171173
rl_trainer_lib.train(ppo_hparams, gym_problem.env_name, event_dir,
172-
agent_model_dir)
174+
agent_model_dir, epoch=epoch)
173175

174176

175177
def evaluate_world_model(simulated_problem_name, problem_name, hparams,
@@ -281,19 +283,32 @@ def encode_env_frames(problem_name, ae_problem_name, autoencoder_path,
281283
ae_training_paths = ae_problem.training_filepaths(epoch_data_dir, 10, True)
282284
ae_eval_paths = ae_problem.dev_filepaths(epoch_data_dir, 1, True)
283285

286+
skip_train = False
287+
skip_eval = False
288+
for path in ae_training_paths:
289+
if tf.gfile.Exists(path):
290+
skip_train = True
291+
break
292+
for path in ae_eval_paths:
293+
if tf.gfile.Exists(path):
294+
skip_eval = True
295+
break
296+
284297
# Encode train data
285-
dataset = problem.dataset(tf.estimator.ModeKeys.TRAIN, epoch_data_dir,
286-
shuffle_files=False, output_buffer_size=100,
287-
preprocess=False)
288-
encode_dataset(model, dataset, problem, ae_hparams, autoencoder_path,
289-
ae_training_paths)
298+
if not skip_train:
299+
dataset = problem.dataset(tf.estimator.ModeKeys.TRAIN, epoch_data_dir,
300+
shuffle_files=False, output_buffer_size=100,
301+
preprocess=False)
302+
encode_dataset(model, dataset, problem, ae_hparams, autoencoder_path,
303+
ae_training_paths)
290304

291305
# Encode eval data
292-
dataset = problem.dataset(tf.estimator.ModeKeys.EVAL, epoch_data_dir,
293-
shuffle_files=False, output_buffer_size=100,
294-
preprocess=False)
295-
encode_dataset(model, dataset, problem, ae_hparams, autoencoder_path,
296-
ae_eval_paths)
306+
if not skip_eval:
307+
dataset = problem.dataset(tf.estimator.ModeKeys.EVAL, epoch_data_dir,
308+
shuffle_files=False, output_buffer_size=100,
309+
preprocess=False)
310+
encode_dataset(model, dataset, problem, ae_hparams, autoencoder_path,
311+
ae_eval_paths)
297312

298313

299314
def check_problems(problem_names):
@@ -392,7 +407,7 @@ def training_loop(hparams, output_dir, report_fn=None, report_metric=None):
392407
ppo_model_dir = ppo_event_dir
393408
train_agent(world_model_problem, ppo_model_dir,
394409
ppo_event_dir, directories["world_model"], epoch_data_dir,
395-
hparams, autoencoder_path=autoencoder_model_dir)
410+
hparams, autoencoder_path=autoencoder_model_dir, epoch=epoch)
396411

397412
# Collect data from the real environment.
398413
log("Generating real environment data")
@@ -465,6 +480,7 @@ def rl_modelrl_base():
465480
# though it is not necessary.
466481
ppo_epoch_length=60,
467482
ppo_num_agents=16,
483+
ppo_learning_rate=0.,
468484
# Whether the PPO agent should be restored from the previous iteration, or
469485
# should start fresh each time.
470486
ppo_continue_training=True,
@@ -483,6 +499,14 @@ def rl_modelrl_medium():
483499
return hparams
484500

485501

502+
@registry.register_hparams
503+
def rl_modelrl_25k():
504+
"""Small set for larger testing."""
505+
hparams = rl_modelrl_medium()
506+
hparams.true_env_generator_num_steps //= 2
507+
return hparams
508+
509+
486510
@registry.register_hparams
487511
def rl_modelrl_short():
488512
"""Small set for larger testing."""
@@ -583,6 +607,13 @@ def rl_modelrl_ae_base():
583607
return hparams
584608

585609

610+
@registry.register_hparams
611+
def rl_modelrl_ae_25k():
612+
hparams = rl_modelrl_ae_base()
613+
hparams.true_env_generator_num_steps //= 4
614+
return hparams
615+
616+
586617
@registry.register_hparams
587618
def rl_modelrl_ae_l1_base():
588619
"""Parameter set for autoencoders and L1 loss."""

tensor2tensor/rl/rl_trainer_lib.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def define_train(hparams, environment_spec, event_dir):
8383

8484

8585
def train(hparams, environment_spec, event_dir=None, model_dir=None,
86-
restore_agent=True):
86+
restore_agent=True, epoch=0):
8787
"""Train."""
8888
with tf.name_scope("rl_train"):
8989
train_summary_op, eval_summary_op = define_train(hparams, environment_spec,
@@ -112,6 +112,13 @@ def train(hparams, environment_spec, event_dir=None, model_dir=None,
112112
if model_saver and restore_agent:
113113
start_step = trainer_lib.restore_checkpoint(
114114
model_dir, model_saver, sess)
115+
116+
# Fail-friendly, don't train if already trained for this epoch
117+
if start_step >= ((hparams.epochs_num * (epoch+1)) - 5):
118+
tf.logging.info("Skipping PPO training for epoch %d as train steps "
119+
"(%d) already reached", epoch, start_step)
120+
return
121+
115122
for epoch_index in range(hparams.epochs_num):
116123
summary = sess.run(train_summary_op)
117124
if summary_writer:
@@ -127,5 +134,5 @@ def train(hparams, environment_spec, event_dir=None, model_dir=None,
127134
(epoch_index % hparams.save_models_every_epochs == 0 or
128135
(epoch_index + 1) == hparams.epochs_num)):
129136
ckpt_path = os.path.join(
130-
model_dir, "model.ckpt-{}".format(epoch_index + start_step))
137+
model_dir, "model.ckpt-{}".format(epoch_index + 1 + start_step))
131138
model_saver.save(sess, ckpt_path)

0 commit comments

Comments
 (0)