Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit d1f9bb2

Browse files
author
Ryan Sepassi
committed
Fix memory usage of rev_block
PiperOrigin-RevId: 165021509
1 parent 94eca0c commit d1f9bb2

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

tensor2tensor/layers/rev_block.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _rev_layer_forward(xs, f, g):
4141
y1 = x1 + f(x2)
4242
with tf.variable_scope("g"):
4343
y2 = x2 + g(y1)
44-
return (y1, y2)
44+
return tf.tuple([y1, y2])
4545

4646

4747
def _rev_layer_backward(ys, grad_ys, f, g, f_vars, g_vars):
@@ -65,17 +65,26 @@ def _rev_layer_backward(ys, grad_ys, f, g, f_vars, g_vars):
6565

6666
# Compute gradients wrt to inputs
6767
# dL/dy2 * dG(y1)/y1
68-
grad_gy1_y2 = tf.gradients(gy1, y1_stop, grad_y2)[0]
68+
grad_gy1_y2 = tf.gradients(gy1, y1_stop, grad_y2, gate_gradients=True)[0]
6969
grad_x1 = grad_y1 + grad_gy1_y2
70-
grad_x2 = (tf.gradients(fx2, x2_stop, grad_y1)[0] + grad_y2 + tf.gradients(
71-
fx2, x2_stop, grad_gy1_y2)[0])
70+
grad_x2 = (
71+
tf.gradients(fx2, x2_stop, grad_y1, gate_gradients=True)[0] + grad_y2 +
72+
tf.gradients(fx2, x2_stop, grad_gy1_y2, gate_gradients=True)[0])
7273

7374
# Compute gradients wrt to vars in f and g
74-
grad_g_vars = tf.gradients(gy1, g_vars, grad_y2)
75-
grad_f_y1 = tf.gradients(fx2, f_vars, grad_y1)
76-
grad_f_y2 = tf.gradients(fx2, f_vars, grad_gy1_y2)
75+
grad_g_vars = tf.gradients(gy1, g_vars, grad_y2, gate_gradients=True)
76+
grad_f_y1 = tf.gradients(fx2, f_vars, grad_y1, gate_gradients=True)
77+
grad_f_y2 = tf.gradients(fx2, f_vars, grad_gy1_y2, gate_gradients=True)
7778
grad_f_vars = [tf.add_n(grads) for grads in zip(grad_f_y1, grad_f_y2)]
7879

80+
# Put returns in a tuple to ensure a constant memory budget (i.e. don't want
81+
# the subsequent layer to start computing and consuming memory based on a
82+
# subset of these values).
83+
outs = tf.tuple([x1, x2, grad_x1, grad_x2] + grad_f_vars + grad_g_vars)
84+
x1, x2, grad_x1, grad_x2 = outs[:4]
85+
grad_f_vars = outs[4:4 + len(grad_f_vars)]
86+
grad_g_vars = outs[4 + len(grad_f_vars):]
87+
7988
return (x1, x2), (grad_x1, grad_x2), grad_f_vars, grad_g_vars
8089

8190

0 commit comments

Comments
 (0)