Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 3cf9c38

Browse files
author
Błażej O
committed
Merge branch 'stop_hejt_and_gradients' into eval_implemantation
2 parents e97ce1b + b1da810 commit 3cf9c38

File tree

4 files changed

+74
-57
lines changed

4 files changed

+74
-57
lines changed

tensor2tensor/models/research/rl.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,26 @@ def feed_forward_categorical_fun(action_space, config, observations):
131131
value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0]
132132
policy = tf.contrib.distributions.Categorical(logits=logits)
133133
return NetworkOutput(policy, value, lambda a: a)
134+
135+
136+
def feed_forward_cnn_small_categorical_fun(action_space, config, observations):
137+
"""Small cnn network with categorical output."""
138+
obs_shape = observations.shape.as_list()
139+
x = tf.reshape(observations, [-1]+ obs_shape[2:])
140+
141+
with tf.variable_scope('policy'):
142+
x = tf.to_float(x)/255.0
143+
x = tf.contrib.layers.conv2d(x, 32, [5, 5], [2, 2], activation_fn= tf.nn.relu, padding="SAME")
144+
x = tf.contrib.layers.conv2d(x, 32, [5, 5], [2, 2], activation_fn=tf.nn.relu, padding="SAME")
145+
146+
flat_x = tf.reshape(x, [
147+
tf.shape(observations)[0], tf.shape(observations)[1],
148+
functools.reduce(operator.mul, x.shape.as_list()[1:], 1)])
149+
150+
x = tf.contrib.layers.fully_connected(flat_x, 128, tf.nn.relu)
151+
logits = tf.contrib.layers.fully_connected(x, action_space.n, activation_fn=None)
152+
153+
value = tf.contrib.layers.fully_connected(x, 1, activation_fn=None)[..., 0]
154+
policy = tf.contrib.distributions.Categorical(logits=logits)
155+
156+
return NetworkOutput(policy, value, lambda a: a)

tensor2tensor/rl/collect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def define_collect(policy_factory, batch_env, hparams, eval_phase):
2424
memory_shape = [hparams.epoch_length] + [batch_env.observ.shape.as_list()[0]]
2525
memories_shapes_and_types = [
2626
# observation
27-
(memory_shape + [batch_env.observ.shape.as_list()[1]], tf.float32),
27+
(memory_shape + batch_env.observ.shape.as_list()[1:], tf.float32),
2828
(memory_shape, tf.float32), # reward
2929
(memory_shape, tf.bool), # done
3030
# action

tensor2tensor/rl/ppo.py

Lines changed: 43 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,17 @@
2121
import tensorflow as tf
2222

2323

24+
def get_optimiser(config):
25+
26+
if config.optimizer=='Adam':
27+
return tf.train.AdamOptimizer(config.learning_rate)
28+
29+
return config.optimizer(config.learning_rate)
30+
31+
2432
def define_ppo_step(observation, action, reward, done, value, old_pdf,
2533
policy_factory, config):
26-
"""A step of PPO."""
34+
2735
new_policy_dist, new_value, _ = policy_factory(observation)
2836
new_pdf = new_policy_dist.prob(action)
2937

@@ -43,27 +51,30 @@ def define_ppo_step(observation, action, reward, done, value, old_pdf,
4351
ratio * advantage_normalized)
4452
policy_loss = -tf.reduce_mean(surrogate_objective)
4553

46-
value_error = calculate_discounted_return(
47-
reward, new_value, done, config.gae_gamma, config.gae_lambda) - value
54+
value_error = calculate_generalized_advantage_estimator(
55+
reward, new_value, done, config.gae_gamma, config.gae_lambda)
4856
value_loss = config.value_loss_coef * tf.reduce_mean(value_error ** 2)
4957

5058
entropy = new_policy_dist.entropy()
5159
entropy_loss = -config.entropy_loss_coef * tf.reduce_mean(entropy)
5260

53-
total_loss = policy_loss + value_loss + entropy_loss
61+
optimizer = get_optimiser(config)
62+
losses = [policy_loss, value_loss, entropy_loss]
5463

55-
optimization_op = tf.contrib.layers.optimize_loss(
56-
loss=total_loss,
57-
global_step=tf.train.get_or_create_global_step(),
58-
optimizer=config.optimizer,
59-
learning_rate=config.learning_rate)
64+
gradients = [list(zip(*optimizer.compute_gradients(loss))) for loss in losses]
6065

61-
with tf.control_dependencies([optimization_op]):
62-
return [tf.identity(x) for x in (policy_loss, value_loss, entropy_loss)]
66+
gradients_norms = [tf.global_norm(gradient[0]) for gradient in gradients]
67+
68+
gradients_flat = sum([gradient[0] for gradient in gradients], ())
69+
gradients_variables_flat = sum([gradient[1] for gradient in gradients], ())
70+
71+
optimize_op = optimizer.apply_gradients(zip(gradients_flat, gradients_variables_flat))
72+
73+
with tf.control_dependencies([optimize_op]):
74+
return [tf.identity(x) for x in losses + gradients_norms]
6375

6476

6577
def define_ppo_epoch(memory, policy_factory, config):
66-
"""An epoch of PPO."""
6778
observation, reward, done, action, old_pdf, value = memory
6879

6980
# This is to avoid propagating gradients though simulation of simulation
@@ -74,59 +85,39 @@ def define_ppo_epoch(memory, policy_factory, config):
7485
value = tf.stop_gradient(value)
7586
old_pdf = tf.stop_gradient(old_pdf)
7687

77-
policy_loss, value_loss, entropy_loss = tf.scan(
78-
lambda _1, _2: define_ppo_step( # pylint: disable=g-long-lambda
79-
observation, action, reward, done, value,
80-
old_pdf, policy_factory, config),
88+
ppo_step_rets = tf.scan(
89+
lambda _1, _2: define_ppo_step(observation, action, reward, done, value,
90+
old_pdf, policy_factory, config),
8191
tf.range(config.optimization_epochs),
82-
[0., 0., 0.],
92+
[0., 0., 0., 0., 0., 0.],
8393
parallel_iterations=1)
8494

85-
summaries = [tf.summary.scalar("policy loss", tf.reduce_mean(policy_loss)),
86-
tf.summary.scalar("value loss", tf.reduce_mean(value_loss)),
87-
tf.summary.scalar("entropy loss", tf.reduce_mean(entropy_loss))]
95+
ppo_summaries = [tf.reduce_mean(ret) for ret in ppo_step_rets]
96+
summaries_names = ["policy_loss", "value_loss", "entropy_loss",
97+
"policy_gradient", "value_gradient", "entropy_gradient"]
8898

99+
summaries = [tf.summary.scalar(summary_name, summary)
100+
for summary_name, summary in zip(summaries_names, ppo_summaries)]
89101
losses_summary = tf.summary.merge(summaries)
90102

91-
losses_summary = tf.Print(losses_summary,
92-
[tf.reduce_mean(policy_loss)], "policy loss: ")
93-
losses_summary = tf.Print(losses_summary,
94-
[tf.reduce_mean(value_loss)], "value loss: ")
95-
losses_summary = tf.Print(losses_summary,
96-
[tf.reduce_mean(entropy_loss)], "entropy loss: ")
103+
for summary_name, summary in zip(summaries_names, ppo_summaries):
104+
losses_summary = tf.Print(losses_summary, [summary], summary_name + ": ")
97105

98106
return losses_summary
99107

108+
def calculate_generalized_advantage_estimator(reward, value, done, gae_gamma, gae_lambda):
109+
"""Generalized advantage estimator"""
100110

101-
def calculate_discounted_return(reward, value, done, discount, unused_lambda):
102-
"""Discounted Monte-Carlo returns."""
103-
done = tf.cast(done, tf.float32)
104-
reward2 = done[-1, :] * reward[-1, :] + (1 - done[-1, :]) * value[-1, :]
105-
reward = tf.concat([reward[:-1,], reward2[None, ...]], axis=0)
106-
return_ = tf.reverse(tf.scan(
107-
lambda agg, cur: cur[0] + (1 - cur[1]) * discount * agg, # fn
108-
[tf.reverse(reward, [0]), # elem
109-
tf.reverse(done, [0])],
110-
tf.zeros_like(reward[0, :]), # initializer
111-
1,
112-
False), [0])
113-
return tf.check_numerics(return_, "return")
114-
115-
116-
def calculate_generalized_advantage_estimator(
117-
reward, value, done, gae_gamma, gae_lambda):
118-
"""Generalized advantage estimator."""
119-
# Below is slight weirdness, we set the last reward to 0.
120-
# This makes the adventantage to be 0 in the last timestep.
121-
reward = tf.concat([reward[:-1, :], value[-1:, :]], axis=0)
122-
next_value = tf.concat([value[1:, :], tf.zeros_like(value[-1:, :])], axis=0)
123-
next_not_done = 1 - tf.cast(tf.concat(
124-
[done[1:, :], tf.zeros_like(done[-1:, :])], axis=0), tf.float32)
111+
# Below is slight wierdness, we set the last reward to 0.
112+
# This makes the adventantage to be 0 in the last timestep
113+
reward = tf.concat([reward[:-1,:], value[-1:,:]], axis=0)
114+
next_value = tf.concat([value[1:,:], tf.zeros_like(value[-1:, :])], axis=0)
115+
next_not_done = 1 - tf.cast(tf.concat([done[1:, :], tf.zeros_like(done[-1:, :])], axis=0), tf.float32)
125116
delta = reward + gae_gamma * next_value * next_not_done - value
126117

127118
return_ = tf.reverse(tf.scan(
128119
lambda agg, cur: cur[0] + cur[1] * gae_gamma * gae_lambda * agg,
129120
[tf.reverse(delta, [0]), tf.reverse(next_not_done, [0])],
130121
tf.zeros_like(delta[0, :]),
131-
1, False), [0])
132-
return tf.check_numerics(tf.stop_gradient(return_), "return")
122+
parallel_iterations=1), [0])
123+
return tf.check_numerics(return_, 'return')

tensor2tensor/rl/rl_trainer_lib.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,12 @@
3333
import tensorflow as tf
3434

3535

36-
def define_train(hparams, environment_name, event_dir):
36+
def define_train(hparams, environment_spec, event_dir):
3737
"""Define the training setup."""
38-
env_lambda = lambda: gym.make(environment_name)
38+
if isinstance(environment_spec, str):
39+
env_lambda = lambda: gym.make(environment_spec)
40+
else:
41+
env_lambda = environment_spec
3942
policy_lambda = hparams.network
4043
env = env_lambda()
4144
action_space = env.action_space
@@ -63,8 +66,8 @@ def define_train(hparams, environment_name, event_dir):
6366
return summary, eval_summary
6467

6568

66-
def train(hparams, environment_name, event_dir=None):
67-
train_summary_op, eval_summary_op = define_train(hparams, environment_name, event_dir)
69+
def train(hparams, environment_spec, event_dir=None):
70+
train_summary_op, eval_summary_op = define_train(hparams, environment_spec, event_dir)
6871

6972
if event_dir:
7073
summary_writer = tf.summary.FileWriter(

0 commit comments

Comments
 (0)