35
35
import tensorflow as tf
36
36
37
37
38
-
39
-
40
38
flags = tf .flags
41
39
FLAGS = flags .FLAGS
42
40
@@ -50,6 +48,17 @@ def __init__(self, *args, **kwargs):
50
48
super (GymDiscreteProblem , self ).__init__ (* args , ** kwargs )
51
49
self ._env = None
52
50
51
+ def example_reading_spec (self , label_repr = None ):
52
+
53
+ data_fields = {
54
+ "inputs" : tf .FixedLenFeature ([210 , 160 , 3 ], tf .int64 ),
55
+ "inputs_prev" : tf .FixedLenFeature ([210 , 160 , 3 ], tf .int64 ),
56
+ "targets" : tf .FixedLenFeature ([210 , 160 , 3 ], tf .int64 ),
57
+ "action" : tf .FixedLenFeature ([1 ], tf .int64 )
58
+ }
59
+
60
+ return data_fields , None
61
+
53
62
@property
54
63
def env_name (self ):
55
64
# This is the name of the Gym environment for this problem.
@@ -133,7 +142,7 @@ class GymPongRandom5k(GymDiscreteProblem):
133
142
134
143
@property
135
144
def env_name (self ):
136
- return "Pong-v0 "
145
+ return "PongNoFrameskip-v4 "
137
146
138
147
@property
139
148
def num_actions (self ):
@@ -148,21 +157,30 @@ def num_steps(self):
148
157
return 5000
149
158
150
159
160
+
151
161
@registry .register_problem
152
162
class GymPongTrajectoriesFromPolicy (GymDiscreteProblem ):
153
163
"""Pong game, loaded actions."""
154
164
155
- def __init__ (self , event_dir , * args , ** kwargs ):
165
+ def __init__ (self , * args , ** kwargs ):
156
166
super (GymPongTrajectoriesFromPolicy , self ).__init__ (* args , ** kwargs )
157
167
self ._env = None
158
- self ._event_dir = event_dir
168
+ self ._last_policy_op = None
169
+ self ._max_frame_pl = None
170
+ self ._last_action = self .env .action_space .sample ()
171
+ self ._skip = 4
172
+ self ._skip_step = 0
173
+ self ._obs_buffer = np .zeros ((2 ,) + self .env .observation_space .shape ,
174
+ dtype = np .uint8 )
175
+
176
+ def generator (self , data_dir , tmp_dir ):
159
177
env_spec = lambda : atari_wrappers .wrap_atari ( # pylint: disable=g-long-lambda
160
178
gym .make ("PongNoFrameskip-v4" ),
161
179
warp = False ,
162
180
frame_skip = 4 ,
163
181
frame_stack = False )
164
182
hparams = rl .atari_base ()
165
- with tf .variable_scope ("train" ):
183
+ with tf .variable_scope ("train" , reuse = tf . AUTO_REUSE ):
166
184
policy_lambda = hparams .network
167
185
policy_factory = tf .make_template (
168
186
"network" ,
@@ -173,14 +191,13 @@ def __init__(self, event_dir, *args, **kwargs):
173
191
self ._max_frame_pl , 0 ), 0 ))
174
192
policy = actor_critic .policy
175
193
self ._last_policy_op = policy .mode ()
176
- self ._last_action = self .env .action_space .sample ()
177
- self ._skip = 4
178
- self ._skip_step = 0
179
- self ._obs_buffer = np .zeros ((2 ,) + self .env .observation_space .shape ,
180
- dtype = np .uint8 )
181
- self ._sess = tf .Session ()
182
- model_saver = tf .train .Saver (tf .global_variables (".*network_parameters.*" ))
183
- model_saver .restore (self ._sess , FLAGS .model_path )
194
+ with tf .Session () as sess :
195
+ model_saver = tf .train .Saver (
196
+ tf .global_variables (".*network_parameters.*" ))
197
+ model_saver .restore (sess , FLAGS .model_path )
198
+ for item in super (GymPongTrajectoriesFromPolicy ,
199
+ self ).generator (data_dir , tmp_dir ):
200
+ yield item
184
201
185
202
# TODO(blazej0): For training of atari agents wrappers are usually used.
186
203
# Below we have a hacky solution which is a workaround to be used together
@@ -191,7 +208,7 @@ def get_action(self, observation=None):
191
208
self ._skip_step = (self ._skip_step + 1 ) % self ._skip
192
209
if self ._skip_step == 0 :
193
210
max_frame = self ._obs_buffer .max (axis = 0 )
194
- self ._last_action = int (self . _sess .run (
211
+ self ._last_action = int (tf . get_default_session () .run (
195
212
self ._last_policy_op ,
196
213
feed_dict = {self ._max_frame_pl : max_frame })[0 , 0 ])
197
214
return self ._last_action
0 commit comments