Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 7331b7c

Browse files
authored
Merge pull request #643 from deepsense-ai/rl_notebook
Notebook presenting rl module and simple model for generating Pong frames
2 parents fd9b315 + 9e8f6b5 commit 7331b7c

File tree

7 files changed

+503
-31
lines changed

7 files changed

+503
-31
lines changed

.travis.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ script:
4141
--ignore=tensor2tensor/problems_test.py
4242
--ignore=tensor2tensor/bin/t2t_trainer_test.py
4343
--ignore=tensor2tensor/data_generators/algorithmic_math_test.py
44-
--ignore=tensor2tensor/rl/rl_trainer_lib_test.py
4544
- pytest tensor2tensor/utils/registry_test.py
4645
- pytest tensor2tensor/utils/trainer_lib_test.py
4746
- pytest tensor2tensor/visualization/visualization_test.py

tensor2tensor/data_generators/gym.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@
3535
import tensorflow as tf
3636

3737

38-
39-
4038
flags = tf.flags
4139
FLAGS = flags.FLAGS
4240

@@ -50,6 +48,17 @@ def __init__(self, *args, **kwargs):
5048
super(GymDiscreteProblem, self).__init__(*args, **kwargs)
5149
self._env = None
5250

51+
def example_reading_spec(self, label_repr=None):
52+
53+
data_fields = {
54+
"inputs": tf.FixedLenFeature([210, 160, 3], tf.int64),
55+
"inputs_prev": tf.FixedLenFeature([210, 160, 3], tf.int64),
56+
"targets": tf.FixedLenFeature([210, 160, 3], tf.int64),
57+
"action": tf.FixedLenFeature([1], tf.int64)
58+
}
59+
60+
return data_fields, None
61+
5362
@property
5463
def env_name(self):
5564
# This is the name of the Gym environment for this problem.
@@ -133,7 +142,7 @@ class GymPongRandom5k(GymDiscreteProblem):
133142

134143
@property
135144
def env_name(self):
136-
return "Pong-v0"
145+
return "PongNoFrameskip-v4"
137146

138147
@property
139148
def num_actions(self):
@@ -148,21 +157,30 @@ def num_steps(self):
148157
return 5000
149158

150159

160+
151161
@registry.register_problem
152162
class GymPongTrajectoriesFromPolicy(GymDiscreteProblem):
153163
"""Pong game, loaded actions."""
154164

155-
def __init__(self, event_dir, *args, **kwargs):
165+
def __init__(self, *args, **kwargs):
156166
super(GymPongTrajectoriesFromPolicy, self).__init__(*args, **kwargs)
157167
self._env = None
158-
self._event_dir = event_dir
168+
self._last_policy_op = None
169+
self._max_frame_pl = None
170+
self._last_action = self.env.action_space.sample()
171+
self._skip = 4
172+
self._skip_step = 0
173+
self._obs_buffer = np.zeros((2,) + self.env.observation_space.shape,
174+
dtype=np.uint8)
175+
176+
def generator(self, data_dir, tmp_dir):
159177
env_spec = lambda: atari_wrappers.wrap_atari( # pylint: disable=g-long-lambda
160178
gym.make("PongNoFrameskip-v4"),
161179
warp=False,
162180
frame_skip=4,
163181
frame_stack=False)
164182
hparams = rl.atari_base()
165-
with tf.variable_scope("train"):
183+
with tf.variable_scope("train", reuse=tf.AUTO_REUSE):
166184
policy_lambda = hparams.network
167185
policy_factory = tf.make_template(
168186
"network",
@@ -173,14 +191,13 @@ def __init__(self, event_dir, *args, **kwargs):
173191
self._max_frame_pl, 0), 0))
174192
policy = actor_critic.policy
175193
self._last_policy_op = policy.mode()
176-
self._last_action = self.env.action_space.sample()
177-
self._skip = 4
178-
self._skip_step = 0
179-
self._obs_buffer = np.zeros((2,) + self.env.observation_space.shape,
180-
dtype=np.uint8)
181-
self._sess = tf.Session()
182-
model_saver = tf.train.Saver(tf.global_variables(".*network_parameters.*"))
183-
model_saver.restore(self._sess, FLAGS.model_path)
194+
with tf.Session() as sess:
195+
model_saver = tf.train.Saver(
196+
tf.global_variables(".*network_parameters.*"))
197+
model_saver.restore(sess, FLAGS.model_path)
198+
for item in super(GymPongTrajectoriesFromPolicy,
199+
self).generator(data_dir, tmp_dir):
200+
yield item
184201

185202
# TODO(blazej0): For training of atari agents wrappers are usually used.
186203
# Below we have a hacky solution which is a workaround to be used together
@@ -191,7 +208,7 @@ def get_action(self, observation=None):
191208
self._skip_step = (self._skip_step + 1) % self._skip
192209
if self._skip_step == 0:
193210
max_frame = self._obs_buffer.max(axis=0)
194-
self._last_action = int(self._sess.run(
211+
self._last_action = int(tf.get_default_session().run(
195212
self._last_policy_op,
196213
feed_dict={self._max_frame_pl: max_frame})[0, 0])
197214
return self._last_action

tensor2tensor/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from tensor2tensor.models.research import aligned
4242
from tensor2tensor.models.research import attention_lm
4343
from tensor2tensor.models.research import attention_lm_moe
44+
from tensor2tensor.models.research import basic_conv_gen
4445
from tensor2tensor.models.research import cycle_gan
4546
from tensor2tensor.models.research import gene_expression
4647
from tensor2tensor.models.research import multimodel
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
2+
# coding=utf-8
3+
# Copyright 2018 The Tensor2Tensor Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Basic models for testing simple tasks."""
18+
19+
from __future__ import absolute_import
20+
from __future__ import division
21+
from __future__ import print_function
22+
23+
# Dependency imports
24+
25+
from tensor2tensor.layers import common_hparams
26+
from tensor2tensor.layers import common_layers
27+
from tensor2tensor.utils import registry
28+
from tensor2tensor.utils import t2t_model
29+
30+
import tensorflow as tf
31+
32+
33+
@registry.register_model
34+
class BasicConvGen(t2t_model.T2TModel):
35+
36+
def body(self, features):
37+
print(features)
38+
filters = self.hparams.hidden_size
39+
cur_frame = tf.to_float(features["inputs"])
40+
prev_frame = tf.to_float(features["inputs_prev"])
41+
print(features["inputs"].shape, cur_frame.shape, prev_frame.shape)
42+
action = common_layers.embedding(tf.to_int64(features["action"]),
43+
10, filters)
44+
action = tf.reshape(action, [-1, 1, 1, filters])
45+
46+
frames = tf.concat([cur_frame, prev_frame], axis=3)
47+
h1 = tf.layers.conv2d(frames, filters, kernel_size=(3, 3), padding="SAME")
48+
h2 = tf.layers.conv2d(tf.nn.relu(h1 + action), filters,
49+
kernel_size=(5, 5), padding="SAME")
50+
res = tf.layers.conv2d(tf.nn.relu(h2 + action), 3 * 256,
51+
kernel_size=(3, 3), padding="SAME")
52+
53+
height = tf.shape(res)[1]
54+
width = tf.shape(res)[2]
55+
res = tf.reshape(res, [-1, height, width, 3, 256])
56+
return res
57+
58+
59+
@registry.register_hparams
60+
def basic_conv_small():
61+
# """Small conv model."""
62+
hparams = common_hparams.basic_params1()
63+
hparams.hidden_size = 32
64+
hparams.batch_size = 2
65+
return hparams

0 commit comments

Comments
 (0)