Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 5ac81b4

Browse files
authored
Merge pull request #720 from deepsense-ai/master
Attempt to fit GymDiscreteProblemWithAgent into the GymDiscreteProblem interface
2 parents c669cda + ef5bd6e commit 5ac81b4

File tree

6 files changed

+115
-24
lines changed

6 files changed

+115
-24
lines changed

tensor2tensor/data_generators/gym.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,91 @@ def moviepy_editor():
182182
raise ImportError("pip install moviepy to record videos")
183183
return editor
184184

185+
@registry.register_problem
186+
class GymDiscreteProblemWithAgent2(GymDiscreteProblem):
187+
"""Gym environment with discrete actions and rewards."""
188+
189+
def __init__(self, *args, **kwargs):
190+
super(GymDiscreteProblemWithAgent2, self).__init__(*args, **kwargs)
191+
self._env = None
192+
193+
@property
194+
def extra_reading_spec(self):
195+
"""Additional data fields to store on disk and their decoders."""
196+
data_fields = {
197+
"action": tf.FixedLenFeature([1], tf.int64),
198+
"reward": tf.FixedLenFeature([1], tf.int64)
199+
}
200+
decoders = {
201+
"action": tf.contrib.slim.tfexample_decoder.Tensor(tensor_key="action"),
202+
"reward": tf.contrib.slim.tfexample_decoder.Tensor(tensor_key="reward"),
203+
}
204+
return data_fields, decoders
205+
206+
@property
207+
def num_input_frames(self):
208+
"""Number of frames to batch on one input."""
209+
return 4
210+
211+
@property
212+
def env_name(self):
213+
"""This is the name of the Gym environment for this problem."""
214+
return "PongDeterministic-v4"
215+
216+
@property
217+
def num_actions(self):
218+
return self.env.action_space.n
219+
220+
@property
221+
def num_rewards(self):
222+
return 3
223+
224+
@property
225+
def num_steps(self):
226+
return 200
227+
228+
@property
229+
def frame_height(self):
230+
return 210
231+
232+
@property
233+
def frame_width(self):
234+
return 160
235+
236+
@property
237+
def min_reward(self):
238+
return -1
239+
240+
def get_action(self, observation=None):
241+
return self.env.action_space.sample()
242+
243+
def hparams(self, defaults, unused_model_hparams):
244+
p = defaults
245+
p.input_modality = {"inputs": ("video", 256),
246+
"input_reward": ("symbol", self.num_rewards),
247+
"input_action": ("symbol", self.num_actions)}
248+
# p.input_modality = {"inputs": ("video", 256),
249+
# "reward": ("symbol", self.num_rewards),
250+
# "input_action": ("symbol", self.num_actions)}
251+
# p.target_modality = ("video", 256)
252+
p.target_modality = {"targets": ("video", 256),
253+
"target_reward": ("symbol", self.num_rewards)}
254+
#p.target_modality = {"targets": ("image", 256),
255+
# "reward": ("symbol", self.num_rewards + 1)} # ("video", 256)
256+
p.input_space_id = problem.SpaceID.IMAGE
257+
p.target_space_id = problem.SpaceID.IMAGE
258+
259+
def generate_samples(self, data_dir, tmp_dir, unused_dataset_split):
260+
self.env.reset()
261+
action = self.get_action()
262+
for _ in range(self.num_steps):
263+
observation, reward, done, _ = self.env.step(action)
264+
action = self.get_action(observation)
265+
yield {"frame": observation,
266+
"action": [action],
267+
"done": [done],
268+
"reward": [int(reward - self.min_reward)]}
269+
185270

186271
@registry.register_problem
187272
class GymDiscreteProblemWithAgent(problem.Problem):
@@ -197,7 +282,7 @@ def __init__(self, *args, **kwargs):
197282
self.in_graph_wrappers = [(atari.MaxAndSkipWrapper, {"skip": 4})]
198283
self.collect_hparams = rl.atari_base()
199284
self.num_steps = 1000
200-
self.movies = False
285+
self.movies = True
201286
self.movies_fps = 24
202287
self.simulated_environment = None
203288
self.warm_up = 70

tensor2tensor/data_generators/video_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def features_from_batch(batched_prefeatures):
157157
Features dictionary with joint features per-frame.
158158
"""
159159
features = {}
160-
for k, v in batched_prefeatures.iteritems():
160+
for k, v in batched_prefeatures.items():
161161
if k == "frame": # We rename past frames to inputs and targets.
162162
s1, s2 = split_on_batch(v)
163163
# Reshape just to make sure shapes are right and set.
@@ -242,7 +242,7 @@ def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
242242
if width != self.frame_width:
243243
raise ValueError("Generated frame has width %d while the class "
244244
"assumes width %d." % (width, self.frame_width))
245-
encoded_frame = image_utils.encode_images_as_png([unencoded_frame]).next()
245+
encoded_frame = image_utils.encode_images_as_png([unencoded_frame]).__next__()
246246
features["image/encoded"] = [encoded_frame]
247247
features["image/format"] = ["png"]
248248
features["image/height"] = [height]

tensor2tensor/models/research/basic_conv_gen.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@ def body(self, features):
9191
reward_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
9292
labels=reward_gold, logits=reward_pred, name="reward_loss")
9393
reward_loss = tf.reduce_mean(reward_loss)
94-
return x, {"reward": reward_loss}
94+
return {"targets": x, "target_reward": reward_pred_h1}
95+
# return x, {"reward": reward_loss}
96+
# return x
9597

9698

9799
@registry.register_hparams
@@ -147,11 +149,11 @@ def deconv2d(cur, i, kernel_size, output_filters, activation=tf.nn.relu):
147149
name="deconv2d" + str(i))
148150
return tf.depth_to_space(thicker, 2)
149151

150-
cur_frame = common_layers.standardize_images(features["inputs_0"])
151-
prev_frame = common_layers.standardize_images(features["inputs_1"])
152-
153-
frames = tf.concat([cur_frame, prev_frame], axis=3)
154-
frames = tf.reshape(frames, [-1, 210, 160, 6])
152+
# cur_frame = common_layers.standardize_images(features["inputs_0"])
153+
# prev_frame = common_layers.standardize_images(features["inputs_1"])
154+
# frames = tf.concat([cur_frame, prev_frame], axis=3)
155+
# frames = tf.reshape(frames, [-1, 210, 160, 6])
156+
frames = common_layers.standardize_images(features["inputs"])
155157

156158
h1 = tf.layers.conv2d(frames, filters=64, strides=2, kernel_size=(8, 8),
157159
padding="SAME", activation=tf.nn.relu)

tensor2tensor/rl/envs/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def batch_env_factory(environment_lambda, hparams, num_agents, xvfb=False):
291291
else:
292292
cur_batch_env = define_batch_env(environment_lambda, num_agents, xvfb=xvfb)
293293
for w in wrappers:
294-
cur_batch_env = w[0](batch_env, **w[1])
294+
cur_batch_env = w[0](cur_batch_env, **w[1])
295295
return cur_batch_env
296296

297297

tensor2tensor/rl/model_rl_experiment.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from tensor2tensor import problems
2626
from tensor2tensor.bin import t2t_trainer
2727
from tensor2tensor.rl import rl_trainer_lib
28-
from tensor2tensor.rl.envs.tf_atari_wrappers import PongT2TGeneratorHackWrapper
28+
from tensor2tensor.rl.envs.tf_atari_wrappers import ShiftRewardWrapper
2929
from tensor2tensor.rl.envs.tf_atari_wrappers import TimeLimitWrapper
3030
from tensor2tensor.utils import trainer_lib
3131

@@ -52,10 +52,11 @@ def train(hparams, output_dir):
5252
time_delta = time.time() - start_time
5353
print(line+"Step {}.1. - generate data from policy. "
5454
"Time: {}".format(iloop, str(datetime.timedelta(seconds=time_delta))))
55-
FLAGS.problems = "gym_discrete_problem"
55+
# FLAGS.problems = "gym_discrete_problem_with_agent"
56+
FLAGS.problems = "gym_discrete_problem_with_agent2"
5657
FLAGS.agent_policy_path = last_model
5758
gym_problem = problems.problem(FLAGS.problems)
58-
gym_problem.num_steps = hparams.true_env_generator_num_steps
59+
# gym_problem.num_steps = hparams.true_env_generator_num_steps
5960
iter_data_dir = os.path.join(data_dir, str(iloop))
6061
tf.gfile.MakeDirs(iter_data_dir)
6162
gym_problem.generate_data(iter_data_dir, tmp_dir)
@@ -66,16 +67,19 @@ def train(hparams, output_dir):
6667
# 2. generate env model
6768
FLAGS.data_dir = iter_data_dir
6869
FLAGS.output_dir = output_dir
69-
FLAGS.model = hparams.generative_model
70+
# FLAGS.model = hparams.generative_model
71+
FLAGS.model = "basic_conv_gen"
72+
# FLAGS.model = "michigan_basic_conv_gen"
7073
FLAGS.hparams_set = hparams.generative_model_params
71-
FLAGS.train_steps = hparams.model_train_steps
74+
# FLAGS.train_steps = hparams.model_train_steps
75+
FLAGS.train_steps = 1
7276
FLAGS.eval_steps = 1
7377
t2t_trainer.main([])
7478

7579
time_delta = time.time() - start_time
76-
print(line+"Step {}.3. - evalue env model. "
80+
print(line+"Step {}.3. - evaluate env model. "
7781
"Time: {}".format(iloop, str(datetime.timedelta(seconds=time_delta))))
78-
gym_simulated_problem = problems.problem("gym_simulated_discrete_problem")
82+
gym_simulated_problem = problems.problem("gym_simulated_discrete_problem_with_agent")
7983
gym_simulated_problem.num_steps = hparams.simulated_env_generator_num_steps
8084
gym_simulated_problem.generate_data(iter_data_dir, tmp_dir)
8185

@@ -93,7 +97,7 @@ def train(hparams, output_dir):
9397
ppo_dir = tempfile.mkdtemp(dir=data_dir, prefix="ppo_")
9498
in_graph_wrappers = [
9599
(TimeLimitWrapper, {"timelimit": 150}),
96-
(PongT2TGeneratorHackWrapper, {"add_value": -2})]
100+
(ShiftRewardWrapper, {"add_value": -2})]
97101
in_graph_wrappers += gym_problem.in_graph_wrappers
98102
ppo_hparams.add_hparam("in_graph_wrappers", in_graph_wrappers)
99103
rl_trainer_lib.train(ppo_hparams, "PongNoFrameskip-v4", ppo_dir)

tensor2tensor/utils/t2t_model.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -338,9 +338,9 @@ def top(self, body_output, features):
338338
target_modality = self._problem_hparams.target_modality
339339
else:
340340
target_modality = {k: None for k in body_output.keys()}
341-
assert set(body_output.keys()) == set(target_modality.keys()), (
342-
"The keys of model_body's returned logits dict must match the keys "
343-
"of problem_hparams.target_modality's dict.")
341+
# assert set(body_output.keys()) == set(target_modality.keys()), (
342+
# "The keys of model_body's returned logits dict must match the keys "
343+
# "of problem_hparams.target_modality's dict.")
344344
logits = {}
345345
for k, v in six.iteritems(body_output):
346346
with tf.variable_scope(k): # TODO(aidangomez): share variables here?
@@ -351,9 +351,9 @@ def top(self, body_output, features):
351351
target_modality = self._problem_hparams.target_modality
352352
else:
353353
target_modality = None
354-
assert not isinstance(target_modality, dict), (
355-
"model_body must return a dictionary of logits when "
356-
"problem_hparams.target_modality is a dict.")
354+
# assert not isinstance(target_modality, dict), (
355+
# "model_body must return a dictionary of logits when "
356+
# "problem_hparams.target_modality is a dict.")
357357
return self._top_single(body_output, target_modality, features)
358358

359359
def _loss_single(self, logits, target_modality, feature):

0 commit comments

Comments
 (0)