Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 8fd79f4

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Fix PPO training in model_rl_experiment.
PiperOrigin-RevId: 197098255
1 parent b801e54 commit 8fd79f4

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

tensor2tensor/rl/model_rl_experiment.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ def train_agent(problem_name, agent_model_dir,
151151
ppo_hparams.num_agents = hparams.ppo_num_agents
152152
ppo_hparams.problem = gym_problem
153153
ppo_hparams.world_model_dir = world_model_dir
154-
# 4x for the StackAndSkipWrapper
155-
ppo_time_limit = max(ppo_hparams.epoch_length * 4 + 20, 250)
154+
# 4x for the StackAndSkipWrapper minus one to always finish for reporting.
155+
ppo_time_limit = (ppo_hparams.epoch_length - 1) * 4
156156

157157
in_graph_wrappers = [
158158
(TimeLimitWrapper, {"timelimit": ppo_time_limit}),
@@ -456,15 +456,15 @@ def rl_modelrl_base():
456456
simulated_env_generator_num_steps=2000,
457457
simulation_random_starts=True,
458458
intrinsic_reward_scale=0.,
459-
ppo_epochs_num=250, # This should be enough to see something
459+
ppo_epochs_num=200, # This should be enough to see something
460460
# Our simulated envs do not know how to reset.
461461
# You should set ppo_time_limit to the value you believe that
462462
# the simulated env produces a reasonable output.
463463
ppo_time_limit=200, # TODO(blazej): this param is unused
464464
# It makes sense to have ppo_time_limit=ppo_epoch_length,
465465
# though it is not necessary.
466-
ppo_epoch_length=40,
467-
ppo_num_agents=20,
466+
ppo_epoch_length=60,
467+
ppo_num_agents=16,
468468
# Whether the PPO agent should be restored from the previous iteration, or
469469
# should start fresh each time.
470470
ppo_continue_training=True,
@@ -579,7 +579,7 @@ def rl_modelrl_ae_base():
579579
hparams = rl_modelrl_base()
580580
hparams.ppo_params = "ppo_pong_ae_base"
581581
hparams.generative_model_params = "basic_conv_ae"
582-
hparams.autoencoder_train_steps = 100000
582+
hparams.autoencoder_train_steps = 30000
583583
return hparams
584584

585585

@@ -603,10 +603,7 @@ def rl_modelrl_ae_l2_base():
603603
def rl_modelrl_ae_medium():
604604
"""Medium parameter set for autoencoders."""
605605
hparams = rl_modelrl_ae_base()
606-
hparams.autoencoder_train_steps //= 2
607606
hparams.true_env_generator_num_steps //= 2
608-
hparams.model_train_steps //= 2
609-
hparams.ppo_epochs_num //= 2
610607
return hparams
611608

612609

@@ -730,8 +727,11 @@ def rl_modelrl_freeway_ae_medium():
730727
@registry.register_hparams
731728
def rl_modelrl_freeway_short():
732729
"""Short set for testing Freeway."""
733-
hparams = rl_modelrl_short()
734-
hparams.game = "freeway"
730+
hparams = rl_modelrl_freeway_medium()
731+
hparams.true_env_generator_num_steps //= 5
732+
hparams.model_train_steps //= 2
733+
hparams.ppo_epochs_num //= 2
734+
hparams.intrinsic_reward_scale = 0.1
735735
return hparams
736736

737737

0 commit comments

Comments
 (0)