@@ -41,7 +41,7 @@ def _rev_layer_forward(xs, f, g):
41
41
y1 = x1 + f (x2 )
42
42
with tf .variable_scope ("g" ):
43
43
y2 = x2 + g (y1 )
44
- return ( y1 , y2 )
44
+ return tf . tuple ([ y1 , y2 ] )
45
45
46
46
47
47
def _rev_layer_backward (ys , grad_ys , f , g , f_vars , g_vars ):
@@ -65,17 +65,26 @@ def _rev_layer_backward(ys, grad_ys, f, g, f_vars, g_vars):
65
65
66
66
# Compute gradients wrt to inputs
67
67
# dL/dy2 * dG(y1)/y1
68
- grad_gy1_y2 = tf .gradients (gy1 , y1_stop , grad_y2 )[0 ]
68
+ grad_gy1_y2 = tf .gradients (gy1 , y1_stop , grad_y2 , gate_gradients = True )[0 ]
69
69
grad_x1 = grad_y1 + grad_gy1_y2
70
- grad_x2 = (tf .gradients (fx2 , x2_stop , grad_y1 )[0 ] + grad_y2 + tf .gradients (
71
- fx2 , x2_stop , grad_gy1_y2 )[0 ])
70
+ grad_x2 = (
71
+ tf .gradients (fx2 , x2_stop , grad_y1 , gate_gradients = True )[0 ] + grad_y2 +
72
+ tf .gradients (fx2 , x2_stop , grad_gy1_y2 , gate_gradients = True )[0 ])
72
73
73
74
# Compute gradients wrt to vars in f and g
74
- grad_g_vars = tf .gradients (gy1 , g_vars , grad_y2 )
75
- grad_f_y1 = tf .gradients (fx2 , f_vars , grad_y1 )
76
- grad_f_y2 = tf .gradients (fx2 , f_vars , grad_gy1_y2 )
75
+ grad_g_vars = tf .gradients (gy1 , g_vars , grad_y2 , gate_gradients = True )
76
+ grad_f_y1 = tf .gradients (fx2 , f_vars , grad_y1 , gate_gradients = True )
77
+ grad_f_y2 = tf .gradients (fx2 , f_vars , grad_gy1_y2 , gate_gradients = True )
77
78
grad_f_vars = [tf .add_n (grads ) for grads in zip (grad_f_y1 , grad_f_y2 )]
78
79
80
+ # Put returns in a tuple to ensure a constant memory budget (i.e. don't want
81
+ # the subsequent layer to start computing and consuming memory based on a
82
+ # subset of these values).
83
+ outs = tf .tuple ([x1 , x2 , grad_x1 , grad_x2 ] + grad_f_vars + grad_g_vars )
84
+ x1 , x2 , grad_x1 , grad_x2 = outs [:4 ]
85
+ grad_f_vars = outs [4 :4 + len (grad_f_vars )]
86
+ grad_g_vars = outs [4 + len (grad_f_vars ):]
87
+
79
88
return (x1 , x2 ), (grad_x1 , grad_x2 ), grad_f_vars , grad_g_vars
80
89
81
90
0 commit comments