Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit f1e161c

Browse files
Mehrad0711Copybara-Service
authored andcommitted
internal merge of PR #1171
PiperOrigin-RevId: 219151604
1 parent de2964a commit f1e161c

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

tensor2tensor/utils/avg_checkpoints.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,11 @@ def main(_):
9090
for name in var_values: # Average.
9191
var_values[name] /= len(checkpoints)
9292

93-
tf_vars = [
94-
tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[name])
95-
for v in var_values
96-
]
93+
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
94+
tf_vars = [
95+
tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[v])
96+
for v in var_values
97+
]
9798
placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
9899
assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
99100
global_step = tf.Variable(

0 commit comments

Comments
 (0)