Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 6633d9c

Browse files
author
Błażej O
committed
Introducing video_during_eval hparam.
1 parent 1f4d661 commit 6633d9c

File tree

4 files changed

+15
-9
lines changed

4 files changed

+15
-9
lines changed

tensor2tensor/models/research/rl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def ppo_base_v1():
4848
hparams.add_hparam("epochs_num", 2000)
4949
hparams.add_hparam("eval_every_epochs", 10)
5050
hparams.add_hparam("num_eval_agents", 3)
51+
hparams.add_hparam("video_during_eval", True)
5152
return hparams
5253

5354

tensor2tensor/rl/envs/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,8 @@ class EvalVideoWrapper(gym.Wrapper):
5252
returns last seen observation.
5353
Videos are only generated during the active runs.
5454
"""
55-
def __init__(self, env, directory):
56-
super(EvalVideoWrapper, self).__init__(
57-
gym.wrappers.Monitor(env, directory, video_callable=lambda i: i % 2 == 0))
55+
def __init__(self, env):
56+
super(EvalVideoWrapper, self).__init__(env)
5857
self._reset_counter = 0
5958
self._active = False
6059
self._last_returned = None

tensor2tensor/rl/rl_trainer_lib.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,21 @@ def define_train(hparams, environment_spec, event_dir):
5757

5858
with tf.variable_scope("eval"):
5959
eval_env_lambda = env_lambda
60-
if event_dir:
61-
eval_env_lambda = lambda: utils.EvalVideoWrapper(env_lambda(), event_dir)
60+
if event_dir and hparams.video_during_eval:
61+
eval_env_lambda = lambda: gym.wrappers.Monitor(
62+
env_lambda(), event_dir, video_callable=lambda i: i % 2 == 0)
63+
wrapped_eval_env_lambda = lambda: utils.EvalVideoWrapper(eval_env_lambda())
6264
_, eval_summary = collect.define_collect(
6365
policy_factory,
64-
utils.define_batch_env(eval_env_lambda, hparams.num_eval_agents, xvfb=True),
66+
utils.define_batch_env(wrapped_eval_env_lambda, hparams.num_eval_agents,
67+
xvfb=hparams.video_during_eval),
6568
hparams, eval_phase=True)
6669
return summary, eval_summary
6770

6871

6972
def train(hparams, environment_spec, event_dir=None):
70-
train_summary_op, eval_summary_op = define_train(hparams, environment_spec, event_dir)
73+
train_summary_op, eval_summary_op = define_train(hparams, environment_spec,
74+
event_dir)
7175

7276
if event_dir:
7377
summary_writer = tf.summary.FileWriter(

tensor2tensor/rl/rl_trainer_lib_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@
2626
class TrainTest(tf.test.TestCase):
2727

2828
def test_no_crash_pendulum(self):
29-
hparams = trainer_lib.create_hparams("continuous_action_base", "epochs_num=10")
29+
hparams = trainer_lib.create_hparams(
30+
"continuous_action_base", "epochs_num=11,video_during_eval=False")
3031
rl_trainer_lib.train(hparams, "Pendulum-v0")
3132

3233
def test_no_crash_cartpole(self):
33-
hparams = trainer_lib.create_hparams("discrete_action_base", "epochs_num=10")
34+
hparams = trainer_lib.create_hparams(
35+
"discrete_action_base", "epochs_num=11,video_during_eval=False")
3436
rl_trainer_lib.train(hparams, "CartPole-v0")
3537

3638

0 commit comments

Comments
 (0)