28
28
from __future__ import division
29
29
from __future__ import print_function
30
30
31
- import numpy as np
32
-
33
31
from tensor2tensor .models .research import rl # pylint: disable=unused-import
34
32
from tensor2tensor .rl import rl_utils
35
33
from tensor2tensor .rl import trainer_model_based_params # pylint: disable=unused-import
36
34
from tensor2tensor .utils import flags as t2t_flags # pylint: disable=unused-import
37
35
from tensor2tensor .utils import trainer_lib
36
+ from tensor2tensor .utils import registry
38
37
39
38
import tensorflow as tf
40
39
44
43
45
44
46
45
flags .DEFINE_string ("policy_dir" , "" , "Directory with policy checkpoints." )
46
+ flags .DEFINE_string ("model_dir" , "" , "Directory with model checkpoints." )
47
47
flags .DEFINE_string (
48
48
"eval_metrics_dir" , "" , "Directory to output the eval metrics at."
49
49
)
50
50
flags .DEFINE_bool ("full_eval" , True , "Whether to ignore the timestep limit." )
51
- flags .DEFINE_enum ("agent" , "policy" , ["random" , "policy" ], "Agent type to use." )
51
+ flags .DEFINE_enum (
52
+ "agent" , "policy" , ["random" , "policy" , "planner" ], "Agent type to use."
53
+ )
52
54
flags .DEFINE_bool (
53
55
"eval_with_learner" , True ,
54
56
"Whether to use the PolicyLearner.evaluate function instead of an "
55
57
"out-of-graph one. Works only with --agent=policy."
56
58
)
59
+ flags .DEFINE_string (
60
+ "planner_hparams_set" , "planner_tiny" , "Planner hparam set."
61
+ )
62
+ flags .DEFINE_string ("planner_hparams" , "" , "Planner hparam overrides." )
63
+
64
+
65
+ @registry .register_hparams
66
+ def planner_tiny ():
67
+ return tf .contrib .training .HParams (
68
+ num_rollouts = 1 ,
69
+ planning_horizon = 2 ,
70
+ rollout_agent_type = "random" ,
71
+ )
57
72
58
73
59
74
def make_agent (
60
- agent_type , env , policy_hparams , policy_dir , sampling_temp
75
+ agent_type , env , policy_hparams , policy_dir , sampling_temp ,
76
+ sim_env_kwargs = None , frame_stack_size = None , planning_horizon = None ,
77
+ rollout_agent_type = None
61
78
):
62
79
"""Factory function for Agents."""
63
80
return {
@@ -68,45 +85,40 @@ def make_agent(
68
85
env .batch_size , env .observation_space , env .action_space ,
69
86
policy_hparams , policy_dir , sampling_temp
70
87
),
88
+ "planner" : lambda : rl_utils .PlannerAgent ( # pylint: disable=g-long-lambda
89
+ env .batch_size , make_agent (
90
+ rollout_agent_type , env , policy_hparams , policy_dir , sampling_temp
91
+ ), rl_utils .SimulatedBatchGymEnvWithFixedInitialFrames (
92
+ ** sim_env_kwargs
93
+ ), lambda env : rl_utils .BatchStackWrapper (env , frame_stack_size ),
94
+ planning_horizon
95
+ ),
71
96
}[agent_type ]()
72
97
73
98
74
- def make_eval_fn_with_agent (agent_type ):
99
+ def make_eval_fn_with_agent (agent_type , planner_hparams , model_dir ):
75
100
"""Returns an out-of-graph eval_fn using the Agent API."""
76
- def eval_fn (env , hparams , policy_hparams , policy_dir , sampling_temp ):
101
+ def eval_fn (env , loop_hparams , policy_hparams , policy_dir , sampling_temp ):
77
102
"""Eval function."""
78
103
base_env = env
79
- env = rl_utils .BatchStackWrapper (env , hparams .frame_stack_size )
104
+ env = rl_utils .BatchStackWrapper (env , loop_hparams .frame_stack_size )
105
+ sim_env_kwargs = rl .make_simulated_env_kwargs (
106
+ base_env , loop_hparams , batch_size = planner_hparams .num_rollouts ,
107
+ model_dir = model_dir
108
+ )
80
109
agent = make_agent (
81
- agent_type , env , policy_hparams , policy_dir , sampling_temp
110
+ agent_type , env , policy_hparams , policy_dir , sampling_temp ,
111
+ sim_env_kwargs , loop_hparams .frame_stack_size ,
112
+ planner_hparams .planning_horizon , planner_hparams .rollout_agent_type
82
113
)
83
- num_dones = 0
84
- first_dones = [False ] * env .batch_size
85
- observations = env .reset ()
86
- while num_dones < env .batch_size :
87
- actions = agent .act (observations )
88
- (observations , _ , dones ) = env .step (actions )
89
- observations = list (observations )
90
- now_done_indices = []
91
- for (i , done ) in enumerate (dones ):
92
- if done and not first_dones [i ]:
93
- now_done_indices .append (i )
94
- first_dones [i ] = True
95
- num_dones += 1
96
- if now_done_indices :
97
- # Reset only envs done the first time in this timestep to ensure that
98
- # we collect exactly 1 rollout from each env.
99
- reset_observations = env .reset (now_done_indices )
100
- for (i , observation ) in zip (now_done_indices , reset_observations ):
101
- observations [i ] = observation
102
- observations = np .array (observations )
114
+ rl_utils .run_rollouts (env , agent , env .reset ())
103
115
assert len (base_env .current_epoch_rollouts ()) == env .batch_size
104
116
return eval_fn
105
117
106
118
107
119
def evaluate (
108
- hparams , policy_dir , eval_metrics_dir , agent_type , eval_with_learner ,
109
- report_fn = None , report_metric = None
120
+ loop_hparams , planner_hparams , policy_dir , model_dir , eval_metrics_dir ,
121
+ agent_type , eval_with_learner , report_fn = None , report_metric = None
110
122
):
111
123
"""Evaluate."""
112
124
if eval_with_learner :
@@ -118,16 +130,20 @@ def evaluate(
118
130
eval_metrics_writer = tf .summary .FileWriter (eval_metrics_dir )
119
131
kwargs = {}
120
132
if not eval_with_learner :
121
- kwargs ["eval_fn" ] = make_eval_fn_with_agent (agent_type )
122
- eval_metrics = rl_utils .evaluate_all_configs (hparams , policy_dir , ** kwargs )
133
+ kwargs ["eval_fn" ] = make_eval_fn_with_agent (
134
+ agent_type , planner_hparams , model_dir
135
+ )
136
+ eval_metrics = rl_utils .evaluate_all_configs (
137
+ loop_hparams , policy_dir , ** kwargs
138
+ )
123
139
rl_utils .summarize_metrics (eval_metrics_writer , eval_metrics , 0 )
124
140
125
141
# Report metrics
126
142
if report_fn :
127
143
if report_metric == "mean_reward" :
128
144
metric_name = rl_utils .get_metric_name (
129
- sampling_temp = hparams .eval_sampling_temps [0 ],
130
- max_num_noops = hparams .eval_max_num_noops ,
145
+ sampling_temp = loop_hparams .eval_sampling_temps [0 ],
146
+ max_num_noops = loop_hparams .eval_max_num_noops ,
131
147
clipped = False
132
148
)
133
149
report_fn (eval_metrics [metric_name ], 0 )
@@ -137,12 +153,17 @@ def evaluate(
137
153
138
154
139
155
def main (_ ):
140
- hparams = trainer_lib .create_hparams (FLAGS .hparams_set , FLAGS .hparams )
156
+ loop_hparams = trainer_lib .create_hparams (
157
+ FLAGS .loop_hparams_set , FLAGS .loop_hparams
158
+ )
141
159
if FLAGS .full_eval :
142
- hparams .eval_rl_env_max_episode_steps = - 1
160
+ loop_hparams .eval_rl_env_max_episode_steps = - 1
161
+ planner_hparams = trainer_lib .create_hparams (
162
+ FLAGS .planner_hparams_set , FLAGS .planner_hparams
163
+ )
143
164
evaluate (
144
- hparams , FLAGS . policy_dir , FLAGS .eval_metrics_dir , FLAGS .agent ,
145
- FLAGS .eval_with_learner
165
+ loop_hparams , planner_hparams , FLAGS .policy_dir , FLAGS .model_dir ,
166
+ FLAGS .eval_metrics_dir , FLAGS . agent , FLAGS . eval_with_learner
146
167
)
147
168
148
169
0 commit comments