Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 66afb76

Browse files
blazejosinskiCopybara-Service
authored andcommitted
delete hparam.force_beginning_resets
PiperOrigin-RevId: 219241227
1 parent e6000fc commit 66afb76

File tree

2 files changed

+2
-20
lines changed

2 files changed

+2
-20
lines changed

tensor2tensor/models/research/rl.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,7 @@ def standard_atari_env_spec(env=None, simulated=False):
137137
simulated_env=simulated,
138138
reward_range=env.reward_range,
139139
observation_space=env.observation_space,
140-
action_space=env.action_space,
141-
force_beginning_resets=True
140+
action_space=env.action_space
142141
)
143142
if not simulated:
144143
env_spec.add_hparam("env", env)
@@ -150,7 +149,6 @@ def standard_atari_env_simulated_spec(real_env, **kwargs):
150149
env_spec = standard_atari_env_spec(real_env, simulated=True)
151150
for (name, value) in six.iteritems(kwargs):
152151
env_spec.add_hparam(name, value)
153-
env_spec.force_beginning_resets = False
154152
return env_spec
155153

156154

tensor2tensor/rl/collect.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -141,16 +141,12 @@ def initialization_lambda(sess):
141141
should_reset_var = tf.Variable(True, trainable=False)
142142
zeros_tensor = tf.zeros(len(batch_env))
143143

144-
force_beginning_resets = tf.convert_to_tensor(
145-
environment_spec.force_beginning_resets
146-
)
147-
148144
def reset_ops_group():
149145
return tf.group(batch_env.reset(tf.range(len(batch_env))),
150146
tf.assign(cumulative_rewards, zeros_tensor))
151147

152148
reset_op = tf.cond(
153-
tf.logical_or(should_reset_var.read_value(), force_beginning_resets),
149+
tf.logical_or(should_reset_var.read_value(), eval_phase_t),
154150
reset_ops_group, tf.no_op)
155151

156152
with tf.control_dependencies([reset_op]):
@@ -238,18 +234,6 @@ def stop_condition(i, _, resets):
238234
parallel_iterations=1,
239235
back_prop=False)
240236

241-
# We handle force_beginning_resets differently. We assume that all envs are
242-
# reseted at the end of episod (though it happens at the beginning of the
243-
# next one
244-
scores_num = tf.cond(force_beginning_resets,
245-
lambda: scores_num + len(batch_env), lambda: scores_num)
246-
247-
with tf.control_dependencies([scores_sum]):
248-
scores_sum = tf.cond(
249-
force_beginning_resets,
250-
lambda: scores_sum + tf.reduce_sum(cumulative_rewards.read_value()),
251-
lambda: scores_sum)
252-
253237
mean_score = tf.cond(tf.greater(scores_num, 0),
254238
lambda: scores_sum / tf.cast(scores_num, tf.float32),
255239
lambda: 0.)

0 commit comments

Comments
 (0)