Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 9266e67

Browse files
konradczechowskiafrozenator
authored andcommitted
Player, correct mapping keyboard keys to actions. (#1364)
1 parent 27224cd commit 9266e67

File tree

2 files changed

+43
-30
lines changed

2 files changed

+43
-30
lines changed

tensor2tensor/rl/player.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,15 @@
5353
from __future__ import print_function
5454

5555
import gym
56-
from gym.envs.atari.atari_env import ACTION_MEANING
5756
from gym.utils import play
5857
import numpy as np
59-
import six
6058

6159
from tensor2tensor.bin import t2t_trainer # pylint: disable=unused-import
6260
from tensor2tensor.rl import player_utils
6361
from tensor2tensor.rl.envs.simulated_batch_env import PIL_Image
6462
from tensor2tensor.rl.envs.simulated_batch_env import PIL_ImageDraw
6563
from tensor2tensor.rl.envs.simulated_batch_gym_env import FlatBatchEnv
66-
from tensor2tensor.rl.rl_utils import absolute_hinge_difference
64+
from tensor2tensor.rl.rl_utils import absolute_hinge_difference, full_game_name
6765
# Import flags from t2t_trainer and trainer_model_based
6866
import tensor2tensor.rl.trainer_model_based_params # pylint: disable=unused-import
6967
from tensor2tensor.utils import registry
@@ -137,24 +135,23 @@ class PlayerEnv(gym.Env):
137135

138136
HEADER_HEIGHT = 27
139137

140-
def __init__(self):
138+
def __init__(self, action_meanings):
139+
"""
140+
141+
Args:
142+
action_meanings: list of strings indicating action names. Can be obtain by
143+
>>> env = gym.make("PongNoFrameskip-v4") # insert your game name
144+
>>> env.unwrapped.get_action_meanings()
145+
See gym AtariEnv get_action_meanings() for more details.
146+
"""
147+
self.action_meanings = action_meanings
141148
self._wait = True
142149
# If action_space will be needed, one could use e.g. gym.spaces.Dict.
143150
self.action_space = None
144151
self._last_step_tuples = None
145-
146-
def _init_action_mappings(self, env):
147-
# Atari dependant. In case of problems with keyboard key interpretation
148-
# switch to _action_set instead of range(env.action_space.n) (similarly to
149-
# how gym AtariEnv does). _action_set can probably be obtain from full
150-
# game name.
151-
self.action_meaning = {i: ACTION_MEANING[i]
152-
for i in range(env.action_space.n)}
153-
self.name_to_action_num = {v: k for k, v in
154-
six.iteritems(self.action_meaning)}
155-
156-
def _get_action_meanings(self):
157-
return [self.action_meaning[i] for i in range(len(self.action_meaning))]
152+
self.action_meanings = action_meanings
153+
self.name_to_action_num = {name: num for num, name in
154+
enumerate(self.action_meanings)}
158155

159156
def get_keys_to_action(self):
160157
"""Get mapping from keyboard keys to actions.
@@ -178,7 +175,7 @@ def get_keys_to_action(self):
178175

179176
keys_to_action = {}
180177

181-
for action_id, action_meaning in enumerate(self._get_action_meanings()):
178+
for action_id, action_meaning in enumerate(self.action_meanings):
182179
keys = []
183180
for keyword, key in keyword_to_key.items():
184181
if keyword in action_meaning:
@@ -255,6 +252,7 @@ def _augment_observation(self, ob, reward, cumulative_reward):
255252
pixel_fill = (0, 255, 0)
256253
else:
257254
pixel_fill = (255, 0, 0)
255+
pixel_fill = (255, 0, 0)
258256
header[0, :, :] = pixel_fill
259257
return np.concatenate([header, ob], axis=0)
260258

@@ -306,7 +304,7 @@ class SimAndRealEnvPlayer(PlayerEnv):
306304

307305
RESTART_SIMULATED_ENV_ACTION = 110
308306

309-
def __init__(self, real_env, sim_env):
307+
def __init__(self, real_env, sim_env, action_meanings):
310308
"""Init.
311309
312310
Args:
@@ -315,7 +313,7 @@ def __init__(self, real_env, sim_env):
315313
`SimulatedGymEnv` must allow to update initial frames for next reset
316314
with `add_to_initial_stack` method.
317315
"""
318-
super(SimAndRealEnvPlayer, self).__init__()
316+
super(SimAndRealEnvPlayer, self).__init__(action_meanings)
319317
assert real_env.observation_space.shape == sim_env.observation_space.shape
320318
self.real_env = real_env
321319
self.sim_env = sim_env
@@ -329,7 +327,6 @@ def __init__(self, real_env, sim_env):
329327
self.observation_space = gym.spaces.Box(low=orig.low.min(),
330328
high=orig.high.max(),
331329
shape=shape, dtype=orig.dtype)
332-
self._init_action_mappings(sim_env)
333330

334331
def _player_actions(self):
335332
actions = super(SimAndRealEnvPlayer, self)._player_actions()
@@ -438,16 +435,15 @@ class SingleEnvPlayer(PlayerEnv):
438435
Plural form used for consistency with `PlayerEnv`.
439436
"""
440437

441-
def __init__(self, env):
442-
super(SingleEnvPlayer, self).__init__()
438+
def __init__(self, env, action_meanings):
439+
super(SingleEnvPlayer, self).__init__(action_meanings)
443440
self.env = env
444441
# Set observation space
445442
orig = self.env.observation_space
446443
shape = tuple([orig.shape[0] + self.HEADER_HEIGHT] + list(orig.shape[1:]))
447444
self.observation_space = gym.spaces.Box(low=orig.low.min(),
448445
high=orig.high.max(),
449446
shape=shape, dtype=orig.dtype)
450-
self._init_action_mappings(env)
451447

452448
def _player_step_tuple(self, envs_step_tuples):
453449
"""Augment observation, return usual step tuple."""
@@ -494,6 +490,8 @@ def main(_):
494490
hparams.set_hparam(
495491
"game", player_utils.infer_game_name_from_filenames(directories["data"])
496492
)
493+
action_meanings = gym.make(full_game_name(hparams.game)).\
494+
unwrapped.get_action_meanings()
497495
epoch = FLAGS.epoch if FLAGS.epoch == "last" else int(FLAGS.epoch)
498496

499497
def make_real_env():
@@ -514,18 +512,19 @@ def make_simulated_env(setable_initial_frames, which_epoch_data):
514512
sim_env = make_simulated_env(
515513
which_epoch_data=None, setable_initial_frames=True)
516514
real_env = make_real_env()
517-
env = SimAndRealEnvPlayer(real_env, sim_env)
515+
env = SimAndRealEnvPlayer(real_env, sim_env, action_meanings)
518516
else:
519517
if FLAGS.simulated_env:
520518
env = make_simulated_env( # pylint: disable=redefined-variable-type
521519
which_epoch_data=epoch, setable_initial_frames=False)
522520
else:
523521
env = make_real_env()
524-
env = SingleEnvPlayer(env) # pylint: disable=redefined-variable-type
522+
env = SingleEnvPlayer(env, action_meanings) # pylint: disable=redefined-variable-type
525523

526524
env = player_utils.wrap_with_monitor(env, FLAGS.video_dir)
527525

528526
if FLAGS.dry_run:
527+
env.unwrapped.get_keys_to_action()
529528
for _ in range(5):
530529
env.reset()
531530
for i in range(50):

tensor2tensor/rl/rl_utils.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,26 @@ def summarize_metrics(eval_metrics_writer, metrics, epoch):
126126
}
127127

128128

129+
ATARI_GAME_MODE = "NoFrameskip-v4"
130+
131+
132+
def full_game_name(short_name):
133+
"""CamelCase game name with mode suffix.
134+
135+
Args:
136+
short_name: snake_case name without mode e.g "crazy_climber"
137+
138+
Returns:
139+
full game name e.g. "CrazyClimberNoFrameskip-v4"
140+
"""
141+
camel_game_name = misc_utils.snakecase_to_camelcase(short_name)
142+
full_name = camel_game_name + ATARI_GAME_MODE
143+
return full_name
144+
145+
129146
def setup_env(hparams, batch_size, max_num_noops, rl_env_max_episode_steps=-1):
130147
"""Setup."""
131-
game_mode = "NoFrameskip-v4"
132-
camel_game_name = misc_utils.snakecase_to_camelcase(hparams.game)
133-
camel_game_name += game_mode
134-
env_name = camel_game_name
148+
env_name = full_game_name(hparams.game)
135149

136150
env = T2TGymEnv(base_env_name=env_name,
137151
batch_size=batch_size,

0 commit comments

Comments
 (0)