Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 152beb0

Browse files
koz4klukaszkaiser
authored andcommitted
Implement PlannerAgent (#1365)
* Extract a function for running rollouts to rl_utils * Extract a base class for batch wrappers * Factorize make_simulated_env_fn_from_hparams by a function returning just kwargs * Implement PlannerAgent
1 parent 7836aa7 commit 152beb0

File tree

6 files changed

+224
-83
lines changed

6 files changed

+224
-83
lines changed

tensor2tensor/models/research/rl.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -208,25 +208,34 @@ def env_fn(in_graph):
208208
return env_fn
209209

210210

211-
def make_simulated_env_fn_from_hparams(
212-
real_env, hparams, batch_size, initial_frame_chooser, model_dir,
213-
sim_video_dir=None):
214-
"""Creates a simulated env_fn."""
215-
model_hparams = trainer_lib.create_hparams(hparams.generative_model_params)
211+
# TODO(koz4k): Move this and the one below to rl_utils.
212+
def make_simulated_env_kwargs(real_env, hparams, **extra_kwargs):
213+
"""Extracts simulated env kwargs from real_env and loop hparams."""
214+
objs_and_attrs = [
215+
(real_env, [
216+
"reward_range", "observation_space", "action_space", "frame_height",
217+
"frame_width"
218+
]),
219+
(hparams, ["frame_stack_size", "intrinsic_reward_scale"])
220+
]
221+
kwargs = {
222+
attr: getattr(obj, attr)
223+
for (obj, attrs) in objs_and_attrs for attr in attrs
224+
}
225+
kwargs["model_name"] = hparams.generative_model
226+
kwargs["model_hparams"] = trainer_lib.create_hparams(
227+
hparams.generative_model_params
228+
)
216229
if hparams.wm_policy_param_sharing:
217-
model_hparams.optimizer_zero_grads = True
230+
kwargs["model_hparams"].optimizer_zero_grads = True
231+
kwargs.update(extra_kwargs)
232+
return kwargs
233+
234+
235+
def make_simulated_env_fn_from_hparams(real_env, hparams, **extra_kwargs):
236+
"""Creates a simulated env_fn."""
218237
return make_simulated_env_fn(
219-
reward_range=real_env.reward_range,
220-
observation_space=real_env.observation_space,
221-
action_space=real_env.action_space,
222-
frame_stack_size=hparams.frame_stack_size,
223-
frame_height=real_env.frame_height, frame_width=real_env.frame_width,
224-
initial_frame_chooser=initial_frame_chooser, batch_size=batch_size,
225-
model_name=hparams.generative_model,
226-
model_hparams=trainer_lib.create_hparams(hparams.generative_model_params),
227-
model_dir=model_dir,
228-
intrinsic_reward_scale=hparams.intrinsic_reward_scale,
229-
sim_video_dir=sim_video_dir,
238+
**make_simulated_env_kwargs(real_env, hparams, **extra_kwargs)
230239
)
231240

232241

tensor2tensor/rl/envs/simulated_batch_gym_env.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from gym import Env
2323
from tensor2tensor.rl.envs.simulated_batch_env import SimulatedBatchEnv
2424

25+
import numpy as np
2526
import tensorflow as tf
2627

2728

@@ -55,6 +56,7 @@ def __init__(self, *args, **kwargs):
5556
self._rewards_t, self._dones_t = self._batch_env.simulate(self._actions_t)
5657
with tf.control_dependencies([self._rewards_t]):
5758
self._obs_t = self._batch_env.observ
59+
self._indices_t = tf.placeholder(shape=(self.batch_size,), dtype=tf.int32)
5860
self._reset_op = self._batch_env.reset(
5961
tf.range(self.batch_size, dtype=tf.int32)
6062
)
@@ -79,9 +81,9 @@ def render(self, mode="human"):
7981
raise NotImplementedError()
8082

8183
def reset(self, indices=None):
82-
if indices:
83-
raise NotImplementedError()
84-
obs = self._sess.run(self._reset_op)
84+
if indices is None:
85+
indices = np.array(range(self.batch_size))
86+
obs = self._sess.run(self._reset_op, feed_dict={self._indices_t: indices})
8587
# TODO(pmilos): remove if possible
8688
# obs[:, 0, 0, 0] = 0
8789
# obs[:, 0, 0, 1] = 255

tensor2tensor/rl/evaluator.py

Lines changed: 59 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,12 @@
2828
from __future__ import division
2929
from __future__ import print_function
3030

31-
import numpy as np
32-
3331
from tensor2tensor.models.research import rl # pylint: disable=unused-import
3432
from tensor2tensor.rl import rl_utils
3533
from tensor2tensor.rl import trainer_model_based_params # pylint: disable=unused-import
3634
from tensor2tensor.utils import flags as t2t_flags # pylint: disable=unused-import
3735
from tensor2tensor.utils import trainer_lib
36+
from tensor2tensor.utils import registry
3837

3938
import tensorflow as tf
4039

@@ -44,20 +43,38 @@
4443

4544

4645
flags.DEFINE_string("policy_dir", "", "Directory with policy checkpoints.")
46+
flags.DEFINE_string("model_dir", "", "Directory with model checkpoints.")
4747
flags.DEFINE_string(
4848
"eval_metrics_dir", "", "Directory to output the eval metrics at."
4949
)
5050
flags.DEFINE_bool("full_eval", True, "Whether to ignore the timestep limit.")
51-
flags.DEFINE_enum("agent", "policy", ["random", "policy"], "Agent type to use.")
51+
flags.DEFINE_enum(
52+
"agent", "policy", ["random", "policy", "planner"], "Agent type to use."
53+
)
5254
flags.DEFINE_bool(
5355
"eval_with_learner", True,
5456
"Whether to use the PolicyLearner.evaluate function instead of an "
5557
"out-of-graph one. Works only with --agent=policy."
5658
)
59+
flags.DEFINE_string(
60+
"planner_hparams_set", "planner_tiny", "Planner hparam set."
61+
)
62+
flags.DEFINE_string("planner_hparams", "", "Planner hparam overrides.")
63+
64+
65+
@registry.register_hparams
66+
def planner_tiny():
67+
return tf.contrib.training.HParams(
68+
num_rollouts=1,
69+
planning_horizon=2,
70+
rollout_agent_type="random",
71+
)
5772

5873

5974
def make_agent(
60-
agent_type, env, policy_hparams, policy_dir, sampling_temp
75+
agent_type, env, policy_hparams, policy_dir, sampling_temp,
76+
sim_env_kwargs=None, frame_stack_size=None, planning_horizon=None,
77+
rollout_agent_type=None
6178
):
6279
"""Factory function for Agents."""
6380
return {
@@ -68,45 +85,40 @@ def make_agent(
6885
env.batch_size, env.observation_space, env.action_space,
6986
policy_hparams, policy_dir, sampling_temp
7087
),
88+
"planner": lambda: rl_utils.PlannerAgent( # pylint: disable=g-long-lambda
89+
env.batch_size, make_agent(
90+
rollout_agent_type, env, policy_hparams, policy_dir, sampling_temp
91+
), rl_utils.SimulatedBatchGymEnvWithFixedInitialFrames(
92+
**sim_env_kwargs
93+
), lambda env: rl_utils.BatchStackWrapper(env, frame_stack_size),
94+
planning_horizon
95+
),
7196
}[agent_type]()
7297

7398

74-
def make_eval_fn_with_agent(agent_type):
99+
def make_eval_fn_with_agent(agent_type, planner_hparams, model_dir):
75100
"""Returns an out-of-graph eval_fn using the Agent API."""
76-
def eval_fn(env, hparams, policy_hparams, policy_dir, sampling_temp):
101+
def eval_fn(env, loop_hparams, policy_hparams, policy_dir, sampling_temp):
77102
"""Eval function."""
78103
base_env = env
79-
env = rl_utils.BatchStackWrapper(env, hparams.frame_stack_size)
104+
env = rl_utils.BatchStackWrapper(env, loop_hparams.frame_stack_size)
105+
sim_env_kwargs = rl.make_simulated_env_kwargs(
106+
base_env, loop_hparams, batch_size=planner_hparams.num_rollouts,
107+
model_dir=model_dir
108+
)
80109
agent = make_agent(
81-
agent_type, env, policy_hparams, policy_dir, sampling_temp
110+
agent_type, env, policy_hparams, policy_dir, sampling_temp,
111+
sim_env_kwargs, loop_hparams.frame_stack_size,
112+
planner_hparams.planning_horizon, planner_hparams.rollout_agent_type
82113
)
83-
num_dones = 0
84-
first_dones = [False] * env.batch_size
85-
observations = env.reset()
86-
while num_dones < env.batch_size:
87-
actions = agent.act(observations)
88-
(observations, _, dones) = env.step(actions)
89-
observations = list(observations)
90-
now_done_indices = []
91-
for (i, done) in enumerate(dones):
92-
if done and not first_dones[i]:
93-
now_done_indices.append(i)
94-
first_dones[i] = True
95-
num_dones += 1
96-
if now_done_indices:
97-
# Reset only envs done the first time in this timestep to ensure that
98-
# we collect exactly 1 rollout from each env.
99-
reset_observations = env.reset(now_done_indices)
100-
for (i, observation) in zip(now_done_indices, reset_observations):
101-
observations[i] = observation
102-
observations = np.array(observations)
114+
rl_utils.run_rollouts(env, agent, env.reset())
103115
assert len(base_env.current_epoch_rollouts()) == env.batch_size
104116
return eval_fn
105117

106118

107119
def evaluate(
108-
hparams, policy_dir, eval_metrics_dir, agent_type, eval_with_learner,
109-
report_fn=None, report_metric=None
120+
loop_hparams, planner_hparams, policy_dir, model_dir, eval_metrics_dir,
121+
agent_type, eval_with_learner, report_fn=None, report_metric=None
110122
):
111123
"""Evaluate."""
112124
if eval_with_learner:
@@ -118,16 +130,20 @@ def evaluate(
118130
eval_metrics_writer = tf.summary.FileWriter(eval_metrics_dir)
119131
kwargs = {}
120132
if not eval_with_learner:
121-
kwargs["eval_fn"] = make_eval_fn_with_agent(agent_type)
122-
eval_metrics = rl_utils.evaluate_all_configs(hparams, policy_dir, **kwargs)
133+
kwargs["eval_fn"] = make_eval_fn_with_agent(
134+
agent_type, planner_hparams, model_dir
135+
)
136+
eval_metrics = rl_utils.evaluate_all_configs(
137+
loop_hparams, policy_dir, **kwargs
138+
)
123139
rl_utils.summarize_metrics(eval_metrics_writer, eval_metrics, 0)
124140

125141
# Report metrics
126142
if report_fn:
127143
if report_metric == "mean_reward":
128144
metric_name = rl_utils.get_metric_name(
129-
sampling_temp=hparams.eval_sampling_temps[0],
130-
max_num_noops=hparams.eval_max_num_noops,
145+
sampling_temp=loop_hparams.eval_sampling_temps[0],
146+
max_num_noops=loop_hparams.eval_max_num_noops,
131147
clipped=False
132148
)
133149
report_fn(eval_metrics[metric_name], 0)
@@ -137,12 +153,17 @@ def evaluate(
137153

138154

139155
def main(_):
140-
hparams = trainer_lib.create_hparams(FLAGS.hparams_set, FLAGS.hparams)
156+
loop_hparams = trainer_lib.create_hparams(
157+
FLAGS.loop_hparams_set, FLAGS.loop_hparams
158+
)
141159
if FLAGS.full_eval:
142-
hparams.eval_rl_env_max_episode_steps = -1
160+
loop_hparams.eval_rl_env_max_episode_steps = -1
161+
planner_hparams = trainer_lib.create_hparams(
162+
FLAGS.planner_hparams_set, FLAGS.planner_hparams
163+
)
143164
evaluate(
144-
hparams, FLAGS.policy_dir, FLAGS.eval_metrics_dir, FLAGS.agent,
145-
FLAGS.eval_with_learner
165+
loop_hparams, planner_hparams, FLAGS.policy_dir, FLAGS.model_dir,
166+
FLAGS.eval_metrics_dir, FLAGS.agent, FLAGS.eval_with_learner
146167
)
147168

148169

tensor2tensor/rl/evaluator_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@
2727
class EvalTest(tf.test.TestCase):
2828

2929
def test_evaluate_pong_random_agent(self):
30-
hparams = registry.hparams("rlmb_tiny")
30+
loop_hparams = registry.hparams("rlmb_tiny")
31+
planner_hparams = registry.hparams("planner_tiny")
3132
temp_dir = tf.test.get_temp_dir()
3233
evaluator.evaluate(
33-
hparams, temp_dir, temp_dir, agent_type="random",
34-
eval_with_learner=False
34+
loop_hparams, planner_hparams, temp_dir, temp_dir, temp_dir,
35+
agent_type="random", eval_with_learner=False
3536
)
3637

3738

0 commit comments

Comments
 (0)