53
53
from __future__ import print_function
54
54
55
55
import gym
56
- from gym .envs .atari .atari_env import ACTION_MEANING
57
56
from gym .utils import play
58
57
import numpy as np
59
- import six
60
58
61
59
from tensor2tensor .bin import t2t_trainer # pylint: disable=unused-import
62
60
from tensor2tensor .rl import player_utils
63
61
from tensor2tensor .rl .envs .simulated_batch_env import PIL_Image
64
62
from tensor2tensor .rl .envs .simulated_batch_env import PIL_ImageDraw
65
63
from tensor2tensor .rl .envs .simulated_batch_gym_env import FlatBatchEnv
66
- from tensor2tensor .rl .rl_utils import absolute_hinge_difference
64
+ from tensor2tensor .rl .rl_utils import absolute_hinge_difference , full_game_name
67
65
# Import flags from t2t_trainer and trainer_model_based
68
66
import tensor2tensor .rl .trainer_model_based_params # pylint: disable=unused-import
69
67
from tensor2tensor .utils import registry
@@ -137,24 +135,23 @@ class PlayerEnv(gym.Env):
137
135
138
136
HEADER_HEIGHT = 27
139
137
140
- def __init__ (self ):
138
+ def __init__ (self , action_meanings ):
139
+ """
140
+
141
+ Args:
142
+ action_meanings: list of strings indicating action names. Can be obtain by
143
+ >>> env = gym.make("PongNoFrameskip-v4") # insert your game name
144
+ >>> env.unwrapped.get_action_meanings()
145
+ See gym AtariEnv get_action_meanings() for more details.
146
+ """
147
+ self .action_meanings = action_meanings
141
148
self ._wait = True
142
149
# If action_space will be needed, one could use e.g. gym.spaces.Dict.
143
150
self .action_space = None
144
151
self ._last_step_tuples = None
145
-
146
- def _init_action_mappings (self , env ):
147
- # Atari dependant. In case of problems with keyboard key interpretation
148
- # switch to _action_set instead of range(env.action_space.n) (similarly to
149
- # how gym AtariEnv does). _action_set can probably be obtain from full
150
- # game name.
151
- self .action_meaning = {i : ACTION_MEANING [i ]
152
- for i in range (env .action_space .n )}
153
- self .name_to_action_num = {v : k for k , v in
154
- six .iteritems (self .action_meaning )}
155
-
156
- def _get_action_meanings (self ):
157
- return [self .action_meaning [i ] for i in range (len (self .action_meaning ))]
152
+ self .action_meanings = action_meanings
153
+ self .name_to_action_num = {name : num for num , name in
154
+ enumerate (self .action_meanings )}
158
155
159
156
def get_keys_to_action (self ):
160
157
"""Get mapping from keyboard keys to actions.
@@ -178,7 +175,7 @@ def get_keys_to_action(self):
178
175
179
176
keys_to_action = {}
180
177
181
- for action_id , action_meaning in enumerate (self ._get_action_meanings () ):
178
+ for action_id , action_meaning in enumerate (self .action_meanings ):
182
179
keys = []
183
180
for keyword , key in keyword_to_key .items ():
184
181
if keyword in action_meaning :
@@ -255,6 +252,7 @@ def _augment_observation(self, ob, reward, cumulative_reward):
255
252
pixel_fill = (0 , 255 , 0 )
256
253
else :
257
254
pixel_fill = (255 , 0 , 0 )
255
+ pixel_fill = (255 , 0 , 0 )
258
256
header [0 , :, :] = pixel_fill
259
257
return np .concatenate ([header , ob ], axis = 0 )
260
258
@@ -306,7 +304,7 @@ class SimAndRealEnvPlayer(PlayerEnv):
306
304
307
305
RESTART_SIMULATED_ENV_ACTION = 110
308
306
309
- def __init__ (self , real_env , sim_env ):
307
+ def __init__ (self , real_env , sim_env , action_meanings ):
310
308
"""Init.
311
309
312
310
Args:
@@ -315,7 +313,7 @@ def __init__(self, real_env, sim_env):
315
313
`SimulatedGymEnv` must allow to update initial frames for next reset
316
314
with `add_to_initial_stack` method.
317
315
"""
318
- super (SimAndRealEnvPlayer , self ).__init__ ()
316
+ super (SimAndRealEnvPlayer , self ).__init__ (action_meanings )
319
317
assert real_env .observation_space .shape == sim_env .observation_space .shape
320
318
self .real_env = real_env
321
319
self .sim_env = sim_env
@@ -329,7 +327,6 @@ def __init__(self, real_env, sim_env):
329
327
self .observation_space = gym .spaces .Box (low = orig .low .min (),
330
328
high = orig .high .max (),
331
329
shape = shape , dtype = orig .dtype )
332
- self ._init_action_mappings (sim_env )
333
330
334
331
def _player_actions (self ):
335
332
actions = super (SimAndRealEnvPlayer , self )._player_actions ()
@@ -438,16 +435,15 @@ class SingleEnvPlayer(PlayerEnv):
438
435
Plural form used for consistency with `PlayerEnv`.
439
436
"""
440
437
441
- def __init__ (self , env ):
442
- super (SingleEnvPlayer , self ).__init__ ()
438
+ def __init__ (self , env , action_meanings ):
439
+ super (SingleEnvPlayer , self ).__init__ (action_meanings )
443
440
self .env = env
444
441
# Set observation space
445
442
orig = self .env .observation_space
446
443
shape = tuple ([orig .shape [0 ] + self .HEADER_HEIGHT ] + list (orig .shape [1 :]))
447
444
self .observation_space = gym .spaces .Box (low = orig .low .min (),
448
445
high = orig .high .max (),
449
446
shape = shape , dtype = orig .dtype )
450
- self ._init_action_mappings (env )
451
447
452
448
def _player_step_tuple (self , envs_step_tuples ):
453
449
"""Augment observation, return usual step tuple."""
@@ -494,6 +490,8 @@ def main(_):
494
490
hparams .set_hparam (
495
491
"game" , player_utils .infer_game_name_from_filenames (directories ["data" ])
496
492
)
493
+ action_meanings = gym .make (full_game_name (hparams .game )).\
494
+ unwrapped .get_action_meanings ()
497
495
epoch = FLAGS .epoch if FLAGS .epoch == "last" else int (FLAGS .epoch )
498
496
499
497
def make_real_env ():
@@ -514,18 +512,19 @@ def make_simulated_env(setable_initial_frames, which_epoch_data):
514
512
sim_env = make_simulated_env (
515
513
which_epoch_data = None , setable_initial_frames = True )
516
514
real_env = make_real_env ()
517
- env = SimAndRealEnvPlayer (real_env , sim_env )
515
+ env = SimAndRealEnvPlayer (real_env , sim_env , action_meanings )
518
516
else :
519
517
if FLAGS .simulated_env :
520
518
env = make_simulated_env ( # pylint: disable=redefined-variable-type
521
519
which_epoch_data = epoch , setable_initial_frames = False )
522
520
else :
523
521
env = make_real_env ()
524
- env = SingleEnvPlayer (env ) # pylint: disable=redefined-variable-type
522
+ env = SingleEnvPlayer (env , action_meanings ) # pylint: disable=redefined-variable-type
525
523
526
524
env = player_utils .wrap_with_monitor (env , FLAGS .video_dir )
527
525
528
526
if FLAGS .dry_run :
527
+ env .unwrapped .get_keys_to_action ()
529
528
for _ in range (5 ):
530
529
env .reset ()
531
530
for i in range (50 ):
0 commit comments