Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit ffe2386

Browse files
nshazeerRyan Sepassi
authored andcommitted
Added optional memory-efficient versions of conv-hidden-relu and self-attention.
PiperOrigin-RevId: 166915506
1 parent 357c9d4 commit ffe2386

File tree

5 files changed

+398
-18
lines changed

5 files changed

+398
-18
lines changed

tensor2tensor/layers/common_attention.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131
import tensorflow as tf
3232

33+
from tensorflow.python.framework import function
34+
3335

3436
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
3537
"""Adds a bunch of sinusoids of different frequencies to a Tensor.
@@ -1100,3 +1102,150 @@ def local_expert_attention(
11001102
additional_dispatch_params=additional_dispatch_params,
11011103
pad_remover=pad_remover
11021104
)
1105+
1106+
1107+
def scaled_dot_product_attention_simple(q, k, v, bias, name=None):
1108+
"""scaled dot-product attention. One head. One spatial dimension.
1109+
1110+
Args:
1111+
q: a Tensor with shape [batch, length_q, depth_k]
1112+
k: a Tensor with shape [batch, length_kv, depth_k]
1113+
v: a Tensor with shape [batch, length_kv, depth_v]
1114+
bias: optional Tensor broadcastable to [batch, length_q, length_kv]
1115+
name: an optional string
1116+
1117+
Returns:
1118+
A Tensor.
1119+
"""
1120+
with tf.variable_scope(
1121+
name, default_name="scaled_dot_product_attention_simple"):
1122+
scalar = tf.rsqrt(tf.to_float(tf.shape(q)[2]))
1123+
logits = tf.matmul(q * scalar, k, transpose_b=True)
1124+
if bias is not None:
1125+
logits += bias
1126+
weights = tf.nn.softmax(logits, name="attention_weights")
1127+
return tf.matmul(weights, v)
1128+
1129+
1130+
_function_cache = {}
1131+
1132+
1133+
def multihead_self_attention_memory_efficient(x,
1134+
bias,
1135+
num_heads,
1136+
head_size=None,
1137+
epsilon=1e-6,
1138+
forget=True,
1139+
test_vars=None,
1140+
name=None):
1141+
"""Multihead scaled-dot-product self-attention.
1142+
1143+
Includes layer norm.
1144+
1145+
Returns multihead-self-attention(layer_norm(x))
1146+
1147+
Computes one attention head at a time to avoid exhausting memory.
1148+
1149+
If forget=True, then forget all forwards activations and recompute on
1150+
the backwards pass.
1151+
1152+
Args:
1153+
x: a Tensor with shape [batch, length, input_size]
1154+
bias: an attention bias tensor broadcastable to [batch, 1, length, length]
1155+
num_heads: an integer
1156+
head_size: an optional integer - defaults to input_size/num_heads
1157+
epsilon: a float, for layer norm
1158+
forget: a boolean - forget forwards activations and recompute on backprop
1159+
test_vars: optional tuple of variables for testing purposes
1160+
name: an optional string
1161+
1162+
Returns:
1163+
A Tensor.
1164+
"""
1165+
io_size = x.get_shape().as_list()[-1]
1166+
if head_size is None:
1167+
assert io_size % num_heads == 0
1168+
head_size = io_size / num_heads
1169+
1170+
def forward_internal(x, wqkv, wo, attention_bias, norm_scale, norm_bias):
1171+
"""Forward function."""
1172+
n = common_layers.layer_norm_compute_python(
1173+
x, epsilon, norm_scale, norm_bias)
1174+
wqkv_split = tf.unstack(wqkv, num=num_heads)
1175+
wo_split = tf.unstack(wo, num=num_heads)
1176+
y = 0
1177+
for h in xrange(num_heads):
1178+
with tf.control_dependencies([y] if h > 0 else []):
1179+
combined = tf.nn.conv1d(n, wqkv_split[h], 1, "SAME")
1180+
q, k, v = tf.split(combined, 3, axis=2)
1181+
o = scaled_dot_product_attention_simple(q, k, v, attention_bias)
1182+
y += tf.nn.conv1d(o, wo_split[h], 1, "SAME")
1183+
return y
1184+
1185+
key = ("multihead_self_attention_memory_efficient %s %s" %
1186+
(num_heads, epsilon))
1187+
if not forget:
1188+
forward_fn = forward_internal
1189+
elif key in _function_cache:
1190+
forward_fn = _function_cache[key]
1191+
else:
1192+
@function.Defun(compiled=True)
1193+
def grad_fn(x, wqkv, wo, attention_bias, norm_scale, norm_bias, dy):
1194+
with tf.control_dependencies([dy]):
1195+
n = common_layers.layer_norm_compute_python(
1196+
x, epsilon, norm_scale, norm_bias)
1197+
wqkv_split = tf.unstack(wqkv, num=num_heads)
1198+
wo_split = tf.unstack(wo, num=num_heads)
1199+
deps = []
1200+
dwqkvs = []
1201+
dwos = []
1202+
dn = 0
1203+
for h in xrange(num_heads):
1204+
with tf.control_dependencies(deps):
1205+
combined = tf.nn.conv1d(n, wqkv_split[h], 1, "SAME")
1206+
q, k, v = tf.split(combined, 3, axis=2)
1207+
o = scaled_dot_product_attention_simple(q, k, v, attention_bias)
1208+
partial_y = tf.nn.conv1d(o, wo_split[h], 1, "SAME")
1209+
pdn, dwqkvh, dwoh = tf.gradients(
1210+
ys=[partial_y],
1211+
xs=[n, wqkv_split[h], wo_split[h]],
1212+
grad_ys=[dy])
1213+
dn += pdn
1214+
dwqkvs.append(dwqkvh)
1215+
dwos.append(dwoh)
1216+
deps = [dn, dwqkvh, dwoh]
1217+
dwqkv = tf.stack(dwqkvs)
1218+
dwo = tf.stack(dwos)
1219+
with tf.control_dependencies(deps):
1220+
dx, dnorm_scale, dnorm_bias = tf.gradients(
1221+
ys=[n], xs=[x, norm_scale, norm_bias], grad_ys=[dn])
1222+
return (dx, dwqkv, dwo, tf.zeros_like(attention_bias),
1223+
dnorm_scale, dnorm_bias)
1224+
1225+
@function.Defun(grad_func=grad_fn, compiled=True,
1226+
separate_compiled_gradients=True)
1227+
def forward_fn(x, wqkv, wo, attention_bias, norm_scale, norm_bias):
1228+
return forward_internal(
1229+
x, wqkv, wo, attention_bias, norm_scale, norm_bias)
1230+
_function_cache[key] = forward_fn
1231+
1232+
if bias is not None:
1233+
bias = tf.squeeze(bias, 1)
1234+
with tf.variable_scope(name, default_name="multihead_attention", values=[x]):
1235+
# TODO(noam): it would be nice to save memory by casting x to float16
1236+
# here, but this causes problems with the gradients. Figure out if there
1237+
# is a way to leave the gradients as float32.
1238+
if test_vars is not None:
1239+
wqkv, wo, norm_scale, norm_bias = list(test_vars)
1240+
else:
1241+
wqkv = tf.get_variable(
1242+
"wqkv", [num_heads, 1, io_size, 3 * head_size],
1243+
initializer=tf.random_normal_initializer(stddev=io_size**-0.5))
1244+
wo = tf.get_variable(
1245+
"wo", [num_heads, 1, head_size, io_size],
1246+
initializer=tf.random_normal_initializer(
1247+
stddev=(head_size * num_heads)**-0.5))
1248+
norm_scale, norm_bias = common_layers.layer_norm_vars(io_size)
1249+
y = forward_fn(x, wqkv, wo, bias, norm_scale, norm_bias)
1250+
y.set_shape(x.get_shape())
1251+
return y

tensor2tensor/layers/common_attention_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import numpy as np
2525
from tensor2tensor.layers import common_attention
26+
from tensor2tensor.layers import common_layers
2627

2728
import tensorflow as tf
2829

@@ -117,6 +118,49 @@ def testLocalUnmaskedAttention2DMatchingBlockLength(self):
117118
res = session.run(a)
118119
self.assertEqual(res.shape, (5, 4, 25, 25, 16))
119120

121+
def testMultiheadSelfAttentionMemoryEfficient(self):
122+
num_heads = 4
123+
io_size = 16
124+
batch = 2
125+
length = 7
126+
head_size = 5
127+
x = np.random.rand(batch, length, io_size)
128+
dy = np.random.rand(batch, length, io_size)
129+
with self.test_session() as session:
130+
x = tf.to_float(x)
131+
dy = tf.to_float(dy)
132+
bias = common_attention.attention_bias_lower_triangle(length)
133+
wqkv = tf.get_variable(
134+
"wqkv", [num_heads, 1, io_size, 3 * head_size],
135+
initializer=tf.random_normal_initializer(stddev=io_size**-0.5))
136+
wo = tf.get_variable(
137+
"wo", [num_heads, 1, head_size, io_size],
138+
initializer=tf.random_normal_initializer(
139+
stddev=(head_size * num_heads)**-0.5))
140+
norm_scale, norm_bias = common_layers.layer_norm_vars(io_size)
141+
y = common_attention.multihead_self_attention_memory_efficient(
142+
x, bias, num_heads, head_size=head_size, forget=False,
143+
test_vars=(wqkv, wo, norm_scale, norm_bias))
144+
y_forget = common_attention.multihead_self_attention_memory_efficient(
145+
x, bias, num_heads, head_size=head_size, forget=True,
146+
test_vars=(wqkv, wo, norm_scale, norm_bias))
147+
dx, dwqkv, dwo, dnorm_scale, dnorm_bias = tf.gradients(
148+
ys=[y], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy])
149+
dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f = tf.gradients(
150+
ys=[y_forget], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy])
151+
session.run(tf.global_variables_initializer())
152+
(y, y_forget,
153+
dx, dwqkv, dwo, dnorm_scale, dnorm_bias,
154+
dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f) = session.run(
155+
[y, y_forget,
156+
dx, dwqkv, dwo, dnorm_scale, dnorm_bias,
157+
dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f])
158+
self.assertAllClose(y, y_forget)
159+
self.assertAllClose(dwo, dwo_f)
160+
self.assertAllClose(dwqkv, dwqkv_f)
161+
self.assertAllClose(dnorm_scale, dnorm_scale_f)
162+
self.assertAllClose(dnorm_bias, dnorm_bias_f)
163+
self.assertAllClose(dx, dx_f)
120164

121165
if __name__ == "__main__":
122166
tf.test.main()

tensor2tensor/layers/common_layers.py

Lines changed: 121 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,15 @@ def conv_fn(inputs, filters, kernel_size, **kwargs):
425425
return conv_internal(conv_fn, inputs, filters, kernel_size, **kwargs)
426426

427427

428+
def layer_norm_vars(filters):
429+
"""Create Variables for layer norm."""
430+
scale = tf.get_variable(
431+
"layer_norm_scale", [filters], initializer=tf.ones_initializer())
432+
bias = tf.get_variable(
433+
"layer_norm_bias", [filters], initializer=tf.zeros_initializer())
434+
return scale, bias
435+
436+
428437
def layer_norm_compute_python(x, epsilon, scale, bias):
429438
"""Layer norm raw computation."""
430439
mean = tf.reduce_mean(x, axis=[-1], keep_dims=True)
@@ -1773,7 +1782,7 @@ def smoothing_cross_entropy_factored_grad(op, dy):
17731782
b = op.inputs[1]
17741783
labels = op.inputs[2]
17751784
confidence = op.inputs[3]
1776-
num_splits = 32
1785+
num_splits = 16
17771786
vocab_size = tf.shape(b)[0]
17781787
labels = approximate_split(labels, num_splits)
17791788
a = approximate_split(a, num_splits)
@@ -1817,7 +1826,7 @@ def smoothing_cross_entropy_factored(a, b, labels, confidence):
18171826
Returns:
18181827
A Tensor with shape [batch]
18191828
"""
1820-
num_splits = 32
1829+
num_splits = 16
18211830
vocab_size = tf.shape(b)[0]
18221831
labels = approximate_split(labels, num_splits)
18231832
a = approximate_split(a, num_splits)
@@ -1957,3 +1966,113 @@ def identity(*args):
19571966

19581967
id_out = identity(*(inputs + train_vars + outputs))
19591968
return id_out
1969+
1970+
1971+
_function_cache = {}
1972+
1973+
1974+
def conv_hidden_relu_memory_efficient(x,
1975+
filter_size,
1976+
epsilon=1e-6,
1977+
forget=True,
1978+
test_vars=None,
1979+
name=None):
1980+
"""LayerNorm, Conv, ReLU, Conv.
1981+
1982+
All convolutions have kernel size 1.
1983+
1984+
returns conv(relu(conv(layer_norm(x))))
1985+
1986+
Args:
1987+
x: input Tensor with shape [batch, length, io_size]
1988+
filter_size: an integer - size of the hidden layer.
1989+
epsilon: a float (for layer norm)
1990+
forget: a boolean - forget forwards activations and recompute on backprop
1991+
test_vars: optional tuple of variables for testing purposes
1992+
name: an optional string
1993+
1994+
Returns:
1995+
a Tensor with shape [batch, length, io_size]
1996+
"""
1997+
io_size = x.get_shape().as_list()[-1]
1998+
1999+
def forward_internal(x, f1, f2, scale, bias):
2000+
"""Forward function."""
2001+
# split batch-wise to avoid exhausting memory in cast the batch is large
2002+
# and the hidden layer is large.
2003+
num_splits = 4
2004+
x_flat = tf.reshape(x, [-1, 1, tf.shape(x)[2]])
2005+
xs = approximate_split(x_flat, num_splits)
2006+
ys = []
2007+
for i in xrange(num_splits):
2008+
with tf.control_dependencies(ys[-1:]):
2009+
n = layer_norm_compute_python(xs[i], epsilon, scale, bias)
2010+
y = tf.nn.conv1d(n, f1, 1, "SAME")
2011+
y = tf.nn.relu(y)
2012+
y = tf.nn.conv1d(y, f2, 1, "SAME")
2013+
ys.append(y)
2014+
y = tf.concat(ys, 0)
2015+
y = tf.reshape(y, tf.shape(x))
2016+
return y
2017+
key = ("conv_hidden_relu_memory_efficient %s" % epsilon)
2018+
if not forget:
2019+
forward_fn = forward_internal
2020+
elif key in _function_cache:
2021+
forward_fn = _function_cache[key]
2022+
else:
2023+
@function.Defun(compiled=True)
2024+
def grad_fn(x, f1, f2, scale, bias, dy):
2025+
with tf.control_dependencies([dy]):
2026+
num_splits = 4
2027+
x_shape = tf.shape(x)
2028+
flat_shape = [-1, 1, x_shape[2]]
2029+
x = tf.reshape(x, flat_shape)
2030+
dy = tf.reshape(dy, flat_shape)
2031+
xs = approximate_split(x, num_splits)
2032+
dys = approximate_split(dy, num_splits)
2033+
dxs = []
2034+
df1 = 0
2035+
df2 = 0
2036+
dscale = 0
2037+
dbias = 0
2038+
deps = []
2039+
for i in xrange(num_splits):
2040+
with tf.control_dependencies(deps):
2041+
n = layer_norm_compute_python(xs[i], epsilon, scale, bias)
2042+
y = tf.nn.conv1d(n, f1, 1, "SAME")
2043+
y = tf.nn.relu(y)
2044+
y = tf.nn.conv1d(y, f2, 1, "SAME")
2045+
dxi, pdf1, pdf2, pdscale, pdbias = tf.gradients(
2046+
ys=[y], xs=[xs[i], f1, f2, scale, bias], grad_ys=[dys[i]])
2047+
df1 += pdf1
2048+
df2 += pdf2
2049+
dscale += pdscale
2050+
dbias += pdbias
2051+
dxs.append(dxi)
2052+
deps = [dxi, df1, df2, dscale, dbias]
2053+
with tf.control_dependencies(deps):
2054+
dx = tf.concat(dxs, 0)
2055+
dx = tf.reshape(dx, x_shape)
2056+
return dx, df1, df2, dscale, dbias
2057+
2058+
@function.Defun(grad_func=grad_fn, compiled=True,
2059+
separate_compiled_gradients=True)
2060+
def forward_fn(x, f1, f2, scale, bias):
2061+
return forward_internal(x, f1, f2, scale, bias)
2062+
2063+
with tf.variable_scope(name, default_name="ffn2", values=[x]):
2064+
# TODO(noam): it would be nice to save memory by casting x to float16
2065+
# here, but this causes problems with the gradients. Figure out if there
2066+
# is a way to leave the gradients as float32.
2067+
if test_vars is not None:
2068+
f1, f2, scale, bias = list(test_vars)
2069+
else:
2070+
f1 = tf.get_variable("f1", [1, io_size, filter_size])
2071+
f2 = tf.get_variable("f2", [1, filter_size, io_size])
2072+
scale, bias = layer_norm_vars(io_size)
2073+
if forget:
2074+
y = forward_fn(x, f1, f2, scale, bias)
2075+
else:
2076+
y = forward_internal(x, f1, f2, scale, bias)
2077+
y.set_shape(x.get_shape())
2078+
return y

0 commit comments

Comments
 (0)