Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 3b34470

Browse files
Lukasz KaiserCopybara-Service
authored andcommitted
Add hparams for RL.
PiperOrigin-RevId: 228941032
1 parent e9eb66a commit 3b34470

File tree

3 files changed

+70
-1
lines changed

3 files changed

+70
-1
lines changed

tensor2tensor/models/research/rl.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,22 @@ def ppo_original_params():
130130
return hparams
131131

132132

133+
@registry.register_hparams
134+
def ppo_original_params_gamma95():
135+
"""Parameters based on the original PPO paper, changed gamma."""
136+
hparams = ppo_original_params()
137+
hparams.gae_gamma = 0.95
138+
return hparams
139+
140+
141+
@registry.register_hparams
142+
def ppo_original_params_gamma90():
143+
"""Parameters based on the original PPO paper, changed gamma."""
144+
hparams = ppo_original_params()
145+
hparams.gae_gamma = 0.90
146+
return hparams
147+
148+
133149
@registry.register_hparams
134150
def ppo_original_world_model():
135151
"""Atari parameters with world model as policy."""

tensor2tensor/models/video/basic_deterministic_params.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,15 @@ def next_frame_pixel_noise():
5959
return hparams
6060

6161

62+
@registry.register_hparams
63+
def next_frame_pixel_noise_long():
64+
"""Long scheduled sampling setting."""
65+
hparams = next_frame_pixel_noise()
66+
hparams.batch_size = 2
67+
hparams.video_num_target_frames = 16
68+
return hparams
69+
70+
6271
@registry.register_hparams
6372
def next_frame_sampling():
6473
"""Basic conv model with scheduled sampling."""

tensor2tensor/rl/trainer_model_based_params.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,10 @@ def rlmb_base_stochastic_discrete():
306306
hparams.grayscale = False
307307
hparams.generative_model = "next_frame_basic_stochastic_discrete"
308308
hparams.generative_model_params = "next_frame_basic_stochastic_discrete"
309+
# The parameters below are the same as base, but repeated for easier reading.
309310
hparams.ppo_epoch_length = 50
310311
hparams.simulated_rollout_length = 50
312+
hparams.simulated_batch_size = 16
311313
return hparams
312314

313315

@@ -320,6 +322,14 @@ def rlmb_base_stochastic_discrete_param_sharing():
320322
return hparams
321323

322324

325+
@registry.register_hparams
326+
def rlmb_long():
327+
"""Long setting with base model."""
328+
hparams = rlmb_base()
329+
hparams.generative_model_params = "next_frame_pixel_noise_long"
330+
return hparams
331+
332+
323333
@registry.register_hparams
324334
def rlmb_long_stochastic_discrete():
325335
"""Long setting with stochastic discrete model."""
@@ -330,7 +340,41 @@ def rlmb_long_stochastic_discrete():
330340

331341

332342
@registry.register_hparams
333-
def rlmb_base_stochastic_recurrent():
343+
def rlmb_long_stochastic_discrete_100steps():
344+
"""Long setting with stochastic discrete model, changed ppo steps."""
345+
hparams = rlmb_long_stochastic_discrete()
346+
hparams.ppo_epoch_length = 100
347+
hparams.simulated_rollout_length = 100
348+
hparams.simulated_batch_size = 8
349+
return hparams
350+
351+
352+
@registry.register_hparams
353+
def rlmb_long_stochastic_discrete_25steps():
354+
"""Long setting with stochastic discrete model, changed ppo steps."""
355+
hparams = rlmb_long_stochastic_discrete()
356+
hparams.ppo_epoch_length = 25
357+
hparams.simulated_rollout_length = 25
358+
hparams.simulated_batch_size = 32
359+
return hparams
360+
361+
362+
def rlmb_long_stochastic_discrete_gamma95():
363+
"""Long setting with stochastic discrete model, changed gamma."""
364+
hparams = rlmb_long_stochastic_discrete()
365+
hparams.base_algo_params = "ppo_original_params_gamma95"
366+
return hparams
367+
368+
369+
def rlmb_long_stochastic_discrete_gamma90():
370+
"""Long setting with stochastic discrete model, changed gamma."""
371+
hparams = rlmb_long_stochastic_discrete()
372+
hparams.base_algo_params = "ppo_original_params_gamma90"
373+
return hparams
374+
375+
376+
@registry.register_hparams
377+
def rlmb_base_recurrent():
334378
"""Base setting with recurrent model."""
335379
hparams = rlmb_base()
336380
hparams.generative_model = "next_frame_basic_recurrent"

0 commit comments

Comments
 (0)